diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000000000000000000000000000000000..45232b80ed8cd1c7ce3892a9d181d38b131240d7 --- /dev/null +++ b/.clang-format @@ -0,0 +1,161 @@ +--- +Language: Cpp +AlignAfterOpenBracket: Align +AlignArrayOfStructures: Left +AlignConsecutiveAssignments: AcrossComments +AlignConsecutiveBitFields: AcrossComments +AlignConsecutiveDeclarations: AcrossComments +AlignConsecutiveMacros: AcrossComments +# AlignConsecutiveShortCaseStatements: AcrossComments +AlignEscapedNewlines: Left # LeftWithLastLine +AlignOperands: Align +AlignTrailingComments: + Kind: Always + OverEmptyLines: 1 +AllowAllArgumentsOnNextLine: true +AllowAllParametersOfDeclarationOnNextLine: false +# AllowBreakBeforeNoexceptSpecifier: OnlyWithParen +AllowShortBlocksOnASingleLine: Never +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Inline +AllowShortIfStatementsOnASingleLine: Never +AllowShortLambdasOnASingleLine: Inline +AllowShortLoopsOnASingleLine: false +AlwaysBreakBeforeMultilineStrings: true +BinPackArguments: true +BinPackParameters: true # OnePerLine +BitFieldColonSpacing: Both +BreakBeforeBraces: Custom # Attach +BraceWrapping: + AfterCaseLabel: true + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + AfterExternBlock: false + BeforeCatch: false + BeforeElse: false + BeforeLambdaBody: false + BeforeWhile: false + IndentBraces: false + SplitEmptyFunction: false + SplitEmptyRecord: false + SplitEmptyNamespace: false +# BreakAdjacentStringLiterals: true +BreakAfterAttributes: Never +BreakBeforeBinaryOperators: None +BreakBeforeInlineASMColon: OnlyMultiline +BreakBeforeTernaryOperators: false +# BreakBinaryOperations: Never +BreakConstructorInitializers: AfterColon +# BreakFunctionDefinitionParameters: false +BreakInheritanceList: AfterComma +BreakStringLiterals: true +# BreakTemplateDeclarations: Yes +ColumnLimit: 120 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: false +DerivePointerAlignment: false +DisableFormat: false +EmptyLineBeforeAccessModifier: Leave +EmptyLineAfterAccessModifier: Never +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +IncludeBlocks: Regroup +IncludeCategories: + - Regex: '^<.*\.h>' + Priority: 1 + SortPriority: 0 + - Regex: '^<.*' + Priority: 2 + SortPriority: 0 + - Regex: '.*' + Priority: 3 + SortPriority: 0 +IncludeIsMainRegex: '([-_](test|unittest))?$' +IncludeIsMainSourceRegex: '' +IndentAccessModifiers: false +IndentCaseBlocks: true +IndentCaseLabels: true +IndentExternBlock: NoIndent +IndentGotoLabels: false +IndentPPDirectives: AfterHash +IndentWidth: 4 +IndentWrappedFunctionNames: false +InsertBraces: true # NOTE: may lead to incorrect formatting +InsertNewlineAtEOF: true +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: false +LambdaBodyIndentation: Signature +LineEnding: LF +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBinPackProtocolList: Auto +ObjCBlockIndentWidth: 4 +ObjCSpaceAfterProperty: true +ObjCSpaceBeforeProtocolList: true +PPIndentWidth: -1 +PackConstructorInitializers: CurrentLine +PenaltyBreakAssignment: 2 +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +PointerAlignment: Middle +QualifierAlignment: Left +#QualifierOrder: ['static', 'inline', 'friend', 'constexpr', 'const', 'volatile', 'type', 'restrict'] +RawStringFormats: + - Language: Cpp + Delimiters: + - cc + - CC + - cpp + - Cpp + - CPP + - 'c++' + - 'C++' + CanonicalDelimiter: '' +ReferenceAlignment: Middle +ReflowComments: false # IndentOnly +SeparateDefinitionBlocks: Always +SortIncludes: CaseInsensitive +SortUsingDeclarations: LexicographicNumeric +SpaceAfterCStyleCast: true +SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyBlock: false +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 2 +SpacesInAngles: Never +SpacesInContainerLiterals: true +SpacesInLineCommentPrefix: + Minimum: 1 + Maximum: -1 +SpacesInParentheses: false +SpacesInSquareBrackets: false +SpaceBeforeSquareBrackets: false +Standard: c++17 +TabWidth: 4 +UseTab: Never +WhitespaceSensitiveMacros: ['STRINGIZE'] +... + diff --git a/.clang-tidy b/.clang-tidy new file mode 100644 index 0000000000000000000000000000000000000000..5bc63bc6e27b632a1ada1ee776f5ab560bddb243 --- /dev/null +++ b/.clang-tidy @@ -0,0 +1,27 @@ +--- +Checks: > + bugprone-*, + -bugprone-easily-swappable-parameters, + -bugprone-implicit-widening-of-multiplication-result, + -bugprone-misplaced-widening-cast, + -bugprone-narrowing-conversions, + readability-*, + -readability-avoid-unconditional-preprocessor-if, + -readability-function-cognitive-complexity, + -readability-identifier-length, + -readability-implicit-bool-conversion, + -readability-magic-numbers, + -readability-uppercase-literal-suffix, + -readability-simplify-boolean-expr, + -readability-math-missing-parentheses, + clang-analyzer-*, + -clang-analyzer-security.insecureAPI.DeprecatedOrUnsafeBufferHandling, + performance-*, + portability-*, + -portability-simd-intrinsics, + misc-*, + -misc-const-correctness, + -misc-non-private-member-variables-in-classes, + -misc-no-recursion, + -misc-use-anonymous-namespace, +FormatStyle: none diff --git a/.devops/cloud-v-pipeline b/.devops/cloud-v-pipeline new file mode 100644 index 0000000000000000000000000000000000000000..af8c0cea6155c71c0c162e034715908ea16ee9cf --- /dev/null +++ b/.devops/cloud-v-pipeline @@ -0,0 +1,22 @@ +node('x86_runner1'){ // Running on x86 runner containing latest vector qemu, latest vector gcc and all the necessary libraries + stage('Cleanup'){ + cleanWs() // Cleaning previous CI build in workspace + } + stage('checkout repo'){ + retry(5){ // Retry if the cloning fails due to some reason + checkout scm // Clone the repo on Runner + } + } + stage('Compiling llama.cpp'){ + sh'''#!/bin/bash + make RISCV=1 RISCV_CROSS_COMPILE=1 # Compiling llama for RISC-V + ''' + } + stage('Running llama.cpp'){ + sh'''#!/bin/bash + module load gnu-bin2/0.1 # loading latest versions of vector qemu and vector gcc + qemu-riscv64 -L /softwares/gnu-bin2/sysroot -cpu rv64,v=true,vlen=256,elen=64,vext_spec=v1.0 ./llama-cli -m /home/alitariq/codellama-7b.Q4_K_M.gguf -p "Anything" -n 9 > llama_log.txt # Running llama.cpp on vector qemu-riscv64 + cat llama_log.txt # Printing results + ''' + } +} diff --git a/.devops/cpu.Dockerfile b/.devops/cpu.Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..9459f08c10c9454b546acd6c178e6608de7e8d07 --- /dev/null +++ b/.devops/cpu.Dockerfile @@ -0,0 +1,92 @@ +ARG UBUNTU_VERSION=22.04 + +FROM ubuntu:$UBUNTU_VERSION AS build + +ARG TARGETARCH + +ARG GGML_CPU_ARM_ARCH=armv8-a + +RUN apt-get update && \ + apt-get install -y build-essential git cmake libcurl4-openssl-dev + +WORKDIR /app + +COPY . . + +RUN if [ "$TARGETARCH" = "amd64" ]; then \ + cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON; \ + elif [ "$TARGETARCH" = "arm64" ]; then \ + cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DLLAMA_BUILD_TESTS=OFF -DGGML_CPU_ARM_ARCH=${GGML_CPU_ARM_ARCH}; \ + else \ + echo "Unsupported architecture"; \ + exit 1; \ + fi && \ + cmake --build build -j $(nproc) + +RUN mkdir -p /app/lib && \ + find build -name "*.so" -exec cp {} /app/lib \; + +RUN mkdir -p /app/full \ + && cp build/bin/* /app/full \ + && cp *.py /app/full \ + && cp -r gguf-py /app/full \ + && cp -r requirements /app/full \ + && cp requirements.txt /app/full \ + && cp .devops/tools.sh /app/full/tools.sh + +## Base image +FROM ubuntu:$UBUNTU_VERSION AS base + +RUN apt-get update \ + && apt-get install -y libgomp1 curl\ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + +COPY --from=build /app/lib/ /app + +### Full +FROM base AS full + +COPY --from=build /app/full /app + +WORKDIR /app + +RUN apt-get update \ + && apt-get install -y \ + git \ + python3 \ + python3-pip \ + && pip install --upgrade pip setuptools wheel \ + && pip install -r requirements.txt \ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + +ENTRYPOINT ["/app/tools.sh"] + +### Light, CLI only +FROM base AS light + +COPY --from=build /app/full/llama-cli /app + +WORKDIR /app + +ENTRYPOINT [ "/app/llama-cli" ] + +### Server, Server only +FROM base AS server + +ENV LLAMA_ARG_HOST=0.0.0.0 + +COPY --from=build /app/full/llama-server /app + +WORKDIR /app + +HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] + +ENTRYPOINT [ "/app/llama-server" ] diff --git a/.devops/cuda.Dockerfile b/.devops/cuda.Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..94f143397233f4f3dbd85763b54ea978ad6dc7fa --- /dev/null +++ b/.devops/cuda.Dockerfile @@ -0,0 +1,94 @@ +ARG UBUNTU_VERSION=22.04 +# This needs to generally match the container host's environment. +ARG CUDA_VERSION=12.4.0 +# Target the CUDA build image +ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} + +ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION} + +FROM ${BASE_CUDA_DEV_CONTAINER} AS build + +# CUDA architecture to build for (defaults to all supported archs) +ARG CUDA_DOCKER_ARCH=default + +RUN apt-get update && \ + apt-get install -y build-essential cmake python3 python3-pip git libcurl4-openssl-dev libgomp1 + +WORKDIR /app + +COPY . . + +RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \ + export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${CUDA_DOCKER_ARCH}"; \ + fi && \ + cmake -B build -DGGML_NATIVE=OFF -DGGML_CUDA=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DLLAMA_BUILD_TESTS=OFF ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \ + cmake --build build --config Release -j$(nproc) + +RUN mkdir -p /app/lib && \ + find build -name "*.so" -exec cp {} /app/lib \; + +RUN mkdir -p /app/full \ + && cp build/bin/* /app/full \ + && cp *.py /app/full \ + && cp -r gguf-py /app/full \ + && cp -r requirements /app/full \ + && cp requirements.txt /app/full \ + && cp .devops/tools.sh /app/full/tools.sh + +## Base image +FROM ${BASE_CUDA_RUN_CONTAINER} AS base + +RUN apt-get update \ + && apt-get install -y libgomp1 curl\ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + +COPY --from=build /app/lib/ /app + +### Full +FROM base AS full + +COPY --from=build /app/full /app + +WORKDIR /app + +RUN apt-get update \ + && apt-get install -y \ + git \ + python3 \ + python3-pip \ + && pip install --upgrade pip setuptools wheel \ + && pip install -r requirements.txt \ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + + +ENTRYPOINT ["/app/tools.sh"] + +### Light, CLI only +FROM base AS light + +COPY --from=build /app/full/llama-cli /app + +WORKDIR /app + +ENTRYPOINT [ "/app/llama-cli" ] + +### Server, Server only +FROM base AS server + +ENV LLAMA_ARG_HOST=0.0.0.0 + +COPY --from=build /app/full/llama-server /app + +WORKDIR /app + +HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] + +ENTRYPOINT [ "/app/llama-server" ] diff --git a/.devops/intel.Dockerfile b/.devops/intel.Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..9ce80a71eb9501094a753363a1a521b6082e4489 --- /dev/null +++ b/.devops/intel.Dockerfile @@ -0,0 +1,95 @@ +ARG ONEAPI_VERSION=2025.1.1-0-devel-ubuntu24.04 + +## Build Image + +FROM intel/oneapi-basekit:$ONEAPI_VERSION AS build + +ARG GGML_SYCL_F16=OFF +RUN apt-get update && \ + apt-get install -y git libcurl4-openssl-dev + +WORKDIR /app + +COPY . . + +RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \ + echo "GGML_SYCL_F16 is set" \ + && export OPT_SYCL_F16="-DGGML_SYCL_F16=ON"; \ + fi && \ + echo "Building with dynamic libs" && \ + cmake -B build -DGGML_NATIVE=OFF -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DLLAMA_BUILD_TESTS=OFF ${OPT_SYCL_F16} && \ + cmake --build build --config Release -j$(nproc) + +RUN mkdir -p /app/lib && \ + find build -name "*.so" -exec cp {} /app/lib \; + +RUN mkdir -p /app/full \ + && cp build/bin/* /app/full \ + && cp *.py /app/full \ + && cp -r gguf-py /app/full \ + && cp -r requirements /app/full \ + && cp requirements.txt /app/full \ + && cp .devops/tools.sh /app/full/tools.sh + +FROM intel/oneapi-basekit:$ONEAPI_VERSION AS base + +RUN apt-get update \ + && apt-get install -y libgomp1 curl\ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + +### Full +FROM base AS full + +COPY --from=build /app/lib/ /app +COPY --from=build /app/full /app + +WORKDIR /app + +RUN apt-get update && \ + apt-get install -y \ + git \ + python3 \ + python3-pip \ + python3-venv && \ + python3 -m venv /opt/venv && \ + . /opt/venv/bin/activate && \ + pip install --upgrade pip setuptools wheel && \ + pip install -r requirements.txt && \ + apt autoremove -y && \ + apt clean -y && \ + rm -rf /tmp/* /var/tmp/* && \ + find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete && \ + find /var/cache -type f -delete + +ENV PATH="/opt/venv/bin:$PATH" + +ENTRYPOINT ["/app/tools.sh"] + +### Light, CLI only +FROM base AS light + +COPY --from=build /app/lib/ /app +COPY --from=build /app/full/llama-cli /app + +WORKDIR /app + +ENTRYPOINT [ "/app/llama-cli" ] + +### Server, Server only +FROM base AS server + +ENV LLAMA_ARG_HOST=0.0.0.0 + +COPY --from=build /app/lib/ /app +COPY --from=build /app/full/llama-server /app + +WORKDIR /app + +HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] + +ENTRYPOINT [ "/app/llama-server" ] + diff --git a/.devops/llama-cli-cann.Dockerfile b/.devops/llama-cli-cann.Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..ef43d78cd2a856036896a7ef2aa98a42132d0ffa --- /dev/null +++ b/.devops/llama-cli-cann.Dockerfile @@ -0,0 +1,44 @@ +ARG ASCEND_VERSION=8.1.RC1.alpha001-910b-openeuler22.03-py3.10 + +FROM ascendai/cann:$ASCEND_VERSION AS build + +WORKDIR /app + +COPY . . + +RUN yum install -y gcc g++ cmake make libcurl-devel +ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest +ENV LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:$LIBRARY_PATH +ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe/op_tiling:${LD_LIBRARY_PATH} +ENV PYTHONPATH=${ASCEND_TOOLKIT_HOME}/python/site-packages:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe:${PYTHONPATH} +ENV PATH=${ASCEND_TOOLKIT_HOME}/bin:${ASCEND_TOOLKIT_HOME}/compiler/ccec_compiler/bin:${PATH} +ENV ASCEND_AICPU_PATH=${ASCEND_TOOLKIT_HOME} +ENV ASCEND_OPP_PATH=${ASCEND_TOOLKIT_HOME}/opp +ENV TOOLCHAIN_HOME=${ASCEND_TOOLKIT_HOME}/toolkit +ENV ASCEND_HOME_PATH=${ASCEND_TOOLKIT_HOME} + +# find libascend_hal.so, because the drive hasn`t been mounted. +ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/runtime/lib64/stub:$LD_LIBRARY_PATH + +RUN echo "Building with static libs" && \ + source /usr/local/Ascend/ascend-toolkit/set_env.sh --force && \ + cmake -B build -DGGML_NATIVE=OFF -DGGML_CANN=ON -DBUILD_SHARED_LIBS=OFF -DLLAMA_BUILD_TESTS=OFF && \ + cmake --build build --config Release --target llama-cli + +# TODO: use image with NNRT +FROM ascendai/cann:$ASCEND_VERSION AS runtime +COPY --from=build /app/build/bin/llama-cli /llama-cli + +ENV LC_ALL=C.utf8 + +ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest +ENV LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:$LIBRARY_PATH +ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe/op_tiling:${LD_LIBRARY_PATH} +ENV PYTHONPATH=${ASCEND_TOOLKIT_HOME}/python/site-packages:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe:${PYTHONPATH} +ENV PATH=${ASCEND_TOOLKIT_HOME}/bin:${ASCEND_TOOLKIT_HOME}/compiler/ccec_compiler/bin:${PATH} +ENV ASCEND_AICPU_PATH=${ASCEND_TOOLKIT_HOME} +ENV ASCEND_OPP_PATH=${ASCEND_TOOLKIT_HOME}/opp +ENV TOOLCHAIN_HOME=${ASCEND_TOOLKIT_HOME}/toolkit +ENV ASCEND_HOME_PATH=${ASCEND_TOOLKIT_HOME} + +ENTRYPOINT ["/llama-cli" ] diff --git a/.devops/llama-cpp-cuda.srpm.spec b/.devops/llama-cpp-cuda.srpm.spec new file mode 100644 index 0000000000000000000000000000000000000000..3bbf4a4def2a59cae23c4d504fa1aff9d813e668 --- /dev/null +++ b/.devops/llama-cpp-cuda.srpm.spec @@ -0,0 +1,83 @@ +# SRPM for building from source and packaging an RPM for RPM-based distros. +# https://docs.fedoraproject.org/en-US/quick-docs/creating-rpm-packages +# Built and maintained by John Boero - boeroboy@gmail.com +# In honor of Seth Vidal https://www.redhat.com/it/blog/thank-you-seth-vidal + +# Notes for llama.cpp: +# 1. Tags are currently based on hash - which will not sort asciibetically. +# We need to declare standard versioning if people want to sort latest releases. +# 2. Builds for CUDA/OpenCL support are separate, with different depenedencies. +# 3. NVidia's developer repo must be enabled with nvcc, cublas, clblas, etc installed. +# Example: https://developer.download.nvidia.com/compute/cuda/repos/fedora37/x86_64/cuda-fedora37.repo +# 4. OpenCL/CLBLAST support simply requires the ICD loader and basic opencl libraries. +# It is up to the user to install the correct vendor-specific support. + +Name: llama.cpp-cuda +Version: %( date "+%%Y%%m%%d" ) +Release: 1%{?dist} +Summary: CPU Inference of LLaMA model in pure C/C++ (no CUDA/OpenCL) +License: MIT +Source0: https://github.com/ggml-org/llama.cpp/archive/refs/heads/master.tar.gz +BuildRequires: coreutils make gcc-c++ git cuda-toolkit +Requires: cuda-toolkit +URL: https://github.com/ggml-org/llama.cpp + +%define debug_package %{nil} +%define source_date_epoch_from_changelog 0 + +%description +CPU inference for Meta's Lllama2 models using default options. + +%prep +%setup -n llama.cpp-master + +%build +make -j GGML_CUDA=1 + +%install +mkdir -p %{buildroot}%{_bindir}/ +cp -p llama-cli %{buildroot}%{_bindir}/llama-cuda-cli +cp -p llama-server %{buildroot}%{_bindir}/llama-cuda-server +cp -p llama-simple %{buildroot}%{_bindir}/llama-cuda-simple + +mkdir -p %{buildroot}/usr/lib/systemd/system +%{__cat} < %{buildroot}/usr/lib/systemd/system/llamacuda.service +[Unit] +Description=Llama.cpp server, CPU only (no GPU support in this build). +After=syslog.target network.target local-fs.target remote-fs.target nss-lookup.target + +[Service] +Type=simple +EnvironmentFile=/etc/sysconfig/llama +ExecStart=/usr/bin/llama-cuda-server $LLAMA_ARGS +ExecReload=/bin/kill -s HUP $MAINPID +Restart=never + +[Install] +WantedBy=default.target +EOF + +mkdir -p %{buildroot}/etc/sysconfig +%{__cat} < %{buildroot}/etc/sysconfig/llama +LLAMA_ARGS="-m /opt/llama2/ggml-model-f32.bin" +EOF + +%clean +rm -rf %{buildroot} +rm -rf %{_builddir}/* + +%files +%{_bindir}/llama-cuda-cli +%{_bindir}/llama-cuda-server +%{_bindir}/llama-cuda-simple +/usr/lib/systemd/system/llamacuda.service +%config /etc/sysconfig/llama + +%pre + +%post + +%preun +%postun + +%changelog diff --git a/.devops/llama-cpp.srpm.spec b/.devops/llama-cpp.srpm.spec new file mode 100644 index 0000000000000000000000000000000000000000..45902dcf896e0b239ce17560947c532b87379a0f --- /dev/null +++ b/.devops/llama-cpp.srpm.spec @@ -0,0 +1,85 @@ +# SRPM for building from source and packaging an RPM for RPM-based distros. +# https://docs.fedoraproject.org/en-US/quick-docs/creating-rpm-packages +# Built and maintained by John Boero - boeroboy@gmail.com +# In honor of Seth Vidal https://www.redhat.com/it/blog/thank-you-seth-vidal + +# Notes for llama.cpp: +# 1. Tags are currently based on hash - which will not sort asciibetically. +# We need to declare standard versioning if people want to sort latest releases. +# In the meantime, YYYYMMDD format will be used. +# 2. Builds for CUDA/OpenCL support are separate, with different depenedencies. +# 3. NVidia's developer repo must be enabled with nvcc, cublas, clblas, etc installed. +# Example: https://developer.download.nvidia.com/compute/cuda/repos/fedora37/x86_64/cuda-fedora37.repo +# 4. OpenCL/CLBLAST support simply requires the ICD loader and basic opencl libraries. +# It is up to the user to install the correct vendor-specific support. + +Name: llama.cpp +Version: %( date "+%%Y%%m%%d" ) +Release: 1%{?dist} +Summary: CPU Inference of LLaMA model in pure C/C++ (no CUDA/OpenCL) +License: MIT +Source0: https://github.com/ggml-org/llama.cpp/archive/refs/heads/master.tar.gz +BuildRequires: coreutils make gcc-c++ git libstdc++-devel +Requires: libstdc++ +URL: https://github.com/ggml-org/llama.cpp + +%define debug_package %{nil} +%define source_date_epoch_from_changelog 0 + +%description +CPU inference for Meta's Lllama2 models using default options. +Models are not included in this package and must be downloaded separately. + +%prep +%setup -n llama.cpp-master + +%build +make -j + +%install +mkdir -p %{buildroot}%{_bindir}/ +cp -p llama-cli %{buildroot}%{_bindir}/llama-cli +cp -p llama-server %{buildroot}%{_bindir}/llama-server +cp -p llama-simple %{buildroot}%{_bindir}/llama-simple + +mkdir -p %{buildroot}/usr/lib/systemd/system +%{__cat} < %{buildroot}/usr/lib/systemd/system/llama.service +[Unit] +Description=Llama.cpp server, CPU only (no GPU support in this build). +After=syslog.target network.target local-fs.target remote-fs.target nss-lookup.target + +[Service] +Type=simple +EnvironmentFile=/etc/sysconfig/llama +ExecStart=/usr/bin/llama-server $LLAMA_ARGS +ExecReload=/bin/kill -s HUP $MAINPID +Restart=never + +[Install] +WantedBy=default.target +EOF + +mkdir -p %{buildroot}/etc/sysconfig +%{__cat} < %{buildroot}/etc/sysconfig/llama +LLAMA_ARGS="-m /opt/llama2/ggml-model-f32.bin" +EOF + +%clean +rm -rf %{buildroot} +rm -rf %{_builddir}/* + +%files +%{_bindir}/llama-cli +%{_bindir}/llama-server +%{_bindir}/llama-simple +/usr/lib/systemd/system/llama.service +%config /etc/sysconfig/llama + +%pre + +%post + +%preun +%postun + +%changelog diff --git a/.devops/musa.Dockerfile b/.devops/musa.Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..87ce2393f6bf9b1ef7acbfedb05c3eac32cb405d --- /dev/null +++ b/.devops/musa.Dockerfile @@ -0,0 +1,101 @@ +ARG UBUNTU_VERSION=22.04 +# This needs to generally match the container host's environment. +ARG MUSA_VERSION=rc4.0.1 +# Target the MUSA build image +ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-mudnn-devel-ubuntu${UBUNTU_VERSION} + +ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-mudnn-runtime-ubuntu${UBUNTU_VERSION} + +FROM ${BASE_MUSA_DEV_CONTAINER} AS build + +# MUSA architecture to build for (defaults to all supported archs) +ARG MUSA_DOCKER_ARCH=default + +RUN apt-get update && \ + apt-get install -y \ + build-essential \ + cmake \ + python3 \ + python3-pip \ + git \ + libcurl4-openssl-dev \ + libgomp1 + +WORKDIR /app + +COPY . . + +RUN if [ "${MUSA_DOCKER_ARCH}" != "default" ]; then \ + export CMAKE_ARGS="-DMUSA_ARCHITECTURES=${MUSA_DOCKER_ARCH}"; \ + fi && \ + cmake -B build -DGGML_NATIVE=OFF -DGGML_MUSA=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DLLAMA_BUILD_TESTS=OFF ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \ + cmake --build build --config Release -j$(nproc) + +RUN mkdir -p /app/lib && \ + find build -name "*.so" -exec cp {} /app/lib \; + +RUN mkdir -p /app/full \ + && cp build/bin/* /app/full \ + && cp *.py /app/full \ + && cp -r gguf-py /app/full \ + && cp -r requirements /app/full \ + && cp requirements.txt /app/full \ + && cp .devops/tools.sh /app/full/tools.sh + +## Base image +FROM ${BASE_MUSA_RUN_CONTAINER} AS base + +RUN apt-get update \ + && apt-get install -y libgomp1 curl\ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + +COPY --from=build /app/lib/ /app + +### Full +FROM base AS full + +COPY --from=build /app/full /app + +WORKDIR /app + +RUN apt-get update \ + && apt-get install -y \ + git \ + python3 \ + python3-pip \ + && pip install --upgrade pip setuptools wheel \ + && pip install -r requirements.txt \ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + + +ENTRYPOINT ["/app/tools.sh"] + +### Light, CLI only +FROM base AS light + +COPY --from=build /app/full/llama-cli /app + +WORKDIR /app + +ENTRYPOINT [ "/app/llama-cli" ] + +### Server, Server only +FROM base AS server + +ENV LLAMA_ARG_HOST=0.0.0.0 + +COPY --from=build /app/full/llama-server /app + +WORKDIR /app + +HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] + +ENTRYPOINT [ "/app/llama-server" ] diff --git a/.devops/nix/apps.nix b/.devops/nix/apps.nix new file mode 100644 index 0000000000000000000000000000000000000000..0ecf19fc56d554c69aeef8a03b253fc15338688e --- /dev/null +++ b/.devops/nix/apps.nix @@ -0,0 +1,21 @@ +{ + perSystem = + { config, lib, ... }: + { + apps = + let + inherit (config.packages) default; + binaries = [ + "llama-cli" + "llama-embedding" + "llama-server" + "llama-quantize" + ]; + mkApp = name: { + type = "app"; + program = "${default}/bin/${name}"; + }; + in + lib.genAttrs binaries mkApp; + }; +} diff --git a/.devops/nix/devshells.nix b/.devops/nix/devshells.nix new file mode 100644 index 0000000000000000000000000000000000000000..bfd304af14dcda59e284d30b9fc7491466bf71d1 --- /dev/null +++ b/.devops/nix/devshells.nix @@ -0,0 +1,52 @@ +{ inputs, ... }: + +{ + perSystem = + { + config, + lib, + system, + ... + }: + { + devShells = + let + pkgs = import inputs.nixpkgs { inherit system; }; + stdenv = pkgs.stdenv; + scripts = config.packages.python-scripts; + in + lib.pipe (config.packages) [ + (lib.concatMapAttrs ( + name: package: { + ${name} = pkgs.mkShell { + name = "${name}"; + inputsFrom = [ package ]; + shellHook = '' + echo "Entering ${name} devShell" + ''; + }; + "${name}-extra" = + if (name == "python-scripts") then + null + else + pkgs.mkShell { + name = "${name}-extra"; + inputsFrom = [ + package + scripts + ]; + # Extra packages that *may* be used by some scripts + packages = [ + pkgs.python3Packages.tiktoken + ]; + shellHook = '' + echo "Entering ${name} devShell" + addToSearchPath "LD_LIBRARY_PATH" "${lib.getLib stdenv.cc.cc}/lib" + ''; + }; + } + )) + (lib.filterAttrs (name: value: value != null)) + ]; + }; +} diff --git a/.devops/nix/docker.nix b/.devops/nix/docker.nix new file mode 100644 index 0000000000000000000000000000000000000000..d607b4575772c5330e962649ff6e14e5562ecfad --- /dev/null +++ b/.devops/nix/docker.nix @@ -0,0 +1,37 @@ +{ + lib, + dockerTools, + buildEnv, + llama-cpp, + interactive ? true, + coreutils, +}: + +# A tar that can be fed into `docker load`: +# +# $ nix build .#llamaPackages.docker +# $ docker load < result + +# For details and variations cf. +# - https://nixos.org/manual/nixpkgs/unstable/#ssec-pkgs-dockerTools-buildLayeredImage +# - https://discourse.nixos.org/t/a-faster-dockertools-buildimage-prototype/16922 +# - https://nixery.dev/ + +# Approximate (compressed) sizes, at the time of writing, are: +# +# .#llamaPackages.docker: 125M; +# .#llamaPackagesCuda.docker: 537M; +# .#legacyPackages.aarch64-linux.llamaPackagesXavier.docker: 415M. + +dockerTools.buildLayeredImage { + name = llama-cpp.pname; + tag = "latest"; + + contents = + [ llama-cpp ] + ++ lib.optionals interactive [ + coreutils + dockerTools.binSh + dockerTools.caCertificates + ]; +} diff --git a/.devops/nix/jetson-support.nix b/.devops/nix/jetson-support.nix new file mode 100644 index 0000000000000000000000000000000000000000..78e2e40e03864e3df046389f7b751a1fd4575656 --- /dev/null +++ b/.devops/nix/jetson-support.nix @@ -0,0 +1,39 @@ +{ inputs, ... }: +{ + perSystem = + { + config, + system, + lib, + pkgsCuda, + ... + }: + { + legacyPackages = + let + caps.llamaPackagesXavier = "7.2"; + caps.llamaPackagesOrin = "8.7"; + caps.llamaPackagesTX2 = "6.2"; + caps.llamaPackagesNano = "5.3"; + + pkgsFor = + cap: + import inputs.nixpkgs { + inherit system; + config = { + cudaSupport = true; + cudaCapabilities = [ cap ]; + cudaEnableForwardCompat = false; + inherit (pkgsCuda.config) allowUnfreePredicate; + }; + }; + in + builtins.mapAttrs (name: cap: (pkgsFor cap).callPackage ./scope.nix { }) caps; + + packages = lib.optionalAttrs (system == "aarch64-linux") { + jetson-xavier = config.legacyPackages.llamaPackagesXavier.llama-cpp; + jetson-orin = config.legacyPackages.llamaPackagesOrin.llama-cpp; + jetson-nano = config.legacyPackages.llamaPackagesNano.llama-cpp; + }; + }; +} diff --git a/.devops/nix/nixpkgs-instances.nix b/.devops/nix/nixpkgs-instances.nix new file mode 100644 index 0000000000000000000000000000000000000000..90d683a713aa1f6f7b7e263549afe5edef0a4033 --- /dev/null +++ b/.devops/nix/nixpkgs-instances.nix @@ -0,0 +1,45 @@ +{ inputs, ... }: +{ + # The _module.args definitions are passed on to modules as arguments. E.g. + # the module `{ pkgs ... }: { /* config */ }` implicitly uses + # `_module.args.pkgs` (defined in this case by flake-parts). + perSystem = + { system, ... }: + { + _module.args = { + # Note: bringing up https://zimbatm.com/notes/1000-instances-of-nixpkgs + # again, the below creates several nixpkgs instances which the + # flake-centric CLI will be forced to evaluate e.g. on `nix flake show`. + # + # This is currently "slow" and "expensive", on a certain scale. + # This also isn't "right" in that this hinders dependency injection at + # the level of flake inputs. This might get removed in the foreseeable + # future. + # + # Note that you can use these expressions without Nix + # (`pkgs.callPackage ./devops/nix/scope.nix { }` is the entry point). + + pkgsCuda = import inputs.nixpkgs { + inherit system; + # Ensure dependencies use CUDA consistently (e.g. that openmpi, ucc, + # and ucx are built with CUDA support) + config.cudaSupport = true; + config.allowUnfreePredicate = + p: + builtins.all ( + license: + license.free + || builtins.elem license.shortName [ + "CUDA EULA" + "cuDNN EULA" + ] + ) (p.meta.licenses or [ p.meta.license ]); + }; + # Ensure dependencies use ROCm consistently + pkgsRocm = import inputs.nixpkgs { + inherit system; + config.rocmSupport = true; + }; + }; + }; +} diff --git a/.devops/nix/package-gguf-py.nix b/.devops/nix/package-gguf-py.nix new file mode 100644 index 0000000000000000000000000000000000000000..cca2f36a5bd4df214b3b7da84aa7a05d242fce0e --- /dev/null +++ b/.devops/nix/package-gguf-py.nix @@ -0,0 +1,36 @@ +{ + lib, + llamaVersion, + numpy, + tqdm, + sentencepiece, + pyyaml, + poetry-core, + buildPythonPackage, + pytestCheckHook, +}: + +buildPythonPackage { + pname = "gguf"; + version = llamaVersion; + pyproject = true; + nativeBuildInputs = [ poetry-core ]; + propagatedBuildInputs = [ + numpy + tqdm + sentencepiece + pyyaml + ]; + src = lib.cleanSource ../../gguf-py; + pythonImportsCheck = [ + "numpy" + "gguf" + ]; + nativeCheckInputs = [ pytestCheckHook ]; + doCheck = true; + meta = with lib; { + description = "Python package for writing binary files in the GGUF format"; + license = licenses.mit; + maintainers = [ maintainers.ditsuke ]; + }; +} diff --git a/.devops/nix/package.nix b/.devops/nix/package.nix new file mode 100644 index 0000000000000000000000000000000000000000..6e8050a49963528323e074b1e9d8bf32d618ed4b --- /dev/null +++ b/.devops/nix/package.nix @@ -0,0 +1,247 @@ +{ + lib, + glibc, + config, + stdenv, + runCommand, + cmake, + ninja, + pkg-config, + git, + mpi, + blas, + cudaPackages, + autoAddDriverRunpath, + darwin, + rocmPackages, + vulkan-headers, + vulkan-loader, + curl, + shaderc, + useBlas ? + builtins.all (x: !x) [ + useCuda + useMetalKit + useRocm + useVulkan + ] + && blas.meta.available, + useCuda ? config.cudaSupport, + useMetalKit ? stdenv.isAarch64 && stdenv.isDarwin, + # Increases the runtime closure size by ~700M + useMpi ? false, + useRocm ? config.rocmSupport, + rocmGpuTargets ? builtins.concatStringsSep ";" rocmPackages.clr.gpuTargets, + enableCurl ? true, + useVulkan ? false, + llamaVersion ? "0.0.0", # Arbitrary version, substituted by the flake + + # It's necessary to consistently use backendStdenv when building with CUDA support, + # otherwise we get libstdc++ errors downstream. + effectiveStdenv ? if useCuda then cudaPackages.backendStdenv else stdenv, + enableStatic ? effectiveStdenv.hostPlatform.isStatic, + precompileMetalShaders ? false, +}: + +let + inherit (lib) + cmakeBool + cmakeFeature + optionals + strings + ; + + stdenv = throw "Use effectiveStdenv instead"; + + suffices = + lib.optionals useBlas [ "BLAS" ] + ++ lib.optionals useCuda [ "CUDA" ] + ++ lib.optionals useMetalKit [ "MetalKit" ] + ++ lib.optionals useMpi [ "MPI" ] + ++ lib.optionals useRocm [ "ROCm" ] + ++ lib.optionals useVulkan [ "Vulkan" ]; + + pnameSuffix = + strings.optionalString (suffices != [ ]) + "-${strings.concatMapStringsSep "-" strings.toLower suffices}"; + descriptionSuffix = strings.optionalString ( + suffices != [ ] + ) ", accelerated with ${strings.concatStringsSep ", " suffices}"; + + xcrunHost = runCommand "xcrunHost" { } '' + mkdir -p $out/bin + ln -s /usr/bin/xcrun $out/bin + ''; + + # apple_sdk is supposed to choose sane defaults, no need to handle isAarch64 + # separately + darwinBuildInputs = + with darwin.apple_sdk.frameworks; + [ + Accelerate + CoreVideo + CoreGraphics + ] + ++ optionals useMetalKit [ MetalKit ]; + + cudaBuildInputs = with cudaPackages; [ + cuda_cudart + cuda_cccl # + libcublas + ]; + + rocmBuildInputs = with rocmPackages; [ + clr + hipblas + rocblas + ]; + + vulkanBuildInputs = [ + vulkan-headers + vulkan-loader + shaderc + ]; +in + +effectiveStdenv.mkDerivation (finalAttrs: { + pname = "llama-cpp${pnameSuffix}"; + version = llamaVersion; + + # Note: none of the files discarded here are visible in the sandbox or + # affect the output hash. This also means they can be modified without + # triggering a rebuild. + src = lib.cleanSourceWith { + filter = + name: type: + let + noneOf = builtins.all (x: !x); + baseName = baseNameOf name; + in + noneOf [ + (lib.hasSuffix ".nix" name) # Ignore *.nix files when computing outPaths + (lib.hasSuffix ".md" name) # Ignore *.md changes whe computing outPaths + (lib.hasPrefix "." baseName) # Skip hidden files and directories + (baseName == "flake.lock") + ]; + src = lib.cleanSource ../../.; + }; + + postPatch = '' + substituteInPlace ./ggml/src/ggml-metal/ggml-metal.m \ + --replace '[bundle pathForResource:@"ggml-metal" ofType:@"metal"];' "@\"$out/bin/ggml-metal.metal\";" + substituteInPlace ./ggml/src/ggml-metal/ggml-metal.m \ + --replace '[bundle pathForResource:@"default" ofType:@"metallib"];' "@\"$out/bin/default.metallib\";" + ''; + + # With PR#6015 https://github.com/ggml-org/llama.cpp/pull/6015, + # `default.metallib` may be compiled with Metal compiler from XCode + # and we need to escape sandbox on MacOS to access Metal compiler. + # `xcrun` is used find the path of the Metal compiler, which is varible + # and not on $PATH + # see https://github.com/ggml-org/llama.cpp/pull/6118 for discussion + __noChroot = effectiveStdenv.isDarwin && useMetalKit && precompileMetalShaders; + + nativeBuildInputs = + [ + cmake + ninja + pkg-config + git + ] + ++ optionals useCuda [ + cudaPackages.cuda_nvcc + + autoAddDriverRunpath + ] + ++ optionals (effectiveStdenv.hostPlatform.isGnu && enableStatic) [ glibc.static ] + ++ optionals (effectiveStdenv.isDarwin && useMetalKit && precompileMetalShaders) [ xcrunHost ]; + + buildInputs = + optionals effectiveStdenv.isDarwin darwinBuildInputs + ++ optionals useCuda cudaBuildInputs + ++ optionals useMpi [ mpi ] + ++ optionals useRocm rocmBuildInputs + ++ optionals useBlas [ blas ] + ++ optionals useVulkan vulkanBuildInputs + ++ optionals enableCurl [ curl ]; + + cmakeFlags = + [ + (cmakeBool "LLAMA_BUILD_SERVER" true) + (cmakeBool "BUILD_SHARED_LIBS" (!enableStatic)) + (cmakeBool "CMAKE_SKIP_BUILD_RPATH" true) + (cmakeBool "LLAMA_CURL" enableCurl) + (cmakeBool "GGML_NATIVE" false) + (cmakeBool "GGML_BLAS" useBlas) + (cmakeBool "GGML_CUDA" useCuda) + (cmakeBool "GGML_HIP" useRocm) + (cmakeBool "GGML_METAL" useMetalKit) + (cmakeBool "GGML_VULKAN" useVulkan) + (cmakeBool "GGML_STATIC" enableStatic) + ] + ++ optionals useCuda [ + ( + with cudaPackages.flags; + cmakeFeature "CMAKE_CUDA_ARCHITECTURES" ( + builtins.concatStringsSep ";" (map dropDot cudaCapabilities) + ) + ) + ] + ++ optionals useRocm [ + (cmakeFeature "CMAKE_HIP_COMPILER" "${rocmPackages.llvm.clang}/bin/clang") + (cmakeFeature "CMAKE_HIP_ARCHITECTURES" rocmGpuTargets) + ] + ++ optionals useMetalKit [ + (lib.cmakeFeature "CMAKE_C_FLAGS" "-D__ARM_FEATURE_DOTPROD=1") + (cmakeBool "GGML_METAL_EMBED_LIBRARY" (!precompileMetalShaders)) + ]; + + # Environment variables needed for ROCm + env = optionals useRocm { + ROCM_PATH = "${rocmPackages.clr}"; + HIP_DEVICE_LIB_PATH = "${rocmPackages.rocm-device-libs}/amdgcn/bitcode"; + }; + + # TODO(SomeoneSerge): It's better to add proper install targets at the CMake level, + # if they haven't been added yet. + postInstall = '' + mkdir -p $out/include + cp $src/include/llama.h $out/include/ + ''; + + meta = { + # Configurations we don't want even the CI to evaluate. Results in the + # "unsupported platform" messages. This is mostly a no-op, because + # cudaPackages would've refused to evaluate anyway. + badPlatforms = optionals useCuda lib.platforms.darwin; + + # Configurations that are known to result in build failures. Can be + # overridden by importing Nixpkgs with `allowBroken = true`. + broken = (useMetalKit && !effectiveStdenv.isDarwin); + + description = "Inference of LLaMA model in pure C/C++${descriptionSuffix}"; + homepage = "https://github.com/ggml-org/llama.cpp/"; + license = lib.licenses.mit; + + # Accommodates `nix run` and `lib.getExe` + mainProgram = "llama-cli"; + + # These people might respond, on the best effort basis, if you ping them + # in case of Nix-specific regressions or for reviewing Nix-specific PRs. + # Consider adding yourself to this list if you want to ensure this flake + # stays maintained and you're willing to invest your time. Do not add + # other people without their consent. Consider removing people after + # they've been unreachable for long periods of time. + + # Note that lib.maintainers is defined in Nixpkgs, but you may just add + # an attrset following the same format as in + # https://github.com/NixOS/nixpkgs/blob/f36a80e54da29775c78d7eff0e628c2b4e34d1d7/maintainers/maintainer-list.nix + maintainers = with lib.maintainers; [ + philiptaron + SomeoneSerge + ]; + + # Extend `badPlatforms` instead + platforms = lib.platforms.all; + }; +}) diff --git a/.devops/nix/python-scripts.nix b/.devops/nix/python-scripts.nix new file mode 100644 index 0000000000000000000000000000000000000000..56ea1827887646424fa08017fe91c4bdcc82c465 --- /dev/null +++ b/.devops/nix/python-scripts.nix @@ -0,0 +1,66 @@ +{ + lib, + stdenv, + buildPythonPackage, + poetry-core, + mkShell, + python3Packages, + gguf-py, +}@inputs: + +let + llama-python-deps = with python3Packages; [ + numpy + sentencepiece + transformers + protobuf + torchWithoutCuda + gguf-py + tqdm + + # for scripts/compare-llama-bench.py + gitpython + tabulate + + # for examples/pydantic-models-to-grammar-examples.py + docstring-parser + pydantic + + ]; + + llama-python-test-deps = with python3Packages; [ + # Server bench + matplotlib + + # server tests + openai + pytest + prometheus-client + ]; +in + +buildPythonPackage ({ + pname = "llama-scripts"; + version = "0.0.0"; + pyproject = true; + + # NOTE: The files filtered out here are not visible in the build sandbox, neither + # do they affect the output hash. They can be modified without triggering a rebuild. + src = lib.cleanSourceWith { + filter = + name: type: + let + any = builtins.any (x: x); + baseName = builtins.baseNameOf name; + in + any [ + (lib.hasSuffix ".py" name) + (baseName == "README.md") + (baseName == "pyproject.toml") + ]; + src = lib.cleanSource ../../.; + }; + nativeBuildInputs = [ poetry-core ]; + nativeCheckInputs = llama-python-test-deps; + dependencies = llama-python-deps; +}) diff --git a/.devops/nix/scope.nix b/.devops/nix/scope.nix new file mode 100644 index 0000000000000000000000000000000000000000..478e8c4228afad449c3bd396d7ed4a12759a5de3 --- /dev/null +++ b/.devops/nix/scope.nix @@ -0,0 +1,41 @@ +{ + lib, + newScope, + python3, + llamaVersion ? "0.0.0", +}: + +let + pythonPackages = python3.pkgs; + buildPythonPackage = pythonPackages.buildPythonPackage; + numpy = pythonPackages.numpy; + tqdm = pythonPackages.tqdm; + sentencepiece = pythonPackages.sentencepiece; + pyyaml = pythonPackages.pyyaml; + poetry-core = pythonPackages.poetry-core; + pytestCheckHook = pythonPackages.pytestCheckHook; +in + +# We're using `makeScope` instead of just writing out an attrset +# because it allows users to apply overlays later using `overrideScope'`. +# Cf. https://noogle.dev/f/lib/makeScope + +lib.makeScope newScope (self: { + inherit llamaVersion; + gguf-py = self.callPackage ./package-gguf-py.nix { + inherit + buildPythonPackage + numpy + tqdm + sentencepiece + poetry-core + pyyaml + pytestCheckHook + ; + }; + python-scripts = self.callPackage ./python-scripts.nix { inherit buildPythonPackage poetry-core; }; + llama-cpp = self.callPackage ./package.nix { }; + docker = self.callPackage ./docker.nix { }; + docker-min = self.callPackage ./docker.nix { interactive = false; }; + sif = self.callPackage ./sif.nix { }; +}) diff --git a/.devops/nix/sif.nix b/.devops/nix/sif.nix new file mode 100644 index 0000000000000000000000000000000000000000..7a5e1dd0ffc4c61e9b88b25d14a10afbd4f8cda9 --- /dev/null +++ b/.devops/nix/sif.nix @@ -0,0 +1,27 @@ +{ + lib, + singularity-tools, + llama-cpp, + bashInteractive, + interactive ? false, +}: + +let + optionalInt = cond: x: if cond then x else 0; +in +singularity-tools.buildImage rec { + inherit (llama-cpp) name; + contents = [ llama-cpp ] ++ lib.optionals interactive [ bashInteractive ]; + + # These are excessive (but safe) for most variants. Building singularity + # images requires superuser privileges, so we build them inside a VM in a + # writable image of pre-determined size. + # + # ROCm is currently affected by https://github.com/NixOS/nixpkgs/issues/276846 + # + # Expected image sizes: + # - cpu/blas: 150M, + # - cuda, all gencodes: 560M, + diskSize = 4096 + optionalInt llama-cpp.useRocm 16384; + memSize = diskSize; +} diff --git a/.devops/rocm.Dockerfile b/.devops/rocm.Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..1c00f1b9c2cd319c071ad19a92d27b9f9a7f23ff --- /dev/null +++ b/.devops/rocm.Dockerfile @@ -0,0 +1,113 @@ +ARG UBUNTU_VERSION=24.04 + +# This needs to generally match the container host's environment. +ARG ROCM_VERSION=6.3 +ARG AMDGPU_VERSION=6.3 + +# Target the CUDA build image +ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete + +### Build image +FROM ${BASE_ROCM_DEV_CONTAINER} AS build + +# Unless otherwise specified, we make a fat build. +# List from https://github.com/ggml-org/llama.cpp/pull/1087#issuecomment-1682807878 +# This is mostly tied to rocBLAS supported archs. +# gfx803, gfx900, gfx1032, gfx1101, gfx1102,not officialy supported +# gfx906 is deprecated +#check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.2.4/reference/system-requirements.html + +ARG ROCM_DOCKER_ARCH='gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102' +#ARG ROCM_DOCKER_ARCH=gfx1100 + +# Set nvcc architectured +ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH} +# Enable ROCm +# ENV CC=/opt/rocm/llvm/bin/clang +# ENV CXX=/opt/rocm/llvm/bin/clang++ + +RUN apt-get update \ + && apt-get install -y \ + build-essential \ + cmake \ + git \ + libcurl4-openssl-dev \ + curl \ + libgomp1 + +WORKDIR /app + +COPY . . + +RUN HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \ + cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=$ROCM_DOCKER_ARCH -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DCMAKE_BUILD_TYPE=Release -DLLAMA_BUILD_TESTS=OFF \ + && cmake --build build --config Release -j$(nproc) + +RUN mkdir -p /app/lib \ + && find build -name "*.so" -exec cp {} /app/lib \; + +RUN mkdir -p /app/full \ + && cp build/bin/* /app/full \ + && cp *.py /app/full \ + && cp -r gguf-py /app/full \ + && cp -r requirements /app/full \ + && cp requirements.txt /app/full \ + && cp .devops/tools.sh /app/full/tools.sh + +## Base image +FROM ${BASE_ROCM_DEV_CONTAINER} AS base + +RUN apt-get update \ + && apt-get install -y libgomp1 curl\ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + +COPY --from=build /app/lib/ /app + +### Full +FROM base AS full + +COPY --from=build /app/full /app + +WORKDIR /app + +RUN apt-get update \ + && apt-get install -y \ + git \ + python3-pip \ + python3 \ + python3-wheel\ + && pip install --break-system-packages --upgrade setuptools \ + && pip install --break-system-packages -r requirements.txt \ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + +ENTRYPOINT ["/app/tools.sh"] + +### Light, CLI only +FROM base AS light + +COPY --from=build /app/full/llama-cli /app + +WORKDIR /app + +ENTRYPOINT [ "/app/llama-cli" ] + +### Server, Server only +FROM base AS server + +ENV LLAMA_ARG_HOST=0.0.0.0 + +COPY --from=build /app/full/llama-server /app + +WORKDIR /app + +HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] + +ENTRYPOINT [ "/app/llama-server" ] diff --git a/.devops/tools.sh b/.devops/tools.sh new file mode 100644 index 0000000000000000000000000000000000000000..41a6b1e55c7d295c9075a24eeaf2b8b0bc8eaa01 --- /dev/null +++ b/.devops/tools.sh @@ -0,0 +1,49 @@ +#!/bin/bash +set -e + +# Read the first argument into a variable +arg1="$1" + +# Shift the arguments to remove the first one +shift + +if [[ "$arg1" == '--convert' || "$arg1" == '-c' ]]; then + exec python3 ./convert_hf_to_gguf.py "$@" +elif [[ "$arg1" == '--quantize' || "$arg1" == '-q' ]]; then + exec ./llama-quantize "$@" +elif [[ "$arg1" == '--run' || "$arg1" == '-r' ]]; then + exec ./llama-cli "$@" +elif [[ "$arg1" == '--bench' || "$arg1" == '-b' ]]; then + exec ./llama-bench "$@" +elif [[ "$arg1" == '--perplexity' || "$arg1" == '-p' ]]; then + exec ./llama-perplexity "$@" +elif [[ "$arg1" == '--all-in-one' || "$arg1" == '-a' ]]; then + echo "Converting PTH to GGML..." + for i in $(ls $1/$2/ggml-model-f16.bin*); do + if [ -f "${i/f16/q4_0}" ]; then + echo "Skip model quantization, it already exists: ${i/f16/q4_0}" + else + echo "Converting PTH to GGML: $i into ${i/f16/q4_0}..." + exec ./llama-quantize "$i" "${i/f16/q4_0}" q4_0 + fi + done +elif [[ "$arg1" == '--server' || "$arg1" == '-s' ]]; then + exec ./llama-server "$@" +else + echo "Unknown command: $arg1" + echo "Available commands: " + echo " --run (-r): Run a model previously converted into ggml" + echo " ex: -m /models/7B/ggml-model-q4_0.bin -p \"Building a website can be done in 10 simple steps:\" -n 512" + echo " --bench (-b): Benchmark the performance of the inference for various parameters." + echo " ex: -m model.gguf" + echo " --perplexity (-p): Measure the perplexity of a model over a given text." + echo " ex: -m model.gguf -f file.txt" + echo " --convert (-c): Convert a llama model into ggml" + echo " ex: --outtype f16 \"/models/7B/\" " + echo " --quantize (-q): Optimize with quantization process ggml" + echo " ex: \"/models/7B/ggml-model-f16.bin\" \"/models/7B/ggml-model-q4_0.bin\" 2" + echo " --all-in-one (-a): Execute --convert & --quantize" + echo " ex: \"/models/\" 7B" + echo " --server (-s): Run a model on the server" + echo " ex: -m /models/7B/ggml-model-q4_0.bin -c 2048 -ngl 43 -mg 1 --port 8080" +fi diff --git a/.devops/vulkan.Dockerfile b/.devops/vulkan.Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..fcd81ffa1e94e120107eadb9117d703268cce264 --- /dev/null +++ b/.devops/vulkan.Dockerfile @@ -0,0 +1,89 @@ +ARG UBUNTU_VERSION=24.04 + +FROM ubuntu:$UBUNTU_VERSION AS build + +# Install build tools +RUN apt update && apt install -y git build-essential cmake wget + +# Install Vulkan SDK and cURL +RUN wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add - && \ + wget -qO /etc/apt/sources.list.d/lunarg-vulkan-noble.list https://packages.lunarg.com/vulkan/lunarg-vulkan-noble.list && \ + apt update -y && \ + apt-get install -y vulkan-sdk libcurl4-openssl-dev curl + +# Build it +WORKDIR /app + +COPY . . + +RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_VULKAN=1 -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON && \ + cmake --build build --config Release -j$(nproc) + +RUN mkdir -p /app/lib && \ + find build -name "*.so" -exec cp {} /app/lib \; + +RUN mkdir -p /app/full \ + && cp build/bin/* /app/full \ + && cp *.py /app/full \ + && cp -r gguf-py /app/full \ + && cp -r requirements /app/full \ + && cp requirements.txt /app/full \ + && cp .devops/tools.sh /app/full/tools.sh + +## Base image +FROM ubuntu:$UBUNTU_VERSION AS base + +RUN apt-get update \ + && apt-get install -y libgomp1 curl libvulkan-dev \ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + +COPY --from=build /app/lib/ /app + +### Full +FROM base AS full + +COPY --from=build /app/full /app + +WORKDIR /app + +RUN apt-get update \ + && apt-get install -y \ + git \ + python3 \ + python3-pip \ + python3-wheel \ + && pip install --break-system-packages --upgrade setuptools \ + && pip install --break-system-packages -r requirements.txt \ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + +ENTRYPOINT ["/app/tools.sh"] + +### Light, CLI only +FROM base AS light + +COPY --from=build /app/full/llama-cli /app + +WORKDIR /app + +ENTRYPOINT [ "/app/llama-cli" ] + +### Server, Server only +FROM base AS server + +ENV LLAMA_ARG_HOST=0.0.0.0 + +COPY --from=build /app/full/llama-server /app + +WORKDIR /app + +HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] + +ENTRYPOINT [ "/app/llama-server" ] diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..064b7c7be86d02d7bbc17b17a0ca40f638d66611 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,20 @@ +*.o +*.a +.cache/ +# Do not ignore .git directory, otherwise the reported build number will always be 0 +.github/ +.gitignore +.vs/ +.vscode/ +.DS_Store + +build*/ + +models/* + +/llama-cli +/llama-quantize + +arm_neon.h +compile_commands.json +Dockerfile diff --git a/.ecrc b/.ecrc new file mode 100644 index 0000000000000000000000000000000000000000..c68877ec211f1c7f9458b8d59c278fc7e88f38c8 --- /dev/null +++ b/.ecrc @@ -0,0 +1,6 @@ +{ + "Exclude": ["^\\.gitmodules$", "stb_image\\.h"], + "Disable": { + "IndentSize": true + } +} diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000000000000000000000000000000000000..c90b171f5567681f0426bf499d8acc6a95e1e5a6 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,54 @@ +# https://EditorConfig.org + +# Top-most EditorConfig file +root = true + +# Unix-style newlines with a newline ending every file, utf-8 charset +[*] +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true +charset = utf-8 +indent_style = space +indent_size = 4 + +[Makefile] +indent_style = tab + +[scripts/*.mk] +indent_style = tab + +[prompts/*.txt] +insert_final_newline = unset + +[tools/server/public/*] +indent_size = 2 + +[tools/server/public/deps_*] +trim_trailing_whitespace = unset +indent_style = unset +indent_size = unset + +[tools/server/deps_*] +trim_trailing_whitespace = unset +indent_style = unset +indent_size = unset + +[examples/llama.swiftui/llama.swiftui.xcodeproj/*] +indent_style = tab + +[tools/cvector-generator/*.txt] +trim_trailing_whitespace = unset +insert_final_newline = unset + +[models/templates/*.jinja] +indent_style = unset +indent_size = unset +end_of_line = unset +charset = unset +trim_trailing_whitespace = unset +insert_final_newline = unset + +[vendor/miniaudio/miniaudio.h] +trim_trailing_whitespace = unset +insert_final_newline = unset diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000000000000000000000000000000000000..669d231f1f63bf661bb9c380ea901a43dd97ca6f --- /dev/null +++ b/.flake8 @@ -0,0 +1,18 @@ +[flake8] +max-line-length = 125 +ignore = E203,E211,E221,E225,E231,E241,E251,E261,E266,E501,E701,E704,W503 +exclude = + # Do not traverse examples and tools + examples, + tools, + # Do not include package initializers + __init__.py, + # No need to traverse our git directory + .git, + # There's no value in checking cache directories + __pycache__, + # No need to include the build path + build, + # This contains builds that we don't want to check + dist # This is generated with `python build .` for package releases +# max-complexity = 10 diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..3c92951e7c53593ed62f29639a69e765f16cee00 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,42 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +docs/development/llama-star/idea-arch.key filter=lfs diff=lfs merge=lfs -text +ggml/src/ggml-kompute/kompute/docs/images/android-editor.jpg filter=lfs diff=lfs merge=lfs -text +ggml/src/ggml-kompute/kompute/docs/images/binder-cpp.jpg filter=lfs diff=lfs merge=lfs -text +ggml/src/ggml-kompute/kompute/docs/images/binder-python.jpg filter=lfs diff=lfs merge=lfs -text +ggml/src/ggml-kompute/kompute/docs/images/kompute-cpp-video.png filter=lfs diff=lfs merge=lfs -text +ggml/src/ggml-kompute/kompute/docs/images/kompute-python-video.png filter=lfs diff=lfs merge=lfs -text +ggml/src/ggml-kompute/kompute/docs/images/kompute-vulkan-architecture.jpg filter=lfs diff=lfs merge=lfs -text +ggml/src/ggml-kompute/kompute/docs/images/komputer-2.gif filter=lfs diff=lfs merge=lfs -text +ggml/src/ggml-kompute/kompute/docs/images/komputer-godot-4.gif filter=lfs diff=lfs merge=lfs -text +ggml/src/ggml-kompute/kompute/docs/images/komputer-logos.gif filter=lfs diff=lfs merge=lfs -text +ggml/src/ggml-kompute/kompute/docs/images/logistic-regression.jpg filter=lfs diff=lfs merge=lfs -text +ggml/src/ggml-kompute/kompute/docs/images/queue-allocation.jpg filter=lfs diff=lfs merge=lfs -text +ggml/src/ggml-kompute/kompute/examples/android/android-simple/app/src/main/assets/komputer-2.gif filter=lfs diff=lfs merge=lfs -text +ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/assets/roboto.ttf filter=lfs diff=lfs merge=lfs -text +media/llama0-banner.png filter=lfs diff=lfs merge=lfs -text +media/llama0-logo.png filter=lfs diff=lfs merge=lfs -text +media/matmul.png filter=lfs diff=lfs merge=lfs -text +models/ggml-vocab-aquila.gguf filter=lfs diff=lfs merge=lfs -text +models/ggml-vocab-baichuan.gguf filter=lfs diff=lfs merge=lfs -text +models/ggml-vocab-bert-bge.gguf filter=lfs diff=lfs merge=lfs -text +models/ggml-vocab-command-r.gguf filter=lfs diff=lfs merge=lfs -text +models/ggml-vocab-deepseek-coder.gguf filter=lfs diff=lfs merge=lfs -text +models/ggml-vocab-deepseek-llm.gguf filter=lfs diff=lfs merge=lfs -text +models/ggml-vocab-falcon.gguf filter=lfs diff=lfs merge=lfs -text +models/ggml-vocab-gpt-2.gguf filter=lfs diff=lfs merge=lfs -text +models/ggml-vocab-gpt-neox.gguf filter=lfs diff=lfs merge=lfs -text +models/ggml-vocab-llama-bpe.gguf filter=lfs diff=lfs merge=lfs -text +models/ggml-vocab-llama-spm.gguf filter=lfs diff=lfs merge=lfs -text +models/ggml-vocab-mpt.gguf filter=lfs diff=lfs merge=lfs -text +models/ggml-vocab-nomic-bert-moe.gguf filter=lfs diff=lfs merge=lfs -text +models/ggml-vocab-phi-3.gguf filter=lfs diff=lfs merge=lfs -text +models/ggml-vocab-qwen2.gguf filter=lfs diff=lfs merge=lfs -text +models/ggml-vocab-refact.gguf filter=lfs diff=lfs merge=lfs -text +models/ggml-vocab-starcoder.gguf filter=lfs diff=lfs merge=lfs -text +tools/mtmd/test-1.jpeg filter=lfs diff=lfs merge=lfs -text +tools/mtmd/test-2.mp3 filter=lfs diff=lfs merge=lfs -text +tools/server/themes/buttons-top/buttons_top.png filter=lfs diff=lfs merge=lfs -text +tools/server/themes/wild/llamapattern.png filter=lfs diff=lfs merge=lfs -text +tools/server/themes/wild/wild.png filter=lfs diff=lfs merge=lfs -text diff --git a/.github/ISSUE_TEMPLATE/010-bug-compilation.yml b/.github/ISSUE_TEMPLATE/010-bug-compilation.yml new file mode 100644 index 0000000000000000000000000000000000000000..b85bf5741e5a38de8f5b40153e720ded2fb07835 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/010-bug-compilation.yml @@ -0,0 +1,87 @@ +name: Bug (compilation) +description: Something goes wrong when trying to compile llama.cpp. +title: "Compile bug: " +labels: ["bug-unconfirmed", "compilation"] +body: + - type: markdown + attributes: + value: > + Thanks for taking the time to fill out this bug report! + This issue template is intended for bug reports where the compilation of llama.cpp fails. + Before opening an issue, please confirm that the compilation still fails with `-DGGML_CCACHE=OFF`. + If the compilation succeeds with ccache disabled you should be able to permanently fix the issue + by clearing `~/.cache/ccache` (on Linux). + - type: textarea + id: commit + attributes: + label: Git commit + description: Which commit are you trying to compile? + placeholder: | + $git rev-parse HEAD + 84a07a17b1b08cf2b9747c633a2372782848a27f + validations: + required: true + - type: dropdown + id: operating-system + attributes: + label: Operating systems + description: Which operating systems do you know to be affected? + multiple: true + options: + - Linux + - Mac + - Windows + - BSD + - Other? (Please let us know in description) + validations: + required: true + - type: dropdown + id: backends + attributes: + label: GGML backends + description: Which GGML backends do you know to be affected? + options: [AMX, BLAS, CPU, CUDA, HIP, Kompute, Metal, Musa, RPC, SYCL, Vulkan] + multiple: true + validations: + required: true + - type: textarea + id: info + attributes: + label: Problem description & steps to reproduce + description: > + Please give us a summary of the problem and tell us how to reproduce it. + If you can narrow down the bug to specific compile flags, that information would be very much appreciated by us. + placeholder: > + I'm trying to compile llama.cpp with CUDA support on a fresh install of Ubuntu and get error XY. + Here are the exact commands that I used: ... + validations: + required: true + - type: textarea + id: first_bad_commit + attributes: + label: First Bad Commit + description: > + If the bug was not present on an earlier version: when did it start appearing? + If possible, please do a git bisect and identify the exact commit that introduced the bug. + validations: + required: false + - type: textarea + id: command + attributes: + label: Compile command + description: > + Please provide the exact command you used to compile llama.cpp. For example: `cmake -B ...`. + This will be automatically formatted into code, so no need for backticks. + render: shell + validations: + required: true + - type: textarea + id: logs + attributes: + label: Relevant log output + description: > + Please copy and paste any relevant log output, including any generated text. + This will be automatically formatted into code, so no need for backticks. + render: shell + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/011-bug-results.yml b/.github/ISSUE_TEMPLATE/011-bug-results.yml new file mode 100644 index 0000000000000000000000000000000000000000..1ccef0793d45e890195b104b837cb568ccb11dff --- /dev/null +++ b/.github/ISSUE_TEMPLATE/011-bug-results.yml @@ -0,0 +1,101 @@ +name: Bug (model use) +description: Something goes wrong when using a model (in general, not specific to a single llama.cpp module). +title: "Eval bug: " +labels: ["bug-unconfirmed", "model evaluation"] +body: + - type: markdown + attributes: + value: > + Thanks for taking the time to fill out this bug report! + This issue template is intended for bug reports where the model evaluation results + (i.e. the generated text) are incorrect or llama.cpp crashes during model evaluation. + If you encountered the issue while using an external UI (e.g. ollama), + please reproduce your issue using one of the examples/binaries in this repository. + The `llama-cli` binary can be used for simple and reproducible model inference. + - type: textarea + id: version + attributes: + label: Name and Version + description: Which version of our software are you running? (use `--version` to get a version string) + placeholder: | + $./llama-cli --version + version: 2999 (42b4109e) + built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu + validations: + required: true + - type: dropdown + id: operating-system + attributes: + label: Operating systems + description: Which operating systems do you know to be affected? + multiple: true + options: + - Linux + - Mac + - Windows + - BSD + - Other? (Please let us know in description) + validations: + required: true + - type: dropdown + id: backends + attributes: + label: GGML backends + description: Which GGML backends do you know to be affected? + options: [AMX, BLAS, CPU, CUDA, HIP, Kompute, Metal, Musa, RPC, SYCL, Vulkan] + multiple: true + validations: + required: true + - type: textarea + id: hardware + attributes: + label: Hardware + description: Which CPUs/GPUs are you using? + placeholder: > + e.g. Ryzen 5950X + 2x RTX 4090 + validations: + required: true + - type: textarea + id: model + attributes: + label: Models + description: > + Which model(s) at which quantization were you using when encountering the bug? + If you downloaded a GGUF file off of Huggingface, please provide a link. + placeholder: > + e.g. Meta LLaMA 3.1 Instruct 8b q4_K_M + validations: + required: false + - type: textarea + id: info + attributes: + label: Problem description & steps to reproduce + description: > + Please give us a summary of the problem and tell us how to reproduce it. + If you can narrow down the bug to specific hardware, compile flags, or command line arguments, + that information would be very much appreciated by us. + placeholder: > + e.g. when I run llama-cli with -ngl 99 I get garbled outputs. + When I use -ngl 0 it works correctly. + Here are the exact commands that I used: ... + validations: + required: true + - type: textarea + id: first_bad_commit + attributes: + label: First Bad Commit + description: > + If the bug was not present on an earlier version: when did it start appearing? + If possible, please do a git bisect and identify the exact commit that introduced the bug. + validations: + required: false + - type: textarea + id: logs + attributes: + label: Relevant log output + description: > + Please copy and paste any relevant log output, including the command that you entered and any generated text. + This will be automatically formatted into code, so no need for backticks. + render: shell + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/019-bug-misc.yml b/.github/ISSUE_TEMPLATE/019-bug-misc.yml new file mode 100644 index 0000000000000000000000000000000000000000..1904e31fdc4362502ecaf0a6c16d1b30f720e50c --- /dev/null +++ b/.github/ISSUE_TEMPLATE/019-bug-misc.yml @@ -0,0 +1,91 @@ +name: Bug (misc.) +description: Something is not working the way it should (and it's not covered by any of the above cases). +title: "Misc. bug: " +labels: ["bug-unconfirmed"] +body: + - type: markdown + attributes: + value: > + Thanks for taking the time to fill out this bug report! + This issue template is intended for miscellaneous bugs that don't fit into any other category. + If you encountered the issue while using an external UI (e.g. ollama), + please reproduce your issue using one of the examples/binaries in this repository. + - type: textarea + id: version + attributes: + label: Name and Version + description: Which version of our software is affected? (You can use `--version` to get a version string.) + placeholder: | + $./llama-cli --version + version: 2999 (42b4109e) + built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu + validations: + required: true + - type: dropdown + id: operating-system + attributes: + label: Operating systems + description: Which operating systems do you know to be affected? + multiple: true + options: + - Linux + - Mac + - Windows + - BSD + - Other? (Please let us know in description) + validations: + required: false + - type: dropdown + id: module + attributes: + label: Which llama.cpp modules do you know to be affected? + multiple: true + options: + - Documentation/Github + - libllama (core library) + - llama-cli + - llama-server + - llama-bench + - llama-quantize + - Python/Bash scripts + - Test code + - Other (Please specify in the next section) + validations: + required: false + - type: textarea + id: command + attributes: + label: Command line + description: > + Please provide the exact commands you entered, if applicable. For example: `llama-server -m ... -c ...`, `llama-cli -m ...`, etc. + This will be automatically formatted into code, so no need for backticks. + render: shell + validations: + required: false + - type: textarea + id: info + attributes: + label: Problem description & steps to reproduce + description: > + Please give us a summary of the problem and tell us how to reproduce it (if applicable). + validations: + required: true + - type: textarea + id: first_bad_commit + attributes: + label: First Bad Commit + description: > + If the bug was not present on an earlier version and it's not trivial to track down: when did it start appearing? + If possible, please do a git bisect and identify the exact commit that introduced the bug. + validations: + required: false + - type: textarea + id: logs + attributes: + label: Relevant log output + description: > + If applicable, please copy and paste any relevant log output, including any generated text. + This will be automatically formatted into code, so no need for backticks. + render: shell + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/020-enhancement.yml b/.github/ISSUE_TEMPLATE/020-enhancement.yml new file mode 100644 index 0000000000000000000000000000000000000000..cee1446f5a097a643e444cf204d22f512a665ada --- /dev/null +++ b/.github/ISSUE_TEMPLATE/020-enhancement.yml @@ -0,0 +1,51 @@ +name: Enhancement +description: Used to request enhancements for llama.cpp. +title: "Feature Request: " +labels: ["enhancement"] +body: + - type: markdown + attributes: + value: | + [Please post your idea first in Discussion if there is not yet a consensus for this enhancement request. This will help to keep this issue tracker focused on enhancements that the community has agreed needs to be implemented.](https://github.com/ggml-org/llama.cpp/discussions/categories/ideas) + + - type: checkboxes + id: prerequisites + attributes: + label: Prerequisites + description: Please confirm the following before submitting your enhancement request. + options: + - label: I am running the latest code. Mention the version if possible as well. + required: true + - label: I carefully followed the [README.md](https://github.com/ggml-org/llama.cpp/blob/master/README.md). + required: true + - label: I searched using keywords relevant to my issue to make sure that I am creating a new issue that is not already open (or closed). + required: true + - label: I reviewed the [Discussions](https://github.com/ggml-org/llama.cpp/discussions), and have a new and useful enhancement to share. + required: true + + - type: textarea + id: feature-description + attributes: + label: Feature Description + description: Please provide a detailed written description of what you were trying to do, and what you expected `llama.cpp` to do as an enhancement. + placeholder: Detailed description of the enhancement + validations: + required: true + + - type: textarea + id: motivation + attributes: + label: Motivation + description: Please provide a detailed written description of reasons why this feature is necessary and how it is useful to `llama.cpp` users. + placeholder: Explanation of why this feature is needed and its benefits + validations: + required: true + + - type: textarea + id: possible-implementation + attributes: + label: Possible Implementation + description: If you have an idea as to how it can be implemented, please write a detailed description. Feel free to give links to external sources or share visuals that might be helpful to understand the details better. + placeholder: Detailed description of potential implementation + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/030-research.yml b/.github/ISSUE_TEMPLATE/030-research.yml new file mode 100644 index 0000000000000000000000000000000000000000..e774550d5908c38aeb7a6823b7dc38d8e49cd3df --- /dev/null +++ b/.github/ISSUE_TEMPLATE/030-research.yml @@ -0,0 +1,52 @@ +name: Research +description: Track new technical research area. +title: "Research: " +labels: ["research 🔬"] +body: + - type: markdown + attributes: + value: | + Don't forget to check for any [duplicate research issue tickets](https://github.com/ggml-org/llama.cpp/issues?q=is%3Aopen+is%3Aissue+label%3A%22research+%F0%9F%94%AC%22) + + - type: checkboxes + id: research-stage + attributes: + label: Research Stage + description: Track general state of this research ticket + options: + - label: Background Research (Let's try to avoid reinventing the wheel) + - label: Hypothesis Formed (How do you think this will work and it's effect?) + - label: Strategy / Implementation Forming + - label: Analysis of results + - label: Debrief / Documentation (So people in the future can learn from us) + + - type: textarea + id: background + attributes: + label: Previous existing literature and research + description: Whats the current state of the art and whats the motivation for this research? + + - type: textarea + id: hypothesis + attributes: + label: Hypothesis + description: How do you think this will work and it's effect? + + - type: textarea + id: implementation + attributes: + label: Implementation + description: Got an approach? e.g. a PR ready to go? + + - type: textarea + id: analysis + attributes: + label: Analysis + description: How does the proposed implementation behave? + + - type: textarea + id: logs + attributes: + label: Relevant log output + description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks. + render: shell diff --git a/.github/ISSUE_TEMPLATE/040-refactor.yml b/.github/ISSUE_TEMPLATE/040-refactor.yml new file mode 100644 index 0000000000000000000000000000000000000000..2fe94e26c6988a9577fa30d91e88ed993e526a92 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/040-refactor.yml @@ -0,0 +1,28 @@ +name: Refactor (Maintainers) +description: Used to track refactoring opportunities. +title: "Refactor: " +labels: ["refactor"] +body: + - type: markdown + attributes: + value: | + Don't forget to [check for existing refactor issue tickets](https://github.com/ggml-org/llama.cpp/issues?q=is%3Aopen+is%3Aissue+label%3Arefactoring) in case it's already covered. + Also you may want to check [Pull request refactor label as well](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Aopen+is%3Apr+label%3Arefactoring) for duplicates too. + + - type: textarea + id: background-description + attributes: + label: Background Description + description: Please provide a detailed written description of the pain points you are trying to solve. + placeholder: Detailed description behind your motivation to request refactor + validations: + required: true + + - type: textarea + id: possible-approaches + attributes: + label: Possible Refactor Approaches + description: If you have some idea of possible approaches to solve this problem. You may want to make it a todo list. + placeholder: Your idea of possible refactoring opportunity/approaches + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..0d246533c95158ed3eb574f6a58fab59f02e3a8a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,11 @@ +blank_issues_enabled: true +contact_links: + - name: Got an idea? + url: https://github.com/ggml-org/llama.cpp/discussions/categories/ideas + about: Pop it there. It may then become an enhancement ticket. + - name: Got a question? + url: https://github.com/ggml-org/llama.cpp/discussions/categories/q-a + about: Ask a question there! + - name: Want to contribute? + url: https://github.com/ggml-org/llama.cpp/wiki/contribute + about: Head to the contribution guide page of the wiki for areas you can help with diff --git a/.github/actions/get-tag-name/action.yml b/.github/actions/get-tag-name/action.yml new file mode 100644 index 0000000000000000000000000000000000000000..7ace23b2a3e7642822091162f05f529e40dc2bc0 --- /dev/null +++ b/.github/actions/get-tag-name/action.yml @@ -0,0 +1,22 @@ +name: "Determine tag name" +description: "Determine the tag name to use for a release" +outputs: + name: + description: "The name of the tag" + value: ${{ steps.tag.outputs.name }} + +runs: + using: "composite" + steps: + - name: Determine tag name + id: tag + shell: bash + run: | + BUILD_NUMBER="$(git rev-list --count HEAD)" + SHORT_HASH="$(git rev-parse --short=7 HEAD)" + if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then + echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT + else + SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') + echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT + fi diff --git a/.github/actions/windows-setup-cuda/action.yml b/.github/actions/windows-setup-cuda/action.yml new file mode 100644 index 0000000000000000000000000000000000000000..5575caeca31a27c2b2edb00364cf5b6c4fc66a39 --- /dev/null +++ b/.github/actions/windows-setup-cuda/action.yml @@ -0,0 +1,67 @@ +name: "Windows - Setup CUDA Toolkit" +description: "Setup CUDA Toolkit for Windows" +inputs: + cuda_version: + description: "CUDA toolkit version" + required: true + +runs: + using: "composite" + steps: + - name: Install Cuda Toolkit 11.7 + if: ${{ inputs.cuda_version == '11.7' }} + shell: pwsh + run: | + mkdir -p "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" + choco install unzip -y + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-11.7.99-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-11.7.99-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-11.7.99-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcublas/windows-x86_64/libcublas-windows-x86_64-11.7.4.6-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-11.7.91-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-11.7.91-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-11.7.101-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-11.7.91-archive.zip" + unzip '*.zip' -d "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_cudart-windows-x86_64-11.7.99-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_nvcc-windows-x86_64-11.7.99-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_nvrtc-windows-x86_64-11.7.99-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\libcublas-windows-x86_64-11.7.4.6-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_nvtx-windows-x86_64-11.7.91-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\visual_studio_integration-windows-x86_64-11.7.91-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_nvprof-windows-x86_64-11.7.101-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_cccl-windows-x86_64-11.7.91-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 + echo "CUDA_PATH_V11_7=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 + + - name: Install Cuda Toolkit 12.4 + if: ${{ inputs.cuda_version == '12.4' }} + shell: pwsh + run: | + mkdir -p "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" + choco install unzip -y + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-12.4.127-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-12.4.131-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-12.4.127-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcublas/windows-x86_64/libcublas-windows-x86_64-12.4.5.8-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-12.4.127-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-12.4.127-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-12.4.127-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-12.4.127-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-12.4.127-archive.zip" + unzip '*.zip' -d "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_cudart-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvcc-windows-x86_64-12.4.131-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvrtc-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\libcublas-windows-x86_64-12.4.5.8-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvtx-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_profiler_api-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\visual_studio_integration-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvprof-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_cccl-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 + echo "CUDA_PATH_V12_4=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 diff --git a/.github/actions/windows-setup-curl/action.yml b/.github/actions/windows-setup-curl/action.yml new file mode 100644 index 0000000000000000000000000000000000000000..446f799fac34ab0d1e07ed626ca30d9bc7caa09e --- /dev/null +++ b/.github/actions/windows-setup-curl/action.yml @@ -0,0 +1,30 @@ +name: 'Windows - Setup CURL' +description: 'Composite action, to be reused in other workflow' +inputs: + curl_version: + description: 'CURL version' + required: false + default: '8.6.0_6' + architecture: + description: 'Architecture of the libcurl to download' + required: false + default: 'win64' +outputs: + curl_path: + description: "Path to the downloaded libcurl" + value: ${{ steps.get_libcurl.outputs.curl_path }} + +runs: + using: "composite" + steps: + - name: libCURL + id: get_libcurl + shell: powershell + env: + CURL_VERSION: ${{ inputs.curl_version }} + ARCHITECTURE: ${{ inputs.architecture }} + run: | + curl.exe -o $env:RUNNER_TEMP/curl.zip -L "https://curl.se/windows/dl-${env:CURL_VERSION}/curl-${env:CURL_VERSION}-${env:ARCHITECTURE}-mingw.zip" + mkdir $env:RUNNER_TEMP/libcurl + tar.exe -xvf $env:RUNNER_TEMP/curl.zip --strip-components=1 -C $env:RUNNER_TEMP/libcurl + echo "curl_path=$env:RUNNER_TEMP/libcurl" >> $env:GITHUB_OUTPUT diff --git a/.github/labeler.yml b/.github/labeler.yml new file mode 100644 index 0000000000000000000000000000000000000000..3c2f67707b024421bec8af69335fe69f97e18c89 --- /dev/null +++ b/.github/labeler.yml @@ -0,0 +1,95 @@ +# https://github.com/actions/labeler +Kompute: + - changed-files: + - any-glob-to-any-file: + - ggml/include/ggml-kompute.h + - ggml/src/ggml-kompute/** + - README-kompute.md +Apple Metal: + - changed-files: + - any-glob-to-any-file: + - ggml/include/ggml-metal.h + - ggml/src/ggml-metal/** + - README-metal.md +SYCL: + - changed-files: + - any-glob-to-any-file: + - ggml/include/ggml-sycl.h + - ggml/src/ggml-sycl/** + - docs/backend/SYCL.md + - examples/sycl/** +Nvidia GPU: + - changed-files: + - any-glob-to-any-file: + - ggml/include/ggml-cuda.h + - ggml/src/ggml-cuda/** +Vulkan: + - changed-files: + - any-glob-to-any-file: + - ggml/include/ggml-vulkan.h + - ggml/src/ggml-vulkan/** +documentation: + - changed-files: + - any-glob-to-any-file: + - docs/** + - media/** +testing: + - changed-files: + - any-glob-to-any-file: + - tests/** +build: + - changed-files: + - any-glob-to-any-file: + - cmake/** + - CMakeLists.txt + - CMakePresets.json +examples: + - changed-files: + - any-glob-to-any-file: + - examples/** + - tools/** +devops: + - changed-files: + - any-glob-to-any-file: + - .devops/** + - .github/** + - ci/** +python: + - changed-files: + - any-glob-to-any-file: + - "**/*.py" + - requirements/** + - gguf-py/** + - .flake8 +script: + - changed-files: + - any-glob-to-any-file: + - scripts/** +android: + - changed-files: + - any-glob-to-any-file: + - examples/llama.android/** +server: + - changed-files: + - any-glob-to-any-file: + - tools/server/** +ggml: + - changed-files: + - any-glob-to-any-file: + - ggml/** +nix: + - changed-files: + - any-glob-to-any-file: + - "**/*.nix" + - .github/workflows/nix-*.yml + - .devops/nix/nixpkgs-instances.nix +embedding: + - changed-files: + - any-glob-to-any-file: examples/embedding/ + +Ascend NPU: + - changed-files: + - any-glob-to-any-file: + - ggml/include/ggml-cann.h + - ggml/src/ggml-cann/** + - docs/backend/CANN.md diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000000000000000000000000000000000000..d0bdd73c4439c59506c94c6cfadeeb2b09375214 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1 @@ +*Make sure to read the [contributing guidelines](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md) before submitting a PR* diff --git a/.github/workflows/bench.yml.disabled b/.github/workflows/bench.yml.disabled new file mode 100644 index 0000000000000000000000000000000000000000..f2d7e16e981ac7c694749bb17e4e78bc68828871 --- /dev/null +++ b/.github/workflows/bench.yml.disabled @@ -0,0 +1,304 @@ +# TODO: there have been some issues with the workflow, so disabling for now +# https://github.com/ggml-org/llama.cpp/issues/7893 +# +# Benchmark +name: Benchmark + +on: + workflow_dispatch: + inputs: + gpu-series: + description: 'Azure GPU series to run with' + required: true + type: choice + options: + - Standard_NC4as_T4_v3 + - Standard_NC24ads_A100_v4 + - Standard_NC80adis_H100_v5 + sha: + description: 'Commit SHA1 to build' + required: false + type: string + duration: + description: 'Duration of the bench' + type: string + default: 10m + + push: + branches: + - master + paths: ['llama.cpp', 'ggml.c', 'ggml-backend.cpp', 'ggml-quants.c', '**/*.cu', 'tools/server/*.h*', 'tools/server/*.cpp'] + pull_request_target: + types: [opened, synchronize, reopened] + paths: ['llama.cpp', 'ggml.c', 'ggml-backend.cpp', 'ggml-quants.c', '**/*.cu', 'tools/server/*.h*', 'tools/server/*.cpp'] + schedule: + - cron: '04 2 * * *' + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.run_id }}-${{ github.event.inputs.sha }} + cancel-in-progress: true + +jobs: + bench-server-baseline: + runs-on: Standard_NC4as_T4_v3 + env: + RUNNER_LABEL: Standard_NC4as_T4_v3 # FIXME Do not find a way to not duplicate it + N_USERS: 8 + DURATION: 10m + + strategy: + matrix: + model: [phi-2] + ftype: [q4_0, q8_0, f16] + include: + - model: phi-2 + ftype: q4_0 + pr_comment_enabled: "true" + + if: | + inputs.gpu-series == 'Standard_NC4as_T4_v3' + || github.event_name == 'pull_request_target' + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} + + - name: Install python env + id: pipenv + run: | + cd tools/server/bench + python3 -m venv venv + source venv/bin/activate + pip install -r requirements.txt + + - name: Prometheus + id: install_prometheus + run: | + wget --quiet https://github.com/prometheus/prometheus/releases/download/v2.51.0/prometheus-2.51.0.linux-amd64.tar.gz + tar xzf prometheus*.tar.gz --strip-components=1 + ./prometheus --config.file=tools/server/bench/prometheus.yml & + while ! nc -z localhost 9090; do + sleep 0.1 + done + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.21' + + - name: Install k6 and xk6-sse + id: k6_installation + run: | + cd tools/server/bench + go install go.k6.io/xk6/cmd/xk6@latest + xk6 build master \ + --with github.com/phymbert/xk6-sse + + - name: Build + id: cmake_build + run: | + set -eux + cmake -B build \ + -DGGML_NATIVE=OFF \ + -DLLAMA_BUILD_SERVER=ON \ + -DLLAMA_CUBLAS=ON \ + -DCUDAToolkit_ROOT=/usr/local/cuda \ + -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc \ + -DCMAKE_CUDA_ARCHITECTURES=75 \ + -DLLAMA_FATAL_WARNINGS=OFF \ + -DLLAMA_ALL_WARNINGS=OFF \ + -DCMAKE_BUILD_TYPE=Release; + cmake --build build --config Release -j $(nproc) --target llama-server + + - name: Download the dataset + id: download_dataset + run: | + cd tools/server/bench + wget --quiet https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + + - name: Server bench + id: server_bench + env: + HEAD_REF: ${{ github.head_ref || github.ref_name }} + run: | + set -eux + + cd tools/server/bench + source venv/bin/activate + python bench.py \ + --runner-label ${{ env.RUNNER_LABEL }} \ + --name ${{ github.job }} \ + --branch $HEAD_REF \ + --commit ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha }} \ + --scenario script.js \ + --duration ${{ github.event.inputs.duration || env.DURATION }} \ + --hf-repo ggml-org/models \ + --hf-file ${{ matrix.model }}/ggml-model-${{ matrix.ftype }}.gguf \ + --model-path-prefix /models \ + --parallel ${{ env.N_USERS }} \ + -ngl 33 \ + --batch-size 2048 \ + --ubatch-size 256 \ + --ctx-size 16384 \ + --n-prompts 1000 \ + --max-prompt-tokens 1024 \ + --max-tokens 2048 + + cat results.github.env >> $GITHUB_ENV + + # Remove dataset as we do not want it in the artefact + rm ShareGPT_V3_unfiltered_cleaned_split.json + + - uses: actions/upload-artifact@v4 + with: + name: bench-server-${{ github.job }}-${{ env.RUNNER_LABEL }}-${{ matrix.model }}-${{ matrix.ftype }} + compression-level: 9 + path: | + tools/server/bench/*.jpg + tools/server/bench/*.json + tools/server/bench/*.log + + - name: Commit status + uses: Sibz/github-status-action@v1 + with: + authToken: ${{secrets.GITHUB_TOKEN}} + sha: ${{ inputs.sha || github.event.pull_request.head.sha || github.sha }} + context: bench-server-${{ github.job }}-${{ env.RUNNER_LABEL }}-${{ matrix.model }}-${{ matrix.ftype }} + description: | + ${{ env.BENCH_RESULTS }} + state: 'success' + + - name: Upload benchmark images + uses: devicons/public-upload-to-imgur@v2.2.2 + continue-on-error: true # Important as it looks unstable: 503 + id: imgur_step + with: + client_id: ${{secrets.IMGUR_CLIENT_ID}} + path: | + tools/server/bench/prompt_tokens_seconds.jpg + tools/server/bench/predicted_tokens_seconds.jpg + tools/server/bench/kv_cache_usage_ratio.jpg + tools/server/bench/requests_processing.jpg + + - name: Extract mermaid + id: set_mermaid + run: | + set -eux + + cd tools/server/bench + PROMPT_TOKENS_SECONDS=$(cat prompt_tokens_seconds.mermaid) + echo "PROMPT_TOKENS_SECONDS<> $GITHUB_ENV + echo "$PROMPT_TOKENS_SECONDS" >> $GITHUB_ENV + echo "EOF" >> $GITHUB_ENV + + PREDICTED_TOKENS_SECONDS=$(cat predicted_tokens_seconds.mermaid) + echo "PREDICTED_TOKENS_SECONDS<> $GITHUB_ENV + echo "$PREDICTED_TOKENS_SECONDS" >> $GITHUB_ENV + echo "EOF" >> $GITHUB_ENV + + KV_CACHE_USAGE_RATIO=$(cat kv_cache_usage_ratio.mermaid) + echo "KV_CACHE_USAGE_RATIO<> $GITHUB_ENV + echo "$KV_CACHE_USAGE_RATIO" >> $GITHUB_ENV + echo "EOF" >> $GITHUB_ENV + + REQUESTS_PROCESSING=$(cat requests_processing.mermaid) + echo "REQUESTS_PROCESSING<> $GITHUB_ENV + echo "$REQUESTS_PROCESSING" >> $GITHUB_ENV + echo "EOF" >> $GITHUB_ENV + + - name: Extract image url + id: extract_image_url + continue-on-error: true + run: | + set -eux + + echo "IMAGE_O=${{ fromJSON(steps.imgur_step.outputs.imgur_urls)[0] }}" >> $GITHUB_ENV + echo "IMAGE_1=${{ fromJSON(steps.imgur_step.outputs.imgur_urls)[1] }}" >> $GITHUB_ENV + echo "IMAGE_2=${{ fromJSON(steps.imgur_step.outputs.imgur_urls)[2] }}" >> $GITHUB_ENV + echo "IMAGE_3=${{ fromJSON(steps.imgur_step.outputs.imgur_urls)[3] }}" >> $GITHUB_ENV + + - name: Comment PR + uses: mshick/add-pr-comment@v2 + id: comment_pr + if: ${{ github.event.pull_request != '' && matrix.pr_comment_enabled == 'true' }} + with: + message-id: bench-server-${{ github.job }}-${{ env.RUNNER_LABEL }}-${{ matrix.model }}-${{ matrix.ftype }} + message: | +

+ + 📈 **llama.cpp server** for _${{ github.job }}_ on _${{ env.RUNNER_LABEL }}_ for `${{ matrix.model }}`-`${{ matrix.ftype }}`: **${{ env.BENCH_ITERATIONS}} iterations** 🚀 + +

+ +
+ + Expand details for performance related PR only + + - Concurrent users: ${{ env.N_USERS }}, duration: ${{ github.event.inputs.duration || env.DURATION }} + - HTTP request : avg=${{ env.HTTP_REQ_DURATION_AVG }}ms p(95)=${{ env.HTTP_REQ_DURATION_P_95_ }}ms fails=${{ env.HTTP_REQ_FAILED_PASSES }}, finish reason: stop=${{ env.LLAMACPP_COMPLETIONS_STOP_RATE_PASSES }} truncated=${{ env.LLAMACPP_COMPLETIONS_TRUNCATED_RATE_PASSES }} + - Prompt processing (pp): avg=${{ env.LLAMACPP_PROMPT_PROCESSING_SECOND_AVG }}tk/s p(95)=${{ env.LLAMACPP_PROMPT_PROCESSING_SECOND_P_95_ }}tk/s + - Token generation (tg): avg=${{ env.LLAMACPP_TOKENS_SECOND_AVG }}tk/s p(95)=${{ env.LLAMACPP_TOKENS_SECOND_P_95_ }}tk/s + - ${{ env.BENCH_GRAPH_XLABEL }} + + +

+ + prompt_tokens_seconds + +

+ + More + + ```mermaid + ${{ env.PROMPT_TOKENS_SECONDS }} + ``` + +
+ + predicted_tokens_seconds + +
+ More + + ```mermaid + ${{ env.PREDICTED_TOKENS_SECONDS }} + ``` + +
+ +

+ +
+ + Details + +

+ + kv_cache_usage_ratio + +

+ More + + ```mermaid + ${{ env.KV_CACHE_USAGE_RATIO }} + ``` + +
+ + requests_processing + +
+ More + + ```mermaid + ${{ env.REQUESTS_PROCESSING }} + ``` + +
+ +

+
+
diff --git a/.github/workflows/build-linux-cross.yml b/.github/workflows/build-linux-cross.yml new file mode 100644 index 0000000000000000000000000000000000000000..7cfc82ba4e277283a75d6a6ab589fa944907e47d --- /dev/null +++ b/.github/workflows/build-linux-cross.yml @@ -0,0 +1,346 @@ +name: Build on Linux using cross-compiler +on: + workflow_dispatch: + workflow_call: + +jobs: + ubuntu-24-riscv64-cpu-cross: + runs-on: ubuntu-24.04 + + steps: + - uses: actions/checkout@v4 + - name: Setup Riscv + run: | + sudo dpkg --add-architecture riscv64 + + # Add arch-specific repositories for non-amd64 architectures + cat << EOF | sudo tee /etc/apt/sources.list.d/riscv64-ports.list + deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe + deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe + deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe + deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe + EOF + + sudo apt-get update || true ;# Prevent failure due to missing URLs. + + sudo apt-get install -y --no-install-recommends \ + build-essential \ + gcc-14-riscv64-linux-gnu \ + g++-14-riscv64-linux-gnu + + - name: Build + run: | + cmake -B build -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_SYSTEM_NAME=Linux \ + -DCMAKE_SYSTEM_PROCESSOR=riscv64 \ + -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \ + -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH + + cmake --build build --config Release -j $(nproc) + + ubuntu-24-riscv64-vulkan-cross: + runs-on: ubuntu-24.04 + + steps: + - uses: actions/checkout@v4 + - name: Setup Riscv + run: | + sudo dpkg --add-architecture riscv64 + + # Add arch-specific repositories for non-amd64 architectures + cat << EOF | sudo tee /etc/apt/sources.list.d/riscv64-ports.list + deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe + deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe + deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe + deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe + EOF + + sudo apt-get update || true ;# Prevent failure due to missing URLs. + + sudo apt-get install -y --no-install-recommends \ + build-essential \ + glslc \ + gcc-14-riscv64-linux-gnu \ + g++-14-riscv64-linux-gnu \ + libvulkan-dev:riscv64 + + - name: Build + run: | + cmake -B build -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_VULKAN=ON \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_SYSTEM_NAME=Linux \ + -DCMAKE_SYSTEM_PROCESSOR=riscv64 \ + -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \ + -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH + + cmake --build build --config Release -j $(nproc) + + ubuntu-24-arm64-vulkan-cross: + runs-on: ubuntu-24.04 + + steps: + - uses: actions/checkout@v4 + - name: Setup Arm64 + run: | + sudo dpkg --add-architecture arm64 + + # Add arch-specific repositories for non-amd64 architectures + cat << EOF | sudo tee /etc/apt/sources.list.d/arm64-ports.list + deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe + deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe + deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe + deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe + EOF + + sudo apt-get update || true ;# Prevent failure due to missing URLs. + + sudo apt-get install -y --no-install-recommends \ + build-essential \ + glslc \ + crossbuild-essential-arm64 \ + libvulkan-dev:arm64 + + - name: Build + run: | + cmake -B build -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_VULKAN=ON \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_SYSTEM_NAME=Linux \ + -DCMAKE_SYSTEM_PROCESSOR=aarch64 \ + -DCMAKE_C_COMPILER=aarch64-linux-gnu-gcc \ + -DCMAKE_CXX_COMPILER=aarch64-linux-gnu-g++ \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DCMAKE_FIND_ROOT_PATH=/usr/lib/aarch64-linux-gnu \ + -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH + + cmake --build build --config Release -j $(nproc) + + ubuntu-24-ppc64el-cpu-cross: + runs-on: ubuntu-24.04 + + steps: + - uses: actions/checkout@v4 + - name: Setup PowerPC64le + run: | + sudo dpkg --add-architecture ppc64el + + # Add arch-specific repositories for non-amd64 architectures + cat << EOF | sudo tee /etc/apt/sources.list.d/ppc64el-ports.list + deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble main universe + deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe + deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe + deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe + EOF + + sudo apt-get update || true ;# Prevent failure due to missing URLs. + + sudo apt-get install -y --no-install-recommends \ + build-essential \ + gcc-14-powerpc64le-linux-gnu \ + g++-14-powerpc64le-linux-gnu + + - name: Build + run: | + cmake -B build -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_SYSTEM_NAME=Linux \ + -DCMAKE_SYSTEM_PROCESSOR=ppc64 \ + -DCMAKE_C_COMPILER=powerpc64le-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=powerpc64le-linux-gnu-g++-14 \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DCMAKE_FIND_ROOT_PATH=/usr/lib/powerpc64le-linux-gnu \ + -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH + + cmake --build build --config Release -j $(nproc) + + ubuntu-24-ppc64el-vulkan-cross: + runs-on: ubuntu-24.04 + + steps: + - uses: actions/checkout@v4 + - name: Setup PowerPC64le + run: | + sudo dpkg --add-architecture ppc64el + + # Add arch-specific repositories for non-amd64 architectures + cat << EOF | sudo tee /etc/apt/sources.list.d/ppc64el-ports.list + deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble main universe + deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe + deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe + deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe + EOF + + sudo apt-get update || true ;# Prevent failure due to missing URLs. + + sudo apt-get install -y --no-install-recommends \ + build-essential \ + glslc \ + gcc-14-powerpc64le-linux-gnu \ + g++-14-powerpc64le-linux-gnu \ + libvulkan-dev:ppc64el + + - name: Build + run: | + cmake -B build -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_VULKAN=ON \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_SYSTEM_NAME=Linux \ + -DCMAKE_SYSTEM_PROCESSOR=ppc64 \ + -DCMAKE_C_COMPILER=powerpc64le-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=powerpc64le-linux-gnu-g++-14 \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DCMAKE_FIND_ROOT_PATH=/usr/lib/powerpc64le-linux-gnu \ + -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH + + cmake --build build --config Release -j $(nproc) + + debian-13-loongarch64-cpu-cross: + runs-on: ubuntu-24.04 + container: debian@sha256:653dfb9f86c3782e8369d5f7d29bb8faba1f4bff9025db46e807fa4c22903671 + + steps: + - uses: actions/checkout@v4 + - name: Setup LoongArch + run: | + rm -f /etc/apt/sources.list.d/* + cat << EOF | tee /etc/apt/sources.list.d/debian-ports.list + deb http://snapshot.debian.org/archive/debian/20250515T202920Z/ trixie main + EOF + ( echo 'quiet "true";'; \ + echo 'APT::Get::Assume-Yes "true";'; \ + echo 'APT::Install-Recommends "false";'; \ + echo 'Acquire::Check-Valid-Until "false";'; \ + echo 'Acquire::Retries "5";'; \ + ) > /etc/apt/apt.conf.d/99snapshot-repos + + apt-get update + apt-get install -y ca-certificates debian-ports-archive-keyring cmake git zip + dpkg --add-architecture loong64 + + # Add arch-specific repositories for non-amd64 architectures + cat << EOF | tee /etc/apt/sources.list.d/loong64-ports.list + deb [arch=loong64] http://snapshot.debian.org/archive/debian-ports/20250515T194251Z/ sid main + EOF + + apt-get update || true ;# Prevent failure due to missing URLs. + + apt-get install -y --no-install-recommends \ + build-essential \ + gcc-14-loongarch64-linux-gnu \ + g++-14-loongarch64-linux-gnu + + - name: Build + run: | + cmake -B build -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_SYSTEM_NAME=Linux \ + -DCMAKE_SYSTEM_PROCESSOR=loongarch64 \ + -DCMAKE_C_COMPILER=loongarch64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=loongarch64-linux-gnu-g++-14 \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DCMAKE_FIND_ROOT_PATH=/usr/lib/loongarch64-linux-gnu \ + -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH + + cmake --build build --config Release -j $(nproc) + + debian-13-loongarch64-vulkan-cross: + runs-on: ubuntu-24.04 + container: debian@sha256:653dfb9f86c3782e8369d5f7d29bb8faba1f4bff9025db46e807fa4c22903671 + + steps: + - uses: actions/checkout@v4 + - name: Setup LoongArch + run: | + rm -f /etc/apt/sources.list.d/* + cat << EOF | tee /etc/apt/sources.list.d/debian-ports.list + deb http://snapshot.debian.org/archive/debian/20250515T202920Z/ trixie main + EOF + ( echo 'quiet "true";'; \ + echo 'APT::Get::Assume-Yes "true";'; \ + echo 'APT::Install-Recommends "false";'; \ + echo 'Acquire::Check-Valid-Until "false";'; \ + echo 'Acquire::Retries "5";'; \ + ) > /etc/apt/apt.conf.d/99snapshot-repos + + apt-get update + apt-get install -y ca-certificates debian-ports-archive-keyring cmake git zip + dpkg --add-architecture loong64 + + # Add arch-specific repositories for non-amd64 architectures + cat << EOF | tee /etc/apt/sources.list.d/loong64-ports.list + deb [arch=loong64] http://snapshot.debian.org/archive/debian-ports/20250515T194251Z/ sid main + EOF + + apt-get update || true ;# Prevent failure due to missing URLs. + + apt-get install -y --no-install-recommends \ + build-essential \ + glslc \ + gcc-14-loongarch64-linux-gnu \ + g++-14-loongarch64-linux-gnu \ + libvulkan-dev:loong64 + + - name: Build + run: | + cmake -B build -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_VULKAN=ON \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_SYSTEM_NAME=Linux \ + -DCMAKE_SYSTEM_PROCESSOR=loongarch64 \ + -DCMAKE_C_COMPILER=loongarch64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=loongarch64-linux-gnu-g++-14 \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DCMAKE_FIND_ROOT_PATH=/usr/lib/loongarch64-linux-gnu \ + -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH + + cmake --build build --config Release -j $(nproc) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000000000000000000000000000000000000..c4783a6df88820ce933e6569b5fbabbe8f7c5bca --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,1080 @@ +name: CI + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build.yml', '.github/workflows/build-linux-cross.yml', '**/CMakeLists.txt', '**/.cmake', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m', '**/*.metal', '**/*.comp'] + pull_request: + types: [opened, synchronize, reopened] + paths: ['.github/workflows/build.yml', '.github/workflows/build-linux-cross.yml', '**/CMakeLists.txt', '**/.cmake', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m', '**/*.metal', '**/*.comp'] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +env: + GGML_NLOOP: 3 + GGML_N_THREADS: 1 + LLAMA_LOG_COLORS: 1 + LLAMA_LOG_PREFIX: 1 + LLAMA_LOG_TIMESTAMPS: 1 + +jobs: + macOS-latest-cmake-arm64: + runs-on: macos-14 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: macOS-latest-cmake-arm64 + evict-old-files: 1d + + - name: Dependencies + id: depends + continue-on-error: true + run: | + brew update + brew install curl + + - name: Build + id: cmake_build + run: | + sysctl -a + cmake -B build \ + -DCMAKE_BUILD_RPATH="@loader_path" \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DGGML_METAL_USE_BF16=ON \ + -DGGML_METAL_EMBED_LIBRARY=ON \ + -DGGML_RPC=ON + cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) + + - name: Test + id: cmake_test + run: | + cd build + ctest -L 'main|curl' --verbose --timeout 900 + + macOS-latest-cmake-x64: + runs-on: macos-13 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: macOS-latest-cmake-x64 + evict-old-files: 1d + + - name: Dependencies + id: depends + continue-on-error: true + run: | + brew update + brew install curl + + - name: Build + id: cmake_build + run: | + sysctl -a + # Metal is disabled due to intermittent failures with Github runners not having a GPU: + # https://github.com/ggml-org/llama.cpp/actions/runs/8635935781/job/23674807267#step:5:2313 + cmake -B build \ + -DCMAKE_BUILD_RPATH="@loader_path" \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DGGML_METAL=OFF \ + -DGGML_RPC=ON + cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) + + - name: Test + id: cmake_test + run: | + cd build + ctest -L main --verbose --timeout 900 + + ubuntu-cpu-cmake: + strategy: + matrix: + include: + - build: 'x64' + os: ubuntu-22.04 + - build: 'arm64' + os: ubuntu-22.04-arm + + runs-on: ${{ matrix.os }} + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-cpu-cmake + evict-old-files: 1d + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential libcurl4-openssl-dev + + - name: Build + id: cmake_build + run: | + cmake -B build \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DGGML_RPC=ON + cmake --build build --config Release -j $(nproc) + + - name: Test + id: cmake_test + run: | + cd build + ctest -L 'main|curl' --verbose --timeout 900 + + - name: Test llama2c conversion + id: llama2c_test + run: | + cd build + echo "Fetch tokenizer" + wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories260K/tok512.bin + echo "Fetch llama2c model" + wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories260K/stories260K.bin + ./bin/llama-convert-llama2c-to-ggml --copy-vocab-from-model ./tok512.bin --llama2c-model stories260K.bin --llama2c-output-model stories260K.gguf + ./bin/llama-cli -m stories260K.gguf -p "One day, Lily met a Shoggoth" -n 500 -c 256 + + ubuntu-latest-cmake-sanitizer: + runs-on: ubuntu-latest + + continue-on-error: true + + strategy: + matrix: + sanitizer: [ADDRESS, THREAD, UNDEFINED] + build_type: [Debug] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-latest-cmake-sanitizer-${{ matrix.sanitizer }} + evict-old-files: 1d + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential libcurl4-openssl-dev + + - name: Build + id: cmake_build + if: ${{ matrix.sanitizer != 'THREAD' }} + run: | + cmake -B build \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) + + - name: Build (no OpenMP) + id: cmake_build_no_openmp + if: ${{ matrix.sanitizer == 'THREAD' }} + run: | + cmake -B build \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ + -DGGML_OPENMP=OFF + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) + + - name: Test + id: cmake_test + run: | + cd build + ctest -L main --verbose --timeout 900 + + ubuntu-latest-llguidance: + runs-on: ubuntu-latest + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential libcurl4-openssl-dev + + - name: Build + id: cmake_build + run: | + mkdir build + cd build + cmake .. \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DLLAMA_LLGUIDANCE=ON + cmake --build . --config Release -j $(nproc) + + - name: Test + id: cmake_test + run: | + cd build + ctest -L main --verbose --timeout 900 + + ubuntu-latest-cmake-rpc: + runs-on: ubuntu-latest + + continue-on-error: true + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-latest-cmake-rpc + evict-old-files: 1d + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential libcurl4-openssl-dev + + - name: Build + id: cmake_build + run: | + cmake -B build \ + -DGGML_RPC=ON + cmake --build build --config Release -j $(nproc) + + - name: Test + id: cmake_test + run: | + cd build + ctest -L main --verbose + + ubuntu-22-cmake-vulkan: + runs-on: ubuntu-22.04 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-22-cmake-vulkan + evict-old-files: 1d + + - name: Dependencies + id: depends + run: | + wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | sudo apt-key add - + sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list + sudo apt-get update -y + sudo apt-get install -y build-essential mesa-vulkan-drivers vulkan-sdk libcurl4-openssl-dev + + - name: Build + id: cmake_build + run: | + cmake -B build \ + -DGGML_VULKAN=ON + cmake --build build --config Release -j $(nproc) + + - name: Test + id: cmake_test + run: | + cd build + export GGML_VK_VISIBLE_DEVICES=0 + # This is using llvmpipe and runs slower than other backends + ctest -L main --verbose --timeout 3600 + + ubuntu-22-cmake-hip: + runs-on: ubuntu-22.04 + container: rocm/dev-ubuntu-22.04:6.0.2 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libcurl4-openssl-dev + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-22-cmake-hip + evict-old-files: 1d + + - name: Build with native CMake HIP support + id: cmake_build + run: | + cmake -B build -S . \ + -DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \ + -DGGML_HIP_ROCWMMA_FATTN=ON \ + -DGGML_HIP=ON + cmake --build build --config Release -j $(nproc) + + - name: Build with legacy HIP support + id: cmake_build_legacy_hip + run: | + cmake -B build2 -S . \ + -DCMAKE_C_COMPILER=hipcc \ + -DCMAKE_CXX_COMPILER=hipcc \ + -DGGML_HIP_ROCWMMA_FATTN=ON \ + -DGGML_HIP=ON + cmake --build build2 --config Release -j $(nproc) + + ubuntu-22-cmake-musa: + runs-on: ubuntu-22.04 + container: mthreads/musa:rc4.0.1-mudnn-devel-ubuntu22.04 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Dependencies + id: depends + run: | + apt-get update + apt-get install -y build-essential git cmake libcurl4-openssl-dev + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-22-cmake-musa + evict-old-files: 1d + + - name: Build with native CMake MUSA support + id: cmake_build + run: | + cmake -B build -S . \ + -DGGML_MUSA=ON + cmake --build build --config Release -j $(nproc) + + ubuntu-22-cmake-sycl: + runs-on: ubuntu-22.04 + + continue-on-error: true + + steps: + - uses: actions/checkout@v4 + + - name: add oneAPI to apt + shell: bash + run: | + cd /tmp + wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + rm GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + sudo add-apt-repository "deb https://apt.repos.intel.com/oneapi all main" + + - name: install oneAPI dpcpp compiler + shell: bash + run: | + sudo apt update + sudo apt install intel-oneapi-compiler-dpcpp-cpp libcurl4-openssl-dev + + - name: install oneAPI MKL library + shell: bash + run: | + sudo apt install intel-oneapi-mkl-devel + + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-22-cmake-sycl + evict-old-files: 1d + + - name: Build + id: cmake_build + run: | + source /opt/intel/oneapi/setvars.sh + cmake -B build \ + -DGGML_SYCL=ON \ + -DCMAKE_C_COMPILER=icx \ + -DCMAKE_CXX_COMPILER=icpx + cmake --build build --config Release -j $(nproc) + + ubuntu-22-cmake-sycl-fp16: + runs-on: ubuntu-22.04 + + continue-on-error: true + + steps: + - uses: actions/checkout@v4 + + - name: add oneAPI to apt + shell: bash + run: | + cd /tmp + wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + rm GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + sudo add-apt-repository "deb https://apt.repos.intel.com/oneapi all main" + + - name: install oneAPI dpcpp compiler + shell: bash + run: | + sudo apt update + sudo apt install intel-oneapi-compiler-dpcpp-cpp libcurl4-openssl-dev + + - name: install oneAPI MKL library + shell: bash + run: | + sudo apt install intel-oneapi-mkl-devel + + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-22-cmake-sycl-fp16 + evict-old-files: 1d + + - name: Build + id: cmake_build + run: | + source /opt/intel/oneapi/setvars.sh + cmake -B build \ + -DGGML_SYCL=ON \ + -DCMAKE_C_COMPILER=icx \ + -DCMAKE_CXX_COMPILER=icpx \ + -DGGML_SYCL_F16=ON + cmake --build build --config Release -j $(nproc) + + build-linux-cross: + uses: ./.github/workflows/build-linux-cross.yml + + macOS-latest-cmake-ios: + runs-on: macos-latest + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: macOS-latest-cmake-ios + evict-old-files: 1d + + - name: Dependencies + id: depends + continue-on-error: true + run: | + brew update + + - name: Build + id: cmake_build + run: | + sysctl -a + cmake -B build -G Xcode \ + -DGGML_METAL_USE_BF16=ON \ + -DGGML_METAL_EMBED_LIBRARY=ON \ + -DLLAMA_BUILD_COMMON=OFF \ + -DLLAMA_BUILD_EXAMPLES=OFF \ + -DLLAMA_BUILD_TOOLS=OFF \ + -DLLAMA_BUILD_TESTS=OFF \ + -DLLAMA_BUILD_SERVER=OFF \ + -DCMAKE_SYSTEM_NAME=iOS \ + -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \ + -DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml + cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO + + macOS-latest-cmake-tvos: + runs-on: macos-latest + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: macOS-latest-cmake-tvos + evict-old-files: 1d + + - name: Dependencies + id: depends + continue-on-error: true + run: | + brew update + + - name: Build + id: cmake_build + run: | + sysctl -a + cmake -B build -G Xcode \ + -DGGML_METAL_USE_BF16=ON \ + -DGGML_METAL_EMBED_LIBRARY=ON \ + -DLLAMA_BUILD_COMMON=OFF \ + -DLLAMA_BUILD_EXAMPLES=OFF \ + -DLLAMA_BUILD_TOOLS=OFF \ + -DLLAMA_BUILD_TESTS=OFF \ + -DLLAMA_BUILD_SERVER=OFF \ + -DCMAKE_SYSTEM_NAME=tvOS \ + -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \ + -DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml + cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO + + macOS-latest-cmake-visionos: + runs-on: macos-latest + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Dependencies + id: depends + continue-on-error: true + run: | + brew update + + - name: Build + id: cmake_build + run: | + sysctl -a + cmake -B build -G Xcode \ + -DGGML_METAL_USE_BF16=ON \ + -DGGML_METAL_EMBED_LIBRARY=ON \ + -DLLAMA_BUILD_COMMON=OFF \ + -DLLAMA_BUILD_EXAMPLES=OFF \ + -DLLAMA_BUILD_TOOLS=OFF \ + -DLLAMA_BUILD_TESTS=OFF \ + -DLLAMA_BUILD_SERVER=OFF \ + -DCMAKE_SYSTEM_NAME=visionOS \ + -DCMAKE_OSX_DEPLOYMENT_TARGET=1.0 \ + -DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml + cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO + + macOS-latest-swift: + runs-on: macos-latest + + strategy: + matrix: + destination: ['generic/platform=macOS', 'generic/platform=iOS', 'generic/platform=tvOS'] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: macOS-latest-swift + evict-old-files: 1d + + - name: Dependencies + id: depends + continue-on-error: true + run: | + brew update + + - name: Build llama.cpp with CMake + id: cmake_build + run: | + sysctl -a + cmake -B build -G Xcode \ + -DGGML_METAL_USE_BF16=ON \ + -DGGML_METAL_EMBED_LIBRARY=ON \ + -DLLAMA_CURL=OFF \ + -DLLAMA_BUILD_EXAMPLES=OFF \ + -DLLAMA_BUILD_TOOLS=OFF \ + -DLLAMA_BUILD_TESTS=OFF \ + -DLLAMA_BUILD_SERVER=OFF \ + -DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" + cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) + + - name: xcodebuild for swift package + id: xcodebuild + run: | + ./build-xcframework.sh + + windows-msys2: + runs-on: windows-latest + + strategy: + fail-fast: false + matrix: + include: + - { sys: UCRT64, env: ucrt-x86_64, build: Release } + - { sys: CLANG64, env: clang-x86_64, build: Release } + + steps: + - name: Clone + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: windows-msys2 + variant: ccache + evict-old-files: 1d + + - name: Setup ${{ matrix.sys }} + uses: msys2/setup-msys2@v2 + with: + update: true + msystem: ${{matrix.sys}} + install: >- + base-devel + git + mingw-w64-${{matrix.env}}-toolchain + mingw-w64-${{matrix.env}}-cmake + mingw-w64-${{matrix.env}}-openblas + + - name: Build using CMake + shell: msys2 {0} + run: | + cmake -B build + cmake --build build --config ${{ matrix.build }} -j $(nproc) + + - name: Clean after building using CMake + shell: msys2 {0} + run: | + rm -rf build + + - name: Build using CMake w/ OpenBLAS + shell: msys2 {0} + run: | + cmake -B build -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS + cmake --build build --config ${{ matrix.build }} -j $(nproc) + + windows-latest-cmake: + runs-on: windows-latest + + env: + OPENBLAS_VERSION: 0.3.23 + SDE_VERSION: 9.33.0-2024-01-07 + VULKAN_VERSION: 1.4.309.0 + + strategy: + matrix: + include: + - build: 'cpu-x64 (static)' + defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DBUILD_SHARED_LIBS=OFF' + - build: 'openblas-x64' + defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DBLAS_INCLUDE_DIRS="$env:RUNNER_TEMP/openblas/include" -DBLAS_LIBRARIES="$env:RUNNER_TEMP/openblas/lib/openblas.lib"' + - build: 'vulkan-x64' + defines: '-DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_VULKAN=ON' + - build: 'llvm-arm64' + defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON' + - build: 'llvm-arm64-opencl-adreno' + defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" -DGGML_OPENCL=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON' + # - build: 'kompute-x64' + # defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF -DGGML_KOMPUTE=ON -DKOMPUTE_OPT_DISABLE_VULKAN_VERSION_CHECK=ON' + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: windows-latest-cmake-${{ matrix.build }} + variant: ccache + evict-old-files: 1d + + - name: Clone Kompute submodule + id: clone_kompute + if: ${{ matrix.build == 'kompute-x64' }} + run: | + git submodule update --init ggml/src/ggml-kompute/kompute + + - name: Download OpenBLAS + id: get_openblas + if: ${{ matrix.build == 'openblas-x64' }} + run: | + curl.exe -o $env:RUNNER_TEMP/openblas.zip -L "https://github.com/xianyi/OpenBLAS/releases/download/v${env:OPENBLAS_VERSION}/OpenBLAS-${env:OPENBLAS_VERSION}-x64.zip" + curl.exe -o $env:RUNNER_TEMP/OpenBLAS.LICENSE.txt -L "https://github.com/xianyi/OpenBLAS/raw/v${env:OPENBLAS_VERSION}/LICENSE" + mkdir $env:RUNNER_TEMP/openblas + tar.exe -xvf $env:RUNNER_TEMP/openblas.zip -C $env:RUNNER_TEMP/openblas + $vcdir = $(vswhere -latest -products * -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath) + $msvc = $(join-path $vcdir $('VC\Tools\MSVC\'+$(gc -raw $(join-path $vcdir 'VC\Auxiliary\Build\Microsoft.VCToolsVersion.default.txt')).Trim())) + $lib = $(join-path $msvc 'bin\Hostx64\x64\lib.exe') + & $lib /machine:x64 "/def:${env:RUNNER_TEMP}/openblas/lib/libopenblas.def" "/out:${env:RUNNER_TEMP}/openblas/lib/openblas.lib" /name:openblas.dll + + - name: Install Vulkan SDK + id: get_vulkan + if: ${{ matrix.build == 'kompute-x64' || matrix.build == 'vulkan-x64' }} + run: | + curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/VulkanSDK-${env:VULKAN_VERSION}-Installer.exe" + & "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install + Add-Content $env:GITHUB_ENV "VULKAN_SDK=C:\VulkanSDK\${env:VULKAN_VERSION}" + Add-Content $env:GITHUB_PATH "C:\VulkanSDK\${env:VULKAN_VERSION}\bin" + + - name: Install Ninja + id: install_ninja + run: | + choco install ninja + + - name: Install OpenCL Headers and Libs + id: install_opencl + if: ${{ matrix.build == 'llvm-arm64-opencl-adreno' }} + run: | + git clone https://github.com/KhronosGroup/OpenCL-Headers + cd OpenCL-Headers + cmake -B build ` + -DBUILD_TESTING=OFF ` + -DOPENCL_HEADERS_BUILD_TESTING=OFF ` + -DOPENCL_HEADERS_BUILD_CXX_TESTS=OFF ` + -DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release" + cmake --build build --target install + git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader + cd OpenCL-ICD-Loader + cmake -B build-arm64-release ` + -A arm64 ` + -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" ` + -DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release" + cmake --build build-arm64-release --target install --config release + + - name: libCURL + id: get_libcurl + uses: ./.github/actions/windows-setup-curl + + - name: Build + id: cmake_build + env: + CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} + run: | + cmake -S . -B build ${{ matrix.defines }} ` + -DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include" + cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS} + cp $env:CURL_PATH/bin/libcurl-*.dll build/bin/Release + + - name: Add libopenblas.dll + id: add_libopenblas_dll + if: ${{ matrix.build == 'openblas-x64' }} + run: | + cp $env:RUNNER_TEMP/openblas/bin/libopenblas.dll ./build/bin/Release/openblas.dll + cp $env:RUNNER_TEMP/OpenBLAS.LICENSE.txt ./build/bin/Release/OpenBLAS-${env:OPENBLAS_VERSION}.txt + + - name: Test + id: cmake_test + if: ${{ matrix.build != 'llvm-arm64' && matrix.build != 'llvm-arm64-opencl-adreno' }} + run: | + cd build + ctest -L main -C Release --verbose --timeout 900 + + # TODO: disabled for now, consider adding tests for all CPU variants instead + # - name: Test (Intel SDE) + # id: cmake_test_sde + # if: ${{ matrix.build == 'avx512-x64' && env.HAS_AVX512F == '0' }} # use Intel SDE for AVX-512 emulation + # run: | + # curl.exe -o $env:RUNNER_TEMP/sde.tar.xz -L "https://downloadmirror.intel.com/813591/sde-external-${env:SDE_VERSION}-win.tar.xz" + # # for some weird reason windows tar doesn't like sde tar.xz + # 7z x "-o${env:RUNNER_TEMP}" $env:RUNNER_TEMP/sde.tar.xz + # 7z x "-o${env:RUNNER_TEMP}" $env:RUNNER_TEMP/sde.tar + # $sde = $(join-path $env:RUNNER_TEMP sde-external-${env:SDE_VERSION}-win/sde.exe) + # cd build + # $env:LLAMA_SKIP_TESTS_SLOW_ON_EMULATOR = 1 + # & $sde -future -- ctest -L main -C Release --verbose --timeout 900 + + ubuntu-latest-cmake-cuda: + runs-on: ubuntu-latest + container: nvidia/cuda:12.6.2-devel-ubuntu24.04 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Install dependencies + env: + DEBIAN_FRONTEND: noninteractive + run: | + apt update + apt install -y cmake build-essential ninja-build libgomp1 git libcurl4-openssl-dev + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-latest-cmake-cuda + evict-old-files: 1d + + - name: Build with CMake + run: | + cmake -S . -B build -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_CUDA_ARCHITECTURES=89-real \ + -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DGGML_NATIVE=OFF \ + -DGGML_CUDA=ON + cmake --build build + + windows-2022-cmake-cuda: + runs-on: windows-2022 + + strategy: + matrix: + cuda: ['12.4'] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Install ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: windows-cuda-${{ matrix.cuda }} + variant: ccache + evict-old-files: 1d + + - name: Install Cuda Toolkit + uses: ./.github/actions/windows-setup-cuda + with: + cuda_version: ${{ matrix.cuda }} + + - name: Install Ninja + id: install_ninja + run: | + choco install ninja + + - name: libCURL + id: get_libcurl + uses: ./.github/actions/windows-setup-curl + + - name: Build + id: cmake_build + shell: cmd + env: + CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} + run: | + call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64 + cmake -S . -B build -G "Ninja Multi-Config" ^ + -DLLAMA_BUILD_SERVER=ON ^ + -DGGML_NATIVE=OFF ^ + -DGGML_BACKEND_DL=ON ^ + -DGGML_CPU_ALL_VARIANTS=ON ^ + -DGGML_CUDA=ON ^ + -DGGML_RPC=ON ^ + -DCURL_LIBRARY="%CURL_PATH%/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="%CURL_PATH%/include" + set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 + cmake --build build --config Release -j %NINJA_JOBS% -t ggml + cmake --build build --config Release + + windows-latest-cmake-sycl: + runs-on: windows-latest + + defaults: + run: + shell: bash + + env: + WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/7cd9bba0-7aab-4e30-b3ae-2221006a4a05/intel-oneapi-base-toolkit-2025.1.1.34_offline.exe + WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel:intel.oneapi.win.dnnl:intel.oneapi.win.tbb.devel + ONEAPI_ROOT: "C:/Program Files (x86)/Intel/oneAPI" + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: windows-latest-cmake-sycl + variant: ccache + evict-old-files: 1d + + - name: Install + run: | + scripts/install-oneapi.bat $WINDOWS_BASEKIT_URL $WINDOWS_DPCPP_MKL + + # TODO: add libcurl support ; we will also need to modify win-build-sycl.bat to accept user-specified args + + - name: Build + id: cmake_build + run: examples/sycl/win-build-sycl.bat + + windows-latest-cmake-hip: + if: ${{ github.event.inputs.create_release != 'true' }} + runs-on: windows-latest + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Clone rocWMMA repository + id: clone_rocwmma + run: | + git clone https://github.com/rocm/rocwmma --branch rocm-6.2.4 --depth 1 + + - name: Install + id: depends + run: | + $ErrorActionPreference = "Stop" + write-host "Downloading AMD HIP SDK Installer" + Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" + write-host "Installing AMD HIP SDK" + Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait + write-host "Completed AMD HIP SDK installation" + + - name: Verify ROCm + id: verify + run: | + & 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version + + - name: Install ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ${{ github.job }} + evict-old-files: 1d + + - name: libCURL + id: get_libcurl + uses: ./.github/actions/windows-setup-curl + + - name: Build + id: cmake_build + env: + CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} + run: | + $env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path) + $env:CMAKE_PREFIX_PATH="${env:HIP_PATH}" + cmake -G "Unix Makefiles" -B build -S . ` + -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" ` + -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" ` + -DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/rocwmma/library/include/" ` + -DCMAKE_BUILD_TYPE=Release ` + -DGGML_HIP=ON ` + -DGGML_HIP_ROCWMMA_FATTN=ON ` + -DGGML_RPC=ON ` + -DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include" + cmake --build build -j ${env:NUMBER_OF_PROCESSORS} + + ios-xcode-build: + runs-on: macos-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Build + id: cmake_build + run: | + sysctl -a + cmake -B build -G Xcode \ + -DGGML_METAL_USE_BF16=ON \ + -DGGML_METAL_EMBED_LIBRARY=ON \ + -DLLAMA_CURL=OFF \ + -DLLAMA_BUILD_EXAMPLES=OFF \ + -DLLAMA_BUILD_TOOLS=OFF \ + -DLLAMA_BUILD_TESTS=OFF \ + -DLLAMA_BUILD_SERVER=OFF \ + -DCMAKE_SYSTEM_NAME=iOS \ + -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \ + -DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml + cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO + + - name: xcodebuild for swift package + id: xcodebuild + run: | + ./build-xcframework.sh + + - name: Build Xcode project + run: xcodebuild -project examples/llama.swiftui/llama.swiftui.xcodeproj -scheme llama.swiftui -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' FRAMEWORK_FOLDER_PATH=./build-ios build + + android-build: + runs-on: ubuntu-latest + + steps: + - name: Clone + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: android-build + evict-old-files: 1d + + - name: Set up JDK + uses: actions/setup-java@v3 + with: + java-version: 17 + distribution: zulu + + - name: Setup Android SDK + uses: android-actions/setup-android@v3 + with: + log-accepted-android-sdk-licenses: false + + - name: Build + run: | + cd examples/llama.android + ./gradlew build --no-daemon + + openEuler-latest-cmake-cann: + if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'Ascend NPU') }} + defaults: + run: + shell: bash -el {0} + strategy: + matrix: + arch: [x86, aarch64] + cann: + - '8.1.RC1.alpha001-910b-openeuler22.03-py3.10' + device: + - 'ascend910b3' + build: + - 'Release' + runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }} + container: ascendai/cann:${{ matrix.cann }} + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Dependencies + run: | + yum update -y + yum install -y git gcc gcc-c++ make cmake libcurl-devel + + - name: Build + run: | + export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/$(uname -m)-linux/devlib/:${LD_LIBRARY_PATH} + + cmake -S . -B build \ + -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DGGML_CANN=on \ + -DSOC_TYPE=${{ matrix.device }} + cmake --build build -j $(nproc) diff --git a/.github/workflows/close-issue.yml b/.github/workflows/close-issue.yml new file mode 100644 index 0000000000000000000000000000000000000000..276a217d4500523b8a951510a918de276f088cc3 --- /dev/null +++ b/.github/workflows/close-issue.yml @@ -0,0 +1,28 @@ +name: Close inactive issues +on: + schedule: + - cron: "42 0 * * *" + +# Fine-grant permission +# https://docs.github.com/en/actions/security-for-github-actions/security-guides/automatic-token-authentication#modifying-the-permissions-for-the-github_token +permissions: + issues: write + +jobs: + close-issues: + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + steps: + - uses: actions/stale@v5 + with: + exempt-issue-labels: "refactor,help wanted,good first issue,research,bug,roadmap" + days-before-issue-stale: 30 + days-before-issue-close: 14 + stale-issue-label: "stale" + close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale." + days-before-pr-stale: -1 + days-before-pr-close: -1 + operations-per-run: 10000 + repo-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml new file mode 100644 index 0000000000000000000000000000000000000000..2067927be56caae7d9a80a074123a9a6859ec71e --- /dev/null +++ b/.github/workflows/docker.yml @@ -0,0 +1,178 @@ +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +# GitHub recommends pinning actions to a commit SHA. +# To get a newer version, you will need to update the SHA. +# You can also reference a tag or branch, but the action may change without warning. + +name: Publish Docker image + +on: + workflow_dispatch: # allows manual triggering + schedule: + # Rebuild daily rather than on every push because it is expensive + - cron: '12 4 * * *' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +# Fine-grant permission +# https://docs.github.com/en/actions/security-for-github-actions/security-guides/automatic-token-authentication#modifying-the-permissions-for-the-github_token +permissions: + packages: write + +jobs: + push_to_registry: + name: Push Docker image to Docker Hub + + runs-on: ubuntu-22.04 + env: + COMMIT_SHA: ${{ github.sha }} + strategy: + fail-fast: false + matrix: + config: + # Multi-stage build + # Note: the arm64 images are failing, which prevents the amd64 images from being built + # https://github.com/ggml-org/llama.cpp/issues/11888 + #- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, free_disk_space: false } + - { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false } + - { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false } + - { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true } + - { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true } + - { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false } + # Note: the rocm images are failing due to a compiler error and are disabled until this is fixed to allow the workflow to complete + #- {tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, free_disk_space: true } + steps: + - name: Check out the repo + uses: actions/checkout@v4 + with: + fetch-depth: 0 # preserve git history, so we can determine the build number + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + with: + image: tonistiigi/binfmt:qemu-v7.0.0-28 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to Docker Hub + uses: docker/login-action@v2 + with: + registry: ghcr.io + username: ${{ github.repository_owner }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Determine tag name + id: tag + shell: bash + run: | + BUILD_NUMBER="$(git rev-list --count HEAD)" + SHORT_HASH="$(git rev-parse --short=7 HEAD)" + REPO_OWNER="${GITHUB_REPOSITORY_OWNER@L}" # to lower case + REPO_NAME="${{ github.event.repository.name }}" + + # determine tag name postfix (build number, commit hash) + if [[ "${{ env.GITHUB_BRANCH_NAME }}" == "master" ]]; then + TAG_POSTFIX="-b${BUILD_NUMBER}" + else + SAFE_NAME=$(echo "${{ env.GITHUB_BRANCH_NAME }}" | tr '/' '-') + TAG_POSTFIX="-${SAFE_NAME}-${SHORT_HASH}" + fi + # list all tags possible + if [[ "${{ matrix.config.tag }}" == "cpu" ]]; then + TYPE="" + else + TYPE="-${{ matrix.config.tag }}" + fi + PREFIX="ghcr.io/${REPO_OWNER}/${REPO_NAME}:" + FULLTAGS="${PREFIX}full${TYPE},${PREFIX}full${TYPE}${TAG_POSTFIX}" + LIGHTTAGS="${PREFIX}light${TYPE},${PREFIX}light${TYPE}${TAG_POSTFIX}" + SERVERTAGS="${PREFIX}server${TYPE},${PREFIX}server${TYPE}${TAG_POSTFIX}" + echo "full_output_tags=$FULLTAGS" >> $GITHUB_OUTPUT + echo "light_output_tags=$LIGHTTAGS" >> $GITHUB_OUTPUT + echo "server_output_tags=$SERVERTAGS" >> $GITHUB_OUTPUT + echo "full_output_tags=$FULLTAGS" # print out for debugging + echo "light_output_tags=$LIGHTTAGS" # print out for debugging + echo "server_output_tags=$SERVERTAGS" # print out for debugging + env: + GITHUB_BRANCH_NAME: ${{ github.head_ref || github.ref_name }} + GITHUB_REPOSITORY_OWNER: '${{ github.repository_owner }}' + + - name: Free Disk Space (Ubuntu) + if: ${{ matrix.config.free_disk_space == true }} + uses: ggml-org/free-disk-space@v1.3.1 + with: + # this might remove tools that are actually needed, + # if set to "true" but frees about 6 GB + tool-cache: false + + # all of these default to true, but feel free to set to + # "false" if necessary for your workflow + android: true + dotnet: true + haskell: true + large-packages: true + docker-images: true + swap-storage: true + + - name: Build and push Full Docker image (tagged + versioned) + if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.full == true }} + uses: docker/build-push-action@v6 + with: + context: . + push: true + platforms: ${{ matrix.config.platforms }} + # tag list is generated from step above + tags: ${{ steps.tag.outputs.full_output_tags }} + file: ${{ matrix.config.dockerfile }} + target: full + provenance: false + # using github experimental cache + cache-from: type=gha + cache-to: type=gha,mode=max + # return to this if the experimental github cache is having issues + #cache-to: type=local,dest=/tmp/.buildx-cache + #cache-from: type=local,src=/tmp/.buildx-cache + + - name: Build and push Light Docker image (tagged + versioned) + if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.light == true }} + uses: docker/build-push-action@v6 + with: + context: . + push: true + platforms: ${{ matrix.config.platforms }} + # tag list is generated from step above + tags: ${{ steps.tag.outputs.light_output_tags }} + file: ${{ matrix.config.dockerfile }} + target: light + provenance: false + # using github experimental cache + cache-from: type=gha + cache-to: type=gha,mode=max + # return to this if the experimental github cache is having issues + #cache-to: type=local,dest=/tmp/.buildx-cache + #cache-from: type=local,src=/tmp/.buildx-cache + + - name: Build and push Server Docker image (tagged + versioned) + if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.server == true }} + uses: docker/build-push-action@v6 + with: + context: . + push: true + platforms: ${{ matrix.config.platforms }} + # tag list is generated from step above + tags: ${{ steps.tag.outputs.server_output_tags }} + file: ${{ matrix.config.dockerfile }} + target: server + provenance: false + # using github experimental cache + cache-from: type=gha + cache-to: type=gha,mode=max + # return to this if the experimental github cache is having issues + #cache-to: type=local,dest=/tmp/.buildx-cache + #cache-from: type=local,src=/tmp/.buildx-cache diff --git a/.github/workflows/editorconfig.yml b/.github/workflows/editorconfig.yml new file mode 100644 index 0000000000000000000000000000000000000000..f02b7c2194bcf57568eb29e1d072ec9b5923f4cd --- /dev/null +++ b/.github/workflows/editorconfig.yml @@ -0,0 +1,29 @@ +name: EditorConfig Checker + +on: + workflow_dispatch: # allows manual triggering + inputs: + create_release: + description: 'Create new release' + required: true + type: boolean + push: + branches: + - master + pull_request: + branches: + - master + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + editorconfig: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: editorconfig-checker/action-editorconfig-checker@v2 + with: + version: v3.0.3 + - run: editorconfig-checker diff --git a/.github/workflows/gguf-publish.yml b/.github/workflows/gguf-publish.yml new file mode 100644 index 0000000000000000000000000000000000000000..3ca4d30581074cc1b69d7b7306eef170f0baa0d6 --- /dev/null +++ b/.github/workflows/gguf-publish.yml @@ -0,0 +1,44 @@ +# This workflow will upload a Python Package using Twine when a GGUF release is created +# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries + +# See `gguf-py/README.md` for how to make a release. + +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +name: Upload Python Package + +on: + workflow_dispatch: + push: + # Pattern matched against refs/tags + tags: + - 'gguf-v*' # Push events to every version tag + + +jobs: + deploy: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.9.x' + - name: Install dependencies + run: | + cd gguf-py + python -m pip install poetry + poetry install + + - name: Build package + run: cd gguf-py && poetry build + - name: Publish package + uses: pypa/gh-action-pypi-publish@release/v1 + with: + password: ${{ secrets.PYPI_API_TOKEN }} + packages-dir: gguf-py/dist diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml new file mode 100644 index 0000000000000000000000000000000000000000..0b0f300aa402ad80634d1338eb698311e4dc3a9c --- /dev/null +++ b/.github/workflows/labeler.yml @@ -0,0 +1,17 @@ +name: "Pull Request Labeler" +on: +- pull_request_target + +jobs: + labeler: + permissions: + contents: read + pull-requests: write + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + repository: "ggml-org/llama.cpp" + - uses: actions/labeler@v5 + with: + configuration-path: '.github/labeler.yml' diff --git a/.github/workflows/python-check-requirements.yml b/.github/workflows/python-check-requirements.yml new file mode 100644 index 0000000000000000000000000000000000000000..46e80aecd0a0cd9d44847d120e2bc9c0026dfc77 --- /dev/null +++ b/.github/workflows/python-check-requirements.yml @@ -0,0 +1,33 @@ +name: Python check requirements.txt + +on: + push: + paths: + - '.github/workflows/python-check-requirements.yml' + - 'scripts/check-requirements.sh' + - 'convert*.py' + - '**/requirements*.txt' + pull_request: + paths: + - '.github/workflows/python-check-requirements.yml' + - 'scripts/check-requirements.sh' + - 'convert*.py' + - '**/requirements*.txt' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + python-check-requirements: + runs-on: ubuntu-latest + name: check-requirements + steps: + - name: Check out source repository + uses: actions/checkout@v4 + - name: Set up Python environment + uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Run check-requirements.sh script + run: bash scripts/check-requirements.sh diff --git a/.github/workflows/python-lint.yml b/.github/workflows/python-lint.yml new file mode 100644 index 0000000000000000000000000000000000000000..ddfdf73b8fce2c12d413127eb24575dc81a8e64d --- /dev/null +++ b/.github/workflows/python-lint.yml @@ -0,0 +1,30 @@ +name: flake8 Lint + +on: + push: + branches: + - master + paths: ['.github/workflows/python-lint.yml', '**/*.py'] + pull_request: + types: [opened, synchronize, reopened] + paths: ['.github/workflows/python-lint.yml', '**/*.py'] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + flake8-lint: + runs-on: ubuntu-latest + name: Lint + steps: + - name: Check out source repository + uses: actions/checkout@v4 + - name: Set up Python environment + uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: flake8 Lint + uses: py-actions/flake8@v2 + with: + plugins: "flake8-no-print" diff --git a/.github/workflows/python-type-check.yml b/.github/workflows/python-type-check.yml new file mode 100644 index 0000000000000000000000000000000000000000..373bb601020b25597408efdb4bcbf31dc4562a18 --- /dev/null +++ b/.github/workflows/python-type-check.yml @@ -0,0 +1,40 @@ +name: Python Type-Check + +on: + push: + paths: + - '.github/workflows/python-type-check.yml' + - 'pyrightconfig.json' + - '**.py' + - '**/requirements*.txt' + pull_request: + paths: + - '.github/workflows/python-type-check.yml' + - 'pyrightconfig.json' + - '**.py' + - '**/requirements*.txt' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + python-type-check: + runs-on: ubuntu-latest + name: pyright type-check + steps: + - name: Check out source repository + uses: actions/checkout@v4 + - name: Set up Python environment + uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Install Python dependencies + # TODO: use a venv + run: pip install -r requirements/requirements-all.txt + - name: Type-check with Pyright + uses: jakebailey/pyright-action@v2 + with: + version: 1.1.382 + level: warning + warnings: true diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000000000000000000000000000000000000..9874736cbd8def25970b0c6fed22106f7b9028dc --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,749 @@ +name: Release + +on: + workflow_dispatch: # allows manual triggering + inputs: + create_release: + description: 'Create new release' + required: true + type: boolean + push: + branches: + - master + paths: ['.github/workflows/release.yml', '**/CMakeLists.txt', '**/.cmake', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m', '**/*.metal', '**/*.comp'] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +env: + BRANCH_NAME: ${{ github.head_ref || github.ref_name }} + CMAKE_ARGS: "-DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_TOOLS=ON -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON" + +jobs: + macOS-arm64: + runs-on: macos-14 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: macOS-latest-cmake-arm64 + evict-old-files: 1d + + - name: Dependencies + id: depends + continue-on-error: true + run: | + brew update + brew install curl + + - name: Build + id: cmake_build + run: | + sysctl -a + cmake -B build \ + -DCMAKE_BUILD_RPATH="@loader_path" \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DGGML_METAL_USE_BF16=ON \ + -DGGML_METAL_EMBED_LIBRARY=ON \ + -DGGML_RPC=ON \ + ${{ env.CMAKE_ARGS }} + cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) + + - name: Determine tag name + id: tag + uses: ./.github/actions/get-tag-name + + - name: Pack artifacts + id: pack_artifacts + run: | + cp LICENSE ./build/bin/ + zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip ./build/bin/* + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + path: llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip + name: llama-bin-macos-arm64.zip + + macOS-x64: + runs-on: macos-13 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: macOS-latest-cmake-x64 + evict-old-files: 1d + + - name: Dependencies + id: depends + continue-on-error: true + run: | + brew update + brew install curl + + - name: Build + id: cmake_build + run: | + sysctl -a + # Metal is disabled due to intermittent failures with Github runners not having a GPU: + # https://github.com/ggml-org/llama.cpp/actions/runs/8635935781/job/23674807267#step:5:2313 + cmake -B build \ + -DCMAKE_BUILD_RPATH="@loader_path" \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DGGML_METAL=OFF \ + -DGGML_RPC=ON + cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) + + - name: Determine tag name + id: tag + uses: ./.github/actions/get-tag-name + + - name: Pack artifacts + id: pack_artifacts + run: | + cp LICENSE ./build/bin/ + zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip ./build/bin/* + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip + name: llama-bin-macos-x64.zip + + ubuntu-22-cpu: + strategy: + matrix: + include: + - build: 'x64' + os: ubuntu-22.04 + # GGML_BACKEND_DL and GGML_CPU_ALL_VARIANTS are not currently supported on arm + # - build: 'arm64' + # os: ubuntu-22.04-arm + + runs-on: ${{ matrix.os }} + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-cpu-cmake + evict-old-files: 1d + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential libcurl4-openssl-dev + + - name: Build + id: cmake_build + run: | + cmake -B build \ + -DGGML_BACKEND_DL=ON \ + -DGGML_NATIVE=OFF \ + -DGGML_CPU_ALL_VARIANTS=ON \ + -DLLAMA_FATAL_WARNINGS=ON \ + ${{ env.CMAKE_ARGS }} + cmake --build build --config Release -j $(nproc) + + - name: Determine tag name + id: tag + uses: ./.github/actions/get-tag-name + + - name: Pack artifacts + id: pack_artifacts + run: | + cp LICENSE ./build/bin/ + zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip ./build/bin/* + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip + name: llama-bin-ubuntu-${{ matrix.build }}.zip + + ubuntu-22-vulkan: + runs-on: ubuntu-22.04 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-22-cmake-vulkan + evict-old-files: 1d + + - name: Dependencies + id: depends + run: | + wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | sudo apt-key add - + sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list + sudo apt-get update -y + sudo apt-get install -y build-essential mesa-vulkan-drivers vulkan-sdk libcurl4-openssl-dev + + - name: Build + id: cmake_build + run: | + cmake -B build \ + -DGGML_BACKEND_DL=ON \ + -DGGML_NATIVE=OFF \ + -DGGML_CPU_ALL_VARIANTS=ON \ + -DGGML_VULKAN=ON \ + ${{ env.CMAKE_ARGS }} + cmake --build build --config Release -j $(nproc) + + - name: Determine tag name + id: tag + uses: ./.github/actions/get-tag-name + + - name: Pack artifacts + id: pack_artifacts + run: | + cp LICENSE ./build/bin/ + zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip ./build/bin/* + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip + name: llama-bin-ubuntu-vulkan-x64.zip + + windows-cpu: + runs-on: windows-latest + + strategy: + matrix: + include: + - arch: 'x64' + - arch: 'arm64' + + steps: + - name: Clone + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: windows-latest-cmake-cpu-${{ matrix.arch }} + variant: ccache + evict-old-files: 1d + + - name: Install Ninja + run: | + choco install ninja + + - name: libCURL + id: get_libcurl + uses: ./.github/actions/windows-setup-curl + with: + architecture: ${{ matrix.arch == 'x64' && 'win64' || 'win64a' }} + + - name: Build + shell: cmd + env: + CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} + run: | + call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" ${{ matrix.arch }} + cmake -S . -B build -G "Ninja Multi-Config" ^ + -D CMAKE_TOOLCHAIN_FILE=cmake/${{ matrix.arch }}-windows-llvm.cmake ^ + -DGGML_NATIVE=OFF ^ + -DGGML_BACKEND_DL=ON ^ + -DGGML_CPU_ALL_VARIANTS=${{ matrix.arch == 'x64' && 'ON' || 'OFF' }} ^ + -DGGML_OPENMP=ON ^ + -DCURL_LIBRARY="%CURL_PATH%/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="%CURL_PATH%/include" ^ + ${{ env.CMAKE_ARGS }} + cmake --build build --config Release + + - name: Pack artifacts + id: pack_artifacts + env: + CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} + run: | + Copy-Item $env:CURL_PATH\bin\libcurl-${{ matrix.arch }}.dll .\build\bin\Release\ + Copy-Item "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Redist\MSVC\14.42.34433\debug_nonredist\${{ matrix.arch }}\Microsoft.VC143.OpenMP.LLVM\libomp140.${{ matrix.arch == 'x64' && 'x86_64' || 'aarch64' }}.dll" .\build\bin\Release\ + 7z a llama-bin-win-cpu-${{ matrix.arch }}.zip .\build\bin\Release\* + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + path: llama-bin-win-cpu-${{ matrix.arch }}.zip + name: llama-bin-win-cpu-${{ matrix.arch }}.zip + + windows: + runs-on: windows-latest + + env: + OPENBLAS_VERSION: 0.3.23 + VULKAN_VERSION: 1.4.309.0 + + strategy: + matrix: + include: + - backend: 'vulkan' + arch: 'x64' + defines: '-DGGML_VULKAN=ON' + target: 'ggml-vulkan' + - backend: 'opencl-adreno' + arch: 'arm64' + defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" -DGGML_OPENCL=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON' + target: 'ggml-opencl' + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: windows-latest-cmake-${{ matrix.backend }}-${{ matrix.arch }} + variant: ccache + evict-old-files: 1d + + - name: Install Vulkan SDK + id: get_vulkan + if: ${{ matrix.backend == 'vulkan' }} + run: | + curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/VulkanSDK-${env:VULKAN_VERSION}-Installer.exe" + & "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install + Add-Content $env:GITHUB_ENV "VULKAN_SDK=C:\VulkanSDK\${env:VULKAN_VERSION}" + Add-Content $env:GITHUB_PATH "C:\VulkanSDK\${env:VULKAN_VERSION}\bin" + + - name: Install Ninja + id: install_ninja + run: | + choco install ninja + + - name: Install OpenCL Headers and Libs + id: install_opencl + if: ${{ matrix.backend == 'opencl-adreno' && matrix.arch == 'arm64' }} + run: | + git clone https://github.com/KhronosGroup/OpenCL-Headers + cd OpenCL-Headers + cmake -B build ` + -DBUILD_TESTING=OFF ` + -DOPENCL_HEADERS_BUILD_TESTING=OFF ` + -DOPENCL_HEADERS_BUILD_CXX_TESTS=OFF ` + -DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release" + cmake --build build --target install + git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader + cd OpenCL-ICD-Loader + cmake -B build-arm64-release ` + -A arm64 ` + -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" ` + -DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release" + cmake --build build-arm64-release --target install --config release + + - name: Build + id: cmake_build + run: | + cmake -S . -B build ${{ matrix.defines }} -DGGML_NATIVE=OFF -DGGML_CPU=OFF -DGGML_BACKEND_DL=ON -DLLAMA_CURL=OFF + cmake --build build --config Release --target ${{ matrix.target }} + + - name: Pack artifacts + id: pack_artifacts + run: | + 7z a llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip .\build\bin\Release\${{ matrix.target }}.dll + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + path: llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip + name: llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip + + windows-cuda: + runs-on: windows-2022 + + strategy: + matrix: + cuda: ['12.4'] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Install ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: windows-cuda-${{ matrix.cuda }} + variant: ccache + evict-old-files: 1d + + - name: Install Cuda Toolkit + uses: ./.github/actions/windows-setup-cuda + with: + cuda_version: ${{ matrix.cuda }} + + - name: Install Ninja + id: install_ninja + run: | + choco install ninja + + - name: Build + id: cmake_build + shell: cmd + run: | + call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64 + cmake -S . -B build -G "Ninja Multi-Config" ^ + -DGGML_BACKEND_DL=ON ^ + -DGGML_NATIVE=OFF ^ + -DGGML_CPU=OFF ^ + -DGGML_CUDA=ON ^ + -DLLAMA_CURL=OFF + set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 + cmake --build build --config Release -j %NINJA_JOBS% --target ggml-cuda + + - name: Pack artifacts + id: pack_artifacts + run: | + 7z a llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip .\build\bin\Release\ggml-cuda.dll + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + path: llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip + name: llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip + + - name: Copy and pack Cuda runtime + run: | + echo "Cuda install location: ${{ env.CUDA_PATH }}" + $dst='.\build\bin\cudart\' + robocopy "${{env.CUDA_PATH}}\bin" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll + robocopy "${{env.CUDA_PATH}}\lib" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll + 7z a cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip $dst\* + + - name: Upload Cuda runtime + uses: actions/upload-artifact@v4 + with: + path: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip + name: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip + + windows-sycl: + runs-on: windows-latest + + defaults: + run: + shell: bash + + env: + WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/7cd9bba0-7aab-4e30-b3ae-2221006a4a05/intel-oneapi-base-toolkit-2025.1.1.34_offline.exe + WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel:intel.oneapi.win.dnnl:intel.oneapi.win.tbb.devel + ONEAPI_ROOT: "C:/Program Files (x86)/Intel/oneAPI" + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: windows-latest-cmake-sycl + variant: ccache + evict-old-files: 1d + + - name: Install + run: | + scripts/install-oneapi.bat $WINDOWS_BASEKIT_URL $WINDOWS_DPCPP_MKL + + - name: Build + id: cmake_build + shell: cmd + run: | + call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force + cmake -G "Ninja" -B build ^ + -DCMAKE_C_COMPILER=cl -DCMAKE_CXX_COMPILER=icx ^ + -DCMAKE_BUILD_TYPE=Release ^ + -DGGML_BACKEND_DL=ON -DBUILD_SHARED_LIBS=ON ^ + -DGGML_CPU=OFF -DGGML_SYCL=ON ^ + -DLLAMA_CURL=OFF + cmake --build build --target ggml-sycl -j + + - name: Build the release package + id: pack_artifacts + run: | + echo "cp oneAPI running time dll files in ${{ env.ONEAPI_ROOT }} to ./build/bin" + + cp "${{ env.ONEAPI_ROOT }}/mkl/latest/bin/mkl_sycl_blas.5.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/mkl/latest/bin/mkl_core.2.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/mkl/latest/bin/mkl_tbb_thread.2.dll" ./build/bin + + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_adapter_level_zero.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_adapter_opencl.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_loader.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_win_proxy_loader.dll" ./build/bin + + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/sycl8.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/svml_dispmd.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/libmmd.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/libiomp5md.dll" ./build/bin + + cp "${{ env.ONEAPI_ROOT }}/dnnl/latest/bin/dnnl.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/tbb/latest/bin/tbb12.dll" ./build/bin + + echo "cp oneAPI running time dll files to ./build/bin done" + 7z a llama-bin-win-sycl-x64.zip ./build/bin/* + + - name: Upload the release package + uses: actions/upload-artifact@v4 + with: + path: llama-bin-win-sycl-x64.zip + name: llama-bin-win-sycl-x64.zip + + windows-hip: + runs-on: windows-latest + + strategy: + matrix: + include: + - name: "radeon" + gpu_targets: "gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032" + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Clone rocWMMA repository + id: clone_rocwmma + run: | + git clone https://github.com/rocm/rocwmma --branch rocm-6.2.4 --depth 1 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: windows-latest-cmake-hip-${{ matrix.name }}-x64 + evict-old-files: 1d + + - name: Install + id: depends + run: | + $ErrorActionPreference = "Stop" + write-host "Downloading AMD HIP SDK Installer" + Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" + write-host "Installing AMD HIP SDK" + Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait + write-host "Completed AMD HIP SDK installation" + + - name: Verify ROCm + id: verify + run: | + & 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version + + - name: Build + id: cmake_build + run: | + $env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path) + $env:CMAKE_PREFIX_PATH="${env:HIP_PATH}" + cmake -G "Unix Makefiles" -B build -S . ` + -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" ` + -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" ` + -DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/rocwmma/library/include/ -Wno-ignored-attributes -Wno-nested-anon-types" ` + -DCMAKE_BUILD_TYPE=Release ` + -DGGML_BACKEND_DL=ON ` + -DGGML_NATIVE=OFF ` + -DGGML_CPU=OFF ` + -DAMDGPU_TARGETS="${{ matrix.gpu_targets }}" ` + -DGGML_HIP_ROCWMMA_FATTN=ON ` + -DGGML_HIP=ON ` + -DLLAMA_CURL=OFF + cmake --build build --target ggml-hip -j ${env:NUMBER_OF_PROCESSORS} + md "build\bin\rocblas\library\" + cp "${env:HIP_PATH}\bin\hipblas.dll" "build\bin\" + cp "${env:HIP_PATH}\bin\rocblas.dll" "build\bin\" + cp "${env:HIP_PATH}\bin\rocblas\library\*" "build\bin\rocblas\library\" + + - name: Pack artifacts + id: pack_artifacts + run: | + 7z a llama-bin-win-hip-${{ matrix.name }}-x64.zip .\build\bin\* + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + path: llama-bin-win-hip-${{ matrix.name }}-x64.zip + name: llama-bin-win-hip-${{ matrix.name }}-x64.zip + + ios-xcode-build: + runs-on: macos-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Build + id: cmake_build + run: | + sysctl -a + cmake -B build -G Xcode \ + -DGGML_METAL_USE_BF16=ON \ + -DGGML_METAL_EMBED_LIBRARY=ON \ + -DLLAMA_CURL=OFF \ + -DLLAMA_BUILD_EXAMPLES=OFF \ + -DLLAMA_BUILD_TOOLS=OFF \ + -DLLAMA_BUILD_TESTS=OFF \ + -DLLAMA_BUILD_SERVER=OFF \ + -DCMAKE_SYSTEM_NAME=iOS \ + -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \ + -DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml + cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO + + - name: xcodebuild for swift package + id: xcodebuild + run: | + ./build-xcframework.sh + + - name: Build Xcode project + run: xcodebuild -project examples/llama.swiftui/llama.swiftui.xcodeproj -scheme llama.swiftui -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' FRAMEWORK_FOLDER_PATH=./build-ios build + + - name: Determine tag name + id: tag + uses: ./.github/actions/get-tag-name + + - name: Pack artifacts + id: pack_artifacts + run: | + zip --symlinks -r llama-${{ steps.tag.outputs.name }}-xcframework.zip build-apple/llama.xcframework + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + path: llama-${{ steps.tag.outputs.name }}-xcframework.zip + name: llama-${{ steps.tag.outputs.name }}-xcframework + + release: + if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} + + # Fine-grant permission + # https://docs.github.com/en/actions/security-for-github-actions/security-guides/automatic-token-authentication#modifying-the-permissions-for-the-github_token + permissions: + contents: write # for creating release + + runs-on: ubuntu-latest + + needs: + - windows + - windows-cpu + - windows-cuda + - windows-sycl + - windows-hip + - ubuntu-22-cpu + - ubuntu-22-vulkan + - macOS-arm64 + - macOS-x64 + - ios-xcode-build + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Determine tag name + id: tag + uses: ./.github/actions/get-tag-name + + - name: Download artifacts + id: download-artifact + uses: actions/download-artifact@v4 + with: + path: ./artifact + merge-multiple: true + + - name: Move artifacts + id: move_artifacts + run: | + mkdir -p release + + echo "Adding CPU backend files to existing zips..." + for arch in x64 arm64; do + cpu_zip="artifact/llama-bin-win-cpu-${arch}.zip" + temp_dir=$(mktemp -d) + echo "Extracting CPU backend for $arch..." + unzip "$cpu_zip" -d "$temp_dir" + + echo "Adding CPU files to $arch zips..." + for target_zip in artifact/llama-bin-win-*-${arch}.zip; do + if [[ "$target_zip" == "$cpu_zip" ]]; then + continue + fi + echo "Adding CPU backend to $(basename "$target_zip")" + realpath_target_zip=$(realpath "$target_zip") + (cd "$temp_dir" && zip -r "$realpath_target_zip" .) + done + + rm -rf "$temp_dir" + done + + echo "Renaming and moving zips to release..." + for zip_file in artifact/llama-bin-win-*.zip; do + base_name=$(basename "$zip_file" .zip) + zip_name="llama-${{ steps.tag.outputs.name }}-${base_name#llama-}.zip" + echo "Moving $zip_file to release/$zip_name" + mv "$zip_file" "release/$zip_name" + done + + echo "Moving other artifacts..." + mv -v artifact/*.zip release + + - name: Create release + id: create_release + uses: ggml-org/action-create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ steps.tag.outputs.name }} + + - name: Upload release + id: upload_release + uses: actions/github-script@v3 + with: + github-token: ${{secrets.GITHUB_TOKEN}} + script: | + const path = require('path'); + const fs = require('fs'); + const release_id = '${{ steps.create_release.outputs.id }}'; + for (let file of await fs.readdirSync('./release')) { + if (path.extname(file) === '.zip') { + console.log('uploadReleaseAsset', file); + await github.repos.uploadReleaseAsset({ + owner: context.repo.owner, + repo: context.repo.repo, + release_id: release_id, + name: file, + data: await fs.readFileSync(`./release/${file}`) + }); + } + } diff --git a/.github/workflows/server.yml b/.github/workflows/server.yml new file mode 100644 index 0000000000000000000000000000000000000000..f6da488576937915d9df4b434aff3dbadc4ddfd1 --- /dev/null +++ b/.github/workflows/server.yml @@ -0,0 +1,237 @@ +# Server build and tests +name: Server + +on: + workflow_dispatch: # allows manual triggering + inputs: + sha: + description: 'Commit SHA1 to build' + required: false + type: string + slow_tests: + description: 'Run slow tests' + required: true + type: boolean + push: + branches: + - master + paths: ['.github/workflows/server.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'tools/server/**.*'] + pull_request: + types: [opened, synchronize, reopened] + paths: ['.github/workflows/server.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'tools/server/**.*'] + +env: + LLAMA_LOG_COLORS: 1 + LLAMA_LOG_PREFIX: 1 + LLAMA_LOG_TIMESTAMPS: 1 + LLAMA_LOG_VERBOSITY: 10 + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + server: + runs-on: ubuntu-latest + + strategy: + matrix: + sanitizer: [ADDRESS, UNDEFINED] # THREAD is broken + build_type: [RelWithDebInfo] + include: + - build_type: Release + sanitizer: "" + fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken + + steps: + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get -y install \ + build-essential \ + xxd \ + git \ + cmake \ + curl \ + wget \ + language-pack-en \ + libcurl4-openssl-dev + + - name: Clone + id: checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} + + - name: Python setup + id: setup_python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Tests dependencies + id: test_dependencies + run: | + pip install -r tools/server/tests/requirements.txt + + # Setup nodejs (to be used for verifying bundled index.html) + - uses: actions/setup-node@v4 + with: + node-version: '22.11.0' + + - name: WebUI - Install dependencies + id: webui_lint + run: | + cd tools/server/webui + npm ci + + - name: WebUI - Check code format + id: webui_format + run: | + git config --global --add safe.directory $(realpath .) + cd tools/server/webui + git status + + npm run format + git status + modified_files="$(git status -s)" + echo "Modified files: ${modified_files}" + if [ -n "${modified_files}" ]; then + echo "Files do not follow coding style. To fix: npm run format" + echo "${modified_files}" + exit 1 + fi + + - name: Verify bundled index.html + id: verify_server_index_html + run: | + git config --global --add safe.directory $(realpath .) + cd tools/server/webui + git status + + npm run build + git status + modified_files="$(git status -s)" + echo "Modified files: ${modified_files}" + if [ -n "${modified_files}" ]; then + echo "Repository is dirty or server/webui is not built as expected" + echo "Hint: You may need to follow Web UI build guide in server/README.md" + echo "${modified_files}" + exit 1 + fi + + - name: Build (no OpenMP) + id: cmake_build_no_openmp + if: ${{ matrix.sanitizer == 'THREAD' }} + run: | + cmake -B build \ + -DGGML_NATIVE=OFF \ + -DLLAMA_BUILD_SERVER=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ + -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DGGML_OPENMP=OFF ; + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server + + - name: Build (sanitizers) + id: cmake_build_sanitizers + if: ${{ matrix.sanitizer != '' && matrix.sanitizer != 'THREAD' }} + run: | + cmake -B build \ + -DGGML_NATIVE=OFF \ + -DLLAMA_BUILD_SERVER=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ + -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ; + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server + + - name: Build (sanitizers) + id: cmake_build + if: ${{ matrix.sanitizer == '' }} + run: | + cmake -B build \ + -DGGML_NATIVE=OFF \ + -DLLAMA_BUILD_SERVER=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ; + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server + + - name: Tests + id: server_integration_tests + if: ${{ matrix.sanitizer == '' }} + env: + GITHUB_ACTIONS: "true" + run: | + cd tools/server/tests + ./tests.sh + + - name: Tests (sanitizers) + id: server_integration_tests_sanitizers + if: ${{ matrix.sanitizer != '' }} + run: | + cd tools/server/tests + LLAMA_SANITIZE=1 ./tests.sh + + - name: Slow tests + id: server_integration_tests_slow + if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }} + run: | + cd tools/server/tests + SLOW_TESTS=1 ./tests.sh + + + server-windows: + runs-on: windows-2022 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} + + - name: libCURL + id: get_libcurl + uses: ./.github/actions/windows-setup-curl + + - name: Build + id: cmake_build + env: + CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} + run: | + cmake -B build -DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include" + cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS} --target llama-server + + - name: Python setup + id: setup_python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Tests dependencies + id: test_dependencies + run: | + pip install -r tools/server/tests/requirements.txt + + - name: Copy Libcurl + id: prepare_libcurl + env: + CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} + run: | + cp $env:CURL_PATH/bin/libcurl-x64.dll ./build/bin/Release/libcurl-x64.dll + + - name: Tests + id: server_integration_tests + if: ${{ !matrix.disabled_on_pr || !github.event.pull_request }} + run: | + cd tools/server/tests + $env:PYTHONIOENCODING = ":replace" + pytest -v -x -m "not slow" + + - name: Slow tests + id: server_integration_tests_slow + if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }} + run: | + cd tools/server/tests + $env:SLOW_TESTS = "1" + pytest -v -x diff --git a/.github/workflows/winget.yml b/.github/workflows/winget.yml new file mode 100644 index 0000000000000000000000000000000000000000..5c286155951e5ef3bb8d8e5b923ccc3b13456c04 --- /dev/null +++ b/.github/workflows/winget.yml @@ -0,0 +1,42 @@ +name: Update Winget Package + +on: + workflow_dispatch: # allows manual triggering + schedule: + - cron: '28 5 * * *' # Update every day at 5:28 UTC + +jobs: + update: + name: Update Winget Package + runs-on: ubuntu-latest + + steps: + - name: Install cargo binstall + uses: cargo-bins/cargo-binstall@268643a6b5ea099f5718ee5cd3ff7dc89a5eb49b + + - name: Install komac + run: | + cargo binstall komac@2.11.2 -y + + - name: Find latest release + id: find_latest_release + uses: actions/github-script@v6 + with: + script: | + const { data: releases } = await github.rest.repos.listReleases({ + owner: context.repo.owner, + repo: context.repo.repo, + }); + console.log("Latest release:", releases[0].tag_name); + return releases[0].tag_name; + + - name: Update manifest + env: + VERSION: ${{ steps.find_latest_release.outputs.result }} + run: | + echo "Updating manifest..." + komac update --version ${{ env.VERSION }} \ + --urls "https://github.com/ggml-org/llama.cpp/releases/download/${{ env.VERSION }}/llama-${{ env.VERSION }}-bin-win-vulkan-x64.zip" \ + --token ${{ secrets.WINGET_GITHUB_TOKEN }} \ + --submit \ + ggml.llamacpp diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f8ceb1560a1dfe33183378abaf9115cd9eabb1f9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,148 @@ +# Extensions + +*.a +*.bat +*.bin +*.d +*.dll +*.dot +*.etag +*.exe +*.gcda +*.gcno +*.gcov +*.gguf +*.gguf.json +*.lastModified +*.log +*.metallib +*.o +*.so +*.swp +*.tmp + +# IDE / OS + +.cache/ +.ccls-cache/ +.direnv/ +.DS_Store +.envrc +.idea/ +.swiftpm +.vs/ +.vscode/ +nppBackup + + +# Coverage + +gcovr-report/ +lcov-report/ + +# Build Artifacts + +tags +.build/ +build* +release +debug +!build-info.cmake +!build-info.cpp.in +!build-info.sh +!build.zig +!docs/build.md +/libllama.so +/llama-* +/vulkan-shaders-gen +android-ndk-* +arm_neon.h +cmake-build-* +CMakeSettings.json +compile_commands.json +ggml-metal-embed.metal +llama-batched-swift +/rpc-server +out/ +tmp/ +autogen-*.md + +# Deprecated + +/main +/server + +# CI + +!.github/workflows/*.yml + +# Models + +models/* +models-mnt +!models/.editorconfig +!models/ggml-vocab-*.gguf* + +# Zig +zig-out/ +zig-cache/ + +# Logs + +ppl-*.txt +qnt-*.txt +perf-*.txt + +# Examples + +examples/jeopardy/results.txt +tools/server/*.css.hpp +tools/server/*.html.hpp +tools/server/*.js.hpp +tools/server/*.mjs.hpp +tools/server/*.gz.hpp +!build_64.sh +!examples/*.bat +!examples/*/*.kts +!examples/*/*/*.kts +!examples/sycl/*.bat +!examples/sycl/*.sh + +# Server Web UI temporary files +node_modules +tools/server/webui/dist + +# Python + +/.venv +__pycache__/ +*/poetry.lock +poetry.toml + +# Nix +/result + +# Test binaries +/tests/test-backend-ops +/tests/test-double-float +/tests/test-grad0 +/tests/test-grammar-parser +/tests/test-llama-grammar +/tests/test-opt +/tests/test-quantize-fns +/tests/test-quantize-perf +/tests/test-rope +/tests/test-sampling +/tests/test-tokenizer-0 +/tests/test-tokenizer-1-bpe +/tests/test-tokenizer-1-spm + +# Scripts +!/scripts/install-oneapi.bat + +# Test models for lora adapters +/lora-tests + +# Local scripts +/run-vim.sh +/run-chat.sh diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..23ce5ff059b1bd9a25266c963d2e0ba271cce693 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "kompute"] + path = ggml/src/ggml-kompute/kompute + url = https://github.com/nomic-ai/kompute.git diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..91d7916285081aa14f5b801935cb35d031f2e601 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,16 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +exclude: prompts/.*.txt +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files +- repo: https://github.com/PyCQA/flake8 + rev: 7.0.0 + hooks: + - id: flake8 + additional_dependencies: [flake8-no-print] diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 0000000000000000000000000000000000000000..0af9f44ad4a16bd57d739acad7bd1e3f45256677 --- /dev/null +++ b/AUTHORS @@ -0,0 +1,1106 @@ +# date: Sat Mar 8 18:23:52 EET 2025 +# this file is auto-generated by scripts/gen-authors.sh + +0cc4m +0xspringtime <110655352+0xspringtime@users.noreply.github.com> +20kdc +2f38b454 +3ooabkhxtn <31479382+3ooabkhxtn@users.noreply.github.com> +44670 <44670@users.noreply.github.com> +65a <10104049+65a@users.noreply.github.com> +708-145 <40387547+708-145@users.noreply.github.com> +AN Long +AT +Aarni Koskela +Aaron Miller +Aaron Teo <57927438+taronaeo@users.noreply.github.com> +Aaryaman Vasishta +Abheek Gulati +Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com> +Abhishek Gopinath K <31348521+overtunned@users.noreply.github.com> +Adithya Balaji +AdithyanI +Adrian +Adrian Hesketh +Adrian Kretz +Adrien Gallouët +Adrien Gallouët +Ahmad Tameem <113388789+Tameem-10xE@users.noreply.github.com> +Ahmet Zeer +AidanBeltonS <87009434+AidanBeltonS@users.noreply.github.com> +AidanBeltonS +Aisuko +Akarshan Biswas +Akarshan Biswas +Akarshan Biswas +Al Mochkin <14274697+amochkin@users.noreply.github.com> +Albert Jin +Alberto <57916483+albbus-stack@users.noreply.github.com> +Alberto Cabrera Pérez +Alberto Cabrera Pérez +Aleksei Nikiforov <103434461+AlekseiNikiforovIBM@users.noreply.github.com> +Alex +Alex Azarov +Alex Azarov +Alex Brooks +Alex Klinkhamer +Alex Klinkhamer +Alex Nguyen +Alex O'Connell <35843486+acon96@users.noreply.github.com> +Alex Petenchea +Alex Renda +Alex Tuddenham <61622354+AlexsCode@users.noreply.github.com> +Alex von Gluck IV +Alexey Parfenov +Ali Chraghi <63465728+alichraghi@users.noreply.github.com> +Ali Nehzat +Ali Tariq +Alon +AlpinDale <52078762+AlpinDale@users.noreply.github.com> +Amir +AmirAli Mirian <37371367+amiralimi@users.noreply.github.com> +Ananta Bastola +Anas Ahouzi <112881240+aahouzi@users.noreply.github.com> +András Salamon +Andreas (Andi) Kunar +Andreas Kieslinger <47689530+aendk@users.noreply.github.com> +Andrei +Andrew Canis +Andrew Downing +Andrew Duffy +Andrew Godfrey +Andrew Minh Nguyen <40281306+amqdn@users.noreply.github.com> +Andy Salerno +Andy Tai +Anthony Van de Gejuchte +Antoine Viallon +Antonis Makropoulos +Arik Poznanski +Armen Kaleshian +Artem +Artem Zinnatullin +Artyom Lebedev +Asbjørn Olling +Ásgeir Bjarni Ingvarsson +Asghar Ghorbani +Ashish <1856117+ashishdatta@users.noreply.github.com> +Ashok Gelal <401055+ashokgelal@users.noreply.github.com> +Ashraful Islam +Atsushi Tatsuma +Austin <77757836+teleprint-me@users.noreply.github.com> +AustinMroz +BADR +BB-fat <45072480+BB-fat@users.noreply.github.com> +Bach Le +Bailey Chittle <39804642+bachittle@users.noreply.github.com> +BarfingLemurs <128182951+BarfingLemurs@users.noreply.github.com> +Bartowski +Behnam M <58621210+ibehnam@users.noreply.github.com> +Ben Ashbaugh +Ben Garney +Ben Siraphob +Ben Williams +Benjamin Findley <39356821+Kartoffelsaft@users.noreply.github.com> +Benjamin Lecaillon <84293038+blecaillon@users.noreply.github.com> +Benson Wong +Bernat Vadell +Bernhard M. Wiedemann +Bert Wagner +Billel Mokeddem +Bingan <70050083+binganao@users.noreply.github.com> +Bjarke Viksøe <164612031+bviksoe@users.noreply.github.com> +Bodhi <3882561+BodhiHu@users.noreply.github.com> +Bodo Graumann +Bono Lv +Borislav Stanimirov +Borislav Stanimirov +Branden Butler +Brandon Squizzato <35474886+bsquizz@users.noreply.github.com> +Brian +Brian Cunnie +Bruce MacDonald +Bryan Honof +CJ Pais +CRD716 +Calvin Laurenson +Cameron +Cameron Kaiser +Carolinabanana <140120812+Carolinabanana@users.noreply.github.com> +CarryFun <76023481+CarryFun@users.noreply.github.com> +Carsten Kragelund Jørgensen +CarterLi999 <664681047@qq.com> +Casey Primozic +Casey Primozic +CausalLM <148736309+CausalLM@users.noreply.github.com> +Cebtenzzre +CentricStorm +Chad Brewbaker +Changyeon Kim +Chao Jiang +Charles Duffy +Charles Xu <63788048+chaxu01@users.noreply.github.com> +Charles Xu +Chen Xi +Chen Xi +Cheng Shao +Chenguang Li <87689256+noemotiovon@users.noreply.github.com> +Chris Elrod +Chris Kuehl +Christian Demsar +Christian Demsar +Christian Falch <875252+chrfalch@users.noreply.github.com> +Christian Fillion +Christian Kastner +Christian Kögler +Christian Köhnenkamp +Christian Zhou-Zheng <59622928+christianazinn@users.noreply.github.com> +Christopher Nielsen <62156882+mascguy@users.noreply.github.com> +Clark Saben <76020733+csaben@users.noreply.github.com> +Clauszy +Clint Herron +Conrad Kramer +Corentin REGAL +CrispStrobe <154636388+CrispStrobe@users.noreply.github.com> +Csaba Kecskemeti +Cuong Trinh Manh +DAN™ +Damian Stewart +Dan Johansson <164997844+eddnjjn@users.noreply.github.com> +Dan Johansson +Dane Madsen +DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> +Daniel Bevenius +Daniel Drake +Daniel Hiltgen +Daniel Illescas Romero +Daniel Kleine <53251018+d-kleine@users.noreply.github.com> +Daniele <57776841+daniandtheweb@users.noreply.github.com> +Danny Milosavljevic +DannyDaemonic +Dat Quoc Nguyen <2412555+datquocnguyen@users.noreply.github.com> +Dave +Dave Airlie +Dave Airlie +Dave Della Costa +David Friehs +David Huang <1969802+hjc4869@users.noreply.github.com> +David Kennedy +David Pflug +David Renshaw +David Sommers <12738+databyte@users.noreply.github.com> +David Yang +DavidKorczynski +Dawid Potocki +Dawid Wysocki <62249621+TortillaZHawaii@users.noreply.github.com> +Dean +Deins +Denis Spasyuk <34203011+dspasyuk@users.noreply.github.com> +Derrick T. Woolworth +Deven Mistry <31466137+deven367@users.noreply.github.com> +Dibakar Gope +Didzis Gosko +Diego Devesa +Diogo Teles Sant'Anna +Djip007 <3705339+Djip007@users.noreply.github.com> +Djip007 +Don Mahurin +DooWoong Lee (David) +Doomsdayrs <38189170+Doomsdayrs@users.noreply.github.com> +Dou Xinpeng <15529241576@163.com> +Dou Xinpeng <81913537+Dou-Git@users.noreply.github.com> +Douglas Hanley +Dr. Tom Murphy VII Ph.D <499244+tom7@users.noreply.github.com> +Ebey Abraham +Echo Nolan +Ed Lee +Ed Lepedus +Eddie-Wang +Edward Taylor +Elaine +Elbios <141279586+Elbios@users.noreply.github.com> +Elton Kola +Emreerdog <34742675+Emreerdog@users.noreply.github.com> +Engininja2 <139037756+Engininja2@users.noreply.github.com> +Equim +Eric Curtin +Eric Curtin +Eric Sommerlade +Eric Zhang <34133756+EZForever@users.noreply.github.com> +Erik Garrison +Erik Scholz +Esko Toivonen +Ettore Di Giacinto +Evan Jones +Evan Miller +Eve <139727413+netrunnereve@users.noreply.github.com> +Evgeny Kurnevsky +Ewout ter Hoeven +ExtReMLapin <3909752+ExtReMLapin@users.noreply.github.com> +FK +Fabian +Fabio R. Sluzala +Faez Shakil +Faisal Zaghloul +Faisal Zaghloul +Fan Shupei +FantasyGmm <16450052+FantasyGmm@users.noreply.github.com> +Farbod Bijary <110523279+farbodbj@users.noreply.github.com> +Fattire <528174+fat-tire@users.noreply.github.com> +Felix +Finn Voorhees +Firat +FirstTimeEZ <179362031+FirstTimeEZ@users.noreply.github.com> +Florent BENOIT +Folko-Ven <71110216+Folko-Ven@users.noreply.github.com> +Foul-Tarnished <107711110+Foul-Tarnished@users.noreply.github.com> +Francisco Melo <43780565+francis2tm@users.noreply.github.com> +Frank Mai +FrankHB +Frankie Robertson +Fred Douglas <43351173+fredlas@users.noreply.github.com> +Frederik Vogel +Gabe Goodhart +Gabe Goodhart +Gaetan Bisson +GainLee +Galunid +Gary Linscott +Gary Mulder +Gavin Zhao +Genkagaku.GPT +Georgi Gerganov +Gian-Carlo Pascutto +Gilad S +Gilad S. <7817232+giladgd@users.noreply.github.com> +Giuseppe Scrivano +GiviMAD +Govlzkoy +Guillaume "Vermeille" Sanchez +Guillaume Wenzek +Guoliang Hua <32868157+nbcsm@users.noreply.github.com> +Guoteng <32697156+SolenoidWGT@users.noreply.github.com> +Guspan Tanadi <36249910+guspan-tanadi@users.noreply.github.com> +Gustavo Rocha Dias <91472747+gustrd@users.noreply.github.com> +Haggai Nuchi +Halalaluyafail3 <55773281+Halalaluyafail3@users.noreply.github.com> +Hale Chan +Hamdoud Hakem <90524568+hamdoudhakem@users.noreply.github.com> +Han Yin +HanishKVC +Haohui Mai +Haoxiang Fei +Harald Fernengel +Hatsune Miku <129688334+at8u@users.noreply.github.com> +HatsuneMikuUwU33 <173229399+HatsuneMikuUwU33@users.noreply.github.com> +Haus1 +Henk Poley +Henri Vasserman +Henrik Forstén +Henry Linjamäki +Herman Semenov +Hesen Peng +HimariO +Hoang Nguyen +Hong Bo PENG +Hongyu Ouyang <96765450+casavaca@users.noreply.github.com> +Howard Su +Hua Jiang +Huang Qi +Huawei Lin +Hugo Roussel +Huifeng Ou <79071290+ho2103@users.noreply.github.com> +Ian Bull +Ian Bull +Ian Scrivener +Icecream95 +Ido S +IgnacioFDM +Igor Okulist +Ihar Hrachyshka +Ikko Eltociear Ashimine +Ilya Kurdyukov <59548320+ilyakurdyukov@users.noreply.github.com> +Ionoclast Laboratories +Isaac McFadyen +IsaacDynamo <61521674+IsaacDynamo@users.noreply.github.com> +Ivan +Ivan Filipov <159561759+vanaka11@users.noreply.github.com> +Ivan Komarov +Ivan Stepanov +JC <43374599+MrSMlT@users.noreply.github.com> +JFLFY2255 +JH23X <165871467+JH23X@users.noreply.github.com> +Jack Mousseau +Jack Mousseau +JackJollimore <130917767+JackJollimore@users.noreply.github.com> +Jaeden Amero +Jaemin Son +Jafar Uruç +Jag Chadha +Jakub N +James A Capozzoli <157492257+jac-jim@users.noreply.github.com> +James Reynolds +Jan Boon +Jan Boon +Jan Ploski +Jannis Schönleber +Jared Van Bortel +Jared Van Bortel +Jason C.H +Jason McCartney +Jason Stillerman +Jean-Christophe Hoelt +Jean-Michaël Celerier +Jed Fox +Jeff Bolz +Jeffrey Morgan +Jeffrey Quesnelle +Jeroen Mostert +Jesse Jojo Johnson +Jett Janiak +Jeximo +Jhen-Jie Hong +Jiahao Li +Jian Liao +JidongZhang-THU <1119708529@qq.com> +Jinwoo Jeong <33892306+williamjeong2@users.noreply.github.com> +Jinyang He +Jiří Podivín <66251151+jpodivin@users.noreply.github.com> +Jiří Sejkora +Joan Fontanals +Joan Fontanals +João Dinis Ferreira +Joe Eli McIlvain +Joe Todd +Johan +Johannes Gäßler +Johannes Rudolph +John <78893154+cmp-nct@users.noreply.github.com> +John Balis +John Smith <67539080+kingsidelee@users.noreply.github.com> +JohnnyB +Jonas Wunderlich <32615971+jonas-w@users.noreply.github.com> +Jorge A <161275481+jorgealias@users.noreply.github.com> +Jose Maldonado <63384398+yukiteruamano@users.noreply.github.com> +Joseph Stahl <1269177+josephst@users.noreply.github.com> +Josh Ramer +Joyce +Juan Calderon-Perez <835733+gaby@users.noreply.github.com> +Judd +Juk Armstrong <69222624+jukofyork@users.noreply.github.com> +Julius Arkenberg +Jun Hee Yoo +Jun Jie <71215065+junnjiee16@users.noreply.github.com> +Junil Kim +Junyang Lin +Juraj Bednar +Justin Parker +Justin Suess +Justina Cho +Justine Tunney +Justine Tunney +Juuso Alasuutari +KASR +Kamil Tomšík +Kante Yin +Karol Kontny <82021046+kkontny@users.noreply.github.com> +Karsten Weiss +Karthick +Karthik Kumar Viswanathan <195178+guilt@users.noreply.github.com> +Karthik Sethuraman +Kasumi <90275229+kasumi-1@users.noreply.github.com> +Kawrakow <48489457+ikawrakow@users.noreply.github.com> +Keiichi Tabata +Keke Han +Kenvix ⭐ +Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com> +Kevin Gibbons +Kevin Ji <1146876+kevinji@users.noreply.github.com> +Kevin Kwok +Kevin Lo +Kevin Wang +Kolen Cheung +Konstantin Herud +Konstantin Zhuravlyov +Kunshang Ji +Kyle Bruene +Kyle Liang +Kyle Mistele +Kylin <56434533+KyL0N@users.noreply.github.com> +Lars Grammel +Laura +Lee <44310445+lx200916@users.noreply.github.com> +Lee Drake +Leng Yue +Leon Knauer +LeonEricsson <70749762+LeonEricsson@users.noreply.github.com> +Leonardo Neumann +Li Tan +Linwei Wang +Liu Jia <109258120+Septa2112@users.noreply.github.com> +Liu Jia +LoganDark +Loïc Carrère +LostRuins <39025047+LostRuins@users.noreply.github.com> +LostRuins Concedo <39025047+LostRuins@users.noreply.github.com> +Lucas Moura Belo +Luciano +Luo Tian +Lyle Dean +M-A +M. Yusuf Sarıgöz +Ma Mingfei +Maarten ter Huurne +Mack Straight +Maël Kerbiriou +MaggotHATE +Mahesh Madhav <67384846+heshpdx@users.noreply.github.com> +Manuel <44313466+makuche@users.noreply.github.com> +Marc Köhlbrugge +Marco Matthies <71844+marcom@users.noreply.github.com> +Marcus Dunn <51931484+MarcusDunn@users.noreply.github.com> +Marian Cepok +Mark Fairbairn +Mark Zhuang +Marko Tasic +Markus Tavenrath +Martin Delille +Martin Krasser +Martin Schwaighofer +Marvin Gießing +Masaya, Kato <62578291+msy-kato@users.noreply.github.com> +MasterYi1024 <39848311+MasterYi1024@users.noreply.github.com> +Mateusz Charytoniuk +Matheus C. França +Matheus Gabriel Alves Silva +Mathieu Baudier +Mathieu Geli +Mathieu Nayrolles +Mathijs Henquet +Mathijs de Bruin +Matt Clayton <156335168+mattjcly@users.noreply.github.com> +Matt Pulver +Matt Stephenson +Matteo Boschini <12133566+mbosc@users.noreply.github.com> +Matteo Mortari +Mattheus Chediak +Matthew Tejo +Matvey Soloviev +Max Krasnyansky +Max Krasnyansky +Maxim Evtush <154841002+maximevtush@users.noreply.github.com> +Maxime <672982+maximegmd@users.noreply.github.com> +Maximilian Winter +Meng Zhang +Meng, Hengyu +Mengqing Cao +Merrick Christensen +Michael Coppola +Michael Engel +Michael Francis +Michael Hueschen +Michael Kesper +Michael Klimenko +Michael Podvitskiy +Michael Potter +Michael de Gans +Michaël de Vries +Michał Moskal +Michał Tuszyński +Michelle Tan <41475767+MichelleTanPY@users.noreply.github.com> +Mihai +Mike +Mikko Juola +Minsoo Cheong <54794500+mscheong01@users.noreply.github.com> +Minsoo Cheong +Mirko185 +Mirror Azure <54669636+MirrorAzure@users.noreply.github.com> +MistApproach <98988043+MistApproach@users.noreply.github.com> +Miwa / Ensan <63481257+ensan-hcl@users.noreply.github.com> +Mohammadreza Hendiani +Mohammadreza Hendiani +Molly Sophia +MoonRide303 <130458190+MoonRide303@users.noreply.github.com> +MorganRO8 <47795945+MorganRO8@users.noreply.github.com> +Murilo Santana +Musab Gultekin +Nam D. Tran <42194884+namtranase@users.noreply.github.com> +Nathan Epstein +Natsu +NawafAlansari <72708095+NawafAlansari@users.noreply.github.com> +Nebula +Neo Zhang <14088817+arthw@users.noreply.github.com> +Neo Zhang +Neo Zhang Jianyu +Neuman Vong +NeverLucky <92274250+nvrxq@users.noreply.github.com> +Nexes the Old <124105151+Nexesenex@users.noreply.github.com> +Nexesenex <124105151+Nexesenex@users.noreply.github.com> +Niall Coates <1349685+Niall-@users.noreply.github.com> +Nicholai Tukanov +Nico Bosshard +Nicolai Weitkemper +Nicolás Pérez +Nicolò Scipione +Nigel Bosch +Nikita Sarychev <42014488+sARY77@users.noreply.github.com> +Niklas Korz +NikolaiLyssogor <59844691+NikolaiLyssogor@users.noreply.github.com> +Nikolaos Pothitos +Nikolas <127742645+nneubacher@users.noreply.github.com> +Nindaleth +Nuno +OSecret <135510162+OLSecret@users.noreply.github.com> +Oleksandr Kuvshynov <661042+okuvshynov@users.noreply.github.com> +Oleksandr Nikitin +Oleksii Maryshchenko +Olivier Chafik +Ondřej Čertík +Ouadie EL FAROUKI +PAB +Pablo Duboue +Pascal Patry +Patrice Ferlet +Patrick Peng +Paul Tsochantaris +Pavel Zloi +Pavol Rusnak +Paweł Wodnicki <151604+32bitmicro@users.noreply.github.com> +Pedro Cuenca +Peter +Peter Sugihara +Phil H <5756783+phiharri@users.noreply.github.com> +Philip Taron +Phillip Kravtsov +Pierre Alexandre SCHEMBRI +Pierrick Hymbert +Pieter Ouwerkerk +Plamen Minev +Prashant Vithule <119530321+Vithulep@users.noreply.github.com> +Przemysław Pawełczyk +PureJourney +Qin Yue Chen <71813199+chenqiny@users.noreply.github.com> +Qingyou Meng +Qu Zongfu <43257352+yancaoweidaode@users.noreply.github.com> +R0CKSTAR +R0CKSTAR +RJ Adriaansen +Radoslav Gerganov +Radosław Gryta +Rahul Vivek Nair <68507071+RahulVivekNair@users.noreply.github.com> +Raj Hammeer Singh Hada +Ralph Soika +Rand Xie +Randall Fitzgerald +Random Fly +Reinforce-II +Rémy O +Rémy Oudompheng +Ren Xuancheng +Rene Leonhardt <65483435+reneleonhardt@users.noreply.github.com> +Reza Kakhki +Reza Rahemtola <49811529+RezaRahemtola@users.noreply.github.com> +RhinoDevel +Riccardo Orlando +Riceball LEE +Rich Dougherty +Richard +Richard Kiss +Richard Roberson +Rick G <26732651+TheFlipbook@users.noreply.github.com> +Rickard Edén +Rickard Hallerbäck +Rickey Bowers Jr +Riley Stewart +Rinne +Rinne +Robert Brisita <986796+rbrisita@users.noreply.github.com> +Robert Collins +Robert Ormandi <52251610+ormandi@users.noreply.github.com> +Robert Sung-wook Shin +Robey Holderith +Robyn +Roger Meier +Rohanjames1997 +Roland <14355895+rbur0425@users.noreply.github.com> +Romain Biessy +Romain D <90720+Artefact2@users.noreply.github.com> +Romain Neutron +Roman Parykin +Ron Evans +Ron Jailall +Roni +Ronny Brendel +Ronsor +Rowan Hart +Ruan <47767371+ruanych@users.noreply.github.com> +Ruchira Hasaranga +Rudi Servo +Ruixin Huang <18860020911@163.com> +Rune <43761327+Rune-AI@users.noreply.github.com> +RunningLeon +RunningLeon +Ryan Landay +Ryder Wishart +Ryuei +Rőczey Barnabás <31726601+An0nie@users.noreply.github.com> +SAMI +SRHMorris <69468379+SRHMorris@users.noreply.github.com> +SXX +SakuraUmi +Salvador E. Tropea +Salvatore Mesoraca +Sam Spilsbury +Sami Farin <3876865+Safari77@users.noreply.github.com> +Samuel Maynard +Sang-Kil Park +Seb C <47074056+Sebby37@users.noreply.github.com> +Sebastián A +SebastianApel <13675545+SebastianApel@users.noreply.github.com> +Senemu <10880819+Senemu@users.noreply.github.com> +Sergey Alirzaev +Sergio López +Sergio López +Sertaç Özercan <852750+sozercan@users.noreply.github.com> +SeungWon Jeong <65549245+redlion0929@users.noreply.github.com> +ShadovvBeast +Shakhar Dasgupta +Shane A +Shangning Xu <32517059+xushangning@users.noreply.github.com> +Shankar +Shanshan Shen <467638484@qq.com> +Shelby Jenkins <47464908+ShelbyJenkins@users.noreply.github.com> +Sheldon Robinson +Shijie <821898965@qq.com> +Shintarou Okada +Shouzheng Liu <61452103+lshzh-ww@users.noreply.github.com> +Shouzheng Liu +Shuichi Tsutsumi +Shupei Fan +Sigbjørn Skjæret +Simon Willison +Siwen Yu +Sky Yan +Slaren <2141330+slaren@users.noreply.github.com> +Slava Primenko +Small Grass Forest +SoftwareRenderer <138734813+SoftwareRenderer@users.noreply.github.com> +Someone +Someone Serge +Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> +Spencer Sutton +Srihari-mcw <96763064+Srihari-mcw@users.noreply.github.com> +Srinivas Billa +Stefan Sydow +Steffen Röcker +Stephan Walter +Stephen Nichols +Steve Bonds +Steve Grubb +Steven Prichard +Steven Roussey +Steward Garcia <57494570+FSSRepo@users.noreply.github.com> +StrangeBytesDev <141275258+StrangeBytesDev@users.noreply.github.com> +Suaj Carrot <72162667+SuajCarrot@users.noreply.github.com> +Sukriti Sharma +SuperUserNameMan +Sutou Kouhei +Tai Duc Nguyen +Taikono-Himazin +Tameem <113388789+AhmadTameem@users.noreply.github.com> +Tamotsu Takahashi +Tei Home +Thái Hoàng Tâm <75922889+RoyalHeart@users.noreply.github.com> +Thatcher Chamberlin +Theia Vogel +Thérence <13496987+Royalphax@users.noreply.github.com> +Thibault Terrasson +Thomas Klausner +Thorsten Sommer +Tim Miller +Tim Wang +Timmy Knight +Timothy Cronin <40186632+4imothy@users.noreply.github.com> +Ting Lou +Ting Lou +Ting Sun +Tobias Lütke +Tom C +Tom Jobbins <784313+TheBloke@users.noreply.github.com> +Tomas +Tomáš Pazdiora +Tony Wasserka <4840017+neobrain@users.noreply.github.com> +Tristan Druyen +Tristan Ross +Trivikram Kamat <16024985+trivikr@users.noreply.github.com> +Tungsten842 <886724vf@anonaddy.me> +Tungsten842 +Tushar +UEXTM.com <84163508+uextm@users.noreply.github.com> +Ujjawal Panchal <31011628+Ujjawal-K-Panchal@users.noreply.github.com> +Ulrich Drepper +Uzo Nweke +Vaibhav Srivastav +Val Kharitonov +Valentin Konovalov +Valentin Mamedov <45292985+Inf1delis@users.noreply.github.com> +Valentyn Bezshapkin <61702053+valentynbez@users.noreply.github.com> +Vali Malinoiu <0x4139@gmail.com> +Victor Nogueira +Victor Z. Peng +Viet-Anh NGUYEN (Andrew) +Vinesh Janarthanan <36610342+VJHack@users.noreply.github.com> +Vitali Lovich +Vivian +Vlad +Vladimir +Vladimir Malyutin +Vladimir Vuksanovic <109677816+vvuksanovic@users.noreply.github.com> +Vladimir Zorin +VoidIsVoid <343750470@qq.com> +Volodymyr Vitvitskyi <72226+signalpillar@users.noreply.github.com> +Wagner Bruna +Wang Qin <37098874+wangqin0@users.noreply.github.com> +Wang Ran (汪然) +WangHaoranRobin <56047610+WangHaoranRobin@users.noreply.github.com> +Weird Constructor +Weizhao Ouyang +Welby Seely +Wentai Zhang +Wilken Gottwalt <12194808+wgottwalt@users.noreply.github.com> +WillCorticesAI <150854901+WillCorticesAI@users.noreply.github.com> +William Tambellini +William Tambellini +Willy Tarreau +Woof Dog <197125663+woof-dog@users.noreply.github.com> +Wouter <9594229+DifferentialityDevelopment@users.noreply.github.com> +Wu Jian Ping +Wu Jian Ping +Xiake Sun +Xiang (Kevin) Li +Xiao-Yong Jin +XiaotaoChen +Xiaoyi Chen +Xie Yanbo +Xingchen Song(宋星辰) +Xinpeng Dou <81913537+Dou-Git@users.noreply.github.com> +Xuan Son Nguyen +Xuan-Son Nguyen +Yaiko +Yann Follet <131855179+YannFollet@users.noreply.github.com> +Yaroslav +Yazan Agha-Schrader +Yiming Cui +Yishuo Wang +Yoshi Suhara +Yoshi Suhara +Younes Belkada <49240599+younesbelkada@users.noreply.github.com> +Yueh-Po Peng <94939112+y10ab1@users.noreply.github.com> +Yüg +Yui +Yun Dou +Yuri Khrustalev +Yusuf Kağan Hanoğlu +Yuval Peled <31162840+Yuval-Peled@users.noreply.github.com> +ZHAOKAI WANG +Zane Shannon +Zay <95888118+isaiahbjork@users.noreply.github.com> +Zenix +Zhang Peiyuan +Zheng.Deng <32841220+dengzheng-cloud@users.noreply.github.com> +Zhenwei Jin <109658203+kylo5aby@users.noreply.github.com> +Zhiyuan Li +Zhiyuan Li +ZhouYuChen +Ziad Ben Hadj-Alouane +Ziang Wu <97337387+ZiangWu-77@users.noreply.github.com> +Zsapi +a-n-n-a-l-e-e <150648636+a-n-n-a-l-e-e@users.noreply.github.com> +a3sh <38979186+A3shTnT@users.noreply.github.com> +adel boussaken +afrideva <95653597+afrideva@users.noreply.github.com> +ag2s20150909 <19373730+ag2s20150909@users.noreply.github.com> +agray3 +akawrykow <142945436+akawrykow@users.noreply.github.com> +alek3y <44779186+alek3y@users.noreply.github.com> +alexpinel <93524949+alexpinel@users.noreply.github.com> +alonfaraj +alwqx +amd-dwang +amd-lalithnc +amritahs-ibm +andrijdavid +anon998 <131767832+anon998@users.noreply.github.com> +anzz1 +apaz +apcameron <37645737+apcameron@users.noreply.github.com> +arch-btw <57669023+arch-btw@users.noreply.github.com> +arcrank +ardfork <134447697+ardfork@users.noreply.github.com> +arlo-phoenix <140345165+arlo-phoenix@users.noreply.github.com> +aryantandon01 <80969509+aryantandon01@users.noreply.github.com> +at8u <129688334+at8u@users.noreply.github.com> +automaticcat +awatuna <23447591+awatuna@users.noreply.github.com> +b4b4o +bandoti <141645996+bandoti@users.noreply.github.com> +beiller +bhubbb <79117352+bhubbb@users.noreply.github.com> +bmwl +bobqianic <129547291+bobqianic@users.noreply.github.com> +brucepro +bryanSwk <93190252+bryanSwk@users.noreply.github.com> +bsilvereagle +bssrdf +byte-6174 <88070277+byte-6174@users.noreply.github.com> +cduk <19917266+cduk@users.noreply.github.com> +cebtenzzre +chaihahaha +chiranko <96988916+chiranko@users.noreply.github.com> +clibdev <52199778+clibdev@users.noreply.github.com> +clyang +cmdr2 +cmdr2 +cocktailpeanut <121128867+cocktailpeanut@users.noreply.github.com> +codezjx +coezbek +comex +compilade <113953597+compilade@users.noreply.github.com> +compilade +cpumaxx <163466046+cpumaxx@users.noreply.github.com> +crasm +crasm +daboe01 +daghanerdonmez <44506702+daghanerdonmez@users.noreply.github.com> +daminho <37615795+daminho@users.noreply.github.com> +david raistrick +ddh0 +ddpasa <112642920+ddpasa@users.noreply.github.com> +deepdiffuser <112834445+deepdiffuser@users.noreply.github.com> +devojony <61173062+devojony@users.noreply.github.com> +ditsuke +divinity76 +dm4 +dm4 +dotpy314 <33351922+dotpy314@users.noreply.github.com> +drbh +ds5t5 <145942675+ds5t5@users.noreply.github.com> +dylan +eastriver +ebraminio +ebraminio +eiery <19350831+eiery@users.noreply.github.com> +eric8607242 +fairydreaming <166155368+fairydreaming@users.noreply.github.com> +fengerhu1 <2748250768@qq.com> +fj-y-saito <85871716+fj-y-saito@users.noreply.github.com> +fraxy-v <65565042+fraxy-v@users.noreply.github.com> +fxzjshm <11426482+fxzjshm@users.noreply.github.com> +github-actions[bot] +gliptic +gn64 +goerch +grahameth <96447521+grahameth@users.noreply.github.com> +gtygo +gwjr <502526+gwjr@users.noreply.github.com> +h-h-h-h <13482553+h-h-h-h@users.noreply.github.com> +hankcs +haopeng <657407891@qq.com> +hipudding +hoangmit +hongbo.mo <352280764@qq.com> +hopkins385 <98618192+hopkins385@users.noreply.github.com> +howlger +howlger +hutli <6594598+hutli@users.noreply.github.com> +hutli +hutli +hxer7963 +hydai +iSma +iacore <74560659+iacore@users.noreply.github.com> +icppWorld <124377669+icppWorld@users.noreply.github.com> +igardev <49397134+igardev@users.noreply.github.com> +igarnier +intelmatt <61025942+intelmatt@users.noreply.github.com> +iohub +issixx <46835150+issixx@users.noreply.github.com> +jacobi petrucciani <8117202+jpetrucciani@users.noreply.github.com> +jaime-m-p <167997752+jaime-m-p@users.noreply.github.com> +jameswu2014 <545426914@qq.com> +jason_w +jdomke <28772296+jdomke@users.noreply.github.com> +jiahao su +jiez <373447296@qq.com> +jneem +joecryptotoo <80373433+joecryptotoo@users.noreply.github.com> +johnson442 <56517414+johnson442@users.noreply.github.com> +jojorne +jon-chuang <9093549+jon-chuang@users.noreply.github.com> +jp-x-g +jukofyork <69222624+jukofyork@users.noreply.github.com> +junchao-loongson <68935141+junchao-loongson@users.noreply.github.com> +junchao-zhao <68935141+junchao-loongson@users.noreply.github.com> +jwj7140 <32943891+jwj7140@users.noreply.github.com> +k.h.lai +kaizau +kallewoof +kalomaze <66376113+kalomaze@users.noreply.github.com> +kang +katsu560 <118887472+katsu560@users.noreply.github.com> +kchro3 <62481661+kchro3@users.noreply.github.com> +khimaros +kiltyj +klosax <131523366+klosax@users.noreply.github.com> +krystiancha +kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> +kunnis +kuronekosaiko +kustaaya <58045274+kustaaya@users.noreply.github.com> +kuvaus <22169537+kuvaus@users.noreply.github.com> +kwin1412 <42286931+kwin1412@users.noreply.github.com> +l3utterfly +laik +ldwang +le.chang +leejet +leo-pony +lexasub +lhez +limitedAtonement +liuwei-git <14815172+liuwei-git@users.noreply.github.com> +lon <114724657+longregen@users.noreply.github.com> +loonerin <132926317+loonerin@users.noreply.github.com> +ltoniazzi <61414566+ltoniazzi@users.noreply.github.com> +luoyu-intel +m3ndax +maddes8cht <55592906+maddes8cht@users.noreply.github.com> +magicse +mahorozte <41834471+mahorozte@users.noreply.github.com> +makomk +manikbhandari +maor-ps <154728172+maor-ps@users.noreply.github.com> +mashdragon <122402293+mashdragon@users.noreply.github.com> +matiaslin <45382001+matiaslin@users.noreply.github.com> +matt23654 +matteo +mdrokz +mgroeber9110 <45620825+mgroeber9110@users.noreply.github.com> +midnight +minarchist +mj-shifu <77107165+mj-shifu@users.noreply.github.com> +mmyjona +momonga <115213907+mmnga@users.noreply.github.com> +momonga <146910567+mmngays@users.noreply.github.com> +moritzbrantner <31051084+moritzbrantner@users.noreply.github.com> +musoles <135031143+musoles@users.noreply.github.com> +mzcu +nanahi <130121847+na-na-hi@users.noreply.github.com> +ngc92 <7938269+ngc92@users.noreply.github.com> +nhamanasu <45545786+nhamanasu@users.noreply.github.com> +niansa/tuxifan +niansa/tuxifan +nickp27 +ningshanwutuobang +nold +nopperl <54780682+nopperl@users.noreply.github.com> +nusu-github <29514220+nusu-github@users.noreply.github.com> +olexiyb +omahs <73983677+omahs@users.noreply.github.com> +oobabooga <112222186+oobabooga@users.noreply.github.com> +opparco +ostix360 <55257054+ostix360@users.noreply.github.com> +pascal-lc <49066376+pascal-lc@users.noreply.github.com> +pculliton +peidaqi +pengxin99 +perserk +petterreinholdtsen +piDack <104877312+piDack@users.noreply.github.com> +pmysl +postmasters +pudepiedj +qingfengfenga <41416092+qingfengfenga@users.noreply.github.com> +qingy1337 +qouoq +qunash +rabidcopy +rankaiyx +redbeard +rhjdvsgsgks <26178113+rhjdvsgsgks@users.noreply.github.com> +rhuddleston +rimoliga <53384203+rimoliga@users.noreply.github.com> +runfuture +sandyiscool +sasha0552 +semidark +serhii-nakon <57632032+serhii-nakon@users.noreply.github.com> +sharpHL <132747147+sharpHL@users.noreply.github.com> +shibe2 +simon886212 <37953122+simon886212@users.noreply.github.com> +singularity <12184989+singularity-s0@users.noreply.github.com> +sjinzh +sjxx <63994076+ylsdamxssjxxdd@users.noreply.github.com> +slaren <2141330+slaren@users.noreply.github.com> +slaren +snadampal <87143774+snadampal@users.noreply.github.com> +someone13574 <81528246+someone13574@users.noreply.github.com> +standby24x7 +staviq +stduhpf +strawberrymelonpanda <152940198+strawberrymelonpanda@users.noreply.github.com> +swittk +takov751 <40316768+takov751@users.noreply.github.com> +tarcey +tc-mb <157115220+tc-mb@users.noreply.github.com> +texmex76 <40733439+texmex76@users.noreply.github.com> +thement <40525767+thement@users.noreply.github.com> +theraininsky <76763719+theraininsky@users.noreply.github.com> +thewh1teagle <61390950+thewh1teagle@users.noreply.github.com> +tjohnman +toyer <2042519524@qq.com> +tslmy +tv1wnd <55383215+tv1wnd@users.noreply.github.com> +ubik2 +uint256_t +uint256_t +unbounded +uvos +uvos +valiray <133289098+valiray@users.noreply.github.com> +vb +vik +viric +vmobilis <75476228+vmobilis@users.noreply.github.com> +vodkaslime <646329483@qq.com> +vvhg1 <94630311+vvhg1@users.noreply.github.com> +vxiiduu <73044267+vxiiduu@users.noreply.github.com> +wangshuai09 <391746016@qq.com> +wbpxre150 <100937007+wbpxre150@users.noreply.github.com> +whoreson <139810751+whoreson@users.noreply.github.com> +woachk <24752637+woachk@users.noreply.github.com> +wonjun Jang +woodx <124784234+woodx9@users.noreply.github.com> +wwoodsTM <104587230+wwoodsTM@users.noreply.github.com> +wzy <32936898+Freed-Wu@users.noreply.github.com> +xaedes +xaedes +xctan +xiaobing318 <71554036+xiaobing318@users.noreply.github.com> +xiaofei +xloem <0xloem@gmail.com> +yangli2 +ymcki <84055651+ymcki@users.noreply.github.com> +yuiseki +yuri@FreeBSD +zakkor +zhangkaihuo +zhentaoyu +zhouwg <6889919+zhouwg@users.noreply.github.com> +zhouwg +zrm +Ștefan-Gabriel Muscalu +杨朱 · Kiki +源文雨 <41315874+fumiama@users.noreply.github.com> +蕭澧邦 <45505768+shou692199@users.noreply.github.com> +谢乃闻 +Нияз Гарифзянов <112617865+garrnizon@users.noreply.github.com> diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..50801cdc637bd366a656e4095b5beba74cf9ded5 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,265 @@ +cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories. +project("llama.cpp" C CXX) +include(CheckIncludeFileCXX) + +#set(CMAKE_WARN_DEPRECATED YES) +set(CMAKE_WARN_UNUSED_CLI YES) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") +endif() + +# Add path to modules +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") + +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) + +if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) + set(LLAMA_STANDALONE ON) + + include(git-vars) + + # configure project version + # TODO +else() + set(LLAMA_STANDALONE OFF) +endif() + +option(LLAMA_USE_SYSTEM_GGML "Use system libggml" OFF) + +if (EMSCRIPTEN) + set(BUILD_SHARED_LIBS_DEFAULT OFF) + + option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" ON) +else() + if (MINGW) + set(BUILD_SHARED_LIBS_DEFAULT OFF) + else() + set(BUILD_SHARED_LIBS_DEFAULT ON) + endif() +endif() + +option(BUILD_SHARED_LIBS "build shared libraries" ${BUILD_SHARED_LIBS_DEFAULT}) + +if (WIN32) + add_compile_definitions(_CRT_SECURE_NO_WARNINGS) +endif() + +if (MSVC) + add_compile_options("$<$:/utf-8>") + add_compile_options("$<$:/utf-8>") + add_compile_options("$<$:/bigobj>") + add_compile_options("$<$:/bigobj>") +endif() + +# +# option list +# + +# debug +option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON) +option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF) + +# build +option(LLAMA_FATAL_WARNINGS "llama: enable -Werror flag" OFF) + +# sanitizers +option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF) +option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF) +option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF) + +# utils +option(LLAMA_BUILD_COMMON "llama: build common utils library" ${LLAMA_STANDALONE}) + +# extra artifacts +option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) +option(LLAMA_BUILD_TOOLS "llama: build tools" ${LLAMA_STANDALONE}) +option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) +option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE}) + +# 3rd party libs +option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON) +option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF) + +# Required for relocatable CMake package +include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake) +include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/common.cmake) + +if (NOT DEFINED LLAMA_BUILD_NUMBER) + set(LLAMA_BUILD_NUMBER ${BUILD_NUMBER}) +endif() +if (NOT DEFINED LLAMA_BUILD_COMMIT) + set(LLAMA_BUILD_COMMIT ${BUILD_COMMIT}) +endif() +set(LLAMA_INSTALL_VERSION 0.0.${BUILD_NUMBER}) + +# override ggml options +set(GGML_ALL_WARNINGS ${LLAMA_ALL_WARNINGS}) +set(GGML_FATAL_WARNINGS ${LLAMA_FATAL_WARNINGS}) + +# change the default for these ggml options +if (NOT DEFINED GGML_LLAMAFILE) + set(GGML_LLAMAFILE_DEFAULT ON) +endif() + +if (NOT DEFINED GGML_CUDA_GRAPHS) + set(GGML_CUDA_GRAPHS_DEFAULT ON) +endif() + +# transition helpers +function (llama_option_depr TYPE OLD NEW) + if (${OLD}) + message(${TYPE} "${OLD} is deprecated and will be removed in the future.\nUse ${NEW} instead\n") + set(${NEW} ON PARENT_SCOPE) + endif() +endfunction() + +llama_option_depr(FATAL_ERROR LLAMA_CUBLAS GGML_CUDA) +llama_option_depr(WARNING LLAMA_CUDA GGML_CUDA) +llama_option_depr(WARNING LLAMA_KOMPUTE GGML_KOMPUTE) +llama_option_depr(WARNING LLAMA_METAL GGML_METAL) +llama_option_depr(WARNING LLAMA_METAL_EMBED_LIBRARY GGML_METAL_EMBED_LIBRARY) +llama_option_depr(WARNING LLAMA_NATIVE GGML_NATIVE) +llama_option_depr(WARNING LLAMA_RPC GGML_RPC) +llama_option_depr(WARNING LLAMA_SYCL GGML_SYCL) +llama_option_depr(WARNING LLAMA_SYCL_F16 GGML_SYCL_F16) +llama_option_depr(WARNING LLAMA_CANN GGML_CANN) + +if (NOT MSVC) + if (LLAMA_SANITIZE_THREAD) + message(STATUS "Using -fsanitize=thread") + + add_compile_options(-fsanitize=thread) + link_libraries (-fsanitize=thread) + endif() + + if (LLAMA_SANITIZE_ADDRESS) + message(STATUS "Using -fsanitize=address") + + add_compile_options(-fsanitize=address -fno-omit-frame-pointer) + link_libraries (-fsanitize=address) + endif() + + if (LLAMA_SANITIZE_UNDEFINED) + message(STATUS "Using -fsanitize=undefined") + + add_compile_options(-fsanitize=undefined) + link_libraries (-fsanitize=undefined) + endif() +endif() + +# +# 3rd-party +# + +if (LLAMA_USE_SYSTEM_GGML) + message(STATUS "Using system-provided libggml, skipping ggml build") + find_package(ggml REQUIRED) + add_library(ggml ALIAS ggml::ggml) +endif() + +if (NOT TARGET ggml AND NOT LLAMA_USE_SYSTEM_GGML) + set(GGML_BUILD_NUMBER ${LLAMA_BUILD_NUMBER}) + set(GGML_BUILD_COMMIT ${LLAMA_BUILD_COMMIT}) + add_subdirectory(ggml) + # ... otherwise assume ggml is added by a parent CMakeLists.txt +endif() + +if (MINGW) + # Target Windows 8 for PrefetchVirtualMemory + add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER}) +endif() + +# +# build the library +# + +add_subdirectory(src) + +# +# utils, programs, examples and tests +# + +if (NOT LLAMA_BUILD_COMMON) + message(STATUS "LLAMA_BUILD_COMMON is OFF, disabling LLAMA_CURL") + set(LLAMA_CURL OFF) +endif() + +if (LLAMA_BUILD_COMMON) + add_subdirectory(common) +endif() + +if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION) + include(CTest) + add_subdirectory(tests) +endif() + +if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_EXAMPLES) + add_subdirectory(examples) + add_subdirectory(pocs) +endif() + +if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TOOLS) + add_subdirectory(tools) +endif() + +# +# install +# + +include(GNUInstallDirs) +include(CMakePackageConfigHelpers) + +set(LLAMA_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files") +set(LLAMA_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files") +set(LLAMA_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files") + +set(LLAMA_PUBLIC_HEADERS + ${CMAKE_CURRENT_SOURCE_DIR}/include/llama.h + ${CMAKE_CURRENT_SOURCE_DIR}/include/llama-cpp.h) + +set_target_properties(llama + PROPERTIES + PUBLIC_HEADER "${LLAMA_PUBLIC_HEADERS}") + +install(TARGETS llama LIBRARY PUBLIC_HEADER) + +configure_package_config_file( + ${CMAKE_CURRENT_SOURCE_DIR}/cmake/llama-config.cmake.in + ${CMAKE_CURRENT_BINARY_DIR}/llama-config.cmake + INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/llama + PATH_VARS LLAMA_INCLUDE_INSTALL_DIR + LLAMA_LIB_INSTALL_DIR + LLAMA_BIN_INSTALL_DIR ) + +write_basic_package_version_file( + ${CMAKE_CURRENT_BINARY_DIR}/llama-version.cmake + VERSION ${LLAMA_INSTALL_VERSION} + COMPATIBILITY SameMajorVersion) + +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/llama-config.cmake + ${CMAKE_CURRENT_BINARY_DIR}/llama-version.cmake + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/llama) + +install( + FILES convert_hf_to_gguf.py + PERMISSIONS + OWNER_READ + OWNER_WRITE + OWNER_EXECUTE + GROUP_READ + GROUP_EXECUTE + WORLD_READ + WORLD_EXECUTE + DESTINATION ${CMAKE_INSTALL_BINDIR}) + +configure_file(cmake/llama.pc.in + "${CMAKE_CURRENT_BINARY_DIR}/llama.pc" + @ONLY) + +install(FILES "${CMAKE_CURRENT_BINARY_DIR}/llama.pc" + DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig) diff --git a/CMakePresets.json b/CMakePresets.json new file mode 100644 index 0000000000000000000000000000000000000000..e9844701304fcaf952466507c1b03cba8a06546e --- /dev/null +++ b/CMakePresets.json @@ -0,0 +1,84 @@ +{ + "version": 4, + "configurePresets": [ + { + "name": "base", + "hidden": true, + "generator": "Ninja", + "binaryDir": "${sourceDir}/build-${presetName}", + "cacheVariables": { + "CMAKE_EXPORT_COMPILE_COMMANDS": "ON", + "CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.." + } + }, + { + "name": "sycl-base", + "hidden": true, + "generator": "Ninja", + "binaryDir": "${sourceDir}/build-${presetName}", + "cacheVariables": { + "CMAKE_EXPORT_COMPILE_COMMANDS": "ON", + "CMAKE_CXX_COMPILER": "icx", + "CMAKE_C_COMPILER": "cl", + "GGML_SYCL": "ON", + "CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.." + } + }, + { "name": "debug", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug" } }, + { "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } }, + { "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } }, + { "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } }, + { "name": "sycl_f16", "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } }, + { "name": "vulkan", "hidden": true, "cacheVariables": { "GGML_VULKAN": "ON" } }, + + { + "name": "x64-windows-llvm", "hidden": true, + "cacheVariables": { + "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/x64-windows-llvm.cmake" + } + }, + + { + "name": "arm64-windows-llvm", "hidden": true, + "architecture": { "value": "arm64", "strategy": "external" }, + "toolset": { "value": "host=x64", "strategy": "external" }, + "cacheVariables": { + "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-windows-llvm.cmake" + } + }, + + { + "name": "arm64-apple-clang", "hidden": true, + "architecture": { "value": "arm64", "strategy": "external" }, + "toolset": { "value": "host=x64", "strategy": "external" }, + "cacheVariables": { + "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-apple-clang.cmake" + } + }, + + { "name": "arm64-windows-llvm-debug", "inherits": [ "base", "arm64-windows-llvm", "debug" ] }, + { "name": "arm64-windows-llvm-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg" ] }, + { "name": "arm64-windows-llvm+static-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg", "static" ] }, + + { "name": "arm64-apple-clang-debug", "inherits": [ "base", "arm64-apple-clang", "debug" ] }, + { "name": "arm64-apple-clang-release", "inherits": [ "base", "arm64-apple-clang", "reldbg" ] }, + { "name": "arm64-apple-clang+static-release", "inherits": [ "base", "arm64-apple-clang", "reldbg", "static" ] }, + + { "name": "x64-windows-llvm-debug", "inherits": [ "base", "x64-windows-llvm", "debug" ] }, + { "name": "x64-windows-llvm-release", "inherits": [ "base", "x64-windows-llvm", "release" ] }, + { "name": "x64-windows-llvm-reldbg", "inherits": [ "base", "x64-windows-llvm", "reldbg" ] }, + { "name": "x64-windows-llvm+static-release", "inherits": [ "base", "x64-windows-llvm", "reldbg", "static" ] }, + + { "name": "x64-windows-msvc-debug", "inherits": [ "base", "debug" ] }, + { "name": "x64-windows-msvc-release", "inherits": [ "base", "reldbg" ] }, + { "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] }, + + { "name": "x64-windows-sycl-debug", "inherits": [ "sycl-base", "debug" ] }, + { "name": "x64-windows-sycl-debug-f16", "inherits": [ "sycl-base", "debug", "sycl_f16" ] }, + { "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] }, + { "name": "x64-windows-sycl-release-f16", "inherits": [ "sycl-base", "release", "sycl_f16" ] }, + + { "name": "x64-windows-vulkan-debug", "inherits": [ "base", "vulkan", "debug" ] }, + { "name": "x64-windows-vulkan-release", "inherits": [ "base", "vulkan", "release" ] } + ] +} diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 0000000000000000000000000000000000000000..3186f8eb1c5148948a7408f384e5aa6b74591646 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1,11 @@ +# collaborators can optionally add themselves here to indicate their availability for reviewing related PRs + +/ci/ @ggerganov +/.devops/*.Dockerfile @ngxson +/tools/server/ @ngxson +/ggml/src/ggml-cuda/fattn* @JohannesGaessler +/ggml/src/ggml-cuda/mmq.* @JohannesGaessler +/ggml/src/ggml-cuda/mmv.* @JohannesGaessler +/ggml/src/ggml-cuda/mmvq.* @JohannesGaessler +/ggml/src/ggml-opt.cpp @JohannesGaessler +/ggml/src/gguf.cpp @JohannesGaessler diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..e68ff92445828c5ea0a225e3eb6122f4e266fb06 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,127 @@ +# Pull requests (for contributors) + +- llama.cpp uses the ggml tensor library for model evaluation. If you are unfamiliar with ggml, consider taking a look at the [examples in the ggml repository](https://github.com/ggml-org/ggml/tree/master/examples/). [simple](https://github.com/ggml-org/ggml/tree/master/examples/simple) shows the bare minimum for using ggml. [gpt-2](https://github.com/ggml-org/ggml/tree/master/examples/gpt-2) has minimal implementations for language model inference using GPT-2. [mnist](https://github.com/ggml-org/ggml/tree/master/examples/mnist) demonstrates how to train and evaluate a simple image classifier +- Test your changes: + - Execute [the full CI locally on your machine](ci/README.md) before publishing + - Verify that the perplexity and the performance are not affected negatively by your changes (use `llama-perplexity` and `llama-bench`) + - If you modified the `ggml` source, run the `test-backend-ops` tool to check whether different backend implementations of the `ggml` operators produce consistent results (this requires access to at least two different `ggml` backends) + - If you modified a `ggml` operator or added a new one, add the corresponding test cases to `test-backend-ops` +- Create separate PRs for each feature or fix. Avoid combining unrelated changes in a single PR +- Consider allowing write access to your branch for faster reviews, as reviewers can push commits directly +- If your PR becomes stale, don't hesitate to ping the maintainers in the comments + +# Pull requests (for collaborators) + +- Squash-merge PRs +- Use the following format for the squashed commit title: ` : (#)`. For example: `utils : fix typo in utils.py (#1234)` +- Optionally pick a `` from here: https://github.com/ggml-org/llama.cpp/wiki/Modules +- Consider adding yourself to [CODEOWNERS](CODEOWNERS) + +# Coding guidelines + +- Avoid adding third-party dependencies, extra files, extra headers, etc. +- Always consider cross-compatibility with other operating systems and architectures +- Avoid fancy-looking modern STL constructs, use basic `for` loops, avoid templates, keep it simple +- Vertical alignment makes things more readable and easier to batch edit +- Clean-up any trailing whitespaces, use 4 spaces for indentation, brackets on the same line, `void * ptr`, `int & a` +- Use sized integer types such as `int32_t` in the public API, e.g. `size_t` may also be appropriate for allocation sizes or byte offsets +- Declare structs with `struct foo {}` instead of `typedef struct foo {} foo` + - In C++ code omit optional `struct` and `enum` keyword whenever they are not necessary + ```cpp + // OK + llama_context * ctx; + const llama_rope_type rope_type; + + // not OK + struct llama_context * ctx; + const enum llama_rope_type rope_type; + ``` + + _(NOTE: this guideline is yet to be applied to the `llama.cpp` codebase. New code should follow this guideline.)_ + +- Try to follow the existing patterns in the code (indentation, spaces, etc.). In case of doubt use `clang-format` (from clang-tools v15+) to format the added code +- For anything not covered in the current guidelines, refer to the [C++ Core Guidelines](https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines) +- Tensors store data in row-major order. We refer to dimension 0 as columns, 1 as rows, 2 as matrices +- Matrix multiplication is unconventional: [`C = ggml_mul_mat(ctx, A, B)`](https://github.com/ggml-org/llama.cpp/blob/880e352277fc017df4d5794f0c21c44e1eae2b84/ggml.h#L1058-L1064) means $C^T = A B^T \Leftrightarrow C = B A^T.$ + +![matmul](media/matmul.png) + +# Naming guidelines + +- Use `snake_case` for function, variable and type names +- Naming usually optimizes for longest common prefix (see https://github.com/ggml-org/ggml/pull/302#discussion_r1243240963) + + ```cpp + // not OK + int small_number; + int big_number; + + // OK + int number_small; + int number_big; + ``` + +- Enum values are always in upper case and prefixed with the enum name + + ```cpp + enum llama_vocab_type { + LLAMA_VOCAB_TYPE_NONE = 0, + LLAMA_VOCAB_TYPE_SPM = 1, + LLAMA_VOCAB_TYPE_BPE = 2, + LLAMA_VOCAB_TYPE_WPM = 3, + LLAMA_VOCAB_TYPE_UGM = 4, + LLAMA_VOCAB_TYPE_RWKV = 5, + }; + ``` + +- The general naming pattern is `_`, with `` being `_` + + ```cpp + llama_model_init(); // class: "llama_model", method: "init" + llama_sampler_chain_remove(); // class: "llama_sampler_chain", method: "remove" + llama_sampler_get_seed(); // class: "llama_sampler", method: "get_seed" + llama_set_embeddings(); // class: "llama_context", method: "set_embeddings" + llama_n_threads(); // class: "llama_context", method: "n_threads" + llama_adapter_lora_free(); // class: "llama_adapter_lora", method: "free" + ``` + + - The `get` `` can be omitted + - The `` can be omitted if not necessary + - The `_context` suffix of the `` is optional. Use it to disambiguate symbols when needed + - Use `init`/`free` for constructor/destructor `` + +- Use the `_t` suffix when a type is supposed to be opaque to the user - it's not relevant to them if it is a struct or anything else + + ```cpp + typedef struct llama_context * llama_context_t; + + enum llama_pooling_type llama_pooling_type(const llama_context_t ctx); + ``` + + _(NOTE: this guideline is yet to be applied to the `llama.cpp` codebase. New code should follow this guideline)_ + +- C/C++ filenames are all lowercase with dashes. Headers use the `.h` extension. Source files use the `.c` or `.cpp` extension +- Python filenames are all lowercase with underscores + +- _(TODO: abbreviations usage)_ + +# Preprocessor directives + +- _(TODO: add guidelines with examples and apply them to the codebase)_ + + ```cpp + #ifdef FOO + #endif // FOO + ``` + +# Documentation + +- Documentation is a community effort +- When you need to look into the source code to figure out how to use an API consider adding a short summary to the header file for future reference +- When you notice incorrect or outdated documentation, please update it + +# Resources + +The Github issues, PRs and discussions contain a lot of information that can be useful to get familiar with the codebase. For convenience, some of the more important information is referenced from Github projects: + +https://github.com/ggml-org/llama.cpp/projects diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..acb96ce78e0486e9dc7602cdcdd2f491c34f335a --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023-2024 The ggml authors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..ac442aec095d63ed7bd41554ed37c687957fd54a --- /dev/null +++ b/Makefile @@ -0,0 +1,1608 @@ +ifndef LLAMA_MAKEFILE +$(error The Makefile build is deprecated. Use the CMake build instead. For more details, see https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md) +endif + +# Define the default target now so that it is always the first target +BUILD_TARGETS = \ + libllava.a \ + llama-batched \ + llama-batched-bench \ + llama-bench \ + llama-cli \ + llama-convert-llama2c-to-ggml \ + llama-embedding \ + llama-eval-callback \ + llama-export-lora \ + llama-gbnf-validator \ + llama-gguf \ + llama-gguf-hash \ + llama-gguf-split \ + llama-gritlm \ + llama-imatrix \ + llama-infill \ + llama-llava-cli \ + llama-minicpmv-cli\ + llama-qwen2vl-cli\ + llama-lookahead \ + llama-lookup \ + llama-lookup-create \ + llama-lookup-merge \ + llama-lookup-stats \ + llama-parallel \ + llama-passkey \ + llama-perplexity \ + llama-q8dot \ + llama-quantize \ + llama-quantize-stats \ + llama-retrieval \ + llama-save-load-state \ + llama-server \ + llama-simple \ + llama-simple-chat \ + llama-run \ + llama-speculative \ + llama-tokenize \ + llama-vdot \ + llama-cvector-generator \ + llama-gen-docs \ + tests/test-c.o + +# Binaries only useful for tests +TEST_TARGETS = \ + tests/test-arg-parser \ + tests/test-autorelease \ + tests/test-backend-ops \ + tests/test-chat \ + tests/test-chat-template \ + tests/test-double-float \ + tests/test-grammar-integration \ + tests/test-grammar-parser \ + tests/test-json-schema-to-grammar \ + tests/test-llama-grammar \ + tests/test-log \ + tests/test-model-load-cancel \ + tests/test-quantize-fns \ + tests/test-quantize-perf \ + tests/test-rope \ + tests/test-sampling \ + tests/test-tokenizer-0 \ + tests/test-tokenizer-1-bpe \ + tests/test-tokenizer-1-spm +# tests/test-opt \ + +# Legacy build targets that were renamed in #7809, but should still be removed when the project is cleaned +LEGACY_TARGETS_CLEAN = main quantize quantize-stats perplexity imatrix embedding vdot q8dot convert-llama2c-to-ggml \ + simple batched batched-bench save-load-state server gguf gguf-split eval-callback llama-bench libllava.a llava-cli baby-llama \ + retrieval speculative infill tokenize parallel export-lora lookahead lookup passkey gritlm + +# Legacy build targets that were renamed in #7809, but we want to build binaries that for them that output a deprecation warning if people try to use them. +# We don't want to clutter things too much, so we only build replacements for the most commonly used binaries. +LEGACY_TARGETS_BUILD = main quantize perplexity embedding server + +# Deprecation aliases +ifdef LLAMA_CUBLAS +$(error LLAMA_CUBLAS is removed. Use GGML_CUDA instead.) +endif + +ifdef LLAMA_CUDA +GGML_CUDA := 1 +DEPRECATE_WARNING := 1 +endif + +ifdef LLAMA_KOMPUTE +GGML_KOMPUTE := 1 +DEPRECATE_WARNING := 1 +endif + +ifdef LLAMA_METAL +GGML_METAL := 1 +DEPRECATE_WARNING := 1 +endif + +ifdef LLAMA_RPC +GGML_RPC := 1 +DEPRECATE_WARNING := 1 +endif + +ifdef LLAMA_SYCL +GGML_SYCL := 1 +DEPRECATE_WARNING := 1 +endif + +ifdef LLAMA_SYCL_F16 +GGML_SYCL_F16 := 1 +DEPRECATE_WARNING := 1 +endif + +ifdef LLAMA_OPENBLAS +GGML_OPENBLAS := 1 +DEPRECATE_WARNING := 1 +endif + +ifdef LLAMA_OPENBLAS64 +GGML_OPENBLAS64 := 1 +DEPRECATE_WARNING := 1 +endif + +ifdef LLAMA_BLIS +GGML_BLIS := 1 +DEPRECATE_WARNING := 1 +endif + +ifdef LLAMA_NO_LLAMAFILE +GGML_NO_LLAMAFILE := 1 +DEPRECATE_WARNING := 1 +endif + +ifdef LLAMA_NO_ACCELERATE +GGML_NO_ACCELERATE := 1 +DEPRECATE_WARNING := 1 +endif + +ifdef LLAMA_NO_OPENMP +GGML_NO_OPENMP := 1 +DEPRECATE_WARNING := 1 +endif + +ifdef LLAMA_NO_METAL +GGML_NO_METAL := 1 +DEPRECATE_WARNING := 1 +endif + +ifdef LLAMA_DISABLE_LOGS +REMOVE_WARNING := 1 +endif + +ifdef LLAMA_SERVER_VERBOSE +REMOVE_WARNING := 1 +endif + +ifndef UNAME_S +UNAME_S := $(shell uname -s) +endif + +ifndef UNAME_P +UNAME_P := $(shell uname -p) +endif + +ifndef UNAME_M +UNAME_M := $(shell uname -m) +endif + +# In GNU make default CXX is g++ instead of c++. Let's fix that so that users +# of non-gcc compilers don't have to provide g++ alias or wrapper. +DEFCC := cc +DEFCXX := c++ +ifeq ($(origin CC),default) +CC := $(DEFCC) +endif +ifeq ($(origin CXX),default) +CXX := $(DEFCXX) +endif + +# Mac OS + Arm can report x86_64 +# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789 +ifeq ($(UNAME_S),Darwin) + ifndef GGML_NO_METAL + GGML_METAL := 1 + endif + + GGML_NO_OPENMP := 1 + + ifneq ($(UNAME_P),arm) + SYSCTL_M := $(shell sysctl -n hw.optional.arm64 2>/dev/null) + ifeq ($(SYSCTL_M),1) + # UNAME_P := arm + # UNAME_M := arm64 + warn := $(warning Your arch is announced as x86_64, but it seems to actually be ARM64. Not fixing that can lead to bad performance. For more info see: https://github.com/ggerganov/whisper.cpp/issues/66\#issuecomment-1282546789) + endif + endif +endif + +ifdef GGML_METAL + GGML_METAL_EMBED_LIBRARY := 1 +endif + +ifdef GGML_RPC + BUILD_TARGETS += rpc-server +endif + +ifdef GGML_VULKAN + BUILD_TARGETS += vulkan-shaders-gen +endif + +default: $(BUILD_TARGETS) $(LEGACY_TARGETS_BUILD) + +test: $(TEST_TARGETS) + @failures=0; \ + for test_target in $(TEST_TARGETS); do \ + if [ "$$test_target" = "tests/test-tokenizer-0" ]; then \ + ./$$test_target $(CURDIR)/models/ggml-vocab-llama-spm.gguf; \ + ./$$test_target $(CURDIR)/models/ggml-vocab-llama-bpe.gguf; \ + ./$$test_target $(CURDIR)/models/ggml-vocab-phi-3.gguf; \ + ./$$test_target $(CURDIR)/models/ggml-vocab-falcon.gguf; \ + ./$$test_target $(CURDIR)/models/ggml-vocab-bert-bge.gguf; \ + ./$$test_target $(CURDIR)/models/ggml-vocab-starcoder.gguf; \ + ./$$test_target $(CURDIR)/models/ggml-vocab-gpt-2.gguf; \ + ./$$test_target $(CURDIR)/models/ggml-vocab-refact.gguf; \ + elif [ "$$test_target" = "tests/test-tokenizer-1-spm" ]; then \ + continue; \ + elif [ "$$test_target" = "tests/test-tokenizer-1-bpe" ]; then \ + continue; \ + else \ + echo "Running test $$test_target..."; \ + ./$$test_target; \ + fi; \ + if [ $$? -ne 0 ]; then \ + printf 'Test %s FAILED!\n\n' $$test_target; \ + failures=$$(( failures + 1 )); \ + else \ + printf 'Test %s passed.\n\n' $$test_target; \ + fi; \ + done; \ + if [ $$failures -gt 0 ]; then \ + printf '\n%s tests failed.\n' $$failures; \ + exit 1; \ + fi + @echo 'All tests passed.' + +all: $(BUILD_TARGETS) $(TEST_TARGETS) $(LEGACY_TARGETS_BUILD) + +ifdef RISCV_CROSS_COMPILE +CC := riscv64-unknown-linux-gnu-gcc +CXX := riscv64-unknown-linux-gnu-g++ +endif + +# +# Compile flags +# + +# keep standard at C11 and C++17 +MK_CPPFLAGS = -Iggml/include -Iggml/src -Iinclude -Isrc -Icommon -DGGML_USE_CPU +MK_CFLAGS = -std=c11 -fPIC +MK_CXXFLAGS = -std=c++17 -fPIC +MK_NVCCFLAGS = -std=c++17 + +ifdef LLAMA_NO_CCACHE +GGML_NO_CCACHE := 1 +DEPRECATE_WARNING := 1 +endif + +ifndef GGML_NO_CCACHE +CCACHE := $(shell which ccache) +ifdef CCACHE +export CCACHE_SLOPPINESS = time_macros +$(info I ccache found, compilation results will be cached. Disable with GGML_NO_CCACHE.) +CC := $(CCACHE) $(CC) +CXX := $(CCACHE) $(CXX) +else +$(info I ccache not found. Consider installing it for faster compilation.) +endif # CCACHE +endif # GGML_NO_CCACHE + +# clock_gettime came in POSIX.1b (1993) +# CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional +# posix_memalign came in POSIX.1-2001 / SUSv3 +# M_PI is an XSI extension since POSIX.1-2001 / SUSv3, came in XPG1 (1985) +MK_CPPFLAGS += -D_XOPEN_SOURCE=600 + +# Somehow in OpenBSD whenever POSIX conformance is specified +# some string functions rely on locale_t availability, +# which was introduced in POSIX.1-2008, forcing us to go higher +ifeq ($(UNAME_S),OpenBSD) + MK_CPPFLAGS += -U_XOPEN_SOURCE -D_XOPEN_SOURCE=700 +endif + +# Data types, macros and functions related to controlling CPU affinity and +# some memory allocation are available on Linux through GNU extensions in libc +ifeq ($(UNAME_S),Linux) + MK_CPPFLAGS += -D_GNU_SOURCE + MK_LDFLAGS += -ldl +endif + +# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1, +# and on macOS its availability depends on enabling Darwin extensions +# similarly on DragonFly, enabling BSD extensions is necessary +ifeq ($(UNAME_S),Darwin) + MK_CPPFLAGS += -D_DARWIN_C_SOURCE +endif +ifeq ($(UNAME_S),DragonFly) + MK_CPPFLAGS += -D__BSD_VISIBLE +endif + +# alloca is a non-standard interface that is not visible on BSDs when +# POSIX conformance is specified, but not all of them provide a clean way +# to enable it in such cases +ifeq ($(UNAME_S),FreeBSD) + MK_CPPFLAGS += -D__BSD_VISIBLE +endif +ifeq ($(UNAME_S),NetBSD) + MK_CPPFLAGS += -D_NETBSD_SOURCE +endif +ifeq ($(UNAME_S),OpenBSD) + MK_CPPFLAGS += -D_BSD_SOURCE +endif + +ifdef GGML_SCHED_MAX_COPIES + MK_CPPFLAGS += -DGGML_SCHED_MAX_COPIES=$(GGML_SCHED_MAX_COPIES) +endif + +ifdef LLAMA_DEBUG + MK_CFLAGS += -O0 -g + MK_CXXFLAGS += -O0 -g + MK_LDFLAGS += -g + MK_NVCCFLAGS += -O0 -g + + ifeq ($(UNAME_S),Linux) + MK_CPPFLAGS += -D_GLIBCXX_ASSERTIONS + endif +else + MK_CPPFLAGS += -DNDEBUG + MK_CFLAGS += -O3 -g + MK_CXXFLAGS += -O3 -g + MK_NVCCFLAGS += -O3 -g +endif + +ifdef LLAMA_SANITIZE_THREAD + MK_CFLAGS += -fsanitize=thread -g + MK_CXXFLAGS += -fsanitize=thread -g + MK_LDFLAGS += -fsanitize=thread -g +endif + +ifdef LLAMA_SANITIZE_ADDRESS + MK_CFLAGS += -fsanitize=address -fno-omit-frame-pointer -g + MK_CXXFLAGS += -fsanitize=address -fno-omit-frame-pointer -g + MK_LDFLAGS += -fsanitize=address -fno-omit-frame-pointer -g +endif + +ifdef LLAMA_SANITIZE_UNDEFINED + MK_CFLAGS += -fsanitize=undefined -g + MK_CXXFLAGS += -fsanitize=undefined -g + MK_LDFLAGS += -fsanitize=undefined -g +endif + +ifdef LLAMA_SERVER_SSL + MK_CPPFLAGS += -DCPPHTTPLIB_OPENSSL_SUPPORT + MK_LDFLAGS += -lssl -lcrypto +endif + +ifndef GGML_NO_CPU_AARCH64 + MK_CPPFLAGS += -DGGML_USE_CPU_REPACK +endif + +# warnings +WARN_FLAGS = \ + -Wall \ + -Wextra \ + -Wpedantic \ + -Wcast-qual \ + -Wno-unused-function + +MK_CFLAGS += \ + $(WARN_FLAGS) \ + -Wshadow \ + -Wstrict-prototypes \ + -Wpointer-arith \ + -Wmissing-prototypes \ + -Werror=implicit-int \ + -Werror=implicit-function-declaration + +MK_CXXFLAGS += \ + $(WARN_FLAGS) \ + -Wmissing-declarations \ + -Wmissing-noreturn + +ifeq ($(LLAMA_FATAL_WARNINGS),1) + MK_CFLAGS += -Werror + MK_CXXFLAGS += -Werror +endif + +# this version of Apple ld64 is buggy +ifneq '' '$(findstring dyld-1015.7,$(shell $(CC) $(LDFLAGS) -Wl,-v 2>&1))' + MK_CPPFLAGS += -DHAVE_BUGGY_APPLE_LINKER +endif + +# OS specific +# TODO: support Windows +ifneq '' '$(filter $(UNAME_S),Linux Darwin FreeBSD NetBSD OpenBSD Haiku)' + MK_CFLAGS += -pthread + MK_CXXFLAGS += -pthread +endif + +# detect Windows +ifneq ($(findstring _NT,$(UNAME_S)),) + _WIN32 := 1 +endif + +# library name prefix +ifneq ($(_WIN32),1) + LIB_PRE := lib +endif + +# Dynamic Shared Object extension +ifneq ($(_WIN32),1) + DSO_EXT := .so +else + DSO_EXT := .dll +endif + +# Windows Sockets 2 (Winsock) for network-capable apps +ifeq ($(_WIN32),1) + LWINSOCK2 := -lws2_32 +endif + +ifdef LLAMA_GPROF + MK_CFLAGS += -pg + MK_CXXFLAGS += -pg +endif + +# Architecture specific +# TODO: probably these flags need to be tweaked on some architectures +# feel free to update the Makefile for your architecture and send a pull request or issue + +ifndef RISCV_CROSS_COMPILE + +ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686 amd64)) + # Use all CPU extensions that are available: + MK_CFLAGS += -march=native -mtune=native + HOST_CXXFLAGS += -march=native -mtune=native + + # Usage AMX build test + #MK_CFLAGS += -march=graniterapids -mtune=graniterapids + #HOST_CXXFLAGS += -march=graniterapids -mtune=graniterapids + + # Usage AVX-only + #MK_CFLAGS += -mfma -mf16c -mavx + #MK_CXXFLAGS += -mfma -mf16c -mavx + + # Usage SSSE3-only (Not is SSE3!) + #MK_CFLAGS += -mssse3 + #MK_CXXFLAGS += -mssse3 +endif + +ifneq '' '$(findstring mingw,$(shell $(CC) -dumpmachine))' + # The stack is only 16-byte aligned on Windows, so don't let gcc emit aligned moves. + # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=54412 + # https://github.com/ggml-org/llama.cpp/issues/2922 + MK_CFLAGS += -Xassembler -muse-unaligned-vector-move + MK_CXXFLAGS += -Xassembler -muse-unaligned-vector-move + + # Target Windows 8 for PrefetchVirtualMemory + MK_CPPFLAGS += -D_WIN32_WINNT=0x602 +endif + +ifneq ($(filter aarch64%,$(UNAME_M)),) + # Apple M1, M2, etc. + # Raspberry Pi 3, 4, Zero 2 (64-bit) + # Nvidia Jetson + MK_CFLAGS += -mcpu=native + MK_CXXFLAGS += -mcpu=native + JETSON_RELEASE_INFO = $(shell jetson_release) + ifdef JETSON_RELEASE_INFO + ifneq ($(filter TX2%,$(JETSON_RELEASE_INFO)),) + JETSON_EOL_MODULE_DETECT = 1 + CC = aarch64-unknown-linux-gnu-gcc + cxx = aarch64-unknown-linux-gnu-g++ + endif + endif +endif + +ifneq ($(filter armv6%,$(UNAME_M)),) + # Raspberry Pi 1, Zero + MK_CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access + MK_CXXFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access +endif + +ifneq ($(filter armv7%,$(UNAME_M)),) + # Raspberry Pi 2 + MK_CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations + MK_CXXFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations +endif + +ifneq ($(filter armv8%,$(UNAME_M)),) + # Raspberry Pi 3, 4, Zero 2 (32-bit) + MK_CFLAGS += -mfp16-format=ieee -mno-unaligned-access + MK_CXXFLAGS += -mfp16-format=ieee -mno-unaligned-access +endif + +ifneq ($(filter ppc64%,$(UNAME_M)),) + POWER9_M := $(shell grep "POWER9" /proc/cpuinfo) + ifneq (,$(findstring POWER9,$(POWER9_M))) + MK_CFLAGS += -mcpu=power9 + MK_CXXFLAGS += -mcpu=power9 + endif +endif + +ifneq ($(filter ppc64le%,$(UNAME_M)),) + MK_CFLAGS += -mcpu=powerpc64le + MK_CXXFLAGS += -mcpu=powerpc64le + CUDA_POWER_ARCH = 1 +endif + +ifneq ($(filter loongarch64%,$(UNAME_M)),) + MK_CFLAGS += -mlasx + MK_CXXFLAGS += -mlasx +endif + +ifneq ($(filter riscv64%,$(UNAME_M)),) + MK_CFLAGS += -march=rv64gcv -mabi=lp64d + MK_CXXFLAGS += -march=rv64gcv -mabi=lp64d +endif + +else # RISC-V CROSS COMPILATION + MK_CFLAGS += -march=rv64gcv -mabi=lp64d + MK_CXXFLAGS += -march=rv64gcv -mabi=lp64d +endif + +ifndef GGML_NO_ACCELERATE + # Mac OS - include Accelerate framework. + # `-framework Accelerate` works both with Apple Silicon and Mac Intel + ifeq ($(UNAME_S),Darwin) + MK_CPPFLAGS += -DGGML_USE_ACCELERATE -DGGML_USE_BLAS -DGGML_BLAS_USE_ACCELERATE + MK_CPPFLAGS += -DACCELERATE_NEW_LAPACK + MK_CPPFLAGS += -DACCELERATE_LAPACK_ILP64 + MK_LDFLAGS += -framework Accelerate + OBJ_GGML_EXT += ggml/src/ggml-blas/ggml-blas.o + endif +endif # GGML_NO_ACCELERATE + +ifndef GGML_NO_OPENMP + MK_CPPFLAGS += -DGGML_USE_OPENMP + MK_CFLAGS += -fopenmp + MK_CXXFLAGS += -fopenmp +endif # GGML_NO_OPENMP + +ifdef GGML_OPENBLAS + MK_CPPFLAGS += -DGGML_USE_BLAS $(shell pkg-config --cflags-only-I openblas) + MK_CFLAGS += $(shell pkg-config --cflags-only-other openblas) + MK_LDFLAGS += $(shell pkg-config --libs openblas) + OBJ_GGML_EXT += ggml/src/ggml-blas/ggml-blas.o +endif # GGML_OPENBLAS + +ifdef GGML_OPENBLAS64 + MK_CPPFLAGS += -DGGML_USE_BLAS $(shell pkg-config --cflags-only-I openblas64) + MK_CFLAGS += $(shell pkg-config --cflags-only-other openblas64) + MK_LDFLAGS += $(shell pkg-config --libs openblas64) + OBJ_GGML_EXT += ggml/src/ggml-blas/ggml-blas.o +endif # GGML_OPENBLAS64 + +ifdef GGML_BLIS + MK_CPPFLAGS += -DGGML_USE_BLAS -DGGML_BLAS_USE_BLIS -I/usr/local/include/blis -I/usr/include/blis + MK_LDFLAGS += -lblis -L/usr/local/lib + OBJ_GGML_EXT += ggml/src/ggml-blas/ggml-blas.o +endif # GGML_BLIS + +ifdef GGML_NVPL + MK_CPPFLAGS += -DGGML_USE_BLAS -DGGML_BLAS_USE_NVPL -DNVPL_ILP64 -I/usr/local/include/nvpl_blas -I/usr/include/nvpl_blas + MK_LDFLAGS += -L/usr/local/lib -lnvpl_blas_core -lnvpl_blas_ilp64_gomp + OBJ_GGML_EXT += ggml/src/ggml-blas/ggml-blas.o +endif # GGML_NVPL + +ifndef GGML_NO_LLAMAFILE + MK_CPPFLAGS += -DGGML_USE_LLAMAFILE + OBJ_GGML_EXT += ggml/src/ggml-cpu/llamafile/sgemm.o +endif + +ifndef GGML_NO_AMX + MK_CPPFLAGS += -DGGML_USE_AMX + OBJ_GGML_EXT += ggml/src/ggml-cpu/amx/amx.o ggml/src/ggml-cpu/amx/mmq.o +endif + +# only necessary for the CPU backend files +MK_CPPFLAGS += -Iggml/src/ggml-cpu + +ifdef GGML_RPC + MK_CPPFLAGS += -DGGML_USE_RPC + OBJ_GGML_EXT += ggml/src/ggml-rpc.o +endif # GGML_RPC + +OBJ_CUDA_TMPL = $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-mma*.cu)) +OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/mmq*.cu)) + +ifdef GGML_CUDA_FA_ALL_QUANTS + OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*.cu)) +else + OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu)) + OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu)) + OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*f16-f16.cu)) +endif # GGML_CUDA_FA_ALL_QUANTS + +ifdef GGML_CUDA + ifneq ('', '$(wildcard /opt/cuda)') + CUDA_PATH ?= /opt/cuda + else + CUDA_PATH ?= /usr/local/cuda + endif + + MK_CPPFLAGS += -DGGML_USE_CUDA -DGGML_CUDA_USE_GRAPHS -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include + MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/lib/wsl/lib + MK_NVCCFLAGS += -use_fast_math + + OBJ_GGML_EXT += ggml/src/ggml-cuda/ggml-cuda.o + OBJ_GGML_EXT += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu)) + OBJ_GGML_EXT += $(OBJ_CUDA_TMPL) + +ifdef LLAMA_FATAL_WARNINGS + MK_NVCCFLAGS += -Werror all-warnings +endif # LLAMA_FATAL_WARNINGS + +ifndef JETSON_EOL_MODULE_DETECT + MK_NVCCFLAGS += --forward-unknown-to-host-compiler +endif # JETSON_EOL_MODULE_DETECT + +ifdef LLAMA_DEBUG + MK_NVCCFLAGS += -lineinfo +endif # LLAMA_DEBUG + +ifdef GGML_CUDA_DEBUG + MK_NVCCFLAGS += --device-debug +endif # GGML_CUDA_DEBUG + +ifdef GGML_CUDA_NVCC + NVCC = $(CCACHE) $(GGML_CUDA_NVCC) +else + NVCC = $(CCACHE) nvcc +endif # GGML_CUDA_NVCC + +ifdef CUDA_DOCKER_ARCH + MK_NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH) +else ifndef CUDA_POWER_ARCH + MK_NVCCFLAGS += -arch=native +endif # CUDA_DOCKER_ARCH + +ifdef GGML_CUDA_FORCE_MMQ + MK_NVCCFLAGS += -DGGML_CUDA_FORCE_MMQ +endif # GGML_CUDA_FORCE_MMQ + +ifdef GGML_CUDA_FORCE_CUBLAS + MK_NVCCFLAGS += -DGGML_CUDA_FORCE_CUBLAS +endif # GGML_CUDA_FORCE_CUBLAS + +ifdef GGML_CUDA_F16 + MK_NVCCFLAGS += -DGGML_CUDA_F16 +endif # GGML_CUDA_F16 + +ifdef GGML_CUDA_DMMV_F16 + MK_NVCCFLAGS += -DGGML_CUDA_F16 +endif # GGML_CUDA_DMMV_F16 + +ifdef GGML_CUDA_PEER_MAX_BATCH_SIZE + MK_NVCCFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=$(GGML_CUDA_PEER_MAX_BATCH_SIZE) +else + MK_NVCCFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 +endif # GGML_CUDA_PEER_MAX_BATCH_SIZE + +ifdef GGML_CUDA_NO_PEER_COPY + MK_NVCCFLAGS += -DGGML_CUDA_NO_PEER_COPY +endif # GGML_CUDA_NO_PEER_COPY + +ifdef GGML_CUDA_CCBIN + MK_NVCCFLAGS += -ccbin $(GGML_CUDA_CCBIN) +endif # GGML_CUDA_CCBIN + +ifdef GGML_CUDA_NO_FA + MK_NVCCFLAGS += -DGGML_CUDA_NO_FA +endif # GGML_CUDA_NO_FA + +ifdef GGML_CUDA_FA_ALL_QUANTS + MK_NVCCFLAGS += -DGGML_CUDA_FA_ALL_QUANTS +endif # GGML_CUDA_FA_ALL_QUANTS + +ifdef JETSON_EOL_MODULE_DETECT +define NVCC_COMPILE + $(NVCC) -I. -Icommon -D_XOPEN_SOURCE=600 -D_GNU_SOURCE -DNDEBUG -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I/usr/local/cuda/targets/aarch64-linux/include -std=c++11 -O3 $(NVCCFLAGS) $(CPPFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@ +endef # NVCC_COMPILE +else +define NVCC_COMPILE + $(NVCC) $(NVCCFLAGS) $(CPPFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@ +endef # NVCC_COMPILE +endif # JETSON_EOL_MODULE_DETECT + +ggml/src/ggml-cuda/%.o: \ + ggml/src/ggml-cuda/%.cu \ + ggml/include/ggml.h \ + ggml/src/ggml-common.h \ + ggml/src/ggml-cuda/common.cuh + $(NVCC_COMPILE) + +ggml/src/ggml-cuda/ggml-cuda.o: \ + ggml/src/ggml-cuda/ggml-cuda.cu \ + ggml/include/ggml-cuda.h \ + ggml/include/ggml.h \ + ggml/include/ggml-backend.h \ + ggml/src/ggml-backend-impl.h \ + ggml/src/ggml-common.h \ + $(wildcard ggml/src/ggml-cuda/*.cuh) + $(NVCC_COMPILE) +endif # GGML_CUDA + +ifdef GGML_VULKAN + MK_CPPFLAGS += -DGGML_USE_VULKAN + MK_LDFLAGS += $(shell pkg-config --libs vulkan) + OBJ_GGML_EXT += ggml/src/ggml-vulkan.o ggml/src/ggml-vulkan-shaders.o + +ifdef GGML_VULKAN_CHECK_RESULTS + MK_CPPFLAGS += -DGGML_VULKAN_CHECK_RESULTS +endif + +ifdef GGML_VULKAN_DEBUG + MK_CPPFLAGS += -DGGML_VULKAN_DEBUG +endif + +ifdef GGML_VULKAN_MEMORY_DEBUG + MK_CPPFLAGS += -DGGML_VULKAN_MEMORY_DEBUG +endif + +ifdef GGML_VULKAN_PERF + MK_CPPFLAGS += -DGGML_VULKAN_PERF +endif + +ifdef GGML_VULKAN_VALIDATE + MK_CPPFLAGS += -DGGML_VULKAN_VALIDATE +endif + +ifdef GGML_VULKAN_RUN_TESTS + MK_CPPFLAGS += -DGGML_VULKAN_RUN_TESTS +endif + +GLSLC_CMD = glslc +_ggml_vk_genshaders_cmd = $(shell pwd)/vulkan-shaders-gen +_ggml_vk_header = ggml/src/ggml-vulkan-shaders.hpp +_ggml_vk_source = ggml/src/ggml-vulkan-shaders.cpp +_ggml_vk_input_dir = ggml/src/ggml-vulkan/vulkan-shaders +_ggml_vk_shader_deps = $(echo $(_ggml_vk_input_dir)/*.comp) + +ggml/src/ggml-vulkan.o: ggml/src/ggml-vulkan/ggml-vulkan.cpp ggml/include/ggml-vulkan.h $(_ggml_vk_header) $(_ggml_vk_source) + $(CXX) $(CXXFLAGS) $(shell pkg-config --cflags vulkan) -c $< -o $@ + +$(_ggml_vk_header): $(_ggml_vk_source) + +$(_ggml_vk_source): $(_ggml_vk_shader_deps) vulkan-shaders-gen + $(_ggml_vk_genshaders_cmd) \ + --glslc $(GLSLC_CMD) \ + --input-dir $(_ggml_vk_input_dir) \ + --target-hpp $(_ggml_vk_header) \ + --target-cpp $(_ggml_vk_source) + +vulkan-shaders-gen: ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp + $(CXX) $(CXXFLAGS) -o $@ $(LDFLAGS) ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp + +endif # GGML_VULKAN + +ifdef GGML_HIP + ifeq ($(wildcard /opt/rocm),) + ROCM_PATH ?= /usr + AMDGPU_TARGETS ?= $(shell $(shell which amdgpu-arch)) + else + ROCM_PATH ?= /opt/rocm + AMDGPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch) + endif + + MK_CPPFLAGS += -DGGML_USE_HIP -DGGML_USE_CUDA + + MK_LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib + MK_LDFLAGS += -L$(ROCM_PATH)/lib64 -Wl,-rpath=$(ROCM_PATH)/lib64 + MK_LDFLAGS += -lhipblas -lamdhip64 -lrocblas + + HIPCC ?= $(CCACHE) $(ROCM_PATH)/bin/hipcc + + HIPFLAGS += $(addprefix --offload-arch=,$(AMDGPU_TARGETS)) + +ifdef GGML_CUDA_FORCE_MMQ + HIPFLAGS += -DGGML_CUDA_FORCE_MMQ +endif # GGML_CUDA_FORCE_MMQ + +ifdef GGML_CUDA_FORCE_CUBLAS + HIPFLAGS += -DGGML_CUDA_FORCE_CUBLAS +endif # GGML_CUDA_FORCE_CUBLAS + +ifdef GGML_CUDA_NO_PEER_COPY + HIPFLAGS += -DGGML_CUDA_NO_PEER_COPY +endif # GGML_CUDA_NO_PEER_COPY + +ifdef GGML_CUDA_NO_FA + HIPFLAGS += -DGGML_CUDA_NO_FA +endif # GGML_CUDA_NO_FA + + OBJ_GGML_EXT += ggml/src/ggml-cuda/ggml-cuda.o + OBJ_GGML_EXT += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu)) + OBJ_GGML_EXT += $(OBJ_CUDA_TMPL) + +ggml/src/ggml-cuda/ggml-cuda.o: \ + ggml/src/ggml-cuda/ggml-cuda.cu \ + ggml/include/ggml-cuda.h \ + ggml/include/ggml.h \ + ggml/include/ggml-backend.h \ + ggml/src/ggml-backend-impl.h \ + ggml/src/ggml-common.h \ + $(wildcard ggml/src/ggml-cuda/*.cuh) + $(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $< + +ggml/src/ggml-cuda/%.o: \ + ggml/src/ggml-cuda/%.cu \ + ggml/include/ggml.h \ + ggml/src/ggml-common.h \ + ggml/src/ggml-cuda/common.cuh + $(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $< +endif # GGML_HIP + +ifdef GGML_MUSA + ifeq ($(wildcard /opt/musa),) + MUSA_PATH ?= /usr/local/musa + else + MUSA_PATH ?= /opt/musa + endif + MUSA_ARCHITECTURES ?= 21;22;31 + + MK_CPPFLAGS += -DGGML_USE_MUSA -DGGML_USE_CUDA + MK_LDFLAGS += -L$(MUSA_PATH)/lib -Wl,-rpath=$(MUSA_PATH)/lib + MK_LDFLAGS += -lmusa -lmusart -lmublas + + ifndef GGML_NO_OPENMP + # For Ubuntu Focal + MK_CPPFLAGS += -I/usr/lib/llvm-10/include/openmp + MK_LDFLAGS += -L/usr/lib/llvm-10/lib + # For Ubuntu Jammy + MK_CPPFLAGS += -I/usr/lib/llvm-14/lib/clang/14.0.0/include + MK_LDFLAGS += -L/usr/lib/llvm-14/lib + endif # GGML_NO_OPENMP + + CC := $(MUSA_PATH)/bin/clang + CXX := $(MUSA_PATH)/bin/clang++ + MCC := $(CCACHE) $(MUSA_PATH)/bin/mcc + + MUSAFLAGS = -fsigned-char -x musa -mtgpu + MUSAFLAGS += $(foreach arch,$(subst ;, ,$(MUSA_ARCHITECTURES)),--cuda-gpu-arch=mp_$(arch)) + +ifdef GGML_CUDA_FORCE_MMQ + MUSAFLAGS += -DGGML_CUDA_FORCE_MMQ +endif # GGML_CUDA_FORCE_MMQ + +ifdef GGML_CUDA_FORCE_CUBLAS + MUSAFLAGS += -DGGML_CUDA_FORCE_CUBLAS +endif # GGML_CUDA_FORCE_CUBLAS + +ifdef GGML_CUDA_F16 + MUSAFLAGS += -DGGML_CUDA_F16 +endif # GGML_CUDA_F16 + +ifdef GGML_CUDA_DMMV_F16 + MUSAFLAGS += -DGGML_CUDA_F16 +endif # GGML_CUDA_DMMV_F16 + +ifdef GGML_CUDA_PEER_MAX_BATCH_SIZE + MUSAFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=$(GGML_CUDA_PEER_MAX_BATCH_SIZE) +else + MUSAFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 +endif # GGML_CUDA_PEER_MAX_BATCH_SIZE + +ifdef GGML_CUDA_NO_PEER_COPY + MUSAFLAGS += -DGGML_CUDA_NO_PEER_COPY +endif # GGML_CUDA_NO_PEER_COPY + +ifdef GGML_CUDA_NO_FA + MUSAFLAGS += -DGGML_CUDA_NO_FA +endif # GGML_CUDA_NO_FA + +ifdef GGML_CUDA_FA_ALL_QUANTS + MUSAFLAGS += -DGGML_CUDA_FA_ALL_QUANTS +endif # GGML_CUDA_FA_ALL_QUANTS + + OBJ_GGML_EXT += ggml/src/ggml-cuda/ggml-cuda.o + OBJ_GGML_EXT += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu)) + OBJ_GGML_EXT += $(OBJ_CUDA_TMPL) + +ggml/src/ggml-cuda/ggml-cuda.o: \ + ggml/src/ggml-cuda/ggml-cuda.cu \ + ggml/include/ggml-cuda.h \ + ggml/include/ggml.h \ + ggml/include/ggml-backend.h \ + ggml/src/ggml-backend-impl.h \ + ggml/src/ggml-common.h \ + $(wildcard ggml/src/ggml-cuda/*.cuh) + $(MCC) $(CXXFLAGS) $(MUSAFLAGS) -c -o $@ $< + +ggml/src/ggml-cuda/%.o: \ + ggml/src/ggml-cuda/%.cu \ + ggml/include/ggml.h \ + ggml/src/ggml-common.h \ + ggml/src/ggml-cuda/common.cuh + $(MCC) $(CXXFLAGS) $(MUSAFLAGS) -c -o $@ $< +endif # GGML_MUSA + +ifdef GGML_METAL + MK_CPPFLAGS += -DGGML_USE_METAL + MK_LDFLAGS += -framework Foundation -framework Metal -framework MetalKit + OBJ_GGML_EXT += ggml/src/ggml-metal/ggml-metal.o + +ifdef GGML_METAL_USE_BF16 + MK_CPPFLAGS += -DGGML_METAL_USE_BF16 +endif # GGML_METAL_USE_BF16 +ifdef GGML_METAL_NDEBUG + MK_CPPFLAGS += -DGGML_METAL_NDEBUG +endif +ifdef GGML_METAL_EMBED_LIBRARY + MK_CPPFLAGS += -DGGML_METAL_EMBED_LIBRARY + OBJ_GGML_EXT += ggml/src/ggml-metal-embed.o +endif +endif # GGML_METAL + +ifdef GGML_METAL +ggml/src/ggml-metal/ggml-metal.o: \ + ggml/src/ggml-metal/ggml-metal.m \ + ggml/src/ggml-metal/ggml-metal-impl.h \ + ggml/include/ggml-metal.h \ + ggml/include/ggml.h + $(CC) $(CFLAGS) -c $< -o $@ + +ifdef GGML_METAL_EMBED_LIBRARY +ggml/src/ggml-metal-embed.o: \ + ggml/src/ggml-metal/ggml-metal.metal \ + ggml/src/ggml-metal/ggml-metal-impl.h \ + ggml/src/ggml-common.h + @echo "Embedding Metal library" + @sed -e '/__embed_ggml-common.h__/r ggml/src/ggml-common.h' -e '/__embed_ggml-common.h__/d' < ggml/src/ggml-metal/ggml-metal.metal > ggml/src/ggml-metal/ggml-metal-embed.metal.tmp + @sed -e '/#include "ggml-metal-impl.h"/r ggml/src/ggml-metal/ggml-metal-impl.h' -e '/#include "ggml-metal-impl.h"/d' < ggml/src/ggml-metal/ggml-metal-embed.metal.tmp > ggml/src/ggml-metal/ggml-metal-embed.metal + $(eval TEMP_ASSEMBLY=$(shell mktemp -d)) + @echo ".section __DATA, __ggml_metallib" > $(TEMP_ASSEMBLY)/ggml-metal-embed.s + @echo ".globl _ggml_metallib_start" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s + @echo "_ggml_metallib_start:" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s + @echo ".incbin \"ggml/src/ggml-metal/ggml-metal-embed.metal\"" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s + @echo ".globl _ggml_metallib_end" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s + @echo "_ggml_metallib_end:" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s + $(CC) $(CFLAGS) -c $(TEMP_ASSEMBLY)/ggml-metal-embed.s -o $@ + @rm -f ${TEMP_ASSEMBLY}/ggml-metal-embed.s + @rmdir ${TEMP_ASSEMBLY} +endif +endif # GGML_METAL + +DIR_GGML = ggml +DIR_LLAMA = src +DIR_COMMON = common + +OBJ_GGML = \ + $(DIR_GGML)/src/ggml.o \ + $(DIR_GGML)/src/ggml-alloc.o \ + $(DIR_GGML)/src/ggml-backend.o \ + $(DIR_GGML)/src/ggml-backend-reg.o \ + $(DIR_GGML)/src/ggml-opt.o \ + $(DIR_GGML)/src/ggml-quants.o \ + $(DIR_GGML)/src/ggml-threading.o \ + $(DIR_GGML)/src/ggml-cpu/ggml-cpu.o \ + $(DIR_GGML)/src/ggml-cpu/ggml-cpu_cpp.o \ + $(DIR_GGML)/src/ggml-cpu/repack.o \ + $(DIR_GGML)/src/ggml-cpu/ggml-cpu-hbm.o \ + $(DIR_GGML)/src/ggml-cpu/ggml-cpu-quants.o \ + $(DIR_GGML)/src/ggml-cpu/ggml-cpu-traits.o \ + $(OBJ_GGML_EXT) + +OBJ_LLAMA = \ + $(DIR_LLAMA)/llama.o \ + $(DIR_LLAMA)/llama-vocab.o \ + $(DIR_LLAMA)/llama-grammar.o \ + $(DIR_LLAMA)/llama-sampling.o \ + $(DIR_LLAMA)/unicode.o \ + $(DIR_LLAMA)/unicode-data.o + +OBJ_COMMON = \ + $(DIR_COMMON)/common.o \ + $(DIR_COMMON)/arg.o \ + $(DIR_COMMON)/log.o \ + $(DIR_COMMON)/console.o \ + $(DIR_COMMON)/ngram-cache.o \ + $(DIR_COMMON)/sampling.o \ + $(DIR_COMMON)/speculative.o \ + $(DIR_COMMON)/chat.o \ + $(DIR_COMMON)/build-info.o \ + $(DIR_COMMON)/json-schema-to-grammar.o + +OBJ_ALL = $(OBJ_GGML) $(OBJ_LLAMA) $(OBJ_COMMON) + +LIB_GGML = $(LIB_PRE)ggml$(DSO_EXT) +LIB_GGML_S = $(LIB_PRE)ggml.a + +LIB_LLAMA = $(LIB_PRE)llama$(DSO_EXT) +LIB_LLAMA_S = $(LIB_PRE)llama.a + +LIB_COMMON = $(LIB_PRE)common$(DSO_EXT) +LIB_COMMON_S = $(LIB_PRE)common.a + +LIB_ALL = $(LIB_GGML) $(LIB_LLAMA) $(LIB_COMMON) +LIB_ALL_S = $(LIB_GGML_S) $(LIB_LLAMA_S) $(LIB_COMMON_S) + +GF_CC := $(CC) +include scripts/get-flags.mk + +# combine build flags with cmdline overrides +override CPPFLAGS := $(MK_CPPFLAGS) $(CPPFLAGS) +override CFLAGS := $(CPPFLAGS) $(MK_CFLAGS) $(GF_CFLAGS) $(CFLAGS) +BASE_CXXFLAGS := $(MK_CXXFLAGS) $(CXXFLAGS) +override CXXFLAGS := $(BASE_CXXFLAGS) $(HOST_CXXFLAGS) $(GF_CXXFLAGS) $(CPPFLAGS) +override NVCCFLAGS := $(MK_NVCCFLAGS) $(NVCCFLAGS) +override LDFLAGS := $(MK_LDFLAGS) $(LDFLAGS) + +# identify CUDA host compiler +ifdef GGML_CUDA +GF_CC := $(NVCC) $(NVCCFLAGS) 2>/dev/null .c -Xcompiler +include scripts/get-flags.mk +CUDA_CXXFLAGS := $(BASE_CXXFLAGS) $(GF_CXXFLAGS) -Wno-pedantic +endif + +ifdef LLAMA_CURL +override CXXFLAGS := $(CXXFLAGS) -DLLAMA_USE_CURL +override LDFLAGS := $(LDFLAGS) -lcurl +endif + +# +# Print build information +# + +$(info I llama.cpp build info: ) +$(info I UNAME_S: $(UNAME_S)) +$(info I UNAME_P: $(UNAME_P)) +$(info I UNAME_M: $(UNAME_M)) +$(info I CFLAGS: $(CFLAGS)) +$(info I CXXFLAGS: $(CXXFLAGS)) +$(info I NVCCFLAGS: $(NVCCFLAGS)) +$(info I LDFLAGS: $(LDFLAGS)) +$(info I CC: $(shell $(CC) --version | head -n 1)) +$(info I CXX: $(shell $(CXX) --version | head -n 1)) +ifdef GGML_CUDA +$(info I NVCC: $(shell $(NVCC) --version | tail -n 1)) +CUDA_VERSION := $(shell $(NVCC) --version | grep -oP 'release (\K[0-9]+\.[0-9])') +ifeq ($(shell awk -v "v=$(CUDA_VERSION)" 'BEGIN { print (v < 11.7) }'),1) + +ifndef CUDA_DOCKER_ARCH +ifndef CUDA_POWER_ARCH +$(error I ERROR: For CUDA versions < 11.7 a target CUDA architecture must be explicitly provided via environment variable CUDA_DOCKER_ARCH, e.g. by running "export CUDA_DOCKER_ARCH=compute_XX" on Unix-like systems, where XX is the minimum compute capability that the code needs to run on. A list with compute capabilities can be found here: https://developer.nvidia.com/cuda-gpus ) +endif # CUDA_POWER_ARCH +endif # CUDA_DOCKER_ARCH + +endif # eq ($(shell echo "$(CUDA_VERSION) < 11.7" | bc),1) +endif # GGML_CUDA +$(info ) + +ifdef DEPRECATE_WARNING +$(info !!! DEPRECATION WARNING !!!) +$(info The following LLAMA_ options are deprecated and will be removed in the future. Use the GGML_ prefix instead) +$(info - LLAMA_CUDA) +$(info - LLAMA_METAL) +$(info - LLAMA_METAL_EMBED_LIBRARY) +$(info - LLAMA_OPENMP) +$(info - LLAMA_RPC) +$(info - LLAMA_SYCL) +$(info - LLAMA_SYCL_F16) +$(info - LLAMA_OPENBLAS) +$(info - LLAMA_OPENBLAS64) +$(info - LLAMA_BLIS) +$(info - LLAMA_NO_LLAMAFILE) +$(info - LLAMA_NO_ACCELERATE) +$(info - LLAMA_NO_OPENMP) +$(info - LLAMA_NO_METAL) +$(info - LLAMA_NO_CCACHE) +$(info ) +endif + +ifdef REMOVE_WARNING +$(info !!! REMOVAL WARNING !!!) +$(info The following LLAMA_ options have been removed and are no longer supported) +$(info - LLAMA_DISABLE_LOGS (https://github.com/ggml-org/llama.cpp/pull/9418)) +$(info - LLAMA_SERVER_VERBOSE (https://github.com/ggml-org/llama.cpp/pull/9418)) +$(info ) +endif + +# +# Build libraries +# + +# Libraries +LIB_GGML = libggml.so +LIB_GGML_S = libggml.a + +LIB_LLAMA = libllama.so +LIB_LLAMA_S = libllama.a + +LIB_COMMON = libcommon.so +LIB_COMMON_S = libcommon.a + +# Targets +BUILD_TARGETS += $(LIB_GGML) $(LIB_GGML_S) $(LIB_LLAMA) $(LIB_LLAMA_S) $(LIB_COMMON) $(LIB_COMMON_S) + +# Dependency files +DEP_FILES = $(OBJ_GGML:.o=.d) $(OBJ_LLAMA:.o=.d) $(OBJ_COMMON:.o=.d) + +# Default target +all: $(BUILD_TARGETS) + +# force c++ build for source file that have same name as c file +# Note: need this exception because `ggml-cpu.c` and `ggml-cpu.cpp` both produce the same obj/dep files +$(DIR_GGML)/%_cpp.o: $(DIR_GGML)/%.cpp + $(CXX) $(CXXFLAGS) -MMD -c $< -o $@ + +# Rules for building object files +$(DIR_GGML)/%.o: $(DIR_GGML)/%.c + $(CC) $(CFLAGS) -MMD -c $< -o $@ + +$(DIR_GGML)/%.o: $(DIR_GGML)/%.cpp + $(CXX) $(CXXFLAGS) -MMD -c $< -o $@ + +$(DIR_LLAMA)/%.o: $(DIR_LLAMA)/%.cpp + $(CXX) $(CXXFLAGS) -MMD -c $< -o $@ + +$(DIR_COMMON)/%.o: $(DIR_COMMON)/%.cpp + $(CXX) $(CXXFLAGS) -MMD -c $< -o $@ + +# Rules for building libraries +$(LIB_GGML): $(OBJ_GGML) + $(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS) + +$(LIB_GGML_S): $(OBJ_GGML) + ar rcs $(LIB_GGML_S) $^ + +$(LIB_LLAMA): $(OBJ_LLAMA) $(LIB_GGML) + $(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS) + +$(LIB_LLAMA_S): $(OBJ_LLAMA) + ar rcs $(LIB_LLAMA_S) $^ + +$(LIB_COMMON): $(OBJ_COMMON) $(LIB_LLAMA) $(LIB_GGML) + $(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS) + +$(LIB_COMMON_S): $(OBJ_COMMON) + ar rcs $(LIB_COMMON_S) $^ + +# Include dependency files +-include $(DEP_FILES) + +# Clean generated server assets +clean-server-assets: + find tools/server -type f -name "*.js.hpp" -delete + find tools/server -type f -name "*.mjs.hpp" -delete + find tools/server -type f -name "*.css.hpp" -delete + find tools/server -type f -name "*.html.hpp" -delete + +# Clean rule +clean: clean-server-assets + rm -vrf $(BUILD_TARGETS) $(TEST_TARGETS) + rm -rvf *.a *.dll *.so *.dot + find ggml src common tests examples pocs -type f -name "*.o" -delete + find ggml src common tests examples pocs -type f -name "*.d" -delete + +# +# Examples +# + +# $< is the first prerequisite, i.e. the source file. +# Explicitly compile this to an object file so that it can be cached with ccache. +# The source file is then filtered out from $^ (the list of all prerequisites) and the object file is added instead. + +# Helper function that replaces .c, .cpp, and .cu file endings with .o: +GET_OBJ_FILE = $(patsubst %.c,%.o,$(patsubst %.cpp,%.o,$(patsubst %.cu,%.o,$(1)))) + +llama-cli: tools/main/main.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + @echo + @echo '==== Run ./llama-cli -h for help. ====' + @echo + +llama-run: tools/run/run.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-simple: examples/simple/simple.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-simple-chat: examples/simple-chat/simple-chat.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-tokenize: tools/tokenize/tokenize.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-batched: examples/batched/batched.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-batched-bench: tools/batched-bench/batched-bench.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-quantize: tools/quantize/quantize.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-quantize-stats: tools/quantize-stats/quantize-stats.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-perplexity: tools/perplexity/perplexity.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-imatrix: tools/imatrix/imatrix.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-embedding: examples/embedding/embedding.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-gritlm: examples/gritlm/gritlm.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-save-load-state: examples/save-load-state/save-load-state.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-gguf: examples/gguf/gguf.cpp \ + $(OBJ_GGML) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +examples/gguf-hash/deps/sha1/sha1.o: \ + examples/gguf-hash/deps/sha1/sha1.c + $(CC) $(CFLAGS) -Iexamples/gguf-hash/deps -c $< -o $@ + +examples/gguf-hash/deps/xxhash/xxhash.o: \ + examples/gguf-hash/deps/xxhash/xxhash.c + $(CC) $(CFLAGS) -Iexamples/gguf-hash/deps -c $< -o $@ + +examples/gguf-hash/deps/sha256/sha256.o: \ + examples/gguf-hash/deps/sha256/sha256.c + $(CC) $(CFLAGS) -Iexamples/gguf-hash/deps -c $< -o $@ + +llama-gguf-hash: examples/gguf-hash/gguf-hash.cpp examples/gguf-hash/deps/sha1/sha1.o examples/gguf-hash/deps/xxhash/xxhash.o examples/gguf-hash/deps/sha256/sha256.o\ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -Iexamples/gguf-hash/deps -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-gguf-split: tools/gguf-split/gguf-split.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-eval-callback: examples/eval-callback/eval-callback.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-cvector-generator: tools/cvector-generator/cvector-generator.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-convert-llama2c-to-ggml: examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-bench: tools/llama-bench/llama-bench.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-export-lora: tools/export-lora/export-lora.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-retrieval: examples/retrieval/retrieval.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-speculative: examples/speculative/speculative.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-parallel: examples/parallel/parallel.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-lookahead: examples/lookahead/lookahead.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-lookup: examples/lookup/lookup.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-lookup-create: examples/lookup/lookup-create.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-lookup-merge: examples/lookup/lookup-merge.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-lookup-stats: examples/lookup/lookup-stats.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-passkey: examples/passkey/passkey.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-gbnf-validator: examples/gbnf-validator/gbnf-validator.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +ifdef GGML_RPC +rpc-server: tools/rpc/rpc-server.cpp \ + $(OBJ_GGML) + $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) +endif # GGML_RPC + +llama-server: \ + tools/server/server.cpp \ + tools/server/utils.hpp \ + tools/server/httplib.h \ + tools/server/index.html.hpp \ + tools/server/loading.html.hpp \ + common/chat.cpp \ + common/chat.h \ + common/chat-template.hpp \ + common/json.hpp \ + common/minja.hpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Itools/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2) + +# Portable equivalent of `cd tools/server/public && xxd -i $(notdir $<) ../$(notdir $<).hpp`: +tools/server/%.hpp: tools/server/public/% FORCE Makefile + @( export NAME=$(subst .,_,$(subst -,_,$(notdir $<))) && \ + echo "unsigned char $${NAME}[] = {" && \ + cat $< | od -v -t x1 -An | sed -E 's/([0-9a-fA-F]+)/0x\1, /g' && \ + echo "};" && \ + echo "unsigned int $${NAME}_len = $(shell cat $< | wc -c );" \ + ) > $@ + +llama-gen-docs: examples/gen-docs/gen-docs.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +libllava.a: tools/mtmd/llava.cpp \ + tools/mtmd/llava.h \ + tools/mtmd/clip.cpp \ + tools/mtmd/clip.h \ + common/stb_image.h \ + common/base64.hpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -static -fPIC -c $< -o $@ -Wno-cast-qual + +llama-llava-cli: tools/mtmd/llava-cli.cpp \ + tools/mtmd/llava.cpp \ + tools/mtmd/llava.h \ + tools/mtmd/clip.cpp \ + tools/mtmd/clip.h \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) $< $(filter-out %.h $<,$^) -o $@ $(LDFLAGS) -Wno-cast-qual + +llama-minicpmv-cli: tools/mtmd/minicpmv-cli.cpp \ + tools/mtmd/llava.cpp \ + tools/mtmd/llava.h \ + tools/mtmd/clip.cpp \ + tools/mtmd/clip.h \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) $< $(filter-out %.h $<,$^) -o $@ $(LDFLAGS) -Wno-cast-qual + +llama-qwen2vl-cli: tools/mtmd/qwen2vl-cli.cpp \ + tools/mtmd/llava.cpp \ + tools/mtmd/llava.h \ + tools/mtmd/clip.cpp \ + tools/mtmd/clip.h \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) $< $(filter-out %.h $<,$^) -o $@ $(LDFLAGS) -Wno-cast-qual + +ifeq ($(UNAME_S),Darwin) +swift: examples/batched.swift + (cd examples/batched.swift; make build) +endif + +common/build-info.cpp: $(wildcard .git/index) scripts/build-info.sh + @sh scripts/build-info.sh "$(CC)" > $@.tmp + @if ! cmp -s $@.tmp $@; then \ + mv $@.tmp $@; \ + else \ + rm $@.tmp; \ + fi + +common/build-info.o: common/build-info.cpp + $(CXX) $(CXXFLAGS) -c $(filter-out %.h,$^) -o $@ + +# +# Tests +# + +tests: $(TEST_TARGETS) + +tests/test-arg-parser: tests/test-arg-parser.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +tests/test-llama-grammar: tests/test-llama-grammar.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +tests/test-log: tests/test-log.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +tests/test-grammar-parser: tests/test-grammar-parser.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +tests/test-grammar-integration: tests/test-grammar-integration.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +tests/test-double-float: tests/test-double-float.cpp + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -Itools/server -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +tests/test-chat: tests/test-chat.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -Itools/server -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +tests/test-opt: tests/test-opt.cpp \ + $(OBJ_GGML) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +tests/test-quantize-fns: tests/test-quantize-fns.cpp \ + $(OBJ_GGML) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +tests/test-quantize-perf: tests/test-quantize-perf.cpp \ + $(OBJ_GGML) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +tests/test-sampling: tests/test-sampling.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +tests/test-tokenizer-0: tests/test-tokenizer-0.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +tests/test-tokenizer-1-bpe: tests/test-tokenizer-1-bpe.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +tests/test-tokenizer-1-spm: tests/test-tokenizer-1-spm.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +tests/test-rope: tests/test-rope.cpp ggml/src/ggml.o \ + $(OBJ_GGML) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +tests/test-c.o: tests/test-c.c include/llama.h + $(CC) $(CFLAGS) -c $(filter-out %.h,$^) -o $@ + +tests/test-backend-ops: tests/test-backend-ops.cpp \ + $(OBJ_GGML) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +tests/test-model-load-cancel: tests/test-model-load-cancel.cpp tests/get-model.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +tests/test-autorelease: tests/test-autorelease.cpp tests/get-model.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +tests/test-chat-template: tests/test-chat-template.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +# +# PoCs +# + +llama-vdot: pocs/vdot/vdot.cpp ggml/src/ggml.o \ + $(OBJ_GGML) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +llama-q8dot: pocs/vdot/q8dot.cpp ggml/src/ggml.o \ + $(OBJ_GGML) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +# +# Deprecated binaries that we want to keep around long enough for people to migrate to the new filenames, then these can be removed. +# +# Mark legacy binary targets as .PHONY so that they are always checked. +.PHONY: FORCE main quantize perplexity embedding server + +# Define the object file target +examples/deprecation-warning/deprecation-warning.o: examples/deprecation-warning/deprecation-warning.cpp + $(CXX) $(CXXFLAGS) -c $< -o $@ + +# NOTE: We currently will always build the deprecation-warning `main` and `server` binaries to help users migrate. +# Eventually we will want to remove these target from building all the time. +main: examples/deprecation-warning/deprecation-warning.o + $(CXX) $(CXXFLAGS) $< -o $@ $(LDFLAGS) + @echo "NOTICE: The 'main' binary is deprecated. Please use 'llama-cli' instead." + +server: examples/deprecation-warning/deprecation-warning.o + $(CXX) $(CXXFLAGS) $< -o $@ $(LDFLAGS) + @echo "NOTICE: The 'server' binary is deprecated. Please use 'llama-server' instead." + +quantize: examples/deprecation-warning/deprecation-warning.o +ifneq (,$(wildcard quantize)) + $(CXX) $(CXXFLAGS) $< -o $@ $(LDFLAGS) + @echo "#########" + @echo "WARNING: The 'quantize' binary is deprecated. Please use 'llama-quantize' instead." + @echo " Remove the 'quantize' binary to remove this warning." + @echo "#########" +endif + +perplexity: examples/deprecation-warning/deprecation-warning.o +ifneq (,$(wildcard perplexity)) + $(CXX) $(CXXFLAGS) $< -o $@ $(LDFLAGS) + @echo "#########" + @echo "WARNING: The 'perplexity' binary is deprecated. Please use 'llama-perplexity' instead." + @echo " Remove the 'perplexity' binary to remove this warning." + @echo "#########" +endif + +embedding: examples/deprecation-warning/deprecation-warning.o +ifneq (,$(wildcard embedding)) + $(CXX) $(CXXFLAGS) $< -o $@ $(LDFLAGS) + @echo "#########" + @echo "WARNING: The 'embedding' binary is deprecated. Please use 'llama-embedding' instead." + @echo " Remove the 'embedding' binary to remove this warning." + @echo "#########" +endif diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..90c7364dfcba0e2dc0a8e1a076e0f963bb7ca2ab --- /dev/null +++ b/README.md @@ -0,0 +1,603 @@ +# llama.cpp + +![llama](https://user-images.githubusercontent.com/1991296/230134379-7181e485-c521-4d23-a0d6-f7b3b61ba524.png) + +[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) +[![Release](https://img.shields.io/github/v/release/ggml-org/llama.cpp)](https://github.com/ggml-org/llama.cpp/releases) +[![Server](https://github.com/ggml-org/llama.cpp/actions/workflows/server.yml/badge.svg)](https://github.com/ggml-org/llama.cpp/actions/workflows/server.yml) + +[Roadmap](https://github.com/users/ggerganov/projects/7) / [Manifesto](https://github.com/ggml-org/llama.cpp/discussions/205) / [ggml](https://github.com/ggml-org/ggml) + +Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) in pure C/C++ + +## Recent API changes + +- [Changelog for `libllama` API](https://github.com/ggml-org/llama.cpp/issues/9289) +- [Changelog for `llama-server` REST API](https://github.com/ggml-org/llama.cpp/issues/9291) + +## Hot topics + +- 🔥 Multimodal support arrived in `llama-server`: [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) | [documentation](./docs/multimodal.md) +- A new binary `llama-mtmd-cli` is introduced to replace `llava-cli`, `minicpmv-cli`, `gemma3-cli` ([#13012](https://github.com/ggml-org/llama.cpp/pull/13012)) and `qwen2vl-cli` ([#13141](https://github.com/ggml-org/llama.cpp/pull/13141)), `libllava` will be deprecated +- VS Code extension for FIM completions: https://github.com/ggml-org/llama.vscode +- Universal [tool call support](./docs/function-calling.md) in `llama-server` https://github.com/ggml-org/llama.cpp/pull/9639 +- Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim +- Introducing GGUF-my-LoRA https://github.com/ggml-org/llama.cpp/discussions/10123 +- Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggml-org/llama.cpp/discussions/9669 +- Hugging Face GGUF editor: [discussion](https://github.com/ggml-org/llama.cpp/discussions/9268) | [tool](https://huggingface.co/spaces/CISCai/gguf-editor) + +---- + +## Quick start + +Getting started with llama.cpp is straightforward. Here are several ways to install it on your machine: + +- Install `llama.cpp` using [brew, nix or winget](docs/install.md) +- Run with Docker - see our [Docker documentation](docs/docker.md) +- Download pre-built binaries from the [releases page](https://github.com/ggml-org/llama.cpp/releases) +- Build from source by cloning this repository - check out [our build guide](docs/build.md) + +Once installed, you'll need a model to work with. Head to the [Obtaining and quantizing models](#obtaining-and-quantizing-models) section to learn more. + +Example command: + +```sh +# Use a local model file +llama-cli -m my_model.gguf + +# Or download and run a model directly from Hugging Face +llama-cli -hf ggml-org/gemma-3-1b-it-GGUF + +# Launch OpenAI-compatible API server +llama-server -hf ggml-org/gemma-3-1b-it-GGUF +``` + +## Description + +The main goal of `llama.cpp` is to enable LLM inference with minimal setup and state-of-the-art performance on a wide +range of hardware - locally and in the cloud. + +- Plain C/C++ implementation without any dependencies +- Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks +- AVX, AVX2, AVX512 and AMX support for x86 architectures +- 1.5-bit, 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use +- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads GPUs via MUSA) +- Vulkan and SYCL backend support +- CPU+GPU hybrid inference to partially accelerate models larger than the total VRAM capacity + +The `llama.cpp` project is the main playground for developing new features for the [ggml](https://github.com/ggml-org/ggml) library. + +
+Models + +Typically finetunes of the base models below are supported as well. + +Instructions for adding support for new models: [HOWTO-add-model.md](docs/development/HOWTO-add-model.md) + +#### Text-only + +- [X] LLaMA 🦙 +- [x] LLaMA 2 🦙🦙 +- [x] LLaMA 3 🦙🦙🦙 +- [X] [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1) +- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral) +- [x] [DBRX](https://huggingface.co/databricks/dbrx-instruct) +- [X] [Falcon](https://huggingface.co/models?search=tiiuae/falcon) +- [X] [Chinese LLaMA / Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca) and [Chinese LLaMA-2 / Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2) +- [X] [Vigogne (French)](https://github.com/bofenghuang/vigogne) +- [X] [BERT](https://github.com/ggml-org/llama.cpp/pull/5423) +- [X] [Koala](https://bair.berkeley.edu/blog/2023/04/03/koala/) +- [X] [Baichuan 1 & 2](https://huggingface.co/models?search=baichuan-inc/Baichuan) + [derivations](https://huggingface.co/hiyouga/baichuan-7b-sft) +- [X] [Aquila 1 & 2](https://huggingface.co/models?search=BAAI/Aquila) +- [X] [Starcoder models](https://github.com/ggml-org/llama.cpp/pull/3187) +- [X] [Refact](https://huggingface.co/smallcloudai/Refact-1_6B-fim) +- [X] [MPT](https://github.com/ggml-org/llama.cpp/pull/3417) +- [X] [Bloom](https://github.com/ggml-org/llama.cpp/pull/3553) +- [x] [Yi models](https://huggingface.co/models?search=01-ai/Yi) +- [X] [StableLM models](https://huggingface.co/stabilityai) +- [x] [Deepseek models](https://huggingface.co/models?search=deepseek-ai/deepseek) +- [x] [Qwen models](https://huggingface.co/models?search=Qwen/Qwen) +- [x] [PLaMo-13B](https://github.com/ggml-org/llama.cpp/pull/3557) +- [x] [Phi models](https://huggingface.co/models?search=microsoft/phi) +- [x] [PhiMoE](https://github.com/ggml-org/llama.cpp/pull/11003) +- [x] [GPT-2](https://huggingface.co/gpt2) +- [x] [Orion 14B](https://github.com/ggml-org/llama.cpp/pull/5118) +- [x] [InternLM2](https://huggingface.co/models?search=internlm2) +- [x] [CodeShell](https://github.com/WisdomShell/codeshell) +- [x] [Gemma](https://ai.google.dev/gemma) +- [x] [Mamba](https://github.com/state-spaces/mamba) +- [x] [Grok-1](https://huggingface.co/keyfan/grok-1-hf) +- [x] [Xverse](https://huggingface.co/models?search=xverse) +- [x] [Command-R models](https://huggingface.co/models?search=CohereForAI/c4ai-command-r) +- [x] [SEA-LION](https://huggingface.co/models?search=sea-lion) +- [x] [GritLM-7B](https://huggingface.co/GritLM/GritLM-7B) + [GritLM-8x7B](https://huggingface.co/GritLM/GritLM-8x7B) +- [x] [OLMo](https://allenai.org/olmo) +- [x] [OLMo 2](https://allenai.org/olmo) +- [x] [OLMoE](https://huggingface.co/allenai/OLMoE-1B-7B-0924) +- [x] [Granite models](https://huggingface.co/collections/ibm-granite/granite-code-models-6624c5cec322e4c148c8b330) +- [x] [GPT-NeoX](https://github.com/EleutherAI/gpt-neox) + [Pythia](https://github.com/EleutherAI/pythia) +- [x] [Snowflake-Arctic MoE](https://huggingface.co/collections/Snowflake/arctic-66290090abe542894a5ac520) +- [x] [Smaug](https://huggingface.co/models?search=Smaug) +- [x] [Poro 34B](https://huggingface.co/LumiOpen/Poro-34B) +- [x] [Bitnet b1.58 models](https://huggingface.co/1bitLLM) +- [x] [Flan T5](https://huggingface.co/models?search=flan-t5) +- [x] [Open Elm models](https://huggingface.co/collections/apple/openelm-instruct-models-6619ad295d7ae9f868b759ca) +- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b) + [GLMEdge-1.5b](https://huggingface.co/THUDM/glm-edge-1.5b-chat) + [GLMEdge-4b](https://huggingface.co/THUDM/glm-edge-4b-chat) +- [x] [GLM-4-0414](https://huggingface.co/collections/THUDM/glm-4-0414-67f3cbcb34dd9d252707cb2e) +- [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966) +- [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct) +- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a) +- [x] [Jais](https://huggingface.co/inceptionai/jais-13b-chat) +- [x] [Bielik-11B-v2.3](https://huggingface.co/collections/speakleash/bielik-11b-v23-66ee813238d9b526a072408a) +- [x] [RWKV-6](https://github.com/BlinkDL/RWKV-LM) +- [x] [QRWKV-6](https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1) +- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct) +- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview) +- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32) + +#### Multimodal + +- [x] [LLaVA 1.5 models](https://huggingface.co/collections/liuhaotian/llava-15-653aac15d994e992e2677a7e), [LLaVA 1.6 models](https://huggingface.co/collections/liuhaotian/llava-16-65b9e40155f60fd046a5ccf2) +- [x] [BakLLaVA](https://huggingface.co/models?search=SkunkworksAI/Bakllava) +- [x] [Obsidian](https://huggingface.co/NousResearch/Obsidian-3B-V0.5) +- [x] [ShareGPT4V](https://huggingface.co/models?search=Lin-Chen/ShareGPT4V) +- [x] [MobileVLM 1.7B/3B models](https://huggingface.co/models?search=mobileVLM) +- [x] [Yi-VL](https://huggingface.co/models?search=Yi-VL) +- [x] [Mini CPM](https://huggingface.co/models?search=MiniCPM) +- [x] [Moondream](https://huggingface.co/vikhyatk/moondream2) +- [x] [Bunny](https://github.com/BAAI-DCAI/Bunny) +- [x] [GLM-EDGE](https://huggingface.co/models?search=glm-edge) +- [x] [Qwen2-VL](https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d) + +
+ +
+Bindings + +- Python: [ddh0/easy-llama](https://github.com/ddh0/easy-llama) +- Python: [abetlen/llama-cpp-python](https://github.com/abetlen/llama-cpp-python) +- Go: [go-skynet/go-llama.cpp](https://github.com/go-skynet/go-llama.cpp) +- Node.js: [withcatai/node-llama-cpp](https://github.com/withcatai/node-llama-cpp) +- JS/TS (llama.cpp server client): [lgrammel/modelfusion](https://modelfusion.dev/integration/model-provider/llamacpp) +- JS/TS (Programmable Prompt Engine CLI): [offline-ai/cli](https://github.com/offline-ai/cli) +- JavaScript/Wasm (works in browser): [tangledgroup/llama-cpp-wasm](https://github.com/tangledgroup/llama-cpp-wasm) +- Typescript/Wasm (nicer API, available on npm): [ngxson/wllama](https://github.com/ngxson/wllama) +- Ruby: [yoshoku/llama_cpp.rb](https://github.com/yoshoku/llama_cpp.rb) +- Rust (more features): [edgenai/llama_cpp-rs](https://github.com/edgenai/llama_cpp-rs) +- Rust (nicer API): [mdrokz/rust-llama.cpp](https://github.com/mdrokz/rust-llama.cpp) +- Rust (more direct bindings): [utilityai/llama-cpp-rs](https://github.com/utilityai/llama-cpp-rs) +- Rust (automated build from crates.io): [ShelbyJenkins/llm_client](https://github.com/ShelbyJenkins/llm_client) +- C#/.NET: [SciSharp/LLamaSharp](https://github.com/SciSharp/LLamaSharp) +- C#/VB.NET (more features - community license): [LM-Kit.NET](https://docs.lm-kit.com/lm-kit-net/index.html) +- Scala 3: [donderom/llm4s](https://github.com/donderom/llm4s) +- Clojure: [phronmophobic/llama.clj](https://github.com/phronmophobic/llama.clj) +- React Native: [mybigday/llama.rn](https://github.com/mybigday/llama.rn) +- Java: [kherud/java-llama.cpp](https://github.com/kherud/java-llama.cpp) +- Zig: [deins/llama.cpp.zig](https://github.com/Deins/llama.cpp.zig) +- Flutter/Dart: [netdur/llama_cpp_dart](https://github.com/netdur/llama_cpp_dart) +- Flutter: [xuegao-tzx/Fllama](https://github.com/xuegao-tzx/Fllama) +- PHP (API bindings and features built on top of llama.cpp): [distantmagic/resonance](https://github.com/distantmagic/resonance) [(more info)](https://github.com/ggml-org/llama.cpp/pull/6326) +- Guile Scheme: [guile_llama_cpp](https://savannah.nongnu.org/projects/guile-llama-cpp) +- Swift [srgtuszy/llama-cpp-swift](https://github.com/srgtuszy/llama-cpp-swift) +- Swift [ShenghaiWang/SwiftLlama](https://github.com/ShenghaiWang/SwiftLlama) +- Delphi [Embarcadero/llama-cpp-delphi](https://github.com/Embarcadero/llama-cpp-delphi) + +
+ +
+UIs + +*(to have a project listed here, it should clearly state that it depends on `llama.cpp`)* + +- [AI Sublime Text plugin](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (MIT) +- [cztomsik/ava](https://github.com/cztomsik/ava) (MIT) +- [Dot](https://github.com/alexpinel/Dot) (GPL) +- [eva](https://github.com/ylsdamxssjxxdd/eva) (MIT) +- [iohub/collama](https://github.com/iohub/coLLaMA) (Apache-2.0) +- [janhq/jan](https://github.com/janhq/jan) (AGPL) +- [johnbean393/Sidekick](https://github.com/johnbean393/Sidekick) (MIT) +- [KanTV](https://github.com/zhouwg/kantv?tab=readme-ov-file) (Apache-2.0) +- [KodiBot](https://github.com/firatkiral/kodibot) (GPL) +- [llama.vim](https://github.com/ggml-org/llama.vim) (MIT) +- [LARS](https://github.com/abgulati/LARS) (AGPL) +- [Llama Assistant](https://github.com/vietanhdev/llama-assistant) (GPL) +- [LLMFarm](https://github.com/guinmoon/LLMFarm?tab=readme-ov-file) (MIT) +- [LLMUnity](https://github.com/undreamai/LLMUnity) (MIT) +- [LMStudio](https://lmstudio.ai/) (proprietary) +- [LocalAI](https://github.com/mudler/LocalAI) (MIT) +- [LostRuins/koboldcpp](https://github.com/LostRuins/koboldcpp) (AGPL) +- [MindMac](https://mindmac.app) (proprietary) +- [MindWorkAI/AI-Studio](https://github.com/MindWorkAI/AI-Studio) (FSL-1.1-MIT) +- [Mobile-Artificial-Intelligence/maid](https://github.com/Mobile-Artificial-Intelligence/maid) (MIT) +- [Mozilla-Ocho/llamafile](https://github.com/Mozilla-Ocho/llamafile) (Apache-2.0) +- [nat/openplayground](https://github.com/nat/openplayground) (MIT) +- [nomic-ai/gpt4all](https://github.com/nomic-ai/gpt4all) (MIT) +- [ollama/ollama](https://github.com/ollama/ollama) (MIT) +- [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) (AGPL) +- [PocketPal AI](https://github.com/a-ghorbani/pocketpal-ai) (MIT) +- [psugihara/FreeChat](https://github.com/psugihara/FreeChat) (MIT) +- [ptsochantaris/emeltal](https://github.com/ptsochantaris/emeltal) (MIT) +- [pythops/tenere](https://github.com/pythops/tenere) (AGPL) +- [ramalama](https://github.com/containers/ramalama) (MIT) +- [semperai/amica](https://github.com/semperai/amica) (MIT) +- [withcatai/catai](https://github.com/withcatai/catai) (MIT) +- [Autopen](https://github.com/blackhole89/autopen) (GPL) + +
+ +
+Tools + +- [akx/ggify](https://github.com/akx/ggify) – download PyTorch models from HuggingFace Hub and convert them to GGML +- [akx/ollama-dl](https://github.com/akx/ollama-dl) – download models from the Ollama library to be used directly with llama.cpp +- [crashr/gppm](https://github.com/crashr/gppm) – launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption +- [gpustack/gguf-parser](https://github.com/gpustack/gguf-parser-go/tree/main/cmd/gguf-parser) - review/check the GGUF file and estimate the memory usage +- [Styled Lines](https://marketplace.unity.com/packages/tools/generative-ai/styled-lines-llama-cpp-model-292902) (proprietary licensed, async wrapper of inference part for game development in Unity3d with pre-built Mobile and Web platform wrappers and a model example) + +
+ +
+Infrastructure + +- [Paddler](https://github.com/distantmagic/paddler) - Stateful load balancer custom-tailored for llama.cpp +- [GPUStack](https://github.com/gpustack/gpustack) - Manage GPU clusters for running LLMs +- [llama_cpp_canister](https://github.com/onicai/llama_cpp_canister) - llama.cpp as a smart contract on the Internet Computer, using WebAssembly +- [llama-swap](https://github.com/mostlygeek/llama-swap) - transparent proxy that adds automatic model switching with llama-server +- [Kalavai](https://github.com/kalavai-net/kalavai-client) - Crowdsource end to end LLM deployment at any scale +- [llmaz](https://github.com/InftyAI/llmaz) - ☸️ Easy, advanced inference platform for large language models on Kubernetes. +
+ +
+Games + +- [Lucy's Labyrinth](https://github.com/MorganRO8/Lucys_Labyrinth) - A simple maze game where agents controlled by an AI model will try to trick you. + +
+ + +## Supported backends + +| Backend | Target devices | +| --- | --- | +| [Metal](docs/build.md#metal-build) | Apple Silicon | +| [BLAS](docs/build.md#blas-build) | All | +| [BLIS](docs/backend/BLIS.md) | All | +| [SYCL](docs/backend/SYCL.md) | Intel and Nvidia GPU | +| [MUSA](docs/build.md#musa) | Moore Threads GPU | +| [CUDA](docs/build.md#cuda) | Nvidia GPU | +| [HIP](docs/build.md#hip) | AMD GPU | +| [Vulkan](docs/build.md#vulkan) | GPU | +| [CANN](docs/build.md#cann) | Ascend NPU | +| [OpenCL](docs/backend/OPENCL.md) | Adreno GPU | +| [RPC](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) | All | + +## Obtaining and quantizing models + +The [Hugging Face](https://huggingface.co) platform hosts a [number of LLMs](https://huggingface.co/models?library=gguf&sort=trending) compatible with `llama.cpp`: + +- [Trending](https://huggingface.co/models?library=gguf&sort=trending) +- [LLaMA](https://huggingface.co/models?sort=trending&search=llama+gguf) + +You can either manually download the GGUF file or directly use any `llama.cpp`-compatible models from [Hugging Face](https://huggingface.co/) or other model hosting sites, such as [ModelScope](https://modelscope.cn/), by using this CLI argument: `-hf /[:quant]`. For example: + +```sh +llama-cli -hf ggml-org/gemma-3-1b-it-GGUF +``` + +By default, the CLI would download from Hugging Face, you can switch to other options with the environment variable `MODEL_ENDPOINT`. For example, you may opt to downloading model checkpoints from ModelScope or other model sharing communities by setting the environment variable, e.g. `MODEL_ENDPOINT=https://www.modelscope.cn/`. + +After downloading a model, use the CLI tools to run it locally - see below. + +`llama.cpp` requires the model to be stored in the [GGUF](https://github.com/ggml-org/ggml/blob/master/docs/gguf.md) file format. Models in other data formats can be converted to GGUF using the `convert_*.py` Python scripts in this repo. + +The Hugging Face platform provides a variety of online tools for converting, quantizing and hosting models with `llama.cpp`: + +- Use the [GGUF-my-repo space](https://huggingface.co/spaces/ggml-org/gguf-my-repo) to convert to GGUF format and quantize model weights to smaller sizes +- Use the [GGUF-my-LoRA space](https://huggingface.co/spaces/ggml-org/gguf-my-lora) to convert LoRA adapters to GGUF format (more info: https://github.com/ggml-org/llama.cpp/discussions/10123) +- Use the [GGUF-editor space](https://huggingface.co/spaces/CISCai/gguf-editor) to edit GGUF meta data in the browser (more info: https://github.com/ggml-org/llama.cpp/discussions/9268) +- Use the [Inference Endpoints](https://ui.endpoints.huggingface.co/) to directly host `llama.cpp` in the cloud (more info: https://github.com/ggml-org/llama.cpp/discussions/9669) + +To learn more about model quantization, [read this documentation](tools/quantize/README.md) + +## [`llama-cli`](tools/main) + +#### A CLI tool for accessing and experimenting with most of `llama.cpp`'s functionality. + +-
+ Run in conversation mode + + Models with a built-in chat template will automatically activate conversation mode. If this doesn't occur, you can manually enable it by adding `-cnv` and specifying a suitable chat template with `--chat-template NAME` + + ```bash + llama-cli -m model.gguf + + # > hi, who are you? + # Hi there! I'm your helpful assistant! I'm an AI-powered chatbot designed to assist and provide information to users like you. I'm here to help answer your questions, provide guidance, and offer support on a wide range of topics. I'm a friendly and knowledgeable AI, and I'm always happy to help with anything you need. What's on your mind, and how can I assist you today? + # + # > what is 1+1? + # Easy peasy! The answer to 1+1 is... 2! + ``` + +
+ +-
+ Run in conversation mode with custom chat template + + ```bash + # use the "chatml" template (use -h to see the list of supported templates) + llama-cli -m model.gguf -cnv --chat-template chatml + + # use a custom template + llama-cli -m model.gguf -cnv --in-prefix 'User: ' --reverse-prompt 'User:' + ``` + +
+ +-
+ Run simple text completion + + To disable conversation mode explicitly, use `-no-cnv` + + ```bash + llama-cli -m model.gguf -p "I believe the meaning of life is" -n 128 -no-cnv + + # I believe the meaning of life is to find your own truth and to live in accordance with it. For me, this means being true to myself and following my passions, even if they don't align with societal expectations. I think that's what I love about yoga – it's not just a physical practice, but a spiritual one too. It's about connecting with yourself, listening to your inner voice, and honoring your own unique journey. + ``` + +
+ +-
+ Constrain the output with a custom grammar + + ```bash + llama-cli -m model.gguf -n 256 --grammar-file grammars/json.gbnf -p 'Request: schedule a call at 8pm; Command:' + + # {"appointmentTime": "8pm", "appointmentDetails": "schedule a a call"} + ``` + + The [grammars/](grammars/) folder contains a handful of sample grammars. To write your own, check out the [GBNF Guide](grammars/README.md). + + For authoring more complex JSON grammars, check out https://grammar.intrinsiclabs.ai/ + +
+ + +## [`llama-server`](tools/server) + +#### A lightweight, [OpenAI API](https://github.com/openai/openai-openapi) compatible, HTTP server for serving LLMs. + +-
+ Start a local HTTP server with default configuration on port 8080 + + ```bash + llama-server -m model.gguf --port 8080 + + # Basic web UI can be accessed via browser: http://localhost:8080 + # Chat completion endpoint: http://localhost:8080/v1/chat/completions + ``` + +
+ +-
+ Support multiple-users and parallel decoding + + ```bash + # up to 4 concurrent requests, each with 4096 max context + llama-server -m model.gguf -c 16384 -np 4 + ``` + +
+ +-
+ Enable speculative decoding + + ```bash + # the draft.gguf model should be a small variant of the target model.gguf + llama-server -m model.gguf -md draft.gguf + ``` + +
+ +-
+ Serve an embedding model + + ```bash + # use the /embedding endpoint + llama-server -m model.gguf --embedding --pooling cls -ub 8192 + ``` + +
+ +-
+ Serve a reranking model + + ```bash + # use the /reranking endpoint + llama-server -m model.gguf --reranking + ``` + +
+ +-
+ Constrain all outputs with a grammar + + ```bash + # custom grammar + llama-server -m model.gguf --grammar-file grammar.gbnf + + # JSON + llama-server -m model.gguf --grammar-file grammars/json.gbnf + ``` + +
+ + +## [`llama-perplexity`](tools/perplexity) + +#### A tool for measuring the perplexity [^1][^2] (and other quality metrics) of a model over a given text. + +-
+ Measure the perplexity over a text file + + ```bash + llama-perplexity -m model.gguf -f file.txt + + # [1]15.2701,[2]5.4007,[3]5.3073,[4]6.2965,[5]5.8940,[6]5.6096,[7]5.7942,[8]4.9297, ... + # Final estimate: PPL = 5.4007 +/- 0.67339 + ``` + +
+ +-
+ Measure KL divergence + + ```bash + # TODO + ``` + +
+ +[^1]: [tools/perplexity/README.md](./tools/perplexity/README.md) +[^2]: [https://huggingface.co/docs/transformers/perplexity](https://huggingface.co/docs/transformers/perplexity) + +## [`llama-bench`](tools/llama-bench) + +#### Benchmark the performance of the inference for various parameters. + +-
+ Run default benchmark + + ```bash + llama-bench -m model.gguf + + # Output: + # | model | size | params | backend | threads | test | t/s | + # | ------------------- | ---------: | ---------: | ---------- | ------: | ------------: | -------------------: | + # | qwen2 1.5B Q4_0 | 885.97 MiB | 1.54 B | Metal,BLAS | 16 | pp512 | 5765.41 ± 20.55 | + # | qwen2 1.5B Q4_0 | 885.97 MiB | 1.54 B | Metal,BLAS | 16 | tg128 | 197.71 ± 0.81 | + # + # build: 3e0ba0e60 (4229) + ``` + +
+ +## [`llama-run`](tools/run) + +#### A comprehensive example for running `llama.cpp` models. Useful for inferencing. Used with RamaLama [^3]. + +-
+ Run a model with a specific prompt (by default it's pulled from Ollama registry) + + ```bash + llama-run granite-code + ``` + +
+ +[^3]: [RamaLama](https://github.com/containers/ramalama) + +## [`llama-simple`](examples/simple) + +#### A minimal example for implementing apps with `llama.cpp`. Useful for developers. + +-
+ Basic text completion + + ```bash + llama-simple -m model.gguf + + # Hello my name is Kaitlyn and I am a 16 year old girl. I am a junior in high school and I am currently taking a class called "The Art of + ``` + +
+ + +## Contributing + +- Contributors can open PRs +- Collaborators can push to branches in the `llama.cpp` repo and merge PRs into the `master` branch +- Collaborators will be invited based on contributions +- Any help with managing issues, PRs and projects is very appreciated! +- See [good first issues](https://github.com/ggml-org/llama.cpp/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) for tasks suitable for first contributions +- Read the [CONTRIBUTING.md](CONTRIBUTING.md) for more information +- Make sure to read this: [Inference at the edge](https://github.com/ggml-org/llama.cpp/discussions/205) +- A bit of backstory for those who are interested: [Changelog podcast](https://changelog.com/podcast/532) + +## Other documentation + +- [main (cli)](tools/main/README.md) +- [server](tools/server/README.md) +- [GBNF grammars](grammars/README.md) + +#### Development documentation + +- [How to build](docs/build.md) +- [Running on Docker](docs/docker.md) +- [Build on Android](docs/android.md) +- [Performance troubleshooting](docs/development/token_generation_performance_tips.md) +- [GGML tips & tricks](https://github.com/ggml-org/llama.cpp/wiki/GGML-Tips-&-Tricks) + +#### Seminal papers and background on the models + +If your issue is with model generation quality, then please at least scan the following links and papers to understand the limitations of LLaMA models. This is especially important when choosing an appropriate model size and appreciating both the significant and subtle differences between LLaMA models and ChatGPT: +- LLaMA: + - [Introducing LLaMA: A foundational, 65-billion-parameter large language model](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) + - [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) +- GPT-3 + - [Language Models are Few-Shot Learners](https://arxiv.org/abs/2005.14165) +- GPT-3.5 / InstructGPT / ChatGPT: + - [Aligning language models to follow instructions](https://openai.com/research/instruction-following) + - [Training language models to follow instructions with human feedback](https://arxiv.org/abs/2203.02155) + +## XCFramework +The XCFramework is a precompiled version of the library for iOS, visionOS, tvOS, +and macOS. It can be used in Swift projects without the need to compile the +library from source. For example: +```swift +// swift-tools-version: 5.10 +// The swift-tools-version declares the minimum version of Swift required to build this package. + +import PackageDescription + +let package = Package( + name: "MyLlamaPackage", + targets: [ + .executableTarget( + name: "MyLlamaPackage", + dependencies: [ + "LlamaFramework" + ]), + .binaryTarget( + name: "LlamaFramework", + url: "https://github.com/ggml-org/llama.cpp/releases/download/b5046/llama-b5046-xcframework.zip", + checksum: "c19be78b5f00d8d29a25da41042cb7afa094cbf6280a225abe614b03b20029ab" + ) + ] +) +``` +The above example is using an intermediate build `b5046` of the library. This can be modified +to use a different version by changing the URL and checksum. + +## Completions +Command-line completion is available for some environments. + +#### Bash Completion +```bash +$ build/bin/llama-cli --completion-bash > ~/.llama-completion.bash +$ source ~/.llama-completion.bash +``` +Optionally this can be added to your `.bashrc` or `.bash_profile` to load it +automatically. For example: +```console +$ echo "source ~/.llama-completion.bash" >> ~/.bashrc +``` + +## Dependencies + +- [yhirose/cpp-httplib](https://github.com/yhirose/cpp-httplib) - Single-header HTTP server, used by `llama-server` - MIT license +- [stb-image](https://github.com/nothings/stb) - Single-header image format decoder, used by multimodal subsystem - Public domain +- [nlohmann/json](https://github.com/nlohmann/json) - Single-header JSON library, used by various tools/examples - MIT License +- [minja](https://github.com/google/minja) - Minimal Jinja parser in C++, used by various tools/examples - MIT License +- [linenoise.cpp](./tools/run/linenoise.cpp/linenoise.cpp) - C++ library that provides readline-like line editing capabilities, used by `llama-run` - BSD 2-Clause License +- [curl](https://curl.se/) - Client-side URL transfer library, used by various tools/examples - [CURL License](https://curl.se/docs/copyright.html) +- [miniaudio.h](https://github.com/mackron/miniaudio) - Single-header audio format decoder, used by multimodal subsystem - Public domain diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000000000000000000000000000000000000..9749e95b715a736c7e31a4daa75d489b001d6895 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,68 @@ +# Security Policy + + - [**Using llama.cpp securely**](#using-llamacpp-securely) + - [Untrusted models](#untrusted-models) + - [Untrusted inputs](#untrusted-inputs) + - [Data privacy](#data-privacy) + - [Untrusted environments or networks](#untrusted-environments-or-networks) + - [Multi-Tenant environments](#multi-tenant-environments) + - [**Reporting a vulnerability**](#reporting-a-vulnerability) + +## Using llama.cpp securely + +### Untrusted models +Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources. + +*Always execute untrusted models within a secure, isolated environment such as a sandbox* (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. + +> [!NOTE] +> The trustworthiness of a model is not binary. You must always determine the proper level of caution depending on the specific model and how it matches your use case and risk tolerance. + +### Untrusted inputs + +Some models accept various input formats (text, images, audio, etc.). The libraries converting these inputs have varying security levels, so it's crucial to isolate the model and carefully pre-process inputs to mitigate script injection risks. + +For maximum security when handling untrusted inputs, you may need to employ the following: + +* Sandboxing: Isolate the environment where the inference happens. +* Pre-analysis: Check how the model performs by default when exposed to prompt injection (e.g. using [fuzzing for prompt injection](https://github.com/FonduAI/awesome-prompt-injection?tab=readme-ov-file#tools)). This will give you leads on how hard you will have to work on the next topics. +* Updates: Keep both LLaMA C++ and your libraries updated with the latest security patches. +* Input Sanitation: Before feeding data to the model, sanitize inputs rigorously. This involves techniques such as: + * Validation: Enforce strict rules on allowed characters and data types. + * Filtering: Remove potentially malicious scripts or code fragments. + * Encoding: Convert special characters into safe representations. + * Verification: Run tooling that identifies potential script injections (e.g. [models that detect prompt injection attempts](https://python.langchain.com/docs/guides/safety/hugging_face_prompt_injection)). + +### Data privacy + +To protect sensitive data from potential leaks or unauthorized access, it is crucial to sandbox the model execution. This means running the model in a secure, isolated environment, which helps mitigate many attack vectors. + +### Untrusted environments or networks + +If you can't run your models in a secure and isolated environment or if it must be exposed to an untrusted network, make sure to take the following security precautions: +* Do not use the RPC backend, [rpc-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) and [llama-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/server) functionality (see https://github.com/ggml-org/llama.cpp/pull/13061). +* Confirm the hash of any downloaded artifact (e.g. pre-trained model weights) matches a known-good value. +* Encrypt your data if sending it over the network. + +### Multi-Tenant environments + +If you intend to run multiple models in parallel with shared memory, it is your responsibility to ensure the models do not interact or access each other's data. The primary areas of concern are tenant isolation, resource allocation, model sharing and hardware attacks. + +1. Tenant Isolation: Models should run separately with strong isolation methods to prevent unwanted data access. Separating networks is crucial for isolation, as it prevents unauthorized access to data or models and malicious users from sending graphs to execute under another tenant's identity. + +2. Resource Allocation: A denial of service caused by one model can impact the overall system health. Implement safeguards like rate limits, access controls, and health monitoring. + +3. Model Sharing: In a multitenant model sharing design, tenants and users must understand the security risks of running code provided by others. Since there are no reliable methods to detect malicious models, sandboxing the model execution is the recommended approach to mitigate the risk. + +4. Hardware Attacks: GPUs or TPUs can also be attacked. [Researches](https://scholar.google.com/scholar?q=gpu+side+channel) has shown that side channel attacks on GPUs are possible, which can make data leak from other models or processes running on the same system at the same time. + +## Reporting a vulnerability + +Beware that none of the topics under [Using llama.cpp securely](#using-llamacpp-securely) are considered vulnerabilities of LLaMA C++. + + +However, If you have discovered a security vulnerability in this project, please report it privately. **Do not disclose it as a public issue.** This gives us time to work with you to fix the issue before public exposure, reducing the chance that the exploit will be used before a patch is released. + +Please disclose it as a private [security advisory](https://github.com/ggml-org/llama.cpp/security/advisories/new). + +A team of volunteers on a reasonable-effort basis maintains this project. As such, please give us at least 90 days to work on a fix before public exposure. diff --git a/ci/README.md b/ci/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6e297f1a82788096aa803848766a350c7fa81af8 --- /dev/null +++ b/ci/README.md @@ -0,0 +1,68 @@ +# CI + +In addition to [Github Actions](https://github.com/ggml-org/llama.cpp/actions) `llama.cpp` uses a custom CI framework: + +https://github.com/ggml-org/ci + +It monitors the `master` branch for new commits and runs the +[ci/run.sh](https://github.com/ggml-org/llama.cpp/blob/master/ci/run.sh) script on dedicated cloud instances. This allows us +to execute heavier workloads compared to just using Github Actions. Also with time, the cloud instances will be scaled +to cover various hardware architectures, including GPU and Apple Silicon instances. + +Collaborators can optionally trigger the CI run by adding the `ggml-ci` keyword to their commit message. +Only the branches of this repo are monitored for this keyword. + +It is a good practice, before publishing changes to execute the full CI locally on your machine: + +```bash +mkdir tmp + +# CPU-only build +bash ./ci/run.sh ./tmp/results ./tmp/mnt + +# with CUDA support +GG_BUILD_CUDA=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt + +# with SYCL support +source /opt/intel/oneapi/setvars.sh +GG_BUILD_SYCL=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt + +# with MUSA support +GG_BUILD_MUSA=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt +``` + +## Running MUSA CI in a Docker Container + +Assuming `$PWD` is the root of the `llama.cpp` repository, follow these steps to set up and run MUSA CI in a Docker container: + +### 1. Create a local directory to store cached models, configuration files and venv: + +```bash +mkdir -p $HOME/llama.cpp/ci-cache +``` + +### 2. Create a local directory to store CI run results: + +```bash +mkdir -p $HOME/llama.cpp/ci-results +``` + +### 3. Start a Docker container and run the CI: + +```bash +docker run --privileged -it \ + -v $HOME/llama.cpp/ci-cache:/ci-cache \ + -v $HOME/llama.cpp/ci-results:/ci-results \ + -v $PWD:/ws -w /ws \ + mthreads/musa:rc4.0.1-mudnn-devel-ubuntu22.04 +``` + +Inside the container, execute the following commands: + +```bash +apt update -y && apt install -y bc cmake ccache git python3.10-venv time unzip wget +git config --global --add safe.directory /ws +GG_BUILD_MUSA=1 bash ./ci/run.sh /ci-results /ci-cache +``` + +This setup ensures that the CI runs within an isolated Docker environment while maintaining cached files and results across runs. diff --git a/ci/run.sh b/ci/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..94005570511b6a3d2fe7ea3c95bff76112419c6f --- /dev/null +++ b/ci/run.sh @@ -0,0 +1,887 @@ +#!/bin/bash +# +# sample usage: +# +# mkdir tmp +# +# # CPU-only build +# bash ./ci/run.sh ./tmp/results ./tmp/mnt +# +# # with CUDA support +# GG_BUILD_CUDA=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt +# +# # with SYCL support +# GG_BUILD_SYCL=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt +# +# # with VULKAN support +# GG_BUILD_VULKAN=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt +# +# # with MUSA support +# GG_BUILD_MUSA=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt +# + +if [ -z "$2" ]; then + echo "usage: $0 " + exit 1 +fi + +mkdir -p "$1" +mkdir -p "$2" + +OUT=$(realpath "$1") +MNT=$(realpath "$2") + +rm -f "$OUT/*.log" +rm -f "$OUT/*.exit" +rm -f "$OUT/*.md" + +sd=`dirname $0` +cd $sd/../ +SRC=`pwd` + +CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON" + +if [ ! -z ${GG_BUILD_METAL} ]; then + CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON -DGGML_METAL_USE_BF16=ON" +fi + +if [ ! -z ${GG_BUILD_CUDA} ]; then + CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_CUDA=ON" + + if command -v nvidia-smi >/dev/null 2>&1; then + CUDA_ARCH=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits 2>/dev/null | head -1 | tr -d '.') + if [[ -n "$CUDA_ARCH" && "$CUDA_ARCH" =~ ^[0-9]+$ ]]; then + CMAKE_EXTRA="${CMAKE_EXTRA} -DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCH}" + else + echo "Warning: Using fallback CUDA architectures" + CMAKE_EXTRA="${CMAKE_EXTRA} -DCMAKE_CUDA_ARCHITECTURES=61;70;75;80;86;89" + fi + else + echo "Error: nvidia-smi not found, cannot build with CUDA" + exit 1 + fi +fi + +if [ ! -z ${GG_BUILD_SYCL} ]; then + if [ -z ${ONEAPI_ROOT} ]; then + echo "Not detected ONEAPI_ROOT, please install oneAPI base toolkit and enable it by:" + echo "source /opt/intel/oneapi/setvars.sh" + exit 1 + fi + # Use only main GPU + export ONEAPI_DEVICE_SELECTOR="level_zero:0" + # Enable sysman for correct memory reporting + export ZES_ENABLE_SYSMAN=1 + # to circumvent precision issues on CPY operations + export SYCL_PROGRAM_COMPILE_OPTIONS="-cl-fp32-correctly-rounded-divide-sqrt" + CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_SYCL=1 -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON" +fi + +if [ ! -z ${GG_BUILD_VULKAN} ]; then + CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_VULKAN=1" +fi + +if [ ! -z ${GG_BUILD_MUSA} ]; then + # Use qy1 by default (MTT S80) + MUSA_ARCH=${MUSA_ARCH:-21} + CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_MUSA=ON -DMUSA_ARCHITECTURES=${MUSA_ARCH}" +fi +## helpers + +# download a file if it does not exist or if it is outdated +function gg_wget { + local out=$1 + local url=$2 + + local cwd=`pwd` + + mkdir -p $out + cd $out + + # should not re-download if file is the same + wget -nv -N $url + + cd $cwd +} + +function gg_printf { + printf -- "$@" >> $OUT/README.md +} + +function gg_run { + ci=$1 + + set -o pipefail + set -x + + gg_run_$ci | tee $OUT/$ci.log + cur=$? + echo "$cur" > $OUT/$ci.exit + + set +x + set +o pipefail + + gg_sum_$ci + + ret=$((ret | cur)) +} + +## ci + +# ctest_debug + +function gg_run_ctest_debug { + cd ${SRC} + + rm -rf build-ci-debug && mkdir build-ci-debug && cd build-ci-debug + + set -e + + # Check cmake, make and ctest are installed + gg_check_build_requirements + + (time cmake -DCMAKE_BUILD_TYPE=Debug ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log + + (time ctest --output-on-failure -L main -E test-opt ) 2>&1 | tee -a $OUT/${ci}-ctest.log + + set +e +} + +function gg_sum_ctest_debug { + gg_printf '### %s\n\n' "${ci}" + + gg_printf 'Runs ctest in debug mode\n' + gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" + gg_printf '```\n' + gg_printf '%s\n' "$(cat $OUT/${ci}-ctest.log)" + gg_printf '```\n' + gg_printf '\n' +} + +# ctest_release + +function gg_run_ctest_release { + cd ${SRC} + + rm -rf build-ci-release && mkdir build-ci-release && cd build-ci-release + + set -e + + # Check cmake, make and ctest are installed + gg_check_build_requirements + + (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log + + if [ -z ${GG_BUILD_LOW_PERF} ]; then + (time ctest --output-on-failure -L main ) 2>&1 | tee -a $OUT/${ci}-ctest.log + else + (time ctest --output-on-failure -L main -E test-opt ) 2>&1 | tee -a $OUT/${ci}-ctest.log + fi + + set +e +} + +function gg_sum_ctest_release { + gg_printf '### %s\n\n' "${ci}" + + gg_printf 'Runs ctest in release mode\n' + gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" + gg_printf '```\n' + gg_printf '%s\n' "$(cat $OUT/${ci}-ctest.log)" + gg_printf '```\n' +} + +# test_scripts_debug + +function gg_run_test_scripts_debug { + cd ${SRC} + + set -e + + (cd ./tools/gguf-split && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log + (cd ./tools/quantize && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log + + set +e +} + +function gg_sum_test_scripts_debug { + gg_printf '### %s\n\n' "${ci}" + + gg_printf 'Runs test scripts in debug mode\n' + gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" + gg_printf '```\n' + gg_printf '%s\n' "$(cat $OUT/${ci}-scripts.log)" + gg_printf '```\n' + gg_printf '\n' +} + +# test_scripts_release + +function gg_run_test_scripts_release { + cd ${SRC} + + set -e + + (cd ./tools/gguf-split && time bash tests.sh "$SRC/build-ci-release/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log + (cd ./tools/quantize && time bash tests.sh "$SRC/build-ci-release/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log + + set +e +} + +function gg_sum_test_scripts_release { + gg_printf '### %s\n\n' "${ci}" + + gg_printf 'Runs test scripts in release mode\n' + gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" + gg_printf '```\n' + gg_printf '%s\n' "$(cat $OUT/${ci}-scripts.log)" + gg_printf '```\n' + gg_printf '\n' +} + +function gg_get_model { + local gguf_0="$MNT/models/pythia/1.4B/ggml-model-f16.gguf" + local gguf_1="$MNT/models/pythia/2.8B/ggml-model-f16.gguf" + local gguf_2="$MNT/models/open-llama/7B-v2/ggml-model-f16.gguf" + if [[ -s $gguf_0 ]]; then + echo -n "$gguf_0" + elif [[ -s $gguf_1 ]]; then + echo -n "$gguf_1" + elif [[ -s $gguf_2 ]]; then + echo -n "$gguf_2" + else + echo >&2 "No model found. Can't run gg_run_ctest_with_model." + exit 1 + fi +} + +function gg_run_ctest_with_model_debug { + cd ${SRC} + + local model; model=$(gg_get_model) + cd build-ci-debug + set -e + (LLAMACPP_TEST_MODELFILE="$model" time ctest --output-on-failure -L model) 2>&1 | tee -a $OUT/${ci}-ctest.log + set +e + cd .. +} + +function gg_run_ctest_with_model_release { + cd ${SRC} + + local model; model=$(gg_get_model) + cd build-ci-release + set -e + (LLAMACPP_TEST_MODELFILE="$model" time ctest --output-on-failure -L model) 2>&1 | tee -a $OUT/${ci}-ctest.log + set +e + cd .. +} + +function gg_sum_ctest_with_model_debug { + gg_printf '### %s\n\n' "${ci}" + + gg_printf 'Runs ctest with model files in debug mode\n' + gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" + gg_printf '```\n' + gg_printf '%s\n' "$(cat $OUT/${ci}-ctest.log)" + gg_printf '```\n' +} + +function gg_sum_ctest_with_model_release { + gg_printf '### %s\n\n' "${ci}" + + gg_printf 'Runs ctest with model files in release mode\n' + gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" + gg_printf '```\n' + gg_printf '%s\n' "$(cat $OUT/${ci}-ctest.log)" + gg_printf '```\n' +} + +# open_llama_7b_v2 + +function gg_run_open_llama_7b_v2 { + cd ${SRC} + + gg_wget models-mnt/open-llama/7B-v2/ https://huggingface.co/openlm-research/open_llama_7b_v2/raw/main/config.json + gg_wget models-mnt/open-llama/7B-v2/ https://huggingface.co/openlm-research/open_llama_7b_v2/resolve/main/tokenizer.model + gg_wget models-mnt/open-llama/7B-v2/ https://huggingface.co/openlm-research/open_llama_7b_v2/raw/main/tokenizer_config.json + gg_wget models-mnt/open-llama/7B-v2/ https://huggingface.co/openlm-research/open_llama_7b_v2/raw/main/special_tokens_map.json + gg_wget models-mnt/open-llama/7B-v2/ https://huggingface.co/openlm-research/open_llama_7b_v2/raw/main/pytorch_model.bin.index.json + gg_wget models-mnt/open-llama/7B-v2/ https://huggingface.co/openlm-research/open_llama_7b_v2/resolve/main/pytorch_model-00001-of-00002.bin + gg_wget models-mnt/open-llama/7B-v2/ https://huggingface.co/openlm-research/open_llama_7b_v2/resolve/main/pytorch_model-00002-of-00002.bin + gg_wget models-mnt/open-llama/7B-v2/ https://huggingface.co/openlm-research/open_llama_7b_v2/raw/main/generation_config.json + + gg_wget models-mnt/wikitext/ https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip + unzip -o models-mnt/wikitext/wikitext-2-raw-v1.zip -d models-mnt/wikitext/ + + path_models="../models-mnt/open-llama/7B-v2" + path_wiki="../models-mnt/wikitext/wikitext-2-raw" + + rm -rf build-ci-release && mkdir build-ci-release && cd build-ci-release + + set -e + + (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log + + python3 ../examples/convert_legacy_llama.py ${path_models} --outfile ${path_models}/ggml-model-f16.gguf + + model_f16="${path_models}/ggml-model-f16.gguf" + model_q8_0="${path_models}/ggml-model-q8_0.gguf" + model_q4_0="${path_models}/ggml-model-q4_0.gguf" + model_q4_1="${path_models}/ggml-model-q4_1.gguf" + model_q5_0="${path_models}/ggml-model-q5_0.gguf" + model_q5_1="${path_models}/ggml-model-q5_1.gguf" + model_q2_k="${path_models}/ggml-model-q2_k.gguf" + model_q3_k="${path_models}/ggml-model-q3_k.gguf" + model_q4_k="${path_models}/ggml-model-q4_k.gguf" + model_q5_k="${path_models}/ggml-model-q5_k.gguf" + model_q6_k="${path_models}/ggml-model-q6_k.gguf" + + wiki_test="${path_wiki}/wiki.test.raw" + + ./bin/llama-quantize ${model_f16} ${model_q8_0} q8_0 + ./bin/llama-quantize ${model_f16} ${model_q4_0} q4_0 + ./bin/llama-quantize ${model_f16} ${model_q4_1} q4_1 + ./bin/llama-quantize ${model_f16} ${model_q5_0} q5_0 + ./bin/llama-quantize ${model_f16} ${model_q5_1} q5_1 + ./bin/llama-quantize ${model_f16} ${model_q2_k} q2_k + ./bin/llama-quantize ${model_f16} ${model_q3_k} q3_k + ./bin/llama-quantize ${model_f16} ${model_q4_k} q4_k + ./bin/llama-quantize ${model_f16} ${model_q5_k} q5_k + ./bin/llama-quantize ${model_f16} ${model_q6_k} q6_k + + (time ./bin/llama-cli -no-cnv --model ${model_f16} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log + (time ./bin/llama-cli -no-cnv --model ${model_q8_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log + (time ./bin/llama-cli -no-cnv --model ${model_q2_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q3_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q6_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log + + (time ./bin/llama-perplexity --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log + (time ./bin/llama-perplexity --model ${model_q8_0} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log + (time ./bin/llama-perplexity --model ${model_q4_0} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log + (time ./bin/llama-perplexity --model ${model_q4_1} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log + (time ./bin/llama-perplexity --model ${model_q5_0} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log + (time ./bin/llama-perplexity --model ${model_q5_1} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log + (time ./bin/llama-perplexity --model ${model_q2_k} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log + (time ./bin/llama-perplexity --model ${model_q3_k} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log + (time ./bin/llama-perplexity --model ${model_q4_k} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log + (time ./bin/llama-perplexity --model ${model_q5_k} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log + (time ./bin/llama-perplexity --model ${model_q6_k} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log + + (time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log + + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + + function check_ppl { + qnt="$1" + ppl=$(echo "$2" | grep -oE "[0-9]+\.[0-9]+" | tail -n 1) + + if [ $(echo "$ppl > 20.0" | bc) -eq 1 ]; then + printf ' - %s @ %s (FAIL: ppl > 20.0)\n' "$qnt" "$ppl" + return 20 + fi + + printf ' - %s @ %s OK\n' "$qnt" "$ppl" + return 0 + } + + check_ppl "f16" "$(cat $OUT/${ci}-tg-f16.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q8_0" "$(cat $OUT/${ci}-tg-q8_0.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q4_0" "$(cat $OUT/${ci}-tg-q4_0.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q4_1" "$(cat $OUT/${ci}-tg-q4_1.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q5_0" "$(cat $OUT/${ci}-tg-q5_0.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q5_1" "$(cat $OUT/${ci}-tg-q5_1.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q2_k" "$(cat $OUT/${ci}-tg-q2_k.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q3_k" "$(cat $OUT/${ci}-tg-q3_k.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q4_k" "$(cat $OUT/${ci}-tg-q4_k.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q5_k" "$(cat $OUT/${ci}-tg-q5_k.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q6_k" "$(cat $OUT/${ci}-tg-q6_k.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + + cat $OUT/${ci}-imatrix.log | grep "Final" >> $OUT/${ci}-imatrix-sum.log + + set +e +} + +function gg_sum_open_llama_7b_v2 { + gg_printf '### %s\n\n' "${ci}" + + gg_printf 'OpenLLaMA 7B-v2:\n' + gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" + gg_printf '- perplexity:\n%s\n' "$(cat $OUT/${ci}-ppl.log)" + gg_printf '- imatrix:\n```\n%s\n```\n' "$(cat $OUT/${ci}-imatrix-sum.log)" + gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)" + gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)" + gg_printf '- q4_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q4_0.log)" + gg_printf '- q4_1:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q4_1.log)" + gg_printf '- q5_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q5_0.log)" + gg_printf '- q5_1:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q5_1.log)" + gg_printf '- q2_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q2_k.log)" + gg_printf '- q3_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q3_k.log)" + gg_printf '- q4_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q4_k.log)" + gg_printf '- q5_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q5_k.log)" + gg_printf '- q6_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q6_k.log)" + gg_printf '- save-load-state: \n```\n%s\n```\n' "$(cat $OUT/${ci}-save-load-state.log)" +} + +# pythia_1.4b + +function gg_run_pythia_1_4b { + cd ${SRC} + + gg_wget models-mnt/pythia/1.4B/ https://huggingface.co/EleutherAI/pythia-1.4b/raw/main/config.json + gg_wget models-mnt/pythia/1.4B/ https://huggingface.co/EleutherAI/pythia-1.4b/raw/main/tokenizer.json + gg_wget models-mnt/pythia/1.4B/ https://huggingface.co/EleutherAI/pythia-1.4b/raw/main/tokenizer_config.json + gg_wget models-mnt/pythia/1.4B/ https://huggingface.co/EleutherAI/pythia-1.4b/raw/main/special_tokens_map.json + gg_wget models-mnt/pythia/1.4B/ https://huggingface.co/EleutherAI/pythia-1.4b/resolve/main/pytorch_model.bin + + gg_wget models-mnt/wikitext/ https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip + unzip -o models-mnt/wikitext/wikitext-2-raw-v1.zip -d models-mnt/wikitext/ + head -n 60 models-mnt/wikitext/wikitext-2-raw/wiki.test.raw > models-mnt/wikitext/wikitext-2-raw/wiki.test-60.raw + + path_models="../models-mnt/pythia/1.4B" + path_wiki="../models-mnt/wikitext/wikitext-2-raw" + + rm -rf build-ci-release && mkdir build-ci-release && cd build-ci-release + + set -e + + (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log + + python3 ../convert_hf_to_gguf.py ${path_models} --outfile ${path_models}/ggml-model-f16.gguf + + model_f16="${path_models}/ggml-model-f16.gguf" + model_q8_0="${path_models}/ggml-model-q8_0.gguf" + model_q4_0="${path_models}/ggml-model-q4_0.gguf" + model_q4_1="${path_models}/ggml-model-q4_1.gguf" + model_q5_0="${path_models}/ggml-model-q5_0.gguf" + model_q5_1="${path_models}/ggml-model-q5_1.gguf" + model_q2_k="${path_models}/ggml-model-q2_k.gguf" + model_q3_k="${path_models}/ggml-model-q3_k.gguf" + model_q4_k="${path_models}/ggml-model-q4_k.gguf" + model_q5_k="${path_models}/ggml-model-q5_k.gguf" + model_q6_k="${path_models}/ggml-model-q6_k.gguf" + + wiki_test_60="${path_wiki}/wiki.test-60.raw" + + ./bin/llama-quantize ${model_f16} ${model_q8_0} q8_0 + ./bin/llama-quantize ${model_f16} ${model_q4_0} q4_0 + ./bin/llama-quantize ${model_f16} ${model_q4_1} q4_1 + ./bin/llama-quantize ${model_f16} ${model_q5_0} q5_0 + ./bin/llama-quantize ${model_f16} ${model_q5_1} q5_1 + ./bin/llama-quantize ${model_f16} ${model_q2_k} q2_k + ./bin/llama-quantize ${model_f16} ${model_q3_k} q3_k + ./bin/llama-quantize ${model_f16} ${model_q4_k} q4_k + ./bin/llama-quantize ${model_f16} ${model_q5_k} q5_k + ./bin/llama-quantize ${model_f16} ${model_q6_k} q6_k + + (time ./bin/llama-cli -no-cnv --model ${model_f16} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log + (time ./bin/llama-cli -no-cnv --model ${model_q8_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_1} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_1} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log + (time ./bin/llama-cli -no-cnv --model ${model_q2_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q3_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q6_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log + + (time ./bin/llama-perplexity --model ${model_f16} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log + (time ./bin/llama-perplexity --model ${model_q8_0} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log + (time ./bin/llama-perplexity --model ${model_q4_0} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log + (time ./bin/llama-perplexity --model ${model_q4_1} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log + (time ./bin/llama-perplexity --model ${model_q5_0} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log + (time ./bin/llama-perplexity --model ${model_q5_1} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log + (time ./bin/llama-perplexity --model ${model_q2_k} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log + (time ./bin/llama-perplexity --model ${model_q3_k} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log + (time ./bin/llama-perplexity --model ${model_q4_k} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log + (time ./bin/llama-perplexity --model ${model_q5_k} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log + (time ./bin/llama-perplexity --model ${model_q6_k} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log + + (time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log + + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + + function check_ppl { + qnt="$1" + ppl=$(echo "$2" | grep -oE "[0-9]+\.[0-9]+" | tail -n 1) + + if [ $(echo "$ppl > 20.0" | bc) -eq 1 ]; then + printf ' - %s @ %s (FAIL: ppl > 20.0)\n' "$qnt" "$ppl" + return 20 + fi + + printf ' - %s @ %s OK\n' "$qnt" "$ppl" + return 0 + } + + check_ppl "f16" "$(cat $OUT/${ci}-tg-f16.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q8_0" "$(cat $OUT/${ci}-tg-q8_0.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q4_0" "$(cat $OUT/${ci}-tg-q4_0.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q4_1" "$(cat $OUT/${ci}-tg-q4_1.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q5_0" "$(cat $OUT/${ci}-tg-q5_0.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q5_1" "$(cat $OUT/${ci}-tg-q5_1.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + #check_ppl "q2_k" "$(cat $OUT/${ci}-tg-q2_k.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log # note: ppl > 20.0 for this quant and model + check_ppl "q3_k" "$(cat $OUT/${ci}-tg-q3_k.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q4_k" "$(cat $OUT/${ci}-tg-q4_k.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q5_k" "$(cat $OUT/${ci}-tg-q5_k.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q6_k" "$(cat $OUT/${ci}-tg-q6_k.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + + cat $OUT/${ci}-imatrix.log | grep "Final" >> $OUT/${ci}-imatrix-sum.log + + set +e +} + +function gg_sum_pythia_1_4b { + gg_printf '### %s\n\n' "${ci}" + + gg_printf 'Pythia 1.4B:\n' + gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" + gg_printf '- perplexity:\n%s\n' "$(cat $OUT/${ci}-ppl.log)" + gg_printf '- imatrix:\n```\n%s\n```\n' "$(cat $OUT/${ci}-imatrix-sum.log)" + gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)" + gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)" + gg_printf '- q4_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q4_0.log)" + gg_printf '- q4_1:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q4_1.log)" + gg_printf '- q5_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q5_0.log)" + gg_printf '- q5_1:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q5_1.log)" + gg_printf '- q2_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q2_k.log)" + gg_printf '- q3_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q3_k.log)" + gg_printf '- q4_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q4_k.log)" + gg_printf '- q5_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q5_k.log)" + gg_printf '- q6_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q6_k.log)" + gg_printf '- save-load-state: \n```\n%s\n```\n' "$(cat $OUT/${ci}-save-load-state.log)" +} + +# pythia_2_8b + +function gg_run_pythia_2_8b { + cd ${SRC} + + gg_wget models-mnt/pythia/2.8B/ https://huggingface.co/EleutherAI/pythia-2.8b/raw/main/config.json + gg_wget models-mnt/pythia/2.8B/ https://huggingface.co/EleutherAI/pythia-2.8b/raw/main/tokenizer.json + gg_wget models-mnt/pythia/2.8B/ https://huggingface.co/EleutherAI/pythia-2.8b/raw/main/tokenizer_config.json + gg_wget models-mnt/pythia/2.8B/ https://huggingface.co/EleutherAI/pythia-2.8b/raw/main/special_tokens_map.json + gg_wget models-mnt/pythia/2.8B/ https://huggingface.co/EleutherAI/pythia-2.8b/resolve/main/pytorch_model.bin + + gg_wget models-mnt/wikitext/ https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip + unzip -o models-mnt/wikitext/wikitext-2-raw-v1.zip -d models-mnt/wikitext/ + + path_models="../models-mnt/pythia/2.8B" + path_wiki="../models-mnt/wikitext/wikitext-2-raw" + + rm -rf build-ci-release && mkdir build-ci-release && cd build-ci-release + + set -e + + (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log + + python3 ../convert_hf_to_gguf.py ${path_models} --outfile ${path_models}/ggml-model-f16.gguf + + model_f16="${path_models}/ggml-model-f16.gguf" + model_q8_0="${path_models}/ggml-model-q8_0.gguf" + model_q4_0="${path_models}/ggml-model-q4_0.gguf" + model_q4_1="${path_models}/ggml-model-q4_1.gguf" + model_q5_0="${path_models}/ggml-model-q5_0.gguf" + model_q5_1="${path_models}/ggml-model-q5_1.gguf" + model_q2_k="${path_models}/ggml-model-q2_k.gguf" + model_q3_k="${path_models}/ggml-model-q3_k.gguf" + model_q4_k="${path_models}/ggml-model-q4_k.gguf" + model_q5_k="${path_models}/ggml-model-q5_k.gguf" + model_q6_k="${path_models}/ggml-model-q6_k.gguf" + + wiki_test="${path_wiki}/wiki.test.raw" + + ./bin/llama-quantize ${model_f16} ${model_q8_0} q8_0 + ./bin/llama-quantize ${model_f16} ${model_q4_0} q4_0 + ./bin/llama-quantize ${model_f16} ${model_q4_1} q4_1 + ./bin/llama-quantize ${model_f16} ${model_q5_0} q5_0 + ./bin/llama-quantize ${model_f16} ${model_q5_1} q5_1 + ./bin/llama-quantize ${model_f16} ${model_q2_k} q2_k + ./bin/llama-quantize ${model_f16} ${model_q3_k} q3_k + ./bin/llama-quantize ${model_f16} ${model_q4_k} q4_k + ./bin/llama-quantize ${model_f16} ${model_q5_k} q5_k + ./bin/llama-quantize ${model_f16} ${model_q6_k} q6_k + + (time ./bin/llama-cli -no-cnv --model ${model_f16} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log + (time ./bin/llama-cli -no-cnv --model ${model_q8_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log + (time ./bin/llama-cli -no-cnv --model ${model_q2_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q3_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q6_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log + + (time ./bin/llama-perplexity --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log + (time ./bin/llama-perplexity --model ${model_q8_0} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log + (time ./bin/llama-perplexity --model ${model_q4_0} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log + (time ./bin/llama-perplexity --model ${model_q4_1} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log + (time ./bin/llama-perplexity --model ${model_q5_0} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log + (time ./bin/llama-perplexity --model ${model_q5_1} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log + (time ./bin/llama-perplexity --model ${model_q2_k} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log + (time ./bin/llama-perplexity --model ${model_q3_k} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log + (time ./bin/llama-perplexity --model ${model_q4_k} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log + (time ./bin/llama-perplexity --model ${model_q5_k} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log + (time ./bin/llama-perplexity --model ${model_q6_k} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log + + (time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log + + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + + function check_ppl { + qnt="$1" + ppl=$(echo "$2" | grep -oE "[0-9]+\.[0-9]+" | tail -n 1) + + if [ $(echo "$ppl > 20.0" | bc) -eq 1 ]; then + printf ' - %s @ %s (FAIL: ppl > 20.0)\n' "$qnt" "$ppl" + return 20 + fi + + printf ' - %s @ %s OK\n' "$qnt" "$ppl" + return 0 + } + + check_ppl "f16" "$(cat $OUT/${ci}-tg-f16.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q8_0" "$(cat $OUT/${ci}-tg-q8_0.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q4_0" "$(cat $OUT/${ci}-tg-q4_0.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q4_1" "$(cat $OUT/${ci}-tg-q4_1.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q5_0" "$(cat $OUT/${ci}-tg-q5_0.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q5_1" "$(cat $OUT/${ci}-tg-q5_1.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + #check_ppl "q2_k" "$(cat $OUT/${ci}-tg-q2_k.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log # note: ppl > 20.0 for this quant and model + check_ppl "q3_k" "$(cat $OUT/${ci}-tg-q3_k.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q4_k" "$(cat $OUT/${ci}-tg-q4_k.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q5_k" "$(cat $OUT/${ci}-tg-q5_k.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + check_ppl "q6_k" "$(cat $OUT/${ci}-tg-q6_k.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + + cat $OUT/${ci}-imatrix.log | grep "Final" >> $OUT/${ci}-imatrix-sum.log + + set +e +} + +function gg_sum_pythia_2_8b { + gg_printf '### %s\n\n' "${ci}" + + gg_printf 'Pythia 2.8B:\n' + gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" + gg_printf '- perplexity:\n%s\n' "$(cat $OUT/${ci}-ppl.log)" + gg_printf '- imatrix:\n```\n%s\n```\n' "$(cat $OUT/${ci}-imatrix-sum.log)" + gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)" + gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)" + gg_printf '- q4_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q4_0.log)" + gg_printf '- q4_1:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q4_1.log)" + gg_printf '- q5_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q5_0.log)" + gg_printf '- q5_1:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q5_1.log)" + gg_printf '- q2_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q2_k.log)" + gg_printf '- q3_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q3_k.log)" + gg_printf '- q4_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q4_k.log)" + gg_printf '- q5_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q5_k.log)" + gg_printf '- q6_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q6_k.log)" + gg_printf '- save-load-state: \n```\n%s\n```\n' "$(cat $OUT/${ci}-save-load-state.log)" +} + +# bge-small + +function gg_run_embd_bge_small { + cd ${SRC} + + gg_wget models-mnt/bge-small/ https://huggingface.co/BAAI/bge-small-en-v1.5/raw/main/config.json + gg_wget models-mnt/bge-small/ https://huggingface.co/BAAI/bge-small-en-v1.5/raw/main/tokenizer.json + gg_wget models-mnt/bge-small/ https://huggingface.co/BAAI/bge-small-en-v1.5/raw/main/tokenizer_config.json + gg_wget models-mnt/bge-small/ https://huggingface.co/BAAI/bge-small-en-v1.5/raw/main/special_tokens_map.json + gg_wget models-mnt/bge-small/ https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/pytorch_model.bin + gg_wget models-mnt/bge-small/ https://huggingface.co/BAAI/bge-small-en-v1.5/raw/main/sentence_bert_config.json + gg_wget models-mnt/bge-small/ https://huggingface.co/BAAI/bge-small-en-v1.5/raw/main/vocab.txt + gg_wget models-mnt/bge-small/ https://huggingface.co/BAAI/bge-small-en-v1.5/raw/main/modules.json + gg_wget models-mnt/bge-small/ https://huggingface.co/BAAI/bge-small-en-v1.5/raw/main/config.json + + gg_wget models-mnt/bge-small/1_Pooling https://huggingface.co/BAAI/bge-small-en-v1.5/raw/main/1_Pooling/config.json + + path_models="../models-mnt/bge-small" + + rm -rf build-ci-release && mkdir build-ci-release && cd build-ci-release + + set -e + + (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log + + python3 ../convert_hf_to_gguf.py ${path_models} --outfile ${path_models}/ggml-model-f16.gguf + + model_f16="${path_models}/ggml-model-f16.gguf" + model_q8_0="${path_models}/ggml-model-q8_0.gguf" + + ./bin/llama-quantize ${model_f16} ${model_q8_0} q8_0 + + (time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log + (time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log + + set +e +} + +function gg_sum_embd_bge_small { + gg_printf '### %s\n\n' "${ci}" + + gg_printf 'BGE Small (BERT):\n' + gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" + gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)" + gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)" +} + +# rerank_tiny + +function gg_run_rerank_tiny { + cd ${SRC} + + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer_config.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/special_tokens_map.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/resolve/main/pytorch_model.bin + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/sentence_bert_config.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/vocab.txt + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/modules.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json + + gg_wget models-mnt/rerank-tiny/1_Pooling https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/1_Pooling/config.json + + path_models="../models-mnt/rerank-tiny" + + rm -rf build-ci-release && mkdir build-ci-release && cd build-ci-release + + set -e + + (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log + + python3 ../convert_hf_to_gguf.py ${path_models} --outfile ${path_models}/ggml-model-f16.gguf + + model_f16="${path_models}/ggml-model-f16.gguf" + + # for this model, the SEP token is "" + (time ./bin/llama-embedding --model ${model_f16} -p "what is panda?hi\nwhat is panda?it's a bear\nwhat is panda?The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log + + # sample output + # rerank score 0: 0.029 + # rerank score 1: 0.029 + # rerank score 2: 0.135 + + # check that the score is in the range [$3, $4] + function check_score { + qnt="$1" + score=$(echo "$2" | grep -oE "[0-9]+\.[0-9]+" | tail -n 1) + + if [ $(echo "$score < $3" | bc) -eq 1 ] || [ $(echo "$score > $4" | bc) -eq 1 ]; then + printf ' - %s @ %s (FAIL: score not in range [%s, %s])\n' "$qnt" "$score" "$3" "$4" + return 20 + fi + + printf ' - %s @ %s OK\n' "$qnt" "$score" + return 0 + } + + check_score "rerank score 0" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 0")" "0.00" "0.05" | tee -a $OUT/${ci}-rk-f16.log + check_score "rerank score 1" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 1")" "0.00" "0.05" | tee -a $OUT/${ci}-rk-f16.log + check_score "rerank score 2" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 2")" "0.10" "0.30" | tee -a $OUT/${ci}-rk-f16.log + + set +e +} + +function gg_sum_rerank_tiny { + gg_printf '### %s\n\n' "${ci}" + + gg_printf 'Rerank Tiny (Jina):\n' + gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" + gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-rk-f16.log)" +} + +function gg_check_build_requirements { + if ! command -v cmake &> /dev/null; then + gg_printf 'cmake not found, please install' + fi + + if ! command -v make &> /dev/null; then + gg_printf 'make not found, please install' + fi + + if ! command -v ctest &> /dev/null; then + gg_printf 'ctest not found, please install' + fi +} + +## main + +export LLAMA_LOG_PREFIX=1 +export LLAMA_LOG_TIMESTAMPS=1 + +if [ -z ${GG_BUILD_LOW_PERF} ]; then + # Create symlink: ./llama.cpp/models-mnt -> $MNT/models + rm -rf ${SRC}/models-mnt + mnt_models=${MNT}/models + mkdir -p ${mnt_models} + ln -sfn ${mnt_models} ${SRC}/models-mnt + + # Create a fresh python3 venv and enter it + if ! python3 -m venv "$MNT/venv"; then + echo "Error: Failed to create Python virtual environment at $MNT/venv." + exit 1 + fi + source "$MNT/venv/bin/activate" + + pip install -r ${SRC}/requirements.txt --disable-pip-version-check + pip install --editable gguf-py --disable-pip-version-check +fi + +ret=0 +if [ -z ${GG_BUILD_SYCL} ]; then + # SYCL build breaks with debug build flags + test $ret -eq 0 && gg_run ctest_debug +fi +test $ret -eq 0 && gg_run ctest_release + +if [ -z ${GG_BUILD_LOW_PERF} ]; then + test $ret -eq 0 && gg_run embd_bge_small + test $ret -eq 0 && gg_run rerank_tiny + + if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then + if [ -z ${GG_BUILD_SYCL} ]; then + test $ret -eq 0 && gg_run test_scripts_debug + fi + test $ret -eq 0 && gg_run test_scripts_release + fi + + if [ -z ${GG_BUILD_VRAM_GB} ] || [ ${GG_BUILD_VRAM_GB} -ge 8 ]; then + if [ -z ${GG_BUILD_CUDA} ] && [ -z ${GG_BUILD_VULKAN} ]; then + test $ret -eq 0 && gg_run pythia_1_4b + else + test $ret -eq 0 && gg_run pythia_2_8b + #test $ret -eq 0 && gg_run open_llama_7b_v2 + fi + if [ -z ${GG_BUILD_SYCL} ]; then + test $ret -eq 0 && gg_run ctest_with_model_debug + fi + test $ret -eq 0 && gg_run ctest_with_model_release + fi +fi + +exit $ret diff --git a/cmake/arm64-apple-clang.cmake b/cmake/arm64-apple-clang.cmake new file mode 100644 index 0000000000000000000000000000000000000000..5fcd2882afc9d96ccc1c28396b91e5e46c93f66c --- /dev/null +++ b/cmake/arm64-apple-clang.cmake @@ -0,0 +1,16 @@ +set( CMAKE_SYSTEM_NAME Darwin ) +set( CMAKE_SYSTEM_PROCESSOR arm64 ) + +set( target arm64-apple-darwin-macho ) + +set( CMAKE_C_COMPILER clang ) +set( CMAKE_CXX_COMPILER clang++ ) + +set( CMAKE_C_COMPILER_TARGET ${target} ) +set( CMAKE_CXX_COMPILER_TARGET ${target} ) + +set( arch_c_flags "-march=armv8.4-a -fvectorize -ffp-model=fast -fno-finite-math-only" ) +set( warn_c_flags "-Wno-format -Wno-unused-variable -Wno-unused-function" ) + +set( CMAKE_C_FLAGS_INIT "${arch_c_flags} ${warn_c_flags}" ) +set( CMAKE_CXX_FLAGS_INIT "${arch_c_flags} ${warn_c_flags}" ) diff --git a/cmake/arm64-windows-llvm.cmake b/cmake/arm64-windows-llvm.cmake new file mode 100644 index 0000000000000000000000000000000000000000..8023796800683240e32048180a1eb874ddab1963 --- /dev/null +++ b/cmake/arm64-windows-llvm.cmake @@ -0,0 +1,16 @@ +set( CMAKE_SYSTEM_NAME Windows ) +set( CMAKE_SYSTEM_PROCESSOR arm64 ) + +set( target arm64-pc-windows-msvc ) + +set( CMAKE_C_COMPILER clang ) +set( CMAKE_CXX_COMPILER clang++ ) + +set( CMAKE_C_COMPILER_TARGET ${target} ) +set( CMAKE_CXX_COMPILER_TARGET ${target} ) + +set( arch_c_flags "-march=armv8.7-a -fvectorize -ffp-model=fast -fno-finite-math-only" ) +set( warn_c_flags "-Wno-format -Wno-unused-variable -Wno-unused-function -Wno-gnu-zero-variadic-macro-arguments" ) + +set( CMAKE_C_FLAGS_INIT "${arch_c_flags} ${warn_c_flags}" ) +set( CMAKE_CXX_FLAGS_INIT "${arch_c_flags} ${warn_c_flags}" ) diff --git a/cmake/build-info.cmake b/cmake/build-info.cmake new file mode 100644 index 0000000000000000000000000000000000000000..75c78222f2e7f30c9500f4ad67929ce572c78f63 --- /dev/null +++ b/cmake/build-info.cmake @@ -0,0 +1,64 @@ +set(BUILD_NUMBER 0) +set(BUILD_COMMIT "unknown") +set(BUILD_COMPILER "unknown") +set(BUILD_TARGET "unknown") + +# Look for git +find_package(Git) +if(NOT Git_FOUND) + find_program(GIT_EXECUTABLE NAMES git git.exe) + if(GIT_EXECUTABLE) + set(Git_FOUND TRUE) + message(STATUS "Found Git: ${GIT_EXECUTABLE}") + else() + message(WARNING "Git not found. Build info will not be accurate.") + endif() +endif() + +# Get the commit count and hash +if(Git_FOUND) + execute_process( + COMMAND ${GIT_EXECUTABLE} rev-parse --short HEAD + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE HEAD + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE RES + ) + if (RES EQUAL 0) + set(BUILD_COMMIT ${HEAD}) + endif() + execute_process( + COMMAND ${GIT_EXECUTABLE} rev-list --count HEAD + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE COUNT + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE RES + ) + if (RES EQUAL 0) + set(BUILD_NUMBER ${COUNT}) + endif() +endif() + +if(MSVC) + set(BUILD_COMPILER "${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}") + if (CMAKE_VS_PLATFORM_NAME) + set(BUILD_TARGET ${CMAKE_VS_PLATFORM_NAME}) + else() + set(BUILD_TARGET "${CMAKE_SYSTEM_NAME} ${CMAKE_SYSTEM_PROCESSOR}") + endif() +else() + execute_process( + COMMAND ${CMAKE_C_COMPILER} --version + OUTPUT_VARIABLE OUT + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + string(REGEX REPLACE " *\n.*" "" OUT "${OUT}") + set(BUILD_COMPILER ${OUT}) + + execute_process( + COMMAND ${CMAKE_C_COMPILER} -dumpmachine + OUTPUT_VARIABLE OUT + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + set(BUILD_TARGET ${OUT}) +endif() diff --git a/cmake/common.cmake b/cmake/common.cmake new file mode 100644 index 0000000000000000000000000000000000000000..a5bb787f1519dce0f19618f31f20edea072372ad --- /dev/null +++ b/cmake/common.cmake @@ -0,0 +1,35 @@ +include("ggml/cmake/common.cmake") + +function(llama_add_compile_flags) + if (LLAMA_FATAL_WARNINGS) + if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") + list(APPEND C_FLAGS -Werror) + list(APPEND CXX_FLAGS -Werror) + elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + add_compile_options(/WX) + endif() + endif() + + if (LLAMA_ALL_WARNINGS) + if (NOT MSVC) + list(APPEND C_FLAGS -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes + -Werror=implicit-int -Werror=implicit-function-declaration) + + list(APPEND CXX_FLAGS -Wmissing-declarations -Wmissing-noreturn) + + list(APPEND WARNING_FLAGS -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function) + + list(APPEND C_FLAGS ${WARNING_FLAGS}) + list(APPEND CXX_FLAGS ${WARNING_FLAGS}) + + ggml_get_flags(${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}) + + add_compile_options("$<$:${C_FLAGS};${GF_C_FLAGS}>" + "$<$:${CXX_FLAGS};${GF_CXX_FLAGS}>") + else() + # todo : msvc + set(C_FLAGS "" PARENT_SCOPE) + set(CXX_FLAGS "" PARENT_SCOPE) + endif() + endif() +endfunction() diff --git a/cmake/git-vars.cmake b/cmake/git-vars.cmake new file mode 100644 index 0000000000000000000000000000000000000000..1a4c24ebf6adeb1126e626f56de601621179353d --- /dev/null +++ b/cmake/git-vars.cmake @@ -0,0 +1,22 @@ +find_package(Git) + +# the commit's SHA1 +execute_process(COMMAND + "${GIT_EXECUTABLE}" describe --match=NeVeRmAtCh --always --abbrev=8 + WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" + OUTPUT_VARIABLE GIT_SHA1 + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + +# the date of the commit +execute_process(COMMAND + "${GIT_EXECUTABLE}" log -1 --format=%ad --date=local + WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" + OUTPUT_VARIABLE GIT_DATE + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + +# the subject of the commit +execute_process(COMMAND + "${GIT_EXECUTABLE}" log -1 --format=%s + WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" + OUTPUT_VARIABLE GIT_COMMIT_SUBJECT + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) diff --git a/cmake/llama-config.cmake.in b/cmake/llama-config.cmake.in new file mode 100644 index 0000000000000000000000000000000000000000..90cbec5b6f1331b9149a5ae3aa03fea45d69246d --- /dev/null +++ b/cmake/llama-config.cmake.in @@ -0,0 +1,30 @@ +set(LLAMA_VERSION @LLAMA_INSTALL_VERSION@) +set(LLAMA_BUILD_COMMIT @LLAMA_BUILD_COMMIT@) +set(LLAMA_BUILD_NUMBER @LLAMA_BUILD_NUMBER@) +set(LLAMA_SHARED_LIB @BUILD_SHARED_LIBS@) + +@PACKAGE_INIT@ + +set_and_check(LLAMA_INCLUDE_DIR "@PACKAGE_LLAMA_INCLUDE_INSTALL_DIR@") +set_and_check(LLAMA_LIB_DIR "@PACKAGE_LLAMA_LIB_INSTALL_DIR@") +set_and_check(LLAMA_BIN_DIR "@PACKAGE_LLAMA_BIN_INSTALL_DIR@") + +find_package(ggml REQUIRED HINTS ${LLAMA_LIB_DIR}/cmake) + +find_library(llama_LIBRARY llama + REQUIRED + HINTS ${LLAMA_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH +) + +add_library(llama UNKNOWN IMPORTED) +set_target_properties(llama + PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${LLAMA_INCLUDE_DIR}" + INTERFACE_LINK_LIBRARIES "ggml::ggml;ggml::ggml-base;" + IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" + IMPORTED_LOCATION "${llama_LIBRARY}" + INTERFACE_COMPILE_FEATURES c_std_90 + POSITION_INDEPENDENT_CODE ON) + +check_required_components(Llama) diff --git a/cmake/llama.pc.in b/cmake/llama.pc.in new file mode 100644 index 0000000000000000000000000000000000000000..6fb58b5f6881b7f204d9094a8f2d3ed1d93906c2 --- /dev/null +++ b/cmake/llama.pc.in @@ -0,0 +1,10 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +exec_prefix=@CMAKE_INSTALL_PREFIX@ +libdir=@CMAKE_INSTALL_FULL_LIBDIR@ +includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ + +Name: llama +Description: Port of Facebook's LLaMA model in C/C++ +Version: @LLAMA_INSTALL_VERSION@ +Libs: -L${libdir} -lggml -lggml-base -lllama +Cflags: -I${includedir} diff --git a/cmake/x64-windows-llvm.cmake b/cmake/x64-windows-llvm.cmake new file mode 100644 index 0000000000000000000000000000000000000000..77e79140798b2fcee2d84a3abdeab315c08b7a03 --- /dev/null +++ b/cmake/x64-windows-llvm.cmake @@ -0,0 +1,5 @@ +set( CMAKE_SYSTEM_NAME Windows ) +set( CMAKE_SYSTEM_PROCESSOR x86_64 ) + +set( CMAKE_C_COMPILER clang ) +set( CMAKE_CXX_COMPILER clang++ ) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..f43a630c900ff15d244c42f5876ee185905e78c7 --- /dev/null +++ b/common/CMakeLists.txt @@ -0,0 +1,163 @@ +# common + +find_package(Threads REQUIRED) + +llama_add_compile_flags() + +# Build info header +# + +if(EXISTS "${PROJECT_SOURCE_DIR}/.git") + set(GIT_DIR "${PROJECT_SOURCE_DIR}/.git") + + # Is git submodule + if(NOT IS_DIRECTORY "${GIT_DIR}") + file(READ ${GIT_DIR} REAL_GIT_DIR_LINK) + string(REGEX REPLACE "gitdir: (.*)\n$" "\\1" REAL_GIT_DIR ${REAL_GIT_DIR_LINK}) + string(FIND "${REAL_GIT_DIR}" "/" SLASH_POS) + if (SLASH_POS EQUAL 0) + set(GIT_DIR "${REAL_GIT_DIR}") + else() + set(GIT_DIR "${PROJECT_SOURCE_DIR}/${REAL_GIT_DIR}") + endif() + endif() + + if(EXISTS "${GIT_DIR}/index") + # For build-info.cpp below + set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS "${GIT_DIR}/index") + else() + message(WARNING "Git index not found in git repository.") + endif() +else() + message(WARNING "Git repository not found; to enable automatic generation of build info, make sure Git is installed and the project is a Git repository.") +endif() + +set(TEMPLATE_FILE "${CMAKE_CURRENT_SOURCE_DIR}/build-info.cpp.in") +set(OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/build-info.cpp") +configure_file(${TEMPLATE_FILE} ${OUTPUT_FILE}) + +set(TARGET build_info) +add_library(${TARGET} OBJECT ${OUTPUT_FILE}) +if (BUILD_SHARED_LIBS) + set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON) +endif() + +set(TARGET common) + +add_library(${TARGET} STATIC + arg.cpp + arg.h + base64.hpp + chat-parser.cpp + chat-parser.h + chat.cpp + chat.h + common.cpp + common.h + console.cpp + console.h + json-partial.cpp + json-partial.h + json-schema-to-grammar.cpp + llguidance.cpp + log.cpp + log.h + ngram-cache.cpp + ngram-cache.h + regex-partial.cpp + regex-partial.h + sampling.cpp + sampling.h + speculative.cpp + speculative.h + ) + +if (BUILD_SHARED_LIBS) + set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON) +endif() + +set(LLAMA_COMMON_EXTRA_LIBS build_info) + +# Use curl to download model url +if (LLAMA_CURL) + find_package(CURL) + if (NOT CURL_FOUND) + message(FATAL_ERROR "Could NOT find CURL. Hint: to disable this feature, set -DLLAMA_CURL=OFF") + endif() + target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL) + include_directories(${CURL_INCLUDE_DIRS}) + find_library(CURL_LIBRARY curl REQUIRED) + set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARY}) +endif () + +if (LLAMA_LLGUIDANCE) + include(ExternalProject) + set(LLGUIDANCE_SRC ${CMAKE_BINARY_DIR}/llguidance/source) + set(LLGUIDANCE_PATH ${LLGUIDANCE_SRC}/target/release) + + # Set the correct library file extension based on platform + if (WIN32) + set(LLGUIDANCE_LIB_NAME "llguidance.lib") + # Add Windows-specific libraries + set(LLGUIDANCE_PLATFORM_LIBS + ws2_32 # Windows Sockets API + userenv # For GetUserProfileDirectoryW + ntdll # For NT functions + bcrypt # For BCryptGenRandom + ) + else() + set(LLGUIDANCE_LIB_NAME "libllguidance.a") + set(LLGUIDANCE_PLATFORM_LIBS "") + endif() + + ExternalProject_Add(llguidance_ext + GIT_REPOSITORY https://github.com/guidance-ai/llguidance + # v0.7.20 (+ fix to build on GCC 15): + GIT_TAG b5b8b64dba11c4e4ee6b1d1450d3a3ae279891e8 + PREFIX ${CMAKE_BINARY_DIR}/llguidance + SOURCE_DIR ${LLGUIDANCE_SRC} + BUILD_IN_SOURCE TRUE + CONFIGURE_COMMAND "" + BUILD_COMMAND cargo build --release + INSTALL_COMMAND "" + BUILD_BYPRODUCTS ${LLGUIDANCE_PATH}/${LLGUIDANCE_LIB_NAME} ${LLGUIDANCE_PATH}/llguidance.h + UPDATE_COMMAND "" + ) + target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_LLGUIDANCE) + + add_library(llguidance STATIC IMPORTED) + set_target_properties(llguidance PROPERTIES IMPORTED_LOCATION ${LLGUIDANCE_PATH}/${LLGUIDANCE_LIB_NAME}) + add_dependencies(llguidance llguidance_ext) + + target_include_directories(${TARGET} PRIVATE ${LLGUIDANCE_PATH}) + # Add platform libraries to the main target + set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance ${LLGUIDANCE_PLATFORM_LIBS}) +endif () + +target_include_directories(${TARGET} PUBLIC . ../vendor) +target_compile_features (${TARGET} PUBLIC cxx_std_17) +target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads) + + +# +# copy the license files +# + +# Check if running in GitHub Actions +if (DEFINED ENV{GITHUB_ACTIONS} AND "$ENV{GITHUB_ACTIONS}" STREQUAL "true") + message(STATUS "Running inside GitHub Actions - copying license files") + + # Copy all files from licenses/ to build/bin/ + file(GLOB LICENSE_FILES "${CMAKE_SOURCE_DIR}/licenses/*") + foreach(LICENSE_FILE ${LICENSE_FILES}) + get_filename_component(FILENAME ${LICENSE_FILE} NAME) + add_custom_command( + POST_BUILD + TARGET ${TARGET} + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${LICENSE_FILE}" + "$/${FILENAME}" + COMMENT "Copying ${FILENAME} to ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}") + message(STATUS "Copying ${LICENSE_FILE} to ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${FILENAME}") + endforeach() +endif() diff --git a/common/arg.cpp b/common/arg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..231de227a9122815f2e28489dcb407095d6738bb --- /dev/null +++ b/common/arg.cpp @@ -0,0 +1,3377 @@ +#include "arg.h" + +#include "chat.h" +#include "common.h" +#include "gguf.h" // for reading GGUF splits +#include "json-schema-to-grammar.h" +#include "log.h" +#include "sampling.h" + +// fix problem with std::min and std::max +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +# define NOMINMAX +#endif +#include +#endif + +#define JSON_ASSERT GGML_ASSERT +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +//#define LLAMA_USE_CURL + +#if defined(LLAMA_USE_CURL) +#include +#include +#include +#endif + +using json = nlohmann::ordered_json; + +std::initializer_list mmproj_examples = { + LLAMA_EXAMPLE_MTMD, + LLAMA_EXAMPLE_SERVER, +}; + +static std::string read_file(const std::string & fname) { + std::ifstream file(fname); + if (!file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str())); + } + std::string content((std::istreambuf_iterator(file)), std::istreambuf_iterator()); + file.close(); + return content; +} + +static void write_file(const std::string & fname, const std::string & content) { + std::ofstream file(fname); + if (!file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str())); + } + file << content; + file.close(); +} + +common_arg & common_arg::set_examples(std::initializer_list examples) { + this->examples = std::move(examples); + return *this; +} + +common_arg & common_arg::set_excludes(std::initializer_list excludes) { + this->excludes = std::move(excludes); + return *this; +} + +common_arg & common_arg::set_env(const char * env) { + help = help + "\n(env: " + env + ")"; + this->env = env; + return *this; +} + +common_arg & common_arg::set_sparam() { + is_sparam = true; + return *this; +} + +bool common_arg::in_example(enum llama_example ex) { + return examples.find(ex) != examples.end(); +} + +bool common_arg::is_exclude(enum llama_example ex) { + return excludes.find(ex) != excludes.end(); +} + +bool common_arg::get_value_from_env(std::string & output) { + if (env == nullptr) return false; + char * value = std::getenv(env); + if (value) { + output = value; + return true; + } + return false; +} + +bool common_arg::has_value_from_env() { + return env != nullptr && std::getenv(env); +} + +static std::vector break_str_into_lines(std::string input, size_t max_char_per_line) { + std::vector result; + std::istringstream iss(input); + std::string line; + auto add_line = [&](const std::string& l) { + if (l.length() <= max_char_per_line) { + result.push_back(l); + } else { + std::istringstream line_stream(l); + std::string word, current_line; + while (line_stream >> word) { + if (current_line.length() + !current_line.empty() + word.length() > max_char_per_line) { + if (!current_line.empty()) result.push_back(current_line); + current_line = word; + } else { + current_line += (!current_line.empty() ? " " : "") + word; + } + } + if (!current_line.empty()) result.push_back(current_line); + } + }; + while (std::getline(iss, line)) { + add_line(line); + } + return result; +} + +std::string common_arg::to_string() { + // params for printing to console + const static int n_leading_spaces = 40; + const static int n_char_per_line_help = 70; // TODO: detect this based on current console + std::string leading_spaces(n_leading_spaces, ' '); + + std::ostringstream ss; + for (const auto arg : args) { + if (arg == args.front()) { + if (args.size() == 1) { + ss << arg; + } else { + // first arg is usually abbreviation, we need padding to make it more beautiful + auto tmp = std::string(arg) + ", "; + auto spaces = std::string(std::max(0, 7 - (int)tmp.size()), ' '); + ss << tmp << spaces; + } + } else { + ss << arg << (arg != args.back() ? ", " : ""); + } + } + if (value_hint) ss << " " << value_hint; + if (value_hint_2) ss << " " << value_hint_2; + if (ss.tellp() > n_leading_spaces - 3) { + // current line is too long, add new line + ss << "\n" << leading_spaces; + } else { + // padding between arg and help, same line + ss << std::string(leading_spaces.size() - ss.tellp(), ' '); + } + const auto help_lines = break_str_into_lines(help, n_char_per_line_help); + for (const auto & line : help_lines) { + ss << (&line == &help_lines.front() ? "" : leading_spaces) << line << "\n"; + } + return ss.str(); +} + +// +// downloader +// + +struct common_hf_file_res { + std::string repo; // repo name with ":tag" removed + std::string ggufFile; + std::string mmprojFile; +}; + +#ifdef LLAMA_USE_CURL + +bool common_has_curl() { + return true; +} + +#ifdef __linux__ +#include +#elif defined(_WIN32) +# if !defined(PATH_MAX) +# define PATH_MAX MAX_PATH +# endif +#elif defined(_AIX) +#include +#else +#include +#endif +#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 + +// +// CURL utils +// + +using curl_ptr = std::unique_ptr; + +// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one +struct curl_slist_ptr { + struct curl_slist * ptr = nullptr; + ~curl_slist_ptr() { + if (ptr) { + curl_slist_free_all(ptr); + } + } +}; + +#define CURL_MAX_RETRY 3 +#define CURL_RETRY_DELAY_SECONDS 2 + +static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds, const char * method_name) { + int remaining_attempts = max_attempts; + + while (remaining_attempts > 0) { + LOG_INF("%s: %s %s (attempt %d of %d)...\n", __func__ , method_name, url.c_str(), max_attempts - remaining_attempts + 1, max_attempts); + + CURLcode res = curl_easy_perform(curl); + if (res == CURLE_OK) { + return true; + } + + int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000; + LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay); + + remaining_attempts--; + if (remaining_attempts == 0) break; + std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay)); + } + + LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts); + + return false; +} + +// download one single file from remote URL to local path +static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, bool offline) { + // Check if the file already exists locally + auto file_exists = std::filesystem::exists(path); + + // If the file exists, check its JSON metadata companion file. + std::string metadata_path = path + ".json"; + nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead + std::string etag; + std::string last_modified; + + if (file_exists) { + if (offline) { + LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str()); + return true; // skip verification/downloading + } + // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block). + std::ifstream metadata_in(metadata_path); + if (metadata_in.good()) { + try { + metadata_in >> metadata; + LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str()); + if (metadata.contains("etag") && metadata.at("etag").is_string()) { + etag = metadata.at("etag"); + } + if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) { + last_modified = metadata.at("lastModified"); + } + } catch (const nlohmann::json::exception & e) { + LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); + } + } + // if we cannot open the metadata file, we assume that the downloaded file is not valid (etag and last-modified are left empty, so we will download it again) + } else { + if (offline) { + LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str()); + return false; + } + LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); + } + + // Send a HEAD request to retrieve the etag and last-modified headers + struct common_load_model_from_url_headers { + std::string etag; + std::string last_modified; + }; + + common_load_model_from_url_headers headers; + bool head_request_ok = false; + bool should_download = !file_exists; // by default, we should download if the file does not exist + + // Initialize libcurl + curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_slist_ptr http_headers; + if (!curl) { + LOG_ERR("%s: error initializing libcurl\n", __func__); + return false; + } + + // Set the URL, allow to follow http redirection + curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); + + http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); + // Check if hf-token or bearer-token was specified + if (!bearer_token.empty()) { + std::string auth_header = "Authorization: Bearer " + bearer_token; + http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); + } + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); + +#if defined(_WIN32) + // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of + // operating system. Currently implemented under MS-Windows. + curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); +#endif + + typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *); + auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t { + common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata; + + static std::regex header_regex("([^:]+): (.*)\r\n"); + static std::regex etag_regex("ETag", std::regex_constants::icase); + static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase); + + std::string header(buffer, n_items); + std::smatch match; + if (std::regex_match(header, match, header_regex)) { + const std::string & key = match[1]; + const std::string & value = match[2]; + if (std::regex_match(key, match, etag_regex)) { + headers->etag = value; + } else if (std::regex_match(key, match, last_modified_regex)) { + headers->last_modified = value; + } + } + return n_items; + }; + + curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress + curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast(header_callback)); + curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); + + // we only allow retrying once for HEAD requests + // this is for the use case of using running offline (no internet), retrying can be annoying + bool was_perform_successful = curl_perform_with_retry(url, curl.get(), 1, 0, "HEAD"); + if (!was_perform_successful) { + head_request_ok = false; + } + + long http_code = 0; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code); + if (http_code == 200) { + head_request_ok = true; + } else { + LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code); + head_request_ok = false; + } + + // if head_request_ok is false, we don't have the etag or last-modified headers + // we leave should_download as-is, which is true if the file does not exist + if (head_request_ok) { + // check if ETag or Last-Modified headers are different + // if it is, we need to download the file again + if (!etag.empty() && etag != headers.etag) { + LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str()); + should_download = true; + } else if (!last_modified.empty() && last_modified != headers.last_modified) { + LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str()); + should_download = true; + } + } + + if (should_download) { + std::string path_temporary = path + ".downloadInProgress"; + if (file_exists) { + LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str()); + if (remove(path.c_str()) != 0) { + LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); + return false; + } + } + + // Set the output file + + struct FILE_deleter { + void operator()(FILE * f) const { + fclose(f); + } + }; + + std::unique_ptr outfile(fopen(path_temporary.c_str(), "wb")); + if (!outfile) { + LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str()); + return false; + } + + typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd); + auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t { + return fwrite(data, size, nmemb, (FILE *)fd); + }; + curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L); + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get()); + + // display download progress + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L); + + // helper function to hide password in URL + auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string { + std::size_t protocol_pos = url.find("://"); + if (protocol_pos == std::string::npos) { + return url; // Malformed URL + } + + std::size_t at_pos = url.find('@', protocol_pos + 3); + if (at_pos == std::string::npos) { + return url; // No password in URL + } + + return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos); + }; + + // start the download + LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__, + llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str()); + bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS, "GET"); + if (!was_perform_successful) { + return false; + } + + long http_code = 0; + curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code); + if (http_code < 200 || http_code >= 400) { + LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code); + return false; + } + + // Causes file to be closed explicitly here before we rename it. + outfile.reset(); + + // Write the updated JSON metadata file. + metadata.update({ + {"url", url}, + {"etag", headers.etag}, + {"lastModified", headers.last_modified} + }); + write_file(metadata_path, metadata.dump(4)); + LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str()); + + if (rename(path_temporary.c_str(), path.c_str()) != 0) { + LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); + return false; + } + } else { + LOG_INF("%s: using cached file: %s\n", __func__, path.c_str()); + } + + return true; +} + +// download multiple files from remote URLs to local paths +// the input is a vector of pairs +static bool common_download_file_multiple(const std::vector> & urls, const std::string & bearer_token, bool offline) { + // Prepare download in parallel + std::vector> futures_download; + for (auto const & item : urls) { + futures_download.push_back(std::async(std::launch::async, [bearer_token, offline](const std::pair & it) -> bool { + return common_download_file_single(it.first, it.second, bearer_token, offline); + }, item)); + } + + // Wait for all downloads to complete + for (auto & f : futures_download) { + if (!f.get()) { + return false; + } + } + + return true; +} + +static bool common_download_model( + const common_params_model & model, + const std::string & bearer_token, + bool offline) { + // Basic validation of the model.url + if (model.url.empty()) { + LOG_ERR("%s: invalid model url\n", __func__); + return false; + } + + if (!common_download_file_single(model.url, model.path, bearer_token, offline)) { + return false; + } + + // check for additional GGUFs split to download + int n_split = 0; + { + struct gguf_init_params gguf_params = { + /*.no_alloc = */ true, + /*.ctx = */ NULL, + }; + auto * ctx_gguf = gguf_init_from_file(model.path.c_str(), gguf_params); + if (!ctx_gguf) { + LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, model.path.c_str()); + return false; + } + + auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT); + if (key_n_split >= 0) { + n_split = gguf_get_val_u16(ctx_gguf, key_n_split); + } + + gguf_free(ctx_gguf); + } + + if (n_split > 1) { + char split_prefix[PATH_MAX] = {0}; + char split_url_prefix[LLAMA_CURL_MAX_URL_LENGTH] = {0}; + + // Verify the first split file format + // and extract split URL and PATH prefixes + { + if (!llama_split_prefix(split_prefix, sizeof(split_prefix), model.path.c_str(), 0, n_split)) { + LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, model.path.c_str(), n_split); + return false; + } + + if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model.url.c_str(), 0, n_split)) { + LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model.url.c_str(), n_split); + return false; + } + } + + std::vector> urls; + for (int idx = 1; idx < n_split; idx++) { + char split_path[PATH_MAX] = {0}; + llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split); + + char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0}; + llama_split_path(split_url, sizeof(split_url), split_url_prefix, idx, n_split); + + if (std::string(split_path) == model.path) { + continue; // skip the already downloaded file + } + + urls.push_back({split_url, split_path}); + } + + // Download in parallel + common_download_file_multiple(urls, bearer_token, offline); + } + + return true; +} + +std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params) { + curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_slist_ptr http_headers; + std::vector res_buffer; + + curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); + curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); + typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data); + auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t { + auto data_vec = static_cast *>(data); + data_vec->insert(data_vec->end(), (char *)ptr, (char *)ptr + size * nmemb); + return size * nmemb; + }; + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_buffer); +#if defined(_WIN32) + curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); +#endif + if (params.timeout > 0) { + curl_easy_setopt(curl.get(), CURLOPT_TIMEOUT, params.timeout); + } + if (params.max_size > 0) { + curl_easy_setopt(curl.get(), CURLOPT_MAXFILESIZE, params.max_size); + } + http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); + for (const auto & header : params.headers) { + http_headers.ptr = curl_slist_append(http_headers.ptr, header.c_str()); + } + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); + + CURLcode res = curl_easy_perform(curl.get()); + + if (res != CURLE_OK) { + std::string error_msg = curl_easy_strerror(res); + throw std::runtime_error("error: cannot make GET request: " + error_msg); + } + + long res_code; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code); + + return { res_code, std::move(res_buffer) }; +} + +/** + * Allow getting the HF file from the HF repo with tag (like ollama), for example: + * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4 + * - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M + * - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s + * Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo) + * + * Return pair of (with "repo" already having tag removed) + * + * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files. + */ +static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token, bool offline) { + auto parts = string_split(hf_repo_with_tag, ':'); + std::string tag = parts.size() > 1 ? parts.back() : "latest"; + std::string hf_repo = parts[0]; + if (string_split(hf_repo, '/').size() != 2) { + throw std::invalid_argument("error: invalid HF repo format, expected /[:quant]\n"); + } + + std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag; + + // headers + std::vector headers; + headers.push_back("Accept: application/json"); + if (!bearer_token.empty()) { + headers.push_back("Authorization: Bearer " + bearer_token); + } + // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response + // User-Agent header is already set in common_remote_get_content, no need to set it here + + // we use "=" to avoid clashing with other component, while still being allowed on windows + std::string cached_response_fname = "manifest=" + hf_repo + "=" + tag + ".json"; + string_replace_all(cached_response_fname, "/", "_"); + std::string cached_response_path = fs_get_cache_file(cached_response_fname); + + // make the request + common_remote_params params; + params.headers = headers; + long res_code = 0; + std::string res_str; + bool use_cache = false; + if (!offline) { + try { + auto res = common_remote_get_content(url, params); + res_code = res.first; + res_str = std::string(res.second.data(), res.second.size()); + } catch (const std::exception & e) { + LOG_WRN("error: failed to get manifest at %s: %s\n", url.c_str(), e.what()); + } + } + if (res_code == 0) { + if (std::filesystem::exists(cached_response_path)) { + LOG_WRN("trying to read manifest from cache: %s\n", cached_response_path.c_str()); + res_str = read_file(cached_response_path); + res_code = 200; + use_cache = true; + } else { + throw std::runtime_error( + offline ? "error: failed to get manifest (offline mode)" + : "error: failed to get manifest (check your internet connection)"); + } + } + std::string ggufFile; + std::string mmprojFile; + + if (res_code == 200 || res_code == 304) { + // extract ggufFile.rfilename in json, using regex + { + std::regex pattern("\"ggufFile\"[\\s\\S]*?\"rfilename\"\\s*:\\s*\"([^\"]+)\""); + std::smatch match; + if (std::regex_search(res_str, match, pattern)) { + ggufFile = match[1].str(); + } + } + // extract mmprojFile.rfilename in json, using regex + { + std::regex pattern("\"mmprojFile\"[\\s\\S]*?\"rfilename\"\\s*:\\s*\"([^\"]+)\""); + std::smatch match; + if (std::regex_search(res_str, match, pattern)) { + mmprojFile = match[1].str(); + } + } + if (!use_cache) { + // if not using cached response, update the cache file + write_file(cached_response_path, res_str); + } + } else if (res_code == 401) { + throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token"); + } else { + throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str())); + } + + // check response + if (ggufFile.empty()) { + throw std::runtime_error("error: model does not have ggufFile"); + } + + return { hf_repo, ggufFile, mmprojFile }; +} + +#else + +bool common_has_curl() { + return false; +} + +static bool common_download_file_single(const std::string &, const std::string &, const std::string &, bool) { + LOG_ERR("error: built without CURL, cannot download model from internet\n"); + return false; +} + +static bool common_download_file_multiple(const std::vector> &, const std::string &, bool) { + LOG_ERR("error: built without CURL, cannot download model from the internet\n"); + return false; +} + +static bool common_download_model( + const common_params_model &, + const std::string &, + bool) { + LOG_ERR("error: built without CURL, cannot download model from the internet\n"); + return false; +} + +static struct common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool) { + LOG_ERR("error: built without CURL, cannot download model from the internet\n"); + return {}; +} + +std::pair> common_remote_get_content(const std::string & url, const common_remote_params &) { + if (!url.empty()) { + throw std::runtime_error("error: built without CURL, cannot download model from the internet"); + } + + return {}; +} + +#endif // LLAMA_USE_CURL + +// +// utils +// + +struct handle_model_result { + bool found_mmproj = false; + common_params_model mmproj; +}; + +static handle_model_result common_params_handle_model( + struct common_params_model & model, + const std::string & bearer_token, + const std::string & model_path_default, + bool offline) { + handle_model_result result; + // handle pre-fill default model path and url based on hf_repo and hf_file + { + if (!model.hf_repo.empty()) { + // short-hand to avoid specifying --hf-file -> default it to --model + if (model.hf_file.empty()) { + if (model.path.empty()) { + auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token, offline); + if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) { + exit(1); // built without CURL, error message already printed + } + model.hf_repo = auto_detected.repo; + model.hf_file = auto_detected.ggufFile; + if (!auto_detected.mmprojFile.empty()) { + result.found_mmproj = true; + result.mmproj.hf_repo = model.hf_repo; + result.mmproj.hf_file = auto_detected.mmprojFile; + } + } else { + model.hf_file = model.path; + } + } + + std::string model_endpoint = get_model_endpoint(); + model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file; + // make sure model path is present (for caching purposes) + if (model.path.empty()) { + // this is to avoid different repo having same file name, or same file name in different subdirs + std::string filename = model.hf_repo + "_" + model.hf_file; + // to make sure we don't have any slashes in the filename + string_replace_all(filename, "/", "_"); + model.path = fs_get_cache_file(filename); + } + + } else if (!model.url.empty()) { + if (model.path.empty()) { + auto f = string_split(model.url, '#').front(); + f = string_split(f, '?').front(); + model.path = fs_get_cache_file(string_split(f, '/').back()); + } + + } else if (model.path.empty()) { + model.path = model_path_default; + } + } + + // then, download it if needed + if (!model.url.empty()) { + bool ok = common_download_model(model, bearer_token, offline); + if (!ok) { + LOG_ERR("error: failed to download model from %s\n", model.url.c_str()); + exit(1); + } + } + + return result; +} + +const std::vector kv_cache_types = { + GGML_TYPE_F32, + GGML_TYPE_F16, + GGML_TYPE_BF16, + GGML_TYPE_Q8_0, + GGML_TYPE_Q4_0, + GGML_TYPE_Q4_1, + GGML_TYPE_IQ4_NL, + GGML_TYPE_Q5_0, + GGML_TYPE_Q5_1, +}; + +static ggml_type kv_cache_type_from_str(const std::string & s) { + for (const auto & type : kv_cache_types) { + if (ggml_type_name(type) == s) { + return type; + } + } + throw std::runtime_error("Unsupported cache type: " + s); +} + +static std::string get_all_kv_cache_types() { + std::ostringstream msg; + for (const auto & type : kv_cache_types) { + msg << ggml_type_name(type) << (&type == &kv_cache_types.back() ? "" : ", "); + } + return msg.str(); +} + +// +// CLI argument parsing functions +// + +static bool common_params_parse_ex(int argc, char ** argv, common_params_context & ctx_arg) { + std::string arg; + const std::string arg_prefix = "--"; + common_params & params = ctx_arg.params; + + std::unordered_map arg_to_options; + for (auto & opt : ctx_arg.options) { + for (const auto & arg : opt.args) { + arg_to_options[arg] = &opt; + } + } + + // handle environment variables + for (auto & opt : ctx_arg.options) { + std::string value; + if (opt.get_value_from_env(value)) { + try { + if (opt.handler_void && (value == "1" || value == "true")) { + opt.handler_void(params); + } + if (opt.handler_int) { + opt.handler_int(params, std::stoi(value)); + } + if (opt.handler_string) { + opt.handler_string(params, value); + continue; + } + } catch (std::exception & e) { + throw std::invalid_argument(string_format( + "error while handling environment variable \"%s\": %s\n\n", opt.env, e.what())); + } + } + } + + // handle command line arguments + auto check_arg = [&](int i) { + if (i+1 >= argc) { + throw std::invalid_argument("expected value for argument"); + } + }; + + for (int i = 1; i < argc; i++) { + const std::string arg_prefix = "--"; + + std::string arg = argv[i]; + if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { + std::replace(arg.begin(), arg.end(), '_', '-'); + } + if (arg_to_options.find(arg) == arg_to_options.end()) { + throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str())); + } + auto opt = *arg_to_options[arg]; + if (opt.has_value_from_env()) { + fprintf(stderr, "warn: %s environment variable is set, but will be overwritten by command line argument %s\n", opt.env, arg.c_str()); + } + try { + if (opt.handler_void) { + opt.handler_void(params); + continue; + } + + // arg with single value + check_arg(i); + std::string val = argv[++i]; + if (opt.handler_int) { + opt.handler_int(params, std::stoi(val)); + continue; + } + if (opt.handler_string) { + opt.handler_string(params, val); + continue; + } + + // arg with 2 values + check_arg(i); + std::string val2 = argv[++i]; + if (opt.handler_str_str) { + opt.handler_str_str(params, val, val2); + continue; + } + } catch (std::exception & e) { + throw std::invalid_argument(string_format( + "error while handling argument \"%s\": %s\n\n" + "usage:\n%s\n\nto show complete usage, run with -h", + arg.c_str(), e.what(), arg_to_options[arg]->to_string().c_str())); + } + } + + postprocess_cpu_params(params.cpuparams, nullptr); + postprocess_cpu_params(params.cpuparams_batch, ¶ms.cpuparams); + + postprocess_cpu_params(params.speculative.cpuparams, ¶ms.cpuparams); + postprocess_cpu_params(params.speculative.cpuparams_batch, ¶ms.cpuparams_batch); + + if (params.prompt_cache_all && (params.interactive || params.interactive_first)) { + throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); + } + + // handle model and download + { + auto res = common_params_handle_model(params.model, params.hf_token, DEFAULT_MODEL_PATH, params.offline); + if (params.no_mmproj) { + params.mmproj = {}; + } else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) { + // optionally, handle mmproj model when -hf is specified + params.mmproj = res.mmproj; + } + // only download mmproj if the current example is using it + for (auto & ex : mmproj_examples) { + if (ctx_arg.ex == ex) { + common_params_handle_model(params.mmproj, params.hf_token, "", params.offline); + break; + } + } + common_params_handle_model(params.speculative.model, params.hf_token, "", params.offline); + common_params_handle_model(params.vocoder.model, params.hf_token, "", params.offline); + } + + if (params.escape) { + string_process_escapes(params.prompt); + string_process_escapes(params.input_prefix); + string_process_escapes(params.input_suffix); + for (auto & antiprompt : params.antiprompt) { + string_process_escapes(antiprompt); + } + for (auto & seq_breaker : params.sampling.dry_sequence_breakers) { + string_process_escapes(seq_breaker); + } + } + + if (!params.kv_overrides.empty()) { + params.kv_overrides.emplace_back(); + params.kv_overrides.back().key[0] = 0; + } + + if (!params.tensor_buft_overrides.empty()) { + params.tensor_buft_overrides.push_back({nullptr, nullptr}); + } + + if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) { + throw std::runtime_error(string_format( + "error: the supplied chat template is not supported: %s%s\n", + params.chat_template.c_str(), + params.use_jinja ? "" : "\nnote: llama.cpp was started without --jinja, we only support commonly used templates" + )); + } + + return true; +} + +static void common_params_print_usage(common_params_context & ctx_arg) { + auto print_options = [](std::vector & options) { + for (common_arg * opt : options) { + printf("%s", opt->to_string().c_str()); + } + }; + + std::vector common_options; + std::vector sparam_options; + std::vector specific_options; + for (auto & opt : ctx_arg.options) { + // in case multiple LLAMA_EXAMPLE_* are set, we prioritize the LLAMA_EXAMPLE_* matching current example + if (opt.is_sparam) { + sparam_options.push_back(&opt); + } else if (opt.in_example(ctx_arg.ex)) { + specific_options.push_back(&opt); + } else { + common_options.push_back(&opt); + } + } + printf("----- common params -----\n\n"); + print_options(common_options); + printf("\n\n----- sampling params -----\n\n"); + print_options(sparam_options); + // TODO: maybe convert enum llama_example to string + printf("\n\n----- example-specific params -----\n\n"); + print_options(specific_options); +} + +static void common_params_print_completion(common_params_context & ctx_arg) { + std::vector common_options; + std::vector sparam_options; + std::vector specific_options; + + for (auto & opt : ctx_arg.options) { + if (opt.is_sparam) { + sparam_options.push_back(&opt); + } else if (opt.in_example(ctx_arg.ex)) { + specific_options.push_back(&opt); + } else { + common_options.push_back(&opt); + } + } + + printf("_llama_completions() {\n"); + printf(" local cur prev opts\n"); + printf(" COMPREPLY=()\n"); + printf(" cur=\"${COMP_WORDS[COMP_CWORD]}\"\n"); + printf(" prev=\"${COMP_WORDS[COMP_CWORD-1]}\"\n\n"); + + printf(" opts=\""); + auto print_options = [](const std::vector & options) { + for (const common_arg * opt : options) { + for (const char * arg : opt->args) { + printf("%s ", arg); + } + } + }; + + print_options(common_options); + print_options(sparam_options); + print_options(specific_options); + printf("\"\n\n"); + + printf(" case \"$prev\" in\n"); + printf(" --model)\n"); + printf(" COMPREPLY=( $(compgen -f -X '!*.gguf' -- \"$cur\") $(compgen -d -- \"$cur\") )\n"); + printf(" return 0\n"); + printf(" ;;\n"); + printf(" --grammar-file)\n"); + printf(" COMPREPLY=( $(compgen -f -X '!*.gbnf' -- \"$cur\") $(compgen -d -- \"$cur\") )\n"); + printf(" return 0\n"); + printf(" ;;\n"); + printf(" --chat-template-file)\n"); + printf(" COMPREPLY=( $(compgen -f -X '!*.jinja' -- \"$cur\") $(compgen -d -- \"$cur\") )\n"); + printf(" return 0\n"); + printf(" ;;\n"); + printf(" *)\n"); + printf(" COMPREPLY=( $(compgen -W \"${opts}\" -- \"$cur\") )\n"); + printf(" return 0\n"); + printf(" ;;\n"); + printf(" esac\n"); + printf("}\n\n"); + + std::set executables = { + "llama-batched", + "llama-batched-bench", + "llama-bench", + "llama-cli", + "llama-convert-llama2c-to-ggml", + "llama-cvector-generator", + "llama-embedding", + "llama-eval-callback", + "llama-export-lora", + "llama-gen-docs", + "llama-gguf", + "llama-gguf-hash", + "llama-gguf-split", + "llama-gritlm", + "llama-imatrix", + "llama-infill", + "llama-mtmd-cli", + "llama-llava-clip-quantize-cli", + "llama-lookahead", + "llama-lookup", + "llama-lookup-create", + "llama-lookup-merge", + "llama-lookup-stats", + "llama-parallel", + "llama-passkey", + "llama-perplexity", + "llama-q8dot", + "llama-quantize", + "llama-qwen2vl-cli", + "llama-retrieval", + "llama-run", + "llama-save-load-state", + "llama-server", + "llama-simple", + "llama-simple-chat", + "llama-speculative", + "llama-speculative-simple", + "llama-tokenize", + "llama-tts", + "llama-vdot" + }; + + for (const auto& exe : executables) { + printf("complete -F _llama_completions %s\n", exe.c_str()); + } +} + +static std::vector parse_device_list(const std::string & value) { + std::vector devices; + auto dev_names = string_split(value, ','); + if (dev_names.empty()) { + throw std::invalid_argument("no devices specified"); + } + if (dev_names.size() == 1 && dev_names[0] == "none") { + devices.push_back(nullptr); + } else { + for (const auto & device : dev_names) { + auto * dev = ggml_backend_dev_by_name(device.c_str()); + if (!dev || ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_GPU) { + throw std::invalid_argument(string_format("invalid device: %s", device.c_str())); + } + devices.push_back(dev); + } + devices.push_back(nullptr); + } + return devices; +} + +static void add_rpc_devices(std::string servers) { + auto rpc_servers = string_split(servers, ','); + if (rpc_servers.empty()) { + throw std::invalid_argument("no RPC servers specified"); + } + ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC"); + if (!rpc_reg) { + throw std::invalid_argument("failed to find RPC backend"); + } + typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint); + ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device"); + if (!ggml_backend_rpc_add_device_fn) { + throw std::invalid_argument("failed to find RPC device add function"); + } + for (const auto & server : rpc_servers) { + ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str()); + if (dev) { + ggml_backend_device_register(dev); + } else { + throw std::invalid_argument("failed to register RPC device"); + } + } +} + +bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) { + auto ctx_arg = common_params_parser_init(params, ex, print_usage); + const common_params params_org = ctx_arg.params; // the example can modify the default params + + try { + if (!common_params_parse_ex(argc, argv, ctx_arg)) { + ctx_arg.params = params_org; + return false; + } + if (ctx_arg.params.usage) { + common_params_print_usage(ctx_arg); + if (ctx_arg.print_usage) { + ctx_arg.print_usage(argc, argv); + } + exit(0); + } + if (ctx_arg.params.completion) { + common_params_print_completion(ctx_arg); + exit(0); + } + } catch (const std::invalid_argument & ex) { + fprintf(stderr, "%s\n", ex.what()); + ctx_arg.params = params_org; + return false; + } catch (std::exception & ex) { + fprintf(stderr, "%s\n", ex.what()); + exit(1); // for other exceptions, we exit with status code 1 + } + + return true; +} + +static std::string list_builtin_chat_templates() { + std::vector supported_tmpl; + int32_t res = llama_chat_builtin_templates(nullptr, 0); + supported_tmpl.resize(res); + res = llama_chat_builtin_templates(supported_tmpl.data(), supported_tmpl.size()); + std::ostringstream msg; + for (auto & tmpl : supported_tmpl) { + msg << tmpl << (&tmpl == &supported_tmpl.back() ? "" : ", "); + } + return msg.str(); +} + +common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) { + // load dynamic backends + ggml_backend_load_all(); + + common_params_context ctx_arg(params); + ctx_arg.print_usage = print_usage; + ctx_arg.ex = ex; + + std::string sampler_type_chars; + std::string sampler_type_names; + for (const auto & sampler : params.sampling.samplers) { + sampler_type_chars += common_sampler_type_to_chr(sampler); + sampler_type_names += common_sampler_type_to_str(sampler) + ";"; + } + sampler_type_names.pop_back(); + + + /** + * filter options by example + * rules: + * - all examples inherit options from LLAMA_EXAMPLE_COMMON + * - if LLAMA_EXAMPLE_* is set (other than COMMON), we only show the option in the corresponding example + * - if both {LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_*,} are set, we will prioritize the LLAMA_EXAMPLE_* matching current example + */ + auto add_opt = [&](common_arg arg) { + if ((arg.in_example(ex) || arg.in_example(LLAMA_EXAMPLE_COMMON)) && !arg.is_exclude(ex)) { + ctx_arg.options.push_back(std::move(arg)); + } + }; + + + add_opt(common_arg( + {"-h", "--help", "--usage"}, + "print usage and exit", + [](common_params & params) { + params.usage = true; + } + )); + add_opt(common_arg( + {"--version"}, + "show version and build info", + [](common_params &) { + fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT); + fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET); + exit(0); + } + )); + add_opt(common_arg( + {"--completion-bash"}, + "print source-able bash completion script for llama.cpp", + [](common_params & params) { + params.completion = true; + } + )); + add_opt(common_arg( + {"--verbose-prompt"}, + string_format("print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false"), + [](common_params & params) { + params.verbose_prompt = true; + } + )); + add_opt(common_arg( + {"--no-display-prompt"}, + string_format("don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false"), + [](common_params & params) { + params.display_prompt = false; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"-co", "--color"}, + string_format("colorise output to distinguish prompt and user input from generations (default: %s)", params.use_color ? "true" : "false"), + [](common_params & params) { + params.use_color = true; + } + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP})); + add_opt(common_arg( + {"-t", "--threads"}, "N", + string_format("number of threads to use during generation (default: %d)", params.cpuparams.n_threads), + [](common_params & params, int value) { + params.cpuparams.n_threads = value; + if (params.cpuparams.n_threads <= 0) { + params.cpuparams.n_threads = std::thread::hardware_concurrency(); + } + } + ).set_env("LLAMA_ARG_THREADS")); + add_opt(common_arg( + {"-tb", "--threads-batch"}, "N", + "number of threads to use during batch and prompt processing (default: same as --threads)", + [](common_params & params, int value) { + params.cpuparams_batch.n_threads = value; + if (params.cpuparams_batch.n_threads <= 0) { + params.cpuparams_batch.n_threads = std::thread::hardware_concurrency(); + } + } + )); + add_opt(common_arg( + {"-C", "--cpu-mask"}, "M", + "CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: \"\")", + [](common_params & params, const std::string & mask) { + params.cpuparams.mask_valid = true; + if (!parse_cpu_mask(mask, params.cpuparams.cpumask)) { + throw std::invalid_argument("invalid cpumask"); + } + } + )); + add_opt(common_arg( + {"-Cr", "--cpu-range"}, "lo-hi", + "range of CPUs for affinity. Complements --cpu-mask", + [](common_params & params, const std::string & range) { + params.cpuparams.mask_valid = true; + if (!parse_cpu_range(range, params.cpuparams.cpumask)) { + throw std::invalid_argument("invalid range"); + } + } + )); + add_opt(common_arg( + {"--cpu-strict"}, "<0|1>", + string_format("use strict CPU placement (default: %u)\n", (unsigned) params.cpuparams.strict_cpu), + [](common_params & params, const std::string & value) { + params.cpuparams.strict_cpu = std::stoul(value); + } + )); + add_opt(common_arg( + {"--prio"}, "N", + string_format("set process/thread priority : low(-1), normal(0), medium(1), high(2), realtime(3) (default: %d)\n", params.cpuparams.priority), + [](common_params & params, int prio) { + if (prio < GGML_SCHED_PRIO_LOW || prio > GGML_SCHED_PRIO_REALTIME) { + throw std::invalid_argument("invalid value"); + } + params.cpuparams.priority = (enum ggml_sched_priority) prio; + } + )); + add_opt(common_arg( + {"--poll"}, "<0...100>", + string_format("use polling level to wait for work (0 - no polling, default: %u)\n", (unsigned) params.cpuparams.poll), + [](common_params & params, const std::string & value) { + params.cpuparams.poll = std::stoul(value); + } + )); + add_opt(common_arg( + {"-Cb", "--cpu-mask-batch"}, "M", + "CPU affinity mask: arbitrarily long hex. Complements cpu-range-batch (default: same as --cpu-mask)", + [](common_params & params, const std::string & mask) { + params.cpuparams_batch.mask_valid = true; + if (!parse_cpu_mask(mask, params.cpuparams_batch.cpumask)) { + throw std::invalid_argument("invalid cpumask"); + } + } + )); + add_opt(common_arg( + {"-Crb", "--cpu-range-batch"}, "lo-hi", + "ranges of CPUs for affinity. Complements --cpu-mask-batch", + [](common_params & params, const std::string & range) { + params.cpuparams_batch.mask_valid = true; + if (!parse_cpu_range(range, params.cpuparams_batch.cpumask)) { + throw std::invalid_argument("invalid range"); + } + } + )); + add_opt(common_arg( + {"--cpu-strict-batch"}, "<0|1>", + "use strict CPU placement (default: same as --cpu-strict)", + [](common_params & params, int value) { + params.cpuparams_batch.strict_cpu = value; + } + )); + add_opt(common_arg( + {"--prio-batch"}, "N", + string_format("set process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.cpuparams_batch.priority), + [](common_params & params, int prio) { + if (prio < 0 || prio > 3) { + throw std::invalid_argument("invalid value"); + } + params.cpuparams_batch.priority = (enum ggml_sched_priority) prio; + } + )); + add_opt(common_arg( + {"--poll-batch"}, "<0|1>", + "use polling to wait for work (default: same as --poll)", + [](common_params & params, int value) { + params.cpuparams_batch.poll = value; + } + )); + add_opt(common_arg( + {"-lcs", "--lookup-cache-static"}, "FNAME", + "path to static lookup cache to use for lookup decoding (not updated by generation)", + [](common_params & params, const std::string & value) { + params.lookup_cache_static = value; + } + ).set_examples({LLAMA_EXAMPLE_LOOKUP})); + add_opt(common_arg( + {"-lcd", "--lookup-cache-dynamic"}, "FNAME", + "path to dynamic lookup cache to use for lookup decoding (updated by generation)", + [](common_params & params, const std::string & value) { + params.lookup_cache_dynamic = value; + } + ).set_examples({LLAMA_EXAMPLE_LOOKUP})); + add_opt(common_arg( + {"-c", "--ctx-size"}, "N", + string_format("size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx), + [](common_params & params, int value) { + params.n_ctx = value; + } + ).set_env("LLAMA_ARG_CTX_SIZE")); + add_opt(common_arg( + {"-n", "--predict", "--n-predict"}, "N", + string_format( + ex == LLAMA_EXAMPLE_MAIN + ? "number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)" + : "number of tokens to predict (default: %d, -1 = infinity)", + params.n_predict), + [](common_params & params, int value) { + params.n_predict = value; + } + ).set_env("LLAMA_ARG_N_PREDICT")); + add_opt(common_arg( + {"-b", "--batch-size"}, "N", + string_format("logical maximum batch size (default: %d)", params.n_batch), + [](common_params & params, int value) { + params.n_batch = value; + } + ).set_env("LLAMA_ARG_BATCH")); + add_opt(common_arg( + {"-ub", "--ubatch-size"}, "N", + string_format("physical maximum batch size (default: %d)", params.n_ubatch), + [](common_params & params, int value) { + params.n_ubatch = value; + } + ).set_env("LLAMA_ARG_UBATCH")); + add_opt(common_arg( + {"--keep"}, "N", + string_format("number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep), + [](common_params & params, int value) { + params.n_keep = value; + } + )); + add_opt(common_arg( + {"--swa-full"}, + string_format("use full-size SWA cache (default: %s)\n" + "[(more info)](https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)", params.swa_full ? "true" : "false"), + [](common_params & params) { + params.swa_full = true; + } + ).set_env("LLAMA_ARG_SWA_FULL")); + add_opt(common_arg( + {"--no-context-shift"}, + string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"), + [](common_params & params) { + params.ctx_shift = false; + } + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT")); + add_opt(common_arg( + {"--chunks"}, "N", + string_format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks), + [](common_params & params, int value) { + params.n_chunks = value; + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL})); + add_opt(common_arg( + {"-fa", "--flash-attn"}, + string_format("enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled"), + [](common_params & params) { + params.flash_attn = true; + } + ).set_env("LLAMA_ARG_FLASH_ATTN")); + add_opt(common_arg( + {"-p", "--prompt"}, "PROMPT", + "prompt to start generation with; for system message, use -sys", + [](common_params & params, const std::string & value) { + params.prompt = value; + } + ).set_excludes({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"-sys", "--system-prompt"}, "PROMPT", + "system prompt to use with model (if applicable, depending on chat template)", + [](common_params & params, const std::string & value) { + params.system_prompt = value; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"--no-perf"}, + string_format("disable internal libllama performance timings (default: %s)", params.no_perf ? "true" : "false"), + [](common_params & params) { + params.no_perf = true; + params.sampling.no_perf = true; + } + ).set_env("LLAMA_ARG_NO_PERF")); + add_opt(common_arg( + {"-f", "--file"}, "FNAME", + "a file containing the prompt (default: none)", + [](common_params & params, const std::string & value) { + params.prompt = read_file(value); + // store the external file name in params + params.prompt_file = value; + if (!params.prompt.empty() && params.prompt.back() == '\n') { + params.prompt.pop_back(); + } + } + ).set_excludes({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"-sysf", "--system-prompt-file"}, "FNAME", + "a file containing the system prompt (default: none)", + [](common_params & params, const std::string & value) { + params.system_prompt = read_file(value); + if (!params.system_prompt.empty() && params.system_prompt.back() == '\n') { + params.system_prompt.pop_back(); + } + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"--in-file"}, "FNAME", + "an input file (repeat to specify multiple files)", + [](common_params & params, const std::string & value) { + std::ifstream file(value); + if (!file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); + } + params.in_files.push_back(value); + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX})); + add_opt(common_arg( + {"-bf", "--binary-file"}, "FNAME", + "binary file containing the prompt (default: none)", + [](common_params & params, const std::string & value) { + std::ifstream file(value, std::ios::binary); + if (!file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); + } + // store the external file name in params + params.prompt_file = value; + std::ostringstream ss; + ss << file.rdbuf(); + params.prompt = ss.str(); + fprintf(stderr, "Read %zu bytes from binary file %s\n", params.prompt.size(), value.c_str()); + } + ).set_excludes({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"-e", "--escape"}, + string_format("process escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\) (default: %s)", params.escape ? "true" : "false"), + [](common_params & params) { + params.escape = true; + } + )); + add_opt(common_arg( + {"--no-escape"}, + "do not process escape sequences", + [](common_params & params) { + params.escape = false; + } + )); + add_opt(common_arg( + {"-ptc", "--print-token-count"}, "N", + string_format("print token count every N tokens (default: %d)", params.n_print), + [](common_params & params, int value) { + params.n_print = value; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"--prompt-cache"}, "FNAME", + "file to cache prompt state for faster startup (default: none)", + [](common_params & params, const std::string & value) { + params.path_prompt_cache = value; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"--prompt-cache-all"}, + "if specified, saves user input and generations to cache as well\n", + [](common_params & params) { + params.prompt_cache_all = true; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"--prompt-cache-ro"}, + "if specified, uses the prompt cache but does not update it", + [](common_params & params) { + params.prompt_cache_ro = true; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"-r", "--reverse-prompt"}, "PROMPT", + "halt generation at PROMPT, return control in interactive mode\n", + [](common_params & params, const std::string & value) { + params.antiprompt.emplace_back(value); + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"-sp", "--special"}, + string_format("special tokens output enabled (default: %s)", params.special ? "true" : "false"), + [](common_params & params) { + params.special = true; + } + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"-cnv", "--conversation"}, + "run in conversation mode:\n" + "- does not print special tokens and suffix/prefix\n" + "- interactive mode is also enabled\n" + "(default: auto enabled if chat template is available)", + [](common_params & params) { + params.conversation_mode = COMMON_CONVERSATION_MODE_ENABLED; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"-no-cnv", "--no-conversation"}, + "force disable conversation mode (default: false)", + [](common_params & params) { + params.conversation_mode = COMMON_CONVERSATION_MODE_DISABLED; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"-st", "--single-turn"}, + "run conversation for a single turn only, then exit when done\n" + "will not be interactive if first turn is predefined with --prompt\n" + "(default: false)", + [](common_params & params) { + params.single_turn = true; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"-i", "--interactive"}, + string_format("run in interactive mode (default: %s)", params.interactive ? "true" : "false"), + [](common_params & params) { + params.interactive = true; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"-if", "--interactive-first"}, + string_format("run in interactive mode and wait for input right away (default: %s)", params.interactive_first ? "true" : "false"), + [](common_params & params) { + params.interactive_first = true; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"-mli", "--multiline-input"}, + "allows you to write or paste multiple lines without ending each in '\\'", + [](common_params & params) { + params.multiline_input = true; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"--in-prefix-bos"}, + "prefix BOS to user inputs, preceding the `--in-prefix` string", + [](common_params & params) { + params.input_prefix_bos = true; + params.enable_chat_template = false; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"--in-prefix"}, "STRING", + "string to prefix user inputs with (default: empty)", + [](common_params & params, const std::string & value) { + params.input_prefix = value; + params.enable_chat_template = false; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"--in-suffix"}, "STRING", + "string to suffix after user inputs with (default: empty)", + [](common_params & params, const std::string & value) { + params.input_suffix = value; + params.enable_chat_template = false; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"--no-warmup"}, + "skip warming up the model with an empty run", + [](common_params & params) { + params.warmup = false; + } + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL})); + add_opt(common_arg( + {"--spm-infill"}, + string_format( + "use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: %s)", + params.spm_infill ? "enabled" : "disabled" + ), + [](common_params & params) { + params.spm_infill = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--samplers"}, "SAMPLERS", + string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()), + [](common_params & params, const std::string & value) { + const auto sampler_names = string_split(value, ';'); + params.sampling.samplers = common_sampler_types_from_names(sampler_names, true); + } + ).set_sparam()); + add_opt(common_arg( + {"-s", "--seed"}, "SEED", + string_format("RNG seed (default: %d, use random seed for %d)", params.sampling.seed, LLAMA_DEFAULT_SEED), + [](common_params & params, const std::string & value) { + params.sampling.seed = std::stoul(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--sampling-seq", "--sampler-seq"}, "SEQUENCE", + string_format("simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str()), + [](common_params & params, const std::string & value) { + params.sampling.samplers = common_sampler_types_from_chars(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--ignore-eos"}, + "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)", + [](common_params & params) { + params.sampling.ignore_eos = true; + } + ).set_sparam()); + add_opt(common_arg( + {"--temp"}, "N", + string_format("temperature (default: %.1f)", (double)params.sampling.temp), + [](common_params & params, const std::string & value) { + params.sampling.temp = std::stof(value); + params.sampling.temp = std::max(params.sampling.temp, 0.0f); + } + ).set_sparam()); + add_opt(common_arg( + {"--top-k"}, "N", + string_format("top-k sampling (default: %d, 0 = disabled)", params.sampling.top_k), + [](common_params & params, int value) { + params.sampling.top_k = value; + } + ).set_sparam()); + add_opt(common_arg( + {"--top-p"}, "N", + string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p), + [](common_params & params, const std::string & value) { + params.sampling.top_p = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--min-p"}, "N", + string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p), + [](common_params & params, const std::string & value) { + params.sampling.min_p = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--top-nsigma"}, "N", + string_format("top-n-sigma sampling (default: %.1f, -1.0 = disabled)", params.sampling.top_n_sigma), + [](common_params & params, const std::string & value) { + params.sampling.top_n_sigma = std::stof(value); + } + ).set_examples({LLAMA_EXAMPLE_MAIN}).set_sparam()); + add_opt(common_arg( + {"--xtc-probability"}, "N", + string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability), + [](common_params & params, const std::string & value) { + params.sampling.xtc_probability = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--xtc-threshold"}, "N", + string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold), + [](common_params & params, const std::string & value) { + params.sampling.xtc_threshold = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--typical"}, "N", + string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sampling.typ_p), + [](common_params & params, const std::string & value) { + params.sampling.typ_p = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--repeat-last-n"}, "N", + string_format("last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", params.sampling.penalty_last_n), + [](common_params & params, int value) { + if (value < -1) { + throw std::runtime_error(string_format("error: invalid repeat-last-n = %d\n", value)); + } + params.sampling.penalty_last_n = value; + params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n); + } + ).set_sparam()); + add_opt(common_arg( + {"--repeat-penalty"}, "N", + string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat), + [](common_params & params, const std::string & value) { + params.sampling.penalty_repeat = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--presence-penalty"}, "N", + string_format("repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_present), + [](common_params & params, const std::string & value) { + params.sampling.penalty_present = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--frequency-penalty"}, "N", + string_format("repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_freq), + [](common_params & params, const std::string & value) { + params.sampling.penalty_freq = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--dry-multiplier"}, "N", + string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sampling.dry_multiplier), + [](common_params & params, const std::string & value) { + params.sampling.dry_multiplier = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--dry-base"}, "N", + string_format("set DRY sampling base value (default: %.2f)", (double)params.sampling.dry_base), + [](common_params & params, const std::string & value) { + float potential_base = std::stof(value); + if (potential_base >= 1.0f) + { + params.sampling.dry_base = potential_base; + } + } + ).set_sparam()); + add_opt(common_arg( + {"--dry-allowed-length"}, "N", + string_format("set allowed length for DRY sampling (default: %d)", params.sampling.dry_allowed_length), + [](common_params & params, int value) { + params.sampling.dry_allowed_length = value; + } + ).set_sparam()); + add_opt(common_arg( + {"--dry-penalty-last-n"}, "N", + string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sampling.dry_penalty_last_n), + [](common_params & params, int value) { + if (value < -1) { + throw std::runtime_error(string_format("error: invalid dry-penalty-last-n = %d\n", value)); + } + params.sampling.dry_penalty_last_n = value; + } + ).set_sparam()); + add_opt(common_arg( + {"--dry-sequence-breaker"}, "STRING", + string_format("add sequence breaker for DRY sampling, clearing out default breakers (%s) in the process; use \"none\" to not use any sequence breakers\n", + params.sampling.dry_sequence_breakers.empty() ? "none" : + std::accumulate(std::next(params.sampling.dry_sequence_breakers.begin()), + params.sampling.dry_sequence_breakers.end(), + std::string("'") + (params.sampling.dry_sequence_breakers[0] == "\n" ? "\\n" : params.sampling.dry_sequence_breakers[0]) + "'", + [](const std::string& a, const std::string& b) { + std::string formatted_b = (b == "\n") ? "\\n" : b; + return a + ", '" + formatted_b + "'"; + }).c_str()), + [](common_params & params, const std::string & value) { + static bool defaults_cleared = false; + + if (!defaults_cleared) { + params.sampling.dry_sequence_breakers.clear(); + defaults_cleared = true; + } + + if (value == "none") { + params.sampling.dry_sequence_breakers.clear(); + } else { + params.sampling.dry_sequence_breakers.emplace_back(value); + } + } + ).set_sparam()); + add_opt(common_arg( + {"--dynatemp-range"}, "N", + string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sampling.dynatemp_range), + [](common_params & params, const std::string & value) { + params.sampling.dynatemp_range = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--dynatemp-exp"}, "N", + string_format("dynamic temperature exponent (default: %.1f)", (double)params.sampling.dynatemp_exponent), + [](common_params & params, const std::string & value) { + params.sampling.dynatemp_exponent = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--mirostat"}, "N", + string_format("use Mirostat sampling.\nTop K, Nucleus and Locally Typical samplers are ignored if used.\n" + "(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sampling.mirostat), + [](common_params & params, int value) { + params.sampling.mirostat = value; + } + ).set_sparam()); + add_opt(common_arg( + {"--mirostat-lr"}, "N", + string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta), + [](common_params & params, const std::string & value) { + params.sampling.mirostat_eta = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--mirostat-ent"}, "N", + string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau), + [](common_params & params, const std::string & value) { + params.sampling.mirostat_tau = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"-l", "--logit-bias"}, "TOKEN_ID(+/-)BIAS", + "modifies the likelihood of token appearing in the completion,\n" + "i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n" + "or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'", + [](common_params & params, const std::string & value) { + std::stringstream ss(value); + llama_token key; + char sign; + std::string value_str; + try { + if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) { + const float bias = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); + params.sampling.logit_bias.push_back({key, bias}); + } else { + throw std::invalid_argument("invalid input format"); + } + } catch (const std::exception&) { + throw std::invalid_argument("invalid input format"); + } + } + ).set_sparam()); + add_opt(common_arg( + {"--grammar"}, "GRAMMAR", + string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sampling.grammar.c_str()), + [](common_params & params, const std::string & value) { + params.sampling.grammar = value; + } + ).set_sparam()); + add_opt(common_arg( + {"--grammar-file"}, "FNAME", + "file to read grammar from", + [](common_params & params, const std::string & value) { + params.sampling.grammar = read_file(value); + } + ).set_sparam()); + add_opt(common_arg( + {"-j", "--json-schema"}, "SCHEMA", + "JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead", + [](common_params & params, const std::string & value) { + params.sampling.grammar = json_schema_to_grammar(json::parse(value)); + } + ).set_sparam()); + add_opt(common_arg( + {"-jf", "--json-schema-file"}, "FILE", + "File containing a JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead", + [](common_params & params, const std::string & value) { + std::ifstream file(value); + if (!file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); + } + std::string schema; + std::copy( + std::istreambuf_iterator(file), + std::istreambuf_iterator(), + std::back_inserter(schema) + ); + params.sampling.grammar = json_schema_to_grammar(json::parse(schema)); + } + ).set_sparam()); + add_opt(common_arg( + {"--pooling"}, "{none,mean,cls,last,rank}", + "pooling type for embeddings, use model default if unspecified", + [](common_params & params, const std::string & value) { + /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } + else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } + else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } + else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; } + else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; } + else { throw std::invalid_argument("invalid value"); } + } + ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING")); + add_opt(common_arg( + {"--attention"}, "{causal,non-causal}", + "attention type for embeddings, use model default if unspecified", + [](common_params & params, const std::string & value) { + /**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; } + else if (value == "non-causal") { params.attention_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL; } + else { throw std::invalid_argument("invalid value"); } + } + ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); + add_opt(common_arg( + {"--rope-scaling"}, "{none,linear,yarn}", + "RoPE frequency scaling method, defaults to linear unless specified by the model", + [](common_params & params, const std::string & value) { + /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; } + else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; } + else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; } + else { throw std::invalid_argument("invalid value"); } + } + ).set_env("LLAMA_ARG_ROPE_SCALING_TYPE")); + add_opt(common_arg( + {"--rope-scale"}, "N", + "RoPE context scaling factor, expands context by a factor of N", + [](common_params & params, const std::string & value) { + params.rope_freq_scale = 1.0f / std::stof(value); + } + ).set_env("LLAMA_ARG_ROPE_SCALE")); + add_opt(common_arg( + {"--rope-freq-base"}, "N", + "RoPE base frequency, used by NTK-aware scaling (default: loaded from model)", + [](common_params & params, const std::string & value) { + params.rope_freq_base = std::stof(value); + } + ).set_env("LLAMA_ARG_ROPE_FREQ_BASE")); + add_opt(common_arg( + {"--rope-freq-scale"}, "N", + "RoPE frequency scaling factor, expands context by a factor of 1/N", + [](common_params & params, const std::string & value) { + params.rope_freq_scale = std::stof(value); + } + ).set_env("LLAMA_ARG_ROPE_FREQ_SCALE")); + add_opt(common_arg( + {"--yarn-orig-ctx"}, "N", + string_format("YaRN: original context size of model (default: %d = model training context size)", params.yarn_orig_ctx), + [](common_params & params, int value) { + params.yarn_orig_ctx = value; + } + ).set_env("LLAMA_ARG_YARN_ORIG_CTX")); + add_opt(common_arg( + {"--yarn-ext-factor"}, "N", + string_format("YaRN: extrapolation mix factor (default: %.1f, 0.0 = full interpolation)", (double)params.yarn_ext_factor), + [](common_params & params, const std::string & value) { + params.yarn_ext_factor = std::stof(value); + } + ).set_env("LLAMA_ARG_YARN_EXT_FACTOR")); + add_opt(common_arg( + {"--yarn-attn-factor"}, "N", + string_format("YaRN: scale sqrt(t) or attention magnitude (default: %.1f)", (double)params.yarn_attn_factor), + [](common_params & params, const std::string & value) { + params.yarn_attn_factor = std::stof(value); + } + ).set_env("LLAMA_ARG_YARN_ATTN_FACTOR")); + add_opt(common_arg( + {"--yarn-beta-slow"}, "N", + string_format("YaRN: high correction dim or alpha (default: %.1f)", (double)params.yarn_beta_slow), + [](common_params & params, const std::string & value) { + params.yarn_beta_slow = std::stof(value); + } + ).set_env("LLAMA_ARG_YARN_BETA_SLOW")); + add_opt(common_arg( + {"--yarn-beta-fast"}, "N", + string_format("YaRN: low correction dim or beta (default: %.1f)", (double)params.yarn_beta_fast), + [](common_params & params, const std::string & value) { + params.yarn_beta_fast = std::stof(value); + } + ).set_env("LLAMA_ARG_YARN_BETA_FAST")); + add_opt(common_arg( + {"-gan", "--grp-attn-n"}, "N", + string_format("group-attention factor (default: %d)", params.grp_attn_n), + [](common_params & params, int value) { + params.grp_attn_n = value; + } + ).set_env("LLAMA_ARG_GRP_ATTN_N").set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_PASSKEY})); + add_opt(common_arg( + {"-gaw", "--grp-attn-w"}, "N", + string_format("group-attention width (default: %d)", params.grp_attn_w), + [](common_params & params, int value) { + params.grp_attn_w = value; + } + ).set_env("LLAMA_ARG_GRP_ATTN_W").set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"-nkvo", "--no-kv-offload"}, + "disable KV offload", + [](common_params & params) { + params.no_kv_offload = true; + } + ).set_env("LLAMA_ARG_NO_KV_OFFLOAD")); + add_opt(common_arg( + {"-ctk", "--cache-type-k"}, "TYPE", + string_format( + "KV cache data type for K\n" + "allowed values: %s\n" + "(default: %s)", + get_all_kv_cache_types().c_str(), + ggml_type_name(params.cache_type_k) + ), + [](common_params & params, const std::string & value) { + params.cache_type_k = kv_cache_type_from_str(value); + } + ).set_env("LLAMA_ARG_CACHE_TYPE_K")); + add_opt(common_arg( + {"-ctv", "--cache-type-v"}, "TYPE", + string_format( + "KV cache data type for V\n" + "allowed values: %s\n" + "(default: %s)", + get_all_kv_cache_types().c_str(), + ggml_type_name(params.cache_type_v) + ), + [](common_params & params, const std::string & value) { + params.cache_type_v = kv_cache_type_from_str(value); + } + ).set_env("LLAMA_ARG_CACHE_TYPE_V")); + add_opt(common_arg( + {"--hellaswag"}, + "compute HellaSwag score over random tasks from datafile supplied with -f", + [](common_params & params) { + params.hellaswag = true; + } + ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); + add_opt(common_arg( + {"--hellaswag-tasks"}, "N", + string_format("number of tasks to use when computing the HellaSwag score (default: %zu)", params.hellaswag_tasks), + [](common_params & params, int value) { + params.hellaswag_tasks = value; + } + ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); + add_opt(common_arg( + {"--winogrande"}, + "compute Winogrande score over random tasks from datafile supplied with -f", + [](common_params & params) { + params.winogrande = true; + } + ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); + add_opt(common_arg( + {"--winogrande-tasks"}, "N", + string_format("number of tasks to use when computing the Winogrande score (default: %zu)", params.winogrande_tasks), + [](common_params & params, int value) { + params.winogrande_tasks = value; + } + ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); + add_opt(common_arg( + {"--multiple-choice"}, + "compute multiple choice score over random tasks from datafile supplied with -f", + [](common_params & params) { + params.multiple_choice = true; + } + ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); + add_opt(common_arg( + {"--multiple-choice-tasks"}, "N", + string_format("number of tasks to use when computing the multiple choice score (default: %zu)", params.multiple_choice_tasks), + [](common_params & params, int value) { + params.multiple_choice_tasks = value; + } + ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); + add_opt(common_arg( + {"--kl-divergence"}, + "computes KL-divergence to logits provided via --kl-divergence-base", + [](common_params & params) { + params.kl_divergence = true; + } + ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); + add_opt(common_arg( + {"--save-all-logits", "--kl-divergence-base"}, "FNAME", + "set logits file", + [](common_params & params, const std::string & value) { + params.logits_file = value; + } + ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); + add_opt(common_arg( + {"--ppl-stride"}, "N", + string_format("stride for perplexity calculation (default: %d)", params.ppl_stride), + [](common_params & params, int value) { + params.ppl_stride = value; + } + ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); + add_opt(common_arg( + {"--ppl-output-type"}, "<0|1>", + string_format("output type for perplexity calculation (default: %d)", params.ppl_output_type), + [](common_params & params, int value) { + params.ppl_output_type = value; + } + ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); + add_opt(common_arg( + {"-dt", "--defrag-thold"}, "N", + string_format("KV cache defragmentation threshold (default: %.1f, < 0 - disabled)", (double)params.defrag_thold), + [](common_params & params, const std::string & value) { + params.defrag_thold = std::stof(value); + } + ).set_env("LLAMA_ARG_DEFRAG_THOLD")); + add_opt(common_arg( + {"-np", "--parallel"}, "N", + string_format("number of parallel sequences to decode (default: %d)", params.n_parallel), + [](common_params & params, int value) { + params.n_parallel = value; + } + ).set_env("LLAMA_ARG_N_PARALLEL")); + add_opt(common_arg( + {"-ns", "--sequences"}, "N", + string_format("number of sequences to decode (default: %d)", params.n_sequences), + [](common_params & params, int value) { + params.n_sequences = value; + } + ).set_examples({LLAMA_EXAMPLE_PARALLEL})); + add_opt(common_arg( + {"-cb", "--cont-batching"}, + string_format("enable continuous batching (a.k.a dynamic batching) (default: %s)", params.cont_batching ? "enabled" : "disabled"), + [](common_params & params) { + params.cont_batching = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CONT_BATCHING")); + add_opt(common_arg( + {"-nocb", "--no-cont-batching"}, + "disable continuous batching", + [](common_params & params) { + params.cont_batching = false; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_CONT_BATCHING")); + add_opt(common_arg( + {"--mmproj"}, "FILE", + "path to a multimodal projector file. see tools/mtmd/README.md\n" + "note: if -hf is used, this argument can be omitted", + [](common_params & params, const std::string & value) { + params.mmproj.path = value; + } + ).set_examples(mmproj_examples).set_env("LLAMA_ARG_MMPROJ")); + add_opt(common_arg( + {"--mmproj-url"}, "URL", + "URL to a multimodal projector file. see tools/mtmd/README.md", + [](common_params & params, const std::string & value) { + params.mmproj.url = value; + } + ).set_examples(mmproj_examples).set_env("LLAMA_ARG_MMPROJ_URL")); + add_opt(common_arg( + {"--no-mmproj"}, + "explicitly disable multimodal projector, useful when using -hf", + [](common_params & params) { + params.no_mmproj = true; + } + ).set_examples(mmproj_examples).set_env("LLAMA_ARG_NO_MMPROJ")); + add_opt(common_arg( + {"--no-mmproj-offload"}, + "do not offload multimodal projector to GPU", + [](common_params & params) { + params.mmproj_use_gpu = false; + } + ).set_examples(mmproj_examples).set_env("LLAMA_ARG_NO_MMPROJ_OFFLOAD")); + add_opt(common_arg( + {"--image", "--audio"}, "FILE", + "path to an image or audio file. use with multimodal models, can be repeated if you have multiple files\n", + [](common_params & params, const std::string & value) { + params.image.emplace_back(value); + } + ).set_examples({LLAMA_EXAMPLE_MTMD})); + if (llama_supports_rpc()) { + add_opt(common_arg( + {"--rpc"}, "SERVERS", + "comma separated list of RPC servers", + [](common_params & params, const std::string & value) { + add_rpc_devices(value); + GGML_UNUSED(params); + } + ).set_env("LLAMA_ARG_RPC")); + } + add_opt(common_arg( + {"--mlock"}, + "force system to keep model in RAM rather than swapping or compressing", + [](common_params & params) { + params.use_mlock = true; + } + ).set_env("LLAMA_ARG_MLOCK")); + add_opt(common_arg( + {"--no-mmap"}, + "do not memory-map model (slower load but may reduce pageouts if not using mlock)", + [](common_params & params) { + params.use_mmap = false; + } + ).set_env("LLAMA_ARG_NO_MMAP")); + add_opt(common_arg( + {"--numa"}, "TYPE", + "attempt optimizations that help on some NUMA systems\n" + "- distribute: spread execution evenly over all nodes\n" + "- isolate: only spawn threads on CPUs on the node that execution started on\n" + "- numactl: use the CPU map provided by numactl\n" + "if run without this previously, it is recommended to drop the system page cache before using this\n" + "see https://github.com/ggml-org/llama.cpp/issues/1437", + [](common_params & params, const std::string & value) { + /**/ if (value == "distribute" || value == "") { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; } + else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; } + else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; } + else { throw std::invalid_argument("invalid value"); } + } + ).set_env("LLAMA_ARG_NUMA")); + add_opt(common_arg( + {"-dev", "--device"}, "", + "comma-separated list of devices to use for offloading (none = don't offload)\n" + "use --list-devices to see a list of available devices", + [](common_params & params, const std::string & value) { + params.devices = parse_device_list(value); + } + ).set_env("LLAMA_ARG_DEVICE")); + add_opt(common_arg( + {"--list-devices"}, + "print list of available devices and exit", + [](common_params &) { + std::vector rpc_devices; + std::vector all_devices; + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + auto * dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) { + ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); + if (ggml_backend_reg_name(reg) == std::string("RPC")) { + rpc_devices.push_back(dev); + } else { + all_devices.push_back(dev); + } + } + } + // insert RPC devices in front + all_devices.insert(all_devices.begin(), rpc_devices.begin(), rpc_devices.end()); + printf("Available devices:\n"); + for (size_t i = 0; i < all_devices.size(); ++i) { + auto * dev = all_devices[i]; + size_t free, total; + ggml_backend_dev_memory(dev, &free, &total); + printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024); + } + exit(0); + } + )); + add_opt(common_arg( + {"--override-tensor", "-ot"}, "=,...", + "override tensor buffer type", [](common_params & params, const std::string & value) { + /* static */ std::map buft_list; + if (buft_list.empty()) { + // enumerate all the devices and add their buffer types to the list + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + auto * dev = ggml_backend_dev_get(i); + auto * buft = ggml_backend_dev_buffer_type(dev); + if (buft) { + buft_list[ggml_backend_buft_name(buft)] = buft; + } + } + } + + for (const auto & override : string_split(value, ',')) { + std::string::size_type pos = override.find('='); + if (pos == std::string::npos) { + throw std::invalid_argument("invalid value"); + } + std::string tensor_name = override.substr(0, pos); + std::string buffer_type = override.substr(pos + 1); + + if (buft_list.find(buffer_type) == buft_list.end()) { + printf("Available buffer types:\n"); + for (const auto & it : buft_list) { + printf(" %s\n", ggml_backend_buft_name(it.second)); + } + throw std::invalid_argument("unknown buffer type"); + } + // FIXME: this leaks memory + params.tensor_buft_overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)}); + } + } + )); + add_opt(common_arg( + {"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N", + "number of layers to store in VRAM", + [](common_params & params, int value) { + params.n_gpu_layers = value; + if (!llama_supports_gpu_offload()) { + fprintf(stderr, "warning: no usable GPU found, --gpu-layers option will be ignored\n"); + fprintf(stderr, "warning: one possible reason is that llama.cpp was compiled without GPU support\n"); + fprintf(stderr, "warning: consult docs/build.md for compilation instructions\n"); + } + } + ).set_env("LLAMA_ARG_N_GPU_LAYERS")); + add_opt(common_arg( + {"-sm", "--split-mode"}, "{none,layer,row}", + "how to split the model across multiple GPUs, one of:\n" + "- none: use one GPU only\n" + "- layer (default): split layers and KV across GPUs\n" + "- row: split rows across GPUs", + [](common_params & params, const std::string & value) { + std::string arg_next = value; + if (arg_next == "none") { + params.split_mode = LLAMA_SPLIT_MODE_NONE; + } else if (arg_next == "layer") { + params.split_mode = LLAMA_SPLIT_MODE_LAYER; + } else if (arg_next == "row") { + params.split_mode = LLAMA_SPLIT_MODE_ROW; + } else { + throw std::invalid_argument("invalid value"); + } + if (!llama_supports_gpu_offload()) { + fprintf(stderr, "warning: llama.cpp was compiled without support for GPU offload. Setting the split mode has no effect.\n"); + } + } + ).set_env("LLAMA_ARG_SPLIT_MODE")); + add_opt(common_arg( + {"-ts", "--tensor-split"}, "N0,N1,N2,...", + "fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1", + [](common_params & params, const std::string & value) { + std::string arg_next = value; + + // split string by , and / + const std::regex regex{ R"([,/]+)" }; + std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 }; + std::vector split_arg{ it, {} }; + if (split_arg.size() >= llama_max_devices()) { + throw std::invalid_argument( + string_format("got %d input configs, but system only has %d devices", (int)split_arg.size(), (int)llama_max_devices()) + ); + } + for (size_t i = 0; i < llama_max_devices(); ++i) { + if (i < split_arg.size()) { + params.tensor_split[i] = std::stof(split_arg[i]); + } else { + params.tensor_split[i] = 0.0f; + } + } + if (!llama_supports_gpu_offload()) { + fprintf(stderr, "warning: llama.cpp was compiled without support for GPU offload. Setting a tensor split has no effect.\n"); + } + } + ).set_env("LLAMA_ARG_TENSOR_SPLIT")); + add_opt(common_arg( + {"-mg", "--main-gpu"}, "INDEX", + string_format("the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row) (default: %d)", params.main_gpu), + [](common_params & params, int value) { + params.main_gpu = value; + if (!llama_supports_gpu_offload()) { + fprintf(stderr, "warning: llama.cpp was compiled without support for GPU offload. Setting the main GPU has no effect.\n"); + } + } + ).set_env("LLAMA_ARG_MAIN_GPU")); + add_opt(common_arg( + {"--check-tensors"}, + string_format("check model tensor data for invalid values (default: %s)", params.check_tensors ? "true" : "false"), + [](common_params & params) { + params.check_tensors = true; + } + )); + add_opt(common_arg( + {"--override-kv"}, "KEY=TYPE:VALUE", + "advanced option to override model metadata by key. may be specified multiple times.\n" + "types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false", + [](common_params & params, const std::string & value) { + if (!string_parse_kv_override(value.c_str(), params.kv_overrides)) { + throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", value.c_str())); + } + } + )); + add_opt(common_arg( + {"--no-op-offload"}, + string_format("disable offloading host tensor operations to device (default: %s)", params.no_op_offload ? "true" : "false"), + [](common_params & params) { + params.no_op_offload = true; + } + )); + add_opt(common_arg( + {"--lora"}, "FNAME", + "path to LoRA adapter (can be repeated to use multiple adapters)", + [](common_params & params, const std::string & value) { + params.lora_adapters.push_back({ std::string(value), 1.0, nullptr }); + } + // we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg + ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA})); + add_opt(common_arg( + {"--lora-scaled"}, "FNAME", "SCALE", + "path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)", + [](common_params & params, const std::string & fname, const std::string & scale) { + params.lora_adapters.push_back({ fname, std::stof(scale), nullptr }); + } + // we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg + ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA})); + add_opt(common_arg( + {"--control-vector"}, "FNAME", + "add a control vector\nnote: this argument can be repeated to add multiple control vectors", + [](common_params & params, const std::string & value) { + params.control_vectors.push_back({ 1.0f, value, }); + } + )); + add_opt(common_arg( + {"--control-vector-scaled"}, "FNAME", "SCALE", + "add a control vector with user defined scaling SCALE\n" + "note: this argument can be repeated to add multiple scaled control vectors", + [](common_params & params, const std::string & fname, const std::string & scale) { + params.control_vectors.push_back({ std::stof(scale), fname }); + } + )); + add_opt(common_arg( + {"--control-vector-layer-range"}, "START", "END", + "layer range to apply the control vector(s) to, start and end inclusive", + [](common_params & params, const std::string & start, const std::string & end) { + params.control_vector_layer_start = std::stoi(start); + params.control_vector_layer_end = std::stoi(end); + } + )); + add_opt(common_arg( + {"-a", "--alias"}, "STRING", + "set alias for model name (to be used by REST API)", + [](common_params & params, const std::string & value) { + params.model_alias = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ALIAS")); + add_opt(common_arg( + {"-m", "--model"}, "FNAME", + ex == LLAMA_EXAMPLE_EXPORT_LORA + ? std::string("model path from which to load base model") + : string_format( + "model path (default: `models/$filename` with filename from `--hf-file` " + "or `--model-url` if set, otherwise %s)", DEFAULT_MODEL_PATH + ), + [](common_params & params, const std::string & value) { + params.model.path = value; + } + ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}).set_env("LLAMA_ARG_MODEL")); + add_opt(common_arg( + {"-mu", "--model-url"}, "MODEL_URL", + "model download url (default: unused)", + [](common_params & params, const std::string & value) { + params.model.url = value; + } + ).set_env("LLAMA_ARG_MODEL_URL")); + add_opt(common_arg( + {"-hf", "-hfr", "--hf-repo"}, "/[:quant]", + "Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n" + "mmproj is also downloaded automatically if available. to disable, add --no-mmproj\n" + "example: unsloth/phi-4-GGUF:q4_k_m\n" + "(default: unused)", + [](common_params & params, const std::string & value) { + params.model.hf_repo = value; + } + ).set_env("LLAMA_ARG_HF_REPO")); + add_opt(common_arg( + {"-hfd", "-hfrd", "--hf-repo-draft"}, "/[:quant]", + "Same as --hf-repo, but for the draft model (default: unused)", + [](common_params & params, const std::string & value) { + params.speculative.model.hf_repo = value; + } + ).set_env("LLAMA_ARG_HFD_REPO")); + add_opt(common_arg( + {"-hff", "--hf-file"}, "FILE", + "Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)", + [](common_params & params, const std::string & value) { + params.model.hf_file = value; + } + ).set_env("LLAMA_ARG_HF_FILE")); + add_opt(common_arg( + {"-hfv", "-hfrv", "--hf-repo-v"}, "/[:quant]", + "Hugging Face model repository for the vocoder model (default: unused)", + [](common_params & params, const std::string & value) { + params.vocoder.model.hf_repo = value; + } + ).set_env("LLAMA_ARG_HF_REPO_V")); + add_opt(common_arg( + {"-hffv", "--hf-file-v"}, "FILE", + "Hugging Face model file for the vocoder model (default: unused)", + [](common_params & params, const std::string & value) { + params.vocoder.model.hf_file = value; + } + ).set_env("LLAMA_ARG_HF_FILE_V")); + add_opt(common_arg( + {"-hft", "--hf-token"}, "TOKEN", + "Hugging Face access token (default: value from HF_TOKEN environment variable)", + [](common_params & params, const std::string & value) { + params.hf_token = value; + } + ).set_env("HF_TOKEN")); + add_opt(common_arg( + {"--context-file"}, "FNAME", + "file to load context from (repeat to specify multiple files)", + [](common_params & params, const std::string & value) { + std::ifstream file(value, std::ios::binary); + if (!file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); + } + params.context_files.push_back(value); + } + ).set_examples({LLAMA_EXAMPLE_RETRIEVAL})); + add_opt(common_arg( + {"--chunk-size"}, "N", + string_format("minimum length of embedded text chunks (default: %d)", params.chunk_size), + [](common_params & params, int value) { + params.chunk_size = value; + } + ).set_examples({LLAMA_EXAMPLE_RETRIEVAL})); + add_opt(common_arg( + {"--chunk-separator"}, "STRING", + string_format("separator between chunks (default: '%s')", params.chunk_separator.c_str()), + [](common_params & params, const std::string & value) { + params.chunk_separator = value; + } + ).set_examples({LLAMA_EXAMPLE_RETRIEVAL})); + add_opt(common_arg( + {"--junk"}, "N", + string_format("number of times to repeat the junk text (default: %d)", params.n_junk), + [](common_params & params, int value) { + params.n_junk = value; + } + ).set_examples({LLAMA_EXAMPLE_PASSKEY, LLAMA_EXAMPLE_PARALLEL})); + add_opt(common_arg( + {"--pos"}, "N", + string_format("position of the passkey in the junk text (default: %d)", params.i_pos), + [](common_params & params, int value) { + params.i_pos = value; + } + ).set_examples({LLAMA_EXAMPLE_PASSKEY})); + add_opt(common_arg( + {"-o", "--output", "--output-file"}, "FNAME", + string_format("output file (default: '%s')", params.out_file.c_str()), + [](common_params & params, const std::string & value) { + params.out_file = value; + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS})); + add_opt(common_arg( + {"-ofreq", "--output-frequency"}, "N", + string_format("output the imatrix every N iterations (default: %d)", params.n_out_freq), + [](common_params & params, int value) { + params.n_out_freq = value; + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX})); + add_opt(common_arg( + {"--save-frequency"}, "N", + string_format("save an imatrix copy every N iterations (default: %d)", params.n_save_freq), + [](common_params & params, int value) { + params.n_save_freq = value; + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX})); + add_opt(common_arg( + {"--process-output"}, + string_format("collect data for the output tensor (default: %s)", params.process_output ? "true" : "false"), + [](common_params & params) { + params.process_output = true; + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX})); + add_opt(common_arg( + {"--no-ppl"}, + string_format("do not compute perplexity (default: %s)", params.compute_ppl ? "true" : "false"), + [](common_params & params) { + params.compute_ppl = false; + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX})); + add_opt(common_arg( + {"--chunk", "--from-chunk"}, "N", + string_format("start processing the input from chunk N (default: %d)", params.i_chunk), + [](common_params & params, int value) { + params.i_chunk = value; + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX})); + add_opt(common_arg( + {"--parse-special"}, + string_format("prase special tokens (chat, tool, etc) (default: %s)", params.parse_special ? "true" : "false"), + [](common_params & params) { + params.parse_special = true; + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX})); + add_opt(common_arg( + {"-pps"}, + string_format("is the prompt shared across parallel sequences (default: %s)", params.is_pp_shared ? "true" : "false"), + [](common_params & params) { + params.is_pp_shared = true; + } + ).set_examples({LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL})); + add_opt(common_arg( + {"-npp"}, "n0,n1,...", + "number of prompt tokens", + [](common_params & params, const std::string & value) { + auto p = string_split(value, ','); + params.n_pp.insert(params.n_pp.end(), p.begin(), p.end()); + } + ).set_examples({LLAMA_EXAMPLE_BENCH})); + add_opt(common_arg( + {"-ntg"}, "n0,n1,...", + "number of text generation tokens", + [](common_params & params, const std::string & value) { + auto p = string_split(value, ','); + params.n_tg.insert(params.n_tg.end(), p.begin(), p.end()); + } + ).set_examples({LLAMA_EXAMPLE_BENCH})); + add_opt(common_arg( + {"-npl"}, "n0,n1,...", + "number of parallel prompts", + [](common_params & params, const std::string & value) { + auto p = string_split(value, ','); + params.n_pl.insert(params.n_pl.end(), p.begin(), p.end()); + } + ).set_examples({LLAMA_EXAMPLE_BENCH})); + add_opt(common_arg( + {"--embd-normalize"}, "N", + string_format("normalisation for embeddings (default: %d) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)", params.embd_normalize), + [](common_params & params, int value) { + params.embd_normalize = value; + } + ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); + add_opt(common_arg( + {"--embd-output-format"}, "FORMAT", + "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix", + [](common_params & params, const std::string & value) { + params.embd_out = value; + } + ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); + add_opt(common_arg( + {"--embd-separator"}, "STRING", + "separator of embeddings (default \\n) for example \"<#sep#>\"", + [](common_params & params, const std::string & value) { + params.embd_sep = value; + } + ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); + add_opt(common_arg( + {"--host"}, "HOST", + string_format("ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: %s)", params.hostname.c_str()), + [](common_params & params, const std::string & value) { + params.hostname = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_HOST")); + add_opt(common_arg( + {"--port"}, "PORT", + string_format("port to listen (default: %d)", params.port), + [](common_params & params, int value) { + params.port = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_PORT")); + add_opt(common_arg( + {"--path"}, "PATH", + string_format("path to serve static files from (default: %s)", params.public_path.c_str()), + [](common_params & params, const std::string & value) { + params.public_path = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_STATIC_PATH")); + add_opt(common_arg( + {"--no-webui"}, + string_format("Disable the Web UI (default: %s)", params.webui ? "enabled" : "disabled"), + [](common_params & params) { + params.webui = false; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_WEBUI")); + add_opt(common_arg( + {"--embedding", "--embeddings"}, + string_format("restrict to only support embedding use case; use only with dedicated embedding models (default: %s)", params.embedding ? "enabled" : "disabled"), + [](common_params & params) { + params.embedding = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS")); + add_opt(common_arg( + {"--reranking", "--rerank"}, + string_format("enable reranking endpoint on server (default: %s)", "disabled"), + [](common_params & params) { + params.embedding = true; + params.pooling_type = LLAMA_POOLING_TYPE_RANK; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING")); + add_opt(common_arg( + {"--api-key"}, "KEY", + "API key to use for authentication (default: none)", + [](common_params & params, const std::string & value) { + params.api_keys.push_back(value); + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_API_KEY")); + add_opt(common_arg( + {"--api-key-file"}, "FNAME", + "path to file containing API keys (default: none)", + [](common_params & params, const std::string & value) { + std::ifstream key_file(value); + if (!key_file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); + } + std::string key; + while (std::getline(key_file, key)) { + if (!key.empty()) { + params.api_keys.push_back(key); + } + } + key_file.close(); + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--ssl-key-file"}, "FNAME", + "path to file a PEM-encoded SSL private key", + [](common_params & params, const std::string & value) { + params.ssl_file_key = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSL_KEY_FILE")); + add_opt(common_arg( + {"--ssl-cert-file"}, "FNAME", + "path to file a PEM-encoded SSL certificate", + [](common_params & params, const std::string & value) { + params.ssl_file_cert = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSL_CERT_FILE")); + add_opt(common_arg( + {"-to", "--timeout"}, "N", + string_format("server read/write timeout in seconds (default: %d)", params.timeout_read), + [](common_params & params, int value) { + params.timeout_read = value; + params.timeout_write = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TIMEOUT")); + add_opt(common_arg( + {"--threads-http"}, "N", + string_format("number of threads used to process HTTP requests (default: %d)", params.n_threads_http), + [](common_params & params, int value) { + params.n_threads_http = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_THREADS_HTTP")); + add_opt(common_arg( + {"--cache-reuse"}, "N", + string_format( + "min chunk size to attempt reusing from the cache via KV shifting (default: %d)\n" + "[(card)](https://ggml.ai/f0.png)", params.n_cache_reuse + ), + [](common_params & params, int value) { + params.n_cache_reuse = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CACHE_REUSE")); + add_opt(common_arg( + {"--metrics"}, + string_format("enable prometheus compatible metrics endpoint (default: %s)", params.endpoint_metrics ? "enabled" : "disabled"), + [](common_params & params) { + params.endpoint_metrics = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_METRICS")); + add_opt(common_arg( + {"--slots"}, + string_format("enable slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled"), + [](common_params & params) { + params.endpoint_slots = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_SLOTS")); + add_opt(common_arg( + {"--props"}, + string_format("enable changing global properties via POST /props (default: %s)", params.endpoint_props ? "enabled" : "disabled"), + [](common_params & params) { + params.endpoint_props = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_PROPS")); + add_opt(common_arg( + {"--no-slots"}, + "disables slots monitoring endpoint", + [](common_params & params) { + params.endpoint_slots = false; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_ENDPOINT_SLOTS")); + add_opt(common_arg( + {"--slot-save-path"}, "PATH", + "path to save slot kv cache (default: disabled)", + [](common_params & params, const std::string & value) { + params.slot_save_path = value; + // if doesn't end with DIRECTORY_SEPARATOR, add it + if (!params.slot_save_path.empty() && params.slot_save_path[params.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) { + params.slot_save_path += DIRECTORY_SEPARATOR; + } + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--jinja"}, + "use jinja template for chat (default: disabled)", + [](common_params & params) { + params.use_jinja = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA")); + add_opt(common_arg( + {"--reasoning-format"}, "FORMAT", + "controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n" + "- none: leaves thoughts unparsed in `message.content`\n" + "- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n" + "(default: deepseek)", + [](common_params & params, const std::string & value) { + /**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; } + else if (value == "deepseek-legacy") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY; } + else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; } + else { throw std::invalid_argument("invalid value"); } + } + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK")); + add_opt(common_arg( + {"--reasoning-budget"}, "N", + "controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)", + [](common_params & params, int value) { + if (value != 0 && value != -1) { throw std::invalid_argument("invalid value"); } + params.reasoning_budget = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK_BUDGET")); + add_opt(common_arg( + {"--chat-template"}, "JINJA_TEMPLATE", + string_format( + "set custom jinja chat template (default: template taken from model's metadata)\n" + "if suffix/prefix are specified, template will be disabled\n" + "only commonly used templates are accepted (unless --jinja is set before this flag):\n" + "list of built-in templates:\n%s", list_builtin_chat_templates().c_str() + ), + [](common_params & params, const std::string & value) { + params.chat_template = value; + } + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_CHAT_TEMPLATE")); + add_opt(common_arg( + {"--chat-template-file"}, "JINJA_TEMPLATE_FILE", + string_format( + "set custom jinja chat template file (default: template taken from model's metadata)\n" + "if suffix/prefix are specified, template will be disabled\n" + "only commonly used templates are accepted (unless --jinja is set before this flag):\n" + "list of built-in templates:\n%s", list_builtin_chat_templates().c_str() + ), + [](common_params & params, const std::string & value) { + params.chat_template = read_file(value); + } + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE")); + add_opt(common_arg( + {"--no-prefill-assistant"}, + string_format( + "whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)\n" + "when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled\n" + ), + [](common_params & params) { + params.prefill_assistant = false; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_PREFILL_ASSISTANT")); + add_opt(common_arg( + {"-sps", "--slot-prompt-similarity"}, "SIMILARITY", + string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), + [](common_params & params, const std::string & value) { + params.slot_prompt_similarity = std::stof(value); + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--lora-init-without-apply"}, + string_format("load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"), + [](common_params & params) { + params.lora_init_without_apply = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--simple-io"}, + "use basic IO for better compatibility in subprocesses and limited consoles", + [](common_params & params) { + params.simple_io = true; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"--positive-file"}, "FNAME", + string_format("positive prompts file, one prompt per line (default: '%s')", params.cvector_positive_file.c_str()), + [](common_params & params, const std::string & value) { + params.cvector_positive_file = value; + } + ).set_examples({LLAMA_EXAMPLE_CVECTOR_GENERATOR})); + add_opt(common_arg( + {"--negative-file"}, "FNAME", + string_format("negative prompts file, one prompt per line (default: '%s')", params.cvector_negative_file.c_str()), + [](common_params & params, const std::string & value) { + params.cvector_negative_file = value; + } + ).set_examples({LLAMA_EXAMPLE_CVECTOR_GENERATOR})); + add_opt(common_arg( + {"--pca-batch"}, "N", + string_format("batch size used for PCA. Larger batch runs faster, but uses more memory (default: %d)", params.n_pca_batch), + [](common_params & params, int value) { + params.n_pca_batch = value; + } + ).set_examples({LLAMA_EXAMPLE_CVECTOR_GENERATOR})); + add_opt(common_arg( + {"--pca-iter"}, "N", + string_format("number of iterations used for PCA (default: %d)", params.n_pca_iterations), + [](common_params & params, int value) { + params.n_pca_iterations = value; + } + ).set_examples({LLAMA_EXAMPLE_CVECTOR_GENERATOR})); + add_opt(common_arg( + {"--method"}, "{pca, mean}", + "dimensionality reduction method to be used (default: pca)", + [](common_params & params, const std::string & value) { + /**/ if (value == "pca") { params.cvector_dimre_method = DIMRE_METHOD_PCA; } + else if (value == "mean") { params.cvector_dimre_method = DIMRE_METHOD_MEAN; } + else { throw std::invalid_argument("invalid value"); } + } + ).set_examples({LLAMA_EXAMPLE_CVECTOR_GENERATOR})); + add_opt(common_arg( + {"--output-format"}, "{md,jsonl}", + "output format for batched-bench results (default: md)", + [](common_params & params, const std::string & value) { + /**/ if (value == "jsonl") { params.batched_bench_output_jsonl = true; } + else if (value == "md") { params.batched_bench_output_jsonl = false; } + else { throw std::invalid_argument("invalid value"); } + } + ).set_examples({LLAMA_EXAMPLE_BENCH})); + add_opt(common_arg( + {"--log-disable"}, + "Log disable", + [](common_params &) { + common_log_pause(common_log_main()); + } + )); + add_opt(common_arg( + {"--log-file"}, "FNAME", + "Log to file", + [](common_params &, const std::string & value) { + common_log_set_file(common_log_main(), value.c_str()); + } + )); + add_opt(common_arg( + {"--log-colors"}, + "Enable colored logging", + [](common_params &) { + common_log_set_colors(common_log_main(), true); + } + ).set_env("LLAMA_LOG_COLORS")); + add_opt(common_arg( + {"-v", "--verbose", "--log-verbose"}, + "Set verbosity level to infinity (i.e. log all messages, useful for debugging)", + [](common_params & params) { + params.verbosity = INT_MAX; + common_log_set_verbosity_thold(INT_MAX); + } + )); + add_opt(common_arg( + {"--offline"}, + "Offline mode: forces use of cache, prevents network access", + [](common_params & params) { + params.offline = true; + } + ).set_env("LLAMA_OFFLINE")); + add_opt(common_arg( + {"-lv", "--verbosity", "--log-verbosity"}, "N", + "Set the verbosity threshold. Messages with a higher verbosity will be ignored.", + [](common_params & params, int value) { + params.verbosity = value; + common_log_set_verbosity_thold(value); + } + ).set_env("LLAMA_LOG_VERBOSITY")); + add_opt(common_arg( + {"--log-prefix"}, + "Enable prefix in log messages", + [](common_params &) { + common_log_set_prefix(common_log_main(), true); + } + ).set_env("LLAMA_LOG_PREFIX")); + add_opt(common_arg( + {"--log-timestamps"}, + "Enable timestamps in log messages", + [](common_params &) { + common_log_set_timestamps(common_log_main(), true); + } + ).set_env("LLAMA_LOG_TIMESTAMPS")); + + // speculative parameters + add_opt(common_arg( + {"-td", "--threads-draft"}, "N", + "number of threads to use during generation (default: same as --threads)", + [](common_params & params, int value) { + params.speculative.cpuparams.n_threads = value; + if (params.speculative.cpuparams.n_threads <= 0) { + params.speculative.cpuparams.n_threads = std::thread::hardware_concurrency(); + } + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"-tbd", "--threads-batch-draft"}, "N", + "number of threads to use during batch and prompt processing (default: same as --threads-draft)", + [](common_params & params, int value) { + params.speculative.cpuparams_batch.n_threads = value; + if (params.speculative.cpuparams_batch.n_threads <= 0) { + params.speculative.cpuparams_batch.n_threads = std::thread::hardware_concurrency(); + } + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"-Cd", "--cpu-mask-draft"}, "M", + "Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)", + [](common_params & params, const std::string & mask) { + params.speculative.cpuparams.mask_valid = true; + if (!parse_cpu_mask(mask, params.speculative.cpuparams.cpumask)) { + throw std::invalid_argument("invalid cpumask"); + } + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"-Crd", "--cpu-range-draft"}, "lo-hi", + "Ranges of CPUs for affinity. Complements --cpu-mask-draft", + [](common_params & params, const std::string & range) { + params.speculative.cpuparams.mask_valid = true; + if (!parse_cpu_range(range, params.speculative.cpuparams.cpumask)) { + throw std::invalid_argument("invalid range"); + } + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"--cpu-strict-draft"}, "<0|1>", + "Use strict CPU placement for draft model (default: same as --cpu-strict)", + [](common_params & params, int value) { + params.speculative.cpuparams.strict_cpu = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"--prio-draft"}, "N", + string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.speculative.cpuparams.priority), + [](common_params & params, int prio) { + if (prio < 0 || prio > 3) { + throw std::invalid_argument("invalid value"); + } + params.speculative.cpuparams.priority = (enum ggml_sched_priority) prio; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"--poll-draft"}, "<0|1>", + "Use polling to wait for draft model work (default: same as --poll])", + [](common_params & params, int value) { + params.speculative.cpuparams.poll = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"-Cbd", "--cpu-mask-batch-draft"}, "M", + "Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)", + [](common_params & params, const std::string & mask) { + params.speculative.cpuparams_batch.mask_valid = true; + if (!parse_cpu_mask(mask, params.speculative.cpuparams_batch.cpumask)) { + throw std::invalid_argument("invalid cpumask"); + } + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"-Crbd", "--cpu-range-batch-draft"}, "lo-hi", + "Ranges of CPUs for affinity. Complements --cpu-mask-draft-batch)", + [](common_params & params, const std::string & range) { + params.speculative.cpuparams_batch.mask_valid = true; + if (!parse_cpu_range(range, params.speculative.cpuparams_batch.cpumask)) { + throw std::invalid_argument("invalid cpumask"); + } + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"--cpu-strict-batch-draft"}, "<0|1>", + "Use strict CPU placement for draft model (default: --cpu-strict-draft)", + [](common_params & params, int value) { + params.speculative.cpuparams_batch.strict_cpu = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"--prio-batch-draft"}, "N", + string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.speculative.cpuparams_batch.priority), + [](common_params & params, int prio) { + if (prio < 0 || prio > 3) { + throw std::invalid_argument("invalid value"); + } + params.speculative.cpuparams_batch.priority = (enum ggml_sched_priority) prio; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"--poll-batch-draft"}, "<0|1>", + "Use polling to wait for draft model work (default: --poll-draft)", + [](common_params & params, int value) { + params.speculative.cpuparams_batch.poll = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"--draft-max", "--draft", "--draft-n"}, "N", + string_format("number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max), + [](common_params & params, int value) { + params.speculative.n_max = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_DRAFT_MAX")); + add_opt(common_arg( + {"--draft-min", "--draft-n-min"}, "N", + string_format("minimum number of draft tokens to use for speculative decoding (default: %d)", params.speculative.n_min), + [](common_params & params, int value) { + params.speculative.n_min = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_DRAFT_MIN")); + add_opt(common_arg( + {"--draft-p-split"}, "P", + string_format("speculative decoding split probability (default: %.1f)", (double)params.speculative.p_split), + [](common_params & params, const std::string & value) { + params.speculative.p_split = std::stof(value); + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}).set_env("LLAMA_ARG_DRAFT_P_SPLIT")); + add_opt(common_arg( + {"--draft-p-min"}, "P", + string_format("minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.speculative.p_min), + [](common_params & params, const std::string & value) { + params.speculative.p_min = std::stof(value); + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_DRAFT_P_MIN")); + add_opt(common_arg( + {"-cd", "--ctx-size-draft"}, "N", + string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx), + [](common_params & params, int value) { + params.speculative.n_ctx = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CTX_SIZE_DRAFT")); + add_opt(common_arg( + {"-devd", "--device-draft"}, "", + "comma-separated list of devices to use for offloading the draft model (none = don't offload)\n" + "use --list-devices to see a list of available devices", + [](common_params & params, const std::string & value) { + params.speculative.devices = parse_device_list(value); + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"-ngld", "--gpu-layers-draft", "--n-gpu-layers-draft"}, "N", + "number of layers to store in VRAM for the draft model", + [](common_params & params, int value) { + params.speculative.n_gpu_layers = value; + if (!llama_supports_gpu_offload()) { + fprintf(stderr, "warning: no usable GPU found, --gpu-layers-draft option will be ignored\n"); + fprintf(stderr, "warning: one possible reason is that llama.cpp was compiled without GPU support\n"); + fprintf(stderr, "warning: consult docs/build.md for compilation instructions\n"); + } + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_N_GPU_LAYERS_DRAFT")); + add_opt(common_arg( + {"-md", "--model-draft"}, "FNAME", + "draft model for speculative decoding (default: unused)", + [](common_params & params, const std::string & value) { + params.speculative.model.path = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT")); + + add_opt(common_arg( + {"-mv", "--model-vocoder"}, "FNAME", + "vocoder model for audio generation (default: unused)", + [](common_params & params, const std::string & value) { + params.vocoder.model.path = value; + } + ).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--tts-use-guide-tokens"}, + "Use guide tokens to improve TTS word recall", + [](common_params & params) { + params.vocoder.use_guide_tokens = true; + } + ).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--tts-speaker-file"}, "FNAME", + "speaker file path for audio generation", + [](common_params & params, const std::string & value) { + params.vocoder.speaker_file = value; + } + ).set_examples({LLAMA_EXAMPLE_TTS})); + + // model-specific + add_opt(common_arg( + {"--tts-oute-default"}, + string_format("use default OuteTTS models (note: can download weights from the internet)"), + [](common_params & params) { + params.model.hf_repo = "OuteAI/OuteTTS-0.2-500M-GGUF"; + params.model.hf_file = "OuteTTS-0.2-500M-Q8_0.gguf"; + params.vocoder.model.hf_repo = "ggml-org/WavTokenizer"; + params.vocoder.model.hf_file = "WavTokenizer-Large-75-F16.gguf"; + } + ).set_examples({LLAMA_EXAMPLE_TTS})); + + add_opt(common_arg( + {"--embd-bge-small-en-default"}, + string_format("use default bge-small-en-v1.5 model (note: can download weights from the internet)"), + [](common_params & params) { + params.model.hf_repo = "ggml-org/bge-small-en-v1.5-Q8_0-GGUF"; + params.model.hf_file = "bge-small-en-v1.5-q8_0.gguf"; + params.pooling_type = LLAMA_POOLING_TYPE_NONE; + params.embd_normalize = 2; + params.n_ctx = 512; + params.verbose_prompt = true; + params.embedding = true; + } + ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER})); + + add_opt(common_arg( + {"--embd-e5-small-en-default"}, + string_format("use default e5-small-v2 model (note: can download weights from the internet)"), + [](common_params & params) { + params.model.hf_repo = "ggml-org/e5-small-v2-Q8_0-GGUF"; + params.model.hf_file = "e5-small-v2-q8_0.gguf"; + params.pooling_type = LLAMA_POOLING_TYPE_NONE; + params.embd_normalize = 2; + params.n_ctx = 512; + params.verbose_prompt = true; + params.embedding = true; + } + ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER})); + + add_opt(common_arg( + {"--embd-gte-small-default"}, + string_format("use default gte-small model (note: can download weights from the internet)"), + [](common_params & params) { + params.model.hf_repo = "ggml-org/gte-small-Q8_0-GGUF"; + params.model.hf_file = "gte-small-q8_0.gguf"; + params.pooling_type = LLAMA_POOLING_TYPE_NONE; + params.embd_normalize = 2; + params.n_ctx = 512; + params.verbose_prompt = true; + params.embedding = true; + } + ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER})); + + add_opt(common_arg( + {"--fim-qwen-1.5b-default"}, + string_format("use default Qwen 2.5 Coder 1.5B (note: can download weights from the internet)"), + [](common_params & params) { + params.model.hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-Q8_0-GGUF"; + params.model.hf_file = "qwen2.5-coder-1.5b-q8_0.gguf"; + params.port = 8012; + params.n_gpu_layers = 99; + params.flash_attn = true; + params.n_ubatch = 1024; + params.n_batch = 1024; + params.n_ctx = 0; + params.n_cache_reuse = 256; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + + add_opt(common_arg( + {"--fim-qwen-3b-default"}, + string_format("use default Qwen 2.5 Coder 3B (note: can download weights from the internet)"), + [](common_params & params) { + params.model.hf_repo = "ggml-org/Qwen2.5-Coder-3B-Q8_0-GGUF"; + params.model.hf_file = "qwen2.5-coder-3b-q8_0.gguf"; + params.port = 8012; + params.n_gpu_layers = 99; + params.flash_attn = true; + params.n_ubatch = 1024; + params.n_batch = 1024; + params.n_ctx = 0; + params.n_cache_reuse = 256; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + + add_opt(common_arg( + {"--fim-qwen-7b-default"}, + string_format("use default Qwen 2.5 Coder 7B (note: can download weights from the internet)"), + [](common_params & params) { + params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF"; + params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf"; + params.port = 8012; + params.n_gpu_layers = 99; + params.flash_attn = true; + params.n_ubatch = 1024; + params.n_batch = 1024; + params.n_ctx = 0; + params.n_cache_reuse = 256; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + + add_opt(common_arg( + {"--fim-qwen-7b-spec"}, + string_format("use Qwen 2.5 Coder 7B + 0.5B draft for speculative decoding (note: can download weights from the internet)"), + [](common_params & params) { + params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF"; + params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf"; + params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; + params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; + params.speculative.n_gpu_layers = 99; + params.port = 8012; + params.n_gpu_layers = 99; + params.flash_attn = true; + params.n_ubatch = 1024; + params.n_batch = 1024; + params.n_ctx = 0; + params.n_cache_reuse = 256; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + + add_opt(common_arg( + {"--fim-qwen-14b-spec"}, + string_format("use Qwen 2.5 Coder 14B + 0.5B draft for speculative decoding (note: can download weights from the internet)"), + [](common_params & params) { + params.model.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF"; + params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf"; + params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; + params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; + params.speculative.n_gpu_layers = 99; + params.port = 8012; + params.n_gpu_layers = 99; + params.flash_attn = true; + params.n_ubatch = 1024; + params.n_batch = 1024; + params.n_ctx = 0; + params.n_cache_reuse = 256; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + + return ctx_arg; +} diff --git a/common/arg.h b/common/arg.h new file mode 100644 index 0000000000000000000000000000000000000000..70bea100fd4f268dbb31e84e3c1dae664102f8c2 --- /dev/null +++ b/common/arg.h @@ -0,0 +1,89 @@ +#pragma once + +#include "common.h" + +#include +#include +#include + +// +// CLI argument parsing +// + +struct common_arg { + std::set examples = {LLAMA_EXAMPLE_COMMON}; + std::set excludes = {}; + std::vector args; + const char * value_hint = nullptr; // help text or example for arg value + const char * value_hint_2 = nullptr; // for second arg value + const char * env = nullptr; + std::string help; + bool is_sparam = false; // is current arg a sampling param? + void (*handler_void) (common_params & params) = nullptr; + void (*handler_string) (common_params & params, const std::string &) = nullptr; + void (*handler_str_str)(common_params & params, const std::string &, const std::string &) = nullptr; + void (*handler_int) (common_params & params, int) = nullptr; + + common_arg( + const std::initializer_list & args, + const char * value_hint, + const std::string & help, + void (*handler)(common_params & params, const std::string &) + ) : args(args), value_hint(value_hint), help(help), handler_string(handler) {} + + common_arg( + const std::initializer_list & args, + const char * value_hint, + const std::string & help, + void (*handler)(common_params & params, int) + ) : args(args), value_hint(value_hint), help(help), handler_int(handler) {} + + common_arg( + const std::initializer_list & args, + const std::string & help, + void (*handler)(common_params & params) + ) : args(args), help(help), handler_void(handler) {} + + // support 2 values for arg + common_arg( + const std::initializer_list & args, + const char * value_hint, + const char * value_hint_2, + const std::string & help, + void (*handler)(common_params & params, const std::string &, const std::string &) + ) : args(args), value_hint(value_hint), value_hint_2(value_hint_2), help(help), handler_str_str(handler) {} + + common_arg & set_examples(std::initializer_list examples); + common_arg & set_excludes(std::initializer_list excludes); + common_arg & set_env(const char * env); + common_arg & set_sparam(); + bool in_example(enum llama_example ex); + bool is_exclude(enum llama_example ex); + bool get_value_from_env(std::string & output); + bool has_value_from_env(); + std::string to_string(); +}; + +struct common_params_context { + enum llama_example ex = LLAMA_EXAMPLE_COMMON; + common_params & params; + std::vector options; + void(*print_usage)(int, char **) = nullptr; + common_params_context(common_params & params) : params(params) {} +}; + +// parse input arguments from CLI +// if one argument has invalid value, it will automatically display usage of the specific argument (and not the full usage message) +bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); + +// function to be used by test-arg-parser +common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); +bool common_has_curl(); + +struct common_remote_params { + std::vector headers; + long timeout = 0; // CURLOPT_TIMEOUT, in seconds ; 0 means no timeout + long max_size = 0; // max size of the response ; unlimited if 0 ; max is 2GB +}; +// get remote file content, returns +std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params); diff --git a/common/base64.hpp b/common/base64.hpp new file mode 100644 index 0000000000000000000000000000000000000000..563247a6e5f7dba837c07a509026d8b36e61387c --- /dev/null +++ b/common/base64.hpp @@ -0,0 +1,392 @@ +/* +This is free and unencumbered software released into the public domain. + +Anyone is free to copy, modify, publish, use, compile, sell, or +distribute this software, either in source code form or as a compiled +binary, for any purpose, commercial or non-commercial, and by any +means. + +In jurisdictions that recognize copyright laws, the author or authors +of this software dedicate any and all copyright interest in the +software to the public domain. We make this dedication for the benefit +of the public at large and to the detriment of our heirs and +successors. We intend this dedication to be an overt act of +relinquishment in perpetuity of all present and future rights to this +software under copyright law. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR +OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. + +For more information, please refer to +*/ + +#ifndef PUBLIC_DOMAIN_BASE64_HPP_ +#define PUBLIC_DOMAIN_BASE64_HPP_ + +#include +#include +#include +#include + +class base64_error : public std::runtime_error +{ +public: + using std::runtime_error::runtime_error; +}; + +class base64 +{ +public: + enum class alphabet + { + /** the alphabet is detected automatically */ + auto_, + /** the standard base64 alphabet is used */ + standard, + /** like `standard` except that the characters `+` and `/` are replaced by `-` and `_` respectively*/ + url_filename_safe + }; + + enum class decoding_behavior + { + /** if the input is not padded, the remaining bits are ignored */ + moderate, + /** if a padding character is encounter decoding is finished */ + loose + }; + + /** + Encodes all the elements from `in_begin` to `in_end` to `out`. + + @warning The source and destination cannot overlap. The destination must be able to hold at least + `required_encode_size(std::distance(in_begin, in_end))`, otherwise the behavior depends on the output iterator. + + @tparam Input_iterator the source; the returned elements are cast to `std::uint8_t` and should not be greater than + 8 bits + @tparam Output_iterator the destination; the elements written to it are from the type `char` + @param in_begin the beginning of the source + @param in_end the ending of the source + @param out the destination iterator + @param alphabet which alphabet should be used + @returns the iterator to the next element past the last element copied + @throws see `Input_iterator` and `Output_iterator` + */ + template + static Output_iterator encode(Input_iterator in_begin, Input_iterator in_end, Output_iterator out, + alphabet alphabet = alphabet::standard) + { + constexpr auto pad = '='; + const char* alpha = alphabet == alphabet::url_filename_safe + ? "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" + : "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + while (in_begin != in_end) { + std::uint8_t i0 = 0, i1 = 0, i2 = 0; + + // first character + i0 = static_cast(*in_begin); + ++in_begin; + + *out = alpha[i0 >> 2 & 0x3f]; + ++out; + + // part of first character and second + if (in_begin != in_end) { + i1 = static_cast(*in_begin); + ++in_begin; + + *out = alpha[((i0 & 0x3) << 4) | (i1 >> 4 & 0x0f)]; + ++out; + } else { + *out = alpha[(i0 & 0x3) << 4]; + ++out; + + // last padding + *out = pad; + ++out; + + // last padding + *out = pad; + ++out; + + break; + } + + // part of second character and third + if (in_begin != in_end) { + i2 = static_cast(*in_begin); + ++in_begin; + + *out = alpha[((i1 & 0xf) << 2) | (i2 >> 6 & 0x03)]; + ++out; + } else { + *out = alpha[(i1 & 0xf) << 2]; + ++out; + + // last padding + *out = pad; + ++out; + + break; + } + + // rest of third + *out = alpha[i2 & 0x3f]; + ++out; + } + + return out; + } + /** + Encodes a string. + + @param str the string that should be encoded + @param alphabet which alphabet should be used + @returns the encoded base64 string + @throws see base64::encode() + */ + static std::string encode(const std::string& str, alphabet alphabet = alphabet::standard) + { + std::string result; + + result.reserve(required_encode_size(str.length()) + 1); + + encode(str.begin(), str.end(), std::back_inserter(result), alphabet); + + return result; + } + /** + Encodes a char array. + + @param buffer the char array + @param size the size of the array + @param alphabet which alphabet should be used + @returns the encoded string + */ + static std::string encode(const char* buffer, std::size_t size, alphabet alphabet = alphabet::standard) + { + std::string result; + + result.reserve(required_encode_size(size) + 1); + + encode(buffer, buffer + size, std::back_inserter(result), alphabet); + + return result; + } + /** + Decodes all the elements from `in_begin` to `in_end` to `out`. `in_begin` may point to the same location as `out`, + in other words: inplace decoding is possible. + + @warning The destination must be able to hold at least `required_decode_size(std::distance(in_begin, in_end))`, + otherwise the behavior depends on the output iterator. + + @tparam Input_iterator the source; the returned elements are cast to `char` + @tparam Output_iterator the destination; the elements written to it are from the type `std::uint8_t` + @param in_begin the beginning of the source + @param in_end the ending of the source + @param out the destination iterator + @param alphabet which alphabet should be used + @param behavior the behavior when an error was detected + @returns the iterator to the next element past the last element copied + @throws base64_error depending on the set behavior + @throws see `Input_iterator` and `Output_iterator` + */ + template + static Output_iterator decode(Input_iterator in_begin, Input_iterator in_end, Output_iterator out, + alphabet alphabet = alphabet::auto_, + decoding_behavior behavior = decoding_behavior::moderate) + { + //constexpr auto pad = '='; + std::uint8_t last = 0; + auto bits = 0; + + while (in_begin != in_end) { + auto c = *in_begin; + ++in_begin; + + if (c == '=') { + break; + } + + auto part = _base64_value(alphabet, c); + + // enough bits for one byte + if (bits + 6 >= 8) { + *out = (last << (8 - bits)) | (part >> (bits - 2)); + ++out; + + bits -= 2; + } else { + bits += 6; + } + + last = part; + } + + // check padding + if (behavior != decoding_behavior::loose) { + while (in_begin != in_end) { + auto c = *in_begin; + ++in_begin; + + if (c != '=') { + throw base64_error("invalid base64 character."); + } + } + } + + return out; + } + /** + Decodes a string. + + @param str the base64 encoded string + @param alphabet which alphabet should be used + @param behavior the behavior when an error was detected + @returns the decoded string + @throws see base64::decode() + */ + static std::string decode(const std::string& str, alphabet alphabet = alphabet::auto_, + decoding_behavior behavior = decoding_behavior::moderate) + { + std::string result; + + result.reserve(max_decode_size(str.length())); + + decode(str.begin(), str.end(), std::back_inserter(result), alphabet, behavior); + + return result; + } + /** + Decodes a string. + + @param buffer the base64 encoded buffer + @param size the size of the buffer + @param alphabet which alphabet should be used + @param behavior the behavior when an error was detected + @returns the decoded string + @throws see base64::decode() + */ + static std::string decode(const char* buffer, std::size_t size, alphabet alphabet = alphabet::auto_, + decoding_behavior behavior = decoding_behavior::moderate) + { + std::string result; + + result.reserve(max_decode_size(size)); + + decode(buffer, buffer + size, std::back_inserter(result), alphabet, behavior); + + return result; + } + /** + Decodes a string inplace. + + @param[in,out] str the base64 encoded string + @param alphabet which alphabet should be used + @param behavior the behavior when an error was detected + @throws base64::decode_inplace() + */ + static void decode_inplace(std::string& str, alphabet alphabet = alphabet::auto_, + decoding_behavior behavior = decoding_behavior::moderate) + { + str.resize(decode(str.begin(), str.end(), str.begin(), alphabet, behavior) - str.begin()); + } + /** + Decodes a char array inplace. + + @param[in,out] str the string array + @param size the length of the array + @param alphabet which alphabet should be used + @param behavior the behavior when an error was detected + @returns the pointer to the next element past the last element decoded + @throws base64::decode_inplace() + */ + static char* decode_inplace(char* str, std::size_t size, alphabet alphabet = alphabet::auto_, + decoding_behavior behavior = decoding_behavior::moderate) + { + return decode(str, str + size, str, alphabet, behavior); + } + /** + Returns the required decoding size for a given size. The value is calculated with the following formula: + + $$ + \lceil \frac{size}{4} \rceil \cdot 3 + $$ + + @param size the size of the encoded input + @returns the size of the resulting decoded buffer; this the absolute maximum + */ + static std::size_t max_decode_size(std::size_t size) noexcept + { + return (size / 4 + (size % 4 ? 1 : 0)) * 3; + } + /** + Returns the required encoding size for a given size. The value is calculated with the following formula: + + $$ + \lceil \frac{size}{3} \rceil \cdot 4 + $$ + + @param size the size of the decoded input + @returns the size of the resulting encoded buffer + */ + static std::size_t required_encode_size(std::size_t size) noexcept + { + return (size / 3 + (size % 3 ? 1 : 0)) * 4; + } + +private: + static std::uint8_t _base64_value(alphabet& alphabet, char c) + { + if (c >= 'A' && c <= 'Z') { + return c - 'A'; + } else if (c >= 'a' && c <= 'z') { + return c - 'a' + 26; + } else if (c >= '0' && c <= '9') { + return c - '0' + 52; + } + + // comes down to alphabet + if (alphabet == alphabet::standard) { + if (c == '+') { + return 62; + } else if (c == '/') { + return 63; + } + } else if (alphabet == alphabet::url_filename_safe) { + if (c == '-') { + return 62; + } else if (c == '_') { + return 63; + } + } // auto detect + else { + if (c == '+') { + alphabet = alphabet::standard; + + return 62; + } else if (c == '/') { + alphabet = alphabet::standard; + + return 63; + } else if (c == '-') { + alphabet = alphabet::url_filename_safe; + + return 62; + } else if (c == '_') { + alphabet = alphabet::url_filename_safe; + + return 63; + } + } + + throw base64_error("invalid base64 character."); + } +}; + +#endif // !PUBLIC_DOMAIN_BASE64_HPP_ diff --git a/common/build-info.cpp.in b/common/build-info.cpp.in new file mode 100644 index 0000000000000000000000000000000000000000..aee9d7eafd681c922f27d8e6ecbf4e44f44d1dcd --- /dev/null +++ b/common/build-info.cpp.in @@ -0,0 +1,4 @@ +int LLAMA_BUILD_NUMBER = @LLAMA_BUILD_NUMBER@; +char const *LLAMA_COMMIT = "@LLAMA_BUILD_COMMIT@"; +char const *LLAMA_COMPILER = "@BUILD_COMPILER@"; +char const *LLAMA_BUILD_TARGET = "@BUILD_TARGET@"; diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..18a30e49aa578f89d8fd0f22a17dc38fd8764b50 --- /dev/null +++ b/common/chat-parser.cpp @@ -0,0 +1,385 @@ +#include "chat-parser.h" +#include "common.h" +#include "log.h" +#include "regex-partial.h" + +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax) + : input_(input), is_partial_(is_partial), syntax_(syntax) +{ + result_.role = "assistant"; + + while (true) { + std::string id = std::to_string(std::rand()); + if (input.find(id) == std::string::npos) { + healing_marker_ = id; + break; + } + } +} + +std::string common_chat_msg_parser::str(const common_string_range & rng) const { + GGML_ASSERT(rng.begin <= rng.end); + return input_.substr(rng.begin, rng.end - rng.begin); +} + +void common_chat_msg_parser::add_content(const std::string &content) { + result_.content += content; +} + +void common_chat_msg_parser::add_reasoning_content(const std::string &reasoning_content) { + result_.reasoning_content += reasoning_content; +} + +bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::string & id, const std::string & arguments) { + if (name.empty()) { + return false; + } + + common_chat_tool_call tool_call; + tool_call.name = name; + tool_call.arguments = arguments; + tool_call.id = id; + + // LOG_DBG("Tool call arguments:\n\traw: %s\n\tresult: %s\n", arguments.c_str(), tool_call.arguments.c_str()); + result_.tool_calls.emplace_back(tool_call); + + return true; +} +bool common_chat_msg_parser::add_tool_call(const json & tool_call) { + std::string name = tool_call.contains("name") ? tool_call.at("name") : ""; + std::string id = tool_call.contains("id") ? tool_call.at("id") : ""; + std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : ""; + return add_tool_call(name, id, arguments); +} + +bool common_chat_msg_parser::add_tool_calls(const json & arr) { + for (const auto & item : arr) { + if (!add_tool_call(item)) { + return false; + } + } + return true; +} +void common_chat_msg_parser::finish() { + if (!is_partial_ && pos_ != input_.size()) { + throw std::runtime_error("Unexpected content at end of input");// + input_.substr(pos_)); + } +} + +bool common_chat_msg_parser::consume_spaces() { + const auto length = input_.size(); + auto consumed = false; + while (pos_ < length && std::isspace(input_[pos_])) { + ++pos_; + consumed = true; + } + return consumed; +} + +bool common_chat_msg_parser::try_consume_literal(const std::string & literal) { + auto pos = pos_; + for (auto i = 0u; i < literal.size(); ++i) { + if (pos >= input_.size()) { + return false; + } + if (input_[pos] != literal[i]) { + return false; + } + ++pos; + } + pos_ = pos; + return true; +} + +std::optional common_chat_msg_parser::try_find_literal(const std::string & literal) { + auto idx = input_.find(literal, pos_); + if (idx != std::string::npos) { + find_regex_result res; + res.prelude = input_.substr(pos_, idx - pos_); + auto end = idx + literal.size(); + res.groups.emplace_back(common_string_range{idx, end}); + move_to(end); + return res; + } + if (is_partial_) { + idx = string_find_partial_stop(input_, literal); + if (idx != std::string::npos && idx >= pos_) { + find_regex_result res; + res.prelude = input_.substr(pos_, idx - pos_); + auto end = input_.size(); + res.groups.emplace_back(common_string_range{idx, end}); + move_to(end); + return res; + } + } + return std::nullopt; +} + +void common_chat_msg_parser::consume_literal(const std::string & literal) { + if (!try_consume_literal(literal)) { + throw common_chat_msg_partial_exception(literal); + } +} + +bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think, const std::string & end_think) { + auto handle_reasoning = [&](const std::string & reasoning, bool closed) { + auto stripped_reasoning = string_strip(reasoning); + if (stripped_reasoning.empty()) { + return; + } + if (syntax_.reasoning_in_content) { + add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "" : start_think); + add_content(stripped_reasoning); + if (closed) { + add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "" : end_think); + } + } else { + add_reasoning_content(stripped_reasoning); + } + }; + if (syntax_.reasoning_format != COMMON_REASONING_FORMAT_NONE) { + if (syntax_.thinking_forced_open || try_consume_literal(start_think)) { + if (auto res = try_find_literal(end_think)) { + handle_reasoning(res->prelude, /* closed */ true); + consume_spaces(); + return true; + } + auto rest = consume_rest(); + if (!rest.empty()) { + handle_reasoning(rest, /* closed */ !is_partial()); + } + // Allow unclosed thinking tags, for now (https://github.com/ggml-org/llama.cpp/issues/13812, https://github.com/ggml-org/llama.cpp/issues/13877) + // if (!syntax_.thinking_forced_open) { + // throw common_chat_msg_partial_exception(end_think); + // } + return true; + } + } + return false; +} + +std::string common_chat_msg_parser::consume_rest() { + auto rest = input_.substr(pos_); + pos_ = input_.size(); + return rest; +} + +// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback. +std::optional common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from, bool add_prelude_to_content) { + auto m = regex.search(input_, from == std::string::npos ? pos_ : from); + if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) { + return std::nullopt; + } + auto prelude = input_.substr(pos_, m.groups[0].begin - pos_); + pos_ = m.groups[0].end; + + if (add_prelude_to_content) { + add_content(prelude); + } + if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) { + if (is_partial()) { + throw common_chat_msg_partial_exception(regex.str()); + } + return std::nullopt; + } + return find_regex_result{prelude, m.groups}; +} + +common_chat_msg_parser::find_regex_result common_chat_msg_parser::consume_regex(const common_regex & regex) { + if (auto result = try_consume_regex(regex)) { + return *result; + } + throw common_chat_msg_partial_exception(regex.str()); +} + +std::optional common_chat_msg_parser::try_consume_regex(const common_regex & regex) { + auto m = regex.search(input_, pos_); + if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) { + return std::nullopt; + } + if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) { + if (is_partial()) { + throw common_chat_msg_partial_exception(regex.str()); + } + return std::nullopt; + } + if (m.groups[0].begin != pos_) { + // Didn't match at the current position. + return std::nullopt; + } + pos_ = m.groups[0].end; + + return find_regex_result { + /* .prelude = */ "", + m.groups, + }; +} + +std::optional common_chat_msg_parser::try_consume_json() { + auto it = input_.cbegin() + pos_; + const auto end = input_.cend(); + common_json result; + if (!common_json_parse(it, end, healing_marker_, result)) { + return std::nullopt; + } + pos_ = std::distance(input_.cbegin(), it); + if (result.healing_marker.marker.empty()) { + // No healing marker, just return the parsed json + return result; + } + if (!is_partial()) { + throw common_chat_msg_partial_exception("JSON"); + } + return result; +} + +common_json common_chat_msg_parser::consume_json() { + if (auto result = try_consume_json()) { + return *result; + } + throw common_chat_msg_partial_exception("JSON"); +} + +common_chat_msg_parser::consume_json_result common_chat_msg_parser::consume_json_with_dumped_args( + const std::vector> & args_paths, + const std::vector> & content_paths +) { + if (auto result = try_consume_json_with_dumped_args(args_paths, content_paths)) { + return *result; + } + throw common_chat_msg_partial_exception("JSON"); +} + +std::optional common_chat_msg_parser::try_consume_json_with_dumped_args( + const std::vector> & args_paths, + const std::vector> & content_paths +) { + auto partial = try_consume_json(); + if (!partial) { + return std::nullopt; + } + auto is_arguments_path = [&](const std::vector & path) { + return std::find(args_paths.begin(), args_paths.end(), path) != args_paths.end(); + }; + auto is_content_path = [&](const std::vector & path) { + return std::find(content_paths.begin(), content_paths.end(), path) != content_paths.end(); + }; + + if (partial->healing_marker.marker.empty()) { + if (args_paths.empty()) { + // No arguments to dump, and JSON was parsed fully. + return consume_json_result { + partial->json, + /* .is_partial = */ false, + }; + } + if (is_arguments_path({})) { + // Entire JSON is the arguments and was parsed fully. + return consume_json_result { + partial->json.dump(), + /* .is_partial = */ false, + }; + } + } + + LOG_DBG("Parsed partial JSON: %s (json_healing_marker: %s)\n", partial->json.dump().c_str(), partial->healing_marker.json_dump_marker.c_str()); + + auto found_healing_marker = false; + std::vector path; + std::function remove_unsupported_healings_and_dump_args = [&](const json & j) -> json { + if (is_arguments_path(path)) { + auto arguments = j.dump(); + if (is_partial() && !partial->healing_marker.marker.empty()) { + auto idx = arguments.find(partial->healing_marker.json_dump_marker); + if (idx != std::string::npos) { + arguments.resize(idx); + found_healing_marker = true; + } + if (arguments == "\"") { + // This happens because of completing `:"$magic` after `"arguments"` + arguments = ""; + } + } + return arguments; + } + if (is_content_path(path)) { + if (!j.is_string()) { + throw std::runtime_error("Content path must be a string"); + } + std::string str = j; + auto idx = str.find(partial->healing_marker.marker); // not using json_dump_marker as we're inside a string + if (idx != std::string::npos) { + str.resize(idx); + found_healing_marker = true; + } + return str; + } + if (j.is_object()) { + auto obj = json::object(); + for (const auto & p : j.items()) { + const auto & key = p.key(); + const auto & value = p.value(); + const std::string key_str = key; // NOLINT + auto idx = key_str.find(healing_marker_); + if (idx != std::string::npos) { + found_healing_marker = true; + break; + } + path.push_back(key_str); + if (value.is_string()) { + const std::string value_str = value; + if (value_str.find(healing_marker_) != std::string::npos) { + found_healing_marker = true; + if (is_content_path(path)) { + if (partial->healing_marker.marker == partial->healing_marker.json_dump_marker) { + // The healing occurred inside the string: good. Otherwise we just ditch the entire key/value pair. + obj[key] = remove_unsupported_healings_and_dump_args(value); + } + } + break; + } + obj[key] = value; + } else { + obj[key] = remove_unsupported_healings_and_dump_args(value); + } + path.pop_back(); + } + return obj; + } + if (j.is_array()) { + auto arr = json::array(); + for (const auto & value : j) { + if (value.is_string()) { + std::string str = value; + auto idx = str.find(healing_marker_); + if (idx != std::string::npos) { + // Don't heal array values that aren't in the arguments. + found_healing_marker = true; + break; + } + } + arr.push_back(remove_unsupported_healings_and_dump_args(value)); + } + return arr; + } + return j; + }; + + auto cleaned = remove_unsupported_healings_and_dump_args(partial->json); + LOG_DBG("Cleaned up JSON %s to %s (json_healing_marker : '%s')\n", partial->json.dump().c_str(), cleaned.dump().c_str(), partial->healing_marker.json_dump_marker.c_str()); + return consume_json_result { + cleaned, + /* .is_partial = */ found_healing_marker, + }; +} + +void common_chat_msg_parser::clear_tools() { + result_.tool_calls.clear(); +} diff --git a/common/chat-parser.h b/common/chat-parser.h new file mode 100644 index 0000000000000000000000000000000000000000..0e64c341a50aae93f3468ffcc69839775daeb399 --- /dev/null +++ b/common/chat-parser.h @@ -0,0 +1,120 @@ +#pragma once + +#include "chat.h" +#include "json-partial.h" +#include "regex-partial.h" + +#include + +#include +#include +#include + +class common_chat_msg_partial_exception : public std::runtime_error { + public: + common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {} +}; + +class common_chat_msg_parser { + std::string input_; + bool is_partial_; + common_chat_syntax syntax_; + std::string healing_marker_; + + size_t pos_ = 0; + common_chat_msg result_; + + public: + common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax); + const std::string & input() const { return input_; } + size_t pos() const { return pos_; } + const std::string & healing_marker() const { return healing_marker_; } + const bool & is_partial() const { return is_partial_; } + const common_chat_msg & result() const { return result_; } + const common_chat_syntax & syntax() const { return syntax_; } + + void move_to(size_t pos) { + if (pos > input_.size()) { + throw std::runtime_error("Invalid position!"); + } + pos_ = pos; + } + void move_back(size_t n) { + if (pos_ < n) { + throw std::runtime_error("Can't move back that far!"); + } + pos_ -= n; + } + + // Get the substring of the input at the given range + std::string str(const common_string_range & rng) const; + + // Appends to the result.content field + void add_content(const std::string & content); + + // Appends to the result.reasoning_content field + void add_reasoning_content(const std::string & reasoning_content); + + // Adds a tool call to the result. If the tool call is too incomplete (e.g. name empty), it won't add anything. + bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments); + + // Adds a tool call using the "name", "id" and "arguments" fields of the json object + bool add_tool_call(const nlohmann::ordered_json & tool_call); + + // Adds an array of tool calls using their "name", "id" and "arguments" fields. + bool add_tool_calls(const nlohmann::ordered_json & arr); + + void finish(); + + bool consume_spaces(); + + void consume_literal(const std::string & literal); + + bool try_parse_reasoning(const std::string & start_think, const std::string & end_think); + + std::string consume_rest(); + + struct find_regex_result { + std::string prelude; + std::vector groups; + }; + + std::optional try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true); + + bool try_consume_literal(const std::string & literal); + + std::optional try_find_literal(const std::string & literal); + + find_regex_result consume_regex(const common_regex & regex); + + std::optional try_consume_regex(const common_regex & regex); + + std::optional try_consume_json(); + common_json consume_json(); + + struct consume_json_result { + nlohmann::ordered_json value; + bool is_partial; + }; + + /* + Consume (possibly partial) json and converts specific subtrees to (possibly truncated) JSON strings. + + By default, object keys can't be truncated, nor can string values (their corresponding key is removed, + e.g. `{"foo": "bar", "baz": "b` -> `{"foo": "bar"}` + + But one can allow subpaths to be kept truncated, and possibly json-dumped to truncated json strings + - with `content_paths={{"foo"}}` -> `{"foo": "b` -> {"foo": "b"}` + - with `args_paths={{"foo"}}` -> `{"foo": {"b` -> `{"foo": "{b"}` + */ + consume_json_result consume_json_with_dumped_args( + const std::vector> & args_paths = {}, + const std::vector> & content_paths = {} + ); + std::optional try_consume_json_with_dumped_args( + const std::vector> & args_paths = {}, + const std::vector> & content_paths = {} + ); + + void clear_tools(); +}; diff --git a/common/chat.cpp b/common/chat.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7d9aaeb12a1906115d68fba577740bf679978386 --- /dev/null +++ b/common/chat.cpp @@ -0,0 +1,1932 @@ +#include "chat.h" +#include "chat-parser.h" +#include "common.h" +#include "json-partial.h" +#include "json-schema-to-grammar.h" +#include "log.h" +#include "regex-partial.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) { + auto time = std::chrono::system_clock::to_time_t(now); + auto local_time = *std::localtime(&time); + std::ostringstream ss; + ss << std::put_time(&local_time, format.c_str()); + auto res = ss.str(); + return res; +} + +static std::string string_diff(const std::string & last, const std::string & current) { + if (last.empty()) { + return current; + } + if (!string_starts_with(current, last)) { + if (string_starts_with(last, current)) { + // This happens if the last generation ended on a partial stop word (not erased), + // and the current ended on a stop word (erased). + return ""; + } + throw std::runtime_error("Invalid diff: '" + last + "' not found at start of '" + current + "'"); + } + return current.substr(last.size()); +} + +static bool has_content_or_tool_calls(const common_chat_msg & msg) { + return !msg.content.empty() || !msg.tool_calls.empty(); +} + +template <> +json common_chat_msg::to_json_oaicompat() const +{ + json message { + {"role", "assistant"}, + }; + if (!reasoning_content.empty()) { + message["reasoning_content"] = reasoning_content; + } + if (content.empty() && !tool_calls.empty()) { + message["content"] = json(); + } else { + message["content"] = content; + } + if (!tool_calls.empty()) { + auto arr = json::array(); + for (const auto & tc : tool_calls) { + arr.push_back({ + {"type", "function"}, + {"function", { + {"name", tc.name}, + {"arguments", tc.arguments}, + }}, + {"id", tc.id}, + // // Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo). + // // We only generate a random id for the ones that don't generate one by themselves + // // (they also won't get to see it as their template likely doesn't use it, so it's all for the client) + // {"id", tc.id.empty() ? gen_tool_call_id() : tc.id}, + }); + } + message["tool_calls"] = arr; + } + return message; +} + +std::vector common_chat_msg_diff::compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg) { + std::vector diffs; + if (previous_msg.reasoning_content != new_msg.reasoning_content) { + auto & diff = diffs.emplace_back(); + diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, new_msg.reasoning_content); + } + if (previous_msg.content != new_msg.content) { + auto & diff = diffs.emplace_back(); + diff.content_delta = string_diff(previous_msg.content, new_msg.content); + } + + if (new_msg.tool_calls.size() < previous_msg.tool_calls.size()) { + throw std::runtime_error("Invalid diff: now finding less tool calls!"); + } + + if (!previous_msg.tool_calls.empty()) { + auto idx = previous_msg.tool_calls.size() - 1; + const auto & pref = previous_msg.tool_calls[idx]; + const auto & newf = new_msg.tool_calls[idx]; + if (pref.name != newf.name) { + throw std::runtime_error("Invalid diff: tool call mismatch!"); + } + auto args_diff = string_diff(pref.arguments, newf.arguments); + if (!args_diff.empty() || pref.id != newf.id) { + auto & diff = diffs.emplace_back(); + diff.tool_call_index = idx; + if (pref.id != newf.id) { + diff.tool_call_delta.id = newf.id; + diff.tool_call_delta.name = newf.name; + } + diff.tool_call_delta.arguments = args_diff; + } + } + for (size_t idx = previous_msg.tool_calls.size(); idx < new_msg.tool_calls.size(); ++idx) { + auto & diff = diffs.emplace_back(); + diff.tool_call_index = idx; + diff.tool_call_delta = new_msg.tool_calls[idx]; + } + return diffs; +} + +typedef minja::chat_template common_chat_template; + +struct common_chat_templates { + bool has_explicit_template; // Model had builtin template or template overridde was specified. + std::unique_ptr template_default; // always set (defaults to chatml) + std::unique_ptr template_tool_use; +}; + +struct templates_params { + json messages; + json tools; + common_chat_tool_choice tool_choice; + json json_schema; + bool parallel_tool_calls; + bool stream; + std::string grammar; + bool add_generation_prompt = true; + bool enable_thinking = true; + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); +}; + +common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) { + if (tool_choice == "auto") { + return COMMON_CHAT_TOOL_CHOICE_AUTO; + } + if (tool_choice == "none") { + return COMMON_CHAT_TOOL_CHOICE_NONE; + } + if (tool_choice == "required") { + return COMMON_CHAT_TOOL_CHOICE_REQUIRED; + } + throw std::runtime_error("Invalid tool_choice: " + tool_choice); +} + +template <> +std::vector common_chat_msgs_parse_oaicompat(const json & messages) { + std::vector msgs; + + try { + + if (!messages.is_array()) { + throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump()); + } + + for (const auto & message : messages) { + if (!message.is_object()) { + throw std::runtime_error("Expected 'message' to be an object, got " + message.dump()); + } + + common_chat_msg msg; + if (!message.contains("role")) { + throw std::runtime_error("Missing 'role' in message: " + message.dump()); + } + msg.role = message.at("role"); + + auto has_content = message.contains("content"); + auto has_tool_calls = message.contains("tool_calls"); + if (has_content) { + const auto & content = message.at("content"); + if (content.is_string()) { + msg.content = content; + } else if (content.is_array()) { + for (const auto & part : content) { + if (!part.contains("type")) { + throw std::runtime_error("Missing content part type: " + part.dump()); + } + const auto & type = part.at("type"); + if (type != "text") { + throw std::runtime_error("Unsupported content part type: " + type.dump()); + } + common_chat_msg_content_part msg_part; + msg_part.type = type; + msg_part.text = part.at("text"); + msg.content_parts.push_back(msg_part); + } + } else if (!content.is_null()) { + throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)"); + } + } + if (has_tool_calls) { + for (const auto & tool_call : message.at("tool_calls")) { + common_chat_tool_call tc; + if (!tool_call.contains("type")) { + throw std::runtime_error("Missing tool call type: " + tool_call.dump()); + } + const auto & type = tool_call.at("type"); + if (type != "function") { + throw std::runtime_error("Unsupported tool call type: " + tool_call.dump()); + } + if (!tool_call.contains("function")) { + throw std::runtime_error("Missing tool call function: " + tool_call.dump()); + } + const auto & fc = tool_call.at("function"); + if (!fc.contains("name")) { + throw std::runtime_error("Missing tool call name: " + tool_call.dump()); + } + tc.name = fc.at("name"); + tc.arguments = fc.at("arguments"); + if (tool_call.contains("id")) { + tc.id = tool_call.at("id"); + } + msg.tool_calls.push_back(tc); + } + } + if (!has_content && !has_tool_calls) { + throw std::runtime_error("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)"); + } + if (message.contains("reasoning_content")) { + msg.reasoning_content = message.at("reasoning_content"); + } + if (message.contains("name")) { + msg.tool_name = message.at("name"); + } + if (message.contains("tool_call_id")) { + msg.tool_call_id = message.at("tool_call_id"); + } + + msgs.push_back(msg); + } + } catch (const std::exception & e) { + // @ngxson : disable otherwise it's bloating the API response + // printf("%s\n", std::string("; messages = ") + messages.dump(2)); + throw std::runtime_error("Failed to parse messages: " + std::string(e.what())); + } + + return msgs; +} + +template <> +json common_chat_msgs_to_json_oaicompat(const std::vector & msgs, bool concat_typed_text) { + json messages = json::array(); + for (const auto & msg : msgs) { + if (!msg.content.empty() && !msg.content_parts.empty()) { + throw std::runtime_error("Cannot specify both content and content_parts"); + } + json jmsg { + {"role", msg.role}, + }; + if (!msg.content.empty()) { + jmsg["content"] = msg.content; + } else if (!msg.content_parts.empty()) { + if (concat_typed_text) { + std::string text; + for (const auto & part : msg.content_parts) { + if (part.type != "text") { + LOG_WRN("Ignoring content part type: %s\n", part.type.c_str()); + continue; + } + if (!text.empty()) { + text += '\n'; + } + text += part.text; + } + jmsg["content"] = text; + } else { + auto & parts = jmsg["content"] = json::array(); + for (const auto & part : msg.content_parts) { + parts.push_back({ + {"type", part.type}, + {"text", part.text}, + }); + } + } + } else { + jmsg["content"] = json(); // null + } + if (!msg.reasoning_content.empty()) { + jmsg["reasoning_content"] = msg.reasoning_content; + } + if (!msg.tool_name.empty()) { + jmsg["name"] = msg.tool_name; + } + if (!msg.tool_call_id.empty()) { + jmsg["tool_call_id"] = msg.tool_call_id; + } + if (!msg.tool_calls.empty()) { + auto & tool_calls = jmsg["tool_calls"] = json::array(); + for (const auto & tool_call : msg.tool_calls) { + json tc { + {"type", "function"}, + {"function", { + {"name", tool_call.name}, + {"arguments", tool_call.arguments}, + }}, + }; + if (!tool_call.id.empty()) { + tc["id"] = tool_call.id; + } + tool_calls.push_back(tc); + } + } + messages.push_back(jmsg); + } + return messages; +} + +template <> +std::vector common_chat_msgs_parse_oaicompat(const std::string & messages) { + return common_chat_msgs_parse_oaicompat(json::parse(messages)); +} + +template <> +std::vector common_chat_tools_parse_oaicompat(const json & tools) { + std::vector result; + + try { + if (!tools.is_null()) { + if (!tools.is_array()) { + throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump()); + } + for (const auto & tool : tools) { + if (!tool.contains("type")) { + throw std::runtime_error("Missing tool type: " + tool.dump()); + } + const auto & type = tool.at("type"); + if (!type.is_string() || type != "function") { + throw std::runtime_error("Unsupported tool type: " + tool.dump()); + } + if (!tool.contains("function")) { + throw std::runtime_error("Missing tool function: " + tool.dump()); + } + + const auto & function = tool.at("function"); + result.push_back({ + /* .name = */ function.at("name"), + /* .description = */ function.at("description"), + /* .parameters = */ function.at("parameters").dump(), + }); + } + } + } catch (const std::exception & e) { + throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2)); + } + + return result; +} + +template <> +std::vector common_chat_tools_parse_oaicompat(const std::string & tools) { + return common_chat_tools_parse_oaicompat(json::parse(tools)); +} + +template <> +json common_chat_tools_to_json_oaicompat(const std::vector & tools) { + if (tools.empty()) { + return json(); + } + + auto result = json::array(); + for (const auto & tool : tools) { + result.push_back({ + {"type", "function"}, + {"function", { + {"name", tool.name}, + {"description", tool.description}, + {"parameters", json::parse(tool.parameters)}, + }}, + }); + } + return result; +} + +template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) { + json delta = json::object(); + if (!diff.reasoning_content_delta.empty()) { + delta["reasoning_content"] = diff.reasoning_content_delta; + } + if (!diff.content_delta.empty()) { + delta["content"] = diff.content_delta; + } + if (diff.tool_call_index != std::string::npos) { + json tool_call; + tool_call["index"] = diff.tool_call_index; + if (!diff.tool_call_delta.id.empty()) { + tool_call["id"] = diff.tool_call_delta.id; + tool_call["type"] = "function"; + } + json function = json::object(); + if (!diff.tool_call_delta.name.empty()) { + function["name"] = diff.tool_call_delta.name; + } + function["arguments"] = diff.tool_call_delta.arguments; + tool_call["function"] = function; + delta["tool_calls"] = json::array({tool_call}); + } + return delta; +} + +bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { + if (use_jinja) { + try { + common_chat_msg msg; + msg.role = "user"; + msg.content = "test"; + + auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl); + + common_chat_templates_inputs inputs; + inputs.messages = {msg}; + + common_chat_templates_apply(tmpls.get(), inputs); + return true; + } catch (const std::exception & e) { + LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what()); + return false; + } + } + llama_chat_message chat[] = {{"user", "test"}}; + const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0); + return res >= 0; +} + +std::string common_chat_format_single( + const struct common_chat_templates * tmpls, + const std::vector & past_msg, + const common_chat_msg & new_msg, + bool add_ass, + bool use_jinja) { + + common_chat_templates_inputs inputs; + inputs.use_jinja = use_jinja; + + std::string fmt_past_msg; + if (!past_msg.empty()) { + inputs.messages = past_msg; + inputs.add_generation_prompt = false; + fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt; + } + std::ostringstream ss; + // if the past_msg ends with a newline, we must preserve it in the formatted version + if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { + ss << "\n"; + }; + // format chat with new_msg + inputs.messages.push_back(new_msg); + inputs.add_generation_prompt = add_ass; + auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt; + // get the diff part + ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); + return ss.str(); +} + +std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) { + common_chat_templates_inputs inputs; + inputs.use_jinja = use_jinja; + auto add_simple_msg = [&](auto role, auto content) { + common_chat_msg msg; + msg.role = role; + msg.content = content; + inputs.messages.push_back(msg); + }; + add_simple_msg("system", "You are a helpful assistant"); + add_simple_msg("user", "Hello"); + add_simple_msg("assistant", "Hi there"); + add_simple_msg("user", "How are you?"); + return common_chat_templates_apply(tmpls, inputs).prompt; +} + +#define CHATML_TEMPLATE_SRC \ + "{%- for message in messages -%}\n" \ + " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \ + "{%- endfor -%}\n" \ + "{%- if add_generation_prompt -%}\n" \ + " {{- '<|im_start|>assistant\n' -}}\n" \ + "{%- endif -%}" + +void common_chat_templates_free(struct common_chat_templates * tmpls) { + delete tmpls; +} + +bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls) { + return tmpls->has_explicit_template; +} + +const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) { + if (variant != nullptr) { + if (strcmp(variant, "tool_use") == 0) { + if (tmpls->template_tool_use) { + return tmpls->template_tool_use->source().c_str(); + } + return nullptr; + } else { + LOG_DBG("%s: unknown template variant: %s\n", __func__, variant); + } + } + return tmpls->template_default->source().c_str(); +} + +common_chat_templates_ptr common_chat_templates_init( + const struct llama_model * model, + const std::string & chat_template_override, + const std::string & bos_token_override, + const std::string & eos_token_override) +{ + std::string default_template_src; + std::string template_tool_use_src; + + bool has_explicit_template = !chat_template_override.empty(); + if (chat_template_override.empty()) { + GGML_ASSERT(model != nullptr); + const auto * str = llama_model_chat_template(model, /* name */ nullptr); + if (str) { + default_template_src = str; + has_explicit_template = true; + } + str = llama_model_chat_template(model, /* name */ "tool_use"); + if (str) { + template_tool_use_src = str; + has_explicit_template = true; + } + } else { + default_template_src = chat_template_override; + } + if (default_template_src.empty() || default_template_src == "chatml") { + if (!template_tool_use_src.empty()) { + default_template_src = template_tool_use_src; + } else { + default_template_src = CHATML_TEMPLATE_SRC; + } + } + std::string token_bos = bos_token_override; + std::string token_eos = eos_token_override; + if (model) { + const auto * vocab = llama_model_get_vocab(model); + const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) { + if (token == LLAMA_TOKEN_NULL) { + if (default_template_src.find(jinja_variable_name) != std::string::npos + || template_tool_use_src.find(jinja_variable_name) != std::string::npos) { + LOG_WRN("common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't work as intended.\n", name); + } + return std::string(); + } + return common_token_to_piece(vocab, token, true); + }; + token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token"); + token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token"); + } + common_chat_templates_ptr tmpls(new common_chat_templates()); + tmpls->has_explicit_template = has_explicit_template; + try { + tmpls->template_default = std::make_unique(default_template_src, token_bos, token_eos); + } catch (const std::exception & e) { + LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what()); + tmpls->template_default = std::make_unique(CHATML_TEMPLATE_SRC, token_bos, token_eos); + } + if (!template_tool_use_src.empty()) { + try { + tmpls->template_tool_use = std::make_unique(template_tool_use_src, token_bos, token_eos); + } catch (const std::exception & e) { + LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what()); + } + } + return tmpls; +} + +const char * common_chat_format_name(common_chat_format format) { + switch (format) { + case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only"; + case COMMON_CHAT_FORMAT_GENERIC: return "Generic"; + case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo"; + case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x"; + case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools"; + case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1"; + case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2"; + case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2"; + case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1"; + case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro"; + case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B"; + default: + throw std::runtime_error("Unknown chat format"); + } +} + +const char * common_reasoning_format_name(common_reasoning_format format) { + switch (format) { + case COMMON_REASONING_FORMAT_NONE: return "none"; + case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek"; + case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy"; + default: + throw std::runtime_error("Unknown reasoning format"); + } +} + +static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) { + std::string arguments; + if (builder.is_partial()) { + arguments = (json {{"code", code + builder.healing_marker()}}).dump(); + auto idx = arguments.find(builder.healing_marker()); + if (idx != std::string::npos) { + arguments.resize(idx); + } + } else { + arguments = (json {{"code", code}}).dump(); + } + return arguments; +} + +/** + * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. + * Aggregates the prefix, suffix and in-between text into the content. + */ +static void parse_json_tool_calls( + common_chat_msg_parser & builder, + const std::optional & block_open, + const std::optional & function_regex_start_only, + const std::optional & function_regex, + const common_regex & close_regex, + const std::optional & block_close, + bool allow_raw_python = false, + const std::function & get_function_name = nullptr) { + + auto parse_tool_calls = [&]() { + size_t from = std::string::npos; + auto first = true; + while (true) { + auto res = function_regex_start_only && first + ? builder.try_consume_regex(*function_regex_start_only) + : function_regex + ? builder.try_find_regex(*function_regex, from) + : std::nullopt; + if (res) { + std::string name; + if (get_function_name) { + name = get_function_name(*res); + } else { + GGML_ASSERT(res->groups.size() == 2); + name = builder.str(res->groups[1]); + } + first = false; + if (name.empty()) { + // get_function_name signalled us that we should skip this match and treat it as content. + from = res->groups[0].begin + 1; + continue; + } + from = std::string::npos; + + auto maybe_raw_python = name == "python" && allow_raw_python; + if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) { + if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { + if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_regex(close_regex); + } + continue; + } + if (maybe_raw_python) { + auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); + if (!builder.add_tool_call(name, "", arguments)) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + return; + } + throw common_chat_msg_partial_exception("incomplete tool call"); + } + break; + } + if (block_close) { + builder.consume_regex(*block_close); + } + builder.consume_spaces(); + builder.add_content(builder.consume_rest()); + }; + if (block_open) { + if (auto res = builder.try_find_regex(*block_open)) { + parse_tool_calls(); + } else { + builder.add_content(builder.consume_rest()); + } + } else { + parse_tool_calls(); + } +} + +static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, const common_regex & prefix, size_t rstrip_prefix = 0) { + static const std::vector> args_paths = {{"arguments"}}; + if (auto res = builder.try_find_regex(prefix)) { + builder.move_back(rstrip_prefix); + auto tool_calls = builder.consume_json_with_dumped_args(args_paths); + if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call array"); + } + } else { + builder.add_content(builder.consume_rest()); + } +} + +static void foreach_function(const json & tools, const std::function & fn) { + for (const auto & tool : tools) { + if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) { + LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str()); + continue; + } + fn(tool); + } +} + +static std::string apply( + const common_chat_template & tmpl, + const nlohmann::ordered_json & messages, + const nlohmann::ordered_json & tools, + bool add_generation_prompt, + const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) +{ + minja::chat_template_inputs tmpl_inputs; + tmpl_inputs.messages = messages; + tmpl_inputs.tools = tools; + tmpl_inputs.add_generation_prompt = add_generation_prompt; + tmpl_inputs.extra_context = extra_context; + // TODO: add flag to control date/time, if only for testing purposes. + // tmpl_inputs.now = std::chrono::system_clock::now(); + + minja::chat_template_options tmpl_opts; + // To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens + // instead of using `chat_template_options.use_bos_token = false`, since these tokens + // may be needed inside the template / between messages too. + auto result = tmpl.apply(tmpl_inputs, tmpl_opts); + if (string_starts_with(result, tmpl.bos_token())) { + result = result.substr(tmpl.bos_token().size()); + } + if (string_ends_with(result, tmpl.eos_token())) { + result = result.substr(0, result.size() - tmpl.eos_token().size()); + } + return result; +} + +static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + + auto tool_call_schemas = json::array(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + auto tool_schema = json { + {"type", "object"}, + {"properties", { + {"name", { + {"type", "string"}, + {"const", function.at("name")}, + }}, + {"arguments", function.at("parameters")}, + }}, + {"required", json::array({"name", "arguments"})}, + }; + if (function.contains("description")) { + tool_schema["description"] = function.at("description"); + } + if (inputs.parallel_tool_calls) { + tool_schema.at("properties")["id"] = { + {"type", "string"}, + {"minLength", 4}, + }; + tool_schema.at("required").push_back("id"); + } + tool_call_schemas.emplace_back(tool_schema); + }); + const auto tool_call = + inputs.parallel_tool_calls + ? json { + {"type", "object"}, + {"properties", { + {"tool_calls", { + {"type", "array"}, + {"items", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json { + {"anyOf", tool_call_schemas}, + }}, + {"minItems", 1}, + }}, + }}, + {"required", json::array({"tool_calls"})}, + } + : json { + {"type", "object"}, + {"properties", { + {"tool_call", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json { + {"anyOf", tool_call_schemas}, + }}, + }}, + {"required", json::array({"tool_call"})}, + }; + const auto schema = + inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED + ? json { + {"anyOf", json::array({ + tool_call, + { + {"type", "object"}, + {"properties", { + {"response", inputs.json_schema.is_null() + ? json {{"type", "string"}} + : inputs.json_schema + }, + }}, + {"required", json::array({"response"})}, + }, + })} + } + : tool_call; + + data.grammar_lazy = false; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + builder.add_schema("root", schema); + }); + + auto tweaked_messages = common_chat_template::add_system( + inputs.messages, + "Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request"); + + data.prompt = apply(tmpl, tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.format = COMMON_CHAT_FORMAT_GENERIC; + return data; +} +static void common_chat_parse_generic(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + static const std::vector> content_paths = { + {"response"}, + }; + static const std::vector> args_paths = { + {"tool_call", "arguments"}, + {"tool_calls", "arguments"}, + }; + auto data = builder.consume_json_with_dumped_args(args_paths, content_paths); + if (data.value.contains("tool_calls")) { + if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool calls"); + } + } else if (data.value.contains("tool_call")) { + if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } else if (data.value.contains("response")) { + const auto & response = data.value.at("response"); + builder.add_content(response.is_string() ? response.template get() : response.dump(2)); + if (data.is_partial) { + throw common_chat_msg_partial_exception("incomplete response"); + } + } else { + throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON"); + } +} + +static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + auto schemas = json::array(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + schemas.push_back({ + {"type", "object"}, + {"properties", { + // Important note: the model is probably trained to take a JSON stringified arguments value. + // It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object. + {"name", { + {"type", "string"}, + {"const", function.at("name")}, + }}, + {"arguments", function.at("parameters")}, + {"id", { + {"type", "string"}, + // Nemo's template expects a 9-character alphanumeric ID. + {"pattern", "^[a-zA-Z0-9]{9}$"}, + }}, + }}, + {"required", json::array({"name", "arguments", "id"})}, + }); + }); + auto schema = json { + {"type", "array"}, + {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, + {"minItems", 1}, + }; + if (!inputs.parallel_tool_calls) { + schema["maxItems"] = 1; + } + builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema)); + }); + data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"}); + data.preserved_tokens = { + "[TOOL_CALLS]", + }; + data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO; + return data; +} +static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex prefix(regex_escape("[TOOL_CALLS]")); + parse_prefixed_json_tool_call_array(builder, prefix); +} + +static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + + auto adjusted_messages = json::array(); + for (const auto & msg : inputs.messages) { + auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string(); + auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array(); + if (has_reasoning_content && has_tool_calls) { + auto adjusted_message = msg; + adjusted_message["tool_plan"] = msg.at("reasoning_content"); + adjusted_message.erase("reasoning_content"); + adjusted_messages.push_back(adjusted_message); + } else { + adjusted_messages.push_back(msg); + } + } + data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {}); + data.format = COMMON_CHAT_FORMAT_COMMAND_R7B; + if (string_ends_with(data.prompt, "<|START_THINKING|>")) { + if (!inputs.enable_thinking) { + data.prompt += "<|END_THINKING|>"; + } else { + data.thinking_forced_open = true; + } + } else if (!inputs.enable_thinking && string_ends_with(data.prompt, "<|CHATBOT_TOKEN|>")) { + data.prompt += "<|START_THINKING|><|END_THINKING|>"; + } + + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + auto schemas = json::array(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + schemas.push_back({ + {"type", "object"}, + {"properties", { + {"tool_call_id", { + {"type", "string"}, + // Command-R's template expects an integer string. + {"pattern", "^[0-9]{1,10}$"}, + }}, + {"tool_name", { + {"type", "string"}, + {"const", function.at("name")}, + }}, + {"parameters", function.at("parameters")}, + }}, + {"required", json::array({"tool_call_id", "tool_name", "parameters"})}, + }); + }); + auto schema = json { + {"type", "array"}, + {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, + {"minItems", 1}, + }; + if (!inputs.parallel_tool_calls) { + schema["maxItems"] = 1; + } + builder.add_rule("root", + std::string(data.thinking_forced_open ? "( \"<|END_THINKING|>\" space )? " : "") + + "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\""); + }); + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + // If thinking_forced_open, then we capture the tag in the grammar, + // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) + std::string(data.thinking_forced_open ? "[\\s\\S]*?(<\\|END_THINKING\\|>\\s*)" : "(?:<\\|START_THINKING\\|>[\\s\\S]*?<\\|END_THINKING\\|>\\s*)?") + + "(<\\|START_ACTION\\|>)[\\s\\S]*" + }); + data.preserved_tokens = { + "<|START_ACTION|>", + "<|END_ACTION|>", + "<|START_RESPONSE|>", + "<|END_RESPONSE|>", + "<|START_THINKING|>", + "<|END_THINKING|>", + }; + return data; +} + +static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>"); + + static const common_regex start_action_regex("<\\|START_ACTION\\|>"); + static const common_regex end_action_regex("<\\|END_ACTION\\|>"); + static const common_regex start_response_regex("<\\|START_RESPONSE\\|>"); + static const common_regex end_response_regex("<\\|END_RESPONSE\\|>"); + + if (auto res = builder.try_find_regex(start_action_regex)) { + // If we didn't extract thoughts, prelude includes them. + auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}}); + for (const auto & tool_call : tool_calls.value) { + std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : ""; + std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : ""; + std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : ""; + if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } + if (tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_regex(end_action_regex); + } else if (auto res = builder.try_find_regex(start_response_regex)) { + if (!builder.try_find_regex(end_response_regex)) { + builder.add_content(builder.consume_rest()); + throw common_chat_msg_partial_exception(end_response_regex.str()); + } + } else { + builder.add_content(builder.consume_rest()); + } +} + +static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector & expected_properties) { + if (!parameters.is_object() || !parameters.contains("type") || parameters.at("type") != "object" || !parameters.contains("properties") || !parameters.contains("required")) { + throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties"); + } + const auto & parameters_properties = parameters.at("properties"); + const auto & parameters_required = parameters.at("required"); + for (const auto & prop : expected_properties) { + if (!parameters_properties.contains(prop)) { + throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); // NOLINT + } + if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) { + throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); // NOLINT + } + } + if (parameters_properties.size() != expected_properties.size()) { + throw std::runtime_error("Parameters of tool " + name + " must only have these properties:" + string_join(expected_properties, ", ")); + } +} + +static common_chat_params common_chat_params_init_llama_3_x(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) { + auto builtin_tools = json::array(); + common_chat_params data; + if (!inputs.tools.is_null()) { + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + + auto handle_builtin_tool = [&](const std::string & name, const json & parameters) { + if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") { + // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py + // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py + expect_tool_parameters(name, parameters, {"query"}); + } else if (name == "python" || name == "code_interpreter") { + // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py + expect_tool_parameters(name, parameters, {"code"}); + } else { + return false; + } + + std::vector kvs; + for (const auto & [key, value] : parameters.at("properties").items()) { + kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT + } + + tool_rules.push_back( + builder.add_rule( + name + "-call", + "\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\"")); + builtin_tools.push_back(name); + + return true; + }; + + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); + + // https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime + if (allow_python_tag_builtin_tools) { + handle_builtin_tool(name, parameters); + } + tool_rules.push_back( + builder.add_rule( + name + "-call", + "\"{\" space " + "( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? " + " \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space " + " \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " " + "\"}\" space")); + }); + // Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name. + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + "(\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\")[\\s\\S]*", // + name + "\"[\\s\\S]*", + }); + if (!builtin_tools.empty()) { + data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"}); + data.preserved_tokens.push_back("<|python_tag|>"); + } + // Allow a few empty lines on top of the usual constrained json schema space rule. + builder.add_rule("root", string_join(tool_rules, " | ")); + data.additional_stops.push_back("<|eom_id|>"); + }); + data.format = allow_python_tag_builtin_tools && !builtin_tools.empty() + ? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS + : COMMON_CHAT_FORMAT_LLAMA_3_X; + } else { + data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + } + data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, { + {"date_string", format_time(inputs.now, "%d %b %Y")}, + {"tools_in_user_message", false}, + {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools}, + }); + return data; +} +static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex function_regex( + "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: "); + static const common_regex close_regex("\\}\\s*"); + + static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\("); + static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*"); + + if (with_builtin_tools) { + static const common_regex builtin_call_regex("<\\|python_tag\\|>"); + if (auto res = builder.try_find_regex(builtin_call_regex)) { + auto fun_res = builder.consume_regex(function_name_regex); + auto function_name = builder.str(fun_res.groups[1]); + + common_healing_marker healing_marker; + json args = json::object(); + while (true) { + if (auto arg_res = builder.try_consume_regex(arg_name_regex)) { + auto arg_name = builder.str(arg_res->groups[1]); + auto partial = builder.consume_json(); + args[arg_name] = partial.json; + healing_marker.marker = partial.healing_marker.marker; + healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker; + builder.consume_spaces(); + if (!builder.try_consume_literal(",")) { + break; + } + } else { + break; + } + } + builder.consume_literal(")"); + builder.consume_spaces(); + + auto arguments = args.dump(); + if (!builder.add_tool_call(function_name, "", arguments)) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + return; + } + } + parse_json_tool_calls( + builder, + /* block_open= */ std::nullopt, + /* function_regex_start_only= */ function_regex, + /* function_regex= */ std::nullopt, + close_regex, + std::nullopt); + +} + +static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + + // Hacks to fix the official (broken) prompt. + // It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead, + // until the official template is fixed. + if (tmpl.source().find("{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}") != std::string::npos) { + // Don't leave the chat dangling after tool results + if (string_ends_with(prompt, "<|tool▁outputs▁end|>")) { + prompt += "<|end▁of▁sentence|>"; + if (inputs.add_generation_prompt) { + prompt += "<|Assistant|>"; + } + } + // Fix up tool call delta example added by Minja + prompt = std::regex_replace( + prompt, + std::regex("(<|tool▁call▁end|>)[\\s\\r\\n]*(<|tool▁outputs▁begin|>|<|User|>)"), + "$1<|tool▁calls▁end|><|end▁of▁sentence|>$2"); + } + data.prompt = prompt; + data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1; + if (string_ends_with(data.prompt, "\n")) { + if (!inputs.enable_thinking) { + data.prompt += ""; + } else { + data.thinking_forced_open = true; + } + } + + if (inputs.tools.is_array() && !inputs.tools.empty()) { + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null(); + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); + tool_rules.push_back(builder.add_rule(name + "-call", + "( \"<|tool▁call▁begin|>\" )? \"function<|tool▁sep|>" + name + "\\n" + "```json\\n\" " + builder.add_schema(name + "-args", parameters) + " " + "\"```<|tool▁call▁end|>\"")); + }); + // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag, + // so we accept common variants (then it's all constrained) + builder.add_rule("root", + std::string(data.thinking_forced_open ? "( \"\" space )? " : "") + + "( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" | \"<|tool▁calls|>\" ) " + "(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " " + "\"<|tool▁calls▁end|>\"" + " space"); + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + // If thinking_forced_open, then we capture the tag in the grammar, + // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) + std::string(data.thinking_forced_open ? "[\\s\\S]*?(\\s*)" : "(?:[\\s\\S]*?\\s*)?") + + "(<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)[\\s\\S]*" + }); + data.preserved_tokens = { + "", + "", + "<|tool▁calls▁begin|>", + "<|tool▁call▁begin|>", + "<|tool▁sep|>", + "<|tool▁call▁end|>", + "<|tool▁calls▁end|", + }; + }); + } + return data; +} +static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); + static const common_regex tool_calls_end("<|tool▁calls▁end|>"); + static const common_regex function_regex("(?:<|tool▁call▁begin|>)?function<|tool▁sep|>([^\n]+)\n```json\n"); + static const common_regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>"); + + parse_json_tool_calls( + builder, + /* block_open= */ tool_calls_begin, + /* function_regex_start_only= */ std::nullopt, + function_regex, + close_regex, + tool_calls_end); +} + +static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { + LOG_DBG("%s\n", __func__); + common_chat_params data; + data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { + {"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")}, + {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, + }); + if (inputs.tools.is_array() && !inputs.tools.empty()) { + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + auto schemas = json::array(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + schemas.push_back({ + {"type", "object"}, + {"properties", { + {"name", { + {"type", "string"}, + {"const", function.at("name")}, + }}, + {"arguments", function.at("parameters")}, + }}, + {"required", json::array({"name", "arguments", "id"})}, + }); + }); + auto schema = json { + {"type", "array"}, + {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, + {"minItems", 1}, + }; + if (!inputs.parallel_tool_calls) { + schema["maxItems"] = 1; + } + builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema)); + }); + data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, " functools["}); + data.preserved_tokens = { + " functools[", + }; + data.format = COMMON_CHAT_FORMAT_FIREFUNCTION_V2; + } else { + data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + } + return data; +} +static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + static const common_regex prefix(regex_escape(" functools[")); + parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1); +} + +static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) { + // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... + // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar + // If the function is python, we also allow raw python code (if the line after `python\n` doesn't start w/ opening `{`), which the model seems to prefer for multiline code. + common_chat_params data; + data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2; + if (inputs.tools.is_array() && !inputs.tools.empty()) { + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector first_tool_rules; + std::vector subsequent_tool_rules; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); + std::string args_pattern = "[\\s\\S]*"; + auto args_rule = builder.add_schema(name + "-args", parameters); + if (name == "python") { + args_rule = builder.add_rule(name + "-maybe-raw-args", args_rule + " | [^{] .*"); + } else { + args_pattern = "\\{" + args_pattern; + } + auto call_rule = builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule); + first_tool_rules.push_back(call_rule); + if (inputs.parallel_tool_calls) { + subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>\" " + call_rule)); + } + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + "((?:[\\s\\S]+?>>>)?" + regex_escape(name) + "\n)" + args_pattern, + }); + }); + data.preserved_tokens = { + "<|end_header_id|>", + }; + auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space"; + if (inputs.parallel_tool_calls) { + auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space"; + builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*"); + } else { + builder.add_rule("root", first_rule); + } + + }); + } + return data; +} +static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) { + static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))"); + static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))"); + static const common_regex close_regex(R"(\s*)"); + + parse_json_tool_calls( + builder, + std::nullopt, + function_regex_start_only, + function_regex, + close_regex, + std::nullopt, + /* allow_raw_python= */ true, + /* get_function_name= */ [&](const auto & res) -> std::string { + auto at_start = res.groups[0].begin == 0; + auto name = builder.str(res.groups[1]); + if (!name.empty() && name.back() == '{') { + // Unconsume the opening brace '{' to ensure the JSON parsing goes well. + builder.move_back(1); + } + auto idx = name.find_last_not_of("\n{"); + name = name.substr(0, idx + 1); + if (at_start && name == "all") { + return ""; + } + return name; + }); +} + +static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) { + // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt + common_chat_params data; + + if (!inputs.tools.is_null()) { + std::string python_code_argument_name; + auto has_raw_python = false; + + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + const auto & parameters = function.at("parameters"); + std::string name = function.at("name"); + if (name == "python" || name == "ipython") { + if (!parameters.contains("type")) { + throw std::runtime_error("Missing type in python tool"); + } + has_raw_python = true; + const auto & type = parameters.at("type"); + if (type == "object") { + auto properties = parameters.at("properties"); + for (auto it = properties.begin(); it != properties.end(); ++it) { + if (it.value().at("type") == "string") { + if (!python_code_argument_name.empty()) { + throw std::runtime_error("Multiple string arguments found in python tool"); + } + python_code_argument_name = it.key(); + } + } + if (python_code_argument_name.empty()) { + throw std::runtime_error("No string argument found in python tool"); + } + } else if (type != "string") { + throw std::runtime_error("Invalid type in python tool: " + type.dump()); + } + } + tool_rules.push_back(builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\" space")); + }); + if (has_raw_python) { + tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); + data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"}); + data.preserved_tokens.push_back("<|python_tag|>"); + } + auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space"; + builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); + data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "")); + + static const common_regex function_regex(R"()"); + static const common_regex close_regex(R"()"); + + parse_json_tool_calls( + builder, + /* block_open= */ std::nullopt, + /* function_regex_start_only= */ std::nullopt, + function_regex, + close_regex, + std::nullopt); + + if (auto res = builder.try_find_regex(python_tag_regex)) { + auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); + builder.add_tool_call("python", "", arguments); + return; + } +} + +static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + + json additional_context = { + {"enable_thinking", inputs.enable_thinking}, + }; + + data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, additional_context); + data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO; + if (string_ends_with(data.prompt, "\n")) { + if (!inputs.enable_thinking) { + data.prompt += ""; + } else { + data.thinking_forced_open = true; + } + } + + if (!inputs.tools.is_null()) { + // (content)?({"name": "foo", "arguments": {"a": 1}})* + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + std::vector tool_call_alts; + std::vector escaped_names; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); + tool_rules.push_back(builder.add_schema(name + "-call", { + {"type", "object"}, + {"properties", json { + {"name", json {{"const", name}}}, + {"arguments", parameters}, + }}, + {"required", json::array({"name", "arguments"})}, + })); + tool_call_alts.push_back(builder.add_rule( + name + "-function-tag", + "\"\" space " + + builder.add_schema(name + "-args", parameters) + " " + "\"\" space")); + + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_WORD, + "", + }); + auto escaped_name = regex_escape(name); + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, + " alt_tags { + any_tool_call, + "\"\" space " + any_tool_call + " \"\"", + // The rest is just to accommodate common "good bad" outputs. + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + }; + auto wrappable_tool_call = builder.add_rule("wrappable_tool_call", "( " + string_join(alt_tags, " | ") + " ) space"); + tool_call_alts.push_back(wrappable_tool_call); + tool_call_alts.push_back( + "( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space "); + auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | ")); + builder.add_rule("root", + std::string(data.thinking_forced_open ? "( \"\" space )? " : "") + + (inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call)); + // Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives) + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + // If thinking_forced_open, then we capture the tag in the grammar, + // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) + std::string(data.thinking_forced_open ? "[\\s\\S]*?(\\s*)" : "(?:[\\s\\S]*?\\s*)?") + ( + "(\\s*" + "(?:" + "||||)?" + "\\s*\\{\\s*\"name\"\\s*:\\s*\"(?:" + string_join(escaped_names, "|") + ")\"" + ")" + ")[\\s\\S]*" + ), + }); + data.preserved_tokens = { + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "```", + "```json", + "```xml", + }; + }); + } + + return data; +} +static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex open_regex( + "(?:" + "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start) + "(" // match 2 (open_tag) + "" + "|" + "|" + "|" + "|" + "|" + "|" + "|" + ")?" + "(\\s*\\{\\s*\"name\")" // match 3 (named tool call) + ")" + "|]+)>" // match 4 (function name) + "|" // match 5 (function name again) + ); + + if (auto res = builder.try_find_regex(open_regex)) { + const auto & block_start = res->groups[1]; + std::string block_end = block_start.empty() ? "" : "```"; + + const auto & open_tag = res->groups[2]; + std::string close_tag; + + if (!res->groups[3].empty()) { + builder.move_to(res->groups[3].begin); + close_tag = open_tag.empty() ? "" : "value) || tool_call->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_spaces(); + builder.consume_literal(close_tag); + builder.consume_spaces(); + if (!block_end.empty()) { + builder.consume_literal(block_end); + builder.consume_spaces(); + } + builder.add_content(builder.consume_rest()); + } else { + throw common_chat_msg_partial_exception("failed to parse tool call"); + } + } else { + auto function_name = builder.str(res->groups[4]); + if (function_name.empty()) { + function_name = builder.str(res->groups[5]); + } + GGML_ASSERT(!function_name.empty()); + + close_tag = ""; + + if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { + if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_spaces(); + builder.consume_literal(close_tag); + builder.consume_spaces(); + if (!block_end.empty()) { + builder.consume_literal(block_end); + builder.consume_spaces(); + } + } + builder.add_content(builder.consume_rest()); + } + } else { + builder.add_content(builder.consume_rest()); + } +} + +static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + data.grammar_lazy = false; + if (!inputs.json_schema.is_null()) { + if (!inputs.grammar.empty()) { + throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); + } + data.grammar = json_schema_to_grammar(inputs.json_schema); + } else { + data.grammar = inputs.grammar; + } + return data; +} + +static common_chat_params common_chat_templates_apply_jinja( + const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs) +{ + templates_params params; + params.tools = common_chat_tools_to_json_oaicompat(inputs.tools); + const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use + ? *tmpls->template_tool_use + : *tmpls->template_default; + const auto & src = tmpl.source(); + const auto & caps = tmpl.original_caps(); + params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content); + params.add_generation_prompt = inputs.add_generation_prompt; + params.tool_choice = inputs.tool_choice; + params.enable_thinking = inputs.enable_thinking; + params.grammar = inputs.grammar; + params.now = inputs.now; + if (!inputs.json_schema.empty()) { + params.json_schema = json::parse(inputs.json_schema); + } + + if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) { + LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n"); + params.parallel_tool_calls = false; + } else { + params.parallel_tool_calls = inputs.parallel_tool_calls; + } + + if (params.tools.is_array()) { + if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) { + throw std::runtime_error("Cannot specify grammar with tools"); + } + if (caps.supports_tool_calls && !caps.supports_tools) { + LOG_WRN("Template supports tool calls but does not natively describe tools. The fallback behaviour used may produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n"); + } + } + + // DeepSeek R1: use handler in all cases except json schema (thinking / tools). + if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) { + return common_chat_params_init_deepseek_r1(tmpl, params); + } + + // Command R7B: : use handler in all cases except json schema (thinking / tools). + if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) { + return common_chat_params_init_command_r7b(tmpl, params); + } + + // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools) + if (src.find("") != std::string::npos && params.json_schema.is_null()) { + return common_chat_params_init_hermes_2_pro(tmpl, params); + } + + // Use generic handler when mixing tools + JSON schema. + // TODO: support that mix in handlers below. + if ((params.tools.is_array() && params.json_schema.is_object())) { + return common_chat_params_init_generic(tmpl, params); + } + + // Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases. + if (src.find(">>>all") != std::string::npos) { + return common_chat_params_init_functionary_v3_2(tmpl, params); + } + + // Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases. + if (src.find(" functools[") != std::string::npos) { + return common_chat_params_init_firefunction_v2(tmpl, params); + } + + // Functionary v3.1 (w/ tools) + if (src.find("<|start_header_id|>") != std::string::npos + && src.find("ipython<|end_header_id|>") != std::string::npos) { + auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos; + return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools); + } + + // Plain handler (no tools) + if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { + return common_chat_params_init_without_tools(tmpl, params); + } + + // Mistral Nemo (w/ tools) + if (src.find("[TOOL_CALLS]") != std::string::npos) { + return common_chat_params_init_mistral_nemo(tmpl, params); + } + + // Generic fallback + return common_chat_params_init_generic(tmpl, params); +} + +// Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template. +static common_chat_params common_chat_templates_apply_legacy( + const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs) +{ + int alloc_size = 0; + std::vector chat; + std::vector contents; + for (const auto & msg : inputs.messages) { + auto content = msg.content; + for (const auto & part : msg.content_parts) { + if (part.type != "text") { + LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str()); + continue; + } + if (!content.empty()) { + content += "\n";; + } + content += part.text; + } + contents.emplace_back(std::move(content)); + } + for (size_t i = 0; i < contents.size(); ++i) { + const auto & msg = inputs.messages[i]; + const auto & content = contents[i]; + chat.push_back({msg.role.c_str(), content.c_str()}); + alloc_size += (msg.role.size() + content.size()) * 1.25; + } + + std::vector buf(alloc_size); + + // run the first time to get the total output length + const auto & src = tmpls->template_default->source(); + int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size()); + + // error: chat template is not supported + if (res < 0) { + // if the custom "tmpl" is not supported, we throw an error + // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() + throw std::runtime_error("this custom template is not supported, try using --jinja"); + } + + // if it turns out that our buffer is too small, we resize it + if ((size_t) res > buf.size()) { + buf.resize(res); + res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size()); + } + + common_chat_params params; + params.prompt = std::string(buf.data(), res); + if (!inputs.json_schema.empty()) { + params.grammar = json_schema_to_grammar(json::parse(inputs.json_schema)); + } else { + params.grammar = inputs.grammar; + } + return params; +} + +common_chat_params common_chat_templates_apply( + const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs) +{ + GGML_ASSERT(tmpls != nullptr); + return inputs.use_jinja + ? common_chat_templates_apply_jinja(tmpls, inputs) + : common_chat_templates_apply_legacy(tmpls, inputs); +} + +static void common_chat_parse_content_only(common_chat_msg_parser & builder) { + builder.add_content(builder.consume_rest()); +} + +static void common_chat_parse(common_chat_msg_parser & builder) { + LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str()); + + switch (builder.syntax().format) { + case COMMON_CHAT_FORMAT_CONTENT_ONLY: + common_chat_parse_content_only(builder); + break; + case COMMON_CHAT_FORMAT_GENERIC: + common_chat_parse_generic(builder); + break; + case COMMON_CHAT_FORMAT_MISTRAL_NEMO: + common_chat_parse_mistral_nemo(builder); + break; + case COMMON_CHAT_FORMAT_LLAMA_3_X: + common_chat_parse_llama_3_1(builder); + break; + case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: + common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true); + break; + case COMMON_CHAT_FORMAT_DEEPSEEK_R1: + common_chat_parse_deepseek_r1(builder); + break; + case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: + common_chat_parse_functionary_v3_2(builder); + break; + case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: + common_chat_parse_functionary_v3_1_llama_3_1(builder); + break; + case COMMON_CHAT_FORMAT_HERMES_2_PRO: + common_chat_parse_hermes_2_pro(builder); + break; + case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: + common_chat_parse_firefunction_v2(builder); + break; + case COMMON_CHAT_FORMAT_COMMAND_R7B: + common_chat_parse_command_r7b(builder); + break; + default: + throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); + } + builder.finish(); +} + +common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) { + common_chat_msg_parser builder(input, is_partial, syntax); + try { + common_chat_parse(builder); + } catch (const common_chat_msg_partial_exception & ex) { + LOG_DBG("Partial parse: %s\n", ex.what()); + if (!is_partial) { + builder.clear_tools(); + builder.move_to(0); + common_chat_parse_content_only(builder); + } + } + auto msg = builder.result(); + LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); + return msg; +} diff --git a/common/chat.h b/common/chat.h new file mode 100644 index 0000000000000000000000000000000000000000..9f59e6b08738dcc13e148d4cd5e71f053bd4536b --- /dev/null +++ b/common/chat.h @@ -0,0 +1,202 @@ +// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers. + +#pragma once + +#include "common.h" +#include +#include +#include +#include + +struct common_chat_templates; + +struct common_chat_tool_call { + std::string name; + std::string arguments; + std::string id; + + bool operator==(const common_chat_tool_call & other) const { + return name == other.name && arguments == other.arguments && id == other.id; + } +}; + +struct common_chat_msg_content_part { + std::string type; + std::string text; + + bool operator==(const common_chat_msg_content_part & other) const { + return type == other.type && text == other.text; + } +}; + +struct common_chat_msg { + std::string role; + std::string content; + std::vector content_parts = {}; + std::vector tool_calls = {}; + std::string reasoning_content; + std::string tool_name; + std::string tool_call_id; + + template T to_json_oaicompat() const; + + bool empty() const { + return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty(); + } + void ensure_tool_call_ids_set(std::vector & ids_cache, const std::function & gen_tool_call_id) { + for (auto i = 0u; i < tool_calls.size(); i++) { + if (ids_cache.size() <= i) { + auto id = tool_calls[i].id; + if (id.empty()) { + id = gen_tool_call_id(); + } + ids_cache.push_back(id); + } + tool_calls[i].id = ids_cache[i]; + } + } + bool operator==(const common_chat_msg & other) const { + return role == other.role + && content == other.content + && content_parts == other.content_parts + && tool_calls == other.tool_calls + && reasoning_content == other.reasoning_content + && tool_name == other.tool_name + && tool_call_id == other.tool_call_id; + } + bool operator!=(const common_chat_msg & other) const { + return !(*this == other); + } +}; + +struct common_chat_msg_diff { + std::string reasoning_content_delta; + std::string content_delta; + size_t tool_call_index = std::string::npos; + common_chat_tool_call tool_call_delta; + + static std::vector compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg); + + bool operator==(const common_chat_msg_diff & other) const { + return content_delta == other.content_delta + && tool_call_index == other.tool_call_index + && tool_call_delta == other.tool_call_delta; + } +}; + +struct common_chat_tool { + std::string name; + std::string description; + std::string parameters; +}; + +enum common_chat_tool_choice { + COMMON_CHAT_TOOL_CHOICE_AUTO, + COMMON_CHAT_TOOL_CHOICE_REQUIRED, + COMMON_CHAT_TOOL_CHOICE_NONE, +}; + +enum common_chat_format { + COMMON_CHAT_FORMAT_CONTENT_ONLY, + COMMON_CHAT_FORMAT_GENERIC, + COMMON_CHAT_FORMAT_MISTRAL_NEMO, + COMMON_CHAT_FORMAT_LLAMA_3_X, + COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, + COMMON_CHAT_FORMAT_DEEPSEEK_R1, + COMMON_CHAT_FORMAT_FIREFUNCTION_V2, + COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, + COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, + COMMON_CHAT_FORMAT_HERMES_2_PRO, + COMMON_CHAT_FORMAT_COMMAND_R7B, + + COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats +}; + +struct common_chat_templates_inputs { + std::vector messages; + std::string grammar; + std::string json_schema; + bool add_generation_prompt = true; + bool use_jinja = true; + // Parameters below only supported when use_jinja is true + std::vector tools; + common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; + bool parallel_tool_calls = false; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; + bool enable_thinking = true; + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); +}; + +struct common_chat_params { + common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + std::string prompt; + std::string grammar; + bool grammar_lazy = false; + bool thinking_forced_open = false; + std::vector grammar_triggers; + std::vector preserved_tokens; + std::vector additional_stops; +}; + +struct common_chat_syntax { + common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; + // Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode) + bool reasoning_in_content = false; + bool thinking_forced_open = false; + bool parse_tool_calls = true; +}; + +// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid +bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); + +void common_chat_templates_free(struct common_chat_templates * tmpls); + +struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } }; + +typedef std::unique_ptr common_chat_templates_ptr; + +common_chat_templates_ptr common_chat_templates_init( + const struct llama_model * model, + const std::string & chat_template_override, + const std::string & bos_token_override = "", + const std::string & eos_token_override = ""); + +bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls); +const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr); + + +struct common_chat_params common_chat_templates_apply( + const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs); + +// Format single message, while taking into account the position of that message in chat history +std::string common_chat_format_single( + const struct common_chat_templates * tmpls, + const std::vector & past_msg, + const common_chat_msg & new_msg, + bool add_ass, + bool use_jinja); + +// Returns an example of formatted chat +std::string common_chat_format_example( + const struct common_chat_templates * tmpls, + bool use_jinja); + +const char* common_chat_format_name(common_chat_format format); +const char* common_reasoning_format_name(common_reasoning_format format); +common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax); + +common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice); + +// Parses a JSON array of messages in OpenAI's chat completion API format. +// T can be std::string containing JSON or nlohmann::ordered_json +template std::vector common_chat_msgs_parse_oaicompat(const T & messages); +template T common_chat_msgs_to_json_oaicompat(const std::vector & msgs, bool concat_typed_text = false); + +// Parses a JSON array of tools in OpenAI's chat completion tool call API format. +// T can be std::string containing JSON or nlohmann::ordered_json +template std::vector common_chat_tools_parse_oaicompat(const T & tools); +template T common_chat_tools_to_json_oaicompat(const std::vector & tools); + +template T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff); diff --git a/common/common.cpp b/common/common.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eb80cee0894a60e308aeae03530426c6991530a1 --- /dev/null +++ b/common/common.cpp @@ -0,0 +1,1541 @@ +#if defined(_MSC_VER) +#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING +#endif + +#include "ggml.h" +#include "gguf.h" + +#include "common.h" +#include "log.h" +#include "llama.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__APPLE__) && defined(__MACH__) +#include +#include +#endif + +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +# define NOMINMAX +#endif +#include +#include +#include +#include +#else +#include +#include +#include +#endif + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +// +// CPU utils +// + +int32_t cpu_get_num_physical_cores() { +#ifdef __linux__ + // enumerate the set of thread siblings, num entries is num cores + std::unordered_set siblings; + for (uint32_t cpu=0; cpu < UINT32_MAX; ++cpu) { + std::ifstream thread_siblings("/sys/devices/system/cpu/cpu" + + std::to_string(cpu) + "/topology/thread_siblings"); + if (!thread_siblings.is_open()) { + break; // no more cpus + } + std::string line; + if (std::getline(thread_siblings, line)) { + siblings.insert(line); + } + } + if (!siblings.empty()) { + return static_cast(siblings.size()); + } +#elif defined(__APPLE__) && defined(__MACH__) + int32_t num_physical_cores; + size_t len = sizeof(num_physical_cores); + int result = sysctlbyname("hw.perflevel0.physicalcpu", &num_physical_cores, &len, NULL, 0); + if (result == 0) { + return num_physical_cores; + } + result = sysctlbyname("hw.physicalcpu", &num_physical_cores, &len, NULL, 0); + if (result == 0) { + return num_physical_cores; + } +#elif defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later + // TODO: windows + arm64 + mingw64 + unsigned int n_threads_win = std::thread::hardware_concurrency(); + unsigned int default_threads = n_threads_win > 0 ? (n_threads_win <= 4 ? n_threads_win : n_threads_win / 2) : 4; + + DWORD buffer_size = 0; + if (!GetLogicalProcessorInformationEx(RelationProcessorCore, nullptr, &buffer_size)) { + if (GetLastError() != ERROR_INSUFFICIENT_BUFFER) { + return default_threads; + } + } + + std::vector buffer(buffer_size); + if (!GetLogicalProcessorInformationEx(RelationProcessorCore, reinterpret_cast(buffer.data()), &buffer_size)) { + return default_threads; + } + + int32_t num_physical_cores = 0; + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX info = reinterpret_cast(buffer.data()); + while (buffer_size > 0) { + if (info->Relationship == RelationProcessorCore) { + num_physical_cores += info->Processor.GroupCount; + } + buffer_size -= info->Size; + info = reinterpret_cast(reinterpret_cast(info) + info->Size); + } + + return num_physical_cores > 0 ? num_physical_cores : default_threads; +#endif + unsigned int n_threads = std::thread::hardware_concurrency(); + return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; +} + +#if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__) +#include + +static void cpuid(unsigned leaf, unsigned subleaf, + unsigned *eax, unsigned *ebx, unsigned *ecx, unsigned *edx) { + __asm__("movq\t%%rbx,%%rsi\n\t" + "cpuid\n\t" + "xchgq\t%%rbx,%%rsi" + : "=a"(*eax), "=S"(*ebx), "=c"(*ecx), "=d"(*edx) + : "0"(leaf), "2"(subleaf)); +} + +static int pin_cpu(int cpu) { + cpu_set_t mask; + CPU_ZERO(&mask); + CPU_SET(cpu, &mask); + return pthread_setaffinity_np(pthread_self(), sizeof(mask), &mask); +} + +static bool is_hybrid_cpu(void) { + unsigned eax, ebx, ecx, edx; + cpuid(7, 0, &eax, &ebx, &ecx, &edx); + return !!(edx & (1u << 15)); +} + +static bool is_running_on_efficiency_core(void) { + unsigned eax, ebx, ecx, edx; + cpuid(0x1a, 0, &eax, &ebx, &ecx, &edx); + int intel_atom = 0x20; + int core_type = (eax & 0xff000000u) >> 24; + return core_type == intel_atom; +} + +static int cpu_count_math_cpus(int n_cpu) { + int result = 0; + for (int cpu = 0; cpu < n_cpu; ++cpu) { + if (pin_cpu(cpu)) { + return -1; + } + if (is_running_on_efficiency_core()) { + continue; // efficiency cores harm lockstep threading + } + ++cpu; // hyperthreading isn't useful for linear algebra + ++result; + } + return result; +} + +#endif // __x86_64__ && __linux__ + +/** + * Returns number of CPUs on system that are useful for math. + */ +int32_t cpu_get_num_math() { +#if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__) + int n_cpu = sysconf(_SC_NPROCESSORS_ONLN); + if (n_cpu < 1) { + return cpu_get_num_physical_cores(); + } + if (is_hybrid_cpu()) { + cpu_set_t affinity; + if (!pthread_getaffinity_np(pthread_self(), sizeof(affinity), &affinity)) { + int result = cpu_count_math_cpus(n_cpu); + pthread_setaffinity_np(pthread_self(), sizeof(affinity), &affinity); + if (result > 0) { + return result; + } + } + } +#endif + return cpu_get_num_physical_cores(); +} + +// Helper for setting process priority + +#if defined(_WIN32) + +bool set_process_priority(enum ggml_sched_priority prio) { + if (prio == GGML_SCHED_PRIO_NORMAL) { + return true; + } + + DWORD p = NORMAL_PRIORITY_CLASS; + switch (prio) { + case GGML_SCHED_PRIO_LOW: p = BELOW_NORMAL_PRIORITY_CLASS; break; + case GGML_SCHED_PRIO_NORMAL: p = NORMAL_PRIORITY_CLASS; break; + case GGML_SCHED_PRIO_MEDIUM: p = ABOVE_NORMAL_PRIORITY_CLASS; break; + case GGML_SCHED_PRIO_HIGH: p = HIGH_PRIORITY_CLASS; break; + case GGML_SCHED_PRIO_REALTIME: p = REALTIME_PRIORITY_CLASS; break; + } + + if (!SetPriorityClass(GetCurrentProcess(), p)) { + LOG_WRN("failed to set process priority class %d : (%d)\n", prio, (int) GetLastError()); + return false; + } + + return true; +} + +#else // MacOS and POSIX +#include +#include + +bool set_process_priority(enum ggml_sched_priority prio) { + if (prio == GGML_SCHED_PRIO_NORMAL) { + return true; + } + + int p = 0; + switch (prio) { + case GGML_SCHED_PRIO_LOW: p = 5; break; + case GGML_SCHED_PRIO_NORMAL: p = 0; break; + case GGML_SCHED_PRIO_MEDIUM: p = -5; break; + case GGML_SCHED_PRIO_HIGH: p = -10; break; + case GGML_SCHED_PRIO_REALTIME: p = -20; break; + } + + if (!setpriority(PRIO_PROCESS, 0, p)) { + LOG_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno); + return false; + } + return true; +} + +#endif + +// +// CLI argument parsing +// + + +void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model) { + int32_t n_set = 0; + + if (cpuparams.n_threads < 0) { + // Assuming everything about cpuparams is invalid + if (role_model != nullptr) { + cpuparams = *role_model; + } else { + cpuparams.n_threads = cpu_get_num_math(); + } + } + + for (int32_t i = 0; i < GGML_MAX_N_THREADS; i++) { + if (cpuparams.cpumask[i]) { + n_set++; + } + } + + if (n_set && n_set < cpuparams.n_threads) { + // Not enough set bits, may experience performance issues. + LOG_WRN("Not enough set bits in CPU mask (%d) to satisfy requested thread count: %d\n", n_set, cpuparams.n_threads); + } +} + +bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THREADS]) { + size_t dash_loc = range.find('-'); + if (dash_loc == std::string::npos) { + LOG_ERR("Format of CPU range is invalid! Expected []-[].\n"); + return false; + } + + size_t start_i; + size_t end_i; + + if (dash_loc == 0) { + start_i = 0; + } else { + start_i = std::stoull(range.substr(0, dash_loc)); + if (start_i >= GGML_MAX_N_THREADS) { + LOG_ERR("Start index out of bounds!\n"); + return false; + } + } + + if (dash_loc == range.length() - 1) { + end_i = GGML_MAX_N_THREADS - 1; + } else { + end_i = std::stoull(range.substr(dash_loc + 1)); + if (end_i >= GGML_MAX_N_THREADS) { + LOG_ERR("End index out of bounds!\n"); + return false; + } + } + + for (size_t i = start_i; i <= end_i; i++) { + boolmask[i] = true; + } + + return true; +} + +bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREADS]) { + // Discard potential 0x prefix + size_t start_i = 0; + if (mask.length() >= 2 && mask.substr(0, 2) == "0x") { + start_i = 2; + } + + size_t num_digits = mask.length() - start_i; + if (num_digits > 128) num_digits = 128; + + size_t end_i = num_digits + start_i; + + for (size_t i = start_i, n = (num_digits*4 - 1); i < end_i; i++, n-=4) { + char c = mask.at(i); + int8_t id = c; + + if ((c >= '0' && c <= '9')) { + id -= '0'; + } else if (c >= 'a' && c <= 'f') { + id -= 'a' - 10; + } else if (c >= 'A' && c <= 'F') { + id -= 'A' - 10; + } else { + LOG_ERR("Invalid hex character '%c' at position %d\n", c, int32_t(i)); + return false; + } + + boolmask[ n ] = boolmask[ n ] || ((id & 8) != 0); + boolmask[n - 1] = boolmask[n - 1] || ((id & 4) != 0); + boolmask[n - 2] = boolmask[n - 2] || ((id & 2) != 0); + boolmask[n - 3] = boolmask[n - 3] || ((id & 1) != 0); + } + + return true; +} + +void common_init() { + llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) { + if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) { + common_log_add(common_log_main(), level, "%s", text); + } + }, NULL); + +#ifdef NDEBUG + const char * build_type = ""; +#else + const char * build_type = " (debug)"; +#endif + + LOG_INF("build: %d (%s) with %s for %s%s\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT, LLAMA_COMPILER, LLAMA_BUILD_TARGET, build_type); +} + +std::string common_params_get_system_info(const common_params & params) { + std::ostringstream os; + + os << "system_info: n_threads = " << params.cpuparams.n_threads; + if (params.cpuparams_batch.n_threads != -1) { + os << " (n_threads_batch = " << params.cpuparams_batch.n_threads << ")"; + } +#if defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later + // TODO: windows + arm64 + mingw64 + DWORD logicalProcessorCount = GetActiveProcessorCount(ALL_PROCESSOR_GROUPS); + os << " / " << logicalProcessorCount << " | " << llama_print_system_info(); +#else + os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info(); +#endif + + return os.str(); +} + +// +// String utils +// + +std::string string_format(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::vector buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), size); +} + +std::string string_strip(const std::string & str) { + size_t start = 0; + size_t end = str.size(); + while (start < end && std::isspace(str[start])) { + start++; + } + while (end > start && std::isspace(str[end - 1])) { + end--; + } + return str.substr(start, end - start); +} + +std::string string_get_sortable_timestamp() { + using clock = std::chrono::system_clock; + + const clock::time_point current_time = clock::now(); + const time_t as_time_t = clock::to_time_t(current_time); + char timestamp_no_ns[100]; + std::strftime(timestamp_no_ns, 100, "%Y_%m_%d-%H_%M_%S", std::localtime(&as_time_t)); + + const int64_t ns = std::chrono::duration_cast( + current_time.time_since_epoch() % 1000000000).count(); + char timestamp_ns[11]; + snprintf(timestamp_ns, 11, "%09" PRId64, ns); + + return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns); +} + +void string_replace_all(std::string & s, const std::string & search, const std::string & replace) { + if (search.empty()) { + return; + } + std::string builder; + builder.reserve(s.length()); + size_t pos = 0; + size_t last_pos = 0; + while ((pos = s.find(search, last_pos)) != std::string::npos) { + builder.append(s, last_pos, pos - last_pos); + builder.append(replace); + last_pos = pos + search.length(); + } + builder.append(s, last_pos, std::string::npos); + s = std::move(builder); +} + +bool string_ends_with(const std::string_view & str, const std::string_view & suffix) { + return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; +} +size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) { + if (!str.empty() && !stop.empty()) { + const char text_last_char = str.back(); + for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) { + if (stop[char_index] == text_last_char) { + const auto current_partial = stop.substr(0, char_index + 1); + if (string_ends_with(str, current_partial)) { + return str.size() - char_index - 1; + } + } + } + } + + return std::string::npos; +} + +std::string regex_escape(const std::string & s) { + static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]"); + return std::regex_replace(s, special_chars, "\\$&"); +} + +std::string string_join(const std::vector & values, const std::string & separator) { + std::ostringstream result; + for (size_t i = 0; i < values.size(); ++i) { + if (i > 0) { + result << separator; + } + result << values[i]; + } + return result.str(); +} + +std::vector string_split(const std::string & str, const std::string & delimiter) { + std::vector parts; + size_t start = 0; + size_t end = str.find(delimiter); + + while (end != std::string::npos) { + parts.push_back(str.substr(start, end - start)); + start = end + delimiter.length(); + end = str.find(delimiter, start); + } + + parts.push_back(str.substr(start)); + + return parts; +} + +std::string string_repeat(const std::string & str, size_t n) { + if (n == 0) { + return ""; + } + + std::string result; + result.reserve(str.length() * n); + + for (size_t i = 0; i < n; ++i) { + result += str; + } + + return result; +} + +std::string string_from(bool value) { + return value ? "true" : "false"; +} + +std::string string_from(const std::vector & values) { + std::stringstream buf; + + buf << "[ "; + bool first = true; + for (auto e : values) { + if (first) { + first = false; + } else { + buf << ", "; + } + buf << std::to_string(e); + } + buf << " ]"; + + return buf.str(); +} + +std::string string_from(const struct llama_context * ctx, const std::vector & tokens) { + std::stringstream buf; + + buf << "[ "; + + bool first = true; + for (const auto & token : tokens) { + if (!first) { + buf << ", "; + } else { + first = false; + } + + auto detokenized = common_token_to_piece(ctx, token); + + detokenized.erase( + std::remove_if( + detokenized.begin(), + detokenized.end(), + [](const unsigned char c) { return !std::isprint(c); }), + detokenized.end()); + + buf << "'" << detokenized << "'" + << ":" << std::to_string(token); + } + + buf << " ]"; + + return buf.str(); +} + +std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch) { + std::stringstream buf; + + buf << "[ "; + + bool first = true; + for (int i = 0; i < batch.n_tokens; ++i) { + if (!first) { + buf << ", "; + } else { + first = false; + } + + auto detokenized = common_token_to_piece(ctx, batch.token[i]); + + detokenized.erase( + std::remove_if( + detokenized.begin(), + detokenized.end(), + [](const unsigned char c) { return !std::isprint(c); }), + detokenized.end()); + + buf << "\n" << std::to_string(i) + << ", token '" << detokenized << "'" + << ", pos " << std::to_string(batch.pos[i]) + << ", n_seq_id " << std::to_string(batch.n_seq_id[i]) + << ", seq_id " << std::to_string(batch.seq_id[i][0]) + << ", logits " << std::to_string(batch.logits[i]); + } + + buf << " ]"; + + return buf.str(); +} + +void string_process_escapes(std::string & input) { + std::size_t input_len = input.length(); + std::size_t output_idx = 0; + + for (std::size_t input_idx = 0; input_idx < input_len; ++input_idx) { + if (input[input_idx] == '\\' && input_idx + 1 < input_len) { + switch (input[++input_idx]) { + case 'n': input[output_idx++] = '\n'; break; + case 'r': input[output_idx++] = '\r'; break; + case 't': input[output_idx++] = '\t'; break; + case '\'': input[output_idx++] = '\''; break; + case '\"': input[output_idx++] = '\"'; break; + case '\\': input[output_idx++] = '\\'; break; + case 'x': + // Handle \x12, etc + if (input_idx + 2 < input_len) { + const char x[3] = { input[input_idx + 1], input[input_idx + 2], 0 }; + char *err_p = nullptr; + const long val = std::strtol(x, &err_p, 16); + if (err_p == x + 2) { + input_idx += 2; + input[output_idx++] = char(val); + break; + } + } + // fall through + default: input[output_idx++] = '\\'; + input[output_idx++] = input[input_idx]; break; + } + } else { + input[output_idx++] = input[input_idx]; + } + } + + input.resize(output_idx); +} + +bool string_parse_kv_override(const char * data, std::vector & overrides) { + const char * sep = strchr(data, '='); + if (sep == nullptr || sep - data >= 128) { + LOG_ERR("%s: malformed KV override '%s'\n", __func__, data); + return false; + } + llama_model_kv_override kvo; + std::strncpy(kvo.key, data, sep - data); + kvo.key[sep - data] = 0; + sep++; + if (strncmp(sep, "int:", 4) == 0) { + sep += 4; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; + kvo.val_i64 = std::atol(sep); + } else if (strncmp(sep, "float:", 6) == 0) { + sep += 6; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; + kvo.val_f64 = std::atof(sep); + } else if (strncmp(sep, "bool:", 5) == 0) { + sep += 5; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; + if (std::strcmp(sep, "true") == 0) { + kvo.val_bool = true; + } else if (std::strcmp(sep, "false") == 0) { + kvo.val_bool = false; + } else { + LOG_ERR("%s: invalid boolean value for KV override '%s'\n", __func__, data); + return false; + } + } else if (strncmp(sep, "str:", 4) == 0) { + sep += 4; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; + if (strlen(sep) > 127) { + LOG_ERR("%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data); + return false; + } + strncpy(kvo.val_str, sep, 127); + kvo.val_str[127] = '\0'; + } else { + LOG_ERR("%s: invalid type for KV override '%s'\n", __func__, data); + return false; + } + overrides.emplace_back(std::move(kvo)); + return true; +} + +// +// Filesystem utils +// + +// Validate if a filename is safe to use +// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function +bool fs_validate_filename(const std::string & filename) { + if (!filename.length()) { + // Empty filename invalid + return false; + } + if (filename.length() > 255) { + // Limit at common largest possible filename on Linux filesystems + // to avoid unnecessary further validation + // (On systems with smaller limits it will be caught by the OS) + return false; + } + + std::u32string filename_utf32; + try { +#if defined(__clang__) + // disable C++17 deprecation warning for std::codecvt_utf8 +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wdeprecated-declarations" +#endif + std::wstring_convert, char32_t> converter; + +#if defined(__clang__) +# pragma clang diagnostic pop +#endif + + filename_utf32 = converter.from_bytes(filename); + + // If the reverse conversion mismatches, it means overlong UTF-8 sequences were used, + // or invalid encodings were encountered. Reject such attempts + std::string filename_reencoded = converter.to_bytes(filename_utf32); + if (filename_reencoded != filename) { + return false; + } + } catch (const std::exception &) { + return false; + } + + // Check for forbidden codepoints: + // - Control characters + // - Unicode equivalents of illegal characters + // - UTF-16 surrogate pairs + // - UTF-8 replacement character + // - Byte order mark (BOM) + // - Illegal characters: / \ : * ? " < > | + for (char32_t c : filename_utf32) { + if (c <= 0x1F // Control characters (C0) + || c == 0x7F // Control characters (DEL) + || (c >= 0x80 && c <= 0x9F) // Control characters (C1) + || c == 0xFF0E // Fullwidth Full Stop (period equivalent) + || c == 0x2215 // Division Slash (forward slash equivalent) + || c == 0x2216 // Set Minus (backslash equivalent) + || (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs + || c == 0xFFFD // Replacement Character (UTF-8) + || c == 0xFEFF // Byte Order Mark (BOM) + || c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters + || c == '?' || c == '"' || c == '<' || c == '>' || c == '|') { + return false; + } + } + + // Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename + // Unicode and other whitespace is not affected, only 0x20 space + if (filename.front() == ' ' || filename.back() == ' ' || filename.back() == '.') { + return false; + } + + // Reject any ".." (currently stricter than necessary, it should be fine to just check for == ".." instead) + if (filename.find("..") != std::string::npos) { + return false; + } + + // Reject "." + if (filename == ".") { + return false; + } + + return true; +} + +#include + + +// returns true if successful, false otherwise +bool fs_create_directory_with_parents(const std::string & path) { +#ifdef _WIN32 + std::wstring_convert> converter; + std::wstring wpath = converter.from_bytes(path); + + // if the path already exists, check whether it's a directory + const DWORD attributes = GetFileAttributesW(wpath.c_str()); + if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) { + return true; + } + + size_t pos_slash = 0; + + // process path from front to back, procedurally creating directories + while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) { + const std::wstring subpath = wpath.substr(0, pos_slash); + + pos_slash += 1; + + // skip the drive letter, in some systems it can return an access denied error + if (subpath.length() == 2 && subpath[1] == ':') { + continue; + } + + const bool success = CreateDirectoryW(subpath.c_str(), NULL); + + if (!success) { + const DWORD error = GetLastError(); + + // if the path already exists, ensure that it's a directory + if (error == ERROR_ALREADY_EXISTS) { + const DWORD attributes = GetFileAttributesW(subpath.c_str()); + if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) { + return false; + } + } else { + return false; + } + } + } + + return true; +#else + // if the path already exists, check whether it's a directory + struct stat info; + if (stat(path.c_str(), &info) == 0) { + return S_ISDIR(info.st_mode); + } + + size_t pos_slash = 1; // skip leading slashes for directory creation + + // process path from front to back, procedurally creating directories + while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) { + const std::string subpath = path.substr(0, pos_slash); + struct stat info; + + // if the path already exists, ensure that it's a directory + if (stat(subpath.c_str(), &info) == 0) { + if (!S_ISDIR(info.st_mode)) { + return false; + } + } else { + // create parent directories + const int ret = mkdir(subpath.c_str(), 0755); + if (ret != 0) { + return false; + } + } + + pos_slash += 1; + } + + return true; +#endif // _WIN32 +} + +std::string fs_get_cache_directory() { + std::string cache_directory = ""; + auto ensure_trailing_slash = [](std::string p) { + // Make sure to add trailing slash + if (p.back() != DIRECTORY_SEPARATOR) { + p += DIRECTORY_SEPARATOR; + } + return p; + }; + if (getenv("LLAMA_CACHE")) { + cache_directory = std::getenv("LLAMA_CACHE"); + } else { +#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__) + if (std::getenv("XDG_CACHE_HOME")) { + cache_directory = std::getenv("XDG_CACHE_HOME"); + } else { + cache_directory = std::getenv("HOME") + std::string("/.cache/"); + } +#elif defined(__APPLE__) + cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); +#elif defined(_WIN32) + cache_directory = std::getenv("LOCALAPPDATA"); +#else +# error Unknown architecture +#endif + cache_directory = ensure_trailing_slash(cache_directory); + cache_directory += "llama.cpp"; + } + return ensure_trailing_slash(cache_directory); +} + +std::string fs_get_cache_file(const std::string & filename) { + GGML_ASSERT(filename.find(DIRECTORY_SEPARATOR) == std::string::npos); + std::string cache_directory = fs_get_cache_directory(); + const bool success = fs_create_directory_with_parents(cache_directory); + if (!success) { + throw std::runtime_error("failed to create cache directory: " + cache_directory); + } + return cache_directory + filename; +} + + +// +// Model utils +// + +struct common_init_result common_init_from_params(common_params & params) { + common_init_result iparams; + auto mparams = common_model_params_to_llama(params); + + llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams); + if (model == NULL) { + LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str()); + return iparams; + } + + const llama_vocab * vocab = llama_model_get_vocab(model); + + auto cparams = common_context_params_to_llama(params); + + llama_context * lctx = llama_init_from_model(model, cparams); + if (lctx == NULL) { + LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str()); + llama_model_free(model); + return iparams; + } + + if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) { + LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__); + params.ctx_shift = false; + } + + if (!params.control_vectors.empty()) { + if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1; + if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_model_n_layer(model); + + const auto cvec = common_control_vector_load(params.control_vectors); + if (cvec.n_embd == -1) { + llama_free(lctx); + llama_model_free(model); + + return iparams; + } + + int err = llama_apply_adapter_cvec( + lctx, + cvec.data.data(), + cvec.data.size(), + cvec.n_embd, + params.control_vector_layer_start, + params.control_vector_layer_end); + if (err) { + llama_free(lctx); + llama_model_free(model); + + return iparams; + } + } + + if (llama_pooling_type(lctx) == LLAMA_POOLING_TYPE_RANK) { + bool ok = true; + + if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__); + ok = false; + } + + bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; + bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL; + + if (!has_eos && !has_sep) { + LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__); + ok = false; + } else if (!has_eos) { + LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__); + } else if (!has_sep) { + LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__); + ok = false; + } + + if (!ok) { + llama_free(lctx); + llama_model_free(model); + + return iparams; + } + } + + // load and optionally apply lora adapters + for (auto & la : params.lora_adapters) { + llama_adapter_lora_ptr lora; + lora.reset(llama_adapter_lora_init(model, la.path.c_str())); + if (lora == nullptr) { + LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); + llama_free(lctx); + llama_model_free(model); + return iparams; + } + + la.ptr = lora.get(); + iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters + } + + if (!params.lora_init_without_apply) { + common_set_adapter_lora(lctx, params.lora_adapters); + } + + if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__); + params.sampling.ignore_eos = false; + } + + if (params.sampling.ignore_eos) { + for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) { + if (llama_vocab_is_eog(vocab, i)) { + LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY); + params.sampling.logit_bias.push_back({i, -INFINITY}); + } + } + } + + if (params.sampling.penalty_last_n == -1) { + LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); + params.sampling.penalty_last_n = llama_n_ctx(lctx); + } + + if (params.sampling.dry_penalty_last_n == -1) { + LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); + params.sampling.dry_penalty_last_n = llama_n_ctx(lctx); + } + + if (params.warmup) { + LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__); + + llama_set_warmup(lctx, true); + + std::vector tmp; + llama_token bos = llama_vocab_bos(vocab); + llama_token eos = llama_vocab_eos(vocab); + + // some models (e.g. T5) don't have a BOS token + if (bos != LLAMA_TOKEN_NULL) { + tmp.push_back(bos); + } + if (eos != LLAMA_TOKEN_NULL) { + tmp.push_back(eos); + } + if (tmp.empty()) { + tmp.push_back(0); + } + + if (llama_model_has_encoder(model)) { + llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size())); + llama_token decoder_start_token_id = llama_model_decoder_start_token(model); + if (decoder_start_token_id == LLAMA_TOKEN_NULL) { + decoder_start_token_id = bos; + } + tmp.clear(); + tmp.push_back(decoder_start_token_id); + } + if (llama_model_has_decoder(model)) { + llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch))); + } + llama_memory_clear(llama_get_memory(lctx), true); + llama_synchronize(lctx); + llama_perf_context_reset(lctx); + llama_set_warmup(lctx, false); + } + + iparams.model.reset(model); + iparams.context.reset(lctx); + + return iparams; +} + +std::string get_model_endpoint() { + const char * model_endpoint_env = getenv("MODEL_ENDPOINT"); + // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility. + const char * hf_endpoint_env = getenv("HF_ENDPOINT"); + const char * endpoint_env = model_endpoint_env ? model_endpoint_env : hf_endpoint_env; + std::string model_endpoint = "https://huggingface.co/"; + if (endpoint_env) { + model_endpoint = endpoint_env; + if (model_endpoint.back() != '/') model_endpoint += '/'; + } + return model_endpoint; +} + +void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora) { + llama_clear_adapter_lora(ctx); + for (auto & la : lora) { + if (la.scale != 0.0f) { + llama_set_adapter_lora(ctx, la.ptr, la.scale); + } + } +} + +struct llama_model_params common_model_params_to_llama(common_params & params) { + auto mparams = llama_model_default_params(); + + if (!params.devices.empty()) { + mparams.devices = params.devices.data(); + } + + if (params.n_gpu_layers != -1) { + mparams.n_gpu_layers = params.n_gpu_layers; + } + + mparams.main_gpu = params.main_gpu; + mparams.split_mode = params.split_mode; + mparams.tensor_split = params.tensor_split; + mparams.use_mmap = params.use_mmap; + mparams.use_mlock = params.use_mlock; + mparams.check_tensors = params.check_tensors; + + if (params.kv_overrides.empty()) { + mparams.kv_overrides = NULL; + } else { + GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key"); + mparams.kv_overrides = params.kv_overrides.data(); + } + + if (params.tensor_buft_overrides.empty()) { + mparams.tensor_buft_overrides = NULL; + } else { + GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern"); + mparams.tensor_buft_overrides = params.tensor_buft_overrides.data(); + } + + mparams.progress_callback = params.load_progress_callback; + mparams.progress_callback_user_data = params.load_progress_callback_user_data; + + return mparams; +} + +struct llama_context_params common_context_params_to_llama(const common_params & params) { + auto cparams = llama_context_default_params(); + + cparams.n_ctx = params.n_ctx; + cparams.n_seq_max = params.n_parallel; + cparams.n_batch = params.n_batch; + cparams.n_ubatch = params.n_ubatch; + cparams.n_threads = params.cpuparams.n_threads; + cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ? + params.cpuparams.n_threads : params.cpuparams_batch.n_threads; + cparams.embeddings = params.embedding; + cparams.rope_scaling_type = params.rope_scaling_type; + cparams.rope_freq_base = params.rope_freq_base; + cparams.rope_freq_scale = params.rope_freq_scale; + cparams.yarn_ext_factor = params.yarn_ext_factor; + cparams.yarn_attn_factor = params.yarn_attn_factor; + cparams.yarn_beta_fast = params.yarn_beta_fast; + cparams.yarn_beta_slow = params.yarn_beta_slow; + cparams.yarn_orig_ctx = params.yarn_orig_ctx; + cparams.pooling_type = params.pooling_type; + cparams.attention_type = params.attention_type; + cparams.defrag_thold = params.defrag_thold; + cparams.cb_eval = params.cb_eval; + cparams.cb_eval_user_data = params.cb_eval_user_data; + cparams.offload_kqv = !params.no_kv_offload; + cparams.flash_attn = params.flash_attn; + cparams.no_perf = params.no_perf; + cparams.op_offload = !params.no_op_offload; + cparams.swa_full = params.swa_full; + + cparams.type_k = params.cache_type_k; + cparams.type_v = params.cache_type_v; + + return cparams; +} + +struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params) { + struct ggml_threadpool_params tpp; + + ggml_threadpool_params_init(&tpp, params.n_threads); // setup the defaults + + if (params.mask_valid) { + std::memcpy(&tpp.cpumask, ¶ms.cpumask, GGML_MAX_N_THREADS); + } + + tpp.prio = params.priority; + tpp.poll = params.poll; + tpp.strict_cpu = params.strict_cpu; + + return tpp; +} + +// +// Batch utils +// + +void common_batch_clear(struct llama_batch & batch) { + batch.n_tokens = 0; +} + +void common_batch_add( + struct llama_batch & batch, + llama_token id, + llama_pos pos, + const std::vector & seq_ids, + bool logits) { + GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded"); + + batch.token [batch.n_tokens] = id; + batch.pos [batch.n_tokens] = pos; + batch.n_seq_id[batch.n_tokens] = seq_ids.size(); + for (size_t i = 0; i < seq_ids.size(); ++i) { + batch.seq_id[batch.n_tokens][i] = seq_ids[i]; + } + batch.logits [batch.n_tokens] = logits; + + batch.n_tokens++; +} + +// +// Token utils +// + +size_t common_lcp(const llama_tokens & a, const llama_tokens & b) { + size_t i; + for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} + + return i; +} + +size_t common_lcs(const llama_tokens & a, const llama_tokens & b) { + // check for empty sequences + if (a.empty() || b.empty()) { + return 0; + } + + // get the lengths of the input sequences + size_t a_len = a.size(); + size_t b_len = b.size(); + + // initialize the maximum length of the longest common subsequence (LCS) + size_t max_length = 0; + + // use two rows instead of a 2D matrix to optimize space + std::vector prev_row(b_len + 1, 0); + std::vector curr_row(b_len + 1, 0); + + // iterate through the elements of a + for (size_t i = 1; i <= a_len; i++) { + // iterate through the elements of b + for (size_t j = 1; j <= b_len; j++) { + // if elements at the current positions match + if (a[i - 1] == b[j - 1]) { + // if it's the first element of either sequences, set LCS length to 1 + if (i == 1 || j == 1) { + curr_row[j] = 1; + } else { + // increment LCS length by 1 compared to the previous element + curr_row[j] = prev_row[j - 1] + 1; + } + + // update max_length if necessary + if (curr_row[j] > max_length) { + max_length = curr_row[j]; + } + } else { + // reset LCS length if elements don't match + curr_row[j] = 0; + } + } + + // update the previous row for the next iteration + prev_row = curr_row; + } + + // return the maximum length of the LCS + return max_length; +} + +// +// Vocab utils +// + +std::vector common_tokenize( + const struct llama_context * ctx, + const std::string & text, + bool add_special, + bool parse_special) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + return common_tokenize(vocab, text, add_special, parse_special); +} + +std::vector common_tokenize( + const struct llama_vocab * vocab, + const std::string & text, + bool add_special, + bool parse_special) { + // upper limit for the number of tokens + int n_tokens = text.length() + 2 * add_special; + std::vector result(n_tokens); + n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); + if (n_tokens < 0) { + result.resize(-n_tokens); + int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); + GGML_ASSERT(check == -n_tokens); + } else { + result.resize(n_tokens); + } + return result; +} + +std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + return common_token_to_piece(vocab, token, special); +} + +std::string common_token_to_piece(const struct llama_vocab * vocab, llama_token token, bool special) { + std::string piece; + piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' + const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); + if (n_chars < 0) { + piece.resize(-n_chars); + int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); + GGML_ASSERT(check == -n_chars); + } + else { + piece.resize(n_chars); + } + + return piece; +} + +std::string common_detokenize(const struct llama_context * ctx, const std::vector & tokens, bool special) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + return common_detokenize(vocab, tokens, special); +} + +std::string common_detokenize(const struct llama_vocab * vocab, const std::vector & tokens, bool special) { + std::string text; + text.resize(std::max(text.capacity(), tokens.size())); + int32_t n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + if (n_chars < 0) { + text.resize(-n_chars); + n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization + } + + text.resize(n_chars); + + // NOTE: the original tokenizer decodes bytes after collecting the pieces. + return text; +} + +// +// Embedding utils +// + +void common_embd_normalize(const float * inp, float * out, int n, int embd_norm) { + double sum = 0.0; + + switch (embd_norm) { + case -1: // no normalisation + sum = 1.0; + break; + case 0: // max absolute + for (int i = 0; i < n; i++) { + if (sum < std::abs(inp[i])) { + sum = std::abs(inp[i]); + } + } + sum /= 32760.0; // make an int16 range + break; + case 2: // euclidean + for (int i = 0; i < n; i++) { + sum += inp[i] * inp[i]; + } + sum = std::sqrt(sum); + break; + default: // p-norm (euclidean is p-norm p=2) + for (int i = 0; i < n; i++) { + sum += std::pow(std::abs(inp[i]), embd_norm); + } + sum = std::pow(sum, 1.0 / embd_norm); + break; + } + + const float norm = sum > 0.0 ? 1.0 / sum : 0.0f; + + for (int i = 0; i < n; i++) { + out[i] = inp[i] * norm; + } +} + +float common_embd_similarity_cos(const float * embd1, const float * embd2, int n){ + double sum = 0.0; + double sum1 = 0.0; + double sum2 = 0.0; + + for (int i = 0; i < n; i++) { + sum += embd1[i] * embd2[i]; + sum1 += embd1[i] * embd1[i]; + sum2 += embd2[i] * embd2[i]; + } + + // Handle the case where one or both vectors are zero vectors + if (sum1 == 0.0 || sum2 == 0.0) { + if (sum1 == 0.0 && sum2 == 0.0) { + return 1.0f; // two zero vectors are similar + } + return 0.0f; + } + + return sum / (sqrt(sum1) * sqrt(sum2)); +} + +// +// Control vector utils +// + +static common_control_vector_data common_control_vector_load_one(const common_control_vector_load_info & load_info) { + common_control_vector_data result = { -1, {} }; + + ggml_context * ctx = nullptr; + struct gguf_init_params meta_gguf_params = { + /* .no_alloc = */ false, + /* .ctx = */ &ctx, + }; + struct gguf_context * ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), meta_gguf_params); + if (!ctx_gguf) { + LOG_ERR("%s: failed to load control vector file from %s\n", __func__, load_info.fname.c_str()); + return result; + } + + int32_t n_tensors = gguf_get_n_tensors(ctx_gguf); + if (n_tensors == 0) { + LOG_WRN("%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str()); + } + + for (int i = 0; i < n_tensors; i++) { + std::string name = gguf_get_tensor_name(ctx_gguf, i); + + int layer_idx = -1; + + // split on '.' + size_t dotpos = name.find('.'); + if (dotpos != std::string::npos && name.substr(0, dotpos) == "direction") { + try { + layer_idx = std::stoi(name.substr(dotpos + 1)); + } catch (...) { + layer_idx = -1; + } + } + if (layer_idx < 0) { + LOG_ERR("%s: invalid/unparsable direction tensor layer index in %s\n", __func__, load_info.fname.c_str()); + result.n_embd = -1; + break; + } else if (layer_idx == 0) { + LOG_ERR("%s: invalid (zero) direction tensor layer index in %s\n", __func__, load_info.fname.c_str()); + result.n_embd = -1; + break; + } + + struct ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str()); + if (tensor->type != GGML_TYPE_F32) { + LOG_ERR("%s: invalid (non-F32) direction tensor type in %s\n", __func__, load_info.fname.c_str()); + result.n_embd = -1; + break; + } + if (ggml_n_dims(tensor) != 1) { + LOG_ERR("%s: invalid (non-1D) direction tensor shape in %s\n", __func__, load_info.fname.c_str()); + result.n_embd = -1; + break; + } + + if (result.n_embd == -1) { + result.n_embd = ggml_nelements(tensor); + } else if (ggml_nelements(tensor) != result.n_embd) { + LOG_ERR("%s: direction tensor in %s does not match previous dimensions\n", __func__, load_info.fname.c_str()); + result.n_embd = -1; + break; + } + + // extend if necessary - do not store data for layer 0 (it's not used) + result.data.resize(std::max(result.data.size(), static_cast(result.n_embd * layer_idx)), 0.0f); + + const float * src = (const float *) tensor->data; + float * dst = result.data.data() + result.n_embd * (layer_idx - 1); // layer 1 at [0] + for (int j = 0; j < result.n_embd; j++) { + dst[j] += src[j] * load_info.strength; // allows multiple directions for same layer in same file + } + + } + + if (result.n_embd == -1) { + LOG_WRN("%s: skipping %s due to invalid direction tensors\n", __func__, load_info.fname.c_str()); + result.data.clear(); + } + + gguf_free(ctx_gguf); + ggml_free(ctx); + + return result; +} + +common_control_vector_data common_control_vector_load(const std::vector & load_infos) { + common_control_vector_data result = { -1, {} }; + + for (const auto & info : load_infos) { + auto cur = common_control_vector_load_one(info); + + if (cur.n_embd == -1) { + result.n_embd = -1; + break; + } + if (result.n_embd != -1 && result.n_embd != cur.n_embd) { + LOG_ERR("%s: control vectors in %s does not match previous dimensions\n", __func__, info.fname.c_str()); + result.n_embd = -1; + break; + } + + if (result.n_embd == -1) { + result = std::move(cur); + } else { + result.data.resize(std::max(result.data.size(), cur.data.size()), 0.0f); // extend if necessary + for (size_t i = 0; i < cur.data.size(); i++) { + result.data[i] += cur.data[i]; + } + } + } + + if (result.n_embd == -1) { + LOG_ERR("%s: no valid control vector files passed\n", __func__); + result.data.clear(); + } + + return result; +} + +ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector & tokens, int64_t stride) { + const int64_t ne_datapoint = llama_n_ctx(ctx); + const int64_t ndata = (tokens.size() - ne_datapoint - 1) / stride; + ggml_opt_dataset_t result = ggml_opt_dataset_init( + GGML_TYPE_I32, GGML_TYPE_I32, ne_datapoint, ne_datapoint, ndata, /*ndata_shard =*/ 1); + + llama_token * data = (llama_token *) ggml_opt_dataset_data(result)->data; + llama_token * labels = (llama_token *) ggml_opt_dataset_labels(result)->data; + + for (int64_t idata = 0; idata < ndata; ++idata) { + memcpy(data + idata*ne_datapoint, tokens.data() + idata*stride + 0, ne_datapoint*sizeof(llama_token)); + memcpy(labels + idata*ne_datapoint, tokens.data() + idata*stride + 1, ne_datapoint*sizeof(llama_token)); + } + + return result; +} diff --git a/common/common.h b/common/common.h new file mode 100644 index 0000000000000000000000000000000000000000..00b6ca03a20b4b42634f8ef67e041f992694089b --- /dev/null +++ b/common/common.h @@ -0,0 +1,672 @@ +// Various helper functions and utilities + +#pragma once + +#include "llama-cpp.h" + +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#define DIRECTORY_SEPARATOR '\\' +#else +#define DIRECTORY_SEPARATOR '/' +#endif // _WIN32 + +#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0) +#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0) + +#define print_build_info() do { \ + fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); \ + fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \ +} while(0) + +#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" + +struct common_adapter_lora_info { + std::string path; + float scale; + + struct llama_adapter_lora * ptr; +}; + +using llama_tokens = std::vector; + +// build info +extern int LLAMA_BUILD_NUMBER; +extern const char * LLAMA_COMMIT; +extern const char * LLAMA_COMPILER; +extern const char * LLAMA_BUILD_TARGET; + +struct common_control_vector_load_info; + +// +// CPU utils +// + +struct cpu_params { + int n_threads = -1; + bool cpumask[GGML_MAX_N_THREADS] = {false}; // CPU affinity mask. + bool mask_valid = false; // Default: any CPU + enum ggml_sched_priority priority = GGML_SCHED_PRIO_NORMAL; // Scheduling prio : (0 - normal, 1 - medium, 2 - high, 3 - realtime) + bool strict_cpu = false; // Use strict CPU placement + uint32_t poll = 50; // Polling (busywait) level (0 - no polling, 100 - mostly polling) +}; + +int32_t cpu_get_num_physical_cores(); +int32_t cpu_get_num_math(); + +// +// Common params +// + +enum llama_example { + LLAMA_EXAMPLE_COMMON, + LLAMA_EXAMPLE_SPECULATIVE, + LLAMA_EXAMPLE_MAIN, + LLAMA_EXAMPLE_EMBEDDING, + LLAMA_EXAMPLE_PERPLEXITY, + LLAMA_EXAMPLE_RETRIEVAL, + LLAMA_EXAMPLE_PASSKEY, + LLAMA_EXAMPLE_IMATRIX, + LLAMA_EXAMPLE_BENCH, + LLAMA_EXAMPLE_SERVER, + LLAMA_EXAMPLE_CVECTOR_GENERATOR, + LLAMA_EXAMPLE_EXPORT_LORA, + LLAMA_EXAMPLE_MTMD, + LLAMA_EXAMPLE_LOOKUP, + LLAMA_EXAMPLE_PARALLEL, + LLAMA_EXAMPLE_TTS, + + LLAMA_EXAMPLE_COUNT, +}; + +enum common_sampler_type { + COMMON_SAMPLER_TYPE_NONE = 0, + COMMON_SAMPLER_TYPE_DRY = 1, + COMMON_SAMPLER_TYPE_TOP_K = 2, + COMMON_SAMPLER_TYPE_TOP_P = 3, + COMMON_SAMPLER_TYPE_MIN_P = 4, + //COMMON_SAMPLER_TYPE_TFS_Z = 5, + COMMON_SAMPLER_TYPE_TYPICAL_P = 6, + COMMON_SAMPLER_TYPE_TEMPERATURE = 7, + COMMON_SAMPLER_TYPE_XTC = 8, + COMMON_SAMPLER_TYPE_INFILL = 9, + COMMON_SAMPLER_TYPE_PENALTIES = 10, + COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11, +}; + +// dimensionality reduction methods, used by cvector-generator +enum dimre_method { + DIMRE_METHOD_PCA, + DIMRE_METHOD_MEAN, +}; + +enum common_conversation_mode { + COMMON_CONVERSATION_MODE_DISABLED = 0, + COMMON_CONVERSATION_MODE_ENABLED = 1, + COMMON_CONVERSATION_MODE_AUTO = 2, +}; + +enum common_grammar_trigger_type { + COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN, + COMMON_GRAMMAR_TRIGGER_TYPE_WORD, + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, +}; + +struct common_grammar_trigger { + common_grammar_trigger_type type; + std::string value; + llama_token token = LLAMA_TOKEN_NULL; +}; + +// sampling parameters +struct common_params_sampling { + uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler + + int32_t n_prev = 64; // number of previous tokens to remember + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float min_p = 0.05f; // 0.0 = disabled + float xtc_probability = 0.00f; // 0.0 = disabled + float xtc_threshold = 0.10f; // > 0.5 disables XTC + float typ_p = 1.00f; // typical_p, 1.0 = disabled + float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities + float dynatemp_range = 0.00f; // 0.0 = disabled + float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler + int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat = 1.00f; // 1.0 = disabled + float penalty_freq = 0.00f; // 0.0 = disabled + float penalty_present = 0.00f; // 0.0 = disabled + float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition: + float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length) + int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty + int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float top_n_sigma = -1.00f;// -1.0 = disabled + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate + bool ignore_eos = false; + bool no_perf = false; // disable performance metrics + bool timing_per_token = false; + + std::vector dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY + + + std::vector samplers = { + COMMON_SAMPLER_TYPE_PENALTIES, + COMMON_SAMPLER_TYPE_DRY, + COMMON_SAMPLER_TYPE_TOP_N_SIGMA, + COMMON_SAMPLER_TYPE_TOP_K, + COMMON_SAMPLER_TYPE_TYPICAL_P, + COMMON_SAMPLER_TYPE_TOP_P, + COMMON_SAMPLER_TYPE_MIN_P, + COMMON_SAMPLER_TYPE_XTC, + COMMON_SAMPLER_TYPE_TEMPERATURE, + }; + + std::string grammar; // optional BNF-like grammar to constrain sampling + bool grammar_lazy = false; + std::vector grammar_triggers; // optional triggers (for lazy grammars) + std::set preserved_tokens; + + std::vector logit_bias; // logit biases to apply + + // print the parameters into a string + std::string print() const; +}; + +struct common_params_model { + std::string path = ""; // model local path // NOLINT + std::string url = ""; // model url to download // NOLINT + std::string hf_repo = ""; // HF repo // NOLINT + std::string hf_file = ""; // HF file // NOLINT +}; + +struct common_params_speculative { + std::vector devices; // devices to use for offloading + + int32_t n_ctx = 0; // draft context size + int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding + int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding + int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) + float p_split = 0.1f; // speculative decoding split probability + float p_min = 0.75f; // minimum speculative decoding probability (greedy) + + struct cpu_params cpuparams; + struct cpu_params cpuparams_batch; + + struct common_params_model model; +}; + +struct common_params_vocoder { + struct common_params_model model; + + std::string speaker_file = ""; // speaker file path // NOLINT + + bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT +}; + +enum common_reasoning_format { + COMMON_REASONING_FORMAT_NONE, + COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in tags in stream mode + COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas. +}; + +struct common_params { + int32_t n_predict = -1; // new tokens to predict + int32_t n_ctx = 4096; // context size + int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) + int32_t n_parallel = 1; // number of parallel sequences to decode + int32_t n_sequences = 1; // number of sequences to decode + int32_t grp_attn_n = 1; // group-attention factor + int32_t grp_attn_w = 512; // group-attention width + int32_t n_print = -1; // print token count every n tokens (-1 = disabled) + float rope_freq_base = 0.0f; // RoPE base frequency + float rope_freq_scale = 0.0f; // RoPE frequency scaling factor + float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor + float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor + float yarn_beta_fast = 32.0f; // YaRN low correction dim + float yarn_beta_slow = 1.0f; // YaRN high correction dim + int32_t yarn_orig_ctx = 0; // YaRN original context length + float defrag_thold = 0.1f; // KV cache defragmentation threshold + + // offload params + std::vector devices; // devices to use for offloading + + int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) + int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors + float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs + + enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs + + struct cpu_params cpuparams; + struct cpu_params cpuparams_batch; + + ggml_backend_sched_eval_callback cb_eval = nullptr; + void * cb_eval_user_data = nullptr; + + ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED; + + enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; + enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings + enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings + + struct common_params_sampling sampling; + struct common_params_speculative speculative; + struct common_params_vocoder vocoder; + + struct common_params_model model; + + std::string model_alias = ""; // model alias // NOLINT + std::string hf_token = ""; // HF token // NOLINT + std::string prompt = ""; // NOLINT + std::string system_prompt = ""; // NOLINT + std::string prompt_file = ""; // store the external prompt file name // NOLINT + std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT + std::string input_prefix = ""; // string to prefix user inputs with // NOLINT + std::string input_suffix = ""; // string to suffix user inputs with // NOLINT + std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT + std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT + std::string logits_file = ""; // file for saving *all* logits // NOLINT + + std::vector in_files; // all input files + std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) + std::vector kv_overrides; + std::vector tensor_buft_overrides; + + bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply) + std::vector lora_adapters; // lora adapter path with user defined scale + + std::vector control_vectors; // control vector with user defined scale + + int32_t verbosity = 0; + int32_t control_vector_layer_start = -1; // layer range for control vector + int32_t control_vector_layer_end = -1; // layer range for control vector + bool offline = false; + + int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. + int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line + // (which is more convenient to use for plotting) + // + bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt + size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score + + bool winogrande = false; // compute Winogrande score over random tasks from datafile supplied in prompt + size_t winogrande_tasks = 0; // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed + + bool multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt + size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed + + bool kl_divergence = false; // compute KL divergence + + bool usage = false; // print usage + bool completion = false; // print source-able completion script + bool use_color = false; // use color to distinguish generations and inputs + bool special = false; // enable special token output + bool interactive = false; // interactive mode + bool interactive_first = false; // wait for user input immediately + bool prompt_cache_all = false; // save user input and generations to prompt cache + bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it + + bool escape = true; // escape "\n", "\r", "\t", "\'", "\"", and "\\" + bool multiline_input = false; // reverse the usage of `\` + bool simple_io = false; // improves compatibility with subprocesses and limited consoles + bool cont_batching = true; // insert new sequences for decoding on-the-fly + bool flash_attn = false; // flash attention + bool no_perf = false; // disable performance metrics + bool ctx_shift = true; // context shift on inifinite text generation + bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) + + bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix + bool use_mmap = true; // use mmap for faster loads + bool use_mlock = false; // use mlock to keep model in memory + bool verbose_prompt = false; // print prompt tokens before generation + bool display_prompt = true; // print prompt before generation + bool no_kv_offload = false; // disable KV offloading + bool warmup = true; // warmup run + bool check_tensors = false; // validate tensor data + bool no_op_offload = false; // globally disable offload host tensor operations to device + + bool single_turn = false; // single turn chat conversation + + ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K + ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V + + common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO; + + // multimodal models (see tools/mtmd) + struct common_params_model mmproj; + bool mmproj_use_gpu = true; // use GPU for multimodal model + bool no_mmproj = false; // explicitly disable multimodal model + std::vector image; // path to image file(s) + + // embedding + bool embedding = false; // get only sentence embedding + int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) + std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix + std::string embd_sep = "\n"; // separator of embeddings + + // server params + int32_t port = 8080; // server listens on this network port + int32_t timeout_read = 600; // http read timeout in seconds + int32_t timeout_write = timeout_read; // http write timeout in seconds + int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) + int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting + + std::string hostname = "127.0.0.1"; + std::string public_path = ""; // NOLINT + std::string chat_template = ""; // NOLINT + bool use_jinja = false; // NOLINT + bool enable_chat_template = true; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + int reasoning_budget = -1; + bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response + + std::vector api_keys; + + std::string ssl_file_key = ""; // NOLINT + std::string ssl_file_cert = ""; // NOLINT + + // "advanced" endpoints are disabled by default for better security + bool webui = true; + bool endpoint_slots = false; + bool endpoint_props = false; // only control POST requests, not GET + bool endpoint_metrics = false; + + bool log_json = false; + + std::string slot_save_path; + + float slot_prompt_similarity = 0.5f; + + // batched-bench params + bool is_pp_shared = false; + + std::vector n_pp; + std::vector n_tg; + std::vector n_pl; + + // retrieval params + std::vector context_files; // context files to embed + + int32_t chunk_size = 64; // chunk size for context embedding + + std::string chunk_separator = "\n"; // chunk separator for context embedding + + // passkey params + int32_t n_junk = 250; // number of times to repeat the junk text + int32_t i_pos = -1; // position of the passkey in the junk text + + // imatrix params + int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations + int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations + int32_t i_chunk = 0; // start processing from this chunk + + bool process_output = false; // collect data for the output tensor + bool compute_ppl = true; // whether to compute perplexity + bool parse_special = false; // whether to parse special tokens during imatrix tokenization + + // cvector-generator params + int n_pca_batch = 100; + int n_pca_iterations = 1000; + dimre_method cvector_dimre_method = DIMRE_METHOD_PCA; + std::string cvector_positive_file = "tools/cvector-generator/positive.txt"; + std::string cvector_negative_file = "tools/cvector-generator/negative.txt"; + + bool spm_infill = false; // suffix/prefix/middle pattern for infill + + // batched-bench params + bool batched_bench_output_jsonl = false; + + // common params + std::string out_file; // output filename for all example programs + // optional callback for model loading progress and cancellation: + // called with a progress value between 0.0 and 1.0. + // return false from callback to abort model loading or true to continue + llama_progress_callback load_progress_callback = NULL; + void * load_progress_callback_user_data = NULL; +}; + +// call once at the start of a program if it uses libcommon +// initializes the logging system and prints info about the build +void common_init(); + +std::string common_params_get_system_info(const common_params & params); + +bool parse_cpu_range(const std::string & range, bool(&boolmask)[GGML_MAX_N_THREADS]); +bool parse_cpu_mask(const std::string & mask, bool(&boolmask)[GGML_MAX_N_THREADS]); +void postprocess_cpu_params(cpu_params & cpuparams, const cpu_params * role_model = nullptr); +bool set_process_priority(enum ggml_sched_priority prio); + +// +// String utils +// + +#ifdef __GNUC__ +# if defined(__MINGW32__) && !defined(__clang__) +# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +# else +# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +# endif +#else +# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) +#endif + +LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2) +std::string string_format(const char * fmt, ...); + +std::string string_strip(const std::string & str); +std::string string_get_sortable_timestamp(); + +std::string string_join(const std::vector & values, const std::string & separator); +std::vector string_split(const std::string & str, const std::string & delimiter); +std::string string_repeat(const std::string & str, size_t n); + +void string_replace_all(std::string & s, const std::string & search, const std::string & replace); + +std::string regex_escape(const std::string & s); + +template +static std::vector string_split(const std::string & str, char delim) { + static_assert(!std::is_same::value, "Please use the specialized version for std::string"); + std::vector values; + std::istringstream str_stream(str); + std::string token; + while (std::getline(str_stream, token, delim)) { + T value; + std::istringstream token_stream(token); + token_stream >> value; + values.push_back(value); + } + return values; +} + +template<> +std::vector string_split(const std::string & input, char separator) +{ + std::vector parts; + size_t begin_pos = 0; + size_t separator_pos = input.find(separator); + while (separator_pos != std::string::npos) { + std::string part = input.substr(begin_pos, separator_pos - begin_pos); + parts.emplace_back(part); + begin_pos = separator_pos + 1; + separator_pos = input.find(separator, begin_pos); + } + parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos)); + return parts; +} + +static bool string_starts_with(const std::string & str, + const std::string & prefix) { // While we wait for C++20's std::string::starts_with... + return str.rfind(prefix, 0) == 0; +} + +// While we wait for C++20's std::string::ends_with... +bool string_ends_with(const std::string_view & str, const std::string_view & suffix); +size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop); + +bool string_parse_kv_override(const char * data, std::vector & overrides); +void string_process_escapes(std::string & input); + +std::string string_from(bool value); +std::string string_from(const std::vector & values); +std::string string_from(const struct llama_context * ctx, const std::vector & tokens); +std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch); + +// +// Filesystem utils +// + +bool fs_validate_filename(const std::string & filename); +bool fs_create_directory_with_parents(const std::string & path); + +std::string fs_get_cache_directory(); +std::string fs_get_cache_file(const std::string & filename); + +// +// Model utils +// + +// note: defines object's lifetime +struct common_init_result { + llama_model_ptr model; + llama_context_ptr context; + + std::vector lora; +}; + +struct common_init_result common_init_from_params(common_params & params); + +struct llama_model_params common_model_params_to_llama ( common_params & params); +struct llama_context_params common_context_params_to_llama(const common_params & params); +struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params); + +// clear LoRA adapters from context, then apply new list of adapters +void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora); + +std::string get_model_endpoint(); + +// +// Batch utils +// + +void common_batch_clear(struct llama_batch & batch); + +void common_batch_add( + struct llama_batch & batch, + llama_token id, + llama_pos pos, + const std::vector & seq_ids, + bool logits); + +// +// Token utils +// + +// longest common prefix +size_t common_lcp(const llama_tokens & a, const llama_tokens & b); + +// longet common subsequence +size_t common_lcs(const llama_tokens & a, const llama_tokens & b); + +// +// Vocab utils +// + +// tokenizes a string into a vector of tokens +// should work similar to Python's `tokenizer.encode` +std::vector common_tokenize( + const struct llama_context * ctx, + const std::string & text, + bool add_special, + bool parse_special = false); + +std::vector common_tokenize( + const struct llama_vocab * vocab, + const std::string & text, + bool add_special, + bool parse_special = false); + +// tokenizes a token into a piece, optionally renders special/control tokens +// should work similar to Python's `tokenizer.id_to_piece` +std::string common_token_to_piece( + const struct llama_context * ctx, + llama_token token, + bool special = true); + +std::string common_token_to_piece( + const struct llama_vocab * vocab, + llama_token token, + bool special = true); + +// detokenizes a vector of tokens into a string +// should work similar to Python's `tokenizer.decode` +// optionally renders special/control tokens +std::string common_detokenize( + const struct llama_context * ctx, + const std::vector & tokens, + bool special = true); + +std::string common_detokenize( + const struct llama_vocab * vocab, + const std::vector & tokens, + bool special = true); + +// +// Embedding utils +// + +// TODO: repace embd_norm with an enum +void common_embd_normalize(const float * inp, float * out, int n, int embd_norm); + +float common_embd_similarity_cos(const float * embd1, const float * embd2, int n); + +// +// Control vector utils +// + +struct common_control_vector_data { + int n_embd; + + // stores data for layers [1, n_layer] where n_layer = data.size() / n_embd + std::vector data; +}; + +struct common_control_vector_load_info { + float strength; + + std::string fname; +}; + +// Load control vectors, scale each by strength, and add them together. +// On error, returns {-1, empty} +common_control_vector_data common_control_vector_load(const std::vector & load_infos); + +// +// Split utils +// + +namespace { + +const char * const LLM_KV_SPLIT_NO = "split.no"; +const char * const LLM_KV_SPLIT_COUNT = "split.count"; +const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"; + +} + +// +// training utils +// + +ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector & tokens, int64_t stride); diff --git a/common/console.cpp b/common/console.cpp new file mode 100644 index 0000000000000000000000000000000000000000..078a8d678d933bec79eda408ff45858758c16f37 --- /dev/null +++ b/common/console.cpp @@ -0,0 +1,504 @@ +#include "console.h" +#include +#include + +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#include +#include +#ifndef ENABLE_VIRTUAL_TERMINAL_PROCESSING +#define ENABLE_VIRTUAL_TERMINAL_PROCESSING 0x0004 +#endif +#else +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +#define ANSI_COLOR_RED "\x1b[31m" +#define ANSI_COLOR_GREEN "\x1b[32m" +#define ANSI_COLOR_YELLOW "\x1b[33m" +#define ANSI_COLOR_BLUE "\x1b[34m" +#define ANSI_COLOR_MAGENTA "\x1b[35m" +#define ANSI_COLOR_CYAN "\x1b[36m" +#define ANSI_COLOR_RESET "\x1b[0m" +#define ANSI_BOLD "\x1b[1m" + +namespace console { + + // + // Console state + // + + static bool advanced_display = false; + static bool simple_io = true; + static display_t current_display = reset; + + static FILE* out = stdout; + +#if defined (_WIN32) + static void* hConsole; +#else + static FILE* tty = nullptr; + static termios initial_state; +#endif + + // + // Init and cleanup + // + + void init(bool use_simple_io, bool use_advanced_display) { + advanced_display = use_advanced_display; + simple_io = use_simple_io; +#if defined(_WIN32) + // Windows-specific console initialization + DWORD dwMode = 0; + hConsole = GetStdHandle(STD_OUTPUT_HANDLE); + if (hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(hConsole, &dwMode)) { + hConsole = GetStdHandle(STD_ERROR_HANDLE); + if (hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(hConsole, &dwMode))) { + hConsole = nullptr; + simple_io = true; + } + } + if (hConsole) { + // Check conditions combined to reduce nesting + if (advanced_display && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING) && + !SetConsoleMode(hConsole, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING)) { + advanced_display = false; + } + // Set console output codepage to UTF8 + SetConsoleOutputCP(CP_UTF8); + } + HANDLE hConIn = GetStdHandle(STD_INPUT_HANDLE); + if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) { + // Set console input codepage to UTF16 + _setmode(_fileno(stdin), _O_WTEXT); + + // Set ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT) + if (simple_io) { + dwMode |= ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT; + } else { + dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT); + } + if (!SetConsoleMode(hConIn, dwMode)) { + simple_io = true; + } + } + if (simple_io) { + _setmode(_fileno(stdin), _O_U8TEXT); + } +#else + // POSIX-specific console initialization + if (!simple_io) { + struct termios new_termios; + tcgetattr(STDIN_FILENO, &initial_state); + new_termios = initial_state; + new_termios.c_lflag &= ~(ICANON | ECHO); + new_termios.c_cc[VMIN] = 1; + new_termios.c_cc[VTIME] = 0; + tcsetattr(STDIN_FILENO, TCSANOW, &new_termios); + + tty = fopen("/dev/tty", "w+"); + if (tty != nullptr) { + out = tty; + } + } + + setlocale(LC_ALL, ""); +#endif + } + + void cleanup() { + // Reset console display + set_display(reset); + +#if !defined(_WIN32) + // Restore settings on POSIX systems + if (!simple_io) { + if (tty != nullptr) { + out = stdout; + fclose(tty); + tty = nullptr; + } + tcsetattr(STDIN_FILENO, TCSANOW, &initial_state); + } +#endif + } + + // + // Display and IO + // + + // Keep track of current display and only emit ANSI code if it changes + void set_display(display_t display) { + if (advanced_display && current_display != display) { + fflush(stdout); + switch(display) { + case reset: + fprintf(out, ANSI_COLOR_RESET); + break; + case prompt: + fprintf(out, ANSI_COLOR_YELLOW); + break; + case user_input: + fprintf(out, ANSI_BOLD ANSI_COLOR_GREEN); + break; + case error: + fprintf(out, ANSI_BOLD ANSI_COLOR_RED); + } + current_display = display; + fflush(out); + } + } + + static char32_t getchar32() { +#if defined(_WIN32) + HANDLE hConsole = GetStdHandle(STD_INPUT_HANDLE); + wchar_t high_surrogate = 0; + + while (true) { + INPUT_RECORD record; + DWORD count; + if (!ReadConsoleInputW(hConsole, &record, 1, &count) || count == 0) { + return WEOF; + } + + if (record.EventType == KEY_EVENT && record.Event.KeyEvent.bKeyDown) { + wchar_t wc = record.Event.KeyEvent.uChar.UnicodeChar; + if (wc == 0) { + continue; + } + + if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate + high_surrogate = wc; + continue; + } + if ((wc >= 0xDC00) && (wc <= 0xDFFF)) { // Check if wc is a low surrogate + if (high_surrogate != 0) { // Check if we have a high surrogate + return ((high_surrogate - 0xD800) << 10) + (wc - 0xDC00) + 0x10000; + } + } + + high_surrogate = 0; // Reset the high surrogate + return static_cast(wc); + } + } +#else + wchar_t wc = getwchar(); + if (static_cast(wc) == WEOF) { + return WEOF; + } + +#if WCHAR_MAX == 0xFFFF + if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate + wchar_t low_surrogate = getwchar(); + if ((low_surrogate >= 0xDC00) && (low_surrogate <= 0xDFFF)) { // Check if the next wchar is a low surrogate + return (static_cast(wc & 0x03FF) << 10) + (low_surrogate & 0x03FF) + 0x10000; + } + } + if ((wc >= 0xD800) && (wc <= 0xDFFF)) { // Invalid surrogate pair + return 0xFFFD; // Return the replacement character U+FFFD + } +#endif + + return static_cast(wc); +#endif + } + + static void pop_cursor() { +#if defined(_WIN32) + if (hConsole != NULL) { + CONSOLE_SCREEN_BUFFER_INFO bufferInfo; + GetConsoleScreenBufferInfo(hConsole, &bufferInfo); + + COORD newCursorPosition = bufferInfo.dwCursorPosition; + if (newCursorPosition.X == 0) { + newCursorPosition.X = bufferInfo.dwSize.X - 1; + newCursorPosition.Y -= 1; + } else { + newCursorPosition.X -= 1; + } + + SetConsoleCursorPosition(hConsole, newCursorPosition); + return; + } +#endif + putc('\b', out); + } + + static int estimateWidth(char32_t codepoint) { +#if defined(_WIN32) + (void)codepoint; + return 1; +#else + return wcwidth(codepoint); +#endif + } + + static int put_codepoint(const char* utf8_codepoint, size_t length, int expectedWidth) { +#if defined(_WIN32) + CONSOLE_SCREEN_BUFFER_INFO bufferInfo; + if (!GetConsoleScreenBufferInfo(hConsole, &bufferInfo)) { + // go with the default + return expectedWidth; + } + COORD initialPosition = bufferInfo.dwCursorPosition; + DWORD nNumberOfChars = length; + WriteConsole(hConsole, utf8_codepoint, nNumberOfChars, &nNumberOfChars, NULL); + + CONSOLE_SCREEN_BUFFER_INFO newBufferInfo; + GetConsoleScreenBufferInfo(hConsole, &newBufferInfo); + + // Figure out our real position if we're in the last column + if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) { + DWORD nNumberOfChars; + WriteConsole(hConsole, &" \b", 2, &nNumberOfChars, NULL); + GetConsoleScreenBufferInfo(hConsole, &newBufferInfo); + } + + int width = newBufferInfo.dwCursorPosition.X - initialPosition.X; + if (width < 0) { + width += newBufferInfo.dwSize.X; + } + return width; +#else + // We can trust expectedWidth if we've got one + if (expectedWidth >= 0 || tty == nullptr) { + fwrite(utf8_codepoint, length, 1, out); + return expectedWidth; + } + + fputs("\033[6n", tty); // Query cursor position + int x1; + int y1; + int x2; + int y2; + int results = 0; + results = fscanf(tty, "\033[%d;%dR", &y1, &x1); + + fwrite(utf8_codepoint, length, 1, tty); + + fputs("\033[6n", tty); // Query cursor position + results += fscanf(tty, "\033[%d;%dR", &y2, &x2); + + if (results != 4) { + return expectedWidth; + } + + int width = x2 - x1; + if (width < 0) { + // Calculate the width considering text wrapping + struct winsize w; + ioctl(STDOUT_FILENO, TIOCGWINSZ, &w); + width += w.ws_col; + } + return width; +#endif + } + + static void replace_last(char ch) { +#if defined(_WIN32) + pop_cursor(); + put_codepoint(&ch, 1, 1); +#else + fprintf(out, "\b%c", ch); +#endif + } + + static void append_utf8(char32_t ch, std::string & out) { + if (ch <= 0x7F) { + out.push_back(static_cast(ch)); + } else if (ch <= 0x7FF) { + out.push_back(static_cast(0xC0 | ((ch >> 6) & 0x1F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else if (ch <= 0xFFFF) { + out.push_back(static_cast(0xE0 | ((ch >> 12) & 0x0F))); + out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else if (ch <= 0x10FFFF) { + out.push_back(static_cast(0xF0 | ((ch >> 18) & 0x07))); + out.push_back(static_cast(0x80 | ((ch >> 12) & 0x3F))); + out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else { + // Invalid Unicode code point + } + } + + // Helper function to remove the last UTF-8 character from a string + static void pop_back_utf8_char(std::string & line) { + if (line.empty()) { + return; + } + + size_t pos = line.length() - 1; + + // Find the start of the last UTF-8 character (checking up to 4 bytes back) + for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) { + if ((line[pos] & 0xC0) != 0x80) { + break; // Found the start of the character + } + } + line.erase(pos); + } + + static bool readline_advanced(std::string & line, bool multiline_input) { + if (out != stdout) { + fflush(stdout); + } + + line.clear(); + std::vector widths; + bool is_special_char = false; + bool end_of_stream = false; + + char32_t input_char; + while (true) { + fflush(out); // Ensure all output is displayed before waiting for input + input_char = getchar32(); + + if (input_char == '\r' || input_char == '\n') { + break; + } + + if (input_char == (char32_t) WEOF || input_char == 0x04 /* Ctrl+D*/) { + end_of_stream = true; + break; + } + + if (is_special_char) { + set_display(user_input); + replace_last(line.back()); + is_special_char = false; + } + + if (input_char == '\033') { // Escape sequence + char32_t code = getchar32(); + if (code == '[' || code == 0x1B) { + // Discard the rest of the escape sequence + while ((code = getchar32()) != (char32_t) WEOF) { + if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') { + break; + } + } + } + } else if (input_char == 0x08 || input_char == 0x7F) { // Backspace + if (!widths.empty()) { + int count; + do { + count = widths.back(); + widths.pop_back(); + // Move cursor back, print space, and move cursor back again + for (int i = 0; i < count; i++) { + replace_last(' '); + pop_cursor(); + } + pop_back_utf8_char(line); + } while (count == 0 && !widths.empty()); + } + } else { + int offset = line.length(); + append_utf8(input_char, line); + int width = put_codepoint(line.c_str() + offset, line.length() - offset, estimateWidth(input_char)); + if (width < 0) { + width = 0; + } + widths.push_back(width); + } + + if (!line.empty() && (line.back() == '\\' || line.back() == '/')) { + set_display(prompt); + replace_last(line.back()); + is_special_char = true; + } + } + + bool has_more = multiline_input; + if (is_special_char) { + replace_last(' '); + pop_cursor(); + + char last = line.back(); + line.pop_back(); + if (last == '\\') { + line += '\n'; + fputc('\n', out); + has_more = !has_more; + } else { + // llama will just eat the single space, it won't act as a space + if (line.length() == 1 && line.back() == ' ') { + line.clear(); + pop_cursor(); + } + has_more = false; + } + } else { + if (end_of_stream) { + has_more = false; + } else { + line += '\n'; + fputc('\n', out); + } + } + + fflush(out); + return has_more; + } + + static bool readline_simple(std::string & line, bool multiline_input) { +#if defined(_WIN32) + std::wstring wline; + if (!std::getline(std::wcin, wline)) { + // Input stream is bad or EOF received + line.clear(); + GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0); + return false; + } + + int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), NULL, 0, NULL, NULL); + line.resize(size_needed); + WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), &line[0], size_needed, NULL, NULL); +#else + if (!std::getline(std::cin, line)) { + // Input stream is bad or EOF received + line.clear(); + return false; + } +#endif + if (!line.empty()) { + char last = line.back(); + if (last == '/') { // Always return control on '/' symbol + line.pop_back(); + return false; + } + if (last == '\\') { // '\\' changes the default action + line.pop_back(); + multiline_input = !multiline_input; + } + } + line += '\n'; + + // By default, continue input if multiline_input is set + return multiline_input; + } + + bool readline(std::string & line, bool multiline_input) { + set_display(user_input); + + if (simple_io) { + return readline_simple(line, multiline_input); + } + return readline_advanced(line, multiline_input); + } + +} diff --git a/common/console.h b/common/console.h new file mode 100644 index 0000000000000000000000000000000000000000..ec175269b9d8af48803d0b6e618d008a9ab99b4d --- /dev/null +++ b/common/console.h @@ -0,0 +1,19 @@ +// Console functions + +#pragma once + +#include + +namespace console { + enum display_t { + reset = 0, + prompt, + user_input, + error + }; + + void init(bool use_simple_io, bool use_advanced_display); + void cleanup(); + void set_display(display_t display); + bool readline(std::string & line, bool multiline_input); +} diff --git a/common/json-partial.cpp b/common/json-partial.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d9d91699899f7ba9870184caa7e3c5ff04280e9b --- /dev/null +++ b/common/json-partial.cpp @@ -0,0 +1,256 @@ +#include "json-partial.h" + +#include "log.h" + +#include + +#include + +using json = nlohmann::ordered_json; + +enum common_json_stack_element_type { + COMMON_JSON_STACK_ELEMENT_OBJECT, + COMMON_JSON_STACK_ELEMENT_KEY, + COMMON_JSON_STACK_ELEMENT_ARRAY, +}; + +struct common_json_stack_element { + common_json_stack_element_type type; + std::string key; +}; + +bool common_json_parse( + const std::string & input, + const std::string & healing_marker, + common_json & out) +{ + std::string::const_iterator it = input.begin(); + const auto end = input.end(); + return common_json_parse(it, end, healing_marker, out); +} + +bool common_json_parse( + std::string::const_iterator & it, + const std::string::const_iterator & end, + const std::string & healing_marker, + common_json & out) +{ + // // https://json.nlohmann.me/features/parsing/sax_interface/ + struct json_error_locator : public nlohmann::json_sax { + std::size_t position; + bool found_error; + std::string last_token; + std::string exception_message; + std::vector stack; + + json_error_locator() : position(0), found_error(false) {} + + bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT + this->position = position - 1; + this->found_error = true; + this->last_token = last_token; + this->exception_message = ex.what(); + return false; + } + void close_value() { + if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) { + stack.pop_back(); + } + } + bool null() override { // NOLINT + close_value(); + return true; + } + bool boolean(bool) override { // NOLINT + close_value(); + return true; + } + bool number_integer(number_integer_t) override { // NOLINT + close_value(); + return true; + } + bool number_unsigned(number_unsigned_t) override { // NOLINT + close_value(); + return true; + } + bool number_float(number_float_t, const string_t &) override { // NOLINT + close_value(); + return true; + } + bool string(string_t &) override { // NOLINT + close_value(); + return true; + } + bool binary(binary_t &) override { // NOLINT + close_value(); + return true; + } + bool start_object(std::size_t) override { // NOLINT + stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""}); + return true; + } + bool end_object() override { + GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT); + stack.pop_back(); + close_value(); + return true; + } + bool key(string_t & key) override { // NOLINT + stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key}); + return true; + } + bool start_array(std::size_t) override { // NOLINT + stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""}); + return true; + } + bool end_array() override { + GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY); + stack.pop_back(); + close_value(); + return true; + } + }; + json_error_locator err_loc; + auto start = it; + json::sax_parse(it, end, &err_loc); + + if (err_loc.found_error) { + it = start; + auto temptative_end = it + err_loc.position; + // LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str()); + + auto input = std::string(it, temptative_end); + try { + out.json = json::parse(input); + // out.json = json::parse(it, temptative_end); + it = temptative_end; + return true; + } catch (const std::exception & ex) { + // No, needs healing. + LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str()); + } + auto can_parse = [](const std::string & str) { + try { + auto _ = json::parse(str); // NOLINT + return true; + } catch (const std::exception &) { + return false; + } + }; + if (!healing_marker.empty() && !err_loc.stack.empty()) { + std::string str(it, temptative_end); + auto last_non_sp_pos = str.find_last_not_of(" \n\r\t"); + if (last_non_sp_pos == std::string::npos) { + throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location"); + } + auto last_non_sp_char = str[last_non_sp_pos]; + // Used to detect stops on a number, which may not be complete. + auto was_maybe_number = [&]() { + if (!str.empty() && std::isspace(str.back())) { + return false; + } + return std::isdigit(last_non_sp_char) || + last_non_sp_char == '.' || + last_non_sp_char == 'e' || + last_non_sp_char == 'E' || + last_non_sp_char == '-'; + }; + + std::string closing; + for (size_t i = err_loc.stack.size(); i > 0; i--) { + auto & el = err_loc.stack[i - 1]; + if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) { + closing += "}"; + } else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) { + closing += "]"; + } else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) { + throw std::runtime_error("Unexpected stack element type"); + } + } + + const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$"; + + if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) { + // We're inside an object value + if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) { + // Was about to create an object value + str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; + } else if (can_parse(str + ": 1" + closing)) { + str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing; + } else if (last_non_sp_char == '{' && can_parse(str + closing)) { + // Was about to create an object + str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing; + } else if (can_parse(str + "\"" + closing)) { + // Was inside an object value string + str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing; + } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) { + // Was inside an object value string after an escape + str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing; + } else { + // find last : + auto last_pos = str.find_last_of(':'); + if (last_pos == std::string::npos) { + throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location"); + } + // Cutting back to opening : for object value + str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; + } + } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) { + if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) { + // Was about to create an array value + str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; + } else if (can_parse(str + "\"" + closing)) { + // Was inside an array value string + str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing; + } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) { + // Was inside an array value string after an escape + str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing; + } else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) { + // Had just finished a value + str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing; + } else { + auto last_pos = str.find_last_of("[,"); + if (last_pos == std::string::npos) { + throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location"); + } + // Cutting back to last [ or , for array value + str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; + } + } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) { + if ((last_non_sp_char == '{' && can_parse(str + closing)) || + (last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) { + // Was about to create an object key+value + str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing; + } else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) { + // Was about to create an object key+value + str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing; + } else if (can_parse(str + "\": 1" + closing)) { + // Was inside an object key string + str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing; + } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) { + // Was inside an object key string after an escape + str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing; + } else { + auto last_pos = str.find_last_of(':'); + if (last_pos == std::string::npos) { + throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location"); + } + // fprintf(stderr, "Cutting back to last : for object key+value\n"); + str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; + } + } else { + throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location"); + } + // fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str()); + out.json = json::parse(str); + it = temptative_end; + return true; + } + // TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...) + // fprintf(stderr, "Closing: TODO\n"); + return false; + } + out.json = json::parse(it, end); + it = end; + return true; +} diff --git a/common/json-partial.h b/common/json-partial.h new file mode 100644 index 0000000000000000000000000000000000000000..f63356dc48f7843ce0a11e424397d442e4337fdd --- /dev/null +++ b/common/json-partial.h @@ -0,0 +1,38 @@ +#pragma once + +#include + +// Healing marker (empty if the JSON was fully parsed / wasn't healed). +struct common_healing_marker { + // Raw marker. + std::string marker; + + // Cutting the `common_json.json.dump()` string at the (only) occurrence of this marker should yield the original partial JSON string (modulo spaces / if it had the same dump format). + std::string json_dump_marker; +}; + +// Represents a parsed JSON object, with its optional healing marker (a JSON dump fragment that can be used to find the position of healing in the JSON dump string) +struct common_json { + nlohmann::ordered_json json; + + common_healing_marker healing_marker; +}; + +// Parse the JSON string, healing (closing) any partial JSON if `healing_marker` is not empty. +// +// Healing completes partial JSON strings by adding a (possibly modified) healing marker, then whatever is needed to close the JSON. +// This allows to parse the resulting healed JSON string, yet be able to cut it again if needed at the healing marker. +// (this is used when parsing JSON outputs from the models, then crafting partial JSONs for the partial tool calls in OAI format). +// +// For instance, parsing `{` with a healing marker `foo` will produce a healed JSON `{"foo":1}`, w/ json_dump_marker = `"foo"` (which can be used to break the JSON again). +bool common_json_parse( + const std::string & input, + const std::string & healing_marker, + common_json & out); + +// Parse the JSON string (see overload above), but advancing an iterator to the end of the input when the (potentially partial) parsing succeeds. +bool common_json_parse( + std::string::const_iterator & it, + const std::string::const_iterator & end, + const std::string & healing_marker, + common_json & out); diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d38a74f95c213aaa9a08ecf5cd5ae371ad73df80 --- /dev/null +++ b/common/json-schema-to-grammar.cpp @@ -0,0 +1,1028 @@ +#include "json-schema-to-grammar.h" +#include "common.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") { + auto has_max = max_items != std::numeric_limits::max(); + + if (max_items == 0) { + return ""; + } + if (min_items == 0 && max_items == 1) { + return item_rule + "?"; + } + + if (separator_rule.empty()) { + if (min_items == 1 && !has_max) { + return item_rule + "+"; + } else if (min_items == 0 && !has_max) { + return item_rule + "*"; + } else { + return item_rule + "{" + std::to_string(min_items) + "," + (has_max ? std::to_string(max_items) : "") + "}"; + } + } + + auto result = item_rule + " " + build_repetition("(" + separator_rule + " " + item_rule + ")", min_items == 0 ? 0 : min_items - 1, has_max ? max_items - 1 : max_items); + if (min_items == 0) { + result = "(" + result + ")?"; + } + return result; +} + +/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */ +class string_view { + const std::string & _str; + const size_t _start; + const size_t _end; +public: + string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {} + + size_t size() const { + return _end - _start; + } + + size_t length() const { + return size(); + } + + operator std::string() const { + return str(); + } + + std::string str() const { + return _str.substr(_start, _end - _start); + } + + string_view substr(size_t pos, size_t len = std::string::npos) const { + return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len); + } + + char operator[](size_t pos) const { + auto index = _start + pos; + if (index >= _end) { + throw std::out_of_range("string_view index out of range"); + } + return _str[_start + pos]; + } + + bool operator==(const string_view & other) const { + std::string this_str = *this; + std::string other_str = other; + return this_str == other_str; + } +}; + +static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) { + auto has_min = min_value != std::numeric_limits::min(); + auto has_max = max_value != std::numeric_limits::max(); + + auto digit_range = [&](char from, char to) { + out << "["; + if (from == to) { + out << from; + } else { + out << from << "-" << to; + } + out << "]"; + }; + auto more_digits = [&](int min_digits, int max_digits) { + out << "[0-9]"; + if (min_digits == max_digits && min_digits == 1) { + return; + } + out << "{"; + out << min_digits; + if (max_digits != min_digits) { + out << ","; + if (max_digits != std::numeric_limits::max()) { + out << max_digits; + } + } + out << "}"; + }; + std::function uniform_range = + [&](const string_view & from, const string_view & to) { + size_t i = 0; + while (i < from.length() && i < to.length() && from[i] == to[i]) { + i++; + } + if (i > 0) { + out << "\"" << from.substr(0, i).str() << "\""; + } + if (i < from.length() && i < to.length()) { + if (i > 0) { + out << " "; + } + auto sub_len = from.length() - i - 1; + if (sub_len > 0) { + auto from_sub = from.substr(i + 1); + auto to_sub = to.substr(i + 1); + auto sub_zeros = string_repeat("0", sub_len); + auto sub_nines = string_repeat("9", sub_len); + + auto to_reached = false; + out << "("; + if (from_sub == sub_zeros) { + digit_range(from[i], to[i] - 1); + out << " "; + more_digits(sub_len, sub_len); + } else { + out << "[" << from[i] << "] "; + out << "("; + uniform_range(from_sub, sub_nines); + out << ")"; + if (from[i] < to[i] - 1) { + out << " | "; + if (to_sub == sub_nines) { + digit_range(from[i] + 1, to[i]); + to_reached = true; + } else { + digit_range(from[i] + 1, to[i] - 1); + } + out << " "; + more_digits(sub_len, sub_len); + } + } + if (!to_reached) { + out << " | "; + digit_range(to[i], to[i]); + out << " "; + uniform_range(sub_zeros, to_sub); + } + out << ")"; + } else { + out << "[" << from[i] << "-" << to[i] << "]"; + } + } + }; + + if (has_min && has_max) { + if (min_value < 0 && max_value < 0) { + out << "\"-\" ("; + _build_min_max_int(-max_value, -min_value, out, decimals_left, /* top_level= */ true); + out << ")"; + return; + } + + if (min_value < 0) { + out << "\"-\" ("; + _build_min_max_int(0, -min_value, out, decimals_left, /* top_level= */ true); + out << ") | "; + min_value = 0; + } + + auto min_s = std::to_string(min_value); + auto max_s = std::to_string(max_value); + auto min_digits = min_s.length(); + auto max_digits = max_s.length(); + + for (auto digits = min_digits; digits < max_digits; digits++) { + uniform_range(min_s, string_repeat("9", digits)); + min_s = "1" + string_repeat("0", digits); + out << " | "; + } + uniform_range(min_s, max_s); + return; + } + + auto less_decimals = std::max(decimals_left - 1, 1); + + if (has_min) { + if (min_value < 0) { + out << "\"-\" ("; + _build_min_max_int(std::numeric_limits::min(), -min_value, out, decimals_left, /* top_level= */ false); + out << ") | [0] | [1-9] "; + more_digits(0, decimals_left - 1); + } else if (min_value == 0) { + if (top_level) { + out << "[0] | [1-9] "; + more_digits(0, less_decimals); + } else { + more_digits(1, decimals_left); + } + } else if (min_value <= 9) { + char c = '0' + min_value; + auto range_start = top_level ? '1' : '0'; + if (c > range_start) { + digit_range(range_start, c - 1); + out << " "; + more_digits(1, less_decimals); + out << " | "; + } + digit_range(c, '9'); + out << " "; + more_digits(0, less_decimals); + } else { + auto min_s = std::to_string(min_value); + auto len = min_s.length(); + auto c = min_s[0]; + + if (c > '1') { + digit_range(top_level ? '1' : '0', c - 1); + out << " "; + more_digits(len, less_decimals); + out << " | "; + } + digit_range(c, c); + out << " ("; + _build_min_max_int(std::stoi(min_s.substr(1)), std::numeric_limits::max(), out, less_decimals, /* top_level= */ false); + out << ")"; + if (c < '9') { + out << " | "; + digit_range(c + 1, '9'); + out << " "; + more_digits(len - 1, less_decimals); + } + } + return; + } + + if (has_max) { + if (max_value >= 0) { + if (top_level) { + out << "\"-\" [1-9] "; + more_digits(0, less_decimals); + out << " | "; + } + _build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true); + } else { + out << "\"-\" ("; + _build_min_max_int(-max_value, std::numeric_limits::max(), out, decimals_left, /* top_level= */ false); + out << ")"; + } + return; + } + + throw std::runtime_error("At least one of min_value or max_value must be set"); +} + +const std::string SPACE_RULE = "| \" \" | \"\\n\"{1,2} [ \\t]{0,20}"; + +struct BuiltinRule { + std::string content; + std::vector deps; +}; + +std::unordered_map PRIMITIVE_RULES = { + {"boolean", {"(\"true\" | \"false\") space", {}}}, + {"decimal-part", {"[0-9]{1,16}", {}}}, + {"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}}, + {"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)? space", {"integral-part", "decimal-part"}}}, + {"integer", {"(\"-\"? integral-part) space", {"integral-part"}}}, + {"value", {"object | array | string | number | boolean | null", {"object", "array", "string", "number", "boolean", "null"}}}, + {"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space", {"string", "value"}}}, + {"array", {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}}, + {"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space", {}}}, + {"char", {"[^\"\\\\\\x7F\\x00-\\x1F] | [\\\\] ([\"\\\\bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}}, + {"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}}, + {"null", {"\"null\" space", {}}}, +}; + +std::unordered_map STRING_FORMAT_RULES = { + {"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}}, + {"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}}, + {"date-time", {"date \"T\" time", {"date", "time"}}}, + {"date-string", {"\"\\\"\" date \"\\\"\" space", {"date"}}}, + {"time-string", {"\"\\\"\" time \"\\\"\" space", {"time"}}}, + {"date-time-string", {"\"\\\"\" date-time \"\\\"\" space", {"date-time"}}} +}; + +static bool is_reserved_name(const std::string & name) { + static std::unordered_set RESERVED_NAMES; + if (RESERVED_NAMES.empty()) { + RESERVED_NAMES.insert("root"); + for (const auto &p : PRIMITIVE_RULES) RESERVED_NAMES.insert(p.first); + for (const auto &p : STRING_FORMAT_RULES) RESERVED_NAMES.insert(p.first); + } + return RESERVED_NAMES.find(name) != RESERVED_NAMES.end(); +} + +std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+"); +std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]"); +std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]"); +std::unordered_map GRAMMAR_LITERAL_ESCAPES = { + {'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"} +}; + +std::unordered_set NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'}; +std::unordered_set ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'}; + +static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function & replacement) { + std::smatch match; + std::string result; + + std::string::const_iterator searchStart(input.cbegin()); + std::string::const_iterator searchEnd(input.cend()); + + while (std::regex_search(searchStart, searchEnd, match, regex)) { + result.append(searchStart, searchStart + match.position()); + result.append(replacement(match)); + searchStart = match.suffix().first; + } + + result.append(searchStart, searchEnd); + + return result; +} + +static std::string format_literal(const std::string & literal) { + std::string escaped = replacePattern(literal, GRAMMAR_LITERAL_ESCAPE_RE, [&](const std::smatch & match) { + char c = match.str()[0]; + return GRAMMAR_LITERAL_ESCAPES.at(c); + }); + return "\"" + escaped + "\""; +} + +class SchemaConverter { +private: + friend std::string build_grammar(const std::function & cb, const common_grammar_options & options); + std::function _fetch_json; + bool _dotall; + std::map _rules; + std::unordered_map _refs; + std::unordered_set _refs_being_resolved; + std::vector _errors; + std::vector _warnings; + + std::string _add_rule(const std::string & name, const std::string & rule) { + std::string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-"); + if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) { + _rules[esc_name] = rule; + return esc_name; + } else { + int i = 0; + while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) { + i++; + } + std::string key = esc_name + std::to_string(i); + _rules[key] = rule; + return key; + } + } + + std::string _generate_union_rule(const std::string & name, const std::vector & alt_schemas) { + std::vector rules; + for (size_t i = 0; i < alt_schemas.size(); i++) { + rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i))); + } + return string_join(rules, " | "); + } + + std::string _visit_pattern(const std::string & pattern, const std::string & name) { + if (!(pattern.front() == '^' && pattern.back() == '$')) { + _errors.push_back("Pattern must start with '^' and end with '$'"); + return ""; + } + std::string sub_pattern = pattern.substr(1, pattern.length() - 2); + std::unordered_map sub_rule_ids; + + size_t i = 0; + size_t length = sub_pattern.length(); + + using literal_or_rule = std::pair; + auto to_rule = [&](const literal_or_rule & ls) { + auto is_literal = ls.second; + auto s = ls.first; + return is_literal ? "\"" + s + "\"" : s; + }; + std::function transform = [&]() -> literal_or_rule { + size_t start = i; + std::vector seq; + + auto get_dot = [&]() { + std::string rule; + if (_dotall) { + rule = "[\\U00000000-\\U0010FFFF]"; + } else { + rule = "[^\\x0A\\x0D]"; + } + return _add_rule("dot", rule); + }; + + // Joins the sequence, merging consecutive literals together. + auto join_seq = [&]() { + std::vector ret; + + std::string literal; + auto flush_literal = [&]() { + if (literal.empty()) { + return false; + } + ret.emplace_back(literal, true); + literal.clear(); + return true; + }; + + for (const auto & item : seq) { + auto is_literal = item.second; + if (is_literal) { + literal += item.first; + } else { + flush_literal(); + ret.push_back(item); + } + } + flush_literal(); + + std::vector results; + for (const auto & item : ret) { + results.push_back(to_rule(item)); + } + return std::make_pair(string_join(results, " "), false); + }; + + while (i < length) { + char c = sub_pattern[i]; + if (c == '.') { + seq.emplace_back(get_dot(), false); + i++; + } else if (c == '(') { + i++; + if (i < length) { + if (sub_pattern[i] == '?') { + _warnings.push_back("Unsupported pattern syntax"); + } + } + seq.emplace_back("(" + to_rule(transform()) + ")", false); + } else if (c == ')') { + i++; + if (start > 0 && sub_pattern[start - 1] != '(') { + _errors.push_back("Unbalanced parentheses"); + } + return join_seq(); + } else if (c == '[') { + std::string square_brackets = std::string(1, c); + i++; + while (i < length && sub_pattern[i] != ']') { + if (sub_pattern[i] == '\\') { + square_brackets += sub_pattern.substr(i, 2); + i += 2; + } else { + square_brackets += sub_pattern[i]; + i++; + } + } + if (i >= length) { + _errors.push_back("Unbalanced square brackets"); + } + square_brackets += ']'; + i++; + seq.emplace_back(square_brackets, false); + } else if (c == '|') { + seq.emplace_back("|", false); + i++; + } else if (c == '*' || c == '+' || c == '?') { + seq.back() = std::make_pair(to_rule(seq.back()) + c, false); + i++; + } else if (c == '{') { + std::string curly_brackets = std::string(1, c); + i++; + while (i < length && sub_pattern[i] != '}') { + curly_brackets += sub_pattern[i]; + i++; + } + if (i >= length) { + _errors.push_back("Unbalanced curly brackets"); + } + curly_brackets += '}'; + i++; + auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ","); + int min_times = 0; + int max_times = std::numeric_limits::max(); + try { + if (nums.size() == 1) { + min_times = max_times = std::stoi(nums[0]); + } else if (nums.size() != 2) { + _errors.push_back("Wrong number of values in curly brackets"); + } else { + if (!nums[0].empty()) { + min_times = std::stoi(nums[0]); + } + if (!nums[1].empty()) { + max_times = std::stoi(nums[1]); + } + } + } catch (const std::invalid_argument & e) { + _errors.push_back("Invalid number in curly brackets"); + return std::make_pair("", false); + } + auto &last = seq.back(); + auto &sub = last.first; + auto sub_is_literal = last.second; + + if (!sub_is_literal) { + std::string & sub_id = sub_rule_ids[sub]; + if (sub_id.empty()) { + sub_id = _add_rule(name + "-" + std::to_string(sub_rule_ids.size()), sub); + } + sub = sub_id; + } + seq.back().first = build_repetition( + sub_is_literal ? "\"" + sub + "\"" : sub, + min_times, + max_times, + "" + ); + seq.back().second = false; + } else { + std::string literal; + auto is_non_literal = [&](char c) { + return NON_LITERAL_SET.find(c) != NON_LITERAL_SET.end(); + }; + while (i < length) { + if (sub_pattern[i] == '\\' && i < length - 1) { + char next = sub_pattern[i + 1]; + if (ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS.find(next) != ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS.end()) { + i++; + literal += sub_pattern[i]; + i++; + } else { + literal += sub_pattern.substr(i, 2); + i += 2; + } + } else if (sub_pattern[i] == '"') { + literal += "\\\""; + i++; + } else if (!is_non_literal(sub_pattern[i]) && + (i == length - 1 || literal.empty() || sub_pattern[i + 1] == '.' || !is_non_literal(sub_pattern[i + 1]))) { + literal += sub_pattern[i]; + i++; + } else { + break; + } + } + if (!literal.empty()) { + seq.emplace_back(literal, true); + } + } + } + return join_seq(); + }; + return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space"); + } + + /* + Returns a rule that matches a JSON string that is none of the provided strings + + not_strings({"a"}) + -> ["] ( [a] char+ | [^"a] char* )? ["] space + not_strings({"and", "also"}) + -> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space + */ + std::string _not_strings(const std::vector & strings) { + + struct TrieNode { + std::map children; + bool is_end_of_string; + + TrieNode() : is_end_of_string(false) {} + + void insert(const std::string & string) { + auto node = this; + for (char c : string) { + node = &node->children[c]; + } + node->is_end_of_string = true; + } + }; + + TrieNode trie; + for (const auto & s : strings) { + trie.insert(s); + } + + std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char")); + std::ostringstream out; + out << "[\"] ( "; + std::function visit = [&](const TrieNode & node) { + std::ostringstream rejects; + auto first = true; + for (const auto & kv : node.children) { + rejects << kv.first; + if (first) { + first = false; + } else { + out << " | "; + } + out << "[" << kv.first << "]"; + if (!kv.second.children.empty()) { + out << " ("; + visit(kv.second); + out << ")"; + } else if (kv.second.is_end_of_string) { + out << " " << char_rule << "+"; + } + } + if (!node.children.empty()) { + if (!first) { + out << " | "; + } + out << "[^\"" << rejects.str() << "] " << char_rule << "*"; + } + }; + visit(trie); + + out << " )"; + if (!trie.is_end_of_string) { + out << "?"; + } + out << " [\"] space"; + return out.str(); + } + + std::string _resolve_ref(const std::string & ref) { + std::string ref_name = ref.substr(ref.find_last_of('/') + 1); + if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) { + _refs_being_resolved.insert(ref); + json resolved = _refs[ref]; + ref_name = visit(resolved, ref_name); + _refs_being_resolved.erase(ref); + } + return ref_name; + } + + std::string _build_object_rule( + const std::vector> & properties, + const std::unordered_set & required, + const std::string & name, + const json & additional_properties) + { + std::vector required_props; + std::vector optional_props; + std::unordered_map prop_kv_rule_names; + std::vector prop_names; + for (const auto & kv : properties) { + const auto &prop_name = kv.first; + const auto &prop_schema = kv.second; + + std::string prop_rule_name = visit(prop_schema, name + (name.empty() ? "" : "-") + prop_name); + prop_kv_rule_names[prop_name] = _add_rule( + name + (name.empty() ? "" : "-") + prop_name + "-kv", + format_literal(json(prop_name).dump()) + " space \":\" space " + prop_rule_name + ); + if (required.find(prop_name) != required.end()) { + required_props.push_back(prop_name); + } else { + optional_props.push_back(prop_name); + } + prop_names.push_back(prop_name); + } + if ((additional_properties.is_boolean() && additional_properties.get()) || additional_properties.is_object()) { + std::string sub_name = name + (name.empty() ? "" : "-") + "additional"; + std::string value_rule = + additional_properties.is_object() ? visit(additional_properties, sub_name + "-value") + : _add_primitive("value", PRIMITIVE_RULES.at("value")); + + auto key_rule = + prop_names.empty() ? _add_primitive("string", PRIMITIVE_RULES.at("string")) + : _add_rule(sub_name + "-k", _not_strings(prop_names)); + std::string kv_rule = _add_rule(sub_name + "-kv", key_rule + " \":\" space " + value_rule); + prop_kv_rule_names["*"] = kv_rule; + optional_props.push_back("*"); + } + + std::string rule = "\"{\" space "; + for (size_t i = 0; i < required_props.size(); i++) { + if (i > 0) { + rule += " \",\" space "; + } + rule += prop_kv_rule_names[required_props[i]]; + } + + if (!optional_props.empty()) { + rule += " ("; + if (!required_props.empty()) { + rule += " \",\" space ( "; + } + + std::function &, bool)> get_recursive_refs = [&](const std::vector & ks, bool first_is_optional) { + std::string res; + if (ks.empty()) { + return res; + } + std::string k = ks[0]; + std::string kv_rule_name = prop_kv_rule_names[k]; + std::string comma_ref = "( \",\" space " + kv_rule_name + " )"; + if (first_is_optional) { + res = comma_ref + (k == "*" ? "*" : "?"); + } else { + res = kv_rule_name + (k == "*" ? " " + comma_ref + "*" : ""); + } + if (ks.size() > 1) { + res += " " + _add_rule( + name + (name.empty() ? "" : "-") + k + "-rest", + get_recursive_refs(std::vector(ks.begin() + 1, ks.end()), true) + ); + } + return res; + }; + + for (size_t i = 0; i < optional_props.size(); i++) { + if (i > 0) { + rule += " | "; + } + rule += get_recursive_refs(std::vector(optional_props.begin() + i, optional_props.end()), false); + } + if (!required_props.empty()) { + rule += " )"; + } + rule += " )?"; + } + + rule += " \"}\" space"; + + return rule; + } + + std::string _add_primitive(const std::string & name, const BuiltinRule & rule) { + auto n = _add_rule(name, rule.content); + for (const auto & dep : rule.deps) { + BuiltinRule dep_rule; + auto it = PRIMITIVE_RULES.find(dep); + if (it == PRIMITIVE_RULES.end()) { + it = STRING_FORMAT_RULES.find(dep); + if (it == STRING_FORMAT_RULES.end()) { + _errors.push_back("Rule " + dep + " not known"); + continue; + } + } + if (_rules.find(dep) == _rules.end()) { + _add_primitive(dep, it->second); + } + } + return n; + } + +public: + SchemaConverter( + const std::function & fetch_json, + bool dotall) + : _fetch_json(fetch_json), _dotall(dotall) + { + _rules["space"] = SPACE_RULE; + } + + void resolve_refs(json & schema, const std::string & url) { + /* + * Resolves all $ref fields in the given schema, fetching any remote schemas, + * replacing each $ref with absolute reference URL and populates _refs with the + * respective referenced (sub)schema dictionaries. + */ + std::function visit_refs = [&](json & n) { + if (n.is_array()) { + for (auto & x : n) { + visit_refs(x); + } + } else if (n.is_object()) { + if (n.contains("$ref")) { + std::string ref = n["$ref"]; + if (_refs.find(ref) == _refs.end()) { + json target; + if (ref.find("https://") == 0) { + std::string base_url = ref.substr(0, ref.find('#')); + auto it = _refs.find(base_url); + if (it != _refs.end()) { + target = it->second; + } else { + // Fetch the referenced schema and resolve its refs + auto referenced = _fetch_json(ref); + resolve_refs(referenced, base_url); + _refs[base_url] = referenced; + } + if (ref.find('#') == std::string::npos || ref.substr(ref.find('#') + 1).empty()) { + return; + } + } else if (ref.find("#/") == 0) { + target = schema; + n["$ref"] = url + ref; + ref = url + ref; + } else { + _errors.push_back("Unsupported ref: " + ref); + return; + } + std::string pointer = ref.substr(ref.find('#') + 1); + std::vector tokens = string_split(pointer, "/"); + for (size_t i = 1; i < tokens.size(); ++i) { + std::string sel = tokens[i]; + if (target.is_null() || !target.contains(sel)) { + _errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump()); + return; + } + target = target[sel]; + } + _refs[ref] = target; + } + } else { + for (auto & kv : n.items()) { + visit_refs(kv.value()); + } + } + } + }; + + visit_refs(schema); + } + + std::string _generate_constant_rule(const json & value) { + return format_literal(value.dump()); + } + + std::string visit(const json & schema, const std::string & name) { + json schema_type = schema.contains("type") ? schema["type"] : json(); + std::string schema_format = schema.contains("format") ? schema["format"].get() : ""; + std::string rule_name = is_reserved_name(name) ? name + "-" : name.empty() ? "root" : name; + + if (schema.contains("$ref")) { + return _add_rule(rule_name, _resolve_ref(schema["$ref"])); + } else if (schema.contains("oneOf") || schema.contains("anyOf")) { + std::vector alt_schemas = schema.contains("oneOf") ? schema["oneOf"].get>() : schema["anyOf"].get>(); + return _add_rule(rule_name, _generate_union_rule(name, alt_schemas)); + } else if (schema_type.is_array()) { + std::vector schema_types; + for (const auto & t : schema_type) { + json schema_copy(schema); + schema_copy["type"] = t; + schema_types.push_back(schema_copy); + } + return _add_rule(rule_name, _generate_union_rule(name, schema_types)); + } else if (schema.contains("const")) { + return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space"); + } else if (schema.contains("enum")) { + std::vector enum_values; + for (const auto & v : schema["enum"]) { + enum_values.push_back(_generate_constant_rule(v)); + } + return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space"); + } else if ((schema_type.is_null() || schema_type == "object") + && (schema.contains("properties") || + (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) { + std::unordered_set required; + if (schema.contains("required") && schema["required"].is_array()) { + for (const auto & item : schema["required"]) { + if (item.is_string()) { + required.insert(item.get()); + } + } + } + std::vector> properties; + if (schema.contains("properties")) { + for (const auto & prop : schema["properties"].items()) { + properties.emplace_back(prop.key(), prop.value()); + } + } + return _add_rule(rule_name, + _build_object_rule( + properties, required, name, + schema.contains("additionalProperties") ? schema["additionalProperties"] : json())); + } else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) { + std::unordered_set required; + std::vector> properties; + std::string hybrid_name = name; + std::function add_component = [&](const json & comp_schema, bool is_required) { + if (comp_schema.contains("$ref")) { + add_component(_refs[comp_schema["$ref"]], is_required); + } else if (comp_schema.contains("properties")) { + for (const auto & prop : comp_schema["properties"].items()) { + properties.emplace_back(prop.key(), prop.value()); + if (is_required) { + required.insert(prop.key()); + } + } + } else { + // todo warning + } + }; + for (auto & t : schema["allOf"]) { + if (t.contains("anyOf")) { + for (auto & tt : t["anyOf"]) { + add_component(tt, false); + } + } else { + add_component(t, true); + } + } + return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json())); + } else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) { + json items = schema.contains("items") ? schema["items"] : schema["prefixItems"]; + if (items.is_array()) { + std::string rule = "\"[\" space "; + for (size_t i = 0; i < items.size(); i++) { + if (i > 0) { + rule += " \",\" space "; + } + rule += visit(items[i], name + (name.empty() ? "" : "-") + "tuple-" + std::to_string(i)); + } + rule += " \"]\" space"; + return _add_rule(rule_name, rule); + } else { + std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item"); + int min_items = schema.contains("minItems") ? schema["minItems"].get() : 0; + json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json(); + int max_items = max_items_json.is_number_integer() ? max_items_json.get() : std::numeric_limits::max(); + + return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space"); + } + } else if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) { + return _visit_pattern(schema["pattern"], rule_name); + } else if ((schema_type.is_null() || schema_type == "string") && std::regex_match(schema_format, std::regex("^uuid[1-5]?$"))) { + return _add_primitive(rule_name == "root" ? "root" : schema_format, PRIMITIVE_RULES.at("uuid")); + } else if ((schema_type.is_null() || schema_type == "string") && STRING_FORMAT_RULES.find(schema_format + "-string") != STRING_FORMAT_RULES.end()) { + auto prim_name = schema_format + "-string"; + return _add_rule(rule_name, _add_primitive(prim_name, STRING_FORMAT_RULES.at(prim_name))); + } else if (schema_type == "string" && (schema.contains("minLength") || schema.contains("maxLength"))) { + std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char")); + int min_len = schema.contains("minLength") ? schema["minLength"].get() : 0; + int max_len = schema.contains("maxLength") ? schema["maxLength"].get() : std::numeric_limits::max(); + return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space"); + } else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) { + int min_value = std::numeric_limits::min(); + int max_value = std::numeric_limits::max(); + if (schema.contains("minimum")) { + min_value = schema["minimum"].get(); + } else if (schema.contains("exclusiveMinimum")) { + min_value = schema["exclusiveMinimum"].get() + 1; + } + if (schema.contains("maximum")) { + max_value = schema["maximum"].get(); + } else if (schema.contains("exclusiveMaximum")) { + max_value = schema["exclusiveMaximum"].get() - 1; + } + std::stringstream out; + out << "("; + _build_min_max_int(min_value, max_value, out); + out << ") space"; + return _add_rule(rule_name, out.str()); + } else if (schema.empty() || schema_type == "object") { + return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object"))); + } else { + if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get()) == PRIMITIVE_RULES.end()) { + _errors.push_back("Unrecognized schema: " + schema.dump()); + return ""; + } + // TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero + return _add_primitive(rule_name == "root" ? "root" : schema_type.get(), PRIMITIVE_RULES.at(schema_type.get())); + } + } + + void check_errors() { + if (!_errors.empty()) { + throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n")); + } + if (!_warnings.empty()) { + fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str()); + } + } + + std::string format_grammar() { + std::stringstream ss; + for (const auto & kv : _rules) { + ss << kv.first << " ::= " << kv.second << std::endl; + } + return ss.str(); + } +}; + +std::string json_schema_to_grammar(const json & schema, bool force_gbnf) { +#ifdef LLAMA_USE_LLGUIDANCE + if (!force_gbnf) { + return "%llguidance {}\nstart: %json " + schema.dump(); + } +#else + (void)force_gbnf; +#endif // LLAMA_USE_LLGUIDANCE + return build_grammar([&](const common_grammar_builder & callbacks) { + auto copy = schema; + callbacks.resolve_refs(copy); + callbacks.add_schema("", copy); + }); +} + +std::string build_grammar(const std::function & cb, const common_grammar_options & options) { + SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall); + common_grammar_builder builder { + /* .add_rule = */ [&](const std::string & name, const std::string & rule) { + return converter._add_rule(name, rule); + }, + /* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) { + return converter.visit(schema, name == "root" ? "" : name); + }, + /* .resolve_refs = */ [&](nlohmann::ordered_json & schema) { + converter.resolve_refs(schema, ""); + } + }; + cb(builder); + converter.check_errors(); + return converter.format_grammar(); +} diff --git a/common/json-schema-to-grammar.h b/common/json-schema-to-grammar.h new file mode 100644 index 0000000000000000000000000000000000000000..362991b542682c8085162934613c54bb9f72e49a --- /dev/null +++ b/common/json-schema-to-grammar.h @@ -0,0 +1,21 @@ +#pragma once + +#include + +#include +#include + +std::string json_schema_to_grammar(const nlohmann::ordered_json & schema, + bool force_gbnf = false); + +struct common_grammar_builder { + std::function add_rule; + std::function add_schema; + std::function resolve_refs; +}; + +struct common_grammar_options { + bool dotall = false; +}; + +std::string build_grammar(const std::function & cb, const common_grammar_options & options = {}); diff --git a/common/llguidance.cpp b/common/llguidance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..adce620e4d62fc2563d82d558d96ff2c6da1c94a --- /dev/null +++ b/common/llguidance.cpp @@ -0,0 +1,254 @@ +#include "sampling.h" +#include "log.h" + +#ifdef LLAMA_USE_LLGUIDANCE + +# include "llguidance.h" +# include + +struct llama_sampler_llg { + const llama_vocab * vocab; + std::string grammar_kind; + std::string grammar_data; + LlgTokenizer * tokenizer; + LlgMatcher * grammar; +}; + +static LlgMatcher * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind, + const char * grammar_data) { + LlgConstraintInit cinit; + llg_constraint_init_set_defaults(&cinit, tokenizer); + const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL"); + if (log_level && *log_level) { + cinit.log_stderr_level = atoi(log_level); + } + auto c = llg_new_matcher(&cinit, grammar_kind, grammar_data); + if (llg_matcher_get_error(c)) { + LOG_ERR("llg error: %s\n", llg_matcher_get_error(c)); + llg_free_matcher(c); + return nullptr; + } + + return c; +} + +static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) { + return "llguidance"; +} + +static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_llg *) smpl->ctx; + if (ctx->grammar) { + llg_matcher_consume_token(ctx->grammar, token); + } +} + +static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_llg *) smpl->ctx; + if (ctx->grammar) { + const uint32_t * mask = llg_matcher_get_mask(ctx->grammar); + if (mask == nullptr) { + if (llg_matcher_compute_mask(ctx->grammar) == 0) { + mask = llg_matcher_get_mask(ctx->grammar); + } else { + LOG_ERR("llg error: %s\n", llg_matcher_get_error(ctx->grammar)); + llg_free_matcher(ctx->grammar); + ctx->grammar = nullptr; + return; + } + } + + for (size_t i = 0; i < cur_p->size; ++i) { + auto token = cur_p->data[i].id; + if ((mask[token / 32] & (1 << (token % 32))) == 0) { + cur_p->data[i].logit = -INFINITY; + } + } + } +} + +static void llama_sampler_llg_reset(llama_sampler * smpl) { + auto * ctx = (llama_sampler_llg *) smpl->ctx; + if (ctx->grammar) { + llg_matcher_reset(ctx->grammar); + } +} + +static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_llg *) smpl->ctx; + + auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr); + + // copy the state + { + auto * result_ctx = (llama_sampler_llg *) result->ctx; + + if (ctx->grammar) { + result_ctx->grammar_kind = ctx->grammar_kind; + result_ctx->grammar_data = ctx->grammar_data; + result_ctx->grammar = llg_clone_matcher(ctx->grammar); + result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer); + } + } + + return result; +} + +static void llama_sampler_llg_free(llama_sampler * smpl) { + const auto * ctx = (llama_sampler_llg *) smpl->ctx; + + if (ctx->grammar) { + llg_free_matcher(ctx->grammar); + llg_free_tokenizer(ctx->tokenizer); + } + + delete ctx; +} + +static llama_sampler_i llama_sampler_llg_i = { + /* .name = */ llama_sampler_llg_name, + /* .accept = */ llama_sampler_llg_accept_impl, + /* .apply = */ llama_sampler_llg_apply, + /* .reset = */ llama_sampler_llg_reset, + /* .clone = */ llama_sampler_llg_clone, + /* .free = */ llama_sampler_llg_free, +}; + +static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len, + uint32_t * output_tokens, size_t output_tokens_len) { + const llama_vocab * vocab = (const llama_vocab *) user_data; + int r = 0; + try { + r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false, + true); + } catch (const std::exception & e) { + GGML_ABORT("llama_tokenize failed: %s\n", e.what()); + } + if (r < 0) { + return -r; + } + return r; +} + +static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) { + // TODO store the tokenizer in the vocab somehow + static const llama_vocab * vocab_cache; + static LlgTokenizer * tokenizer_cache; + + if (vocab_cache == vocab) { + return llg_clone_tokenizer(tokenizer_cache); + } + + auto tok_eos = llama_vocab_eot(vocab); + if (tok_eos == LLAMA_TOKEN_NULL) { + tok_eos = llama_vocab_eos(vocab); + } + + size_t vocab_size = llama_vocab_n_tokens(vocab); + + auto token_lens = new uint32_t[vocab_size]; + // we typically have ~7 bytes per token; let's go on the safe side here + auto token_bytes_size = vocab_size * 16 + 1024 * 1024; + auto token_bytes = new uint8_t[token_bytes_size]; + + size_t offset = 0; + for (size_t i = 0; i < vocab_size; i++) { + size_t max_token = 1024; + if (token_bytes_size - offset < max_token) { + GGML_ABORT("token_bytes buffer too small\n"); + } + + llama_token token = i; + auto dp = (char *) token_bytes + offset; + auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false); + if (size < 0) { + GGML_ABORT("llama_detokenize failed\n"); + } + if (size == 0) { + size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true); + if (size < 0) { + GGML_ABORT("llama_detokenize failed\n"); + } + if (size != 0) { + *dp = '\xff'; // special token prefix marker + size += 1; + } + } + + token_lens[i] = size; + offset += size; + } + + LlgTokenizerInit tinit = { + /* .vocab_size = */ (uint32_t) vocab_size, + /* .tok_eos = */ (uint32_t) tok_eos, + /* .token_lens = */ token_lens, + /* .token_bytes = */ token_bytes, + /* .tokenizer_json = */ nullptr, + /* .tokenize_assumes_string = */ true, + /* .tokenize_fn = */ llama_sampler_llg_tokenize_fn, + /* .use_approximate_greedy_tokenize_fn = */ false, + /* .tokenize_user_data = */ vocab, + /* .slices = */ nullptr, + }; + + char error_buffer[1024]; + LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer)); + + delete[] token_bytes; + delete[] token_lens; + + if (tokenizer == nullptr) { + LOG_ERR("llg tokenizer error: %s\n", error_buffer); + return tokenizer; + } + + if (tokenizer_cache) { + llg_free_tokenizer(tokenizer_cache); + } + vocab_cache = vocab; + tokenizer_cache = tokenizer; + + return llg_clone_tokenizer(tokenizer_cache); +} + +llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind, + const char * grammar_data) { + auto * ctx = new llama_sampler_llg; + + if (grammar_kind != nullptr && grammar_kind[0] != '\0') { + auto tokenizer = llama_sampler_llg_new_tokenizer(vocab); + *ctx = { + /* .vocab = */ vocab, + /* .grammar_kind = */ grammar_kind, + /* .grammar_data = */ grammar_data, + /* .tokenizer = */ tokenizer, + /* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data), + }; + if (ctx->grammar) { + GGML_ASSERT(((size_t) llama_vocab_n_tokens(vocab) + 31) / 32 * 4 == + llg_matcher_get_mask_byte_size(ctx->grammar)); + } + } else { + *ctx = { + /* .vocab = */ vocab, + /* .grammar_kind = */ {}, + /* .grammar_data = */ {}, + /* .tokenizer = */ nullptr, + /* .grammar = */ nullptr, + }; + } + + return llama_sampler_init( + /* .iface = */ &llama_sampler_llg_i, + /* .ctx = */ ctx); +} + +#else + +llama_sampler * llama_sampler_init_llg(const llama_vocab *, const char *, const char *) { + LOG_WRN("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); + return nullptr; +} + +#endif // LLAMA_USE_LLGUIDANCE diff --git a/common/log.cpp b/common/log.cpp new file mode 100644 index 0000000000000000000000000000000000000000..52b31470c46bdf241b8623fcaee6968c57d46f3a --- /dev/null +++ b/common/log.cpp @@ -0,0 +1,393 @@ +#include "log.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +int common_log_verbosity_thold = LOG_DEFAULT_LLAMA; + +void common_log_set_verbosity_thold(int verbosity) { + common_log_verbosity_thold = verbosity; +} + +static int64_t t_us() { + return std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); +} + +// colors +enum common_log_col : int { + COMMON_LOG_COL_DEFAULT = 0, + COMMON_LOG_COL_BOLD, + COMMON_LOG_COL_RED, + COMMON_LOG_COL_GREEN, + COMMON_LOG_COL_YELLOW, + COMMON_LOG_COL_BLUE, + COMMON_LOG_COL_MAGENTA, + COMMON_LOG_COL_CYAN, + COMMON_LOG_COL_WHITE, +}; + +// disable colors by default +static std::vector g_col = { + "", + "", + "", + "", + "", + "", + "", + "", + "", +}; + +struct common_log_entry { + enum ggml_log_level level; + + bool prefix; + + int64_t timestamp; + + std::vector msg; + + // signals the worker thread to stop + bool is_end; + + void print(FILE * file = nullptr) const { + FILE * fcur = file; + if (!fcur) { + // stderr displays DBG messages only when their verbosity level is not higher than the threshold + // these messages will still be logged to a file + if (level == GGML_LOG_LEVEL_DEBUG && common_log_verbosity_thold < LOG_DEFAULT_DEBUG) { + return; + } + + fcur = stdout; + + if (level != GGML_LOG_LEVEL_NONE) { + fcur = stderr; + } + } + + if (level != GGML_LOG_LEVEL_NONE && level != GGML_LOG_LEVEL_CONT && prefix) { + if (timestamp) { + // [M.s.ms.us] + fprintf(fcur, "%s%d.%02d.%03d.%03d%s ", + g_col[COMMON_LOG_COL_BLUE], + (int) (timestamp / 1000000 / 60), + (int) (timestamp / 1000000 % 60), + (int) (timestamp / 1000 % 1000), + (int) (timestamp % 1000), + g_col[COMMON_LOG_COL_DEFAULT]); + } + + switch (level) { + case GGML_LOG_LEVEL_INFO: fprintf(fcur, "%sI %s", g_col[COMMON_LOG_COL_GREEN], g_col[COMMON_LOG_COL_DEFAULT]); break; + case GGML_LOG_LEVEL_WARN: fprintf(fcur, "%sW %s", g_col[COMMON_LOG_COL_MAGENTA], "" ); break; + case GGML_LOG_LEVEL_ERROR: fprintf(fcur, "%sE %s", g_col[COMMON_LOG_COL_RED], "" ); break; + case GGML_LOG_LEVEL_DEBUG: fprintf(fcur, "%sD %s", g_col[COMMON_LOG_COL_YELLOW], "" ); break; + default: + break; + } + } + + fprintf(fcur, "%s", msg.data()); + + if (level == GGML_LOG_LEVEL_WARN || level == GGML_LOG_LEVEL_ERROR || level == GGML_LOG_LEVEL_DEBUG) { + fprintf(fcur, "%s", g_col[COMMON_LOG_COL_DEFAULT]); + } + + fflush(fcur); + } +}; + +struct common_log { + // default capacity - will be expanded if needed + common_log() : common_log(256) {} + + common_log(size_t capacity) { + file = nullptr; + prefix = false; + timestamps = false; + running = false; + t_start = t_us(); + + // initial message size - will be expanded if longer messages arrive + entries.resize(capacity); + for (auto & entry : entries) { + entry.msg.resize(256); + } + + head = 0; + tail = 0; + + resume(); + } + + ~common_log() { + pause(); + if (file) { + fclose(file); + } + } + +private: + std::mutex mtx; + std::thread thrd; + std::condition_variable cv; + + FILE * file; + + bool prefix; + bool timestamps; + bool running; + + int64_t t_start; + + // ring buffer of entries + std::vector entries; + size_t head; + size_t tail; + + // worker thread copies into this + common_log_entry cur; + +public: + void add(enum ggml_log_level level, const char * fmt, va_list args) { + std::lock_guard lock(mtx); + + if (!running) { + // discard messages while the worker thread is paused + return; + } + + auto & entry = entries[tail]; + + { + // cannot use args twice, so make a copy in case we need to expand the buffer + va_list args_copy; + va_copy(args_copy, args); + +#if 1 + const size_t n = vsnprintf(entry.msg.data(), entry.msg.size(), fmt, args); + if (n >= entry.msg.size()) { + entry.msg.resize(n + 1); + vsnprintf(entry.msg.data(), entry.msg.size(), fmt, args_copy); + } +#else + // hack for bolding arguments + + std::stringstream ss; + for (int i = 0; fmt[i] != 0; i++) { + if (fmt[i] == '%') { + ss << LOG_COL_BOLD; + while (fmt[i] != ' ' && fmt[i] != ')' && fmt[i] != ']' && fmt[i] != 0) ss << fmt[i++]; + ss << LOG_COL_DEFAULT; + if (fmt[i] == 0) break; + } + ss << fmt[i]; + } + const size_t n = vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args); + if (n >= entry.msg.size()) { + entry.msg.resize(n + 1); + vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args_copy); + } +#endif + va_end(args_copy); + } + + entry.level = level; + entry.prefix = prefix; + entry.timestamp = 0; + if (timestamps) { + entry.timestamp = t_us() - t_start; + } + entry.is_end = false; + + tail = (tail + 1) % entries.size(); + if (tail == head) { + // expand the buffer + std::vector new_entries(2*entries.size()); + + size_t new_tail = 0; + + do { + new_entries[new_tail] = std::move(entries[head]); + + head = (head + 1) % entries.size(); + new_tail = (new_tail + 1); + } while (head != tail); + + head = 0; + tail = new_tail; + + for (size_t i = tail; i < new_entries.size(); i++) { + new_entries[i].msg.resize(256); + } + + entries = std::move(new_entries); + } + + cv.notify_one(); + } + + void resume() { + std::lock_guard lock(mtx); + + if (running) { + return; + } + + running = true; + + thrd = std::thread([this]() { + while (true) { + { + std::unique_lock lock(mtx); + cv.wait(lock, [this]() { return head != tail; }); + + cur = entries[head]; + + head = (head + 1) % entries.size(); + } + + if (cur.is_end) { + break; + } + + cur.print(); // stdout and stderr + + if (file) { + cur.print(file); + } + } + }); + } + + void pause() { + { + std::lock_guard lock(mtx); + + if (!running) { + return; + } + + running = false; + + // push an entry to signal the worker thread to stop + { + auto & entry = entries[tail]; + entry.is_end = true; + + tail = (tail + 1) % entries.size(); + } + + cv.notify_one(); + } + + thrd.join(); + } + + void set_file(const char * path) { + pause(); + + if (file) { + fclose(file); + } + + if (path) { + file = fopen(path, "w"); + } else { + file = nullptr; + } + + resume(); + } + + void set_colors(bool colors) { + pause(); + + if (colors) { + g_col[COMMON_LOG_COL_DEFAULT] = LOG_COL_DEFAULT; + g_col[COMMON_LOG_COL_BOLD] = LOG_COL_BOLD; + g_col[COMMON_LOG_COL_RED] = LOG_COL_RED; + g_col[COMMON_LOG_COL_GREEN] = LOG_COL_GREEN; + g_col[COMMON_LOG_COL_YELLOW] = LOG_COL_YELLOW; + g_col[COMMON_LOG_COL_BLUE] = LOG_COL_BLUE; + g_col[COMMON_LOG_COL_MAGENTA] = LOG_COL_MAGENTA; + g_col[COMMON_LOG_COL_CYAN] = LOG_COL_CYAN; + g_col[COMMON_LOG_COL_WHITE] = LOG_COL_WHITE; + } else { + for (size_t i = 0; i < g_col.size(); i++) { + g_col[i] = ""; + } + } + + resume(); + } + + void set_prefix(bool prefix) { + std::lock_guard lock(mtx); + + this->prefix = prefix; + } + + void set_timestamps(bool timestamps) { + std::lock_guard lock(mtx); + + this->timestamps = timestamps; + } +}; + +// +// public API +// + +struct common_log * common_log_init() { + return new common_log; +} + +struct common_log * common_log_main() { + static struct common_log log; + + return &log; +} + +void common_log_pause(struct common_log * log) { + log->pause(); +} + +void common_log_resume(struct common_log * log) { + log->resume(); +} + +void common_log_free(struct common_log * log) { + delete log; +} + +void common_log_add(struct common_log * log, enum ggml_log_level level, const char * fmt, ...) { + va_list args; + va_start(args, fmt); + log->add(level, fmt, args); + va_end(args); +} + +void common_log_set_file(struct common_log * log, const char * file) { + log->set_file(file); +} + +void common_log_set_colors(struct common_log * log, bool colors) { + log->set_colors(colors); +} + +void common_log_set_prefix(struct common_log * log, bool prefix) { + log->set_prefix(prefix); +} + +void common_log_set_timestamps(struct common_log * log, bool timestamps) { + log->set_timestamps(timestamps); +} diff --git a/common/log.h b/common/log.h new file mode 100644 index 0000000000000000000000000000000000000000..c56bb50d95db06483a4f71f689ad7bf010c0321b --- /dev/null +++ b/common/log.h @@ -0,0 +1,103 @@ +#pragma once + +#include "ggml.h" // for ggml_log_level + +#define LOG_CLR_TO_EOL "\033[K\r" +#define LOG_COL_DEFAULT "\033[0m" +#define LOG_COL_BOLD "\033[1m" +#define LOG_COL_RED "\033[31m" +#define LOG_COL_GREEN "\033[32m" +#define LOG_COL_YELLOW "\033[33m" +#define LOG_COL_BLUE "\033[34m" +#define LOG_COL_MAGENTA "\033[35m" +#define LOG_COL_CYAN "\033[36m" +#define LOG_COL_WHITE "\033[37m" + +#ifndef __GNUC__ +# define LOG_ATTRIBUTE_FORMAT(...) +#elif defined(__MINGW32__) && !defined(__clang__) +# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif + +#define LOG_DEFAULT_DEBUG 1 +#define LOG_DEFAULT_LLAMA 0 + +// needed by the LOG_TMPL macro to avoid computing log arguments if the verbosity lower +// set via common_log_set_verbosity() +extern int common_log_verbosity_thold; + +void common_log_set_verbosity_thold(int verbosity); // not thread-safe + +// the common_log uses an internal worker thread to print/write log messages +// when the worker thread is paused, incoming log messages are discarded +struct common_log; + +struct common_log * common_log_init(); +struct common_log * common_log_main(); // singleton, automatically destroys itself on exit +void common_log_pause (struct common_log * log); // pause the worker thread, not thread-safe +void common_log_resume(struct common_log * log); // resume the worker thread, not thread-safe +void common_log_free (struct common_log * log); + +LOG_ATTRIBUTE_FORMAT(3, 4) +void common_log_add(struct common_log * log, enum ggml_log_level level, const char * fmt, ...); + +// defaults: file = NULL, colors = false, prefix = false, timestamps = false +// +// regular log output: +// +// ggml_backend_metal_log_allocated_size: allocated buffer, size = 6695.84 MiB, ( 6695.91 / 21845.34) +// llm_load_tensors: ggml ctx size = 0.27 MiB +// llm_load_tensors: offloading 32 repeating layers to GPU +// llm_load_tensors: offloading non-repeating layers to GPU +// +// with prefix = true, timestamps = true, the log output will look like this: +// +// 0.00.035.060 D ggml_backend_metal_log_allocated_size: allocated buffer, size = 6695.84 MiB, ( 6695.91 / 21845.34) +// 0.00.035.064 I llm_load_tensors: ggml ctx size = 0.27 MiB +// 0.00.090.578 I llm_load_tensors: offloading 32 repeating layers to GPU +// 0.00.090.579 I llm_load_tensors: offloading non-repeating layers to GPU +// +// I - info (stdout, V = 0) +// W - warning (stderr, V = 0) +// E - error (stderr, V = 0) +// D - debug (stderr, V = LOG_DEFAULT_DEBUG) +// + +void common_log_set_file (struct common_log * log, const char * file); // not thread-safe +void common_log_set_colors (struct common_log * log, bool colors); // not thread-safe +void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log +void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix + +// helper macros for logging +// use these to avoid computing log arguments if the verbosity of the log is higher than the threshold +// +// for example: +// +// LOG_DBG("this is a debug message: %d\n", expensive_function()); +// +// this will avoid calling expensive_function() if LOG_DEFAULT_DEBUG > common_log_verbosity_thold +// + +#define LOG_TMPL(level, verbosity, ...) \ + do { \ + if ((verbosity) <= common_log_verbosity_thold) { \ + common_log_add(common_log_main(), (level), __VA_ARGS__); \ + } \ + } while (0) + +#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, 0, __VA_ARGS__) +#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__) + +#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, 0, __VA_ARGS__) +#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, 0, __VA_ARGS__) +#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, 0, __VA_ARGS__) +#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_DEFAULT_DEBUG, __VA_ARGS__) +#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, 0, __VA_ARGS__) + +#define LOG_INFV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_INFO, verbosity, __VA_ARGS__) +#define LOG_WRNV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_WARN, verbosity, __VA_ARGS__) +#define LOG_ERRV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, verbosity, __VA_ARGS__) +#define LOG_DBGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, verbosity, __VA_ARGS__) +#define LOG_CNTV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_CONT, verbosity, __VA_ARGS__) diff --git a/common/ngram-cache.cpp b/common/ngram-cache.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d1a4d84c40f1c742ece7213139b1325ac62a8edc --- /dev/null +++ b/common/ngram-cache.cpp @@ -0,0 +1,286 @@ +#include "ngram-cache.h" +#include "common.h" +#include "log.h" + +#include +#include +#include +#include +#include +#include + +void common_ngram_cache_update(common_ngram_cache & ngram_cache, int ngram_min, int ngram_max, + std::vector & inp, int nnew, bool print_progress) { + const int64_t t_start_ms = ggml_time_ms(); + const int64_t inp_size = inp.size(); + + const int64_t n_todo = inp_size * (ngram_max - ngram_min + 1); + int64_t n_done = 0; + + for (int64_t ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) { + const int64_t i_start = std::max(inp_size - nnew, ngram_size); + for (int64_t i = i_start; i < inp_size; ++i) { + const int64_t ngram_start = i - ngram_size; + common_ngram ngram(&inp[ngram_start], ngram_size); + const llama_token token = inp[i]; + + common_ngram_cache::iterator part_it = ngram_cache.find(ngram); + if (part_it == ngram_cache.end()) { + common_ngram_cache_part part; + part.emplace(token, 1); + ngram_cache.emplace(ngram, part); + } else { + common_ngram_cache_part::iterator token_count_it = part_it->second.find(token); + if (token_count_it == part_it->second.end()) { + part_it->second.emplace(token, 1); + } else { + token_count_it->second++; + } + } + ++n_done; + + if (print_progress && n_done % 10000000 == 0) { + const int64_t t_now_ms = ggml_time_ms(); + const int64_t eta_ms = (inp_size*(ngram_max-ngram_min+1) - n_done) * (t_now_ms - t_start_ms) / n_done; + const int64_t eta_min = eta_ms / (60*1000); + const int64_t eta_s = (eta_ms - 60*1000*eta_min) / 1000; + + fprintf(stderr, "%s: %" PRId64 "/%" PRId64 " done, ETA: %02" PRId64 ":%02" PRId64 "\n", __func__, n_done, n_todo, eta_min, eta_s); + } + } + } +} + +// Helper function to get a token from the combined, speculative sequence of inp and draft. +static llama_token get_token(const std::vector & inp, const std::vector & draft, const size_t i) { + return i < inp.size() ? inp[i] : draft[1 + i - inp.size()]; +} + +// If sample size or percentage are below these thresholds the draft is aborted early: +constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 2, 2, 1, 1}; +constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {66, 50, 50, 50}; +constexpr int draft_min_sample_size_strict[LLAMA_NGRAM_MAX] = { 4, 3, 2, 2}; +constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {75, 66, 66, 66}; + +// Helper function that tries to draft a token from only the static ngram cache: +static llama_token try_draft(common_ngram_cache & nc_static, const common_ngram ngram_static) { + common_ngram_cache::iterator part_static_it = nc_static.find(ngram_static); + if (part_static_it == nc_static.end()) { + return LLAMA_TOKEN_NULL; + } + const common_ngram_cache_part part_static = part_static_it->second; + + int max_count_static = 0; + int sum_count_static = 0; + llama_token max_token = LLAMA_TOKEN_NULL; + + for (std::pair token_count_static : part_static) { + const llama_token token = token_count_static.first; + const int32_t count_static = token_count_static.second; + + if (count_static > max_count_static) { + max_token = token; + max_count_static = count_static; + } + sum_count_static += count_static; + } + + if (sum_count_static < draft_min_sample_size_lax[LLAMA_NGRAM_STATIC-1]) { + return LLAMA_TOKEN_NULL; + } + if (100*max_count_static < draft_min_percent_lax[LLAMA_NGRAM_STATIC-1]*sum_count_static) { + return LLAMA_TOKEN_NULL; + } + return max_token; +} + +// Try to draft a token from primary cache (context/dynamic), validate with static cache: +static llama_token try_draft( + common_ngram_cache & nc_primary, const std::vector & ngrams_primary, common_ngram_cache_part & part_static, + const int * min_sample_size, const int * min_percent) { + + llama_token drafted_token = LLAMA_TOKEN_NULL; + + for (int i = ngrams_primary.size()-1; i >= 0 && drafted_token == LLAMA_TOKEN_NULL; --i) { + const common_ngram ngram_primary = ngrams_primary[i]; + + common_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary); + if (part_primary_it == nc_primary.end()) { + continue; + } + const common_ngram_cache_part part_primary = part_primary_it->second; + + int max_count_primary = 0; + int max_count_static = 0; + int sum_count_primary = 0; + llama_token max_token = LLAMA_TOKEN_NULL; + + for (std::pair token_count_primary : part_primary) { + const llama_token token = token_count_primary.first; + + common_ngram_cache_part::iterator token_count_static_it = part_static.find(token); + + const int32_t count_primary = token_count_primary.second; + const int32_t count_static = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1; + + if (count_primary*count_static > max_count_primary*max_count_static) { + max_token = token; + max_count_primary = count_primary; + max_count_static = count_static; + } + sum_count_primary += count_primary; + } + + if (sum_count_primary < min_sample_size[i]) { + continue; + } + if (100*max_count_primary < min_percent[i]*sum_count_primary) { + continue;; + } + drafted_token = max_token; + } + + return drafted_token; +} + +void common_ngram_cache_draft( + std::vector & inp, std::vector & draft, int n_draft, int ngram_min, int ngram_max, + common_ngram_cache & nc_context, common_ngram_cache & nc_dynamic, common_ngram_cache & nc_static +) { + GGML_ASSERT(draft.size() == 1); + const int inp_size = inp.size(); + + if (inp_size < LLAMA_NGRAM_STATIC) { + return; + } + + while ((int) draft.size()-1 < n_draft) { + llama_token drafted_token = LLAMA_TOKEN_NULL; + + const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size()-1; + common_ngram ngram_static; + for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) { + ngram_static.tokens[j-ngram_start_static] = get_token(inp, draft, j); + } + common_ngram_cache::iterator part_static_it = nc_static.find(ngram_static); + common_ngram_cache_part part_static; + if (part_static_it != nc_static.end()) { + part_static = part_static_it->second; + } + + // cd = context + dynamic + std::vector ngrams_cd; + for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) { + const int ngram_start_cd = inp_size-ngram_size_cd + draft.size()-1; + common_ngram ngram_cd; + for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) { + ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, draft, j); + } + ngrams_cd.push_back(ngram_cd); + } + if (drafted_token == LLAMA_TOKEN_NULL) { + drafted_token = try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax); + } + if (drafted_token == LLAMA_TOKEN_NULL) { + drafted_token = try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_strict); + } + if (drafted_token == LLAMA_TOKEN_NULL) { + drafted_token = try_draft(nc_static, ngram_static); + } + + if (drafted_token == LLAMA_TOKEN_NULL) { + break; + } + + LOG(" - draft candidate: token=%d\n", drafted_token); + draft.push_back(drafted_token); + } +} + +void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename) { + std::ofstream file_out(filename, std::ios::binary); + for (std::pair item : ngram_cache) { + const common_ngram ngram = item.first; + common_ngram_cache_part token_counts = item.second; + GGML_ASSERT(!token_counts.empty()); + const int32_t ntokens = token_counts.size(); + GGML_ASSERT(ntokens > 0); + + file_out.write(reinterpret_cast(&ngram), sizeof(common_ngram)); + file_out.write(reinterpret_cast(&ntokens), sizeof(int32_t)); + for (std::pair item2 : token_counts) { + const llama_token token = item2.first; + const int32_t count = item2.second; + GGML_ASSERT(count > 0); + + file_out.write(reinterpret_cast(&token), sizeof(llama_token)); + file_out.write(reinterpret_cast(&count), sizeof(int32_t)); + } + } + +} + +common_ngram_cache common_ngram_cache_load(std::string & filename) { + std::ifstream hashmap_file(filename, std::ios::binary); + if (!hashmap_file) { + throw std::ifstream::failure("Unable to open file " + filename); + } + common_ngram_cache ngram_cache; + + common_ngram ngram; + int32_t ntokens; + llama_token token; + int32_t count; + + char * ngramc = reinterpret_cast(&ngram); + char * ntokensc = reinterpret_cast(&ntokens); + char * tokenc = reinterpret_cast(&token); + char * countc = reinterpret_cast(&count); + while(hashmap_file.read(ngramc, sizeof(common_ngram))) { + GGML_ASSERT(!hashmap_file.eof()); + GGML_ASSERT(hashmap_file.read(ntokensc, sizeof(int32_t))); + GGML_ASSERT(ntokens > 0); + common_ngram_cache_part token_counts; + + for (int i = 0; i < ntokens; ++i) { + GGML_ASSERT(!hashmap_file.eof()); + GGML_ASSERT(hashmap_file.read(tokenc, sizeof(llama_token))); + GGML_ASSERT(!hashmap_file.eof()); + GGML_ASSERT(hashmap_file.read(countc, sizeof(int32_t))); + GGML_ASSERT(count > 0); + token_counts.emplace(token, count); + } + + ngram_cache.emplace(ngram, token_counts); + } + GGML_ASSERT(hashmap_file.eof()); + + return ngram_cache; +} + +void common_ngram_cache_merge(common_ngram_cache & ngram_cache_target, common_ngram_cache & ngram_cache_add) { + for (std::pair ngram_part : ngram_cache_add) { + const common_ngram ngram = ngram_part.first; + common_ngram_cache_part part = ngram_part.second; + + common_ngram_cache::iterator part_merged_it = ngram_cache_target.find(ngram); + if (part_merged_it == ngram_cache_target.end()) { + ngram_cache_target.emplace(ngram, part); + continue; + } + + for (std::pair token_count : part) { + const llama_token token = token_count.first; + const int32_t count = token_count.second; + GGML_ASSERT(count > 0); + + common_ngram_cache_part::iterator token_count_merged_it = part_merged_it->second.find(token); + if (token_count_merged_it == part_merged_it->second.end()) { + part_merged_it->second.emplace(token, count); + continue; + } + + token_count_merged_it->second += count; + } + } +} diff --git a/common/ngram-cache.h b/common/ngram-cache.h new file mode 100644 index 0000000000000000000000000000000000000000..dfe012abe493dc03d2cda367572adc96d61f7b7c --- /dev/null +++ b/common/ngram-cache.h @@ -0,0 +1,101 @@ +#pragma once + +#include "llama.h" + +#include +#include +#include + +#define LLAMA_NGRAM_MIN 1 +#define LLAMA_NGRAM_MAX 4 +#define LLAMA_NGRAM_STATIC 2 + +// Data structures to map n-grams to empirical token probabilities: + +struct common_ngram { + llama_token tokens[LLAMA_NGRAM_MAX]; + + common_ngram() { + for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) { + tokens[i] = LLAMA_TOKEN_NULL; + } + } + + common_ngram(const llama_token * input, const int ngram_size) { + for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) { + tokens[i] = i < ngram_size ? input[i] : LLAMA_TOKEN_NULL; + } + } + + bool operator==(const common_ngram & other) const { + for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) { + if (tokens[i] != other.tokens[i]) { + return false; + } + } + return true; + } +}; + +struct common_token_hash_function { + size_t operator()(const llama_token token) const { + // see https://probablydance.com/2018/06/16/fibonacci-hashing-the-optimization-that-the-world-forgot-or-a-better-alternative-to-integer-modulo/ + return token * 11400714819323198485llu; + } +}; + +struct common_ngram_hash_function { + size_t operator()(const common_ngram & ngram) const { + size_t hash = common_token_hash_function{}(ngram.tokens[0]); + for (int i = 1; i < LLAMA_NGRAM_MAX; ++i) { + hash ^= common_token_hash_function{}(ngram.tokens[i]); + } + return hash; + } +}; + +// token -> number of times token has been seen +typedef std::unordered_map common_ngram_cache_part; + +// n-gram -> empirical distribution of following tokens +typedef std::unordered_map common_ngram_cache; + + +// Update an ngram cache with tokens. +// ngram_cache: the cache to modify. +// ngram_min/ngram_max: the min/max size of the ngrams to extract from inp_data. +// inp_data: the token sequence with which to update ngram_cache. +// nnew: how many new tokens have been appended to inp_data since the last call to this function. +// print_progress: whether to print progress to stderr. +// +// In order to get correct results inp_data can ONLY BE APPENDED TO. +// Changes in the middle need a complete rebuild. +void common_ngram_cache_update( + common_ngram_cache & ngram_cache, int ngram_min, int ngram_max, std::vector & inp_data, int nnew, bool print_progress); + +// Try to draft tokens from ngram caches. +// inp: the tokens generated so far. +// draft: the token sequence to draft. Expected to initially contain the previously sampled token. +// n_draft: maximum number of tokens to add to draft. +// ngram_min/gram_max: the min/max size of the ngrams in nc_context and nc_dynamic. +// nc_context: ngram cache based on current context. +// nc_dynamic: ngram cache based on previous user generations. +// nc_static: ngram cache generated from a large text corpus, used for validation. +void common_ngram_cache_draft( + std::vector & inp, std::vector & draft, int n_draft, int ngram_min, int ngram_max, + common_ngram_cache & nc_context, common_ngram_cache & nc_dynamic, common_ngram_cache & nc_static); + +// Save an ngram cache to a file. +// ngram_cache: the ngram cache to save. +// filename: the path under which to save the ngram cache. +void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename); + +// Load an ngram cache saved with common_ngram_cache_save. +// filename: the path from which to load the ngram cache. +// returns: an ngram cache containing the information saved to filename. +common_ngram_cache common_ngram_cache_load(std::string & filename); + +// Merge two ngram caches. +// ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add. +// ngram_cache_add: the ngram cache to add to ngram_cache_target. +void common_ngram_cache_merge(common_ngram_cache & ngram_cache_target, common_ngram_cache & ngram_cache_add); diff --git a/common/regex-partial.cpp b/common/regex-partial.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4bff6b66336e246283e133cb929960722a6e9cd3 --- /dev/null +++ b/common/regex-partial.cpp @@ -0,0 +1,204 @@ +#include "regex-partial.h" +#include "common.h" +#include +#include + +common_regex::common_regex(const std::string & pattern) : + pattern(pattern), + rx(pattern), + rx_reversed_partial(regex_to_reversed_partial_regex(pattern)) {} + +common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const { + std::smatch match; + if (pos > input.size()) { + throw std::runtime_error("Position out of bounds"); + } + auto start = input.begin() + pos; + auto found = as_match + ? std::regex_match(start, input.end(), match, rx) + : std::regex_search(start, input.end(), match, rx); + if (found) { + common_regex_match res; + res.type = COMMON_REGEX_MATCH_TYPE_FULL; + for (size_t i = 0; i < match.size(); ++i) { + auto begin = pos + match.position(i); + res.groups.emplace_back(begin, begin + match.length(i)); + } + return res; + } + std::match_results srmatch; + if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) { + auto group = srmatch[1].str(); + if (group.length() != 0) { + auto it = srmatch[1].second.base(); + // auto position = static_cast(std::distance(input.begin(), it)); + if ((!as_match) || it == input.begin()) { + common_regex_match res; + res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL; + const size_t begin = std::distance(input.begin(), it); + const size_t end = input.size(); + if (begin == std::string::npos || end == std::string::npos || begin > end) { + throw std::runtime_error("Invalid range"); + } + res.groups.push_back({begin, end}); + return res; + } + } + } + return {}; +} + +/* + Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern. + + Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html) + to see if a string ends with a partial regex match, but but it's not in std::regex yet. + Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input. + + - /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).* + - /a|b/ -> (a|b).* + - /a*?/ -> error, could match "" + - /a*b/ -> ((?:b)?a*+).* (final repetitions become eager) + - /.*?ab/ -> ((?:b)?a).* (merge .*) + - /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches) + - /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).* + - /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).* + - /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).* + + The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern + (i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored) +*/ +std::string regex_to_reversed_partial_regex(const std::string & pattern) { + auto it = pattern.begin(); + const auto end = pattern.end(); + + std::function process = [&]() { + std::vector> alternatives(1); + std::vector * sequence = &alternatives.back(); + + while (it != end) { + if (*it == '[') { + auto start = it; + ++it; + while (it != end) { + if ((*it == '\\') && (++it != end)) { + ++it; + } else if ((it != end) && (*it == ']')) { + break; + } else { + ++it; + } + } + if (it == end) { + throw std::runtime_error("Unmatched '[' in pattern"); + } + ++it; + sequence->push_back(std::string(start, it)); + } else if (*it == '*' || *it == '?' || *it == '+') { + if (sequence->empty()) { + throw std::runtime_error("Quantifier without preceding element"); + } + sequence->back() += *it; + auto is_star = *it == '*'; + ++it; + if (is_star) { + if (*it == '?') { + ++it; + } + } + } else if (*it == '{') { + if (sequence->empty()) { + throw std::runtime_error("Repetition without preceding element"); + } + ++it; + auto start = it; + while (it != end && *it != '}') { + ++it; + } + if (it == end) { + throw std::runtime_error("Unmatched '{' in pattern"); + } + auto parts = string_split(std::string(start, it), ","); + ++it; + if (parts.size() > 2) { + throw std::runtime_error("Invalid repetition range in pattern"); + } + + auto parseOptInt = [&](const std::string & s, const std::optional & def = std::nullopt) -> std::optional { + if (s.empty()) { + return def; + } + return std::stoi(s); + }; + auto min = parseOptInt(parts[0], 0); + auto max = parts.size() == 1 ? min : parseOptInt(parts[1]); + if (min && max && *max < *min) { + throw std::runtime_error("Invalid repetition range in pattern"); + } + // Brutal but... let's repeat at least min times, then ? for the delta between min & max (or * for unbounded) + auto part = sequence->back(); + sequence->pop_back(); + for (int i = 0; i < *min; i++) { + sequence->push_back(part); + } + if (max) { + for (int i = *min; i < *max; i++) { + sequence->push_back(part + "?"); + } + } else { + sequence->push_back(part + "*"); + } + } else if (*it == '(') { + ++it; + if (it != end && *it == '?' && (it + 1 != end) && *(it + 1) == ':') { + it += 2; + } + auto sub = process(); + if (*it != ')') { + throw std::runtime_error("Unmatched '(' in pattern"); + } + ++it; + auto & part = sequence->emplace_back("(?:"); + part += sub; + part += ")"; + } else if (*it == ')') { + break; + } else if (*it == '|') { + ++it; + alternatives.emplace_back(); + sequence = &alternatives.back(); + } else if (*it == '\\' && (++it != end)) { + auto str = std::string("\\") + *it; + sequence->push_back(str); + ++it; + } else if (it != end) { + sequence->push_back(std::string(1, *it)); + ++it; + } + } + + // /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).* + // if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group + // We'll do the outermost capturing group and final .* in the enclosing function. + std::vector res_alts; + for (const auto & parts : alternatives) { + auto & res = res_alts.emplace_back(); + for (size_t i = 0; i < parts.size() - 1; i++) { + res += "(?:"; + } + for (auto it = parts.rbegin(); it != parts.rend(); ++it) { + res += *it; + if (it != parts.rend() - 1) { + res += ")?"; + } + } + } + return string_join(res_alts, "|"); + }; + auto res = process(); + if (it != end) { + throw std::runtime_error("Unmatched '(' in pattern"); + } + + return "(" + res + ")[\\s\\S]*"; +} diff --git a/common/regex-partial.h b/common/regex-partial.h new file mode 100644 index 0000000000000000000000000000000000000000..634cb4022bd1dea664ad98d2cd44c188813ee48f --- /dev/null +++ b/common/regex-partial.h @@ -0,0 +1,56 @@ +#pragma once + +#include +#include + +enum common_regex_match_type { + COMMON_REGEX_MATCH_TYPE_NONE, + COMMON_REGEX_MATCH_TYPE_PARTIAL, + COMMON_REGEX_MATCH_TYPE_FULL, +}; + +struct common_string_range { + size_t begin; + size_t end; + common_string_range(size_t begin, size_t end) : begin(begin), end(end) { + if (begin > end) { + throw std::runtime_error("Invalid range"); + } + } + // prevent default ctor + common_string_range() = delete; + bool empty() const { + return begin == end; + } + bool operator==(const common_string_range & other) const { + return begin == other.begin && end == other.end; + } +}; + +struct common_regex_match { + common_regex_match_type type = COMMON_REGEX_MATCH_TYPE_NONE; + std::vector groups; + + bool operator==(const common_regex_match & other) const { + return type == other.type && groups == other.groups; + } + bool operator!=(const common_regex_match & other) const { + return !(*this == other); + } +}; + +class common_regex { + std::string pattern; + std::regex rx; + std::regex rx_reversed_partial; + + public: + explicit common_regex(const std::string & pattern); + + common_regex_match search(const std::string & input, size_t pos, bool as_match = false) const; + + const std::string & str() const { return pattern; } +}; + +// For testing only (pretty print of failures). +std::string regex_to_reversed_partial_regex(const std::string & pattern); diff --git a/common/sampling.cpp b/common/sampling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9c04d35fd00a290a0710a82788dc9c754a37e25b --- /dev/null +++ b/common/sampling.cpp @@ -0,0 +1,579 @@ +#include "sampling.h" + +#include "common.h" +#include "log.h" + +#include +#include +#include + +// the ring buffer works similarly to std::deque, but with a fixed capacity +// TODO: deduplicate with llama-impl.h +template +struct ring_buffer { + ring_buffer(size_t cap) : capacity(cap), data(cap) {} + + T & front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; + } + + const T & front() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; + } + + T & back() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + const T & back() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + void push_back(const T & value) { + if (sz == capacity) { + // advance the start when buffer is full + first = (first + 1) % capacity; + } else { + sz++; + } + data[pos] = value; + pos = (pos + 1) % capacity; + } + + T pop_front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + T value = data[first]; + first = (first + 1) % capacity; + sz--; + return value; + } + + const T & rat(size_t i) const { + if (i >= sz) { + throw std::runtime_error("ring buffer: index out of bounds"); + } + return data[(first + sz - i - 1) % capacity]; + } + + std::vector to_vector() const { + std::vector result; + result.reserve(sz); + for (size_t i = 0; i < sz; i++) { + result.push_back(data[(first + i) % capacity]); + } + return result; + } + + void clear() { + // here only reset the status of the buffer + sz = 0; + first = 0; + pos = 0; + } + + bool empty() const { + return sz == 0; + } + + size_t size() const { + return sz; + } + + size_t capacity = 0; + size_t sz = 0; + size_t first = 0; + size_t pos = 0; + std::vector data; +}; + +struct common_sampler { + common_params_sampling params; + + struct llama_sampler * grmr; + struct llama_sampler * chain; + + ring_buffer prev; + + std::vector cur; + + llama_token_data_array cur_p; + + void set_logits(struct llama_context * ctx, int idx) { + const auto * logits = llama_get_logits_ith(ctx, idx); + + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_vocab = llama_vocab_n_tokens(vocab); + + cur.resize(n_vocab); + + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } + + cur_p = { cur.data(), cur.size(), -1, false }; + } +}; + +std::string common_params_sampling::print() const { + char result[1024]; + + snprintf(result, sizeof(result), + "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" + "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" + "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n" + "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", + penalty_last_n, penalty_repeat, penalty_freq, penalty_present, + dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, + top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp, + mirostat, mirostat_eta, mirostat_tau); + + return std::string(result); +} + +struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { + const llama_vocab * vocab = llama_model_get_vocab(model); + + llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); + + lparams.no_perf = params.no_perf; + + struct llama_sampler * grmr; + if (params.grammar.compare(0, 11, "%llguidance") == 0) { +#ifdef LLAMA_USE_LLGUIDANCE + grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()); +#else + GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); +#endif // LLAMA_USE_LLGUIDANCE + } else { + std::vector trigger_patterns; + std::vector patterns_anywhere; + std::vector trigger_tokens; + for (const auto & trigger : params.grammar_triggers) { + switch (trigger.type) { + case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: + { + const auto & word = trigger.value; + patterns_anywhere.push_back(regex_escape(word)); + break; + } + case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: + { + patterns_anywhere.push_back(trigger.value); + break; + } + case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL: + { + trigger_patterns.push_back(trigger.value); + break; + } + case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN: + { + const auto token = trigger.token; + trigger_tokens.push_back(token); + break; + } + default: + GGML_ASSERT(false && "unknown trigger type"); + } + } + + if (!patterns_anywhere.empty()) { + trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*"); + } + + std::vector trigger_patterns_c; + trigger_patterns_c.reserve(trigger_patterns.size()); + for (const auto & regex : trigger_patterns) { + trigger_patterns_c.push_back(regex.c_str()); + } + + grmr = params.grammar_lazy + ? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root", + trigger_patterns_c.data(), trigger_patterns_c.size(), + trigger_tokens.data(), trigger_tokens.size()) + : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); + if (!grmr) { + return nullptr; + } + } + + auto * result = new common_sampler { + /* .params = */ params, + /* .grmr = */ grmr, + /* .chain = */ llama_sampler_chain_init(lparams), + /* .prev = */ ring_buffer(std::max(32, params.n_prev)), + /* .cur = */ {}, + /* .cur_p = */ {}, + }; + + llama_sampler_chain_add(result->chain, + llama_sampler_init_logit_bias( + llama_vocab_n_tokens(vocab), + params.logit_bias.size(), + params.logit_bias.data())); + + if (params.mirostat == 0) { + for (const auto & cnstr : params.samplers) { + switch (cnstr) { + case COMMON_SAMPLER_TYPE_DRY: + { + std::vector c_breakers; + c_breakers.reserve(params.dry_sequence_breakers.size()); + for (const auto & str : params.dry_sequence_breakers) { + c_breakers.push_back(str.c_str()); + } + + llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); + } + break; + case COMMON_SAMPLER_TYPE_TOP_K: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); + break; + case COMMON_SAMPLER_TYPE_TOP_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma)); + break; + case COMMON_SAMPLER_TYPE_MIN_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_XTC: + llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); + break; + case COMMON_SAMPLER_TYPE_TYPICAL_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_TEMPERATURE: + llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + break; + case COMMON_SAMPLER_TYPE_INFILL: + llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab)); + break; + case COMMON_SAMPLER_TYPE_PENALTIES: + llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + break; + default: + GGML_ASSERT(false && "unknown sampler type"); + } + } + llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); + } else if (params.mirostat == 1) { + llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); + llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); + } else if (params.mirostat == 2) { + llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); + llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); + } else { + GGML_ASSERT(false && "unknown mirostat version"); + } + + return result; +} + +void common_sampler_free(struct common_sampler * gsmpl) { + if (gsmpl) { + llama_sampler_free(gsmpl->grmr); + + llama_sampler_free(gsmpl->chain); + + delete gsmpl; + } +} + +void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) { + if (accept_grammar) { + llama_sampler_accept(gsmpl->grmr, token); + } + + llama_sampler_accept(gsmpl->chain, token); + + gsmpl->prev.push_back(token); +} + +void common_sampler_reset(struct common_sampler * gsmpl) { + llama_sampler_reset(gsmpl->grmr); + + llama_sampler_reset(gsmpl->chain); +} + +struct common_sampler * common_sampler_clone(common_sampler * gsmpl) { + return new common_sampler { + /* .params = */ gsmpl->params, + /* .grmr = */ llama_sampler_clone(gsmpl->grmr), + /* .chain = */ llama_sampler_clone(gsmpl->chain), + /* .prev = */ gsmpl->prev, + /* .cur = */ gsmpl->cur, + /* .cur_p = */ gsmpl->cur_p, + }; +} + +void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) { + // TODO: measure grammar performance + + if (gsmpl) { + llama_perf_sampler_print(gsmpl->chain); + } + if (ctx) { + llama_perf_context_print(ctx); + } +} + +llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { + gsmpl->set_logits(ctx, idx); + + auto & grmr = gsmpl->grmr; + auto & chain = gsmpl->chain; + auto & cur_p = gsmpl->cur_p; // initialized by set_logits + + if (grammar_first) { + llama_sampler_apply(grmr, &cur_p); + } + + llama_sampler_apply(chain, &cur_p); + + GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); + + const llama_token id = cur_p.data[cur_p.selected].id; + + if (grammar_first) { + return id; + } + + // check if it the sampled token fits the grammar + { + llama_token_data single_token_data = { id, 1.0f, 0.0f }; + llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false }; + + llama_sampler_apply(grmr, &single_token_data_array); + + const bool is_valid = single_token_data_array.data[0].logit != -INFINITY; + if (is_valid) { + return id; + } + } + + // resampling: + // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain + gsmpl->set_logits(ctx, idx); + + llama_sampler_apply(grmr, &cur_p); + llama_sampler_apply(chain, &cur_p); + + GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration"); + + return cur_p.data[cur_p.selected].id; +} + +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first) { + GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); + + std::vector result; + result.reserve(idxs.size()); + + size_t i = 0; + for (; i < draft.size(); i++) { + const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); + + common_sampler_accept(gsmpl, id, true); + + result.push_back(id); + + if (draft[i] != id) { + break; + } + } + + if (i == draft.size()) { + const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); + + common_sampler_accept(gsmpl, id, true); + + result.push_back(id); + } + + return result; +} + +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) { + std::vector idxs(draft.size() + 1); + for (size_t i = 0; i < idxs.size(); ++i) { + idxs[i] = i; + } + + return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first); +} + +uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { + return llama_sampler_get_seed(gsmpl->chain); +} + +// helpers + +llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) { + return &gsmpl->cur_p; +} + +llama_token common_sampler_last(const struct common_sampler * gsmpl) { + return gsmpl->prev.rat(0); +} + +std::string common_sampler_print(const struct common_sampler * gsmpl) { + std::string result = "logits "; + + for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) { + const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i); + result += std::string("-> ") + llama_sampler_name(smpl) + " "; + } + + return result; +} + +std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_main, int n) { + n = std::min(n, (int) gsmpl->prev.size()); + + if (n <= 0) { + return ""; + } + + std::string result; + result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab + + for (int i = n - 1; i >= 0; i--) { + const llama_token id = gsmpl->prev.rat(i); + + GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen"); + + result += common_token_to_piece(ctx_main, id); + } + + return result; +} + +char common_sampler_type_to_chr(enum common_sampler_type cnstr) { + switch (cnstr) { + case COMMON_SAMPLER_TYPE_DRY: return 'd'; + case COMMON_SAMPLER_TYPE_TOP_K: return 'k'; + case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y'; + case COMMON_SAMPLER_TYPE_TOP_P: return 'p'; + case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's'; + case COMMON_SAMPLER_TYPE_MIN_P: return 'm'; + case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't'; + case COMMON_SAMPLER_TYPE_XTC: return 'x'; + case COMMON_SAMPLER_TYPE_INFILL: return 'i'; + case COMMON_SAMPLER_TYPE_PENALTIES: return 'e'; + default : return '?'; + } +} + +std::string common_sampler_type_to_str(enum common_sampler_type cnstr) { + switch (cnstr) { + case COMMON_SAMPLER_TYPE_DRY: return "dry"; + case COMMON_SAMPLER_TYPE_TOP_K: return "top_k"; + case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p"; + case COMMON_SAMPLER_TYPE_TOP_P: return "top_p"; + case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma"; + case COMMON_SAMPLER_TYPE_MIN_P: return "min_p"; + case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature"; + case COMMON_SAMPLER_TYPE_XTC: return "xtc"; + case COMMON_SAMPLER_TYPE_INFILL: return "infill"; + case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties"; + default : return ""; + } +} + +std::vector common_sampler_types_from_names(const std::vector & names, bool allow_alt_names) { + std::unordered_map sampler_canonical_name_map { + { "dry", COMMON_SAMPLER_TYPE_DRY }, + { "top_k", COMMON_SAMPLER_TYPE_TOP_K }, + { "top_p", COMMON_SAMPLER_TYPE_TOP_P }, + { "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, + { "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P }, + { "min_p", COMMON_SAMPLER_TYPE_MIN_P }, + { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE }, + { "xtc", COMMON_SAMPLER_TYPE_XTC }, + { "infill", COMMON_SAMPLER_TYPE_INFILL }, + { "penalties", COMMON_SAMPLER_TYPE_PENALTIES }, + }; + + // since samplers names are written multiple ways + // make it ready for both system names and input names + std::unordered_map sampler_alt_name_map { + { "top-k", COMMON_SAMPLER_TYPE_TOP_K }, + { "top-p", COMMON_SAMPLER_TYPE_TOP_P }, + { "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, + { "nucleus", COMMON_SAMPLER_TYPE_TOP_P }, + { "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P }, + { "typical", COMMON_SAMPLER_TYPE_TYPICAL_P }, + { "typ-p", COMMON_SAMPLER_TYPE_TYPICAL_P }, + { "typ", COMMON_SAMPLER_TYPE_TYPICAL_P }, + { "min-p", COMMON_SAMPLER_TYPE_MIN_P }, + { "temp", COMMON_SAMPLER_TYPE_TEMPERATURE }, + }; + + std::vector samplers; + samplers.reserve(names.size()); + + for (const auto & name : names) { + auto sampler = sampler_canonical_name_map.find(name); + if (sampler != sampler_canonical_name_map.end()) { + samplers.push_back(sampler->second); + continue; + } + if (allow_alt_names) { + sampler = sampler_alt_name_map.find(name); + if (sampler != sampler_alt_name_map.end()) { + samplers.push_back(sampler->second); + continue; + } + } + LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name.c_str()); + } + + return samplers; +} + +std::vector common_sampler_types_from_chars(const std::string & chars) { + std::unordered_map sampler_name_map = { + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES }, + }; + + std::vector samplers; + samplers.reserve(chars.size()); + + for (const auto & c : chars) { + const auto sampler = sampler_name_map.find(c); + if (sampler != sampler_name_map.end()) { + samplers.push_back(sampler->second); + } else { + LOG_WRN("%s: unable to match sampler by char '%c'\n", __func__, c); + } + } + + return samplers; +} diff --git a/common/sampling.h b/common/sampling.h new file mode 100644 index 0000000000000000000000000000000000000000..2064421db4e80237e0dc8cc3acbcef4894000416 --- /dev/null +++ b/common/sampling.h @@ -0,0 +1,107 @@ +#pragma once + +#include "llama.h" + +#include "common.h" + +#include +#include + +// common_sampler extends llama_sampler with additional functionality: +// +// - grammar support +// - custom sampler logic based on the parameters +// - history of the last accepted tokens +// - performance metrics +// +// This goal is to have a common implementation of the sampling logic shared across the examples. +// For example, depending on the temperature, the sampling chain can be very simple (greedy) or more +// complex (top-k, top-p, etc). +// +// Another example is related to the grammar. In general, the grammar constraints applied on the full +// vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled +// token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the +// grammar constraints are applied to the full vocabulary and the token is resampled. +// +// The common_sampler also maintains a container with the last accepted tokens. In the future, this can +// be moved into the core llama library. +// +// For convenience, the common_sampler also maintains a container with the current candidate tokens. +// This can be used to access the probabilities of the rest of the non-sampled tokens. +// +// TODO: measure grammar performance +// + +struct common_sampler; + +// llama_sampler API overloads + +struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params); + +void common_sampler_free(struct common_sampler * gsmpl); + +// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar +void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar); +void common_sampler_reset (struct common_sampler * gsmpl); +struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl); + +// arguments can be nullptr to skip printing +void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl); + +// extended sampling implementation: +// +// - set logits +// - apply the configured sampler chain +// - check if the token fits the grammar (if any) +// - if not: resample by first applying the grammar constraints and then sampling again (slower path) +// +// if grammar_first is true, the grammar is applied before the samplers (slower) +// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar +// +llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); + +// generalized version of common_sampler_sample +// +// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match +// if the sampler disagrees at some point, we stop and return the accepted tokens up to now +// +// common_sampler_sample_n(gsmpl, ctx, { idx }, {}); +// +// is equivalent to +// +// common_sampler_sample(gsmpl, ctx, idx); +// common_sampler_accept(gsmpl, token, true); +// +// requires: idxs.size() == draft.size() + 1 +// +// returns at least 1 token, up to idxs.size() +// +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first = false); + +// assume idxs == [ 0, 1, 2, ..., draft.size() ] +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false); + +uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); + +// helpers + +// access the internal list of current candidate tokens +llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl); + +// get the last accepted token +llama_token common_sampler_last(const struct common_sampler * gsmpl); + +// print the sampler chain into a string +std::string common_sampler_print(const struct common_sampler * gsmpl); + +// get a string representation of the last accepted tokens +std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx, int n); + +char common_sampler_type_to_chr(enum common_sampler_type cnstr); +std::string common_sampler_type_to_str(enum common_sampler_type cnstr); + +std::vector common_sampler_types_from_names(const std::vector & names, bool allow_alt_names); +std::vector common_sampler_types_from_chars(const std::string & chars); + +llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, + const char * grammar_kind, const char * grammar_data); diff --git a/common/speculative.cpp b/common/speculative.cpp new file mode 100644 index 0000000000000000000000000000000000000000..843bd1ddbdbd79090a9a119c84e5f06bdac394be --- /dev/null +++ b/common/speculative.cpp @@ -0,0 +1,280 @@ +#include "speculative.h" + +#include "log.h" +#include "common.h" +#include "sampling.h" + +#include +#include + +#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 +#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 + +struct common_speculative { + struct llama_context * ctx; + struct common_sampler * smpl; + + llama_batch batch; + llama_tokens prompt; +}; + +struct common_speculative * common_speculative_init( + struct llama_context * ctx_dft) { + auto * result = new common_speculative { + /* .ctx = */ ctx_dft, + /* .smpl = */ nullptr, + /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), + /* .prompt = */ {}, + }; + + // TODO: optimize or pass from outside? +#if 0 + { + common_params_sampling params; + params.no_perf = false; + + params.top_k = 40; + params.top_p = 0.9; + + params.samplers = { + COMMON_SAMPLER_TYPE_TOP_K, + COMMON_SAMPLER_TYPE_TOP_P, + COMMON_SAMPLER_TYPE_INFILL, + }; + + result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); + } +#else + { + common_params_sampling params; + params.no_perf = false; + + params.top_k = 10; + + params.samplers = { + COMMON_SAMPLER_TYPE_TOP_K, + }; + + result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); + } +#endif + + return result; +} + +void common_speculative_free(struct common_speculative * spec) { + if (spec == nullptr) { + return; + } + + common_sampler_free(spec->smpl); + + llama_batch_free(spec->batch); + + delete spec; +} + +bool common_speculative_are_compatible( + const struct llama_context * ctx_tgt, + const struct llama_context * ctx_dft) { + const struct llama_model * model_tgt = llama_get_model(ctx_tgt); + const struct llama_model * model_dft = llama_get_model(ctx_dft); + + const struct llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); + const struct llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); + + const bool vocab_type_tgt = llama_vocab_type(vocab_tgt); + LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt); + + const bool vocab_type_dft = llama_vocab_type(vocab_dft); + LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft); + + if (vocab_type_tgt != vocab_type_dft) { + LOG_ERR("%s: draft model vocab type must match target model to use speculation but " + "vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt); + return false; + } + + if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) || + llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) || + llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) || + llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)) { + LOG_ERR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__); + LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_tgt), llama_vocab_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_tgt)); + LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_dft), llama_vocab_get_add_bos(vocab_dft), llama_vocab_eos(vocab_dft), llama_vocab_get_add_eos(vocab_dft)); + return false; + } + + { + const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt); + const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft); + + const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft); + + if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { + LOG_ERR("%s: draft model vocab must closely match target model to use speculation but " + "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", + __func__, n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); + return false; + } + + for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { + const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i); + const char * token_text_dft = llama_vocab_get_text(vocab_dft, i); + if (std::strcmp(token_text_tgt, token_text_dft) != 0) { + LOG_ERR("%s: draft vocab vocab must match target vocab to use speculation but " + "token %d content differs - target '%s', draft '%s'\n", __func__, i, + common_token_to_piece(ctx_tgt, i).c_str(), + common_token_to_piece(ctx_dft, i).c_str()); + return false; + } + } + } + + return true; +} + +llama_tokens common_speculative_gen_draft( + struct common_speculative * spec, + struct common_speculative_params params, + const llama_tokens & prompt_tgt, + llama_token id_last) { + auto & batch = spec->batch; + auto & ctx = spec->ctx; + auto & smpl = spec->smpl; + auto & prompt = spec->prompt; + + auto * mem = llama_get_memory(ctx); + + int reuse_i = 0; + int reuse_n = 0; + + const int n_ctx = llama_n_ctx(ctx) - params.n_draft; + + const int i_start = std::max(0, (int) prompt_tgt.size() - n_ctx); + + // reuse as much as possible from the old draft context + // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt + for (int i = 0; i < (int) prompt.size(); ++i) { + int cur = 0; + while (i_start + cur < (int) prompt_tgt.size() && + i + cur < (int) prompt.size() && + prompt_tgt[i_start + cur] == prompt[i + cur]) { + cur++; + } + + if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) { + reuse_i = i; + reuse_n = cur; + } + } + + LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size()); + + llama_tokens result; + result.reserve(params.n_draft); + + if (reuse_n == 0) { + llama_memory_clear(mem, false); + + prompt.clear(); + } else { + // this happens when a previous draft has been discarded (for example, due to being too small), but the + // target model agreed with it. in this case, we simply pass back the previous results to save compute + if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) { + for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) { + result.push_back(prompt[i]); + + if (params.n_draft <= (int) result.size()) { + break; + } + } + + return result; + } + + if (reuse_i > 0) { + llama_memory_seq_rm (mem, 0, 0, reuse_i); + llama_memory_seq_add(mem, 0, reuse_i, -1, -reuse_i); + + prompt.erase(prompt.begin(), prompt.begin() + reuse_i); + } + + if (reuse_n < (int) prompt.size()) { + llama_memory_seq_rm (mem, 0, reuse_n, -1); + + prompt.erase(prompt.begin() + reuse_n, prompt.end()); + } + } + + // prepare a batch to evaluate any new tokens in the prompt + common_batch_clear(batch); + + for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) { + //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); + common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false); + + prompt.push_back(prompt_tgt[i]); + } + + // we should rarely end-up here during normal decoding + if (batch.n_tokens > 0) { + //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); + + llama_decode(ctx, batch); + } + + const llama_pos n_past = prompt.size(); + + LOG_DBG("%s: n_past = %d\n", __func__, n_past); + + common_batch_clear(batch); + common_batch_add (batch, id_last, n_past, { 0 }, true); + + prompt.push_back(id_last); + + //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str()); + + llama_decode(ctx, batch); + + common_sampler_reset(smpl); + + // sample n_draft tokens from the draft model + for (int i = 0; i < params.n_draft; ++i) { + common_batch_clear(batch); + + common_sampler_sample(smpl, ctx, 0, true); + + const auto * cur_p = common_sampler_get_candidates(smpl); + + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); + } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + common_sampler_accept(smpl, id, true); + + result.push_back(id); + + if (params.n_draft <= (int) result.size()) { + break; + } + + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < params.p_min) { + break; + } + + common_batch_add(batch, id, n_past + i + 1, { 0 }, true); + + // evaluate the drafted tokens on the draft model + llama_decode(ctx, batch); + + prompt.push_back(id); + } + + return result; +} diff --git a/common/speculative.h b/common/speculative.h new file mode 100644 index 0000000000000000000000000000000000000000..2b51a70ca1f72ca56ffbcaddfadea33ccd489cbb --- /dev/null +++ b/common/speculative.h @@ -0,0 +1,28 @@ +#pragma once + +#include "llama.h" +#include "common.h" + +struct common_speculative; + +struct common_speculative_params { + int n_draft = 16; // max drafted tokens + int n_reuse = 256; + + float p_min = 0.75f; // min probability required to accept a token in the draft +}; + +struct common_speculative * common_speculative_init(struct llama_context * ctx_dft); + +void common_speculative_free(struct common_speculative * spec); + +bool common_speculative_are_compatible( + const struct llama_context * ctx_tgt, + const struct llama_context * ctx_dft); + +// sample up to n_draft tokens and add them to the batch using the draft model +llama_tokens common_speculative_gen_draft( + struct common_speculative * spec, + struct common_speculative_params params, + const llama_tokens & prompt, + llama_token id_last); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py new file mode 100644 index 0000000000000000000000000000000000000000..58e455ae645ede217f88b915c05e7931ed1ffa79 --- /dev/null +++ b/convert_hf_to_gguf.py @@ -0,0 +1,6578 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import ast +import logging +import argparse +import contextlib +import json +import os +import re +import sys +from enum import IntEnum +from pathlib import Path +from hashlib import sha256 +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast +from itertools import chain +from transformers import AutoConfig + +import math +import numpy as np +import torch + +if TYPE_CHECKING: + from torch import Tensor + +if 'NO_LOCAL_GGUF' not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) +import gguf + +logger = logging.getLogger("hf-to-gguf") + + +###### MODEL DEFINITIONS ###### + +class SentencePieceTokenTypes(IntEnum): + NORMAL = 1 + UNKNOWN = 2 + CONTROL = 3 + USER_DEFINED = 4 + UNUSED = 5 + BYTE = 6 + + +class ModelType(IntEnum): + TEXT = 1 + MMPROJ = 2 + + +AnyModel = TypeVar("AnyModel", bound="type[ModelBase]") + + +class ModelBase: + _model_classes: dict[ModelType, dict[str, type[ModelBase]]] = { + ModelType.TEXT: {}, + ModelType.MMPROJ: {}, + } + + dir_model: Path + ftype: gguf.LlamaFileType + fname_out: Path + is_big_endian: bool + endianess: gguf.GGUFEndian + use_temp_file: bool + lazy: bool + part_names: list[str] + is_safetensors: bool + hparams: dict[str, Any] + tensor_names: set[str] | None + gguf_writer: gguf.GGUFWriter + model_name: str | None + metadata_override: Path | None + dir_model_card: Path + remote_hf_model_id: str | None + + # subclasses should define this! + model_arch: gguf.MODEL_ARCH + + # subclasses should initialize this! + block_count: int + tensor_map: gguf.TensorNameMap + + def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False, + use_temp_file: bool = False, eager: bool = False, + metadata_override: Path | None = None, model_name: str | None = None, + split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, + small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None): + if type(self) is ModelBase or \ + type(self) is TextModel or \ + type(self) is MmprojModel: + raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") + + self.dir_model = dir_model + self.ftype = ftype + self.fname_out = fname_out + self.is_big_endian = is_big_endian + self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE + self.use_temp_file = use_temp_file + self.lazy = not eager or (remote_hf_model_id is not None) + self.remote_hf_model_id = remote_hf_model_id + if remote_hf_model_id is not None: + self.is_safetensors = True + + def get_remote_tensors() -> Iterator[tuple[str, Tensor]]: + logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}") + remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id) + self.tensor_names = set(name for name in remote_tensors.keys()) + for name, remote_tensor in gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id).items(): + yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor)) + + self.get_tensors = get_remote_tensors + else: + self.part_names = ModelBase.get_model_part_names(self.dir_model, "model", ".safetensors") + self.is_safetensors = len(self.part_names) > 0 + if not self.is_safetensors: + self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin") + self.hparams = ModelBase.load_hparams(self.dir_model) if hparams is None else hparams + self.tensor_names = None + self.metadata_override = metadata_override + self.model_name = model_name + self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py + + # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type + if self.ftype == gguf.LlamaFileType.GUESSED: + # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie. + _, first_tensor = next(self.get_tensors()) + if first_tensor.dtype == torch.float16: + logger.info(f"choosing --outtype f16 from first tensor type ({first_tensor.dtype})") + self.ftype = gguf.LlamaFileType.MOSTLY_F16 + else: + logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})") + self.ftype = gguf.LlamaFileType.MOSTLY_BF16 + + # Configure GGUF Writer + self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, + split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard) + + @classmethod + def add_prefix_to_filename(cls, path: Path, prefix: str) -> Path: + stem, suffix = path.stem, path.suffix + new_name = f"{prefix}{stem}{suffix}" + return path.with_name(new_name) + + def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any: + key = next((k for k in keys if k in self.hparams), None) + if key is not None: + return self.hparams[key] + if optional: + return None + raise KeyError(f"could not find any of: {keys}") + + def get_tensors(self) -> Iterator[tuple[str, Tensor]]: + tensor_names_from_parts: set[str] = set() + + index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin" + index_name += ".index.json" + index_file = self.dir_model / index_name + + if index_file.is_file(): + self.tensor_names = set() + logger.info(f"gguf: loading model weight map from '{index_name}'") + with open(index_file, "r", encoding="utf-8") as f: + index: dict[str, Any] = json.load(f) + weight_map = index.get("weight_map") + if weight_map is None or not isinstance(weight_map, dict): + raise ValueError(f"Can't load 'weight_map' from {index_name!r}") + self.tensor_names.update(weight_map.keys()) + else: + self.tensor_names = tensor_names_from_parts + weight_map = {} + + for part_name in self.part_names: + logger.info(f"gguf: loading model part '{part_name}'") + ctx: ContextManager[Any] + if self.is_safetensors: + from safetensors import safe_open + ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu")) + else: + ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True)) + + with ctx as model_part: + tensor_names_from_parts.update(model_part.keys()) + + for name in model_part.keys(): + if self.is_safetensors: + if self.lazy: + data = model_part.get_slice(name) + data = LazyTorchTensor.from_safetensors_slice(data) + else: + data = model_part.get_tensor(name) + else: + data = model_part[name] + if self.lazy: + data = LazyTorchTensor.from_eager(data) + yield name, data + + # verify tensor name presence and identify potentially missing files + if len(tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0: + missing = sorted(self.tensor_names.difference(tensor_names_from_parts)) + extra = sorted(tensor_names_from_parts.difference(self.tensor_names)) + missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map)) + if len(extra) == 0 and len(missing_files) > 0: + raise ValueError(f"Missing or incomplete model files: {missing_files}\n" + f"Missing tensors: {missing}") + else: + raise ValueError("Mismatch between weight map and model parts for tensor names:\n" + f"Missing tensors: {missing}\n" + f"Extra tensors: {extra}") + + def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str: + if key not in gguf.MODEL_TENSORS[self.model_arch]: + raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}") + name: str = gguf.TENSOR_NAMES[key] + if "{bid}" in name: + assert bid is not None + name = name.format(bid=bid) + return name + suffix + + def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int | None, suffix: str = ".weight") -> bool: + if key not in gguf.MODEL_TENSORS[self.model_arch]: + return False + key_name: str = gguf.TENSOR_NAMES[key] + if "{bid}" in key_name: + if bid is None: + return False + key_name = key_name.format(bid=bid) + else: + if bid is not None: + return False + return name == (key_name + suffix) + + def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str: + new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes) + if new_name is None: + raise ValueError(f"Can not map tensor {name!r}") + return new_name + + def set_gguf_parameters(self): + raise NotImplementedError("set_gguf_parameters() must be implemented in subclasses") + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + return [(self.map_tensor_name(name), data_torch)] + + def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: + del name, new_name, bid, n_dims # unused + + return False + + # some models need extra generated tensors (like rope_freqs) + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + return () + + def prepare_tensors(self): + max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") + + for name, data_torch in chain(self.generate_extra_tensors(), self.get_tensors()): + # we don't need these + if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): + continue + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + # use the first number-like part of the tensor name as the block id + bid = None + for part in name.split("."): + if part.isdecimal(): + bid = int(part) + break + + for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)): + # TODO: why do we squeeze here? + # data = data_torch.squeeze().numpy() + data = data_torch.numpy() + + # if data ends up empty, it means data_torch was a scalar tensor -> restore + if len(data.shape) == 0: + data = data_torch.numpy() + + n_dims = len(data.shape) + data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims) + + # Most of the codebase that takes in 1D tensors or norms only handles F32 tensors + if n_dims <= 1 or new_name.endswith("_norm.weight"): + data_qtype = gguf.GGMLQuantizationType.F32 + + # Conditions should closely match those in llama_model_quantize_internal in llama.cpp + # Some tensor types are always in float32 + if data_qtype is False and ( + any( + self.match_model_tensor_name(new_name, key, bid) + for key in ( + gguf.MODEL_TENSOR.FFN_GATE_INP, + gguf.MODEL_TENSOR.POS_EMBD, + gguf.MODEL_TENSOR.TOKEN_TYPES, + gguf.MODEL_TENSOR.SSM_CONV1D, + gguf.MODEL_TENSOR.TIME_MIX_FIRST, + gguf.MODEL_TENSOR.TIME_MIX_W1, + gguf.MODEL_TENSOR.TIME_MIX_W2, + gguf.MODEL_TENSOR.TIME_MIX_DECAY_W1, + gguf.MODEL_TENSOR.TIME_MIX_DECAY_W2, + gguf.MODEL_TENSOR.TIME_MIX_LERP_FUSED, + gguf.MODEL_TENSOR.POSNET_NORM1, + gguf.MODEL_TENSOR.POSNET_NORM2, + gguf.MODEL_TENSOR.V_ENC_EMBD_POS, + gguf.MODEL_TENSOR.A_ENC_EMBD_POS, + ) + ) + or not new_name.endswith(".weight") + ): + data_qtype = gguf.GGMLQuantizationType.F32 + + if data_qtype is False and any( + self.match_model_tensor_name(new_name, key, bid) + for key in ( + gguf.MODEL_TENSOR.TOKEN_EMBD, + gguf.MODEL_TENSOR.OUTPUT, + ) + ): + if self.ftype in ( + gguf.LlamaFileType.MOSTLY_TQ1_0, + gguf.LlamaFileType.MOSTLY_TQ2_0, + ): + # TODO: use Q4_K and Q6_K + data_qtype = gguf.GGMLQuantizationType.F16 + + # No override (data_qtype is False), or wants to be quantized (data_qtype is True) + if isinstance(data_qtype, bool): + if self.ftype == gguf.LlamaFileType.ALL_F32: + data_qtype = gguf.GGMLQuantizationType.F32 + elif self.ftype == gguf.LlamaFileType.MOSTLY_F16: + data_qtype = gguf.GGMLQuantizationType.F16 + elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16: + data_qtype = gguf.GGMLQuantizationType.BF16 + elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0: + data_qtype = gguf.GGMLQuantizationType.Q8_0 + elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ1_0: + data_qtype = gguf.GGMLQuantizationType.TQ1_0 + elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ2_0: + data_qtype = gguf.GGMLQuantizationType.TQ2_0 + else: + raise ValueError(f"Unknown file type: {self.ftype.name}") + + try: + data = gguf.quants.quantize(data, data_qtype) + except gguf.QuantError as e: + logger.warning("%s, %s", e, "falling back to F16") + data_qtype = gguf.GGMLQuantizationType.F16 + data = gguf.quants.quantize(data, data_qtype) + + shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape + + # reverse shape to make it similar to the internal ggml dimension order + shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}" + + # n_dims is implicit in the shape + logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") + + self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype) + + def set_type(self): + self.gguf_writer.add_type(gguf.GGUFType.MODEL) + + def prepare_metadata(self, vocab_only: bool): + + total_params, shared_params, expert_params, expert_count = self.gguf_writer.get_total_parameter_count() + + self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model_card, self.model_name, total_params) + + # If we are using HF model id, set the metadata name to the model id + if self.remote_hf_model_id: + self.metadata.name = self.remote_hf_model_id + + # Fallback to model directory name if metadata name is still missing + if self.metadata.name is None: + self.metadata.name = self.dir_model.name + + # Generate parameter weight class (useful for leader boards) if not yet determined + if self.metadata.size_label is None and total_params > 0: + self.metadata.size_label = gguf.size_label(total_params, shared_params, expert_params, expert_count) + + self.set_type() + + logger.info("Set meta model") + self.metadata.set_gguf_meta_model(self.gguf_writer) + + logger.info("Set model parameters") + self.set_gguf_parameters() + + logger.info("Set model quantization version") + self.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION) + + def write_vocab(self): + raise NotImplementedError("write_vocab() must be implemented in subclasses") + + def write(self): + self.prepare_tensors() + self.prepare_metadata(vocab_only=False) + self.gguf_writer.write_header_to_file(path=self.fname_out) + self.gguf_writer.write_kv_data_to_file() + self.gguf_writer.write_tensors_to_file(progress=True) + self.gguf_writer.close() + + @staticmethod + def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]: + part_names: list[str] = [] + for filename in os.listdir(dir_model): + if filename.startswith(prefix) and filename.endswith(suffix): + part_names.append(filename) + + part_names.sort() + + return part_names + + @staticmethod + def load_hparams(dir_model: Path): + try: + # for security reason, we don't allow loading remote code by default + # if a model need remote code, we will fallback to config.json + config = AutoConfig.from_pretrained(dir_model, trust_remote_code=False).to_dict() + except Exception as e: + logger.warning(f"Failed to load model config from {dir_model}: {e}") + logger.warning("Trying to load config.json instead") + with open(dir_model / "config.json", "r", encoding="utf-8") as f: + config = json.load(f) + if "llm_config" in config: + # rename for InternVL + config["text_config"] = config["llm_config"] + if "thinker_config" in config: + # rename for Qwen2.5-Omni + config["text_config"] = config["thinker_config"]["text_config"] + return config + + @classmethod + def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]: + assert names + + def func(modelcls: AnyModel) -> AnyModel: + model_type = ModelType.MMPROJ if modelcls.model_arch == gguf.MODEL_ARCH.MMPROJ else ModelType.TEXT + for name in names: + cls._model_classes[model_type][name] = modelcls + return modelcls + return func + + @classmethod + def print_registered_models(cls): + for model_type, model_classes in cls._model_classes.items(): + logger.error(f"{model_type.name} models:") + for name in sorted(model_classes.keys()): + logger.error(f" - {name}") + + @classmethod + def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type[ModelBase]: + try: + return cls._model_classes[model_type][arch] + except KeyError: + raise NotImplementedError(f'Architecture {arch!r} not supported!') from None + + +class TextModel(ModelBase): + model_type = ModelType.TEXT + hf_arch: str + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.hf_arch = get_model_architecture(self.hparams, self.model_type) + + if "text_config" in self.hparams: + # move the text_config to the root level + self.hparams = {**self.hparams, **self.hparams["text_config"]} + + self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + @classmethod + def __init_subclass__(cls): + # can't use an abstract property, because overriding it without type errors + # would require using decorated functions instead of simply defining the property + if "model_arch" not in cls.__dict__: + raise TypeError(f"Missing property 'model_arch' for {cls.__name__!r}") + + def set_vocab(self): + self._set_vocab_gpt2() + + def prepare_metadata(self, vocab_only: bool): + super().prepare_metadata(vocab_only=vocab_only) + + total_params = self.gguf_writer.get_total_parameter_count()[0] + # Extract the encoding scheme from the file type name. e.g. 'gguf.LlamaFileType.MOSTLY_Q8_0' --> 'Q8_0' + output_type: str = self.ftype.name.partition("_")[2] + + # Filename Output + if self.fname_out.is_dir(): + # Generate default filename based on model specification and available metadata + if not vocab_only: + fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, self.metadata.size_label, output_type, model_type="LoRA" if total_params < 0 else None) + else: + fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, size_label=None, output_type=None, model_type="vocab") + + # Use the default filename + self.fname_out = self.fname_out / f"{fname_default}.gguf" + else: + # Output path is a custom defined templated filename + # Note: `not is_dir()` is used because `.is_file()` will not detect + # file template strings as it doesn't actually exist as a file + + # Process templated file name with the output ftype, useful with the "auto" ftype + self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type) + + logger.info("Set model tokenizer") + self.set_vocab() + + def set_gguf_parameters(self): + self.gguf_writer.add_block_count(self.block_count) + + if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx", "n_positions", "max_length"], optional=True)) is not None: + self.gguf_writer.add_context_length(n_ctx) + logger.info(f"gguf: context length = {n_ctx}") + + if (n_embd := self.find_hparam(["hidden_size", "n_embd", "dim"], optional=True)) is not None: + self.gguf_writer.add_embedding_length(n_embd) + logger.info(f"gguf: embedding length = {n_embd}") + + if (n_ff := self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"], optional=True)) is not None: + self.gguf_writer.add_feed_forward_length(n_ff) + logger.info(f"gguf: feed forward length = {n_ff}") + + if (n_head := self.find_hparam(["num_attention_heads", "n_head", "n_heads"], optional=True)) is not None: + self.gguf_writer.add_head_count(n_head) + logger.info(f"gguf: head count = {n_head}") + + if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None: + self.gguf_writer.add_head_count_kv(n_head_kv) + logger.info(f"gguf: key-value head count = {n_head_kv}") + + if (rope_theta := self.hparams.get("rope_theta")) is not None: + self.gguf_writer.add_rope_freq_base(rope_theta) + logger.info(f"gguf: rope theta = {rope_theta}") + if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None: + self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps) + logger.info(f"gguf: rms norm epsilon = {f_rms_eps}") + if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None: + self.gguf_writer.add_layer_norm_eps(f_norm_eps) + logger.info(f"gguf: layer norm epsilon = {f_norm_eps}") + if (n_experts := self.hparams.get("num_local_experts")) is not None: + self.gguf_writer.add_expert_count(n_experts) + logger.info(f"gguf: expert count = {n_experts}") + if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: + self.gguf_writer.add_expert_used_count(n_experts_used) + logger.info(f"gguf: experts used count = {n_experts_used}") + + if (head_dim := self.hparams.get("head_dim")) is not None: + # Workaround for incorrect AutoConfig value for DeepSeekV3 (is set correctly in DeepSeekV2Model class) + # https://github.com/huggingface/transformers/blob/19224c3642705c5b6988c9f5f4251f83323d05ae/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py#L210 + if self.hparams.get("model_type") != "deepseek_v3": + self.gguf_writer.add_key_length(head_dim) + self.gguf_writer.add_value_length(head_dim) + + self.gguf_writer.add_file_type(self.ftype) + logger.info(f"gguf: file type = {self.ftype}") + + def write_vocab(self): + if len(self.gguf_writer.tensors) != 1: + raise ValueError('Splitting the vocabulary is not supported') + + self.prepare_metadata(vocab_only=True) + self.gguf_writer.write_header_to_file(path=self.fname_out) + self.gguf_writer.write_kv_data_to_file() + self.gguf_writer.close() + + def does_token_look_special(self, token: str | bytes) -> bool: + if isinstance(token, (bytes, bytearray)): + token_text = token.decode(encoding="utf-8") + elif isinstance(token, memoryview): + token_text = token.tobytes().decode(encoding="utf-8") + else: + token_text = token + + # Some models mark some added tokens which ought to be control tokens as not special. + # (e.g. command-r, command-r-plus, deepseek-coder, gemma{,-2}) + seems_special = token_text in ( + "", # deepseek-coder + "", "<2mass>", "[@BOS@]", # gemma{,-2} + ) + + seems_special = seems_special or (token_text.startswith("<|") and token_text.endswith("|>")) + seems_special = seems_special or (token_text.startswith("<|") and token_text.endswith("|>")) # deepseek-coder + + # TODO: should these be marked as UNUSED instead? (maybe not) + seems_special = seems_special or (token_text.startswith("")) # gemma{,-2} + + return seems_special + + # used for GPT-2 BPE and WordPiece vocabs + def get_vocab_base(self) -> tuple[list[str], list[int], str]: + tokens: list[str] = [] + toktypes: list[int] = [] + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) + vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab)) + assert max(tokenizer.vocab.values()) < vocab_size + + tokpre = self.get_vocab_base_pre(tokenizer) + + reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} + added_vocab = tokenizer.get_added_vocab() + + added_tokens_decoder = tokenizer.added_tokens_decoder + + for i in range(vocab_size): + if i not in reverse_vocab: + tokens.append(f"[PAD{i}]") + toktypes.append(gguf.TokenType.UNUSED) + else: + token: str = reverse_vocab[i] + if token in added_vocab: + # The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized. + # To avoid unexpected issues - we make sure to normalize non-normalized tokens + if not added_tokens_decoder[i].normalized: + previous_token = token + token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False)) + if previous_token != token: + logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer") + + if added_tokens_decoder[i].special or self.does_token_look_special(token): + toktypes.append(gguf.TokenType.CONTROL) + else: + # NOTE: this was added for Gemma. + # Encoding and decoding the tokens above isn't sufficient for this case. + token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces + toktypes.append(gguf.TokenType.USER_DEFINED) + else: + toktypes.append(gguf.TokenType.NORMAL) + tokens.append(token) + + return tokens, toktypes, tokpre + + # NOTE: this function is generated by convert_hf_to_gguf_update.py + # do not modify it manually! + # ref: https://github.com/ggml-org/llama.cpp/pull/6920 + # Marker: Start get_vocab_base_pre + def get_vocab_base_pre(self, tokenizer) -> str: + # encoding this string and hashing the resulting tokens would (hopefully) give us a unique identifier that + # is specific for the BPE pre-tokenizer used by the model + # we will use this unique identifier to write a "tokenizer.ggml.pre" entry in the GGUF file which we can + # use in llama.cpp to implement the same pre-tokenizer + + chktxt = '\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶\u200d🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български \'\'\'\'\'\'```````""""......!!!!!!?????? I\'ve been \'told he\'s there, \'RE you sure? \'M not sure I\'ll make it, \'D you like some tea? We\'Ve a\'lL' + + chktok = tokenizer.encode(chktxt) + chkhsh = sha256(str(chktok).encode()).hexdigest() + + logger.debug(f"chktok: {chktok}") + logger.debug(f"chkhsh: {chkhsh}") + + res = None + + # NOTE: if you get an error here, you need to update the convert_hf_to_gguf_update.py script + # or pull the latest version of the model from Huggingface + # don't edit the hashes manually! + if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5": + # ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B + res = "llama-bpe" + if chkhsh == "049ecf7629871e3041641907f3de7c733e4dbfdc736f57d882ba0b0845599754": + # ref: https://huggingface.co/deepseek-ai/deepseek-llm-7b-base + res = "deepseek-llm" + if chkhsh == "347715f544604f9118bb75ed199f68779f423cabb20db6de6f31b908d04d7821": + # ref: https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base + res = "deepseek-coder" + if chkhsh == "8aeee3860c56296a157a1fe2fad249ec40aa59b1bb5709f4ade11c4e6fe652ed": + # ref: https://huggingface.co/tiiuae/falcon-7b + res = "falcon" + if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f": + # ref: https://huggingface.co/BAAI/bge-small-en-v1.5 + res = "bert-bge" + if chkhsh == "9d032fcbd5501f4a38150912590928bfb36091efb5df11b8e2124b0390e3fb1e": + # ref: https://huggingface.co/tiiuae/Falcon3-7B-Base + res = "falcon3" + if chkhsh == "8e62295832751ca1e8f92f2226f403dea30dc5165e448b5bfa05af5340c64ec7": + # ref: https://huggingface.co/BAAI/bge-large-zh-v1.5 + res = "bert-bge-large" + if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166": + # ref: https://huggingface.co/mosaicml/mpt-7b + res = "mpt" + if chkhsh == "35d91631860c815f952d711435f48d356ebac988362536bed955d43bfa436e34": + # ref: https://huggingface.co/bigcode/starcoder2-3b + res = "starcoder" + if chkhsh == "3ce83efda5659b07b1ad37ca97ca5797ea4285d9b9ab0dc679e4a720c9da7454": + # ref: https://huggingface.co/openai-community/gpt2 + res = "gpt-2" + if chkhsh == "32d85c31273f8019248f2559fed492d929ea28b17e51d81d3bb36fff23ca72b3": + # ref: https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b + res = "stablelm2" + if chkhsh == "6221ad2852e85ce96f791f476e0b390cf9b474c9e3d1362f53a24a06dc8220ff": + # ref: https://huggingface.co/smallcloudai/Refact-1_6-base + res = "refact" + if chkhsh == "9c2227e4dd922002fb81bde4fc02b0483ca4f12911410dee2255e4987644e3f8": + # ref: https://huggingface.co/CohereForAI/c4ai-command-r-v01 + res = "command-r" + if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea": + # ref: https://huggingface.co/Qwen/Qwen1.5-7B + res = "qwen2" + if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166": + # ref: https://huggingface.co/allenai/OLMo-1.7-7B-hf + res = "olmo" + if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e": + # ref: https://huggingface.co/databricks/dbrx-base + res = "dbrx" + if chkhsh == "c7699093ba4255a91e702aa38a596aa81669f3525dae06c2953267dde580f448": + # ref: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en + res = "jina-v1-en" + if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f": + # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-en + res = "jina-v2-en" + if chkhsh == "171aeeedd6fb548d418a7461d053f11b6f1f1fc9b387bd66640d28a4b9f5c643": + # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-es + res = "jina-v2-es" + if chkhsh == "27949a2493fc4a9f53f5b9b029c82689cfbe5d3a1929bb25e043089e28466de6": + # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-de + res = "jina-v2-de" + if chkhsh == "c136ed14d01c2745d4f60a9596ae66800e2b61fa45643e72436041855ad4089d": + # ref: https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct + res = "smaug-bpe" + if chkhsh == "c7ea5862a53e4272c035c8238367063e2b270d51faa48c0f09e9d5b54746c360": + # ref: https://huggingface.co/LumiOpen/Poro-34B-chat + res = "poro-chat" + if chkhsh == "7967bfa498ade6b757b064f31e964dddbb80f8f9a4d68d4ba7998fcf281c531a": + # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-code + res = "jina-v2-code" + if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee": + # ref: https://huggingface.co/LumiOpen/Viking-7B + res = "viking" + if chkhsh == "b53802fb28e26d645c3a310b34bfe07da813026ec7c7716883404d5e0f8b1901": + # ref: https://huggingface.co/core42/jais-13b + res = "jais" + if chkhsh == "7b3e7548e4308f52a76e8229e4e6cc831195d0d1df43aed21ac6c93da05fec5f": + # ref: https://huggingface.co/WisdomShell/CodeShell-7B + res = "codeshell" + if chkhsh == "63b97e4253352e6f357cc59ea5b583e3a680eaeaf2632188c2b952de2588485e": + # ref: https://huggingface.co/mistralai/Mistral-Nemo-Base-2407 + res = "tekken" + if chkhsh == "855059429035d75a914d1eda9f10a876752e281a054a7a3d421ef0533e5b6249": + # ref: https://huggingface.co/HuggingFaceTB/SmolLM-135M + res = "smollm" + if chkhsh == "3c30d3ad1d6b64202cd222813e7736c2db6e1bd6d67197090fc1211fbc612ae7": + # ref: https://huggingface.co/bigscience/bloom + res = "bloom" + if chkhsh == "bc01ce58980e1db43859146dc51b1758b3b88729b217a74792e9f8d43e479d21": + # ref: https://huggingface.co/TurkuNLP/gpt3-finnish-small + res = "gpt3-finnish" + if chkhsh == "4e2b24cc4770243d65a2c9ec19770a72f08cffc161adbb73fcbb6b7dd45a0aae": + # ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct + res = "exaone" + if chkhsh == "fcace8b9cac38ce847670c970cd5892031a753a1ef381abd1d9af00f713da085": + # ref: https://huggingface.co/microsoft/phi-2 + res = "phi-2" + if chkhsh == "60824e3c0d9401f89943cbb2fff727f0e2d4c545ba4df2d6e4f09a6db0f5b450": + # ref: https://huggingface.co/facebook/chameleon-7b + res = "chameleon" + if chkhsh == "8b5a93ed704057481f240da0be7e7dca721d7f8f4755263b6807227a2cbeae65": + # ref: https://huggingface.co/sentence-transformers/stsb-roberta-base + res = "roberta-bpe" + if chkhsh == "ad851be1dba641f2e3711822f816db2c265f788b37c63b4e1aeacb9ee92de8eb": + # ref: https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct + res = "gigachat" + if chkhsh == "d4c8f286ea6b520b3d495c4455483cfa2302c0cfcd4be05d781b6a8a0a7cdaf1": + # ref: https://huggingface.co/Infinigence/Megrez-3B-Instruct + res = "megrez" + if chkhsh == "877081d19cf6996e2c4ff0e1236341e9b7bde288f5311a56a937f0afbbb3aeb5": + # ref: https://huggingface.co/deepseek-ai/DeepSeek-V3 + res = "deepseek-v3" + if chkhsh == "b3f499bb4255f8ca19fccd664443283318f2fd2414d5e0b040fbdd0cc195d6c5": + # ref: https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B + res = "deepseek-r1-qwen" + if chkhsh == "ccc2ef013c104be7bae2965776d611e1d7a8a2a9c547dd93a682c9a9fc80352e": + # ref: https://huggingface.co/Xenova/gpt-4o + res = "gpt-4o" + if chkhsh == "7dec86086fcc38b66b7bc1575a160ae21cf705be7718b9d5598190d7c12db76f": + # ref: https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k + res = "superbpe" + if chkhsh == "1994ffd01900cfb37395608534236ecd63f2bd5995d6cb1004dda1af50240f15": + # ref: https://huggingface.co/trillionlabs/Trillion-7B-preview + res = "trillion" + if chkhsh == "96a5f08be6259352137b512d4157e333e21df7edd3fcd152990608735a65b224": + # ref: https://huggingface.co/inclusionAI/Ling-lite + res = "bailingmoe" + if chkhsh == "d353350c764d8c3b39c763113960e4fb4919bea5fbf208a0e3b22e8469dc7406": + # ref: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct + res = "llama4" + if chkhsh == "0e9433cbbb161f89e264eb32e8e64bfe69e834973ffca5d41d3948a604a3e2a3": + # ref: https://huggingface.co/mistral-community/pixtral-12b + res = "pixtral" + if chkhsh == "d5f1dd6f980fec569fb218a81a7658ac45fc56b38c5a0adeb1c232fbe04ef5ec": + # ref: https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base + res = "seed-coder" + if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b": + # ref: https://huggingface.co/THUDM/glm-4-9b-chat + res = "chatglm-bpe" + if chkhsh == "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516": + # ref: https://huggingface.co/THUDM/glm-4-9b-chat + res = "chatglm-bpe" + if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2": + # ref: https://huggingface.co/THUDM/glm-4-9b-hf + res = "glm4" + if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35": + # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0 + res = "minerva-7b" + + if res is None: + logger.warning("\n") + logger.warning("**************************************************************************************") + logger.warning("** WARNING: The BPE pre-tokenizer was not recognized!") + logger.warning("** There are 2 possible reasons for this:") + logger.warning("** - the model has not been added to convert_hf_to_gguf_update.py yet") + logger.warning("** - the pre-tokenization config has changed upstream") + logger.warning("** Check your model files and convert_hf_to_gguf_update.py and update them accordingly.") + logger.warning("** ref: https://github.com/ggml-org/llama.cpp/pull/6920") + logger.warning("**") + logger.warning(f"** chkhsh: {chkhsh}") + logger.warning("**************************************************************************************") + logger.warning("\n") + raise NotImplementedError("BPE pre-tokenizer was not recognized - update get_vocab_base_pre()") + + logger.debug(f"tokenizer.ggml.pre: {repr(res)}") + logger.debug(f"chkhsh: {chkhsh}") + + return res + # Marker: End get_vocab_base_pre + + def _set_vocab_none(self) -> None: + self.gguf_writer.add_tokenizer_model("none") + + def _set_vocab_gpt2(self) -> None: + tokens, toktypes, tokpre = self.get_vocab_base() + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + special_vocab.add_to_gguf(self.gguf_writer) + + def _set_vocab_qwen(self): + dir_model = self.dir_model + hparams = self.hparams + tokens: list[str] = [] + toktypes: list[int] = [] + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) + vocab_size = hparams["vocab_size"] + assert max(tokenizer.get_vocab().values()) < vocab_size + + tokpre = self.get_vocab_base_pre(tokenizer) + + merges = [] + vocab = {} + mergeable_ranks = tokenizer.mergeable_ranks + for token, rank in mergeable_ranks.items(): + vocab[QwenModel.token_bytes_to_string(token)] = rank + if len(token) == 1: + continue + merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank) + assert len(merged) == 2 + merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged))) + + # for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined + added_vocab = tokenizer.special_tokens + reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()} + + for i in range(vocab_size): + if i not in reverse_vocab: + tokens.append(f"[PAD{i}]") + toktypes.append(gguf.TokenType.UNUSED) + elif reverse_vocab[i] in added_vocab: + tokens.append(reverse_vocab[i]) + toktypes.append(gguf.TokenType.CONTROL) + else: + tokens.append(reverse_vocab[i]) + toktypes.append(gguf.TokenType.NORMAL) + + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(dir_model, load_merges=False) + special_vocab.merges = merges + # only add special tokens when they were not already loaded from config.json + if len(special_vocab.special_token_ids) == 0: + special_vocab._set_special_token("bos", tokenizer.special_tokens["<|endoftext|>"]) + special_vocab._set_special_token("eos", tokenizer.special_tokens["<|endoftext|>"]) + # this one is usually not in config.json anyway + special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"]) + special_vocab.add_to_gguf(self.gguf_writer) + + def _set_vocab_sentencepiece(self, add_to_gguf=True): + tokens, scores, toktypes = self._create_vocab_sentencepiece() + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + def _create_vocab_sentencepiece(self): + from sentencepiece import SentencePieceProcessor + + tokenizer_path = self.dir_model / 'tokenizer.model' + + if not tokenizer_path.is_file(): + raise FileNotFoundError(f"File not found: {tokenizer_path}") + + tokenizer = SentencePieceProcessor() + tokenizer.LoadFromFile(str(tokenizer_path)) + + vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + + tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] + scores: list[float] = [-10000.0] * vocab_size + toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size + + for token_id in range(tokenizer.vocab_size()): + piece = tokenizer.IdToPiece(token_id) + text = piece.encode("utf-8") + score = tokenizer.GetScore(token_id) + + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.IsUnknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.IsControl(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.IsUnused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.IsByte(token_id): + toktype = SentencePieceTokenTypes.BYTE + + tokens[token_id] = text + scores[token_id] = score + toktypes[token_id] = toktype + + added_tokens_file = self.dir_model / 'added_tokens.json' + if added_tokens_file.is_file(): + with open(added_tokens_file, "r", encoding="utf-8") as f: + added_tokens_json = json.load(f) + for key in added_tokens_json: + token_id = added_tokens_json[key] + if token_id >= vocab_size: + logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') + continue + + tokens[token_id] = key.encode("utf-8") + scores[token_id] = -1000.0 + toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + tokenizer_config_json = json.load(f) + added_tokens_decoder = tokenizer_config_json.get("added_tokens_decoder", {}) + for token_id, token_data in added_tokens_decoder.items(): + token_id = int(token_id) + token: str = token_data["content"] + if token_id >= vocab_size: + logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') + continue + if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: + if tokens[token_id] != token.encode("utf-8"): + logger.warning(f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token!r}') + if token_data.get("special") or self.does_token_look_special(token): + toktypes[token_id] = SentencePieceTokenTypes.CONTROL + else: + token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces + toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + + scores[token_id] = -1000.0 + tokens[token_id] = token.encode("utf-8") + + if vocab_size > len(tokens): + pad_count = vocab_size - len(tokens) + logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]") + for i in range(1, pad_count + 1): + tokens.append(bytes(f"[PAD{i}]", encoding="utf-8")) + scores.append(-1000.0) + toktypes.append(SentencePieceTokenTypes.UNUSED) + + return tokens, scores, toktypes + + def _set_vocab_llama_hf(self): + vocab = gguf.LlamaHfVocab(self.dir_model) + tokens = [] + scores = [] + toktypes = [] + + for text, score, toktype in vocab.all_tokens(): + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + assert len(tokens) == vocab.vocab_size + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + def _set_vocab_rwkv_world(self): + assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file() + vocab_size = self.hparams.get("vocab_size", 65536) + + tokens: list[bytes] = [''.encode("utf-8")] + toktypes: list[int] = [gguf.TokenType.CONTROL] + + with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f: + lines = f.readlines() + for line in lines: + parts = line.split(' ') + assert len(parts) >= 3 + token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1]) + token = token.encode("utf-8") if isinstance(token, str) else token + assert isinstance(token, bytes) + assert len(token) == token_len + token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff" + tokens.append(token_text.encode("utf-8")) + toktypes.append(gguf.TokenType.NORMAL) + remainder = vocab_size - len(tokens) + assert remainder >= 0 + for i in range(len(tokens), vocab_size): + tokens.append(f"[PAD{i}]".encode("utf-8")) + toktypes.append(gguf.TokenType.UNUSED) + + self.gguf_writer.add_tokenizer_model("rwkv") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False) + special_vocab.chat_template = "rwkv-world" + # hack: Add '\n\n' as the EOT token to make it chat normally + special_vocab._set_special_token("eot", 261) + # hack: Override these as they have already been set (incorrectly) + special_vocab.special_token_ids["bos"] = 0 + special_vocab.special_token_ids["eos"] = 0 + + special_vocab.add_to_gguf(self.gguf_writer) + + def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab_size: int): + tokenizer_path = Path(sys.path[0]) / "models" / f"ggml-vocab-{model_name}.gguf" + logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'") + vocab_reader = gguf.GGUFReader(tokenizer_path, "r") + + default_pre = "mpt" if model_name == "gpt-neox" else "default" + + field = vocab_reader.get_field(gguf.Keys.Tokenizer.MODEL) + assert field # tokenizer model + self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1]).decode("utf-8")) + + field = vocab_reader.get_field(gguf.Keys.Tokenizer.PRE) + self.gguf_writer.add_tokenizer_pre(bytes(field.parts[-1]).decode("utf-8") if field else default_pre) + + field = vocab_reader.get_field(gguf.Keys.Tokenizer.LIST) + assert field # token list + self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size]) + + if model_name == "llama-spm": + field = vocab_reader.get_field(gguf.Keys.Tokenizer.SCORES) + assert field # token scores + self.gguf_writer.add_token_scores([field.parts[i].tolist()[0] for i in field.data][:vocab_size]) + + field = vocab_reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE) + assert field # token types + self.gguf_writer.add_token_types([field.parts[i].tolist()[0] for i in field.data][:vocab_size]) + + if model_name != "llama-spm": + field = vocab_reader.get_field(gguf.Keys.Tokenizer.MERGES) + assert field # token merges + self.gguf_writer.add_token_merges([bytes(field.parts[i]) for i in field.data]) + + if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.BOS_ID)) is not None: + self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0]) + if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.EOS_ID)) is not None: + self.gguf_writer.add_eos_token_id(field.parts[-1].tolist()[0]) + if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.UNK_ID)) is not None: + self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0]) + if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.PAD_ID)) is not None: + self.gguf_writer.add_pad_token_id(field.parts[-1].tolist()[0]) + if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.ADD_BOS)) is not None: + self.gguf_writer.add_add_bos_token(field.parts[-1].tolist()[0]) + if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.ADD_EOS)) is not None: + self.gguf_writer.add_add_eos_token(field.parts[-1].tolist()[0]) + + def _try_set_pooling_type(self) -> None: + # get pooling path + pooling_path = None + module_path = self.dir_model / "modules.json" + if module_path.is_file(): + with open(module_path, encoding="utf-8") as f: + modules = json.load(f) + for mod in modules: + if mod["type"] == "sentence_transformers.models.Pooling": + pooling_path = mod["path"] + break + + # get pooling type + if pooling_path is not None: + with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f: + pooling = json.load(f) + if pooling["pooling_mode_mean_tokens"]: + pooling_type = gguf.PoolingType.MEAN + elif pooling["pooling_mode_cls_token"]: + pooling_type = gguf.PoolingType.CLS + elif pooling["pooling_mode_lasttoken"]: + pooling_type = gguf.PoolingType.LAST + else: + raise NotImplementedError("Only MEAN, CLS, and LAST pooling types supported") + self.gguf_writer.add_pooling_type(pooling_type) + + +class MmprojModel(ModelBase): + model_type = ModelType.MMPROJ + model_arch = gguf.MODEL_ARCH.MMPROJ + preprocessor_config: dict[str, Any] + global_config: dict[str, Any] + + n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"] + + has_vision_encoder: bool = True # by default + has_audio_encoder: bool = False + + # for models having multiple encoders, we need to separate their hparams + hparams_vision: dict[str, Any] | None = None + hparams_audio: dict[str, Any] | None = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if self.model_arch != gguf.MODEL_ARCH.MMPROJ: + raise TypeError("MmprojModel must be subclassed with model_arch = gguf.MODEL_ARCH.MMPROJ") + + # get n_embd of the text model + if "text_config" not in self.hparams: + self.hparams["text_config"] = {} + if "audio_config" not in self.hparams: + self.hparams["audio_config"] = {} + text_config = {**self.hparams, **self.hparams["text_config"]} + self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0)) + assert self.n_embd_text > 0, "n_embd not found in hparams" + + # move vision config to the top level, while preserving the original hparams in global_config + import copy + self.global_config = copy.deepcopy(self.hparams) + self.hparams_vision = self.get_vision_config() + self.hparams_audio = self.get_audio_config() + + if self.hparams_vision is None and self.hparams_audio is None: + raise ValueError("vision_config / audio_config not found in hparams") + + # for compat with vision-only models + self.hparams = self.hparams_vision or self.hparams_audio or self.hparams + + # TODO @ngxson : this is a hack to support both vision and audio encoders + have_multiple_encoders = self.has_audio_encoder and self.has_vision_encoder + self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys, True) + self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count) + + # load preprocessor config + with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f: + self.preprocessor_config = json.load(f) + + def get_vision_config(self) -> dict[str, Any] | None: + return self.global_config.get("vision_config") + + def get_audio_config(self) -> dict[str, Any] | None: + return self.global_config.get("audio_config") + + def set_type(self): + self.gguf_writer.add_type(gguf.GGUFType.MMPROJ) + + def set_gguf_parameters(self): + self.gguf_writer.add_file_type(self.ftype) + + if self.has_vision_encoder: + self.gguf_writer.add_clip_has_vision_encoder(True) + self.gguf_writer.add_vision_projection_dim(self.n_embd_text) + + # vision config + self.gguf_writer.add_vision_image_size(self.find_vparam(["image_size"])) + self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"])) + self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"])) + self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"])) + self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys)) + self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"])) + + # preprocessor config + self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"]) + self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_std"]) + + if self.has_audio_encoder: + self.gguf_writer.add_clip_has_audio_encoder(True) + self.gguf_writer.add_audio_projection_dim(self.n_embd_text) + + # audio config + self.gguf_writer.add_audio_embedding_length(self.find_aparam(["hidden_size"])) + self.gguf_writer.add_audio_feed_forward_length(self.find_aparam(["intermediate_size"])) + self.gguf_writer.add_audio_block_count(self.find_aparam(self.n_block_keys)) + self.gguf_writer.add_audio_head_count(self.find_aparam(["num_attention_heads"])) + + if not self.has_vision_encoder and not self.has_audio_encoder: + raise ValueError("MmprojModel must have either vision or audio encoder") + + def write_vocab(self): + raise ValueError("MmprojModel does not support vocab writing") + + def find_vparam(self, keys: Iterable[str], optional: bool = False) -> Any: + assert self.hparams_vision is not None + return self._find_param(self.hparams_vision, keys, optional) + + def find_aparam(self, keys: Iterable[str], optional: bool = False) -> Any: + assert self.hparams_audio is not None + return self._find_param(self.hparams_audio, keys, optional) + + def _find_param(self, obj: dict[str, Any], keys: Iterable[str], optional: bool = False) -> Any: + key = next((k for k in keys if k in obj), None) + if key is not None: + return obj[key] + if optional: + return None + raise KeyError(f"could not find any of: {keys}") + + +@ModelBase.register("GPTNeoXForCausalLM") +class GPTNeoXModel(TextModel): + model_arch = gguf.MODEL_ARCH.GPTNEOX + + def set_gguf_parameters(self): + block_count = self.hparams["num_hidden_layers"] + + self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_rope_dimension_count( + int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])), + ) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True)) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) + n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) + + tensors: list[tuple[str, Tensor]] = [] + + if re.match(r"gpt_neox\.layers\.\d+\.attention\.query_key_value\.weight", name): + # Map bloom-style qkv_linear to gpt-style qkv_linear + # bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa + # gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa + qkv_weights = data_torch.reshape((n_head, 3, n_embed // n_head, n_embed)) + data_torch = torch.cat( + ( + qkv_weights[:, 0, :, :].reshape((-1, n_embed)), + qkv_weights[:, 1, :, :].reshape((-1, n_embed)), + qkv_weights[:, 2, :, :].reshape((-1, n_embed)), + ), + dim=0, + ) + logger.info("re-format attention.linear_qkv.weight") + elif re.match(r"gpt_neox\.layers\.\d+\.attention\.query_key_value\.bias", name): + qkv_bias = data_torch.reshape((n_head, 3, n_embed // n_head)) + data_torch = torch.cat( + ( + qkv_bias[:, 0, :].reshape((n_embed,)), + qkv_bias[:, 1, :].reshape((n_embed,)), + qkv_bias[:, 2, :].reshape((n_embed,)), + ), + dim=0, + ) + logger.info("re-format attention.linear_qkv.bias") + + tensors.append((self.map_tensor_name(name), data_torch)) + + return tensors + + +@ModelBase.register("BloomForCausalLM", "BloomModel") +class BloomModel(TextModel): + model_arch = gguf.MODEL_ARCH.BLOOM + + def set_gguf_parameters(self): + n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) + n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) + self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed)) + self.gguf_writer.add_embedding_length(n_embed) + self.gguf_writer.add_feed_forward_length(4 * n_embed) + self.gguf_writer.add_block_count(self.hparams["n_layer"]) + self.gguf_writer.add_head_count(n_head) + self.gguf_writer.add_head_count_kv(n_head) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) + n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) + + name = re.sub(r'transformer\.', '', name) + + tensors: list[tuple[str, Tensor]] = [] + + if re.match(r"h\.\d+\.self_attention\.query_key_value\.weight", name): + # Map bloom-style qkv_linear to gpt-style qkv_linear + # bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa + # gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa + qkv_weights = data_torch.reshape((n_head, 3, n_embed // n_head, n_embed)) + data_torch = torch.cat( + ( + qkv_weights[:, 0, :, :].reshape((-1, n_embed)), + qkv_weights[:, 1, :, :].reshape((-1, n_embed)), + qkv_weights[:, 2, :, :].reshape((-1, n_embed)), + ), + dim=0, + ) + logger.info("re-format attention.linear_qkv.weight") + elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name): + qkv_bias = data_torch.reshape((n_head, 3, n_embed // n_head)) + data_torch = torch.cat( + ( + qkv_bias[:, 0, :].reshape((n_embed,)), + qkv_bias[:, 1, :].reshape((n_embed,)), + qkv_bias[:, 2, :].reshape((n_embed,)), + ), + dim=0, + ) + logger.info("re-format attention.linear_qkv.bias") + + tensors.append((self.map_tensor_name(name), data_torch)) + + return tensors + + +@ModelBase.register("MPTForCausalLM") +class MPTModel(TextModel): + model_arch = gguf.MODEL_ARCH.MPT + + def set_vocab(self): + try: + self._set_vocab_gpt2() + except Exception: + # Fallback for SEA-LION model + self._set_vocab_sentencepiece() + self.gguf_writer.add_add_bos_token(False) + self.gguf_writer.add_pad_token_id(3) + self.gguf_writer.add_eos_token_id(1) + self.gguf_writer.add_unk_token_id(0) + + def set_gguf_parameters(self): + block_count = self.hparams["n_layers"] + self.gguf_writer.add_context_length(self.hparams["max_seq_len"]) + self.gguf_writer.add_embedding_length(self.hparams["d_model"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"]) + self.gguf_writer.add_head_count(self.hparams["n_heads"]) + if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"): + self.gguf_writer.add_head_count_kv(kv_n_heads) + self.gguf_writer.add_layer_norm_eps(1e-5) + if self.hparams["attn_config"]["clip_qkv"] is not None: + self.gguf_writer.add_clamp_kqv(self.hparams["attn_config"]["clip_qkv"]) + if self.hparams["attn_config"]["alibi"]: + self.gguf_writer.add_max_alibi_bias(self.hparams["attn_config"]["alibi_bias_max"]) + else: + self.gguf_writer.add_max_alibi_bias(0.0) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if "scales" in name: + new_name = self.map_tensor_name(name, try_suffixes=(".weight", ".bias", ".scales")) + new_name = new_name.replace("scales", "act.scales") + else: + new_name = self.map_tensor_name(name, try_suffixes=(".weight", ".bias")) + + return [(new_name, data_torch)] + + +@ModelBase.register("OrionForCausalLM") +class OrionModel(TextModel): + model_arch = gguf.MODEL_ARCH.ORION + + def set_vocab(self): + self._set_vocab_sentencepiece() + + def set_gguf_parameters(self): + block_count = self.hparams["num_hidden_layers"] + head_count = self.hparams["num_attention_heads"] + head_count_kv = self.hparams.get("num_key_value_heads", head_count) + + ctx_length = 0 + if "max_sequence_length" in self.hparams: + ctx_length = self.hparams["max_sequence_length"] + elif "max_position_embeddings" in self.hparams: + ctx_length = self.hparams["max_position_embeddings"] + elif "model_max_length" in self.hparams: + ctx_length = self.hparams["model_max_length"] + else: + raise ValueError("gguf: can not find ctx length parameter.") + + self.gguf_writer.add_file_type(self.ftype) + self.gguf_writer.add_tensor_data_layout("Meta AI original pth") + self.gguf_writer.add_context_length(ctx_length) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_head_count(head_count) + self.gguf_writer.add_head_count_kv(head_count_kv) + # note: config provides rms norm but it is actually layer norm + # ref: https://huggingface.co/OrionStarAI/Orion-14B-Chat/blob/276a17221ce42beb45f66fac657a41540e71f4f5/modeling_orion.py#L570-L571 + self.gguf_writer.add_layer_norm_eps(self.hparams["rms_norm_eps"]) + + +@ModelBase.register("BaichuanForCausalLM", "BaiChuanForCausalLM") +class BaichuanModel(TextModel): + model_arch = gguf.MODEL_ARCH.BAICHUAN + + def set_vocab(self): + self._set_vocab_sentencepiece() + + def set_gguf_parameters(self): + block_count = self.hparams["num_hidden_layers"] + head_count = self.hparams["num_attention_heads"] + head_count_kv = self.hparams.get("num_key_value_heads", head_count) + + ctx_length = 0 + if "max_sequence_length" in self.hparams: + ctx_length = self.hparams["max_sequence_length"] + elif "max_position_embeddings" in self.hparams: + ctx_length = self.hparams["max_position_embeddings"] + elif "model_max_length" in self.hparams: + ctx_length = self.hparams["model_max_length"] + else: + raise ValueError("gguf: can not find ctx length parameter.") + + self.gguf_writer.add_tensor_data_layout("Meta AI original pth") + self.gguf_writer.add_context_length(ctx_length) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count(head_count) + self.gguf_writer.add_head_count_kv(head_count_kv) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) + self.gguf_writer.add_file_type(self.ftype) + + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + head_count = self.hparams["num_attention_heads"] + head_count_kv = self.hparams.get("num_key_value_heads", head_count) + + tensors: list[tuple[str, Tensor]] = [] + + if bid is not None and name == f"model.layers.{bid}.self_attn.W_pack.weight": + logger.info(f"Unpacking and permuting layer {bid}") + tensors = [ + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), + self._reverse_hf_permute_part(data_torch, 0, head_count, head_count)), + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), + self._reverse_hf_permute_part(data_torch, 1, head_count, head_count_kv)), + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), + self._reverse_hf_part(data_torch, 2)), + ] + else: + tensors = [(self.map_tensor_name(name), data_torch)] + + return tensors + + def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor: + if n_kv_head is not None and n_head != n_kv_head: + n_head //= n_kv_head + + return ( + weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape) + ) + + def _reverse_hf_permute_part( + self, weights: Tensor, n_part: int, n_head: int, n_head_kv: int | None = None, + ) -> Tensor: + r = weights.shape[0] // 3 + return self._reverse_hf_permute(weights[r * n_part:r * n_part + r, ...], n_head, n_head_kv) + + def _reverse_hf_part(self, weights: Tensor, n_part: int) -> Tensor: + r = weights.shape[0] // 3 + return weights[r * n_part:r * n_part + r, ...] + + +@ModelBase.register("XverseForCausalLM") +class XverseModel(TextModel): + model_arch = gguf.MODEL_ARCH.XVERSE + + def set_vocab(self): + assert (self.dir_model / "tokenizer.json").is_file() + dir_model = self.dir_model + hparams = self.hparams + + tokens: list[bytes] = [] + toktypes: list[int] = [] + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(dir_model) + vocab_size = hparams.get("vocab_size", len(tokenizer.vocab)) + # Since we are checking the maximum index, we need to ensure it's strictly less than vocab_size, + # because vocab_size is the count of items, and indexes start at 0. + max_vocab_index = max(tokenizer.get_vocab().values()) + if max_vocab_index >= vocab_size: + raise ValueError("Vocabulary size exceeds expected maximum size.") + + reverse_vocab: dict[int, str] = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} + added_vocab = tokenizer.get_added_vocab() + + for token_id in range(vocab_size): + token_text = reverse_vocab[token_id].encode('utf-8') + # replace "\x00" to string with length > 0 + if token_text == b"\x00": + toktype = gguf.TokenType.BYTE # special + token_text = f"<{token_text}>".encode('utf-8') + elif re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text): + toktype = gguf.TokenType.BYTE # special + elif reverse_vocab[token_id] in added_vocab: + if tokenizer.added_tokens_decoder[token_id].special: + toktype = gguf.TokenType.CONTROL + else: + toktype = gguf.TokenType.USER_DEFINED + else: + toktype = gguf.TokenType.NORMAL + + tokens.append(token_text) + toktypes.append(toktype) + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + def set_gguf_parameters(self): + block_count = self.hparams["num_hidden_layers"] + head_count = self.hparams["num_attention_heads"] + head_count_kv = self.hparams.get("num_key_value_heads", head_count) + + ctx_length = 0 + if "max_sequence_length" in self.hparams: + ctx_length = self.hparams["max_sequence_length"] + elif "max_position_embeddings" in self.hparams: + ctx_length = self.hparams["max_position_embeddings"] + elif "model_max_length" in self.hparams: + ctx_length = self.hparams["model_max_length"] + else: + raise ValueError("gguf: can not find ctx length parameter.") + + self.gguf_writer.add_tensor_data_layout("Meta AI original pth") + self.gguf_writer.add_context_length(ctx_length) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count(head_count) + self.gguf_writer.add_head_count_kv(head_count_kv) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) + self.gguf_writer.add_file_type(self.ftype) + + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + head_count = self.hparams["num_attention_heads"] + head_count_kv = self.hparams.get("num_key_value_heads", head_count) + + # HF models permute some of the tensors, so we need to undo that + if name.endswith("q_proj.weight"): + data_torch = self._reverse_hf_permute(data_torch, head_count, head_count) + if name.endswith("k_proj.weight"): + data_torch = self._reverse_hf_permute(data_torch, head_count, head_count_kv) + + return [(self.map_tensor_name(name), data_torch)] + + def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor: + if n_kv_head is not None and n_head != n_kv_head: + n_head //= n_kv_head + + return ( + weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape) + ) + + +@ModelBase.register("FalconForCausalLM", "RWForCausalLM") +class FalconModel(TextModel): + model_arch = gguf.MODEL_ARCH.FALCON + + def set_gguf_parameters(self): + block_count = self.hparams.get("num_hidden_layers") + if block_count is None: + block_count = self.hparams["n_layer"] # old name + + n_head = self.hparams.get("num_attention_heads") + if n_head is None: + n_head = self.hparams["n_head"] # old name + + n_head_kv = self.hparams.get("num_kv_heads") + if n_head_kv is None: + n_head_kv = self.hparams.get("n_head_kv", 1) # old name + + self.gguf_writer.add_context_length(2048) # not in config.json + self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(n_head) + self.gguf_writer.add_head_count_kv(n_head_kv) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + # QKV tensor transform + # The original query_key_value tensor contains n_head_kv "kv groups", + # each consisting of n_head/n_head_kv query weights followed by one key + # and one value weight (shared by all query heads in the kv group). + # This layout makes it a big pain to work with in GGML. + # So we rearrange them here,, so that we have n_head query weights + # followed by n_head_kv key weights followed by n_head_kv value weights, + # in contiguous fashion. + # ref: https://github.com/jploski/ggml/blob/falcon40b/examples/falcon/convert-hf-to-ggml.py + + if "query_key_value" in name: + n_head = self.find_hparam(["num_attention_heads", "n_head"]) + n_head_kv = self.find_hparam(["num_kv_heads", "n_head_kv"], optional=True) or 1 + head_dim = self.hparams["hidden_size"] // n_head + + qkv = data_torch.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head) + q = qkv[:, :-2].reshape(n_head * head_dim, head_dim * n_head) + k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head) + v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head) + data_torch = torch.cat((q, k, v)).reshape_as(data_torch) + + return [(self.map_tensor_name(name), data_torch)] + + +@ModelBase.register("GPTBigCodeForCausalLM") +class StarCoderModel(TextModel): + model_arch = gguf.MODEL_ARCH.STARCODER + + def set_gguf_parameters(self): + block_count = self.hparams["n_layer"] + + self.gguf_writer.add_context_length(self.hparams["n_positions"]) + self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) + self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(self.hparams["n_head"]) + self.gguf_writer.add_head_count_kv(1) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_file_type(self.ftype) + + +@ModelBase.register("GPTRefactForCausalLM") +class RefactModel(TextModel): + model_arch = gguf.MODEL_ARCH.REFACT + + def set_vocab(self): + super().set_vocab() + + # TODO: how to determine special FIM tokens automatically? + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False, + special_token_types = ['prefix', 'suffix', 'middle', 'eot']) + special_vocab._set_special_token("prefix", 1) + special_vocab._set_special_token("suffix", 3) + special_vocab._set_special_token("middle", 2) + special_vocab.chat_template = None # do not add it twice + special_vocab.add_to_gguf(self.gguf_writer) + + def set_gguf_parameters(self): + hidden_dim = self.hparams["n_embd"] + inner_dim = 4 * hidden_dim + hidden_dim = int(2 * inner_dim / 3) + multiple_of = 256 + ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + block_count = self.hparams["n_layer"] + + # refact uses Alibi. So this is from config.json which might be used by training. + self.gguf_writer.add_context_length(self.hparams["n_positions"]) + self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) + + self.gguf_writer.add_feed_forward_length(ff_dim) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(self.hparams["n_head"]) + self.gguf_writer.add_head_count_kv(1) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + hidden_dim = self.hparams["n_embd"] + inner_dim = 4 * hidden_dim + hidden_dim = int(2 * inner_dim / 3) + multiple_of = 256 + ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + n_head = self.hparams["n_head"] + n_head_kv = 1 + head_dim = self.hparams["n_embd"] // n_head + + tensors: list[tuple[str, Tensor]] = [] + + if bid is not None: + if name == f"transformer.h.{bid}.attn.kv.weight": + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), data_torch[:n_head_kv * head_dim])) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), data_torch[n_head_kv * head_dim:])) + elif name == f"transformer.h.{bid}.attn.q.weight": + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), data_torch)) + elif name == f"transformer.h.{bid}.mlp.gate_up_proj.weight": + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), data_torch[:ff_dim])) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), data_torch[ff_dim:])) + + if len(tensors) == 0: + tensors.append((self.map_tensor_name(name), data_torch)) + + return tensors + + +@ModelBase.register("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM") +class StableLMModel(TextModel): + model_arch = gguf.MODEL_ARCH.STABLELM + + def set_vocab(self): + if (self.dir_model / "tokenizer.json").is_file(): + self._set_vocab_gpt2() + else: + # StableLM 2 1.6B used to have a vocab in a similar format to Qwen's vocab + self._set_vocab_qwen() + + def set_gguf_parameters(self): + hparams = self.hparams + block_count = hparams["num_hidden_layers"] + + self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(hparams["hidden_size"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) + rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"]) + self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"]))) + self.gguf_writer.add_head_count(hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"]) + self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True) + self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_eps", "norm_eps"])) + self.gguf_writer.add_file_type(self.ftype) + + _q_norms: list[dict[str, Tensor]] | None = None + _k_norms: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams["num_key_value_heads"] + + if name.find("q_layernorm.norms") != -1: + assert bid is not None + + if self._q_norms is None: + self._q_norms = [{} for _ in range(self.block_count)] + + self._q_norms[bid][name] = data_torch + + if len(self._q_norms[bid]) >= n_head: + return self._stack_qk_norm(bid, n_head, self._q_norms[bid], "q_layernorm") + else: + return [] + + if name.find("k_layernorm.norms") != -1: + assert bid is not None + + if self._k_norms is None: + self._k_norms = [{} for _ in range(self.block_count)] + + self._k_norms[bid][name] = data_torch + + if len(self._k_norms[bid]) >= n_kv_head: + return self._stack_qk_norm(bid, n_kv_head, self._k_norms[bid], "k_layernorm") + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + def _stack_qk_norm(self, bid: int, n_head: int, norms: dict[str, Tensor], layer_name: str = "q_layernorm"): + datas: list[Tensor] = [] + # extract the norms in order + for xid in range(n_head): + ename = f"model.layers.{bid}.self_attn.{layer_name}.norms.{xid}.weight" + datas.append(norms[ename]) + del norms[ename] + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.self_attn.{layer_name}.weight" + new_name = self.map_tensor_name(merged_name) + + return [(new_name, data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + + if self._q_norms is not None or self._k_norms is not None: + # flatten two `list[dict[str, Tensor]]` into a single `list[str]` + norms = ( + [k for d in self._q_norms for k in d.keys()] if self._q_norms is not None else [] + ) + ( + [k for d in self._k_norms for k in d.keys()] if self._k_norms is not None else [] + ) + if len(norms) > 0: + raise ValueError(f"Unprocessed norms: {norms}") + + +@ModelBase.register( + "LLaMAForCausalLM", + "LlamaForCausalLM", + "MistralForCausalLM", + "MixtralForCausalLM", + "VLlama3ForCausalLM", + "LlavaForConditionalGeneration", + "LlamaModel") +class LlamaModel(TextModel): + model_arch = gguf.MODEL_ARCH.LLAMA + undo_permute = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # fix for SmolVLM2, missing `num_attention_heads` in config.json + if self.hf_arch == "VLlama3ForCausalLM": + self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32) + + def set_vocab(self): + try: + self._set_vocab_sentencepiece() + except FileNotFoundError: + try: + self._set_vocab_llama_hf() + except (FileNotFoundError, TypeError): + # Llama 3 + self._set_vocab_gpt2() + + # Apply to CodeLlama only (and ignore for Llama 3 with a vocab size of 128256) + if self.hparams.get("vocab_size", 32000) == 32016: + special_vocab = gguf.SpecialVocab( + self.dir_model, load_merges=False, + special_token_types = ['prefix', 'suffix', 'middle', 'eot'] + ) + special_vocab._set_special_token("prefix", 32007) + special_vocab._set_special_token("suffix", 32008) + special_vocab._set_special_token("middle", 32009) + special_vocab._set_special_token("eot", 32010) + special_vocab.add_to_gguf(self.gguf_writer) + + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + tokenizer_config_json = json.load(f) + if "add_prefix_space" in tokenizer_config_json: + self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"]) + + # Apply to granite small models only + if self.hparams.get("vocab_size", 32000) == 49152: + self.gguf_writer.add_add_bos_token(False) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + + if "head_dim" in hparams: + rope_dim = hparams["head_dim"] + else: + rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] + self.gguf_writer.add_rope_dimension_count(rope_dim) + + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + + @staticmethod + def permute(weights: Tensor, n_head: int, n_head_kv: int | None): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") + is_vision_tensor = "vision_tower" in name \ + or "vision_model" in name \ + or "model.connector" in name \ + or "multi_modal_projector" in name + + if is_vision_tensor: + return [] # skip vision tensors + elif self.hf_arch == "LlamaModel": + name = "model." + name + elif name.startswith("model.text_model"): + name = name.replace("text_model.", "") # for SmolVLM + elif name.startswith("language_model."): + name = name.replace("language_model.", "") # for the rest + + if self.undo_permute: + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + + # process the experts separately + if name.find("block_sparse_moe.experts") != -1: + n_experts = self.hparams["num_local_experts"] + + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for wid in ["w1", "w2", "w3"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"layers.{bid}.feed_forward.experts.{wid}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): + if rope_scaling.get("rope_type", '').lower() == "llama3": + base = self.hparams.get("rope_theta", 10000.0) + dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + + factor = rope_scaling.get("factor", 8.0) + low_freq_factor = rope_scaling.get("low_freq_factor", 1.0) + high_freq_factor = rope_scaling.get("high_freq_factor", 4.0) + old_context_len = self.hparams.get("original_max_position_embeddings", 8192) + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + # assert low_freq_wavelen != high_freq_wavelen # Errors for Llama4 + + rope_factors = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + rope_factors.append(1) + elif wavelen > low_freq_wavelen: + rope_factors.append(factor) + else: + smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + rope_factors.append(1 / ((1 - smooth) / factor + smooth)) + + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + +@ModelBase.register("ArceeForCausalLM") +class ArceeModel(LlamaModel): + model_arch = gguf.MODEL_ARCH.ARCEE + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self._try_set_pooling_type() + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + + +@ModelBase.register( + "LlavaForConditionalGeneration", # pixtral + "Mistral3ForConditionalGeneration", # mistral small 3.1 +) +class LlavaVisionModel(MmprojModel): + img_break_tok_id = -1 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.hparams["model_type"] == "pixtral": + # layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py + self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5) + self.img_break_tok_id = self.get_token_id("[IMG_BREAK]") + logger.info(f"Image break token id: {self.img_break_tok_id}") + else: + raise ValueError(f"Unsupported model type: {self.hparams['model_type']}") + + def get_token_id(self, token: str) -> int: + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + added_tokens_decoder = json.load(f)['added_tokens_decoder'] + for id_, token_data in added_tokens_decoder.items(): + if token_data["content"] == token: + return int(id_) + raise ValueError(f"Token '{token}' not found in tokenizer config.") + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + if hparams["model_type"] == "pixtral": + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PIXTRAL) + self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"]) + + # hidden_act + if hparams["hidden_act"] == "silu": + self.gguf_writer.add_vision_use_silu(True) + elif hparams["hidden_act"] == "gelu": + self.gguf_writer.add_vision_use_gelu(True) + else: + raise ValueError(f"Unsupported hidden_act: {hparams['hidden_act']}") + + # spatial_merge_size + if "spatial_merge_size" in self.global_config: + self.gguf_writer.add_vision_spatial_merge_size(self.global_config["spatial_merge_size"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + n_head = self.hparams["num_attention_heads"] + n_kv_head = n_head + + if name.startswith("multi_modal_projector.") or name.startswith("vision_tower."): + # process vision tensors + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + return [(self.map_tensor_name(name), data_torch)] + + if self.img_break_tok_id > 0 and "embed_tokens.weight" in name: + logger.info(f"Extracting [IMG_BREAK] token embedding from {name}") + # for pixtral model, we need to extract the [IMG_BREAK] token embedding + img_break_embd = data_torch[self.img_break_tok_id] + name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK] + return [(self.map_tensor_name(name), img_break_embd)] + + return [] # skip other tensors + + +@ModelBase.register("Idefics3ForConditionalGeneration", "SmolVLMForConditionalGeneration") +class SmolVLMModel(MmprojModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.hparams["model_type"] == "smolvlm_vision": + # fix for SmolVLM2, missing some keys in config.json + # default values are taken from transformers code + self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1152) + self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16) + self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 3072) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.IDEFICS3) + self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5)) + self.gguf_writer.add_vision_projector_scale_factor(self.global_config.get("scale_factor", 2)) + self.gguf_writer.add_vision_use_gelu(True) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + del bid, new_name, n_dims # unused + if ".embeddings." in name: + return gguf.GGMLQuantizationType.F32 + return False + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + is_vision_tensor = "vision_tower" in name or "vision_model" in name or "model.connector" in name + + if is_vision_tensor: + return [(self.map_tensor_name(name), data_torch)] + + return [] # skip other tensors + + +@ModelBase.register("Llama4ForConditionalGeneration") +class Llama4Model(LlamaModel): + model_arch = gguf.MODEL_ARCH.LLAMA4 + undo_permute = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # IMPORTANT: the normal "intermediate_size" is renamed to "intermediate_size_mlp", we need to undo this + self.hparams["intermediate_size_moe"] = self.hparams["intermediate_size"] + self.hparams["intermediate_size"] = self.hparams["intermediate_size_mlp"] + + def set_vocab(self): + self._set_vocab_gpt2() + self.gguf_writer.add_add_bos_token(True) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_interleave_moe_layer_step(self.hparams["interleave_moe_layer_step"]) + self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): + if name.startswith("language_model."): + name = name.replace("language_model.", "") + + # split the gate_up into gate and up + if "gate_up_proj" in name: + name_up = name.replace("gate_up_proj", "up_proj.weight") + name_gate = name.replace("gate_up_proj", "gate_proj.weight") + dim_half = data_torch.shape[-1] // 2 + gate_proj_weight, up_proj_weight = data_torch.transpose(-1, -2).split(dim_half, dim=-2) + return [ + (self.map_tensor_name(name_gate), gate_proj_weight), + (self.map_tensor_name(name_up), up_proj_weight) + ] + + if name.endswith("down_proj"): + name += ".weight" + data_torch = data_torch.transpose(-1, -2) + + if "multi_modal_projector" in name or "vision_model" in name: + return [] + return super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("Llama4ForConditionalGeneration") +class Llama4VisionModel(MmprojModel): + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LLAMA4) + self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams["norm_eps"]) + self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / self.hparams["pixel_shuffle_ratio"])) + assert self.hparams["hidden_act"] == "gelu" + self.gguf_writer.add_vision_use_gelu(True) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + if "multi_modal_projector" in name or "vision_model" in name: + # process vision tensors + if "positional_embedding_vlm" in name and ".weight" not in name: + name += ".weight" + if "multi_modal_projector.linear_1" in name: + # despite the name with number postfix, this is a single fully connected layer + return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MMPROJ_FC], data_torch)] + return [(self.map_tensor_name(name), data_torch)] + return [] + + +@ModelBase.register("Mistral3ForConditionalGeneration") +class Mistral3Model(LlamaModel): + model_arch = gguf.MODEL_ARCH.LLAMA + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): + name = name.replace("language_model.", "") + if "multi_modal_projector" in name or "vision_tower" in name: + return [] + return super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("DeciLMForCausalLM") +class DeciModel(TextModel): + model_arch = gguf.MODEL_ARCH.DECI + + @staticmethod + def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int: + # DeciLM-specific code + intermediate_size = int(2 * ffn_mult * n_embd / 3) + return DeciModel._find_multiple(intermediate_size, 256) + + @staticmethod + def _find_multiple(n: int, k: int) -> int: + # DeciLM-specific code + if n % k == 0: + return n + return n + k - (n % k) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if "block_configs" in self.hparams: # Llama-3_1-Nemotron-51B + _block_configs: list[dict[str,Any]] = self.hparams["block_configs"] + assert self.block_count == len(_block_configs) + self._num_kv_heads = list() + self._num_heads = list() + _ffn_multipliers = list() + # ***linear attention layer*** + # if n_heads_in_group is None and replace_with_linear is True + # then _num_kv_heads[il] is 0 and _num_heads[il] is num_attention_heads + # ***attention-free layer*** + # if n_heads_in_group is None and replace_with_linear is False + # then _num_kv_heads[il] is 0 and _num_heads[il] is 0 + # ***normal attention-layer*** + # if n_heads_in_group is not None, then + # _num_kv_heads[il] is num_attention_head // n_heads_in_group and + # _num_heads[il] is num_attention_head + # ***dummy layer*** for nemotron 253B + # if n_heads_in_group is None and ffn_mult is None + # then _num_kv_heads[il] is 0 and _num_heads[il] is 0 and _ffn_dims is 0 + for il in range(len(_block_configs)): + if _block_configs[il]["attention"]["n_heads_in_group"] is None: + if _block_configs[il]["attention"]["replace_with_linear"] is True: + self._num_kv_heads.append(0) + self._num_heads.append(self.hparams["num_attention_heads"]) + else: + self._num_kv_heads.append(0) + self._num_heads.append(0) + else: + self._num_kv_heads.append(self.hparams["num_attention_heads"] // _block_configs[il]["attention"]["n_heads_in_group"]) + self._num_heads.append(self.hparams["num_attention_heads"]) + if _block_configs[il]["ffn"]["ffn_mult"] is None: # dummy layer + _ffn_multipliers.append(0.0) + else: + _ffn_multipliers.append(_block_configs[il]["ffn"]["ffn_mult"]) + assert self.block_count == len(self._num_kv_heads) + assert self.block_count == len(self._num_heads) + assert self.block_count == len(_ffn_multipliers) + assert isinstance(self._num_kv_heads, list) and isinstance(self._num_kv_heads[0], int) + assert isinstance(self._num_heads, list) and isinstance(self._num_heads[0], int) + assert isinstance(_ffn_multipliers, list) and isinstance(_ffn_multipliers[0], float) + self._ffn_dims: list[int] = [ + DeciModel._ffn_mult_to_intermediate_size(multiplier, self.hparams["hidden_size"]) + for multiplier in _ffn_multipliers + ] + + def set_vocab(self): + # Please change tokenizer_config.json of Llama-3_1-Nemotron-51B's + # eos_token from '|eot_id|' to '|end_of_text|' + if self.hparams.get("vocab_size", 128256) == 128256: + tokens, toktypes, tokpre = self.get_vocab_base() + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + special_vocab.add_to_gguf(self.gguf_writer) + else: + # DeciLM-7B + self._set_vocab_llama_hf() + + def set_gguf_parameters(self): + if "block_configs" in self.hparams: # Llama-3_1-Nemotron-51B + assert self.block_count == len(self._num_kv_heads) + assert self.block_count == len(self._num_heads) + assert self.block_count == len(self._ffn_dims) + if (rope_theta := self.hparams.get("rope_theta")) is not None: + self.gguf_writer.add_rope_freq_base(rope_theta) + self.gguf_writer.add_head_count_kv(self._num_kv_heads) + self.gguf_writer.add_head_count(self._num_heads) + self.gguf_writer.add_feed_forward_length(self._ffn_dims) + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) + self.gguf_writer.add_key_length(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_value_length(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_file_type(self.ftype) + else: # DeciLM-7B + super().set_gguf_parameters() + if "num_key_value_heads_per_layer" in self.hparams: # DeciLM-7B + self._num_kv_heads: list[int] = self.hparams["num_key_value_heads_per_layer"] + assert self.block_count == len(self._num_kv_heads) + self.gguf_writer.add_head_count_kv(self._num_kv_heads) + hparams = self.hparams + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + + if "head_dim" in hparams: + rope_dim = hparams["head_dim"] + else: + rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] + self.gguf_writer.add_rope_dimension_count(rope_dim) + + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + + @staticmethod + def permute(weights: Tensor, n_head: int, n_head_kv: int | None): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + n_head = self.hparams["num_attention_heads"] + if bid is not None: + if "num_key_value_heads_per_layer" in self.hparams: + n_kv_head = self.hparams["num_key_value_heads_per_layer"][bid] + elif "block_configs" in self.hparams: + n_kv_head = self._num_kv_heads[bid] + n_head = self._num_heads[bid] + else: + n_kv_head = self.hparams.get("num_key_value_heads") + else: + n_kv_head = self.hparams.get("num_key_value_heads") + + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = DeciModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = DeciModel.permute(data_torch, n_head, n_kv_head) + return [(self.map_tensor_name(name), data_torch)] + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): + if rope_scaling.get("rope_type", '').lower() == "llama3": + base = self.hparams.get("rope_theta", 10000.0) + dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + + factor = rope_scaling.get("factor", 8.0) + low_freq_factor = rope_scaling.get("low_freq_factor", 1.0) + high_freq_factor = rope_scaling.get("high_freq_factor", 4.0) + old_context_len = self.hparams.get("original_max_position_embeddings", 8192) + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + assert low_freq_wavelen != high_freq_wavelen + + rope_factors = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + rope_factors.append(1) + elif wavelen > low_freq_wavelen: + rope_factors.append(factor) + else: + smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + rope_factors.append(1 / ((1 - smooth) / factor + smooth)) + + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) + + def prepare_tensors(self): + super().prepare_tensors() + + +@ModelBase.register("BitnetForCausalLM") +class BitnetModel(TextModel): + model_arch = gguf.MODEL_ARCH.BITNET + + def set_vocab(self): + self._set_vocab_sentencepiece() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(1.0) + + def weight_quant(self, weight: Tensor) -> Tensor: + dtype = weight.dtype + weight = weight.float() + scale = weight.abs().mean().clamp(min=1e-5) + iscale = 1 / scale + # TODO: multiply by the scale directly instead of inverting it twice + # (this is also unnecessarily doubly inverted upstream) + # ref: https://huggingface.co/1bitLLM/bitnet_b1_58-3B/blob/af89e318d78a70802061246bf037199d2fb97020/utils_quant.py#L10 + result = (weight * iscale).round().clamp(-1, 1) / iscale + return result.type(dtype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + new_name = self.map_tensor_name(name) + + if any(self.match_model_tensor_name(new_name, key, bid) for key in [ + gguf.MODEL_TENSOR.ATTN_Q, + gguf.MODEL_TENSOR.ATTN_K, + gguf.MODEL_TENSOR.ATTN_V, + gguf.MODEL_TENSOR.ATTN_OUT, + gguf.MODEL_TENSOR.FFN_UP, + gguf.MODEL_TENSOR.FFN_DOWN, + gguf.MODEL_TENSOR.FFN_GATE, + ]): + # transform weight into 1/0/-1 (in fp32) + data_torch = self.weight_quant(data_torch) + + yield (new_name, data_torch) + + +@ModelBase.register("GrokForCausalLM") +class GrokModel(TextModel): + model_arch = gguf.MODEL_ARCH.GROK + + def set_vocab(self): + self._set_vocab_sentencepiece() + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # process the experts separately + if name.find(".moe.") != -1: + n_experts = self.hparams["num_local_experts"] + + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for wid in ["linear", "linear_1", "linear_v"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"transformer.decoder_layer.{bid}.moe.{wid}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + +@ModelBase.register("DbrxForCausalLM") +class DbrxModel(TextModel): + model_arch = gguf.MODEL_ARCH.DBRX + + def set_gguf_parameters(self): + ffn_config = self.hparams["ffn_config"] + attn_config = self.hparams["attn_config"] + self.gguf_writer.add_block_count(self.hparams["n_layers"]) + + self.gguf_writer.add_context_length(self.hparams["max_seq_len"]) + self.gguf_writer.add_embedding_length(self.hparams["d_model"]) + self.gguf_writer.add_feed_forward_length(ffn_config["ffn_hidden_size"]) + + self.gguf_writer.add_head_count(self.hparams["n_heads"]) + self.gguf_writer.add_head_count_kv(attn_config["kv_n_heads"]) + + self.gguf_writer.add_rope_freq_base(attn_config["rope_theta"]) + + self.gguf_writer.add_clamp_kqv(attn_config["clip_qkv"]) + + self.gguf_writer.add_expert_count(ffn_config["moe_num_experts"]) + self.gguf_writer.add_expert_used_count(ffn_config["moe_top_k"]) + + self.gguf_writer.add_layer_norm_eps(1e-5) + + self.gguf_writer.add_file_type(self.ftype) + logger.info(f"gguf: file type = {self.ftype}") + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + n_expert = self.hparams["ffn_config"]["moe_num_experts"] + n_ff = self.hparams["ffn_config"]["ffn_hidden_size"] + n_embd = self.hparams["d_model"] + + # Specific behavior for experts tensors: suffix .weight, view as 3D and transpose + # original implementation expects (n_expert, n_ff, n_embd) for all experts weights + # But llama.cpp moe graph works differently + # AND the dimensions in ggml are typically in the reverse order of the pytorch dimensions + # so (n_expert, n_ff, n_embd) in pytorch is {n_embd, n_ff, n_expert} in ggml_tensor + exp_tensor_names = {"ffn.experts.mlp.w1": None, # LLM_TENSOR_FFN_GATE_EXPS ggml_tensor->ne{n_embd, n_ff, n_expert} + "ffn.experts.mlp.w2": (0, 2, 1), # LLM_TENSOR_FFN_DOWN_EXPS ggml_tensor->ne{n_ff, n_embd, n_expert} + "ffn.experts.mlp.v1": None} # LLM_TENSOR_FFN_UP_EXPS ggml_tensor->ne{n_embd, n_ff, n_expert} + experts = False + + for exp_tensor_name in exp_tensor_names.keys(): + if name.find(exp_tensor_name) != -1 and name.find(".weight") == -1: + experts = True + data_torch = data_torch.view(n_expert, n_ff, n_embd) + if (permute_tensor := exp_tensor_names[exp_tensor_name]) is not None: + data_torch = data_torch.permute(*permute_tensor) + break + + # map tensor names + # In MoE models the ffn tensors are typically most of the model weights, + # and need to be quantizable. Quantize expects tensor names to be suffixed by .weight. + # Every other model has the weight names ending in .weight, + # let's assume that is the convention which is not the case for dbrx: + # https://huggingface.co/databricks/dbrx-instruct/blob/main/model.safetensors.index.json#L15 + new_name = self.map_tensor_name(name if not experts else name + ".weight", try_suffixes=(".weight",)) + + return [(new_name, data_torch)] + + def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: + del name, new_name, bid # unused + + return n_dims > 1 + + +@ModelBase.register("MiniCPMForCausalLM") +class MiniCPMModel(TextModel): + model_arch = gguf.MODEL_ARCH.MINICPM + + def set_gguf_parameters(self): + super().set_gguf_parameters() + embedding_scale = float(self.hparams["scale_emb"]) + self.gguf_writer.add_embedding_scale(embedding_scale) + logger.info(f"gguf: (minicpm) embedding_scale = {embedding_scale}") + residual_scale = self.hparams["scale_depth"] / self.hparams["num_hidden_layers"] ** 0.5 + self.gguf_writer.add_residual_scale(residual_scale) + logger.info(f"gguf: (minicpm) residual_scale = {residual_scale}") + logit_scale = self.hparams["hidden_size"] / self.hparams["dim_model_base"] + self.gguf_writer.add_logit_scale(logit_scale) + logger.info(f"gguf: (minicpm) logit_scale = {logit_scale}") + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "longrope": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LONGROPE) + logger.info(f"gguf: (minicpm) rope_scaling_type = {gguf.RopeScalingType.LONGROPE}") + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + rope_dims = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + + rope_scaling = self.find_hparam(['rope_scaling'], True) + if rope_scaling is not None: + long_factors = rope_scaling.get('long_factor', None) + short_factors = rope_scaling.get('short_factor', None) + + if long_factors is None or short_factors is None: + raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') + + if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: + raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}') + + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32)) + + def set_vocab(self): + self._set_vocab_sentencepiece() + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") + + # HF models permute some of the tensors, so we need to undo that + if name.endswith(("q_proj.weight")): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight")): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + + return [(self.map_tensor_name(name), data_torch)] + + +@ModelBase.register("MiniCPM3ForCausalLM") +class MiniCPM3Model(TextModel): + model_arch = gguf.MODEL_ARCH.MINICPM3 + + def set_gguf_parameters(self): + hparams = self.hparams + + self.gguf_writer.add_file_type(self.ftype) + self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(hparams["hidden_size"]) + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) + self.gguf_writer.add_head_count(hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"]) + self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"]) + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: + self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) + self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) + self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + rope_scaling = self.find_hparam(['rope_scaling'], True) + if rope_scaling is not None: + rope_dims = self.hparams["qk_rope_head_dim"] + + long_factors = rope_scaling.get('long_factor', None) + short_factors = rope_scaling.get('short_factor', None) + + if long_factors is None or short_factors is None: + raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') + + if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: + raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}') + + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32)) + + def set_vocab(self): + self._set_vocab_sentencepiece() + + def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor: + if n_kv_head is not None and n_head != n_kv_head: + n_head //= n_kv_head + + return ( + weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape) + ) + + +@ModelBase.register("QWenLMHeadModel") +class QwenModel(TextModel): + model_arch = gguf.MODEL_ARCH.QWEN + + @staticmethod + def token_bytes_to_string(b): + from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode + byte_encoder = bytes_to_unicode() + return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')]) + + @staticmethod + def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]: + parts = [bytes([b]) for b in token] + while True: + min_idx = None + min_rank = None + for i, pair in enumerate(zip(parts[:-1], parts[1:])): + rank = mergeable_ranks.get(pair[0] + pair[1]) + if rank is not None and (min_rank is None or rank < min_rank): + min_idx = i + min_rank = rank + if min_rank is None or (max_rank is not None and min_rank >= max_rank): + break + assert min_idx is not None + parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:] + return parts + + def set_vocab(self): + self._set_vocab_qwen() + + def set_gguf_parameters(self): + self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"]) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"]) + self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_file_type(self.ftype) + + +@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration") +class Qwen2Model(TextModel): + model_arch = gguf.MODEL_ARCH.QWEN2 + + def set_vocab(self): + try: + self._set_vocab_sentencepiece() + except FileNotFoundError: + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self._try_set_pooling_type() + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if self.hf_arch == "Qwen2Model": + name = f"model.{name}" # map to Qwen2ForCausalLM tensors + if "language_model." in name: + name = name.replace("language_model.", "") # for InternVL + if name.startswith("mlp") or name.startswith("multi_modal_projector") \ + or name.startswith("vision_model") or name.startswith("audio_tower"): + # skip vision and audio tensors + return [] + yield from super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register( + "Qwen2VLModel", + "Qwen2VLForConditionalGeneration", + "Qwen2_5_VLForConditionalGeneration", + "Qwen2_5OmniModel", +) +class Qwen2VLModel(TextModel): + model_arch = gguf.MODEL_ARCH.QWEN2VL + + def set_gguf_parameters(self): + super().set_gguf_parameters() + mrope_section = self.hparams["rope_scaling"]["mrope_section"] + mrope_section += [0] * max(0, 4 - len(mrope_section)) + self.gguf_writer.add_rope_dimension_sections(mrope_section) + + def set_vocab(self): + try: + self._set_vocab_sentencepiece() + except FileNotFoundError: + self._set_vocab_gpt2() + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + if name.startswith("thinker."): + name = name.replace("thinker.", "") + if name.startswith("visual") or name.startswith("audio") or \ + name.startswith("talker") or name.startswith("token2wav"): + # skip multimodal tensors + return [] + return [(self.map_tensor_name(name), data_torch)] + + +@ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration") +class Qwen2VLVisionModel(MmprojModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_vision is not None + self.hparams_vision["image_size"] = self.hparams_vision.get("image_size", 560) + # rename config.json values + self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads") + self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth") + if "embed_dim" in self.hparams_vision: # qwen2vl + self.hparams_vision["intermediate_size"] = self.hparams_vision.get("hidden_size") + self.hparams_vision["hidden_size"] = self.hparams_vision.get("embed_dim") + + def set_gguf_parameters(self): + super().set_gguf_parameters() + assert self.hparams_vision is not None + hparams = self.hparams_vision + model_type = self.global_config['model_type'] + if model_type == 'qwen2_vl': + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN2VL) + elif model_type == 'qwen2_5_vl' or model_type == 'qwen2_5_omni': + if model_type == 'qwen2_5_omni': + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25O) + else: + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25VL) + self.gguf_writer.add_vision_use_silu(True) + # find n_wa_pattern (window attention pattern) + fullatt_block_indexes = hparams.get("fullatt_block_indexes") + assert fullatt_block_indexes is not None, "fullatt_block_indexes is required for qwen2_5_vl" + n_wa_pattern = fullatt_block_indexes[0] + 1 + # validate n_wa_pattern + for i in range(1, len(fullatt_block_indexes)): + if fullatt_block_indexes[i] - fullatt_block_indexes[i - 1] != n_wa_pattern: + raise ValueError(f"Invalid fullatt_block_indexes: {fullatt_block_indexes}") + self.gguf_writer.add_vision_n_wa_pattern(n_wa_pattern) + else: + raise ValueError(f"Unknown QwenVL model type: {self.global_config['model_type']}") + # default values below are taken from HF tranformers code + self.gguf_writer.add_vision_attention_layernorm_eps(self.global_config.get("rms_norm_eps", 1e-6)) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + del bid, name, n_dims # unused + if ".patch_embd." in new_name: + return gguf.GGMLQuantizationType.F16 + if ".position_embd." in new_name: + return gguf.GGMLQuantizationType.F32 + return False + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + if name.startswith("visual."): + # process visual tensors + # split QKV tensors if needed + if ".qkv." in name: + if data_torch.ndim == 2: # weight + c3, _ = data_torch.shape + else: # bias + c3 = data_torch.shape[0] + assert c3 % 3 == 0 + c = c3 // 3 + wq = data_torch[:c] + wk = data_torch[c: c * 2] + wv = data_torch[c * 2:] + return [ + (self.map_tensor_name(name.replace("qkv", "q")), wq), + (self.map_tensor_name(name.replace("qkv", "k")), wk), + (self.map_tensor_name(name.replace("qkv", "v")), wv), + ] + elif 'patch_embed.proj.weight' in name: + # split Conv3D into Conv2Ds + c1, c2, kt, kh, kw = data_torch.shape + del c1, c2, kh, kw # unused + assert kt == 2, "Current implmentation only support temporal_patch_size of 2" + return [ + (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight" , data_torch[:, :, 0, ...]), + (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...]), + ] + else: + return [(self.map_tensor_name(name), data_torch)] + return [] # skip other tensors + + +@ModelBase.register("Qwen2_5OmniModel") +class Qwen25OmniModel(Qwen2VLVisionModel): + has_vision_encoder = True + has_audio_encoder = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_audio is not None + self.hparams_audio["hidden_size"] = self.hparams_audio["d_model"] + self.hparams_audio["intermediate_size"] = self.hparams_audio["encoder_ffn_dim"] + self.hparams_audio["num_attention_heads"] = self.hparams_audio["encoder_attention_heads"] + + def set_gguf_parameters(self): + super().set_gguf_parameters() + assert self.hparams_audio is not None + self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["num_mel_bins"]) + self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams_audio.get("layer_norm_eps", 1e-5)) + + def get_vision_config(self) -> dict[str, Any] | None: + return self.global_config["thinker_config"].get("vision_config") + + def get_audio_config(self) -> dict[str, Any] | None: + return self.global_config["thinker_config"].get("audio_config") + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + # SinusoidsPositionEmbedding + assert self.hparams_audio is not None + max_timescale = 10000 + length = 1500 + channels = self.hparams_audio["hidden_size"] + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float()) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + pos_embd = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1).to(dtype=torch.float32) + yield ("audio_tower.embed_positions.weight", pos_embd) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + del bid, new_name, n_dims # unused + if ".conv" in name and ".weight" in name: + return gguf.GGMLQuantizationType.F16 + return False + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.startswith("thinker."): + name = name.replace("thinker.", "") + + if name.startswith("audio_tower"): + # process audio tensors + if "conv1.bias" in name or "conv2.bias" in name: + # transpose conv1 and conv2 bias + data_torch = data_torch.unsqueeze(-1) + if "audio_bos_eos_token" in name: + # this tensor is left unused in transformers code + # https://github.com/huggingface/transformers/blob/6e3063422c4b1c014aa60c32b9254fd2902f0f28/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py#L1809 + return [] + return [(self.map_tensor_name(name), data_torch)] + + return super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("InternVisionModel") +class InternVisionModel(MmprojModel): + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.INTERNVL) + self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"]) + # hidden_act + if hparams["hidden_act"] == "silu": + self.gguf_writer.add_vision_use_silu(True) + elif hparams["hidden_act"] == "gelu": + self.gguf_writer.add_vision_use_gelu(True) + else: + raise ValueError(f"Unsupported hidden_act: {hparams['hidden_act']}") + # downsample_ratio + downsample_ratio = self.global_config.get("downsample_ratio") + assert downsample_ratio is not None + self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio)) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + del bid, name, n_dims # unused + if ".patch_embd." in new_name: + return gguf.GGMLQuantizationType.F16 + if ".position_embd." in new_name: + return gguf.GGMLQuantizationType.F32 + return False + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + if name.startswith("vision_model") or name.startswith("mlp"): + # process visual tensors + # correct name + if name.startswith("vision_model"): + name = "vision_tower." + name + if (".ls" in name or "position_embedding" in name) and not name.endswith(".weight"): + name += ".weight" + # split QKV tensors if needed + if ".qkv." in name: + if data_torch.ndim == 2: # weight + c3, _ = data_torch.shape + else: # bias + c3 = data_torch.shape[0] + assert c3 % 3 == 0 + c = c3 // 3 + wq = data_torch[:c] + wk = data_torch[c: c * 2] + wv = data_torch[c * 2:] + return [ + (self.map_tensor_name(name.replace("attn.qkv", "self_attn.q_proj")), wq), + (self.map_tensor_name(name.replace("attn.qkv", "self_attn.k_proj")), wk), + (self.map_tensor_name(name.replace("attn.qkv", "self_attn.v_proj")), wv), + ] + return [(self.map_tensor_name(name), data_torch)] + return [] # skip other tensors + + +@ModelBase.register("WavTokenizerDec") +class WavTokenizerDecModel(TextModel): + model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if \ + name.endswith("codebook.cluster_size") or \ + name.endswith("codebook.embed_avg") or \ + name.endswith("codebook.inited"): + logger.debug(f"Skipping {name!r}") + return [] + + logger.info(f"{self.map_tensor_name(name)} -> {data_torch.shape}") + + return [(self.map_tensor_name(name), data_torch)] + + def set_vocab(self): + self._set_vocab_none() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_vocab_size (self.hparams["vocab_size"]) + self.gguf_writer.add_features_length (self.hparams["n_embd_features"]) + self.gguf_writer.add_feed_forward_length(self.hparams["n_ff"]) + self.gguf_writer.add_group_norm_eps (self.hparams["group_norm_epsilon"]) + self.gguf_writer.add_group_norm_groups (self.hparams["group_norm_groups"]) + + self.gguf_writer.add_posnet_embedding_length(self.hparams["posnet"]["n_embd"]) + self.gguf_writer.add_posnet_block_count (self.hparams["posnet"]["n_layer"]) + + self.gguf_writer.add_convnext_embedding_length(self.hparams["convnext"]["n_embd"]) + self.gguf_writer.add_convnext_block_count (self.hparams["convnext"]["n_layer"]) + + self.gguf_writer.add_causal_attention(False) + + +@ModelBase.register("Qwen2MoeForCausalLM") +class Qwen2MoeModel(TextModel): + model_arch = gguf.MODEL_ARCH.QWEN2MOE + + def set_gguf_parameters(self): + super().set_gguf_parameters() + if (n_experts := self.hparams.get("num_experts")) is not None: + self.gguf_writer.add_expert_count(n_experts) + if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) + logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}") + if (shared_expert_intermediate_size := self.hparams.get('shared_expert_intermediate_size')) is not None: + self.gguf_writer.add_expert_shared_feed_forward_length(shared_expert_intermediate_size) + logger.info(f"gguf: expert shared feed forward length = {shared_expert_intermediate_size}") + # YaRN is not enabled by default + # To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # process the experts separately + if name.find("experts") != -1: + n_experts = self.hparams["num_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + +@ModelBase.register("Qwen3ForCausalLM") +class Qwen3Model(Qwen2Model): + model_arch = gguf.MODEL_ARCH.QWEN3 + + +@ModelBase.register("Qwen3MoeForCausalLM") +class Qwen3MoeModel(Qwen2MoeModel): + model_arch = gguf.MODEL_ARCH.QWEN3MOE + + +@ModelBase.register("GPT2LMHeadModel") +class GPT2Model(TextModel): + model_arch = gguf.MODEL_ARCH.GPT2 + + def set_gguf_parameters(self): + self.gguf_writer.add_block_count(self.hparams["n_layer"]) + self.gguf_writer.add_context_length(self.hparams["n_ctx"]) + self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) + self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) + self.gguf_writer.add_head_count(self.hparams["n_head"]) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + tensors: list[tuple[str, Tensor]] = [] + + # we don't need these + if name.endswith((".attn.bias", ".attn.masked_bias")): + return tensors + + if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_proj.weight")): + data_torch = data_torch.transpose(1, 0) + + new_name = self.map_tensor_name(name) + + tensors.append((new_name, data_torch)) + + return tensors + + +@ModelBase.register("PhiForCausalLM") +class Phi2Model(TextModel): + model_arch = gguf.MODEL_ARCH.PHI2 + + def set_gguf_parameters(self): + block_count = self.find_hparam(["num_hidden_layers", "n_layer"]) + + rot_pct = self.find_hparam(["partial_rotary_factor"]) + n_embd = self.find_hparam(["hidden_size", "n_embd"]) + n_head = self.find_hparam(["num_attention_heads", "n_head"]) + + self.gguf_writer.add_context_length(self.find_hparam(["n_positions", "max_position_embeddings"])) + + self.gguf_writer.add_embedding_length(n_embd) + self.gguf_writer.add_feed_forward_length(4 * n_embd) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(n_head) + self.gguf_writer.add_head_count_kv(n_head) + self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_epsilon", "layer_norm_eps"])) + self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head) + self.gguf_writer.add_file_type(self.ftype) + self.gguf_writer.add_add_bos_token(False) + + +@ModelBase.register("Phi3ForCausalLM") +class Phi3MiniModel(TextModel): + model_arch = gguf.MODEL_ARCH.PHI3 + + def set_vocab(self): + # Phi-4 model uses GPT2Tokenizer + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + tokenizer_config_json = json.load(f) + tokenizer_class = tokenizer_config_json['tokenizer_class'] + if tokenizer_class == 'GPT2Tokenizer': + return self._set_vocab_gpt2() + + from sentencepiece import SentencePieceProcessor + + tokenizer_path = self.dir_model / 'tokenizer.model' + + if not tokenizer_path.is_file(): + raise ValueError(f'Error: Missing {tokenizer_path}') + + tokenizer = SentencePieceProcessor() + tokenizer.LoadFromFile(str(tokenizer_path)) + + vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + + tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] + scores: list[float] = [-10000.0] * vocab_size + toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size + + for token_id in range(tokenizer.vocab_size()): + + piece = tokenizer.IdToPiece(token_id) + text = piece.encode("utf-8") + score = tokenizer.GetScore(token_id) + + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.IsUnknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.IsControl(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.IsUnused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.IsByte(token_id): + toktype = SentencePieceTokenTypes.BYTE + + tokens[token_id] = text + scores[token_id] = score + toktypes[token_id] = toktype + + added_tokens_file = self.dir_model / 'added_tokens.json' + if added_tokens_file.is_file(): + with open(added_tokens_file, "r", encoding="utf-8") as f: + added_tokens_json = json.load(f) + + for key in added_tokens_json: + token_id = added_tokens_json[key] + if token_id >= vocab_size: + logger.debug(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') + continue + + tokens[token_id] = key.encode("utf-8") + scores[token_id] = -1000.0 + toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + tokenizer_config_json = json.load(f) + added_tokens_decoder = tokenizer_config_json.get("added_tokens_decoder", {}) + for token_id, foken_data in added_tokens_decoder.items(): + token_id = int(token_id) + token = foken_data["content"].encode("utf-8") + if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: + if tokens[token_id] != token: + logger.warning(f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token.decode("utf-8")!r}') + tokens[token_id] = token + scores[token_id] = -1000.0 + toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + if foken_data.get("special"): + toktypes[token_id] = SentencePieceTokenTypes.CONTROL + + tokenizer_file = self.dir_model / 'tokenizer.json' + if tokenizer_file.is_file(): + with open(tokenizer_file, "r", encoding="utf-8") as f: + tokenizer_json = json.load(f) + added_tokens = tokenizer_json.get("added_tokens", []) + for foken_data in added_tokens: + token_id = int(foken_data["id"]) + token = foken_data["content"].encode("utf-8") + if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: + if tokens[token_id] != token: + logger.warning(f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token.decode("utf-8")!r}') + tokens[token_id] = token + scores[token_id] = -1000.0 + toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + if foken_data.get("special"): + toktypes[token_id] = SentencePieceTokenTypes.CONTROL + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + def set_gguf_parameters(self): + block_count = self.find_hparam(["num_hidden_layers", "n_layer"]) + + n_embd = self.find_hparam(["hidden_size", "n_embd"]) + n_head = self.find_hparam(["num_attention_heads", "n_head"]) + n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"]) + rms_eps = self.find_hparam(["rms_norm_eps"]) + max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"]) + orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"]) + rot_pct = self.hparams.get("partial_rotary_factor", 1.0) + rope_dims = int(rot_pct * n_embd) // n_head + + self.gguf_writer.add_context_length(max_pos_embds) + self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds) + self.gguf_writer.add_embedding_length(n_embd) + self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"])) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(n_head) + self.gguf_writer.add_head_count_kv(n_head_kv) + self.gguf_writer.add_layer_norm_rms_eps(rms_eps) + self.gguf_writer.add_rope_dimension_count(rope_dims) + self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"])) + self.gguf_writer.add_file_type(self.ftype) + sliding_window = self.hparams.get("sliding_window") + # use zero value of sliding_window to distinguish Phi-4 from other PHI3 models + if sliding_window is None: + sliding_window = 0 + self.gguf_writer.add_sliding_window(sliding_window) + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + n_embd = self.find_hparam(["hidden_size", "n_embd"]) + n_head = self.find_hparam(["num_attention_heads", "n_head"]) + max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"]) + orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"]) + rot_pct = self.hparams.get("partial_rotary_factor", 1.0) + rope_dims = int(rot_pct * n_embd) // n_head + + # write rope scaling for long context (128k) model + rope_scaling = self.find_hparam(['rope_scaling'], True) + if rope_scaling is None: + return + + scale = max_pos_embds / orig_max_pos_embds + + rope_scaling_type = rope_scaling.get('rope_type', rope_scaling.get('type', '')).lower() + if len(rope_scaling_type) == 0: + raise KeyError('Missing the required key rope_scaling.type') + + if rope_scaling_type == 'su' or rope_scaling_type == 'longrope': + attn_factor = math.sqrt(1 + math.log(scale) / math.log(orig_max_pos_embds)) if scale > 1.0 else 1.0 + elif rope_scaling_type == 'yarn': + attn_factor = 0.1 * math.log(scale) + 1.0 if scale > 1.0 else 1.0 + else: + raise NotImplementedError(f'The rope scaling type {rope_scaling_type} is not supported yet') + + self.gguf_writer.add_rope_scaling_attn_factors(attn_factor) + + long_factors = rope_scaling.get('long_factor', None) + short_factors = rope_scaling.get('short_factor', None) + + if long_factors is None or short_factors is None: + raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') + + if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: + raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}. long_factors = {len(long_factors)}, short_factors = {len(short_factors)}.') + + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32)) + + +@ModelBase.register("PhiMoEForCausalLM") +class PhiMoeModel(Phi3MiniModel): + model_arch = gguf.MODEL_ARCH.PHIMOE + + _experts: list[dict[str, Tensor]] | None = None + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"]) + self.gguf_writer.add_expert_count(self.hparams["num_local_experts"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # process the experts separately + if name.find("block_sparse_moe.experts") != -1: + n_experts = self.hparams["num_local_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["w1", "w2", "w3"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + +@ModelBase.register("PlamoForCausalLM") +class PlamoModel(TextModel): + model_arch = gguf.MODEL_ARCH.PLAMO + + def set_vocab(self): + self._set_vocab_sentencepiece() + + def set_gguf_parameters(self): + hparams = self.hparams + block_count = hparams["num_hidden_layers"] + + self.gguf_writer.add_context_length(4096) # not in config.json + self.gguf_writer.add_embedding_length(hparams["hidden_size"]) + self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong + self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"]) + self.gguf_writer.add_file_type(self.ftype) + + def shuffle_attn_q_weight(self, data_torch): + assert data_torch.size() == (5120, 5120) + data_torch = data_torch.reshape(8, 5, 128, 5120) + data_torch = torch.permute(data_torch, (1, 0, 2, 3)) + data_torch = torch.reshape(data_torch, (5120, 5120)) + return data_torch + + def shuffle_attn_output_weight(self, data_torch): + assert data_torch.size() == (5120, 5120) + data_torch = data_torch.reshape(5120, 8, 5, 128) + data_torch = torch.permute(data_torch, (0, 2, 1, 3)) + data_torch = torch.reshape(data_torch, (5120, 5120)) + return data_torch + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + new_name = self.map_tensor_name(name) + + # shuffle for broadcasting of gqa in ggml_mul_mat + if new_name.endswith("attn_q.weight"): + data_torch = self.shuffle_attn_q_weight(data_torch) + elif new_name.endswith("attn_output.weight"): + data_torch = self.shuffle_attn_output_weight(data_torch) + + return [(new_name, data_torch)] + + +@ModelBase.register("CodeShellForCausalLM") +class CodeShellModel(TextModel): + model_arch = gguf.MODEL_ARCH.CODESHELL + + def set_gguf_parameters(self): + block_count = self.hparams["n_layer"] + + self.gguf_writer.add_context_length(self.hparams["n_positions"]) + self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) + self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(self.hparams["n_head"]) + self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"]) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_file_type(self.ftype) + self.gguf_writer.add_rope_freq_base(10000.0) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(1.0) + + _has_tok_embd = False + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT) + tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD) + + new_name = self.map_tensor_name(name) + + # assuming token_embd.weight is seen before output.weight + if not self._has_tok_embd and new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT): + # even though the tensor file(s) does not contain the word embeddings they are still in the weight map + if self.tensor_names and "transformer.wte.weight" in self.tensor_names: + logger.debug(f"{tok_embd_name} not found before {output_name}, assuming they are tied") + self.tensor_names.remove("transformer.wte.weight") + elif new_name == tok_embd_name: + self._has_tok_embd = True + + return [(new_name, data_torch)] + + +@ModelBase.register("InternLM2ForCausalLM") +class InternLM2Model(TextModel): + model_arch = gguf.MODEL_ARCH.INTERNLM2 + + def set_vocab(self): + # (TODO): Is there a better way? + # Copy from _set_vocab_sentencepiece, The only difference is that we will treat the character + # \x00 specially and convert it into an emoji character to prevent it from being mistakenly + # recognized as an empty string in C++. + from sentencepiece import SentencePieceProcessor + from sentencepiece import sentencepiece_model_pb2 as model + + tokenizer_path = self.dir_model / 'tokenizer.model' + + tokens: list[bytes] = [] + scores: list[float] = [] + toktypes: list[int] = [] + + if not tokenizer_path.is_file(): + logger.error(f'Error: Missing {tokenizer_path}') + sys.exit(1) + + sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] + sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) + add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix + + tokenizer = SentencePieceProcessor() + tokenizer.LoadFromFile(str(tokenizer_path)) + + vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + + for token_id in range(vocab_size): + piece = tokenizer.IdToPiece(token_id) + text = piece.encode("utf-8") + score = tokenizer.GetScore(token_id) + if text == b"\x00": + # (TODO): fixme + # Hack here and replace the \x00 characters. + logger.warning(f"InternLM2 convert token '{text}' to '🐉'!") + text = "🐉".encode("utf-8") + + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.IsUnknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.IsControl(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.IsUnused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.IsByte(token_id): + toktype = SentencePieceTokenTypes.BYTE + # take care of ununsed raw token + if piece.startswith('[UNUSED'): + toktype = SentencePieceTokenTypes.UNUSED + + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + added_tokens_file = self.dir_model / 'added_tokens.json' + if added_tokens_file.is_file(): + with open(added_tokens_file, "r", encoding="utf-8") as f: + added_tokens_json = json.load(f) + + for key in added_tokens_json: + tokens.append(key.encode("utf-8")) + scores.append(-1000.0) + toktypes.append(SentencePieceTokenTypes.USER_DEFINED) + + chat_eos_token = '<|im_end|>' + chat_eos_token_id = None + + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + tokenizer_config_json = json.load(f) + added_tokens_decoder = tokenizer_config_json.get("added_tokens_decoder", {}) + for token_id, foken_data in added_tokens_decoder.items(): + token_id = int(token_id) + token = foken_data["content"] + if token == chat_eos_token: + chat_eos_token_id = token_id + token = token.encode("utf-8") + if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: + if tokens[token_id] != token: + logger.warning(f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token.decode("utf-8")!r}') + tokens[token_id] = token + scores[token_id] = -1000.0 + toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + if foken_data.get("special"): + toktypes[token_id] = SentencePieceTokenTypes.CONTROL + + tokenizer_file = self.dir_model / 'tokenizer.json' + if tokenizer_file.is_file(): + with open(tokenizer_file, "r", encoding="utf-8") as f: + tokenizer_json = json.load(f) + added_tokens = tokenizer_json.get("added_tokens", []) + for foken_data in added_tokens: + token_id = int(foken_data["id"]) + token = foken_data["content"] + if token == chat_eos_token: + chat_eos_token_id = token_id + token = token.encode("utf-8") + if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: + if tokens[token_id] != token: + logger.warning(f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token.decode("utf-8")!r}') + tokens[token_id] = token + scores[token_id] = -1000.0 + toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + if foken_data.get("special"): + toktypes[token_id] = SentencePieceTokenTypes.CONTROL + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_add_space_prefix(add_prefix) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + old_eos = special_vocab.special_token_ids["eos"] + if chat_eos_token_id is not None: + # For the chat model, we replace the eos with '<|im_end|>'. + # TODO: this is a hack, should be fixed + # https://github.com/ggml-org/llama.cpp/pull/6745#issuecomment-2067687048 + special_vocab.special_token_ids["eos"] = chat_eos_token_id + logger.warning(f"Replace eos:{old_eos} with a special token:{chat_eos_token_id}" + " in chat mode so that the conversation can end normally.") + + special_vocab.add_to_gguf(self.gguf_writer) + + def set_gguf_parameters(self): + self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"]) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"]) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) + self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"]) + self.gguf_writer.add_file_type(self.ftype) + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + num_heads = self.hparams["num_attention_heads"] + num_kv_heads = self.hparams["num_key_value_heads"] + n_embd = self.hparams["hidden_size"] + q_per_kv = num_heads // num_kv_heads + head_dim = n_embd // num_heads + num_groups = num_heads // q_per_kv + + name = name.replace("language_model.", "") # InternVL + if name.startswith("mlp") or name.startswith("vision_model"): + # skip visual tensors + return [] + + if bid is not None and f"model.layers.{bid}.attention.wqkv" in name: + qkv = data_torch + + qkv = qkv.reshape((num_groups, q_per_kv + 2, head_dim, n_embd)) + q, k, v = qkv[:, : q_per_kv], qkv[:, -2], qkv[:, -1] + + # The model weights of q and k equire additional reshape. + q = LlamaModel.permute(q.reshape((-1, q.shape[-1])), num_heads, num_heads) + k = LlamaModel.permute(k.reshape((-1, k.shape[-1])), num_heads, num_kv_heads) + v = v.reshape((-1, v.shape[-1])) + + return [ + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q), + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k), + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), v), + ] + else: + return [(self.map_tensor_name(name), data_torch)] + + +@ModelBase.register("InternLM3ForCausalLM") +class InternLM3Model(TextModel): + model_arch = gguf.MODEL_ARCH.LLAMA + + def set_vocab(self): + tokens, scores, toktypes = self._create_vocab_sentencepiece() + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + tokenizer_config_json = json.load(f) + if "add_prefix_space" in tokenizer_config_json: + self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"]) + + if "added_tokens_decoder" in tokenizer_config_json: + for token_id, token_data in tokenizer_config_json["added_tokens_decoder"].items(): + if token_data.get("special"): + token_id = int(token_id) + token = token_data["content"] + special_vocab._set_special_token(token, token_id) + # update eos token + if token == '<|im_end|>' and "eos" in special_vocab.special_token_ids: + special_vocab.special_token_ids["eos"] = token_id + + special_vocab.add_to_gguf(self.gguf_writer) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + + if "head_dim" in hparams: + rope_dim = hparams["head_dim"] + else: + rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] + self.gguf_writer.add_rope_dimension_count(rope_dim) + + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") + name = name.replace("language_model.", "") # InternVL + if name.startswith("mlp") or name.startswith("vision_model"): + # skip visual tensors + return [] + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + return [(self.map_tensor_name(name), data_torch)] + + +@ModelBase.register("BertModel", "BertForMaskedLM", "CamembertModel", "BertForSequenceClassification") +class BertModel(TextModel): + model_arch = gguf.MODEL_ARCH.BERT + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.vocab_size = None + + if cls_out_labels := self.hparams.get("id2label"): + if len(cls_out_labels) == 2 and cls_out_labels[0] == "LABEL_0": + # Remove dummy labels added by AutoConfig + cls_out_labels = None + self.cls_out_labels = cls_out_labels + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_causal_attention(False) + self._try_set_pooling_type() + + if self.cls_out_labels: + self.gguf_writer.add_classifier_output_labels([v for k, v in sorted(self.cls_out_labels.items())]) + + def set_vocab(self): + tokens, toktypes, tokpre = self.get_vocab_base() + self.vocab_size = len(tokens) + + # we need this to validate the size of the token_type embeddings + # though currently we are passing all zeros to the token_type embeddings + # "Sequence A" or "Sequence B" + self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1)) + + # convert to phantom space vocab + def phantom(tok): + if tok.startswith("[") and tok.endswith("]"): + return tok + if tok.startswith("##"): + return tok[2:] + return "\u2581" + tok + tokens = list(map(phantom, tokens)) + + # add vocab to gguf + self.gguf_writer.add_tokenizer_model("bert") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + + # handle special tokens + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if name.startswith("bert."): + name = name[5:] + + if name.endswith(".gamma"): + name = name[:-6] + ".weight" + + if name.endswith(".beta"): + name = name[:-5] + ".bias" + + # we are only using BERT for embeddings so we don't need the pooling layer + if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"): + return [] # we don't need these + + if name.startswith("cls.predictions"): + return [] + + if name.startswith("cls.seq_relationship"): + return [] + + if self.cls_out_labels: + # For BertForSequenceClassification (direct projection layer) + if name == "classifier.weight": + name = "classifier.out_proj.weight" + + if name == "classifier.bias": + name = "classifier.out_proj.bias" + + return [(self.map_tensor_name(name), data_torch)] + + def _xlmroberta_tokenizer_init(self) -> None: + # we need the pad_token_id to know how to chop down position_embd matrix + if (pad_token_id := self.hparams.get("pad_token_id")) is not None: + self._position_offset = 1 + pad_token_id + if "max_position_embeddings" in self.hparams: + self.hparams["max_position_embeddings"] -= self._position_offset + else: + self._position_offset = None + + def _xlmroberta_set_vocab(self) -> None: + # to avoid TypeError: Descriptors cannot be created directly + # exception when importing sentencepiece_model_pb2 + os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + from sentencepiece import SentencePieceProcessor + from sentencepiece import sentencepiece_model_pb2 as model + + tokenizer_path = self.dir_model / 'sentencepiece.bpe.model' + + tokenizer_json = {} + tokenizer_config_json = {} + if not tokenizer_path.is_file(): + tokenizer_path = self.dir_model / 'tokenizer.json' + tokenizer_config_path = self.dir_model / 'tokenizer_config.json' + + if not tokenizer_path.is_file(): + raise FileNotFoundError(f"File not found: {tokenizer_path}") + + from base64 import b64decode + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) + + with open(tokenizer_path, "r", encoding="utf-8") as fp: + tokenizer_json = json.load(fp) + + if tokenizer_config_path.is_file(): + with open(tokenizer_config_path, "r", encoding="utf-8") as fp: + tokenizer_config_json = json.load(fp) + + add_prefix = tokenizer.add_prefix_space + remove_whitespaces = tokenizer.clean_up_tokenization_spaces + precompiled_charsmap = b64decode(tokenizer_json["normalizer"]["precompiled_charsmap"]) + + vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size) + else: + sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] + sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) + assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM + + add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix + remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces + precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap + + tokenizer = SentencePieceProcessor() + tokenizer.LoadFromFile(str(tokenizer_path)) + + vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size()) + + tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] + scores: list[float] = [-10000.0] * vocab_size + toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size + + if isinstance(tokenizer, SentencePieceProcessor): + for token_id in range(tokenizer.vocab_size()): + piece = tokenizer.IdToPiece(token_id) + text = piece.encode("utf-8") + score = tokenizer.GetScore(token_id) + + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.IsUnknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.IsControl(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.IsUnused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.IsByte(token_id): + toktype = SentencePieceTokenTypes.BYTE + + tokens[token_id] = text + scores[token_id] = score + toktypes[token_id] = toktype + else: + added_vocab = tokenizer.get_added_vocab() + unk_token = tokenizer_config_json.get("unk_token") + unk_token_id = added_vocab.get(unk_token, tokenizer_json["model"].get("unk_id", 3)) + + for token_id in range(tokenizer.vocab_size): + piece = tokenizer._convert_id_to_token(token_id) + if (piece := tokenizer._convert_id_to_token(token_id)) is not None: + text = piece.encode("utf-8") + score = tokenizer_json["model"]["vocab"][token_id][1] + + toktype = SentencePieceTokenTypes.NORMAL + if token_id == unk_token_id: + toktype = SentencePieceTokenTypes.UNKNOWN + elif token_id in tokenizer.all_special_ids: + toktype = SentencePieceTokenTypes.CONTROL + elif token_id in added_vocab.values(): + toktype = SentencePieceTokenTypes.USER_DEFINED + # No reliable way to detect this, but jina doesn't have any + # elif tokenizer.IsByte(token_id): + # toktype = SentencePieceTokenTypes.BYTE + + tokens[token_id] = text + scores[token_id] = score + toktypes[token_id] = toktype + + if isinstance(tokenizer, SentencePieceProcessor): + # realign tokens (see HF tokenizer code) + tokens = [b'', b'', b'', b''] + tokens[3:-1] + scores = [0.0, 0.0, 0.0, 0.0] + scores[3:-1] + toktypes = [ + SentencePieceTokenTypes.CONTROL, + SentencePieceTokenTypes.CONTROL, + SentencePieceTokenTypes.CONTROL, + SentencePieceTokenTypes.UNKNOWN, + ] + toktypes[3:-1] + + if self.model_arch == gguf.MODEL_ARCH.NOMIC_BERT_MOE: + # Add mask token missing from sentencepiece.bpe.model + tokens[250001] = b'' + scores[250001] = 0.0 + toktypes[250001] = SentencePieceTokenTypes.CONTROL + + self.gguf_writer.add_tokenizer_model("t5") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_add_space_prefix(add_prefix) + self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1)) + self.gguf_writer.add_remove_extra_whitespaces(remove_whitespaces) + if precompiled_charsmap: + self.gguf_writer.add_precompiled_charsmap(precompiled_charsmap) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + self.gguf_writer.add_add_bos_token(True) + self.gguf_writer.add_add_eos_token(True) + + +@ModelBase.register("DistilBertModel", "DistilBertForMaskedLM", "DistilBertForSequenceClassification") +class DistilBertModel(BertModel): + model_arch = gguf.MODEL_ARCH.BERT + + def set_gguf_parameters(self): + self.gguf_writer.add_layer_norm_eps(1e-12) + logger.info("gguf: layer norm epsilon = 1e-12") + super().set_gguf_parameters() + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.startswith("distilbert."): + name = name[11:] + + # These layers act as MLM head, so we don't need them + if name.startswith("vocab_"): + return [] + + return super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("RobertaModel", "RobertaForSequenceClassification") +class RobertaModel(BertModel): + model_arch = gguf.MODEL_ARCH.BERT + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # we need the pad_token_id to know how to chop down position_embd matrix + if (pad_token_id := self.hparams.get("pad_token_id")) is not None: + self._position_offset = 1 + pad_token_id + if "max_position_embeddings" in self.hparams: + self.hparams["max_position_embeddings"] -= self._position_offset + else: + self._position_offset = None + + def set_vocab(self): + """Support BPE tokenizers for roberta models""" + bpe_tok_path = self.dir_model / "tokenizer.json" + if bpe_tok_path.exists(): + self._set_vocab_gpt2() + self.gguf_writer.add_add_bos_token(True) + self.gguf_writer.add_add_eos_token(True) + + # we need this to validate the size of the token_type embeddings + # though currently we are passing all zeros to the token_type embeddings + # "Sequence A" or "Sequence B" + self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1)) + + else: + return super().set_vocab() + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # if name starts with "roberta.", remove the prefix + # e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main + if name.startswith("roberta."): + name = name[8:] + + # position embeddings start at pad_token_id + 1, so just chop down the weight tensor + if name == "embeddings.position_embeddings.weight": + if self._position_offset is not None: + data_torch = data_torch[self._position_offset:,:] + + return super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("NomicBertModel") +class NomicBertModel(BertModel): + model_arch = gguf.MODEL_ARCH.BERT + + def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any): + hparams = kwargs.pop("hparams", None) + if hparams is None: + hparams = ModelBase.load_hparams(dir_model) + + self.is_moe = bool(hparams.get("moe_every_n_layers")) + self.model_arch = gguf.MODEL_ARCH.NOMIC_BERT_MOE if self.is_moe else gguf.MODEL_ARCH.NOMIC_BERT + + super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs) + + self._tokenizer_is_xlmroberta = self._is_tokenizer_xlmroberta() + if self._tokenizer_is_xlmroberta: + self._xlmroberta_tokenizer_init() + + npos, mtp = self.hparams["n_positions"], self.hparams.get("max_trained_positions", 2048) + if npos == 8192 and mtp == 2048: + self.hparams["n_positions"] = 2048 # nomic-embed-text v1 and v1.5 are trained for 2048 tokens. + elif npos == 2048 and mtp == 2048: + self.hparams["n_positions"] = 512 # nomic-embed-text-v2-moe is trained for 512 tokens. + else: + raise ValueError(f"unrecognized parameters: n_positions={npos}, max_trained_positions={mtp}") + + assert self.hparams["activation_function"] == "gelu" if self.is_moe else "swiglu" + + # this doesn't do anything in the HF version + assert self.hparams["causal"] is False + # no bias tensors unless MoE + assert self.hparams["qkv_proj_bias"] == self.is_moe + assert self.hparams["mlp_fc1_bias"] == self.is_moe + assert self.hparams["mlp_fc2_bias"] == self.is_moe + + # norm at end of layer + assert self.hparams["prenorm"] is False + # standard RoPE + assert self.hparams["rotary_emb_fraction"] == 1.0 + assert self.hparams["rotary_emb_interleaved"] is False + assert self.hparams["rotary_emb_scale_base"] is None + + def set_vocab(self) -> None: + if self._tokenizer_is_xlmroberta: + return self._xlmroberta_set_vocab() + return super().set_vocab() + + def modify_tensors(self, data_torch: torch.Tensor, name: str, bid: int | None) -> Iterable[tuple[str, torch.Tensor]]: + # If the tensor is an experts bias tensor, skip it by returning an empty list. + if "mlp.experts.bias" in name: + return [] # Explicitly return an empty list. + + if "mlp.experts.mlp.w1" in name: + data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"]) + name += ".weight" + + if "mlp.experts.mlp.w2" in name: + data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"]) + data_torch = data_torch.transpose(1, 2) + name += ".weight" + + return [(self.map_tensor_name(name), data_torch)] + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"]) + if self.is_moe: + self.gguf_writer.add_moe_every_n_layers(self.hparams["moe_every_n_layers"]) + self.gguf_writer.add_expert_count(self.hparams["num_experts"]) + self.gguf_writer.add_expert_used_count(self.hparams["moe_top_k"]) + + def _is_tokenizer_xlmroberta(self) -> bool: + with open(self.dir_model / "tokenizer.json") as f: + tokenizer_json = json.load(f) + toktyp = tokenizer_json["model"]["type"] + if toktyp == "Unigram": + return True + if toktyp == "WordPiece": + return False + raise ValueError(f"unknown tokenizer: {toktyp}") + + +@ModelBase.register("NeoBERT", "NeoBERTLMHead", "NeoBERTForSequenceClassification") +class NeoBert(BertModel): + model_arch = gguf.MODEL_ARCH.NEO_BERT + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + # NeoBERT uses 2/3 of the intermediate size as feed forward length + self.gguf_writer.add_feed_forward_length(int(2 * self.hparams["intermediate_size"] / 3)) + self.gguf_writer.add_rope_freq_base(10000.0) # default value for NeoBERT + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + + f_rms_eps = self.hparams.get("norm_eps", 1e-6) # default value for NeoBERT + self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps) + logger.info(f"gguf: rms norm epsilon = {f_rms_eps}") + + self.gguf_writer.add_pooling_type(gguf.PoolingType.CLS) # https://huggingface.co/chandar-lab/NeoBERT#how-to-use + + def modify_tensors(self, data_torch, name, bid): + if name.startswith("decoder."): + return [] + + if name.startswith("model."): + name = name[6:] + + return super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification") +class XLMRobertaModel(BertModel): + model_arch = gguf.MODEL_ARCH.BERT + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._xlmroberta_tokenizer_init() + + def set_vocab(self): + self._xlmroberta_set_vocab() + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # if name starts with "roberta.", remove the prefix + # e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main + if name.startswith("roberta."): + name = name[8:] + + # position embeddings start at pad_token_id + 1, so just chop down the weight tensor + if name == "embeddings.position_embeddings.weight": + if self._position_offset is not None: + data_torch = data_torch[self._position_offset:,:] + + return super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("GemmaForCausalLM") +class GemmaModel(TextModel): + model_arch = gguf.MODEL_ARCH.GEMMA + + def set_vocab(self): + self._set_vocab_sentencepiece() + + # TODO: these special tokens should be exported only for the CodeGemma family + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False, + special_token_types = ['prefix', 'suffix', 'middle', 'fsep', 'eot']) + special_vocab._set_special_token("prefix", 67) + special_vocab._set_special_token("suffix", 69) + special_vocab._set_special_token("middle", 68) + special_vocab._set_special_token("fsep", 70) + special_vocab._set_special_token("eot", 107) + special_vocab.chat_template = None # do not add it twice + special_vocab.add_to_gguf(self.gguf_writer) + + self.gguf_writer.add_add_space_prefix(False) + + def set_gguf_parameters(self): + hparams = self.hparams + block_count = hparams["num_hidden_layers"] + + self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(hparams["hidden_size"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) + self.gguf_writer.add_head_count(hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) + self.gguf_writer.add_key_length(hparams["head_dim"]) + self.gguf_writer.add_value_length(hparams["head_dim"]) + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + # lm_head is not used in llama.cpp, while autoawq will include this tensor in model + # To prevent errors, skip loading lm_head.weight. + if name == "lm_head.weight": + logger.debug(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.") + return [] + + # ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89 + if name.endswith("norm.weight"): + data_torch = data_torch + 1 + + return [(self.map_tensor_name(name), data_torch)] + + +@ModelBase.register("Gemma2ForCausalLM") +class Gemma2Model(TextModel): + model_arch = gguf.MODEL_ARCH.GEMMA2 + + def set_vocab(self): + self._set_vocab_sentencepiece() + + self.gguf_writer.add_add_space_prefix(False) + + def set_gguf_parameters(self): + hparams = self.hparams + block_count = hparams["num_hidden_layers"] + + self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(hparams["hidden_size"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) + self.gguf_writer.add_head_count(hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) + self.gguf_writer.add_key_length(hparams["head_dim"]) + self.gguf_writer.add_value_length(hparams["head_dim"]) + self.gguf_writer.add_file_type(self.ftype) + self.gguf_writer.add_attn_logit_softcapping( + self.hparams["attn_logit_softcapping"] + ) + self.gguf_writer.add_final_logit_softcapping( + self.hparams["final_logit_softcapping"] + ) + self.gguf_writer.add_sliding_window(self.hparams["sliding_window"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + # lm_head is not used in llama.cpp, while autoawq will include this tensor in model + # To prevent errors, skip loading lm_head.weight. + if name == "lm_head.weight": + logger.debug(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.") + return [] + + # ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89 + if name.endswith("norm.weight"): + data_torch = data_torch + 1 + + return [(self.map_tensor_name(name), data_torch)] + + +@ModelBase.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration") +class Gemma3Model(TextModel): + model_arch = gguf.MODEL_ARCH.GEMMA3 + + def set_vocab(self): + self._set_vocab_sentencepiece() + + self.gguf_writer.add_add_space_prefix(False) + + def set_gguf_parameters(self): + hparams = self.hparams + block_count = hparams["num_hidden_layers"] + + # some default values are not specified in the hparams + self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 131072)) + self.gguf_writer.add_embedding_length(hparams["hidden_size"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) + self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8)) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6)) + self.gguf_writer.add_key_length(hparams.get("head_dim", 256)) + self.gguf_writer.add_value_length(hparams.get("head_dim", 256)) + self.gguf_writer.add_file_type(self.ftype) + self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers + # both attn_logit_softcapping and final_logit_softcapping are removed in Gemma3 + assert hparams.get("attn_logit_softcapping") is None + assert hparams.get("final_logit_softcapping") is None + self.gguf_writer.add_sliding_window(hparams["sliding_window"]) + self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4)) + if hparams.get("rope_scaling") is not None: + assert hparams["rope_scaling"]["rope_type"] == "linear" + # important: this rope_scaling is only applied for global layers, and not used by 1B model + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if name.startswith("language_model."): + name = name.replace("language_model.", "") + + elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \ + or name.startswith("multimodal_projector.") or name.startswith("vision_model."): + return [] # skip vision tensors + + # remove OOV (out-of-vocabulary) rows in token_embd + if "embed_tokens.weight" in name: + vocab = self._create_vocab_sentencepiece() + tokens = vocab[0] + data_torch = data_torch[:len(tokens)] + + # ref code in Gemma3RMSNorm + # output = output * (1.0 + self.weight.float()) + if name.endswith("norm.weight"): + data_torch = data_torch + 1 + + return [(self.map_tensor_name(name), data_torch)] + + +@ModelBase.register("Gemma3ForConditionalGeneration") +class Gemma3VisionModel(MmprojModel): + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GEMMA3) + # default values below are taken from HF tranformers code + self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6)) + self.gguf_writer.add_vision_use_gelu(True) + # calculate proj_scale_factor (used by tinygemma3 test model) + image_seq_length = self.preprocessor_config.get("image_seq_length", 256) + n_per_side = int(image_seq_length ** 0.5) + image_size = self.hparams["image_size"] + patch_size = self.hparams["patch_size"] + proj_scale_factor = (image_size // patch_size) // n_per_side + if proj_scale_factor > 0 and proj_scale_factor != 4: + # we only need to write this if it's not the default value + # in this case, we are converting a test model + self.gguf_writer.add_vision_projector_scale_factor(proj_scale_factor) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + del bid, new_name, n_dims # unused + # related to https://github.com/ggml-org/llama.cpp/issues/13025 + if "input_projection" in name: + return gguf.GGMLQuantizationType.F16 + if ".embeddings." in name: + return gguf.GGMLQuantizationType.F32 + return False + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if "vision_model.head." in name: + return [] # skip redundant tensors for tinygemma3 + + if name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \ + or name.startswith("multimodal_projector.") or name.startswith("vision_model."): + # process vision tensors + name = name.replace("_weight", ".weight") + + # correct norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector + # the other norm values are part of SigLIP model, and they are already correct + # ref code: Gemma3RMSNorm + if "soft_emb_norm.weight" in name: + logger.info(f"Correcting norm value for '{name}'") + data_torch = data_torch + 1 + + return [(self.map_tensor_name(name), data_torch)] + + return [] # skip other tensors + + +@ModelBase.register("Starcoder2ForCausalLM") +class StarCoder2Model(TextModel): + model_arch = gguf.MODEL_ARCH.STARCODER2 + + +@ModelBase.register("Rwkv6ForCausalLM") +class Rwkv6Model(TextModel): + model_arch = gguf.MODEL_ARCH.RWKV6 + + def set_vocab(self): + self._set_vocab_rwkv_world() + + def set_gguf_parameters(self): + block_count = self.hparams["num_hidden_layers"] + head_size = self.hparams["head_size"] + hidden_size = self.hparams["hidden_size"] + layer_norm_eps = self.hparams["layer_norm_epsilon"] + rescale_every_n_layers = self.hparams["rescale_every"] + intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else int((hidden_size * 3.5) // 32 * 32) + time_mix_extra_dim = 64 if hidden_size == 4096 else 32 + time_decay_extra_dim = 128 if hidden_size == 4096 else 64 + + # RWKV isn't context limited + self.gguf_writer.add_context_length(1048576) + self.gguf_writer.add_embedding_length(hidden_size) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_layer_norm_eps(layer_norm_eps) + self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers) + self.gguf_writer.add_wkv_head_size(head_size) + self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim) + self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim) + self.gguf_writer.add_feed_forward_length(intermediate_size) + self.gguf_writer.add_file_type(self.ftype) + + # required by llama.cpp, unused + self.gguf_writer.add_head_count(0) + + lerp_weights: dict[int, dict[str, Tensor]] = {} + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + new_name = self.map_tensor_name(name) + + if not (new_name.endswith(".weight") or new_name.endswith(".bias")): + new_name += ".weight" + + if new_name.endswith("time_mix_w1.weight") or new_name.endswith("time_mix_decay_w1.weight") or new_name.endswith("time_mix_decay_w2.weight"): + data_torch = data_torch.transpose(0, 1) + + if new_name.endswith("time_mix_w2.weight"): + data_torch = data_torch.permute(0, 2, 1) + + if new_name.endswith("time_mix_decay.weight") or "lerp" in new_name: + data_torch = data_torch.squeeze() + + try: + rescale_every_n_layers = self.hparams["rescale_every"] + if rescale_every_n_layers > 0: + if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"): + data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers)) + except KeyError: + pass + + # concat time_mix_lerp weights to reduce some cpu overhead + # also reduces the number of tensors in the model + if bid is not None and "time_mix_lerp" in new_name and "time_mix_lerp_x" not in new_name: + try: + self.lerp_weights[bid][new_name] = data_torch + except KeyError: + self.lerp_weights[bid] = {new_name: data_torch} + if all(f"blk.{bid}.time_mix_lerp_{i}.weight" in self.lerp_weights[bid].keys() for i in ["w", "k", "v", "r", "g"]): + new_name = f"blk.{bid}.time_mix_lerp_fused.weight" + data = torch.stack([self.lerp_weights[bid][f"blk.{bid}.time_mix_lerp_{i}.weight"].unsqueeze(0) for i in ["w", "k", "v", "r", "g"]], dim=0).unsqueeze(1) + yield (new_name, data) + return + + yield (new_name, data_torch) + + +@ModelBase.register("RWKV6Qwen2ForCausalLM") +class RWKV6Qwen2Model(Rwkv6Model): + model_arch = gguf.MODEL_ARCH.RWKV6QWEN2 + + def set_vocab(self): + try: + self._set_vocab_sentencepiece() + except FileNotFoundError: + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + block_count = self.hparams["num_hidden_layers"] + num_attention_heads = self.hparams["num_attention_heads"] + num_key_value_heads = self.hparams["num_key_value_heads"] + hidden_size = self.hparams["hidden_size"] + head_size = hidden_size // num_attention_heads + rms_norm_eps = self.hparams["rms_norm_eps"] + intermediate_size = self.hparams["intermediate_size"] + time_mix_extra_dim = self.hparams.get("lora_rank_tokenshift", 64 if hidden_size >= 4096 else 32) + time_decay_extra_dim = self.hparams.get("lora_rank_decay", 128 if hidden_size >= 4096 else 64) + + # RWKV isn't context limited + self.gguf_writer.add_context_length(1048576) + self.gguf_writer.add_embedding_length(hidden_size) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_wkv_head_size(head_size) + self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim) + self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim) + self.gguf_writer.add_feed_forward_length(intermediate_size) + self.gguf_writer.add_file_type(self.ftype) + + # special parameters for time_mixing in RWKV6QWEN2 + self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) + self.gguf_writer.add_token_shift_count(1) + # RWKV6QWEN2 use grouped key/value like GQA + self.gguf_writer.add_head_count_kv(num_key_value_heads) + + # required by llama.cpp, unused + self.gguf_writer.add_head_count(0) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + for new_name, data in super().modify_tensors(data_torch, name, bid): + if "time_mix_w1" in new_name or "time_mix_w2" in new_name: + data = data.view(5, -1, data.shape[-1]) + # rwkv6qwen2 has a different order of rkvwg instead of the original wkvrg + # permute them here to avoid code changes + data = torch.stack([data[3], data[1], data[2], data[0], data[4]], dim=0).view(-1, data.shape[-1]) + if "w2" in new_name: + data = data.view(5, -1, data.shape[-1]) + yield (new_name, data) + continue + yield (new_name, data) + + +@ModelBase.register("Rwkv7ForCausalLM", "RWKV7ForCausalLM") +class Rwkv7Model(TextModel): + model_arch = gguf.MODEL_ARCH.RWKV7 + + def set_vocab(self): + self._set_vocab_rwkv_world() + + def calc_lora_rank(self, hidden_size, exponent, multiplier): + return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32 + + def set_gguf_parameters(self): + block_count = self.hparams["num_hidden_layers"] + try: + head_size = self.hparams["head_size"] + layer_norm_eps = self.hparams["layer_norm_epsilon"] + except KeyError: + head_size = self.hparams["head_dim"] + layer_norm_eps = self.hparams["norm_eps"] + hidden_size = self.hparams["hidden_size"] + intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else (hidden_size * 4) + + # ICLR: In-Context-Learning-Rate + try: + lora_rank_decay = self.hparams["lora_rank_decay"] if self.hparams["lora_rank_decay"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8) + lora_rank_iclr = self.hparams["lora_rank_iclr"] if self.hparams["lora_rank_iclr"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8) + lora_rank_value_residual_mix = self.hparams["lora_rank_value_residual_mix"] if self.hparams["lora_rank_value_residual_mix"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3) + lora_rank_gate = self.hparams["lora_rank_gate"] if self.hparams["lora_rank_gate"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6) + except KeyError: + lora_rank_decay = self.hparams["decay_low_rank_dim"] if self.hparams["decay_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8) + lora_rank_iclr = self.hparams["a_low_rank_dim"] if self.hparams["a_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8) + lora_rank_value_residual_mix = self.hparams["v_low_rank_dim"] if self.hparams["v_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3) + lora_rank_gate = self.hparams["gate_low_rank_dim"] if self.hparams["gate_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6) + + # RWKV isn't context limited + self.gguf_writer.add_context_length(1048576) + self.gguf_writer.add_embedding_length(hidden_size) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_layer_norm_eps(layer_norm_eps) + self.gguf_writer.add_wkv_head_size(head_size) + self.gguf_writer.add_decay_lora_rank(lora_rank_decay) + self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr) + self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix) + self.gguf_writer.add_gate_lora_rank(lora_rank_gate) + self.gguf_writer.add_feed_forward_length(intermediate_size) + self.gguf_writer.add_file_type(self.ftype) + + # required by llama.cpp, unused + self.gguf_writer.add_head_count(0) + + lerp_weights: dict[int, dict[str, Tensor]] = {} + lora_needs_transpose: bool = True + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # unify tensor names here to make life easier + name = name.replace("blocks", "layers").replace("ffn", "feed_forward") + name = name.replace("self_attn", "attention").replace("attn", "attention") + name = name.replace("time_mixer.", "") + # lora layer names in fla-hub's impl + if "_lora.lora" in name: + self.lora_needs_transpose = False + name = name.replace("_lora.lora.0.weight", "1.weight") + name = name.replace("_lora.lora.2.weight", "2.weight") + name = name.replace("_lora.lora.2.bias", "0.weight") + + name = name.replace("feed_forward_norm", "ln2") + name = name.replace("g_norm", "ln_x") + + if "attention.v" in name and "value" not in self.map_tensor_name(name) and bid == 0: + # some models have dummy v0/v1/v2 on first layer while others don't + # ignore them all since they are not used + return + + wkv_has_gate = self.hparams.get("wkv_has_gate", True) + lerp_list = ["r", "w", "k", "v", "a", "g"] if wkv_has_gate else ["r", "w", "k", "v", "a"] + + if bid is not None and "attention.x_" in name: + if "attention.x_x" in name: + # already concatenated + new_name = f"blk.{bid}.time_mix_lerp_fused.weight" + data = data_torch.reshape(len(lerp_list), 1, 1, -1) + yield (new_name, data) + else: + try: + self.lerp_weights[bid][name] = data_torch + except KeyError: + self.lerp_weights[bid] = {name: data_torch} + if all(f"model.layers.{bid}.attention.x_{i}" in self.lerp_weights[bid].keys() for i in lerp_list): + new_name = f"blk.{bid}.time_mix_lerp_fused.weight" + data = torch.stack([self.lerp_weights[bid][f"model.layers.{bid}.attention.x_{i}"] for i in lerp_list], dim=0) + yield (new_name, data) + return + else: + data_torch = data_torch.squeeze() + new_name = self.map_tensor_name(name) + + if not (new_name.endswith(".weight") or new_name.endswith(".bias")): + new_name += ".weight" + + if self.lora_needs_transpose and any( + new_name.endswith(t) for t in [ + "time_mix_w1.weight", "time_mix_w2.weight", + "time_mix_a1.weight", "time_mix_a2.weight", + "time_mix_v1.weight", "time_mix_v2.weight", + "time_mix_g1.weight", "time_mix_g2.weight", + ] + ): + data_torch = data_torch.transpose(0, 1) + + if 'r_k' in new_name: + data_torch = data_torch.flatten() + + if bid == 0 and "time_mix_a" in new_name: + # dummy v0/v1/v2 on first layer + # easist way to make llama happy + yield (new_name.replace("time_mix_a", "time_mix_v"), data_torch) + + yield (new_name, data_torch) + + +@ModelBase.register("RwkvHybridForCausalLM") +class ARwkv7Model(Rwkv7Model): + model_arch = gguf.MODEL_ARCH.ARWKV7 + + def set_vocab(self): + try: + self._set_vocab_sentencepiece() + except FileNotFoundError: + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + block_count = self.hparams["num_hidden_layers"] + hidden_size = self.hparams["hidden_size"] + head_size = self.hparams["head_size"] + rms_norm_eps = self.hparams["rms_norm_eps"] + intermediate_size = self.hparams["intermediate_size"] + wkv_has_gate = self.hparams["wkv_has_gate"] + assert self.hparams["wkv_version"] == 7 + + # ICLR: In-Context-Learning-Rate + lora_rank_decay = 64 + lora_rank_iclr = 64 + lora_rank_value_residual_mix = 32 + lora_rank_gate = 128 if wkv_has_gate else 0 + + # RWKV isn't context limited + self.gguf_writer.add_context_length(1048576) + self.gguf_writer.add_embedding_length(hidden_size) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) + self.gguf_writer.add_wkv_head_size(head_size) + self.gguf_writer.add_decay_lora_rank(lora_rank_decay) + self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr) + self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix) + self.gguf_writer.add_gate_lora_rank(lora_rank_gate) + self.gguf_writer.add_feed_forward_length(intermediate_size) + self.gguf_writer.add_file_type(self.ftype) + self.gguf_writer.add_token_shift_count(1) + + # required by llama.cpp, unused + self.gguf_writer.add_head_count(0) + + +@ModelBase.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM") +class MambaModel(TextModel): + model_arch = gguf.MODEL_ARCH.MAMBA + + def set_vocab(self): + vocab_size = self.hparams["vocab_size"] + # Round vocab size to next multiple of 8 + pad_vocab = self.hparams.get("pad_vocab_size_multiple", 8) + # pad using ceiling division + # ref: https://stackoverflow.com/a/17511341/22827863 + vocab_size = -(vocab_size // -pad_vocab) * pad_vocab + self.hparams["vocab_size"] = vocab_size + + if (self.dir_model / "tokenizer.json").is_file(): + self._set_vocab_gpt2() + elif (self.dir_model / "tokenizer.model").is_file(): + self._set_vocab_sentencepiece() + else: + # Use the GPT-NeoX tokenizer when no tokenizer files are present + self._set_vocab_builtin("gpt-neox", vocab_size) + + def set_gguf_parameters(self): + d_model = self.find_hparam(["hidden_size", "d_model"]) + d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4 + d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model + d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 16 + # ceiling division + # ref: https://stackoverflow.com/a/17511341/22827863 + # ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58 + dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16) + rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 + use_dt_b_c_norm = False + # For falconmamba we do apply RMS norm on B / DT and C layers + if self.find_hparam(["model_type"], optional=True) in ("falcon_mamba",): + use_dt_b_c_norm = True + # Fail early for models which don't have a block expansion factor of 2 + assert d_inner == 2 * d_model + + self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default + self.gguf_writer.add_embedding_length(d_model) + self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading + self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_ssm_conv_kernel(d_conv) + self.gguf_writer.add_ssm_inner_size(d_inner) + self.gguf_writer.add_ssm_state_size(d_state) + self.gguf_writer.add_ssm_time_step_rank(dt_rank) + self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) + self.gguf_writer.add_ssm_dt_b_c_rms(use_dt_b_c_norm) # For classic Mamba we don't apply rms norm on B / DT layers + self.gguf_writer.add_file_type(self.ftype) + + _tok_embd = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT) + tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD) + + new_name = self.map_tensor_name(name) + + if name.endswith(".A_log"): + logger.debug("A_log --> A ==> " + new_name) + data_torch = -torch.exp(data_torch) + + # [4 1 8192 1] -> [4 8192 1 1] + if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid): + data_torch = data_torch.squeeze() + + # assuming token_embd.weight is seen before output.weight + if self._tok_embd is not None and new_name == output_name: + if torch.equal(self._tok_embd, data_torch): + logger.debug(f"{output_name} is equivalent to {tok_embd_name}, omitting") + return [] + elif new_name == tok_embd_name: + self._tok_embd = data_torch + + return [(new_name, data_torch)] + + +@ModelBase.register("CohereForCausalLM") +class CommandR2Model(TextModel): + model_arch = gguf.MODEL_ARCH.COMMAND_R + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # max_position_embeddings = 8192 in config.json but model was actually + # trained on 128k context length + # aya-23 models don't have model_max_length specified + self.hparams["max_position_embeddings"] = self.find_hparam(["model_max_length", "max_position_embeddings"]) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_logit_scale(self.hparams["logit_scale"]) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + + +@ModelBase.register("Cohere2ForCausalLM") +class Cohere2Model(TextModel): + model_arch = gguf.MODEL_ARCH.COHERE2 + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + self.gguf_writer.add_logit_scale(self.hparams["logit_scale"]) + self.gguf_writer.add_sliding_window(self.hparams["sliding_window"]) + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + + rotary_pct = self.hparams["rotary_pct"] + hidden_size = self.hparams["hidden_size"] + num_attention_heads = self.hparams["num_attention_heads"] + self.gguf_writer.add_rope_dimension_count(int(rotary_pct * (hidden_size // num_attention_heads))) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + + +@ModelBase.register("OlmoForCausalLM") +@ModelBase.register("OLMoForCausalLM") +class OlmoModel(TextModel): + model_arch = gguf.MODEL_ARCH.OLMO + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_layer_norm_eps(1e-5) + clip_qkv = self.hparams.get("clip_qkv") + if clip_qkv is not None: + self.gguf_writer.add_clamp_kqv(clip_qkv) + + # Same as super class, but permuting q_proj, k_proj + # Copied from: LlamaModel + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") + + if name.endswith("q_proj.weight"): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith("k_proj.weight"): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + + return [(self.map_tensor_name(name), data_torch)] + + +@ModelBase.register("Olmo2ForCausalLM") +class Olmo2Model(TextModel): + model_arch = gguf.MODEL_ARCH.OLMO2 + + +@ModelBase.register("OlmoeForCausalLM") +class OlmoeModel(TextModel): + model_arch = gguf.MODEL_ARCH.OLMOE + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_layer_norm_rms_eps(1e-5) + if (n_experts := self.hparams.get("num_experts")) is not None: + self.gguf_writer.add_expert_count(n_experts) + + _experts: list[dict[str, Tensor]] | None = None + + # Copied from: Qwen2MoeModel + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # process the experts separately + if name.find("experts") != -1: + n_experts = self.hparams["num_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + # Copied from: Qwen2MoeModel + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + +@ModelBase.register("JinaBertModel", "JinaBertForMaskedLM") +class JinaBertV2Model(BertModel): + model_arch = gguf.MODEL_ARCH.JINA_BERT_V2 + + def set_vocab(self): + tokenizer_class = 'BertTokenizer' + with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f: + tokenizer_class = json.load(f)['tokenizer_class'] + + if tokenizer_class == 'BertTokenizer': + super().set_vocab() + elif tokenizer_class == 'RobertaTokenizer': + self._set_vocab_gpt2() + self.gguf_writer.add_token_type_count(2) + else: + raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel') + self.gguf_writer.add_add_bos_token(True) + self.gguf_writer.add_add_eos_token(True) + + +@ModelBase.register("OpenELMForCausalLM") +class OpenELMModel(TextModel): + model_arch = gguf.MODEL_ARCH.OPENELM + + @staticmethod + def _make_divisible(v: float | int, divisor: int) -> int: + # ref: https://huggingface.co/apple/OpenELM-270M-Instruct/blob/eb111ff2e6724348e5b905984063d4064d4bc579/configuration_openelm.py#L34-L38 + new_v = max(divisor, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + ffn_multipliers: list[float] = self.hparams["ffn_multipliers"] + ffn_dim_divisor: int = self.hparams["ffn_dim_divisor"] + self._n_embd: int = self.hparams["model_dim"] + self._num_kv_heads: list[int] = self.hparams["num_kv_heads"] + self._num_query_heads: list[int] = self.hparams["num_query_heads"] + self._ffn_dims: list[int] = [ + OpenELMModel._make_divisible(multiplier * self._n_embd, ffn_dim_divisor) + for multiplier in ffn_multipliers + ] + assert isinstance(self._num_kv_heads, list) and isinstance(self._num_kv_heads[0], int) + assert isinstance(self._num_query_heads, list) and isinstance(self._num_query_heads[0], int) + + # Uses the tokenizer from meta-llama/Llama-2-7b-hf + def set_vocab(self): + try: + self._set_vocab_sentencepiece() + except FileNotFoundError: + self._set_vocab_builtin("llama-spm", self.hparams["vocab_size"]) + + def set_gguf_parameters(self): + n_embd = self._n_embd + head_dim = self.hparams["head_dim"] + rot_pct = 1.0 + assert self.block_count == len(self._num_kv_heads) + assert self.block_count == len(self._num_query_heads) + assert self.block_count == len(self._ffn_dims) + + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_context_length(self.hparams["max_context_length"]) + self.gguf_writer.add_embedding_length(n_embd) + self.gguf_writer.add_feed_forward_length(self._ffn_dims) + self.gguf_writer.add_head_count(self._num_query_heads) + self.gguf_writer.add_head_count_kv(self._num_kv_heads) + self.gguf_writer.add_rope_freq_base(self.hparams["rope_freq_constant"]) + # https://huggingface.co/apple/OpenELM-270M-Instruct/blob/c401df2/modeling_openelm.py#L30 + self.gguf_writer.add_layer_norm_rms_eps(1e-6) + self.gguf_writer.add_rope_dimension_count(int(rot_pct * head_dim)) + self.gguf_writer.add_key_length(head_dim) + self.gguf_writer.add_value_length(head_dim) + self.gguf_writer.add_file_type(self.ftype) + + def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any: + if "n_layers" in keys: + return self.hparams["num_transformer_layers"] + + return super().find_hparam(keys, optional) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + + # split ff + if bid is not None and name == f"transformer.layers.{bid}.ffn.proj_1.weight": + ff_dim = self._ffn_dims[bid] + yield (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), data_torch[:ff_dim]) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), data_torch[ff_dim:]) + return + + yield (self.map_tensor_name(name), data_torch) + + +@ModelBase.register("ArcticForCausalLM") +class ArcticModel(TextModel): + model_arch = gguf.MODEL_ARCH.ARCTIC + + def set_vocab(self): + # The reason for using a custom implementation here is that the + # snowflake-arctic-instruct model redefined tokens 31998 and 31999 from + # tokenizer.model and used them as BOS and EOS instead of adding new tokens. + from sentencepiece import SentencePieceProcessor + + tokenizer_path = self.dir_model / 'tokenizer.model' + + if not tokenizer_path.is_file(): + logger.error(f'Error: Missing {tokenizer_path}') + sys.exit(1) + + # Read the whole vocabulary from the tokenizer.model file + tokenizer = SentencePieceProcessor() + tokenizer.LoadFromFile(str(tokenizer_path)) + + vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + + tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] + scores: list[float] = [-10000.0] * vocab_size + toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size + + for token_id in range(tokenizer.vocab_size()): + + piece = tokenizer.IdToPiece(token_id) + text = piece.encode("utf-8") + score = tokenizer.GetScore(token_id) + + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.IsUnknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.IsControl(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.IsUnused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.IsByte(token_id): + toktype = SentencePieceTokenTypes.BYTE + + tokens[token_id] = text + scores[token_id] = score + toktypes[token_id] = toktype + + # Use the added_tokens_decoder field from tokeniser_config.json as the source + # of information about added/redefined tokens and modify them accordingly. + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + tokenizer_config_json = json.load(f) + + if "added_tokens_decoder" in tokenizer_config_json: + added_tokens_decoder = tokenizer_config_json["added_tokens_decoder"] + for token_id, token_json in added_tokens_decoder.items(): + token_id = int(token_id) + if token_id >= vocab_size: + logger.debug(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') + continue + + token_content = token_json["content"] + token_type = SentencePieceTokenTypes.USER_DEFINED + token_score = -10000.0 + + # Map unk_token to UNKNOWN, other special tokens to CONTROL + # Set the score to 0.0 as in the original tokenizer.model + if ("special" in token_json) and token_json["special"]: + if token_content == tokenizer_config_json["unk_token"]: + token_type = SentencePieceTokenTypes.UNKNOWN + else: + token_type = SentencePieceTokenTypes.CONTROL + token_score = 0.0 + + logger.info(f"Setting added token {token_id} to '{token_content}' (type: {token_type}, score: {token_score:.2f})") + tokens[token_id] = token_content.encode("utf-8") + toktypes[token_id] = token_type + scores[token_id] = token_score + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"]) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") + + if name.endswith("q_proj.weight"): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith("k_proj.weight"): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + + # process the experts separately + if name.find("block_sparse_moe.experts") != -1: + n_experts = self.hparams["num_local_experts"] + + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for wid in ["w1", "w2", "w3"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"layers.{bid}.feed_forward.experts.{wid}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + +@ModelBase.register("DeepseekForCausalLM") +class DeepseekModel(TextModel): + model_arch = gguf.MODEL_ARCH.DEEPSEEK + + def set_vocab(self): + try: + self._set_vocab_sentencepiece() + except FileNotFoundError: + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + if "head_dim" in hparams: + rope_dim = hparams["head_dim"] + else: + rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] + + self.gguf_writer.add_rope_dimension_count(rope_dim) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) + self.gguf_writer.add_expert_weights_scale(1.0) + self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) + self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) + + _experts: list[dict[str, Tensor]] | None = None + + @staticmethod + def permute(weights: Tensor, n_head: int, n_head_kv: int | None): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") + + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = DeepseekModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = DeepseekModel.permute(data_torch, n_head, n_kv_head) + + # process the experts separately + if name.find("mlp.experts") != -1: + n_experts = self.hparams["n_routed_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + +@ModelBase.register("DeepseekV2ForCausalLM") +@ModelBase.register("DeepseekV3ForCausalLM") +class DeepseekV2Model(TextModel): + model_arch = gguf.MODEL_ARCH.DEEPSEEK2 + + def set_vocab(self): + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + + # note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group) + self.hparams["num_key_value_heads"] = 1 + + super().set_gguf_parameters() + hparams = self.hparams + + self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: + self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) + self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) + + # note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA + self.gguf_writer.add_key_length(hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length(hparams["kv_lora_rank"]) + self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) + + self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) + self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) + self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) + self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) + self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + + if hparams["scoring_func"] == "sigmoid": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + elif hparams["scoring_func"] == "softmax": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) + else: + raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}") + + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) + + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"]) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # rename e_score_correction_bias tensors + if name.endswith("e_score_correction_bias"): + name = name.replace("e_score_correction_bias", "e_score_correction.bias") + + # skip Multi-Token Prediction (MTP) layers + block_count = self.hparams["num_hidden_layers"] + match = re.match(r"model.layers.(\d+)", name) + if match and int(match.group(1)) >= block_count: + return [] + + # process the experts separately + if name.find("mlp.experts") != -1: + n_experts = self.hparams["n_routed_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + # note: MLA with the absorption optimization, needs these two split and k_b_proj transposed + if name.endswith("kv_b_proj.weight"): + name_kb = name.replace("kv_b_proj", "k_b_proj") + name_vb = name.replace("kv_b_proj", "v_b_proj") + + n_head_kv = self.hparams["num_key_value_heads"] + v_head_dim = self.hparams["v_head_dim"] + qk_nope_head_dim = self.hparams["qk_nope_head_dim"] + + assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim) + + kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1]) + k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1) + k_b = k_b.transpose(1, 2) + + return [ + (self.map_tensor_name(name_kb), k_b), + (self.map_tensor_name(name_vb), v_b) + ] + + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + +@ModelBase.register("Dots1ForCausalLM") +class Dots1Model(Qwen2MoeModel): + model_arch = gguf.MODEL_ARCH.DOTS1 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.hparams["num_experts"] = self.hparams["n_routed_experts"] + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_leading_dense_block_count(self.hparams["first_k_dense_replace"]) + self.gguf_writer.add_expert_shared_count(self.hparams["n_shared_experts"]) + self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"]) + self.gguf_writer.add_expert_weights_norm(self.hparams["norm_topk_prob"]) + + if self.hparams["scoring_func"] == "noaux_tc": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + else: + raise ValueError(f"Unsupported scoring_func value: {self.hparams['scoring_func']}") + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): + if name.endswith("e_score_correction_bias"): + name = name.replace("e_score_correction_bias", "e_score_correction.bias") + if "shared_experts" in name: + return [(self.map_tensor_name(name), data_torch)] + return super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("PLMForCausalLM") +class PLMModel(TextModel): + model_arch = gguf.MODEL_ARCH.PLM + + def set_vocab(self): + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) + self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length(hparams["v_head_dim"]) + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + + +@ModelBase.register("T5WithLMHeadModel") +@ModelBase.register("T5ForConditionalGeneration") +@ModelBase.register("MT5ForConditionalGeneration") +@ModelBase.register("UMT5ForConditionalGeneration") +class T5Model(TextModel): + model_arch = gguf.MODEL_ARCH.T5 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.shared_token_embeddings_found = False + + def set_vocab(self): + # to avoid TypeError: Descriptors cannot be created directly + # exception when importing sentencepiece_model_pb2 + os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + from sentencepiece import SentencePieceProcessor + from sentencepiece import sentencepiece_model_pb2 as model + + tokenizer_path = self.dir_model / 'tokenizer.model' + + # many older models use spiece.model tokenizer model filename + if not tokenizer_path.is_file(): + tokenizer_path = self.dir_model / 'spiece.model' + + if not tokenizer_path.is_file(): + raise FileNotFoundError(f"File not found: {tokenizer_path}") + + sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] + sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) + + # some models like Pile-T5 family use BPE tokenizer instead of Unigram + if sentencepiece_model.trainer_spec.model_type == 2: # BPE + # assure the tokenizer model file name is correct + assert tokenizer_path.name == 'tokenizer.model' + return self._set_vocab_sentencepiece() + else: + assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM + + add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix + remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces + precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap + + tokenizer = SentencePieceProcessor() + tokenizer.LoadFromFile(str(tokenizer_path)) + + vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + + tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] + scores: list[float] = [-10000.0] * vocab_size + toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size + + for token_id in range(tokenizer.vocab_size()): + piece = tokenizer.IdToPiece(token_id) + text = piece.encode("utf-8") + score = tokenizer.GetScore(token_id) + + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.IsUnknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.IsControl(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.IsUnused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.IsByte(token_id): + toktype = SentencePieceTokenTypes.BYTE + + tokens[token_id] = text + scores[token_id] = score + toktypes[token_id] = toktype + + added_tokens_file = self.dir_model / 'added_tokens.json' + if added_tokens_file.is_file(): + with open(added_tokens_file, "r", encoding="utf-8") as f: + added_tokens_json = json.load(f) + for key in added_tokens_json: + token_id = added_tokens_json[key] + if token_id >= vocab_size: + logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') + continue + + tokens[token_id] = key.encode("utf-8") + scores[token_id] = -1000.0 + toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + + if vocab_size > len(tokens): + pad_count = vocab_size - len(tokens) + logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]") + for i in range(1, pad_count + 1): + tokens.append(bytes(f"[PAD{i}]", encoding="utf-8")) + scores.append(-1000.0) + toktypes.append(SentencePieceTokenTypes.UNUSED) + + self.gguf_writer.add_tokenizer_model("t5") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_add_space_prefix(add_prefix) + self.gguf_writer.add_remove_extra_whitespaces(remove_whitespaces) + if precompiled_charsmap: + self.gguf_writer.add_precompiled_charsmap(precompiled_charsmap) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + self.gguf_writer.add_add_bos_token(False) + self.gguf_writer.add_add_eos_token(True) + + def set_gguf_parameters(self): + if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None: + logger.warning("Couldn't find context length in config.json, assuming default value of 512") + n_ctx = 512 + self.gguf_writer.add_context_length(n_ctx) + self.gguf_writer.add_embedding_length(self.hparams["d_model"]) + self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"]) + self.gguf_writer.add_block_count(self.hparams["num_layers"]) + self.gguf_writer.add_head_count(self.hparams["num_heads"]) + self.gguf_writer.add_key_length(self.hparams["d_kv"]) + self.gguf_writer.add_value_length(self.hparams["d_kv"]) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_relative_attn_buckets_count(self.hparams["relative_attention_num_buckets"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_decoder_start_token_id(self.hparams["decoder_start_token_id"]) + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + # T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight", + # "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored + # in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder + # and decoder and ignore the remaining ones. + if name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "shared.weight"]: + if not self.shared_token_embeddings_found: + name = "shared.weight" + self.shared_token_embeddings_found = True + else: + logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.") + return [] + + return [(self.map_tensor_name(name), data_torch)] + + +@ModelBase.register("T5EncoderModel") +class T5EncoderModel(TextModel): + model_arch = gguf.MODEL_ARCH.T5ENCODER + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.shared_token_embeddings_found = False + + def set_vocab(self): + # to avoid TypeError: Descriptors cannot be created directly + # exception when importing sentencepiece_model_pb2 + os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + from sentencepiece import SentencePieceProcessor + from sentencepiece import sentencepiece_model_pb2 as model + + tokenizer_path = self.dir_model / 'tokenizer.model' + + # many older models use spiece.model tokenizer model filename + if not tokenizer_path.is_file(): + tokenizer_path = self.dir_model / 'spiece.model' + + if not tokenizer_path.is_file(): + raise FileNotFoundError(f"File not found: {tokenizer_path}") + + sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] + sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) + + # some models like Pile-T5 family use BPE tokenizer instead of Unigram + if sentencepiece_model.trainer_spec.model_type == 2: # BPE + # assure the tokenizer model file name is correct + assert tokenizer_path.name == 'tokenizer.model' + return self._set_vocab_sentencepiece() + else: + assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM + + add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix + remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces + precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap + + tokenizer = SentencePieceProcessor() + tokenizer.LoadFromFile(str(tokenizer_path)) + + vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + + tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] + scores: list[float] = [-10000.0] * vocab_size + toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size + + for token_id in range(tokenizer.vocab_size()): + piece = tokenizer.IdToPiece(token_id) + text = piece.encode("utf-8") + score = tokenizer.GetScore(token_id) + + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.IsUnknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.IsControl(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.IsUnused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.IsByte(token_id): + toktype = SentencePieceTokenTypes.BYTE + + tokens[token_id] = text + scores[token_id] = score + toktypes[token_id] = toktype + + added_tokens_file = self.dir_model / 'added_tokens.json' + if added_tokens_file.is_file(): + with open(added_tokens_file, "r", encoding="utf-8") as f: + added_tokens_json = json.load(f) + for key in added_tokens_json: + token_id = added_tokens_json[key] + if token_id >= vocab_size: + logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') + continue + + tokens[token_id] = key.encode("utf-8") + scores[token_id] = -1000.0 + toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + + if vocab_size > len(tokens): + pad_count = vocab_size - len(tokens) + logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]") + for i in range(1, pad_count + 1): + tokens.append(bytes(f"[PAD{i}]", encoding="utf-8")) + scores.append(-1000.0) + toktypes.append(SentencePieceTokenTypes.UNUSED) + + self.gguf_writer.add_tokenizer_model("t5") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_add_space_prefix(add_prefix) + self.gguf_writer.add_remove_extra_whitespaces(remove_whitespaces) + if precompiled_charsmap: + self.gguf_writer.add_precompiled_charsmap(precompiled_charsmap) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + self.gguf_writer.add_add_bos_token(False) + self.gguf_writer.add_add_eos_token(True) + + def set_gguf_parameters(self): + if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None: + logger.warning("Couldn't find context length in config.json, assuming default value of 512") + n_ctx = 512 + self.gguf_writer.add_context_length(n_ctx) + self.gguf_writer.add_embedding_length(self.hparams["d_model"]) + self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"]) + self.gguf_writer.add_block_count(self.hparams["num_layers"]) + self.gguf_writer.add_head_count(self.hparams["num_heads"]) + self.gguf_writer.add_key_length(self.hparams["d_kv"]) + self.gguf_writer.add_value_length(self.hparams["d_kv"]) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_relative_attn_buckets_count(self.hparams["relative_attention_num_buckets"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + # T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight", + # "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored + # in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder + # and decoder and ignore the remaining ones. + if name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "shared.weight"]: + if not self.shared_token_embeddings_found: + name = "shared.weight" + self.shared_token_embeddings_found = True + else: + logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.") + return [] + + return [(self.map_tensor_name(name), data_torch)] + + +@ModelBase.register("JAISLMHeadModel") +class JaisModel(TextModel): + model_arch = gguf.MODEL_ARCH.JAIS + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # SwigLU activation + assert self.hparams["activation_function"] == "swiglu" + # ALiBi position embedding + assert self.hparams["position_embedding_type"] == "alibi" + + # Embeddings scale + self.embeddings_scale = 1.0 + if 'mup_embeddings_scale' in self.hparams: + self.embeddings_scale = self.hparams['mup_embeddings_scale'] + elif 'embeddings_scale' in self.hparams: + self.embeddings_scale = self.hparams['embeddings_scale'] + else: + assert False + + self.width_scale = 1.0 + if 'mup_output_alpha' in self.hparams: + assert 'mup_width_scale' in self.hparams + self.width_scale = self.hparams['mup_output_alpha'] * self.hparams['mup_width_scale'] + elif 'width_scale' in self.hparams: + self.width_scale = self.hparams['width_scale'] + else: + assert False + + self.max_alibi_bias = 8.0 + + def set_vocab(self): + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + self.gguf_writer.add_block_count(self.hparams["n_layer"]) + self.gguf_writer.add_context_length(self.hparams["n_positions"]) + self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) + self.gguf_writer.add_feed_forward_length(self.hparams["n_inner"]) + self.gguf_writer.add_head_count(self.hparams["n_head"]) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + tensors: list[tuple[str, Tensor]] = [] + + # we don't need these + if name.endswith((".attn.bias")): + return tensors + + if name.endswith(("relative_pe.slopes")): + # Calculate max ALiBi bias (this is the inverse of the ALiBi calculation) + # Some other models has max_alibi_bias spelled out explicitly in the hyperparams, + # but Jais's PyTorch model simply precalculates the slope values and places them + # in relative_pes.slopes + n_head_closest_log2 = 2 ** math.floor(math.log2(self.hparams["n_head"])) + first_val = float(data_torch[0].item()) + self.max_alibi_bias = -round(math.log2(first_val) * n_head_closest_log2) + + return tensors + + if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_fc2.weight")): + data_torch = data_torch.transpose(1, 0) + + new_name = self.map_tensor_name(name) + + if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD): + tensors.append((new_name, data_torch * self.embeddings_scale)) + elif new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT): + tensors.append((new_name, data_torch * self.width_scale)) + else: + tensors.append((new_name, data_torch)) + + return tensors + + def prepare_tensors(self): + super().prepare_tensors() + self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias) + + +@ModelBase.register("Glm4ForCausalLM") +class Glm4Model(TextModel): + model_arch = gguf.MODEL_ARCH.GLM4 + + def set_vocab(self): + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + tokens, toktypes, tokpre = self.get_vocab_base() + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) + special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab.add_to_gguf(self.gguf_writer) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + rope_dim = self.hparams["head_dim"] + self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))) + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + + +@ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration") +class ChatGLMModel(TextModel): + model_arch = gguf.MODEL_ARCH.CHATGLM + + def set_vocab_chatglm3(self): + dir_model = self.dir_model + hparams = self.hparams + tokens: list[bytes] = [] + toktypes: list[int] = [] + scores: list[float] = [] + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) + vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab())) + assert max(tokenizer.get_vocab().values()) < vocab_size + role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"] + special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens + for token_id in range(vocab_size): + piece = tokenizer._convert_id_to_token(token_id) + if token_id == 0: + piece = "" + elif token_id == 1: + piece = "" + elif token_id == 2: + piece = "" + + text = piece.encode("utf-8") + score = 0.0 + # Referencing the tokenizer Python implementation(https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py), + # it is only valid if it is less than tokenizer.tokenizer.sp_model.vocab_size() + if len(piece) != 0 and token_id < tokenizer.tokenizer.sp_model.vocab_size(): + score = tokenizer.tokenizer.sp_model.get_score(token_id) + + if token_id >= tokenizer.tokenizer.sp_model.vocab_size(): + if piece in special_tokens: + toktype = SentencePieceTokenTypes.CONTROL + elif len(piece) == 0: + text = f"[PAD{token_id}]".encode("utf-8") + toktype = SentencePieceTokenTypes.UNUSED + else: + toktype = SentencePieceTokenTypes.USER_DEFINED + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + continue + + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.tokenizer.sp_model.is_unknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.tokenizer.sp_model.is_control(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.tokenizer.sp_model.is_unused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.tokenizer.sp_model.is_byte(token_id): + toktype = SentencePieceTokenTypes.BYTE + + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + self.gguf_writer.add_tokenizer_model("llama") + # glm3 needs prefix and suffix formatted as: + # prompt = "[gMASK]sop<|user|>\n" + prompt + "<|assistant|>" + self.gguf_writer.add_tokenizer_pre("chatglm-spm") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + @staticmethod + def token_bytes_to_string(b): + from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode + byte_encoder = bytes_to_unicode() + return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')]) + + @staticmethod + def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]: + parts = [bytes([b]) for b in token] + while True: + min_idx = None + min_rank = None + for i, pair in enumerate(zip(parts[:-1], parts[1:])): + rank = mergeable_ranks.get(pair[0] + pair[1]) + if rank is not None and (min_rank is None or rank < min_rank): + min_idx = i + min_rank = rank + if min_rank is None or (max_rank is not None and min_rank >= max_rank): + break + assert min_idx is not None + parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:] + return parts + + def set_vocab(self): + if "THUDM/chatglm3-6b" in self.hparams.get("_name_or_path", ""): + self.set_vocab_chatglm3() + return + + dir_model = self.dir_model + hparams = self.hparams + tokens: list[str] = [] + toktypes: list[int] = [] + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) + vocab_size = hparams.get("padded_vocab_size",hparams["vocab_size"]) + assert max(tokenizer.get_vocab().values()) < vocab_size + + tokens, toktypes, tokpre = self.get_vocab_base() + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + # only add special tokens when they were not already loaded from config.json + special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) + # this one is usually not in config.json anyway + special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab.add_to_gguf(self.gguf_writer) + + def set_gguf_parameters(self): + n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) + n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) + n_head_kv = self.hparams.get("multi_query_group_num", self.hparams.get("num_key_value_heads", n_head)) + self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed)) + self.gguf_writer.add_embedding_length(n_embed) + self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", self.hparams.get("intermediate_size", 4 * n_embed))) + self.gguf_writer.add_block_count(self.hparams.get("num_layers", self.hparams["num_hidden_layers"])) + self.gguf_writer.add_head_count(n_head) + self.gguf_writer.add_head_count_kv(n_head_kv) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon",1e-5)) + self.gguf_writer.add_file_type(self.ftype) + if "attention_dim" in self.hparams: + rope_dim = self.hparams["attention_dim"] + else: + rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))) + self.gguf_writer.add_add_bos_token(False) + rope_freq = 10000 + if "rope_ratio" in self.hparams: + rope_freq = rope_freq * self.hparams["rope_ratio"] + self.gguf_writer.add_rope_freq_base(rope_freq) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if name.endswith(".rotary_pos_emb.inv_freq") or name.startswith("model.vision."): + return [] + + name = name.removeprefix("transformer.") + return [(self.map_tensor_name(name), data_torch)] + + +@ModelBase.register("NemotronForCausalLM") +class NemotronModel(TextModel): + model_arch = gguf.MODEL_ARCH.NEMOTRON + + def set_vocab(self): + self._set_vocab_sentencepiece() + self.gguf_writer.add_pad_token_id(0) + self.gguf_writer.add_unk_token_id(1) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + + f_norm_eps = self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon", "norm_eps"]) + self.gguf_writer.add_layer_norm_eps(f_norm_eps) + + # * Partial RoPE + rot_pct = self.find_hparam(["partial_rotary_factor", "rope_pct", "rope_percent"]) + n_embd = self.find_hparam(["hidden_size", "n_embd"]) + n_head = self.find_hparam(["num_attention_heads", "n_head"]) + self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head) + + # * RopeScaling for Nemotron + if "rope_scaling" not in self.hparams or self.hparams["rope_scaling"] is None: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + else: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(self.hparams["factor"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # * Adding +1 to LayerNorm's weights here to implement layernorm1p w/o changing anything on the GGML engine side + # model.layers.{l}.input_layernorm.weight + # model.layers.{l}.post_attention_layernorm.weight + # model.norm.weight + if name.endswith("norm.weight"): + data_torch = data_torch + 1 + + return [(self.map_tensor_name(name), data_torch)] + + +@ModelBase.register("ExaoneForCausalLM") +class ExaoneModel(TextModel): + model_arch = gguf.MODEL_ARCH.EXAONE + + def set_gguf_parameters(self): + hparams = self.hparams + + assert (hparams["activation_function"] == "silu") + + max_position_embeddings = hparams["max_position_embeddings"] + embed_dim = hparams["hidden_size"] + num_heads = hparams["num_attention_heads"] + num_kv_heads = hparams.get("num_key_value_heads", num_heads) + layer_norm_eps = hparams["layer_norm_epsilon"] + intermediate_size = hparams["intermediate_size"] if "intermediate_size" in hparams else 4 * embed_dim + num_layers = hparams["num_layers"] + # ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0 + # attention_dropout_rate = hparams["attention_dropout"] + # ignore for now as EXAONE-3.0-7.8B-Instruct embed_dropout is 0.0 + # embed_dropout_rate = hparams["embed_dropout"] + self.gguf_writer.add_embedding_length(embed_dim) + self.gguf_writer.add_head_count(num_heads) + self.gguf_writer.add_head_count_kv(num_kv_heads) + self.gguf_writer.add_context_length(max_position_embeddings) + self.gguf_writer.add_layer_norm_rms_eps(layer_norm_eps) + self.gguf_writer.add_feed_forward_length(intermediate_size) + self.gguf_writer.add_block_count(num_layers) + self.gguf_writer.add_file_type(self.ftype) + + if (rope_theta := self.hparams.get("rope_theta")) is not None: + self.gguf_writer.add_rope_freq_base(rope_theta) + rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"], optional=True) + rotary_factor = rotary_factor if rotary_factor is not None else 1.0 + self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"]))) + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): + if rope_scaling.get("rope_type", '').lower() == "llama3": + base = self.hparams.get("rope_theta", 10000.0) + dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + + factor = rope_scaling.get("factor", 8.0) + low_freq_factor = rope_scaling.get("low_freq_factor", 1.0) + high_freq_factor = rope_scaling.get("high_freq_factor", 4.0) + old_context_len = self.hparams.get("original_max_position_embeddings", 8192) + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + assert low_freq_wavelen != high_freq_wavelen + + rope_factors = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + rope_factors.append(1) + elif wavelen > low_freq_wavelen: + rope_factors.append(factor) + else: + smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + rope_factors.append(1 / ((1 - smooth) / factor + smooth)) + + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) + + +@ModelBase.register("GraniteForCausalLM") +class GraniteModel(LlamaModel): + """Conversion for IBM's GraniteForCausalLM""" + model_arch = gguf.MODEL_ARCH.GRANITE + + def set_gguf_parameters(self): + """Granite uses standard llama parameters with the following differences: + + - No head_dim support + - New multiplier params: + - attention_scale + - embedding_scale + - residual_scale + - logits_scaling + """ + if head_dim := self.hparams.pop("head_dim", None): + logger.warning("Ignoring head_dim (%s) from config for Granite", head_dim) + super().set_gguf_parameters() + # NOTE: Convert _multiplier params to _scale params for naming + # consistency + if attention_scale := self.hparams.get("attention_multiplier"): + self.gguf_writer.add_attention_scale(attention_scale) + logger.info("gguf: (granite) attention_scale = %s", attention_scale) + if embedding_scale := self.hparams.get("embedding_multiplier"): + self.gguf_writer.add_embedding_scale(embedding_scale) + logger.info("gguf: (granite) embedding_scale = %s", embedding_scale) + if residual_scale := self.hparams.get("residual_multiplier"): + self.gguf_writer.add_residual_scale(residual_scale) + logger.info("gguf: (granite) residual_scale = %s", residual_scale) + if logits_scale := self.hparams.get("logits_scaling"): + self.gguf_writer.add_logit_scale(logits_scale) + logger.info("gguf: (granite) logits_scale = %s", logits_scale) + + +@ModelBase.register("GraniteMoeForCausalLM", "GraniteMoeSharedForCausalLM") +class GraniteMoeModel(GraniteModel): + """Conversion for IBM's GraniteMoeForCausalLM""" + model_arch = gguf.MODEL_ARCH.GRANITE_MOE + + def set_gguf_parameters(self): + """GraniteMoeShared uses GraniteMoe parameters plus the following: + - shared_intermediate_size + """ + super().set_gguf_parameters() + if shared_feed_forward_length := self.hparams.get("shared_intermediate_size"): + self.gguf_writer.add_expert_shared_feed_forward_length(shared_feed_forward_length) + logger.info("gguf: (granitemoeshared) shared_feed_forward_length = %s", shared_feed_forward_length) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + """In modeling_granitemoe, the JetMoe implementation of parallel experts + is used. This essentially merges w1 and w3 into a single tensor with 2x + the hidden size that is then split during forward. To keep compatibility + with existing mixtral support, we pull them apart here. + """ + + if name.endswith("block_sparse_moe.input_linear.weight"): + ffn_dim = self.hparams["intermediate_size"] + assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size" + gate, up = data_torch.split(ffn_dim, dim=-2) + return [ + (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate), + (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up), + ] + + if name.endswith("shared_mlp.input_linear.weight"): + ffn_dim = self.hparams["shared_intermediate_size"] + assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size" + gate, up = data_torch.split(ffn_dim, dim=-2) + return [ + (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate), + (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up), + ] + + return super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("BailingMoeForCausalLM") +class BailingMoeModel(TextModel): + model_arch = gguf.MODEL_ARCH.BAILINGMOE + + def set_vocab(self): + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + rope_dim = hparams.get("head_dim") or hparams["hidden_size"] // hparams["num_attention_heads"] + + self.gguf_writer.add_rope_dimension_count(rope_dim) + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + else: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) + self.gguf_writer.add_expert_weights_scale(1.0) + self.gguf_writer.add_expert_count(hparams["num_experts"]) + self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"]) + self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + + _experts: list[dict[str, Tensor]] | None = None + + @staticmethod + def permute(weights: Tensor, n_head: int, n_head_kv: int | None): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") + n_embd = self.hparams["hidden_size"] + head_dim = self.hparams.get("head_dim") or n_embd // n_head + + output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT) + + if name.endswith("attention.dense.weight"): + return [(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid), data_torch)] + elif name.endswith("query_key_value.weight"): + q, k, v = data_torch.split([n_head * head_dim, n_kv_head * head_dim, n_kv_head * head_dim], dim=-2) + + return [ + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), BailingMoeModel.permute(q, n_head, n_head)), + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), BailingMoeModel.permute(k, n_head, n_kv_head)), + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), v) + ] + elif name.find("mlp.experts") != -1: + n_experts = self.hparams["num_experts"] + assert bid is not None + + tensors: list[tuple[str, Tensor]] = [] + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + # merge the experts into a single 3d tensor + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + + return tensors + + new_name = self.map_tensor_name(name) + + if new_name == output_name and self.hparams.get("norm_head"): + data_torch = data_torch.float() + data_torch /= torch.norm(data_torch, p=2, dim=0, keepdim=True) + 1e-7 + + return [(new_name, data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + +@ModelBase.register("ChameleonForConditionalGeneration") +@ModelBase.register("ChameleonForCausalLM") # obsolete +class ChameleonModel(TextModel): + model_arch = gguf.MODEL_ARCH.CHAMELEON + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_swin_norm(self.hparams.get("swin_norm", False)) + + def set_vocab(self): + self._set_vocab_gpt2() + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # ignore image tokenizer for now + # TODO: remove this once image support is implemented for Chameleon + if name.startswith("model.vqmodel"): + return [] + + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") + hidden_dim = self.hparams.get("hidden_size") + + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + if name.endswith(("q_norm.weight", "q_norm.bias")): + data_torch = ChameleonModel._reverse_hf_permute(data_torch, n_head, hidden_dim) + if name.endswith(("k_norm.weight", "k_norm.bias")): + data_torch = ChameleonModel._reverse_hf_permute(data_torch, n_kv_head, hidden_dim) + + return [(self.map_tensor_name(name), data_torch)] + + # see: https://github.com/huggingface/transformers/blob/72fb02c47dbbe1999ae105319f24631cad6e2e00/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py#L176-L203 + @staticmethod + def _reverse_hf_permute(data_torch, n_heads, hidden_dim): + head_dim = hidden_dim // n_heads + data_torch = data_torch[0].view(2, head_dim // 2).t().reshape(1, -1) + data_torch = data_torch.repeat_interleave(n_heads, 0) + return data_torch + + +@ModelBase.register("UltravoxModel") +class UltravoxModel(TextModel): + model_arch = gguf.MODEL_ARCH.LLAMA # dummy + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + raise NotImplementedError("Ultravox does not have text decoder. Instead, it uses Llama or other models for text. If you want to get the audio encoder, please use --mmproj argument") + + +@ModelBase.register("Qwen2AudioForConditionalGeneration") +class WhisperEncoderModel(MmprojModel): + has_vision_encoder = False # no vision encoder + has_audio_encoder = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.hparams["hidden_size"] = self.hparams["d_model"] + self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"] + self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"] + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN2A) + self.gguf_writer.add_audio_num_mel_bins(self.hparams["num_mel_bins"]) + self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5)) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + del bid, new_name, n_dims # unused + if ".conv" in name and ".weight" in name: + return gguf.GGMLQuantizationType.F16 + return False + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if name.startswith("language_model."): + # skip language model tensors + return [] + + # prevent clash naming with vision tensors + if name.startswith("multi_modal_projector"): + name = "audio." + name + + if "conv1.bias" in name or "conv2.bias" in name: + # transpose conv1 and conv2 bias + data_torch = data_torch.unsqueeze(-1) + + return [(self.map_tensor_name(name), data_torch)] + + +@ModelBase.register("UltravoxModel") +class UltravoxWhisperEncoderModel(WhisperEncoderModel): + has_vision_encoder = False # no vision encoder + has_audio_encoder = True + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"]) + +###### CONVERSION LOGIC ###### + + +# tree of lazy tensors +class LazyTorchTensor(gguf.LazyBase): + _tensor_type = torch.Tensor + # to keep the type-checker happy + dtype: torch.dtype + shape: torch.Size + + # only used when converting a torch.Tensor to a np.ndarray + _dtype_map: dict[torch.dtype, type] = { + torch.float16: np.float16, + torch.float32: np.float32, + } + + # used for safetensors slices + # ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046 + # TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734 + _dtype_str_map: dict[str, torch.dtype] = { + "F64": torch.float64, + "F32": torch.float32, + "BF16": torch.bfloat16, + "F16": torch.float16, + # "U64": torch.uint64, + "I64": torch.int64, + # "U32": torch.uint32, + "I32": torch.int32, + # "U16": torch.uint16, + "I16": torch.int16, + "U8": torch.uint8, + "I8": torch.int8, + "BOOL": torch.bool, + "F8_E4M3": torch.float8_e4m3fn, + "F8_E5M2": torch.float8_e5m2, + } + + def numpy(self) -> gguf.LazyNumpyTensor: + dtype = self._dtype_map[self.dtype] + return gguf.LazyNumpyTensor( + meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape), + args=(self,), + func=(lambda s: s.numpy()) + ) + + @classmethod + def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor: + return torch.empty(size=shape, dtype=dtype, device="meta") + + @classmethod + def from_safetensors_slice(cls, st_slice: Any) -> Tensor: + dtype = cls._dtype_str_map[st_slice.get_dtype()] + shape: tuple[int, ...] = tuple(st_slice.get_shape()) + lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:]) + return cast(torch.Tensor, lazy) + + @classmethod + def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor): + dtype = cls._dtype_str_map[remote_tensor.dtype] + shape = remote_tensor.shape + meta = cls.meta_with_dtype_and_shape(dtype, shape) + lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape)) + return cast(torch.Tensor, lazy) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + del types # unused + + if kwargs is None: + kwargs = {} + + if func is torch.Tensor.numpy: + return args[0].numpy() + + return cls._wrap_fn(func)(*args, **kwargs) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Convert a huggingface model to a GGML compatible file") + parser.add_argument( + "--vocab-only", action="store_true", + help="extract only the vocab", + ) + parser.add_argument( + "--outfile", type=Path, + help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", + ) + parser.add_argument( + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16", + help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", + ) + parser.add_argument( + "--bigendian", action="store_true", + help="model is executed on big endian machine", + ) + parser.add_argument( + "model", type=Path, + help="directory containing model file", + nargs="?", + ) + parser.add_argument( + "--use-temp-file", action="store_true", + help="use the tempfile library while processing (helpful when running out of memory, process killed)", + ) + parser.add_argument( + "--no-lazy", action="store_true", + help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)", + ) + parser.add_argument( + "--model-name", type=str, default=None, + help="name of the model", + ) + parser.add_argument( + "--verbose", action="store_true", + help="increase output verbosity", + ) + parser.add_argument( + "--split-max-tensors", type=int, default=0, + help="max tensors in each split", + ) + parser.add_argument( + "--split-max-size", type=str, default="0", + help="max size per split N(M|G)", + ) + parser.add_argument( + "--dry-run", action="store_true", + help="only print out a split plan and exit, without writing any new files", + ) + parser.add_argument( + "--no-tensor-first-split", action="store_true", + help="do not add tensors to the first split (disabled by default)" + ) + parser.add_argument( + "--metadata", type=Path, + help="Specify the path for an authorship metadata override file" + ) + parser.add_argument( + "--print-supported-models", action="store_true", + help="Print the supported models" + ) + parser.add_argument( + "--remote", action="store_true", + help="(Experimental) Read safetensors file remotely without downloading to disk. Config and tokenizer files will still be downloaded. To use this feature, you need to specify Hugging Face model repo name instead of a local directory. For example: 'HuggingFaceTB/SmolLM2-1.7B-Instruct'. Note: To access gated repo, set HF_TOKEN environment variable to your Hugging Face token.", + ) + parser.add_argument( + "--mmproj", action="store_true", + help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.", + ) + + args = parser.parse_args() + if not args.print_supported_models and args.model is None: + parser.error("the following arguments are required: model") + return args + + +def split_str_to_n_bytes(split_str: str) -> int: + if split_str.endswith("K"): + n = int(split_str[:-1]) * 1000 + elif split_str.endswith("M"): + n = int(split_str[:-1]) * 1000 * 1000 + elif split_str.endswith("G"): + n = int(split_str[:-1]) * 1000 * 1000 * 1000 + elif split_str.isnumeric(): + n = int(split_str) + else: + raise ValueError(f"Invalid split size: {split_str}, must be a number, optionally followed by K, M, or G") + + if n < 0: + raise ValueError(f"Invalid split size: {split_str}, must be positive") + + return n + + +def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str: + # TODO @ngxson : this won't work correctly if the model has both audio & vision encoders + # maybe we should fallback to text model's arch in that case, since not many models have both + text_config = hparams.get("text_config", {}) + vision_config = hparams.get("vision_config", {}) + arch = hparams["architectures"][0] + # if "architectures" is found in the sub-config, use that instead + if model_type == ModelType.TEXT and text_config.get("architectures") is not None: + arch = text_config["architectures"][0] + elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None: + arch = vision_config["architectures"][0] + return arch + + +def main() -> None: + args = parse_args() + + if args.print_supported_models: + logger.error("Supported models:") + ModelBase.print_registered_models() + sys.exit(0) + + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + else: + logging.basicConfig(level=logging.INFO) + + dir_model = args.model + + if args.remote: + from huggingface_hub import snapshot_download + local_dir = snapshot_download( + repo_id=str(dir_model), + allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"]) + dir_model = Path(local_dir) + logger.info(f"Downloaded config and tokenizer to {local_dir}") + + if not dir_model.is_dir(): + logger.error(f'Error: {args.model} is not a directory') + sys.exit(1) + + ftype_map: dict[str, gguf.LlamaFileType] = { + "f32": gguf.LlamaFileType.ALL_F32, + "f16": gguf.LlamaFileType.MOSTLY_F16, + "bf16": gguf.LlamaFileType.MOSTLY_BF16, + "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, + "tq1_0": gguf.LlamaFileType.MOSTLY_TQ1_0, + "tq2_0": gguf.LlamaFileType.MOSTLY_TQ2_0, + "auto": gguf.LlamaFileType.GUESSED, + } + + is_split = args.split_max_tensors > 0 or args.split_max_size != "0" + if args.use_temp_file and is_split: + logger.error("Error: Cannot use temp file when splitting") + sys.exit(1) + + if args.outfile is not None: + fname_out = args.outfile + elif args.remote: + # if remote, use the model ID as the output file name + fname_out = Path("./" + str(args.model).replace("/", "-") + "-{ftype}.gguf") + else: + fname_out = dir_model + + logger.info(f"Loading model: {dir_model.name}") + + if args.mmproj: + if "mmproj" not in fname_out.name: + fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-") + + with torch.inference_mode(): + output_type = ftype_map[args.outtype] + model_type = ModelType.MMPROJ if args.mmproj else ModelType.TEXT + hparams = ModelBase.load_hparams(dir_model) + model_architecture = get_model_architecture(hparams, model_type) + logger.info(f"Model architecture: {model_architecture}") + try: + model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type) + except NotImplementedError: + logger.error(f"Model {model_architecture} is not supported") + sys.exit(1) + + model_instance = model_class(dir_model, output_type, fname_out, + is_big_endian=args.bigendian, use_temp_file=args.use_temp_file, + eager=args.no_lazy, + metadata_override=args.metadata, model_name=args.model_name, + split_max_tensors=args.split_max_tensors, + split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, + small_first_shard=args.no_tensor_first_split, + remote_hf_model_id=str(args.model) if args.remote else None) + + if args.vocab_only: + logger.info("Exporting model vocab...") + model_instance.write_vocab() + logger.info(f"Model vocab successfully exported to {model_instance.fname_out}") + else: + logger.info("Exporting model...") + model_instance.write() + out_path = f"{model_instance.fname_out.parent}{os.sep}" if is_split else model_instance.fname_out + logger.info(f"Model successfully exported to {out_path}") + + +if __name__ == '__main__': + main() diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py new file mode 100644 index 0000000000000000000000000000000000000000..2f733f0973686f3f7d57508864348a4629243942 --- /dev/null +++ b/convert_hf_to_gguf_update.py @@ -0,0 +1,445 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import logging +import os +import pathlib +import re + +import requests +import sys +import json +import shutil +import argparse + +from hashlib import sha256 +from enum import IntEnum, auto +from transformers import AutoTokenizer + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger("convert_hf_to_gguf_update") +sess = requests.Session() + +convert_py_pth = pathlib.Path("convert_hf_to_gguf.py") +convert_py = convert_py_pth.read_text(encoding="utf-8") +hf_token_pth = pathlib.Path.home() / ".cache" / "huggingface" / "token" +hf_token = hf_token_pth.read_text(encoding="utf-8").strip() if hf_token_pth.exists() else None + + +class TOKENIZER_TYPE(IntEnum): + SPM = auto() + BPE = auto() + WPM = auto() + UGM = auto() + + +DOC_STRING = """ +This script downloads the tokenizer models of the specified models from Huggingface and +generates the get_vocab_base_pre() function for convert_hf_to_gguf.py + +/!\\ It is intended to be used by contributors and is not meant to be run by end users + +This is necessary in order to analyze the type of pre-tokenizer used by the model and +provide the necessary information to llama.cpp via the GGUF header in order to implement +the same pre-tokenizer. + +ref: https://github.com/ggml-org/llama.cpp/pull/6920 + +Instructions: + +- Add a new model to the "models" list +- Run the script with your huggingface token + By default, token will be read from ~/.cache/huggingface/token +- The convert_hf_to_gguf.py script will have had its get_vocab_base_pre() function updated +- Update llama.cpp with the new pre-tokenizer if necessary +""" +# TODO: generate tokenizer tests for llama.cpp + +parser = argparse.ArgumentParser(description=DOC_STRING, formatter_class=argparse.RawTextHelpFormatter) +parser.add_argument( + "--full", action="store_true", + help="download full list of models - make sure you have access to all of them", +) +parser.add_argument( + "hf_token", + help="optional HF token", + nargs="?", +) +args = parser.parse_args() +hf_token = args.hf_token if args.hf_token is not None else hf_token + +if hf_token is None: + logger.error("HF token is required. Please provide it as an argument or set it in ~/.cache/huggingface/token") + sys.exit(1) + +# TODO: this string has to exercise as much pre-tokenizer functionality as possible +# will be updated with time - contributions welcome +CHK_TXT = '\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български \'\'\'\'\'\'```````\"\"\"\"......!!!!!!?????? I\'ve been \'told he\'s there, \'RE you sure? \'M not sure I\'ll make it, \'D you like some tea? We\'Ve a\'lL' + +# TODO: add models here, base models preferred +models = [ + {"name": "llama-spm", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/meta-llama/Llama-2-7b-hf", }, + {"name": "llama-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Meta-Llama-3-8B", }, + {"name": "phi-3", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", }, + {"name": "deepseek-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-llm-7b-base", }, + {"name": "deepseek-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base", }, + {"name": "falcon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/falcon-7b", }, + {"name": "bert-bge", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/BAAI/bge-small-en-v1.5", }, + {"name": "falcon3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon3-7B-Base", }, + {"name": "bert-bge-large", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/BAAI/bge-large-zh-v1.5", }, + {"name": "mpt", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mosaicml/mpt-7b", }, + {"name": "starcoder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/starcoder2-3b", }, + {"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", }, + {"name": "stablelm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b", }, + {"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", }, + {"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", }, + {"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", }, + {"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", }, + {"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", }, + {"name": "jina-v1-en", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-reranker-v1-tiny-en", }, + {"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM! + {"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", }, + {"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", }, + {"name": "smaug-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", }, + {"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", }, + {"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", }, + {"name": "viking", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Viking-7B", }, # Also used for Viking 13B and 33B + {"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", }, + {"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", }, + {"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", }, + {"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", }, + {"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", }, + {"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", }, + {"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", }, + {'name': "bloom", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigscience/bloom", }, + {'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", }, + {"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", }, + {"name": "phi-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/microsoft/phi-2", }, + {"name": "chameleon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", }, + {"name": "roberta-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"}, + {"name": "gigachat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct"}, + {"name": "megrez", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"}, + {"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"}, + {"name": "deepseek-r1-qwen", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"}, + {"name": "gpt-4o", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Xenova/gpt-4o", }, + {"name": "superbpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k", }, + {"name": "trillion", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/trillionlabs/Trillion-7B-preview", }, + {"name": "bailingmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-lite", }, + {"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", }, + {"name": "pixtral", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", }, + {"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", }, +] + +# some models are known to be broken upstream, so we will skip them as exceptions +pre_computed_hashes = [ + # chatglm-bpe has 2 hashes, why? + {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b"}, + {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"}, + {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"}, + {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"}, +] + + +def download_file_with_auth(url, token, save_path): + headers = {"Authorization": f"Bearer {token}"} + response = sess.get(url, headers=headers) + response.raise_for_status() + os.makedirs(os.path.dirname(save_path), exist_ok=True) + with open(save_path, 'wb') as downloaded_file: + downloaded_file.write(response.content) + logger.info(f"File {save_path} downloaded successfully") + + +def download_model(model): + name = model["name"] + repo = model["repo"] + tokt = model["tokt"] + + os.makedirs(f"models/tokenizers/{name}", exist_ok=True) + + files = ["config.json", "tokenizer.json", "tokenizer_config.json"] + + if name == "gpt-4o": + # Xenova/gpt-4o is tokenizer-only, it does not contain config.json + files = ["tokenizer.json", "tokenizer_config.json"] + + if tokt == TOKENIZER_TYPE.SPM: + files.append("tokenizer.model") + + if tokt == TOKENIZER_TYPE.UGM: + files.append("spiece.model") + + if os.path.isdir(repo): + # If repo is a path on the file system, copy the directory + for file in files: + src_path = os.path.join(repo, file) + dst_path = f"models/tokenizers/{name}/{file}" + if os.path.isfile(dst_path): + logger.info(f"{name}: File {dst_path} already exists - skipping") + continue + if os.path.isfile(src_path): + shutil.copy2(src_path, dst_path) + logger.info(f"{name}: Copied {src_path} to {dst_path}") + else: + logger.warning(f"{name}: Source file {src_path} does not exist") + else: + # If repo is a URL, download the files + for file in files: + save_path = f"models/tokenizers/{name}/{file}" + if os.path.isfile(save_path): + logger.info(f"{name}: File {save_path} already exists - skipping") + continue + download_file_with_auth(f"{repo}/resolve/main/{file}", hf_token, save_path) + + +# get list of existing models and chkhsh from the convert_hf_to_gguf.py file +# returns mapping res --> chkhsh +def get_existing_models(convert_py): + pattern = r'if chkhsh == "([a-f0-9]{64})":\s*\n\s*.*\s*res = "([^"]+)"' + matches = re.findall(pattern, convert_py) + output = {} + for chkhsh, res in matches: + output[res] = chkhsh + return output + + +existing_models = {} +all_models = models.copy() +if not args.full: + # Filter out models that already exist in convert_hf_to_gguf.py + existing_models = get_existing_models(convert_py) + all_models = models.copy() + models = [model for model in all_models if model["name"] not in existing_models] + +logging.info(f"Downloading {len(models)} models...") +for model in models: + try: + download_model(model) + except Exception as e: + logger.error(f"Failed to download model {model['name']}. Error: {e}") + + +# generate the source code for the convert_hf_to_gguf.py:get_vocab_base_pre() function: + +src_ifs = "" +for model in [*all_models, *pre_computed_hashes]: + name = model["name"] + tokt = model["tokt"] + chkhsh = model.get("chkhsh") + + if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM: + continue + + # Skip if the tokenizer folder does not exist or there are other download issues previously + if not os.path.exists(f"models/tokenizers/{name}"): + logger.warning(f"Directory for tokenizer {name} not found. Skipping...") + continue + + # create the tokenizer + if chkhsh is not None: + # if the model has a pre-computed hash, use it + logger.info(f"Using pre-computed hash for model {name}: {chkhsh}") + elif name in existing_models: + # if the model already exists in convert_hf_to_gguf.py, skip compute hash + chkhsh = existing_models[name] + else: + # otherwise, compute the hash of the tokenizer + try: + logger.info(f"Loading tokenizer from {f'models/tokenizers/{name}'}...") + if name == "t5": + tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False) + else: + tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") + except OSError as e: + logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}") + continue # Skip to the next model if the tokenizer can't be loaded + + chktok = tokenizer.encode(CHK_TXT) + chkhsh = sha256(str(chktok).encode()).hexdigest() + + logger.info(f"model: {name}") + logger.info(f"tokt: {tokt}") + logger.info(f"repo: {model['repo']}") + logger.info(f"chktok: {chktok}") + logger.info(f"chkhsh: {chkhsh}") + + # print the "pre_tokenizer" content from the tokenizer.json + with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f: + cfg = json.load(f) + normalizer = cfg["normalizer"] + logger.info("normalizer: " + json.dumps(normalizer, indent=4)) + pre_tokenizer = cfg["pre_tokenizer"] + logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4)) + if "ignore_merges" in cfg["model"]: + logger.info("ignore_merges: " + json.dumps(cfg["model"]["ignore_merges"], indent=4)) + + logger.info("") + + src_ifs += f" if chkhsh == \"{chkhsh}\":\n" + src_ifs += f" # ref: {model['repo']}\n" + src_ifs += f" res = \"{name}\"\n" + +src_func = f""" + def get_vocab_base_pre(self, tokenizer) -> str: + # encoding this string and hashing the resulting tokens would (hopefully) give us a unique identifier that + # is specific for the BPE pre-tokenizer used by the model + # we will use this unique identifier to write a "tokenizer.ggml.pre" entry in the GGUF file which we can + # use in llama.cpp to implement the same pre-tokenizer + + chktxt = {repr(CHK_TXT)} + + chktok = tokenizer.encode(chktxt) + chkhsh = sha256(str(chktok).encode()).hexdigest() + + logger.debug(f"chktok: {{chktok}}") + logger.debug(f"chkhsh: {{chkhsh}}") + + res = None + + # NOTE: if you get an error here, you need to update the convert_hf_to_gguf_update.py script + # or pull the latest version of the model from Huggingface + # don't edit the hashes manually! +{src_ifs} + if res is None: + logger.warning("\\n") + logger.warning("**************************************************************************************") + logger.warning("** WARNING: The BPE pre-tokenizer was not recognized!") + logger.warning("** There are 2 possible reasons for this:") + logger.warning("** - the model has not been added to convert_hf_to_gguf_update.py yet") + logger.warning("** - the pre-tokenization config has changed upstream") + logger.warning("** Check your model files and convert_hf_to_gguf_update.py and update them accordingly.") + logger.warning("** ref: https://github.com/ggml-org/llama.cpp/pull/6920") + logger.warning("**") + logger.warning(f"** chkhsh: {{chkhsh}}") + logger.warning("**************************************************************************************") + logger.warning("\\n") + raise NotImplementedError("BPE pre-tokenizer was not recognized - update get_vocab_base_pre()") + + logger.debug(f"tokenizer.ggml.pre: {{repr(res)}}") + logger.debug(f"chkhsh: {{chkhsh}}") + + return res +""" + +convert_py = re.sub( + r"(# Marker: Start get_vocab_base_pre)(.+?)( +# Marker: End get_vocab_base_pre)", + lambda m: m.group(1) + src_func + m.group(3), + convert_py, + flags=re.DOTALL | re.MULTILINE, +) + +convert_py_pth.write_text(convert_py, encoding="utf-8") + +logger.info("+++ convert_hf_to_gguf.py was updated") + +# generate tests for each tokenizer model + +tests = [ + "ied 4 ½ months", + "Äpfel", + "", + " ", + " ", + " ", + "\t", + "\n", + "\n\n", + "\n\n\n", + "\t\n", + "Hello world", + " Hello world", + "Hello World", + " Hello World", + " Hello World!", + "Hello, world!", + " Hello, world!", + " this is 🦙.cpp", + "w048 7tuijk dsdfhu", + "нещо на Български", + "កាន់តែពិសេសអាចខលចេញ", + "🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)", + "Hello", + " Hello", + " Hello", + " Hello", + " Hello", + " Hello\n Hello", + " (", + "\n =", + "' era", + "Hello, y'all! How are you 😁 ?我想在apple工作1314151天~", + "!!!!!!", + "3", + "33", + "333", + "3333", + "33333", + "333333", + "3333333", + "33333333", + "333333333", + "Cửa Việt", # llama-bpe fails on this + " discards", + CHK_TXT, +] + +# write the tests to ./models/ggml-vocab-{name}.gguf.inp +# the format is: +# +# test0 +# __ggml_vocab_test__ +# test1 +# __ggml_vocab_test__ +# ... +# + +# with each model, encode all tests and write the results in ./models/ggml-vocab-{name}.gguf.out +# for each test, write the resulting tokens on a separate line + +for model in models: + name = model["name"] + tokt = model["tokt"] + + # Skip if the tokenizer folder does not exist or there are other download issues previously + if not os.path.exists(f"models/tokenizers/{name}"): + logger.warning(f"Directory for tokenizer {name} not found. Skipping...") + continue + + # create the tokenizer + try: + if name == "t5": + tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False) + else: + tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") + except OSError as e: + logger.error(f"Failed to load tokenizer for model {name}. Error: {e}") + continue # Skip this model and continue with the next one in the loop + + if not os.path.exists(f"models/ggml-vocab-{name}.gguf"): + logger.info(f"Skip vocab files for model {name}, no GGUF file found") + continue + + with open(f"models/ggml-vocab-{name}.gguf.inp", "w", encoding="utf-8") as f: + for text in tests: + f.write(f"{text}") + f.write("\n__ggml_vocab_test__\n") + + with open(f"models/ggml-vocab-{name}.gguf.out", "w") as f: + for text in tests: + res = tokenizer.encode(text, add_special_tokens=False) + for r in res: + f.write(f" {r}") + f.write("\n") + + logger.info(f"Tests for {name} written in ./models/ggml-vocab-{name}.gguf.*") + +# generate commands for creating vocab files + +logger.info("\nRun the following commands to generate the vocab files for testing:\n") + +for model in models: + name = model["name"] + + print(f"python3 convert_hf_to_gguf.py models/tokenizers/{name}/ --outfile models/ggml-vocab-{name}.gguf --vocab-only") # noqa: NP100 + +logger.info("\n") diff --git a/convert_llama_ggml_to_gguf.py b/convert_llama_ggml_to_gguf.py new file mode 100644 index 0000000000000000000000000000000000000000..29b14e98dd237a4c04e139cc6e35d9494576a896 --- /dev/null +++ b/convert_llama_ggml_to_gguf.py @@ -0,0 +1,450 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import logging +import argparse +import os +import struct +import sys +from enum import IntEnum +from pathlib import Path + +import numpy as np + +if 'NO_LOCAL_GGUF' not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) +import gguf + +logger = logging.getLogger("ggml-to-gguf") + + +class GGMLFormat(IntEnum): + GGML = 0 + GGMF = 1 + GGJT = 2 + + +class GGMLFType(IntEnum): + ALL_F32 = 0 + MOSTLY_F16 = 1 + MOSTLY_Q4_0 = 2 + MOSTLY_Q4_1 = 3 + MOSTLY_Q4_1_SOME_F16 = 4 + MOSTLY_Q8_0 = 7 + MOSTLY_Q5_0 = 8 + MOSTLY_Q5_1 = 9 + MOSTLY_Q2_K = 10 + MOSTLY_Q3_K_S = 11 + MOSTLY_Q3_K_M = 12 + MOSTLY_Q3_K_L = 13 + MOSTLY_Q4_K_S = 14 + MOSTLY_Q4_K_M = 15 + MOSTLY_Q5_K_S = 16 + MOSTLY_Q5_K_M = 17 + MOSTLY_Q6_K = 18 + + +class Hyperparameters: + def __init__(self): + self.n_vocab = self.n_embd = self.n_mult = self.n_head = 0 + self.n_layer = self.n_rot = self.n_ff = 0 + self.ftype = GGMLFType.ALL_F32 + + def set_n_ff(self, model): + ff_tensor_idx = model.tensor_map.get(b'layers.0.feed_forward.w1.weight') + assert ff_tensor_idx is not None, 'Missing layer 0 FF tensor' + ff_tensor = model.tensors[ff_tensor_idx] + self.n_ff = ff_tensor.dims[1] + + def load(self, data, offset): + ( + self.n_vocab, + self.n_embd, + self.n_mult, + self.n_head, + self.n_layer, + self.n_rot, + ftype, + ) = struct.unpack('<7I', data[offset:offset + (4 * 7)]) + try: + self.ftype = GGMLFType(ftype) + except ValueError: + raise ValueError(f'Invalid ftype {ftype}') + return 4 * 7 + + def __str__(self): + return f'' + + +class Vocab: + def __init__(self, load_scores = True): + self.items = [] + self.load_scores = load_scores + + def load(self, data, offset, n_vocab): + orig_offset = offset + for _ in range(n_vocab): + itemlen = struct.unpack('= 0 and n_dims <= 4, f'Invalid tensor dimensions {n_dims}' + assert name_len < 4096, 'Absurd tensor name length' + quant = gguf.GGML_QUANT_SIZES.get(dtype) + assert quant is not None, 'Unknown tensor type' + (blksize, tysize) = quant + offset += 12 + self.dtype= gguf.GGMLQuantizationType(dtype) + self.dims = struct.unpack(f'<{n_dims}I', data[offset:offset + (4 * n_dims)]) + offset += 4 * n_dims + self.name = bytes(data[offset:offset + name_len]) + offset += name_len + pad = ((offset + 31) & ~31) - offset if self.use_padding else 0 + offset += pad + n_elems = np.prod(self.dims) + n_bytes = np.int64(np.int64(n_elems) * np.int64(tysize)) // np.int64(blksize) + self.start_offset = offset + self.len_bytes = n_bytes + offset += n_bytes + return offset - orig_offset + + +class GGMLModel: + + file_format: GGMLFormat + format_version: int + + def __init__(self): + self.hyperparameters = None + self.vocab = None + self.tensor_map = {} + self.tensors = [] + + def validate_header(self, data, offset): + magic = bytes(data[offset:offset + 4]) + if magic == b'GGUF': + raise ValueError('File is already in GGUF format.') + if magic == b'lmgg': + self.file_format = GGMLFormat.GGML + self.format_version = 1 + return 4 + version = struct.unpack(' 3: + raise ValueError(f'Cannot handle unexpected GGJT file version {version}') + self.file_format = GGMLFormat.GGJT + self.format_version = version + return 8 + raise ValueError(f"Unexpected file magic {magic!r}! This doesn't look like a GGML format file.") + + def validate_conversion(self, ftype): + err = '' + if (self.file_format < GGMLFormat.GGJT or self.format_version < 2): + if ftype not in (GGMLFType.ALL_F32, GGMLFType.MOSTLY_F16): + err = 'Quantizations changed in GGJTv2. Can only convert unquantized GGML files older than GGJTv2.' + elif (self.file_format == GGMLFormat.GGJT and self.format_version == 2): + if ftype in (GGMLFType.MOSTLY_Q4_0, GGMLFType.MOSTLY_Q4_1, + GGMLFType.MOSTLY_Q4_1_SOME_F16, GGMLFType.MOSTLY_Q8_0): + err = 'Q4 and Q8 quantizations changed in GGJTv3.' + if len(err) > 0: + raise ValueError(f'{err} Sorry, your {self.file_format.name}v{self.format_version} file of type {ftype.name} is not eligible for conversion.') + + def load(self, data, offset): + offset += self.validate_header(data, offset) + hp = Hyperparameters() + offset += hp.load(data, offset) + logger.info(f'* File format: {self.file_format.name}v{self.format_version} with ftype {hp.ftype.name}') + self.validate_conversion(hp.ftype) + vocab = Vocab(load_scores = self.file_format > GGMLFormat.GGML) + offset += vocab.load(data, offset, hp.n_vocab) + tensors: list[Tensor] = [] + tensor_map = {} + while offset < len(data): + tensor = Tensor(use_padding = self.file_format > GGMLFormat.GGMF) + offset += tensor.load(data, offset) + tensor_map[tensor.name] = len(tensors) + tensors.append(tensor) + self.hyperparameters = hp + self.vocab = vocab + self.tensors = tensors + self.tensor_map = tensor_map + hp.set_n_ff(self) + return offset + + +class GGMLToGGUF: + def __init__(self, ggml_model, data, cfg, params_override = None, vocab_override = None, special_vocab = None): + hp = ggml_model.hyperparameters + self.model = ggml_model + self.data = data + self.cfg = cfg + self.params_override = params_override + self.vocab_override = vocab_override + self.special_vocab = special_vocab + if params_override is not None: + n_kv_head = params_override.n_head_kv + else: + if cfg.gqa == 1: + n_kv_head = hp.n_head + else: + gqa = float(cfg.gqa) + n_kv_head = None + for x in range(1, 256): + if float(hp.n_head) / float(x) == gqa: + n_kv_head = x + assert n_kv_head is not None, "Couldn't determine n_kv_head from GQA param" + logger.info(f'- Guessed n_kv_head = {n_kv_head} based on GQA {cfg.gqa}') + self.n_kv_head = n_kv_head + self.name_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.LLAMA, ggml_model.hyperparameters.n_layer) + + def save(self): + logger.info('* Preparing to save GGUF file') + gguf_writer = gguf.GGUFWriter( + self.cfg.output, + gguf.MODEL_ARCH_NAMES[gguf.MODEL_ARCH.LLAMA], + use_temp_file = False) + self.add_params(gguf_writer) + self.add_vocab(gguf_writer) + if self.special_vocab is not None: + self.special_vocab.add_to_gguf(gguf_writer) + self.add_tensors(gguf_writer) + logger.info(" gguf: write header") + gguf_writer.write_header_to_file() + logger.info(" gguf: write metadata") + gguf_writer.write_kv_data_to_file() + logger.info(" gguf: write tensors") + gguf_writer.write_tensors_to_file() + gguf_writer.close() + + def add_params(self, gguf_writer): + hp = self.model.hyperparameters + cfg = self.cfg + if cfg.desc is not None: + desc = cfg.desc + else: + desc = f'converted from legacy {self.model.file_format.name}v{self.model.format_version} {hp.ftype.name} format' + try: + # Filenames aren't necessarily valid UTF8. + name = cfg.name if cfg.name is not None else cfg.input.name + except UnicodeDecodeError: + name = None + logger.info('* Adding model parameters and KV items') + if name is not None: + gguf_writer.add_name(name) + gguf_writer.add_description(desc) + gguf_writer.add_file_type(int(hp.ftype)) + if self.params_override is not None: + po = self.params_override + assert po.n_embd == hp.n_embd, 'Model hyperparams mismatch' + assert po.n_layer == hp.n_layer, 'Model hyperparams mismatch' + assert po.n_head == hp.n_head, 'Model hyperparams mismatch' + gguf_writer.add_context_length (po.n_ctx) + gguf_writer.add_embedding_length (po.n_embd) + gguf_writer.add_block_count (po.n_layer) + gguf_writer.add_feed_forward_length (po.n_ff) + gguf_writer.add_rope_dimension_count(po.n_embd // po.n_head) + gguf_writer.add_head_count (po.n_head) + gguf_writer.add_head_count_kv (po.n_head_kv) + gguf_writer.add_layer_norm_rms_eps (po.f_norm_eps) + return + gguf_writer.add_context_length(cfg.context_length) + gguf_writer.add_embedding_length(hp.n_embd) + gguf_writer.add_block_count(hp.n_layer) + gguf_writer.add_feed_forward_length(hp.n_ff) + gguf_writer.add_rope_dimension_count(hp.n_embd // hp.n_head) + gguf_writer.add_head_count(hp.n_head) + gguf_writer.add_head_count_kv(self.n_kv_head) + gguf_writer.add_layer_norm_rms_eps(float(cfg.eps)) + + def add_vocab(self, gguf_writer): + hp = self.model.hyperparameters + gguf_writer.add_tokenizer_model('llama') + gguf_writer.add_tokenizer_pre('default') + tokens = [] + scores = [] + toktypes = [] + if self.vocab_override is not None: + vo = self.vocab_override + logger.info('* Adding vocab item(s)') + for (_, (vbytes, score, ttype)) in enumerate(vo.all_tokens()): + tokens.append(vbytes) + scores.append(score) + toktypes.append(ttype) + assert len(tokens) == hp.n_vocab, \ + f'Override vocab has a different number of items than hyperparameters - override = {len(tokens)} but n_vocab={hp.n_vocab}' + gguf_writer.add_token_list(tokens) + gguf_writer.add_token_scores(scores) + if len(toktypes) > 0: + gguf_writer.add_token_types(toktypes) + return + logger.info(f'* Adding {hp.n_vocab} vocab item(s)') + assert len(self.model.vocab.items) >= 3, 'Cannot handle unexpectedly short model vocab' + for (tokid, (vbytes, vscore)) in enumerate(self.model.vocab.items): + tt = 1 # Normal + # Special handling for UNK, BOS, EOS tokens. + if tokid <= 2: + if tokid == 0: + vbytes = b'' + tt = 2 + elif tokid == 1: + vbytes = b'' + tt = 3 + else: + vbytes = b'' + tt = 3 + elif len(vbytes) == 0: + tt = 3 # Control + elif tokid >= 3 and tokid <= 258 and len(vbytes) == 1: + vbytes = bytes(f'<0x{vbytes[0]:02X}>', encoding = 'UTF-8') + tt = 6 # Byte + else: + vbytes = vbytes.replace(b' ', b'\xe2\x96\x81') + toktypes.append(tt) + tokens.append(vbytes) + scores.append(vscore) + gguf_writer.add_token_list(tokens) + gguf_writer.add_token_scores(scores) + gguf_writer.add_token_types(toktypes) + gguf_writer.add_unk_token_id(0) + gguf_writer.add_bos_token_id(1) + gguf_writer.add_eos_token_id(2) + + def add_tensors(self, gguf_writer): + tensor_map = self.name_map + data = self.data + logger.info(f'* Adding {len(self.model.tensors)} tensor(s)') + for tensor in self.model.tensors: + name = str(tensor.name, 'UTF-8') + mapped_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias")) + assert mapped_name is not None, f'Bad name {name}' + tempdims = list(tensor.dims[:]) + if len(tempdims) > 1: + temp = tempdims[1] + tempdims[1] = tempdims[0] + tempdims[0] = temp + gguf_writer.add_tensor( + mapped_name, + data[tensor.start_offset:tensor.start_offset + tensor.len_bytes], + raw_shape = tempdims, + raw_dtype = tensor.dtype) + + +def handle_metadata(cfg, hp): + import examples.convert_legacy_llama as convert + + assert cfg.model_metadata_dir.is_dir(), 'Metadata dir is not a directory' + hf_config_path = cfg.model_metadata_dir / "config.json" + orig_config_path = cfg.model_metadata_dir / "params.json" + # We pass a fake model here. "original" mode will check the shapes of some + # tensors if information is missing in the .json file: other than that, the + # model data isn't used so this should be safe (at least for now). + fakemodel = { + 'tok_embeddings.weight': convert.LazyTensor.__new__(convert.LazyTensor), + 'layers.0.feed_forward.w1.weight': convert.LazyTensor.__new__(convert.LazyTensor), + } + fakemodel['tok_embeddings.weight'].shape = [hp.n_vocab] + fakemodel['layers.0.feed_forward.w1.weight'].shape = [hp.n_ff] + if hf_config_path.exists(): + params = convert.Params.loadHFTransformerJson(fakemodel, hf_config_path) + elif orig_config_path.exists(): + params = convert.Params.loadOriginalParamsJson(fakemodel, orig_config_path) + else: + raise ValueError('Unable to load metadata') + vocab_path = Path(cfg.vocab_dir if cfg.vocab_dir is not None else cfg.model_metadata_dir) + vocab_factory = convert.VocabFactory(vocab_path) + vocab, special_vocab = vocab_factory.load_vocab(cfg.vocabtype.split(","), cfg.model_metadata_dir) + convert.check_vocab_size(params, vocab) + return params, vocab, special_vocab + + +def handle_args(): + parser = argparse.ArgumentParser(description = 'Convert GGML models to GGUF') + parser.add_argument('--input', '-i', type = Path, required = True, + help = 'Input GGMLv3 filename') + parser.add_argument('--output', '-o', type = Path, required = True, + help ='Output GGUF filename') + parser.add_argument('--name', + help = 'Set model name') + parser.add_argument('--desc', + help = 'Set model description') + parser.add_argument('--gqa', type = int, default = 1, + help = 'grouped-query attention factor (use 8 for LLaMA2 70B)') + parser.add_argument('--eps', default = '5.0e-06', + help = 'RMS norm eps: Use 1e-6 for LLaMA1 and OpenLLaMA, use 1e-5 for LLaMA2') + parser.add_argument('--context-length', '-c', type=int, default = 2048, + help = 'Default max context length: LLaMA1 is typically 2048, LLaMA2 is typically 4096') + parser.add_argument('--model-metadata-dir', '-m', type = Path, + help ='Load HuggingFace/.pth vocab and metadata from the specified directory') + parser.add_argument("--vocab-dir", type=Path, + help="directory containing tokenizer.model, if separate from model file - only meaningful with --model-metadata-dir") + parser.add_argument("--vocabtype", default="spm,hfft", + help="vocab format - only meaningful with --model-metadata-dir and/or --vocab-dir (default: spm,hfft)") + parser.add_argument("--verbose", action="store_true", help="increase output verbosity") + return parser.parse_args() + + +def main(): + cfg = handle_args() + logging.basicConfig(level=logging.DEBUG if cfg.verbose else logging.INFO) + logger.info(f'* Using config: {cfg}') + logger.warning('=== WARNING === Be aware that this conversion script is best-effort. Use a native GGUF model if possible. === WARNING ===') + if cfg.model_metadata_dir is None and (cfg.gqa == 1 or cfg.eps == '5.0e-06'): + logger.info('- Note: If converting LLaMA2, specifying "--eps 1e-5" is required. 70B models also need "--gqa 8".') + data = np.memmap(cfg.input, mode = 'r') + model = GGMLModel() + logger.info('* Scanning GGML input file') + offset = model.load(data, 0) # noqa + logger.info(f'* GGML model hyperparameters: {model.hyperparameters}') + vocab_override = None + params_override = None + special_vocab = None + if cfg.model_metadata_dir is not None: + (params_override, vocab_override, special_vocab) = handle_metadata(cfg, model.hyperparameters) + logger.info('!! Note: When overriding params the --gqa, --eps and --context-length options are ignored.') + logger.info(f'* Overriding params: {params_override}') + logger.info(f'* Overriding vocab: {vocab_override}') + logger.info(f'* Special vocab: {special_vocab}') + else: + logger.warning('\n=== WARNING === Special tokens may not be converted correctly. Use --model-metadata-dir if possible === WARNING ===\n') + if model.file_format == GGMLFormat.GGML: + logger.info('! This is a very old GGML file that does not contain vocab scores. Strongly recommend using model metadata!') + converter = GGMLToGGUF( + model, data, cfg, + params_override = params_override, + vocab_override = vocab_override, + special_vocab = special_vocab + ) + converter.save() + logger.info(f'* Successful completion. Output saved to: {cfg.output}') + + +if __name__ == '__main__': + main() diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py new file mode 100644 index 0000000000000000000000000000000000000000..00a6733cbd3606dece097e308780dacb23ca54cf --- /dev/null +++ b/convert_lora_to_gguf.py @@ -0,0 +1,461 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from dataclasses import dataclass +import logging +import argparse +import os +import sys +import json +from math import prod +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast +from transformers import AutoConfig + +import torch + +if TYPE_CHECKING: + from torch import Tensor + +if 'NO_LOCAL_GGUF' not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) +import gguf + +# reuse model definitions from convert_hf_to_gguf.py +from convert_hf_to_gguf import LazyTorchTensor, ModelBase + +logger = logging.getLogger("lora-to-gguf") + + +@dataclass +class PartialLoraTensor: + A: Tensor | None = None + B: Tensor | None = None + + +# magic to support tensor shape modifications and splitting +class LoraTorchTensor: + _lora_A: Tensor # (n_rank, row_size) + _lora_B: Tensor # (col_size, n_rank) + _rank: int + + def __init__(self, A: Tensor, B: Tensor): + assert len(A.shape) == len(B.shape) + assert A.shape[-2] == B.shape[-1] + if A.dtype != B.dtype: + A = A.to(torch.float32) + B = B.to(torch.float32) + self._lora_A = A + self._lora_B = B + self._rank = B.shape[-1] + + def get_lora_A_B(self) -> tuple[Tensor, Tensor]: + return (self._lora_A, self._lora_B) + + def __getitem__( + self, + indices: ( + SupportsIndex + | slice + | tuple[SupportsIndex | slice | Tensor, ...] # TODO: add ellipsis in the type signature + ), + ) -> LoraTorchTensor: + shape = self.shape + if isinstance(indices, SupportsIndex): + if len(shape) > 2: + return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices]) + else: + raise NotImplementedError # can't return a vector + elif isinstance(indices, slice): + if len(shape) > 2: + return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices]) + else: + return LoraTorchTensor(self._lora_A, self._lora_B[indices]) + elif isinstance(indices, tuple): + assert len(indices) > 0 + if indices[-1] is Ellipsis: + return self[indices[:-1]] + # expand ellipsis + indices = tuple( + u + for v in ( + ( + (slice(None, None) for _ in range(len(indices) - 1)) + if i is Ellipsis + else (i,) + ) + for i in indices + ) + for u in v + ) + + if len(indices) < len(shape): + indices = (*indices, *(slice(None, None) for _ in range(len(indices), len(shape)))) + + # TODO: make sure this is correct + indices_A = ( + *( + ( + j.__index__() % self._lora_A.shape[i] + if isinstance(j, SupportsIndex) + else slice(None, None) + ) + for i, j in enumerate(indices[:-2]) + ), + slice(None, None), + indices[-1], + ) + indices_B = indices[:-1] + return LoraTorchTensor(self._lora_A[indices_A], self._lora_B[indices_B]) + else: + raise NotImplementedError # unknown indice type + + @property + def dtype(self) -> torch.dtype: + assert self._lora_A.dtype == self._lora_B.dtype + return self._lora_A.dtype + + @property + def shape(self) -> tuple[int, ...]: + assert len(self._lora_A.shape) == len(self._lora_B.shape) + return (*self._lora_B.shape[:-1], self._lora_A.shape[-1]) + + def size(self, dim=None): + assert dim is None + return self.shape + + def reshape(self, *shape: int | tuple[int, ...]) -> LoraTorchTensor: + if isinstance(shape[0], tuple): + new_shape: tuple[int, ...] = shape[0] + else: + new_shape = cast(tuple[int, ...], shape) + orig_shape = self.shape + if len(new_shape) < 2: + raise NotImplementedError # can't become a vector + + # expand -1 in the shape + if any(dim == -1 for dim in new_shape): + n_elems = prod(orig_shape) + n_new_elems = prod(dim if dim != -1 else 1 for dim in new_shape) + assert n_elems % n_new_elems == 0 + new_shape = (*(dim if dim != -1 else n_elems // n_new_elems for dim in new_shape),) + + if new_shape[-1] != orig_shape[-1]: + raise NotImplementedError # can't reshape the row size trivially + + shape_A = (*(1 for _ in new_shape[:-2]), self._rank, orig_shape[-1]) + shape_B = (*new_shape[:-1], self._rank) + return LoraTorchTensor( + self._lora_A.reshape(shape_A), + self._lora_B.reshape(shape_B), + ) + + def reshape_as(self, other: Tensor) -> LoraTorchTensor: + return self.reshape(*other.shape) + + def view(self, *size: int) -> LoraTorchTensor: + return self.reshape(*size) + + def permute(self, *dims: int) -> LoraTorchTensor: + shape = self.shape + dims = tuple(dim - len(shape) if dim >= 0 else dim for dim in dims) + if dims[-1] == -1: + # TODO: support higher dimensional A shapes bigger than 1 + assert all(dim == 1 for dim in self._lora_A.shape[:-2]) + return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims)) + if len(shape) == 2 and dims[-1] == -2 and dims[-2] == -1: + return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims)) + else: + # TODO: compose the above two + raise NotImplementedError + + def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor: + shape = self.shape + dims = [i for i in range(len(shape))] + dims[dim0], dims[dim1] = dims[dim1], dims[dim0] + return self.permute(*dims) + + def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor: + return self.transpose(axis0, axis1) + + def to(self, *args, **kwargs): + return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs)) + + @classmethod + def __torch_function__(cls, func: Callable, types, args=(), kwargs=None): + del types # unused + + if kwargs is None: + kwargs = {} + + if func is torch.permute: + return type(args[0]).permute(*args, **kwargs) + elif func is torch.reshape: + return type(args[0]).reshape(*args, **kwargs) + elif func is torch.stack: + assert isinstance(args[0], Sequence) + dim = kwargs.get("dim", 0) + assert dim == 0 + return LoraTorchTensor( + torch.stack([a._lora_A for a in args[0]], dim), + torch.stack([b._lora_B for b in args[0]], dim), + ) + elif func is torch.cat: + assert isinstance(args[0], Sequence) + dim = kwargs.get("dim", 0) + assert dim == 0 + if len(args[0][0].shape) > 2: + return LoraTorchTensor( + torch.cat([a._lora_A for a in args[0]], dim), + torch.cat([b._lora_B for b in args[0]], dim), + ) + elif all(torch.equal(args[0][0]._lora_A, t._lora_A) for t in args[0][1:]): + return LoraTorchTensor( + args[0][0]._lora_A, + torch.cat([b._lora_B for b in args[0]], dim), + ) + else: + raise NotImplementedError + else: + raise NotImplementedError + + +def get_base_tensor_name(lora_tensor_name: str) -> str: + base_name = lora_tensor_name.replace("base_model.model.", "") + base_name = base_name.replace(".lora_A.weight", ".weight") + base_name = base_name.replace(".lora_B.weight", ".weight") + # models produced by mergekit-extract-lora have token embeddings in the adapter + base_name = base_name.replace(".lora_embedding_A", ".weight") + base_name = base_name.replace(".lora_embedding_B", ".weight") + return base_name + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Convert a Hugging Face PEFT LoRA adapter to a GGUF file") + parser.add_argument( + "--outfile", type=Path, + help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", + ) + parser.add_argument( + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16", + help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", + ) + parser.add_argument( + "--bigendian", action="store_true", + help="model is executed on big endian machine", + ) + parser.add_argument( + "--no-lazy", action="store_true", + help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)", + ) + parser.add_argument( + "--verbose", action="store_true", + help="increase output verbosity", + ) + parser.add_argument( + "--dry-run", action="store_true", + help="only print out what will be done, without writing any new files", + ) + parser.add_argument( + "--base", type=Path, + help="directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required. If base model is unspecified, it will be loaded from Hugging Face hub based on the adapter config", + ) + parser.add_argument( + "--base-model-id", type=str, + help="the model ID of the base model, if it is not available locally or in the adapter config. If specified, it will ignore --base and load the base model config from the Hugging Face hub (Example: 'meta-llama/Llama-3.2-1B-Instruct')", + ) + parser.add_argument( + "lora_path", type=Path, + help="directory containing Hugging Face PEFT LoRA config (adapter_model.json) and weights (adapter_model.safetensors or adapter_model.bin)", + ) + + return parser.parse_args() + + +def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]: + # normally, adapter does not come with base model config, we need to load it from AutoConfig + config = AutoConfig.from_pretrained(hf_model_id) + return config.to_dict() + + +if __name__ == '__main__': + args = parse_args() + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) + + ftype_map: dict[str, gguf.LlamaFileType] = { + "f32": gguf.LlamaFileType.ALL_F32, + "f16": gguf.LlamaFileType.MOSTLY_F16, + "bf16": gguf.LlamaFileType.MOSTLY_BF16, + "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, + "auto": gguf.LlamaFileType.GUESSED, + } + + ftype = ftype_map[args.outtype] + + dir_base_model: Path | None = args.base + dir_lora: Path = args.lora_path + base_model_id: str | None = args.base_model_id + lora_config = dir_lora / "adapter_config.json" + input_model = dir_lora / "adapter_model.safetensors" + + if args.outfile is not None: + fname_out = args.outfile + else: + # output in the same directory as the model by default + fname_out = dir_lora + + if os.path.exists(input_model): + # lazy import load_file only if lora is in safetensors format. + from safetensors.torch import load_file + + lora_model = load_file(input_model, device="cpu") + else: + input_model = os.path.join(dir_lora, "adapter_model.bin") + lora_model = torch.load(input_model, map_location="cpu", weights_only=True) + + # load LoRA config + with open(lora_config, "r") as f: + lparams: dict[str, Any] = json.load(f) + + # load base model + if base_model_id is not None: + logger.info(f"Loading base model from Hugging Face: {base_model_id}") + hparams = load_hparams_from_hf(base_model_id) + elif dir_base_model is None: + if "base_model_name_or_path" in lparams: + model_id = lparams["base_model_name_or_path"] + logger.info(f"Loading base model from Hugging Face: {model_id}") + try: + hparams = load_hparams_from_hf(model_id) + except OSError as e: + logger.error(f"Failed to load base model config: {e}") + logger.error("Please try downloading the base model and add its path to --base") + sys.exit(1) + else: + logger.error("'base_model_name_or_path' is not found in adapter_config.json") + logger.error("Base model config is required. Please download the base model and add its path to --base") + sys.exit(1) + else: + logger.info(f"Loading base model: {dir_base_model.name}") + hparams = ModelBase.load_hparams(dir_base_model) + + with torch.inference_mode(): + try: + model_class = ModelBase.from_model_architecture(hparams["architectures"][0]) + except NotImplementedError: + logger.error(f"Model {hparams['architectures'][0]} is not supported") + sys.exit(1) + + class LoraModel(model_class): + model_arch = model_class.model_arch + + lora_alpha: float + + def __init__(self, *args, dir_lora_model: Path, lora_alpha: float, **kwargs): + + super().__init__(*args, **kwargs) + + self.dir_model_card = dir_lora_model + self.lora_alpha = float(lora_alpha) + + def set_vocab(self): + pass + + def set_type(self): + self.gguf_writer.add_type(gguf.GGUFType.ADAPTER) + self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora") + + def set_gguf_parameters(self): + self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha) + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + # Never add extra tensors (e.g. rope_freqs) for LoRA adapters + return () + + def get_tensors(self) -> Iterator[tuple[str, Tensor]]: + tensor_map: dict[str, PartialLoraTensor] = {} + + for name, tensor in lora_model.items(): + if self.lazy: + tensor = LazyTorchTensor.from_eager(tensor) + base_name = get_base_tensor_name(name) + # note: mergekit-extract-lora also adds token embeddings to the adapter + is_lora_a = ".lora_A.weight" in name or ".lora_embedding_A" in name + is_lora_b = ".lora_B.weight" in name or ".lora_embedding_B" in name + if not is_lora_a and not is_lora_b: + if ".base_layer.weight" in name: + continue + # mergekit-extract-lora add these layernorm to the adapter, we need to keep them + if "_layernorm" in name or ".norm" in name: + yield (base_name, tensor) + continue + logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor") + if ".embed_tokens.weight" in name or ".lm_head.weight" in name: + logger.error("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning") + logger.error("Please refer to https://github.com/ggml-org/llama.cpp/pull/9948") + sys.exit(1) + + if base_name in tensor_map: + if is_lora_a: + tensor_map[base_name].A = tensor + else: + tensor_map[base_name].B = tensor + else: + if is_lora_a: + tensor_map[base_name] = PartialLoraTensor(A=tensor) + else: + tensor_map[base_name] = PartialLoraTensor(B=tensor) + + for name, tensor in tensor_map.items(): + assert tensor.A is not None + assert tensor.B is not None + yield (name, cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B))) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + dest = list(super().modify_tensors(data_torch, name, bid)) + # some archs may have the same tensor for lm_head and output (tie word embeddings) + # in this case, adapters targeting lm_head will fail when using llama-export-lora + # therefore, we ignore them for now + # see: https://github.com/ggml-org/llama.cpp/issues/9065 + if name == "lm_head.weight" and len(dest) == 0: + raise ValueError("lm_head is present in adapter, but is ignored in base model") + for dest_name, dest_data in dest: + # mergekit-extract-lora add these layernorm to the adapter + if "_norm" in dest_name: + assert dest_data.dim() == 1 + yield (dest_name, dest_data) + continue + + # otherwise, we must get the lora_A and lora_B tensors + assert isinstance(dest_data, LoraTorchTensor) + lora_a, lora_b = dest_data.get_lora_A_B() + + # note: mergekit-extract-lora flip and transpose A and B + # here we only need to transpose token_embd.lora_a, see llm_build_inp_embd() + if "token_embd.weight" in dest_name: + lora_a = lora_a.T + + yield (dest_name + ".lora_a", lora_a) + yield (dest_name + ".lora_b", lora_b) + + alpha: float = lparams["lora_alpha"] + + model_instance = LoraModel( + dir_base_model, + ftype, + fname_out, + is_big_endian=args.bigendian, + use_temp_file=False, + eager=args.no_lazy, + dry_run=args.dry_run, + dir_lora_model=dir_lora, + lora_alpha=alpha, + hparams=hparams, + ) + + logger.info("Exporting model...") + model_instance.write() + logger.info(f"Model successfully exported to {model_instance.fname_out}") diff --git a/docs/android.md b/docs/android.md new file mode 100644 index 0000000000000000000000000000000000000000..d2a835653fe5d40c88d10e7504bf9f62bd4da1a9 --- /dev/null +++ b/docs/android.md @@ -0,0 +1,83 @@ + +# Android + +## Build on Android using Termux + +[Termux](https://termux.dev/en/) is an Android terminal emulator and Linux environment app (no root required). As of writing, Termux is available experimentally in the Google Play Store; otherwise, it may be obtained directly from the project repo or on F-Droid. + +With Termux, you can install and run `llama.cpp` as if the environment were Linux. Once in the Termux shell: + +``` +$ apt update && apt upgrade -y +$ apt install git cmake +``` + +Then, follow the [build instructions](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md), specifically for CMake. + +Once the binaries are built, download your model of choice (e.g., from Hugging Face). It's recommended to place it in the `~/` directory for best performance: + +``` +$ curl -L {model-url} -o ~/{model}.gguf +``` + +Then, if you are not already in the repo directory, `cd` into `llama.cpp` and: + +``` +$ ./build/bin/llama-cli -m ~/{model}.gguf -c {context-size} -p "{your-prompt}" +``` + +Here, we show `llama-cli`, but any of the executables under `examples` should work, in theory. Be sure to set `context-size` to a reasonable number (say, 4096) to start with; otherwise, memory could spike and kill your terminal. + +To see what it might look like visually, here's an old demo of an interactive session running on a Pixel 5 phone: + +https://user-images.githubusercontent.com/271616/225014776-1d567049-ad71-4ef2-b050-55b0b3b9274c.mp4 + +## Cross-compile using Android NDK +It's possible to build `llama.cpp` for Android on your host system via CMake and the Android NDK. If you are interested in this path, ensure you already have an environment prepared to cross-compile programs for Android (i.e., install the Android SDK). Note that, unlike desktop environments, the Android environment ships with a limited set of native libraries, and so only those libraries are available to CMake when building with the Android NDK (see: https://developer.android.com/ndk/guides/stable_apis.) + +Once you're ready and have cloned `llama.cpp`, invoke the following in the project directory: + +``` +$ cmake \ + -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-28 \ + -DCMAKE_C_FLAGS="-march=armv8.7a" \ + -DCMAKE_CXX_FLAGS="-march=armv8.7a" \ + -DGGML_OPENMP=OFF \ + -DGGML_LLAMAFILE=OFF \ + -B build-android +``` + +Notes: + - While later versions of Android NDK ship with OpenMP, it must still be installed by CMake as a dependency, which is not supported at this time + - `llamafile` does not appear to support Android devices (see: https://github.com/Mozilla-Ocho/llamafile/issues/325) + +The above command should configure `llama.cpp` with the most performant options for modern devices. Even if your device is not running `armv8.7a`, `llama.cpp` includes runtime checks for available CPU features it can use. + +Feel free to adjust the Android ABI for your target. Once the project is configured: + +``` +$ cmake --build build-android --config Release -j{n} +$ cmake --install build-android --prefix {install-dir} --config Release +``` + +After installing, go ahead and download the model of your choice to your host system. Then: + +``` +$ adb shell "mkdir /data/local/tmp/llama.cpp" +$ adb push {install-dir} /data/local/tmp/llama.cpp/ +$ adb push {model}.gguf /data/local/tmp/llama.cpp/ +$ adb shell +``` + +In the `adb shell`: + +``` +$ cd /data/local/tmp/llama.cpp +$ LD_LIBRARY_PATH=lib ./bin/llama-simple -m {model}.gguf -c {context-size} -p "{your-prompt}" +``` + +That's it! + +Be aware that Android will not find the library path `lib` on its own, so we must specify `LD_LIBRARY_PATH` in order to run the installed executables. Android does support `RPATH` in later API levels, so this could change in the future. Refer to the previous section for information about `context-size` (very important!) and running other `examples`. diff --git a/docs/backend/BLIS.md b/docs/backend/BLIS.md new file mode 100644 index 0000000000000000000000000000000000000000..9045485771ea643111b1dd77a88d8ed3ef2abb37 --- /dev/null +++ b/docs/backend/BLIS.md @@ -0,0 +1,60 @@ +BLIS Installation Manual +------------------------ + +BLIS is a portable software framework for high-performance BLAS-like dense linear algebra libraries. It has received awards and recognition, including the 2023 James H. Wilkinson Prize for Numerical Software and the 2020 SIAM Activity Group on Supercomputing Best Paper Prize. BLIS provides a new BLAS-like API and a compatibility layer for traditional BLAS routine calls. It offers features such as object-based API, typed API, BLAS and CBLAS compatibility layers. + +Project URL: https://github.com/flame/blis + +### Prepare: + +Compile BLIS: + +```bash +git clone https://github.com/flame/blis +cd blis +./configure --enable-cblas -t openmp,pthreads auto +# will install to /usr/local/ by default. +make -j +``` + +Install BLIS: + +```bash +sudo make install +``` + +We recommend using openmp since it's easier to modify the cores being used. + +### llama.cpp compilation + +CMake: + +```bash +mkdir build +cd build +cmake -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=FLAME .. +make -j +``` + +### llama.cpp execution + +According to the BLIS documentation, we could set the following +environment variables to modify the behavior of openmp: + +```bash +export GOMP_CPU_AFFINITY="0-19" +export BLIS_NUM_THREADS=14 +``` + +And then run the binaries as normal. + + +### Intel specific issue + +Some might get the error message saying that `libimf.so` cannot be found. +Please follow this [stackoverflow page](https://stackoverflow.com/questions/70687930/intel-oneapi-2022-libimf-so-no-such-file-or-directory-during-openmpi-compila). + +### Reference: + +1. https://github.com/flame/blis#getting-started +2. https://github.com/flame/blis/blob/master/docs/Multithreading.md diff --git a/docs/backend/CANN.md b/docs/backend/CANN.md new file mode 100644 index 0000000000000000000000000000000000000000..2b001f09abe45d956f8b982a1bbc109fa51b67ca --- /dev/null +++ b/docs/backend/CANN.md @@ -0,0 +1,314 @@ +# llama.cpp for CANN + + - [Background](#background) + - [News](#news) + - [OS](#os) + - [Hardware](#hardware) + - [Model Supports](#model-supports) + - [DataType Supports](#datatype-supports) + - [Docker](#docker) + - [Linux](#linux) + - [Environment variable setup](#environment-variable-setup) + - [TODO](#todo) + + +## Background + +**Ascend NPU** is a range of AI processors using Neural Processing Unit. It will efficiently handle matrix-matrix multiplication, dot-product and scalars. + +**CANN** (Compute Architecture for Neural Networks) is a heterogeneous computing architecture for AI scenarios, providing support for multiple AI frameworks on the top and serving AI processors and programming at the bottom. It plays a crucial role in bridging the gap between upper and lower layers, and is a key platform for improving the computing efficiency of Ascend AI processors. Meanwhile, it offers a highly efficient and easy-to-use programming interface for diverse application scenarios, allowing users to rapidly build AI applications and services based on the Ascend platform. + +**Llama.cpp + CANN** + +The llama.cpp CANN backend is designed to support Ascend NPU. It utilize the ability of AscendC and ACLNN which are intergrated to CANN Toolkit and kernels to using Ascend NPU directly. + +## News + +- 2024.11 + - Support F16 and F32 data type model for Ascend 310P NPU. +- 2024.8 + - Support `Q4_0` and `Q8_0` data type for Ascend NPU. +- 2024.7 + - Create CANN backend for Ascend NPU. + +## OS + +| OS | Status | Verified | +|:-------:|:-------:|:----------------------------------------------:| +| Linux | Support | Ubuntu 22.04, OpenEuler22.03 | + + +## Hardware + +### Ascend NPU + +**Verified devices** + +| Ascend NPU | Status | +|:-----------------------------:|:-------:| +| Atlas 300T A2 | Support | +| Atlas 300I Duo | Support | + +*Notes:* + +- If you have trouble with Ascend NPU device, please create a issue with **[CANN]** prefix/tag. +- If you run successfully with your Ascend NPU device, please help update the upper table. + + +## Model Supports + +| Model Name | FP16 | Q4_0 | Q8_0 | +|:----------------------------|:-----:|:----:|:----:| +| Llama-2 | √ | √ | √ | +| Llama-3 | √ | √ | √ | +| Mistral-7B | √ | √ | √ | +| Mistral MOE | √ | √ | √ | +| DBRX | - | - | - | +| Falcon | √ | √ | √ | +| Chinese LLaMA/Alpaca | √ | √ | √ | +| Vigogne(French) | √ | √ | √ | +| BERT | x | x | x | +| Koala | √ | √ | √ | +| Baichuan | √ | √ | √ | +| Aquila 1 & 2 | √ | √ | √ | +| Starcoder models | √ | √ | √ | +| Refact | √ | √ | √ | +| MPT | √ | √ | √ | +| Bloom | √ | √ | √ | +| Yi models | √ | √ | √ | +| stablelm models | √ | √ | √ | +| DeepSeek models | x | x | x | +| Qwen models | √ | √ | √ | +| PLaMo-13B | √ | √ | √ | +| Phi models | √ | √ | √ | +| PhiMoE | √ | √ | √ | +| GPT-2 | √ | √ | √ | +| Orion | √ | √ | √ | +| InternlLM2 | √ | √ | √ | +| CodeShell | √ | √ | √ | +| Gemma | √ | √ | √ | +| Mamba | √ | √ | √ | +| Xverse | √ | √ | √ | +| command-r models | √ | √ | √ | +| Grok-1 | - | - | - | +| SEA-LION | √ | √ | √ | +| GritLM-7B | √ | √ | √ | +| OLMo | √ | √ | √ | +| OLMo 2 | √ | √ | √ | +| OLMoE | √ | √ | √ | +| Granite models | √ | √ | √ | +| GPT-NeoX | √ | √ | √ | +| Pythia | √ | √ | √ | +| Snowflake-Arctic MoE | - | - | - | +| Smaug | √ | √ | √ | +| Poro 34B | √ | √ | √ | +| Bitnet b1.58 models | √ | x | x | +| Flan-T5 | √ | √ | √ | +| Open Elm models | x | √ | √ | +| chatGLM3-6B + ChatGLM4-9b + GLMEdge-1.5b + GLMEdge-4b | √ | √ | √ | +| GLM-4-0414 | √ | √ | √ | +| SmolLM | √ | √ | √ | +| EXAONE-3.0-7.8B-Instruct | √ | √ | √ | +| FalconMamba Models | √ | √ | √ | +| Jais Models | - | x | x | +| Bielik-11B-v2.3 | √ | √ | √ | +| RWKV-6 | - | √ | √ | +| QRWKV-6 | √ | √ | √ | +| GigaChat-20B-A3B | x | x | x | +| Trillion-7B-preview | √ | √ | √ | +| Ling models | √ | √ | √ | + + +**Multimodal** +| Model Name | FP16 | Q4_0 | Q8_0 | +|:----------------------------|:-----:|:----:|:----:| +| LLaVA 1.5 models, LLaVA 1.6 models | x | x | x | +| BakLLaVA | √ | √ | √ | +| Obsidian | √ | - | - | +| ShareGPT4V | x | - | - | +| MobileVLM 1.7B/3B models | - | - | - | +| Yi-VL | - | - | - | +| Mini CPM | √ | √ | √ | +| Moondream | √ | √ | √ | +| Bunny | √ | - | - | +| GLM-EDGE | √ | √ | √ | +| Qwen2-VL | √ | √ | √ | + + + +## DataType Supports + +| DataType | Status | +|:----------------------:|:-------:| +| FP16 | Support | +| Q8_0 | Support | +| Q4_0 | Support | + +## Docker + +### Build Images +You can get a image with llama.cpp in one command. +```sh +docker build -t llama-cpp-cann -f .devops/llama-cli-cann.Dockerfile . +``` + +### Run container + +```sh +# Find all cards. +npu-smi info + +# Select the cards that you want to use, make sure these cards are not used by someone. +# Following using cards of device0. +docker run --name llamacpp --device /dev/davinci0 --device /dev/davinci_manager --device /dev/devmm_svm --device /dev/hisi_hdc -v /usr/local/dcmi:/usr/local/dcmi -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info -v /PATH_TO_YOUR_MODELS/:/app/models -it llama-cpp-cann -m /app/models/MODEL_PATH -ngl 32 -p "Building a website can be done in 10 simple steps:" +``` + +*Notes:* + +- You may need to install Ascend Driver and firmware on the **host** machine *(Please refer to the [Linux configuration](#linux) for details)*. + +## Linux + +### I. Setup Environment + +1. **Install Ascend Driver and firmware** + + ```sh + # create driver running user. + sudo groupadd -g HwHiAiUser + sudo useradd -g HwHiAiUser -d /home/HwHiAiUser -m HwHiAiUser -s /bin/bash + sudo usermod -aG HwHiAiUser $USER + + # download driver from https://www.hiascend.com/hardware/firmware-drivers/community according to your system + # and install driver. + sudo sh Ascend-hdk-910b-npu-driver_x.x.x_linux-{arch}.run --full --install-for-all + ``` + + Once installed, run `npu-smi info` to check whether driver is installed successfully. + ```sh + +-------------------------------------------------------------------------------------------+ + | npu-smi 24.1.rc2 Version: 24.1.rc2 | + +----------------------+---------------+----------------------------------------------------+ + | NPU Name | Health | Power(W) Temp(C) Hugepages-Usage(page)| + | Chip | Bus-Id | AICore(%) Memory-Usage(MB) HBM-Usage(MB) | + +======================+===============+====================================================+ + | 2 xxx | OK | 64.4 51 15 / 15 | + | 0 | 0000:01:00.0 | 0 1873 / 15077 0 / 32768 | + +======================+===============+====================================================+ + | 5 xxx | OK | 64.0 52 15 / 15 | + | 0 | 0000:81:00.0 | 0 1874 / 15077 0 / 32768 | + +======================+===============+====================================================+ + | No running processes found in NPU 2 | + +======================+===============+====================================================+ + | No running processes found in NPU 5 | + +======================+===============+====================================================+ + ``` + +2. **Install Ascend Firmware** + ```sh + # download driver from https://www.hiascend.com/hardware/firmware-drivers/community according to your system + # and install driver. + sudo sh Ascend-hdk-910b-npu-firmware_x.x.x.x.X.run --full + ``` + If the following messaage appers, firmware is installed successfully. + ```sh + Firmware package installed successfully! + ``` + + +3. **Install CANN toolkit and kernels** + + CANN toolkit and kernels can be obtained from the official [CANN Toolkit](https://www.hiascend.com/zh/developer/download/community/result?module=cann) page. + + Please download the corresponding version that satified your system. The minimum version required is 8.0.RC2.alpha002 and here is the install command. + ```sh + pip3 install attrs numpy decorator sympy cffi pyyaml pathlib2 psutil protobuf scipy requests absl-py wheel typing_extensions + sh Ascend-cann-toolkit_8.0.RC2.alpha002_linux-aarch64.run --install + sh Ascend-cann-kernels-910b_8.0.RC2.alpha002_linux.run --install + ``` + + Set Ascend Variables: + ```sh + echo "source ~/Ascend/ascend-toolkit/set_env.sh" >> ~/.bashrc + source ~/.bashrc + ``` + +Upon a successful installation, CANN is enabled for the available ascend devices. + +### II. Build llama.cpp + +```sh +cmake -B build -DGGML_CANN=on -DCMAKE_BUILD_TYPE=release +cmake --build build --config release +``` + +### III. Run the inference + +1. **Retrieve and prepare model** + + You can refer to the general [*Prepare and Quantize*](../../README.md#prepare-and-quantize) guide for model prepration. + + **Notes**: + + - CANN backend only supports FP16/Q4_0/Q8_0 models currently. + +2. **Launch inference** + + There are two device selection modes: + + - Single device: Use one device target specified by the user. + - Multiple devices: Automatically choose the devices with the same backend. + + | Device selection | Parameter | + |:----------------:|:--------------------------------------:| + | Single device | --split-mode none --main-gpu DEVICE_ID | + | Multiple devices | --split-mode layer (default) | + + Examples: + + - Use device 0: + + ```sh + ./build/bin/llama-cli -m path_to_model -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm none -mg 0 + ``` + + - Use multiple devices: + + ```sh + ./build/bin/llama-cli -m path_to_model -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm layer + ``` + +### **GitHub contribution**: +Please add the **[CANN]** prefix/tag in issues/PRs titles to help the CANN-team check/address them without delay. + +## Updates +### Basic Flash Attention Support +The basic FA kernel with aclnnops has been added in aclnn_ops.cpp. +Currently, the FA only supports the cases with FP16 KV tensors and NO logit softcap. +Since the aclnn interface for flash attention cannot support the logit softcap, we will only update the quantized version in the future. + +Authors from Peking University: Bizhao Shi (bshi@pku.edu.cn), Yuxin Yang (yxyang@pku.edu.cn), Ruiyang Ma (ruiyang@stu.pku.edu.cn), and Guojie Luo (gluo@pku.edu.cn). + +We would like to thank Tuo Dai, Shanni Li, and all of the project maintainers from Huawei Technologies Co., Ltd for their help during the code development and pull request. + +## Environment variable setup + +### GGML_CANN_ASYNC_MODE + +Enables asynchronous operator submission. Disabled by default. + +### GGML_CANN_MEM_POOL + +Specifies the memory pool management strategy: + +- vmm: Utilizes a virtual memory manager pool. If hardware support for VMM is unavailable, falls back to the legacy (leg) memory pool. + +- prio: Employs a priority queue-based memory pool management. +- leg: Uses a fixed-size buffer pool. + +### GGML_CANN_DISABLE_BUF_POOL_CLEAN + +Controls automatic cleanup of the memory pool. This option is only effective when using the prio or leg memory pool strategies. + +## TODO +- Support more models and data types. diff --git a/docs/backend/CUDA-FEDORA.md b/docs/backend/CUDA-FEDORA.md new file mode 100644 index 0000000000000000000000000000000000000000..1508faf776d28ded89f5d03933e34c5a1899c4bd --- /dev/null +++ b/docs/backend/CUDA-FEDORA.md @@ -0,0 +1,283 @@ +# Setting Up CUDA on Fedora + +In this guide we setup [Nvidia CUDA](https://docs.nvidia.com/cuda/) in a toolbox container. This guide is applicable for: + +- [Fedora Workstation](https://fedoraproject.org/workstation/) +- [Atomic Desktops for Fedora](https://fedoraproject.org/atomic-desktops/) +- [Fedora Spins](https://fedoraproject.org/spins) +- [Other Distributions](https://containertoolbx.org/distros/), including `Red Hat Enterprise Linux >= 8.5`, `Arch Linux`, and `Ubuntu`. + +## Table of Contents + +- [Prerequisites](#prerequisites) +- [Using the Fedora 41 CUDA Repository](#using-the-fedora-41-cuda-repository) +- [Creating a Fedora Toolbox Environment](#creating-a-fedora-toolbox-environment) +- [Installing Essential Development Tools](#installing-essential-development-tools) +- [Adding the CUDA Repository](#adding-the-cuda-repository) +- [Installing Nvidia Driver Libraries](#installing-nvidia-driver-libraries) +- [Installing the CUDA Meta-Package](#installing-the-cuda-meta-package) +- [Configuring the Environment](#configuring-the-environment) +- [Verifying the Installation](#verifying-the-installation) +- [Conclusion](#conclusion) +- [Troubleshooting](#troubleshooting) +- [Additional Notes](#additional-notes) +- [References](#references) + +## Prerequisites + +- **Toolbox Installed on the Host System** `Fedora Silverblue` and `Fedora Workstation` both have toolbox by default, other distributions may need to install the [toolbox package](https://containertoolbx.org/install/). +- **NVIDIA Drivers and Graphics Card installed on Host System (recommended)** To run CUDA program, such as `llama.cpp`, the host should be setup to access your NVIDIA hardware. Fedora Hosts can use the [RPM Fusion Repository](https://rpmfusion.org/Howto/NVIDIA). +- **Internet connectivity** to download packages. + +### Using the Fedora 41 CUDA Repository + +The latest release is 41. + +- [Fedora 41 CUDA Repository](https://developer.download.nvidia.com/compute/cuda/repos/fedora41/x86_64/) + +**Note:** We recommend using a toolbox environment to prevent system conflicts. + +## Creating a Fedora Toolbox Environment + +This guide focuses on Fedora hosts, but with small adjustments, it can work for other hosts. Using the Fedora Toolbox allows us to install the necessary packages without affecting the host system. + +**Note:** Toolbox is available for other systems, and even without Toolbox, it is possible to use Podman or Docker. + +1. **Create a Fedora 41 Toolbox:** + + ```bash + toolbox create --image registry.fedoraproject.org/fedora-toolbox:41 --container fedora-toolbox-41-cuda + ``` + +2. **Enter the Toolbox:** + + ```bash + toolbox enter --container fedora-toolbox-41-cuda + ``` + + Inside the toolbox, you have root privileges and can install packages without affecting the host system. + +## Installing Essential Development Tools + +1. **Synchronize the DNF Package Manager:** + + ```bash + sudo dnf distro-sync + ``` + +2. **Install **Vim** the default text editor (Optional):** + + ```bash + sudo dnf install vim-default-editor --allowerasing + ``` + + The `--allowerasing` flag will allow the removal of the conflicting `nano-default-editor` package. + +3. **Install Development Tools and Libraries:** + + ```bash + sudo dnf install @c-development @development-tools cmake + ``` + + This installs essential packages for compiling software, including `gcc`, `make`, and other development headers. + +## Adding the CUDA Repository + +Add the NVIDIA CUDA repository to your DNF configuration: + +```bash +sudo dnf config-manager addrepo --from-repofile=https://developer.download.nvidia.com/compute/cuda/repos/fedora41/x86_64/cuda-fedora41.repo +``` + +After adding the repository, synchronize the package manager again: + +```bash +sudo dnf distro-sync +``` + +## Installing Nvidia Driver Libraries + +First, we need to detect if the host is supplying the [NVIDIA driver libraries into the toolbox](https://github.com/containers/toolbox/blob/main/src/pkg/nvidia/nvidia.go): + +```bash +ls -la /usr/lib64/libcuda.so.1 +``` + +### If *`libcuda.so.1`* is missing: + +``` +ls: cannot access '/usr/lib64/libcuda.so.1': No such file or directory +``` + +**Explanation:** +The host dose not supply the CUDA drivers, **install them now:** + +#### Install the Nvidia Driver Libraries on Guest: + +```bash +sudo dnf install nvidia-driver-cuda nvidia-driver-libs nvidia-driver-cuda-libs nvidia-persistenced +``` + +### If *`libcuda.so.1`* exists: +``` +lrwxrwxrwx. 1 root root 21 Mar 24 11:26 /usr/lib64/libcuda.so.1 -> libcuda.so.570.133.07 +``` + +**Explanation:** +The host is supply the CUDA drivers, **we need to update the guest RPM Database accordingly:** + +#### Update the Toolbox RPM Database to include the Host-Supplied Libraries: + +Note: we do not actually install the libraries, we just update the DB so that the guest system knows they are supplied by the host. + +##### 1. Download `nvidia-` parts that are supplied by the host RPM's (with dependencies) + +```bash +sudo dnf download --destdir=/tmp/nvidia-driver-libs --resolve --arch x86_64 nvidia-driver-cuda nvidia-driver-libs nvidia-driver-cuda-libs nvidia-persistenced +``` + +##### 2. Update the RPM database to assume the installation of these packages. + +```bash +sudo rpm --install --verbose --hash --justdb /tmp/nvidia-driver-libs/* +``` + +**Note:** + +- The `--justdb` option only updates the RPM database, without touching the filesystem elsewhere. + +##### Check that the RPM Database has been correctly updated: + +**Note:** This is the same command as in the *"Install the Nvidia Driver Libraries on Guest"* for if *`libcuda.so.1`* was missing. + + +```bash +sudo dnf install nvidia-driver-cuda nvidia-driver-libs nvidia-driver-cuda-libs nvidia-persistenced +``` + +*(this time it will not install anything, as the database things that these packages are already installed)* + +``` +Updating and loading repositories: +Repositories loaded. +Package "nvidia-driver-cuda-3:570.124.06-1.fc41.x86_64" is already installed. +Package "nvidia-driver-libs-3:570.124.06-1.fc41.x86_64" is already installed. +Package "nvidia-driver-cuda-libs-3:570.124.06-1.fc41.x86_64" is already installed. +Package "nvidia-persistenced-3:570.124.06-1.fc41.x86_64" is already installed. + +Nothing to do. +``` + +## Installing the CUDA Meta-Package + +Now that the driver libraries are installed, proceed to install CUDA: + +```bash +sudo dnf install cuda +``` + +This installs the CUDA toolkit and associated packages. + +## Configuring the Environment + +To use CUDA, add its binary directory to your system's `PATH`. + +1. **Create a Profile Script:** + + ```bash + sudo sh -c 'echo "export PATH=\$PATH:/usr/local/cuda/bin" >> /etc/profile.d/cuda.sh' + ``` + + **Explanation:** + + - We add to `/etc/profile.d/` as the `/etc/` folder is unique to this particular container, and is not shared with other containers or the host system. + - The backslash `\` before `$PATH` ensures the variable is correctly written into the script. + +2. **Make the Script Executable:** + + ```bash + sudo chmod +x /etc/profile.d/cuda.sh + ``` + +3. **Source the Script to Update Your Environment:** + + ```bash + source /etc/profile.d/cuda.sh + ``` + + **Note:** This command updates your current shell session with the new `PATH`. The `/etc/profile.d/cuda.sh` script ensures that the CUDA binaries are available in your `PATH` for all future sessions. + +## Verifying the Installation + +To confirm that CUDA is correctly installed and configured, check the version of the NVIDIA CUDA Compiler (`nvcc`): + +```bash +nvcc --version +``` + +You should see output similar to: + +``` +nvcc: NVIDIA (R) Cuda compiler driver +Copyright (c) 2005-2025 NVIDIA Corporation +Built on Fri_Feb_21_20:23:50_PST_2025 +Cuda compilation tools, release 12.8, V12.8.93 +Build cuda_12.8.r12.8/compiler.35583870_0 +``` + +This output confirms that the CUDA compiler is accessible and indicates the installed version. + +## Conclusion + +You have successfully set up CUDA on Fedora within a toolbox environment using the Fedora 41 CUDA repository. By manually updating the RPM db and configuring the environment, you can develop CUDA applications without affecting your host system. + +## Troubleshooting + +- **Installation Failures:** + + - If you encounter errors during installation, carefully read the error messages. They often indicate conflicting files or missing dependencies. + - You may use the `--excludepath` option with `rpm` to exclude conflicting files during manual RPM installations. + +- **Rebooting the Container:** + + - Sometimes there may be a bug in the NVIDIA driver host passthrough (such as missing a shared library). Rebooting the container may solve this issue: + + ```bash + # on the host system + podman container restart --all + ``` + +- **Environment Variables Not Set:** + - If `nvcc` is not found after installation, ensure that `/usr/local/cuda/bin` is in your `PATH`. + - Run `echo $PATH` to check if the path is included. + - Re-source the profile script or open a new terminal session. + +## Additional Notes + +- **Updating CUDA in the Future:** + + - Keep an eye on the official NVIDIA repositories for updates to your Fedora version. + - When an updated repository becomes available, adjust your `dnf` configuration accordingly. + +- **Building `llama.cpp`:** + + - With CUDA installed, you can follow these [build instructions for `llama.cpp`](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md) to compile it with CUDA support. + - Ensure that any CUDA-specific build flags or paths are correctly set in your build configuration. + +- **Using the Toolbox Environment:** + - The toolbox environment is isolated from your host system, which helps prevent conflicts. + - Remember that system files and configurations inside the toolbox are separate from the host. By default the home directory of the user is shared between the host and the toolbox. + +--- + +**Disclaimer:** Manually installing and modifying system packages can lead to instability of the container. The above steps are provided as a guideline and may need adjustments based on your specific system configuration. Always back up important data before making significant system changes, especially as your home folder is writable and shared with he toolbox. + +**Acknowledgments:** Special thanks to the Fedora community and NVIDIA documentation for providing resources that assisted in creating this guide. + +## References + +- [Fedora Toolbox Documentation](https://docs.fedoraproject.org/en-US/fedora-silverblue/toolbox/) +- [NVIDIA CUDA Installation Guide](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) +- [Podman Documentation](https://podman.io/get-started) + +--- diff --git a/docs/backend/OPENCL.md b/docs/backend/OPENCL.md new file mode 100644 index 0000000000000000000000000000000000000000..07146f7102f3df1f299e9468dc72f0871767b19b --- /dev/null +++ b/docs/backend/OPENCL.md @@ -0,0 +1,209 @@ +# llama.cpp for OpenCL + +- [Background](#background) +- [OS](#os) +- [Hardware](#hardware) +- [DataType Supports](#datatype-supports) +- [Model Preparation](#model-preparation) +- [CMake Options](#cmake-options) +- [Android](#android) +- [Windows 11 Arm64](#windows-11-arm64) +- [Known Issue](#known-issues) +- [TODO](#todo) + +## Background + +OpenCL (Open Computing Language) is an open, royalty-free standard for cross-platform, parallel programming of diverse accelerators found in supercomputers, cloud servers, personal computers, mobile devices and embedded platforms. OpenCL specifies a programming language (based on C99) for programming these devices and application programming interfaces (APIs) to control the platform and execute programs on the compute devices. Similar to CUDA, OpenCL has been widely used to program GPUs and is supported by most GPU vendors. + +### Llama.cpp + OpenCL + +The llama.cpp OpenCL backend is designed to enable llama.cpp on **Qualcomm Adreno GPU** firstly via OpenCL. Thanks to the portabilty of OpenCL, the OpenCL backend can also run on certain Intel GPUs although the performance is not optimal. + +## OS + +| OS | Status | Verified | +|---------|---------|------------------------------------------------| +| Android | Support | Snapdragon 8 Gen 3, Snapdragon 8 Elite | +| Windows | Support | Windows 11 Arm64 with Snapdragon X Elite | +| Linux | Support | Ubuntu 22.04 WSL2 with Intel 12700H | + +## Hardware + +### Adreno GPU + +**Verified devices** + +| Adreno GPU | Status | +|:------------------------------------:|:-------:| +| Adreno 750 (Snapdragon 8 Gen 3) | Support | +| Adreno 830 (Snapdragon 8 Elite) | Support | +| Adreno X85 (Snapdragon X Elite) | Support | + +## DataType Supports + +| DataType | Status | +|:----------------------:|:--------------------------:| +| Q4_0 | Support | +| Q6_K | Support, but not optimized | + +## Model Preparation + +You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model prepration. + +Currently we support `Q4_0` quantization and have optimize for it. To achieve best performance on Adreno GPU, add `--pure` to `llama-quantize`. For example, + +```sh +./llama-quantize --pure ggml-model-qwen2.5-3b-f16.gguf ggml-model-qwen-3b-Q4_0.gguf Q4_0 +``` + +Since `Q6_K` is also supported, `Q4_0` quantization without `--pure` will also work. However, the performance will be worse compared to pure `Q4_0` quantization. + +## CMake Options + +The OpenCL backend has the following CMake options that control the behavior of the backend. + +| CMake options | Default value | Description | +|:---------------------------------:|:--------------:|:------------------------------------------| +| `GGML_OPENCL_EMBED_KERNELS` | `ON` | Embed OpenCL kernels into the executable. | +| `GGML_OPENCL_USE_ADRENO_KERNELS` | `ON` | Use kernels optimized for Adreno. | + +## Android + +Ubuntu 22.04 is used for targeting Android. Make sure the following tools are accessible from command line, + +* Git +* CMake 3.29 +* Ninja +* Python3 + +### I. Setup Environment + +1. **Install NDK** + +```sh +cd ~ +wget https://dl.google.com/android/repository/commandlinetools-linux-8512546_latest.zip && \ +unzip commandlinetools-linux-8512546_latest.zip && \ +mkdir -p ~/android-sdk/cmdline-tools && \ +mv cmdline-tools latest && \ +mv latest ~/android-sdk/cmdline-tools/ && \ +rm -rf commandlinetools-linux-8512546_latest.zip + +yes | ~/android-sdk/cmdline-tools/latest/bin/sdkmanager "ndk;26.3.11579264" +``` + +2. **Install OpenCL Headers and Library** + +```sh +mkdir -p ~/dev/llm +cd ~/dev/llm + +git clone https://github.com/KhronosGroup/OpenCL-Headers && \ +cd OpenCL-Headers && \ +cp -r CL ~/android-sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include + +cd ~/dev/llm + +git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader && \ +cd OpenCL-ICD-Loader && \ +mkdir build_ndk26 && cd build_ndk26 && \ +cmake .. -G Ninja -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=$HOME/android-sdk/ndk/26.3.11579264/build/cmake/android.toolchain.cmake \ + -DOPENCL_ICD_LOADER_HEADERS_DIR=$HOME/android-sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=24 \ + -DANDROID_STL=c++_shared && \ +ninja && \ +cp libOpenCL.so ~/android-sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android +``` + +### II. Build llama.cpp + +```sh +cd ~/dev/llm + +git clone https://github.com/ggml-org/llama.cpp && \ +cd llama.cpp && \ +mkdir build-android && cd build-android + +cmake .. -G Ninja \ + -DCMAKE_TOOLCHAIN_FILE=$HOME/android-sdk/ndk/26.3.11579264/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-28 \ + -DBUILD_SHARED_LIBS=OFF \ + -DGGML_OPENCL=ON + +ninja +``` + +## Windows 11 Arm64 + +A Snapdragon X Elite device with Windows 11 Arm64 is used. Make sure the following tools are accessible from command line, + +* Git +* CMake 3.29 +* Clang 19 +* Ninja +* Visual Studio 2022 +* Powershell 7 + +Visual Studio provides necessary headers and libraries although it is not directly used for building. +Alternatively, Visual Studio Build Tools can be installed instead of the full Visual Studio. + +Powershell 7 is used for the following commands. +If an older version of Powershell is used, these commands may not work as they are. + +### I. Setup Environment + +1. **Install OpenCL Headers and Library** + +```powershell +mkdir -p ~/dev/llm + +cd ~/dev/llm +git clone https://github.com/KhronosGroup/OpenCL-Headers && cd OpenCL-Headers +mkdir build && cd build +cmake .. -G Ninja ` + -DBUILD_TESTING=OFF ` + -DOPENCL_HEADERS_BUILD_TESTING=OFF ` + -DOPENCL_HEADERS_BUILD_CXX_TESTS=OFF ` + -DCMAKE_INSTALL_PREFIX="$HOME/dev/llm/opencl" +cmake --build . --target install + +cd ~/dev/llm +git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader && cd OpenCL-ICD-Loader +mkdir build && cd build +cmake .. -G Ninja ` + -DCMAKE_BUILD_TYPE=Release ` + -DCMAKE_PREFIX_PATH="$HOME/dev/llm/opencl" ` + -DCMAKE_INSTALL_PREFIX="$HOME/dev/llm/opencl" +cmake --build . --target install +``` + +### II. Build llama.cpp + +```powershell + +mkdir -p ~/dev/llm +cd ~/dev/llm + +git clone https://github.com/ggml-org/llama.cpp && cd llama.cpp +mkdir build && cd build + +cmake .. -G Ninja ` + -DCMAKE_TOOLCHAIN_FILE="$HOME/dev/llm/llama.cpp/cmake/arm64-windows-llvm.cmake" ` + -DCMAKE_BUILD_TYPE=Release ` + -DCMAKE_PREFIX_PATH="$HOME/dev/llm/opencl" ` + -DBUILD_SHARED_LIBS=OFF ` + -DGGML_OPENCL=ON +ninja +``` + +## Known Issues + +- Currently OpenCL backend does not work on Adreno 6xx GPUs. + +## TODO + +- Optimization for Q6_K +- Support and optimization for Q4_K diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md new file mode 100644 index 0000000000000000000000000000000000000000..249e73451e66be99ba332b04e7503e4e85c48044 --- /dev/null +++ b/docs/backend/SYCL.md @@ -0,0 +1,815 @@ +# llama.cpp for SYCL + +- [Background](#background) +- [Recommended Release](#recommended-release) +- [News](#news) +- [OS](#os) +- [Hardware](#hardware) +- [Docker](#docker) +- [Linux](#linux) +- [Windows](#windows) +- [Environment Variable](#environment-variable) +- [Known Issue](#known-issues) +- [Q&A](#qa) +- [TODO](#todo) + +## Background + +**SYCL** is a high-level parallel programming model designed to improve developers productivity writing code across various hardware accelerators such as CPUs, GPUs, and FPGAs. It is a single-source language designed for heterogeneous computing and based on standard C++17. + +**oneAPI** is an open ecosystem and a standard-based specification, supporting multiple architectures including but not limited to Intel CPUs, GPUs and FPGAs. The key components of the oneAPI ecosystem include: + +- **DPCPP** *(Data Parallel C++)*: The primary oneAPI SYCL implementation, which includes the icpx/icx Compilers. +- **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. Intel oneMKL, oneMath and oneDNN)*. +- **oneAPI LevelZero**: A high performance low level interface for fine-grained control over Intel iGPUs and dGPUs. +- **Nvidia & AMD Plugins**: These are plugins extending oneAPI's DPCPP support to SYCL on Nvidia and AMD GPU targets. + +### Llama.cpp + SYCL + +The llama.cpp SYCL backend is primarily designed for **Intel GPUs**. +SYCL cross-platform capabilities enable support for Nvidia GPUs as well, with limited support for AMD. + +## Recommended Release + +The following releases are verified and recommended: + +|Commit ID|Tag|Release|Verified Platform| Update date| +|-|-|-|-|-| +|24e86cae7219b0f3ede1d5abdf5bf3ad515cccb8|b5377 |[llama-b5377-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b5377/llama-b5377-bin-win-sycl-x64.zip) |ArcB580/Linux/oneAPI 2025.1
LNL Arc GPU/Windows 11/oneAPI 2025.1.1|2025-05-15| +|3bcd40b3c593d14261fb2abfabad3c0fb5b9e318|b4040 |[llama-b4040-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b4040/llama-b4040-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1
MTL Arc GPU/Windows 11/oneAPI 2024.1| 2024-11-19| +|fb76ec31a9914b7761c1727303ab30380fd4f05c|b3038 |[llama-b3038-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b3038/llama-b3038-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1
MTL Arc GPU/Windows 11/oneAPI 2024.1|| + + +## News + +- 2025.2 + - Optimize MUL_MAT Q4_0 on Intel GPU for all dGPUs and built-in GPUs since MTL. Increase the performance of LLM (llama-2-7b.Q4_0.gguf) 21%-87% on Intel GPUs (MTL, ARL-H, Arc, Flex, PVC). + |GPU|Base tokens/s|Increased tokens/s|Percent| + |-|-|-|-| + |PVC 1550|39|73|+87%| + |Flex 170|39|50|+28%| + |Arc770|42|55|+30%| + |MTL|13|16|+23%| + |ARL-H|14|17|+21%| + +- 2024.11 + - Use syclcompat to improve the performance on some platforms. This requires to use oneAPI 2025.0 or newer. + +- 2024.8 + - Use oneDNN as the default GEMM library, improve the compatibility for new Intel GPUs. + +- 2024.5 + - Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc770. + - Arch Linux is verified successfully. + +- 2024.4 + - Support data types: GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_XS, GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ3_S, GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M. + +- 2024.3 + - Release binary files of Windows. + - A blog is published: **Run LLM on all Intel GPUs Using llama.cpp**: [intel.com](https://www.intel.com/content/www/us/en/developer/articles/technical/run-llm-on-all-gpus-using-llama-cpp-artical.html) or [medium.com](https://medium.com/@jianyu_neo/run-llm-on-all-intel-gpus-using-llama-cpp-fd2e2dcbd9bd). + - New base line is ready: [tag b2437](https://github.com/ggml-org/llama.cpp/tree/b2437). + - Support multiple cards: **--split-mode**: [none|layer]; not support [row], it's on developing. + - Support to assign main GPU by **--main-gpu**, replace $GGML_SYCL_DEVICE. + - Support detecting all GPUs with level-zero and same top **Max compute units**. + - Support OPs + - hardsigmoid + - hardswish + - pool2d + +- 2024.1 + - Create SYCL backend for Intel GPU. + - Support Windows build + +## OS + +| OS | Status | Verified | +|---------|---------|------------------------------------------------| +| Linux | Support | Ubuntu 22.04, Fedora Silverblue 39, Arch Linux | +| Windows | Support | Windows 11 | + + +## Hardware + +### Intel GPU + +SYCL backend supports Intel GPU Family: + +- Intel Data Center Max Series +- Intel Flex Series, Arc Series +- Intel Built-in Arc GPU +- Intel iGPU in Core CPU (11th Generation Core CPU and newer, refer to [oneAPI supported GPU](https://www.intel.com/content/www/us/en/developer/articles/system-requirements/intel-oneapi-base-toolkit-system-requirements.html#inpage-nav-1-1)). + +#### Verified devices + +| Intel GPU | Status | Verified Model | +|-------------------------------|---------|---------------------------------------| +| Intel Data Center Max Series | Support | Max 1550, 1100 | +| Intel Data Center Flex Series | Support | Flex 170 | +| Intel Arc Series | Support | Arc 770, 730M, Arc A750, B580 | +| Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake, Arrow Lake, Lunar Lake | +| Intel iGPU | Support | iGPU in 13700k, 13400, i5-1250P, i7-1260P, i7-1165G7 | + +*Notes:* + +- **Memory** + - The device memory is a limitation when running a large model. The loaded model size, *`llm_load_tensors: buffer_size`*, is displayed in the log when running `./bin/llama-cli`. + - Please make sure the GPU shared memory from the host is large enough to account for the model's size. For e.g. the *llama-2-7b.Q4_0* requires at least 8.0GB for integrated GPU and 4.0GB for discrete GPU. + +- **Execution Unit (EU)** + - If the iGPU has less than 80 EUs, the inference speed will likely be too slow for practical use. + +### Other Vendor GPU + +**Verified devices** + +| Nvidia GPU | Status | Verified Model | +|--------------------------|-----------|----------------| +| Ampere Series | Supported | A100, A4000 | +| Ampere Series *(Mobile)* | Supported | RTX 40 Series | + +| AMD GPU | Status | Verified Model | +|--------------------------|--------------|----------------| +| Radeon Pro | Experimental | W6800 | +| Radeon RX | Experimental | 6700 XT | + +Note: AMD GPU support is highly experimental and is incompatible with F16. +Additionally, it only supports GPUs with a sub_group_size (warp size) of 32. + +## Docker + +The docker build option is currently limited to *Intel GPU* targets. + +### Build image + +```sh +# Using FP16 +docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" --target light -f .devops/intel.Dockerfile . +``` + +*Notes*: + +To build in default FP32 *(Slower than FP16 alternative)*, set `--build-arg="GGML_SYCL_F16=OFF"` in the previous command. + +You can also use the `.devops/llama-server-intel.Dockerfile`, which builds the *"server"* alternative. +Check the [documentation for Docker](../docker.md) to see the available images. + +### Run container + +```sh +# First, find all the DRI cards +ls -la /dev/dri +# Then, pick the card that you want to use (here for e.g. /dev/dri/card1). +docker run -it --rm -v "$(pwd):/app:Z" --device /dev/dri/renderD128:/dev/dri/renderD128 --device /dev/dri/card1:/dev/dri/card1 llama-cpp-sycl -m "/app/models/YOUR_MODEL_FILE" -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 +``` + +*Notes:* +- Docker has been tested successfully on native Linux. WSL support has not been verified yet. +- You may need to install Intel GPU driver on the **host** machine *(Please refer to the [Linux configuration](#linux) for details)*. + +## Linux + +### I. Setup Environment + +1. **Install GPU drivers** + + - **Intel GPU** + +Intel data center GPUs drivers installation guide and download page can be found here: [Get intel dGPU Drivers](https://dgpu-docs.intel.com/driver/installation.html#ubuntu-install-steps). + +*Note*: for client GPUs *(iGPU & Arc A-Series)*, please refer to the [client iGPU driver installation](https://dgpu-docs.intel.com/driver/client/overview.html). + +Once installed, add the user(s) to the `video` and `render` groups. + +```sh +sudo usermod -aG render $USER +sudo usermod -aG video $USER +``` + +*Note*: logout/re-login for the changes to take effect. + +Verify installation through `clinfo`: + +```sh +sudo apt install clinfo +sudo clinfo -l +``` + +Sample output: + +```sh +Platform #0: Intel(R) OpenCL Graphics + `-- Device #0: Intel(R) Arc(TM) A770 Graphics + +Platform #0: Intel(R) OpenCL HD Graphics + `-- Device #0: Intel(R) Iris(R) Xe Graphics [0x9a49] +``` + +- **Nvidia GPU** + +In order to target Nvidia GPUs through SYCL, please make sure the CUDA/CUBLAS native requirements *-found [here](README.md#cuda)-* are installed. + +- **AMD GPU** + +To target AMD GPUs with SYCL, the ROCm stack must be installed first. + +2. **Install Intel® oneAPI Base toolkit** + +- **For Intel GPU** + +The base toolkit can be obtained from the official [Intel® oneAPI Base Toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) page. + +Please follow the instructions for downloading and installing the Toolkit for Linux, and preferably keep the default installation values unchanged, notably the installation path *(`/opt/intel/oneapi` by default)*. + +Following guidelines/code snippets assume the default installation values. Otherwise, please make sure the necessary changes are reflected where applicable. + +Upon a successful installation, SYCL is enabled for the available intel devices, along with relevant libraries such as oneAPI oneDNN for Intel GPUs. + +- **Adding support to Nvidia GPUs** + +**oneAPI Plugin**: In order to enable SYCL support on Nvidia GPUs, please install the [Codeplay oneAPI Plugin for Nvidia GPUs](https://developer.codeplay.com/products/oneapi/nvidia/download). User should also make sure the plugin version matches the installed base toolkit one *(previous step)* for a seamless "oneAPI on Nvidia GPU" setup. + +**oneDNN**: The current oneDNN releases *(shipped with the oneAPI base-toolkit)* do not include the NVIDIA backend. Therefore, oneDNN must be compiled from source to enable the NVIDIA target: + +```sh +git clone https://github.com/oneapi-src/oneDNN.git +cd oneDNN +cmake -GNinja -Bbuild-nvidia -DDNNL_CPU_RUNTIME=DPCPP -DDNNL_GPU_RUNTIME=DPCPP -DDNNL_GPU_VENDOR=NVIDIA -DONEDNN_BUILD_GRAPH=OFF -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx +cmake --build build-nvidia --config Release +``` + +- **Adding support to AMD GPUs** + +**oneAPI Plugin**: In order to enable SYCL support on AMD GPUs, please install the [Codeplay oneAPI Plugin for AMD GPUs](https://developer.codeplay.com/products/oneapi/amd/download). As with Nvidia GPUs, the user should also make sure the plugin version matches the installed base toolkit. + +3. **Verify installation and environment** + +In order to check the available SYCL devices on the machine, please use the `sycl-ls` command. +```sh +source /opt/intel/oneapi/setvars.sh +sycl-ls +``` + +- **Intel GPU** + +When targeting an intel GPU, the user should expect one or more devices among the available SYCL devices. Please make sure that at least one GPU is present via `sycl-ls`, for instance `[level_zero:gpu]` in the sample output below: + +``` +[opencl:acc][opencl:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.10.0.17_160000] +[opencl:cpu][opencl:1] Intel(R) OpenCL, 13th Gen Intel(R) Core(TM) i7-13700K OpenCL 3.0 (Build 0) [2023.16.10.0.17_160000] +[opencl:gpu][opencl:2] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics OpenCL 3.0 NEO [23.30.26918.50] +[level_zero:gpu][level_zero:0] Intel(R) Level-Zero, Intel(R) Arc(TM) A770 Graphics 1.3 [1.3.26918] +``` + +- **Nvidia GPU** + +Similarly, user targeting Nvidia GPUs should expect at least one SYCL-CUDA device [`cuda:gpu`] as below: + +``` +[opencl:acc][opencl:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.12.0.12_195853.xmain-hotfix] +[opencl:cpu][opencl:1] Intel(R) OpenCL, Intel(R) Xeon(R) Gold 6326 CPU @ 2.90GHz OpenCL 3.0 (Build 0) [2023.16.12.0.12_195853.xmain-hotfix] +[cuda:gpu][cuda:0] NVIDIA CUDA BACKEND, NVIDIA A100-PCIE-40GB 8.0 [CUDA 12.5] +``` + +- **AMD GPU** + +For AMD GPUs we should expect at least one SYCL-HIP device [`hip:gpu`]: + +``` +[opencl:cpu][opencl:0] Intel(R) OpenCL, 12th Gen Intel(R) Core(TM) i9-12900K OpenCL 3.0 (Build 0) [2024.18.6.0.02_160000] +[hip:gpu][hip:0] AMD HIP BACKEND, AMD Radeon PRO W6800 gfx1030 [HIP 60140.9] +``` + +### II. Build llama.cpp + +#### Intel GPU + +```sh +./examples/sycl/build.sh +``` + +or + +```sh +# Export relevant ENV variables +source /opt/intel/oneapi/setvars.sh + +# Option 1: Use FP32 (recommended for better performance in most cases) +cmake -B build -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx + +# Option 2: Use FP16 +cmake -B build -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON + +# build all binary +cmake --build build --config Release -j -v +``` + +It is possible to come across some precision issues when running tests that stem from using faster +instructions, which can be circumvented by setting the environment variable `SYCL_PROGRAM_COMPILE_OPTIONS` +as `-cl-fp32-correctly-rounded-divide-sqrt` + +#### Nvidia GPU + +The SYCL backend depends on [oneMath](https://github.com/uxlfoundation/oneMath) for Nvidia and AMD devices. +By default it is automatically built along with the project. A specific build can be provided by setting the CMake flag `-DoneMath_DIR=/path/to/oneMath/install/lib/cmake/oneMath`. + +```sh +# Build LLAMA with Nvidia BLAS acceleration through SYCL +# Setting GGML_SYCL_DEVICE_ARCH is optional but can improve performance +GGML_SYCL_DEVICE_ARCH=sm_80 # Example architecture + +# Option 1: Use FP32 (recommended for better performance in most cases) +cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DDNNL_DIR=/path/to/oneDNN/build-nvidia/install/lib/cmake/dnnl + +# Option 2: Use FP16 +cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON -DDNNL_DIR=/path/to/oneDNN/build-nvidia/install/lib/cmake/dnnl + +# build all binary +cmake --build build --config Release -j -v +``` + +It is possible to come across some precision issues when running tests that stem from using faster +instructions, which can be circumvented by passing the `-fno-fast-math` flag to the compiler. + +#### AMD GPU + +The SYCL backend depends on [oneMath](https://github.com/uxlfoundation/oneMath) for Nvidia and AMD devices. +By default it is automatically built along with the project. A specific build can be provided by setting the CMake flag `-DoneMath_DIR=/path/to/oneMath/install/lib/cmake/oneMath`. + +```sh +# Build LLAMA with rocBLAS acceleration through SYCL + +## AMD +# Use FP32, FP16 is not supported +# Find your GGML_SYCL_DEVICE_ARCH with rocminfo, under the key 'Name:' +GGML_SYCL_DEVICE_ARCH=gfx90a # Example architecture +cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=AMD -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx + +# build all binary +cmake --build build --config Release -j -v +``` + +### III. Run the inference + +#### Retrieve and prepare model + +You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf). + +##### Check device + +1. Enable oneAPI running environment + +```sh +source /opt/intel/oneapi/setvars.sh +``` + +2. List devices information + +Similar to the native `sycl-ls`, available SYCL devices can be queried as follow: + +```sh +./build/bin/llama-ls-sycl-device +``` + +This command will only display the selected backend that is supported by SYCL. The default backend is level_zero. For example, in a system with 2 *intel GPU* it would look like the following: +``` +found 2 SYCL devices: + +| | | |Compute |Max compute|Max work|Max sub| | +|ID| Device Type| Name|capability|units |group |group |Global mem size| +|--|------------------|---------------------------------------------|----------|-----------|--------|-------|---------------| +| 0|[level_zero:gpu:0]| Intel(R) Arc(TM) A770 Graphics| 1.3| 512| 1024| 32| 16225243136| +| 1|[level_zero:gpu:1]| Intel(R) UHD Graphics 770| 1.3| 32| 512| 32| 53651849216| +``` + +#### Choose level-zero devices + +|Chosen Device ID|Setting| +|-|-| +|0|`export ONEAPI_DEVICE_SELECTOR="level_zero:0"` or no action| +|1|`export ONEAPI_DEVICE_SELECTOR="level_zero:1"`| +|0 & 1|`export ONEAPI_DEVICE_SELECTOR="level_zero:0;level_zero:1"`| + +#### Execute + +Choose one of following methods to run. + +1. Script + +- Use device 0: + +```sh +./examples/sycl/run-llama2.sh 0 +# OR +./examples/sycl/run-llama3.sh 0 +``` +- Use multiple devices: + +```sh +./examples/sycl/run-llama2.sh +# OR +./examples/sycl/run-llama3.sh +``` + +2. Command line +Launch inference + +There are two device selection modes: + +- Single device: Use one device assigned by user. Default device id is 0. +- Multiple devices: Automatically choose the devices with the same backend. + +In two device selection modes, the default SYCL backend is level_zero, you can choose other backend supported by SYCL by setting environment variable ONEAPI_DEVICE_SELECTOR. + +| Device selection | Parameter | +|------------------|----------------------------------------| +| Single device | --split-mode none --main-gpu DEVICE_ID | +| Multiple devices | --split-mode layer (default) | + +Examples: + +- Use device 0: + +```sh +ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 99 -sm none -mg 0 +``` + +- Use multiple devices: + +```sh +ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 99 -sm layer +``` + +*Notes:* + +- Upon execution, verify the selected device(s) ID(s) in the output log, which can for instance be displayed as follow: + +```sh +detect 1 SYCL GPUs: [0] with top Max compute units:512 +``` +Or +```sh +use 1 SYCL GPUs: [0] with Max compute units:512 +``` + +## Windows + +### I. Setup Environment + +1. Install GPU driver + +Intel GPU drivers instructions guide and download page can be found here: [Get Intel GPU Drivers](https://www.intel.com/content/www/us/en/products/docs/discrete-gpus/arc/software/drivers.html). + +2. Install Visual Studio + +If you already have a recent version of Microsoft Visual Studio, you can skip this step. Otherwise, please refer to the official download page for [Microsoft Visual Studio](https://visualstudio.microsoft.com/). + +3. Install Intel® oneAPI Base toolkit + +The base toolkit can be obtained from the official [Intel® oneAPI Base Toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) page. + +Please follow the instructions for downloading and installing the Toolkit for Windows, and preferably keep the default installation values unchanged, notably the installation path *(`C:\Program Files (x86)\Intel\oneAPI` by default)*. + +Following guidelines/code snippets assume the default installation values. Otherwise, please make sure the necessary changes are reflected where applicable. + +b. Enable oneAPI running environment: + +- Type "oneAPI" in the search bar, then open the `Intel oneAPI command prompt for Intel 64 for Visual Studio 2022` App. + +- On the command prompt, enable the runtime environment with the following: +``` +"C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 +``` + +- if you are using Powershell, enable the runtime environment with the following: + +``` +cmd.exe "/K" '"C:\Program Files (x86)\Intel\oneAPI\setvars.bat" && powershell' +``` + +c. Verify installation + +In the oneAPI command line, run the following to print the available SYCL devices: + +``` +sycl-ls.exe +``` + +There should be one or more *level-zero* GPU devices displayed as **[ext_oneapi_level_zero:gpu]**. Below is example of such output detecting an *intel Iris Xe* GPU as a Level-zero SYCL device: + +Output (example): +``` +[opencl:acc:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.10.0.17_160000] +[opencl:cpu:1] Intel(R) OpenCL, 11th Gen Intel(R) Core(TM) i7-1185G7 @ 3.00GHz OpenCL 3.0 (Build 0) [2023.16.10.0.17_160000] +[opencl:gpu:2] Intel(R) OpenCL Graphics, Intel(R) Iris(R) Xe Graphics OpenCL 3.0 NEO [31.0.101.5186] +[ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Iris(R) Xe Graphics 1.3 [1.3.28044] +``` + +4. Install build tools + +a. Download & install cmake for Windows: https://cmake.org/download/ (CMake can also be installed from Visual Studio Installer) +b. The new Visual Studio will install Ninja as default. (If not, please install it manually: https://ninja-build.org/) + + +### II. Build llama.cpp + +You could download the release package for Windows directly, which including binary files and depended oneAPI dll files. + +Choose one of following methods to build from source code. + +#### 1. Script + +```sh +.\examples\sycl\win-build-sycl.bat +``` + +#### 2. CMake + +On the oneAPI command line window, step into the llama.cpp main directory and run the following: + +``` +@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force + +# Option 1: Use FP32 (recommended for better performance in most cases) +cmake -B build -G "Ninja" -DGGML_SYCL=ON -DCMAKE_C_COMPILER=cl -DCMAKE_CXX_COMPILER=icx -DCMAKE_BUILD_TYPE=Release + +# Option 2: Or FP16 +cmake -B build -G "Ninja" -DGGML_SYCL=ON -DCMAKE_C_COMPILER=cl -DCMAKE_CXX_COMPILER=icx -DCMAKE_BUILD_TYPE=Release -DGGML_SYCL_F16=ON + +cmake --build build --config Release -j +``` + +Or, use CMake presets to build: + +```sh +cmake --preset x64-windows-sycl-release +cmake --build build-x64-windows-sycl-release -j --target llama-cli + +cmake -DGGML_SYCL_F16=ON --preset x64-windows-sycl-release +cmake --build build-x64-windows-sycl-release -j --target llama-cli + +cmake --preset x64-windows-sycl-debug +cmake --build build-x64-windows-sycl-debug -j --target llama-cli +``` + +#### 3. Visual Studio + +You have two options to use Visual Studio to build llama.cpp: +- As CMake Project using CMake presets. +- Creating a Visual Studio solution to handle the project. + +**Note**: + +All following commands are executed in PowerShell. + +##### - Open as a CMake Project + +You can use Visual Studio to open the `llama.cpp` folder directly as a CMake project. Before compiling, select one of the SYCL CMake presets: + +- `x64-windows-sycl-release` + +- `x64-windows-sycl-debug` + +*Notes:* +- For a minimal experimental setup, you can build only the inference executable using: + + ```Powershell + cmake --build build --config Release -j --target llama-cli + ``` + +##### - Generating a Visual Studio Solution + +You can use Visual Studio solution to build and work on llama.cpp on Windows. You need to convert the CMake Project into a `.sln` file. + +If you want to use the Intel C++ Compiler for the entire `llama.cpp` project, run the following command: + +```Powershell +cmake -B build -G "Visual Studio 17 2022" -T "Intel C++ Compiler 2025" -A x64 -DGGML_SYCL=ON -DCMAKE_BUILD_TYPE=Release +``` + +If you prefer to use the Intel C++ Compiler only for `ggml-sycl`, ensure that `ggml` and its backend libraries are built as shared libraries ( i.e. `-DBUILD_SHARED_LIBRARIES=ON`, this is default behaviour): + +```Powershell +cmake -B build -G "Visual Studio 17 2022" -A x64 -DGGML_SYCL=ON -DCMAKE_BUILD_TYPE=Release \ + -DSYCL_INCLUDE_DIR="C:\Program Files (x86)\Intel\oneAPI\compiler\latest\include" \ + -DSYCL_LIBRARY_DIR="C:\Program Files (x86)\Intel\oneAPI\compiler\latest\lib" +``` + +If successful the build files have been written to: *path/to/llama.cpp/build* +Open the project file **build/llama.cpp.sln** with Visual Studio. + +Once the Visual Studio solution is created, follow these steps: + +1. Open the solution in Visual Studio. + +2. Right-click on `ggml-sycl` and select **Properties**. + +3. In the left column, expand **C/C++** and select **DPC++**. + +4. In the right panel, find **Enable SYCL Offload** and set it to `Yes`. + +5. Apply the changes and save. + + +*Navigation Path:* + +``` +Properties -> C/C++ -> DPC++ -> Enable SYCL Offload (Yes) +``` + +Now, you can build `llama.cpp` with the SYCL backend as a Visual Studio project. +To do it from menu: `Build -> Build Solution`. +Once it is completed, final results will be in **build/Release/bin** + +*Additional Note* + +- You can avoid specifying `SYCL_INCLUDE_DIR` and `SYCL_LIBRARY_DIR` in the CMake command by setting the environment variables: + + - `SYCL_INCLUDE_DIR_HINT` + + - `SYCL_LIBRARY_DIR_HINT` + +- Above instruction has been tested with Visual Studio 17 Community edition and oneAPI 2025.0. We expect them to work also with future version if the instructions are adapted accordingly. + +### III. Run the inference + +#### Retrieve and prepare model + +You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf). + +##### Check device + +1. Enable oneAPI running environment + +On the oneAPI command line window, run the following and step into the llama.cpp directory: +``` +"C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 +``` + +2. List devices information + +Similar to the native `sycl-ls`, available SYCL devices can be queried as follow: + +``` +build\bin\llama-ls-sycl-device.exe +``` + +This command will only display the selected backend that is supported by SYCL. The default backend is level_zero. For example, in a system with 2 *Intel GPU* it would look like the following: +``` +found 2 SYCL devices: +| | | |Compute |Max compute|Max work|Max sub| | +|ID| Device Type| Name|capability|units |group |group |Global mem size| +|--|------------------|---------------------------------------------|----------|-----------|--------|-------|---------------| +| 0|[level_zero:gpu:0]| Intel(R) Arc(TM) A770 Graphics| 1.3| 512| 1024| 32| 16225243136| +| 1|[level_zero:gpu:1]| Intel(R) UHD Graphics 770| 1.3| 32| 512| 32| 53651849216| + +``` + +#### Choose level-zero devices + +|Chosen Device ID|Setting| +|-|-| +|0|Default option. You may also want to `set ONEAPI_DEVICE_SELECTOR="level_zero:0"`| +|1|`set ONEAPI_DEVICE_SELECTOR="level_zero:1"`| +|0 & 1|`set ONEAPI_DEVICE_SELECTOR="level_zero:0;level_zero:1"` or `set ONEAPI_DEVICE_SELECTOR="level_zero:*"`| + +#### Execute + +Choose one of following methods to run. + +1. Script + +``` +examples\sycl\win-run-llama-2.bat +``` + +or + +``` +examples\sycl\win-run-llama-3.bat +``` + +2. Command line + +Launch inference + +There are two device selection modes: + +- Single device: Use one device assigned by user. Default device id is 0. +- Multiple devices: Automatically choose the devices with the same backend. + +In two device selection modes, the default SYCL backend is level_zero, you can choose other backend supported by SYCL by setting environment variable ONEAPI_DEVICE_SELECTOR. + +| Device selection | Parameter | +|------------------|----------------------------------------| +| Single device | --split-mode none --main-gpu DEVICE_ID | +| Multiple devices | --split-mode layer (default) | + +Examples: + +- Use device 0: + +``` +build\bin\llama-cli.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 99 -sm none -mg 0 +``` + +- Use multiple devices: + +``` +build\bin\llama-cli.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 99 -sm layer +``` + + +Note: + +- Upon execution, verify the selected device(s) ID(s) in the output log, which can for instance be displayed as follow: + +```sh +detect 1 SYCL GPUs: [0] with top Max compute units:512 +``` + +Or + +```sh +use 1 SYCL GPUs: [0] with Max compute units:512 +``` + + +## Environment Variable + +#### Build + +| Name | Value | Function | +|--------------------|---------------------------------------|---------------------------------------------| +| GGML_SYCL | ON (mandatory) | Enable build with SYCL code path. | +| GGML_SYCL_TARGET | INTEL *(default)* \| NVIDIA \| AMD | Set the SYCL target device type. | +| GGML_SYCL_DEVICE_ARCH | Optional (except for AMD) | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. | +| GGML_SYCL_F16 | OFF *(default)* \|ON *(optional)* | Enable FP16 build with SYCL code path. (1.) | +| GGML_SYCL_GRAPH | ON *(default)* \|OFF *(Optional)* | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). | +| GGML_SYCL_DNN | ON *(default)* \|OFF *(Optional)* | Enable build with oneDNN. | +| CMAKE_C_COMPILER | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path. | +| CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)* | Set `icpx/icx` compiler for SYCL code path. | + +1. FP16 is recommended for better prompt processing performance on quantized models. Performance is equivalent in text generation but set `GGML_SYCL_F16=OFF` if you are experiencing issues with FP16 builds. + +#### Runtime + +| Name | Value | Function | +|-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------| +| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG | +| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase | +| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. | +| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. | +| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.
Recommended to use when --split-mode = layer | + + +## Known Issues + +- `Split-mode:[row]` is not supported. + +## Q&A + +- Error: `error while loading shared libraries: libsycl.so: cannot open shared object file: No such file or directory`. + + - Potential cause: Unavailable oneAPI installation or not set ENV variables. + - Solution: Install *oneAPI base toolkit* and enable its ENV through: `source /opt/intel/oneapi/setvars.sh`. + +- General compiler error: + + - Remove **build** folder or try a clean-build. + +- I can **not** see `[ext_oneapi_level_zero:gpu]` afer installing the GPU driver on Linux. + + Please double-check with `sudo sycl-ls`. + + If it's present in the list, please add video/render group to your user then **logout/login** or restart your system: + + ``` + sudo usermod -aG render $USER + sudo usermod -aG video $USER + ``` + Otherwise, please double-check the GPU driver installation steps. + +- Can I report Ollama issue on Intel GPU to llama.cpp SYCL backend? + + No. We can't support Ollama issue directly, because we aren't familiar with Ollama. + + Sugguest reproducing on llama.cpp and report similar issue to llama.cpp. We will surpport it. + + It's same for other projects including llama.cpp SYCL backend. + +- `Native API failed. Native API returns: 39 (UR_RESULT_ERROR_OUT_OF_DEVICE_MEMORY)`, `ggml_backend_sycl_buffer_type_alloc_buffer: can't allocate 3503030272 Bytes of memory on device`, or `failed to allocate SYCL0 buffer` + + You are running out of Device Memory. + + |Reason|Solution| + |-|-| + | The default context is too big. It leads to excessive memory usage.|Set `-c 8192` or a smaller value.| + | The model is too big and requires more memory than what is available.|Choose a smaller model or change to a smaller quantization, like Q5 -> Q4;
Alternatively, use more than one device to load model.| + +### **GitHub contribution**: +Please add the `SYCL :` prefix/tag in issues/PRs titles to help the SYCL contributors to check/address them without delay. + +## TODO + +- Review ZES_ENABLE_SYSMAN: https://github.com/intel/compute-runtime/blob/master/programmers-guide/SYSMAN.md#support-and-limitations diff --git a/docs/build.md b/docs/build.md new file mode 100644 index 0000000000000000000000000000000000000000..680b0d83987419338a64e2c40e11dbd2ae09cc6b --- /dev/null +++ b/docs/build.md @@ -0,0 +1,566 @@ +# Build llama.cpp locally + +The main product of this project is the `llama` library. Its C-style interface can be found in [include/llama.h](include/llama.h). + +The project also includes many example programs and tools using the `llama` library. The examples range from simple, minimal code snippets to sophisticated sub-projects such as an OpenAI-compatible HTTP server. + +**To get the Code:** + +```bash +git clone https://github.com/ggml-org/llama.cpp +cd llama.cpp +``` + +The following sections describe how to build with different backends and options. + +## CPU Build + +Build llama.cpp using `CMake`: + +```bash +cmake -B build +cmake --build build --config Release +``` + +**Notes**: + +- For faster compilation, add the `-j` argument to run multiple jobs in parallel, or use a generator that does this automatically such as Ninja. For example, `cmake --build build --config Release -j 8` will run 8 jobs in parallel. +- For faster repeated compilation, install [ccache](https://ccache.dev/) +- For debug builds, there are two cases: + + 1. Single-config generators (e.g. default = `Unix Makefiles`; note that they just ignore the `--config` flag): + + ```bash + cmake -B build -DCMAKE_BUILD_TYPE=Debug + cmake --build build + ``` + + 2. Multi-config generators (`-G` param set to Visual Studio, XCode...): + + ```bash + cmake -B build -G "Xcode" + cmake --build build --config Debug + ``` + + For more details and a list of supported generators, see the [CMake documentation](https://cmake.org/cmake/help/latest/manual/cmake-generators.7.html). +- For static builds, add `-DBUILD_SHARED_LIBS=OFF`: + ``` + cmake -B build -DBUILD_SHARED_LIBS=OFF + cmake --build build --config Release + ``` + +- Building for Windows (x86, x64 and arm64) with MSVC or clang as compilers: + - Install Visual Studio 2022, e.g. via the [Community Edition](https://visualstudio.microsoft.com/vs/community/). In the installer, select at least the following options (this also automatically installs the required additional tools like CMake,...): + - Tab Workload: Desktop-development with C++ + - Tab Components (select quickly via search): C++-_CMake_ Tools for Windows, _Git_ for Windows, C++-_Clang_ Compiler for Windows, MS-Build Support for LLVM-Toolset (clang) + - Please remember to always use a Developer Command Prompt / PowerShell for VS2022 for git, build, test + - For Windows on ARM (arm64, WoA) build with: + ```bash + cmake --preset arm64-windows-llvm-release -D GGML_OPENMP=OFF + cmake --build build-arm64-windows-llvm-release + ``` + Building for arm64 can also be done with the MSVC compiler with the build-arm64-windows-MSVC preset, or the standard CMake build instructions. However, note that the MSVC compiler does not support inline ARM assembly code, used e.g. for the accelerated Q4_0_N_M CPU kernels. + + For building with ninja generator and clang compiler as default: + -set path:set LIB=C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\um\x64;C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Tools\MSVC\14.41.34120\lib\x64\uwp;C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\ucrt\x64 + ```bash + cmake --preset x64-windows-llvm-release + cmake --build build-x64-windows-llvm-release + ``` +- Curl usage is enabled by default and can be turned off with `-DLLAMA_CURL=OFF`. Otherwise you need to install development libraries for libcurl. + +## BLAS Build + +Building the program with BLAS support may lead to some performance improvements in prompt processing using batch sizes higher than 32 (the default is 512). Using BLAS doesn't affect the generation performance. There are currently several different BLAS implementations available for build and use: + +### Accelerate Framework + +This is only available on Mac PCs and it's enabled by default. You can just build using the normal instructions. + +### OpenBLAS + +This provides BLAS acceleration using only the CPU. Make sure to have OpenBLAS installed on your machine. + +- Using `CMake` on Linux: + + ```bash + cmake -B build -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS + cmake --build build --config Release + ``` + +### BLIS + +Check [BLIS.md](./backend/BLIS.md) for more information. + +### Intel oneMKL + +Building through oneAPI compilers will make avx_vnni instruction set available for intel processors that do not support avx512 and avx512_vnni. Please note that this build config **does not support Intel GPU**. For Intel GPU support, please refer to [llama.cpp for SYCL](./backend/SYCL.md). + +- Using manual oneAPI installation: + By default, `GGML_BLAS_VENDOR` is set to `Generic`, so if you already sourced intel environment script and assign `-DGGML_BLAS=ON` in cmake, the mkl version of Blas will automatically been selected. Otherwise please install oneAPI and follow the below steps: + ```bash + source /opt/intel/oneapi/setvars.sh # You can skip this step if in oneapi-basekit docker image, only required for manual installation + cmake -B build -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=Intel10_64lp -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_NATIVE=ON + cmake --build build --config Release + ``` + +- Using oneAPI docker image: + If you do not want to source the environment vars and install oneAPI manually, you can also build the code using intel docker container: [oneAPI-basekit](https://hub.docker.com/r/intel/oneapi-basekit). Then, you can use the commands given above. + +Check [Optimizing and Running LLaMA2 on Intel® CPU](https://www.intel.com/content/www/us/en/content-details/791610/optimizing-and-running-llama2-on-intel-cpu.html) for more information. + +### Other BLAS libraries + +Any other BLAS library can be used by setting the `GGML_BLAS_VENDOR` option. See the [CMake documentation](https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors) for a list of supported vendors. + +## Metal Build + +On MacOS, Metal is enabled by default. Using Metal makes the computation run on the GPU. +To disable the Metal build at compile time use the `-DGGML_METAL=OFF` cmake option. + +When built with Metal support, you can explicitly disable GPU inference with the `--n-gpu-layers 0` command-line argument. + +## SYCL + +SYCL is a higher-level programming model to improve programming productivity on various hardware accelerators. + +llama.cpp based on SYCL is used to **support Intel GPU** (Data Center Max series, Flex series, Arc series, Built-in GPU and iGPU). + +For detailed info, please refer to [llama.cpp for SYCL](./backend/SYCL.md). + +## CUDA + +This provides GPU acceleration using an NVIDIA GPU. Make sure to have the [CUDA toolkit](https://developer.nvidia.com/cuda-toolkit) installed. + +#### Download directly from NVIDIA +You may find the official downloads here: [NVIDIA developer site](https://developer.nvidia.com/cuda-downloads). + + +#### Compile and run inside a Fedora Toolbox Container +We also have a [guide](./backend/CUDA-FEDORA.md) for setting up CUDA toolkit in a Fedora [toolbox container](https://containertoolbx.org/). + +**Recommended for:** +- ***Necessary*** for users of [Atomic Desktops for Fedora](https://fedoraproject.org/atomic-desktops/); such as: [Silverblue](https://fedoraproject.org/atomic-desktops/silverblue/) and [Kinoite](https://fedoraproject.org/atomic-desktops/kinoite/). + - (there are no supported CUDA packages for these systems) +- ***Necessary*** for users that have a host that is not a: [Supported Nvidia CUDA Release Platform](https://developer.nvidia.com/cuda-downloads). + - (for example, you may have [Fedora 42 Beta](https://fedoramagazine.org/announcing-fedora-linux-42-beta/) as your your host operating system) +- ***Convenient*** For those running [Fedora Workstation](https://fedoraproject.org/workstation/) or [Fedora KDE Plasma Desktop](https://fedoraproject.org/spins/kde), and want to keep their host system clean. +- *Optionally* toolbox packages are available: [Arch Linux](https://archlinux.org/), [Red Hat Enterprise Linux >= 8.5](https://www.redhat.com/en/technologies/linux-platforms/enterprise-linux), or [Ubuntu](https://ubuntu.com/download) + + +### Compilation +```bash +cmake -B build -DGGML_CUDA=ON +cmake --build build --config Release +``` + +### Override Compute Capability Specifications + +If `nvcc` cannot detect your gpu, you may get compile-warnings such as: + ```text +nvcc warning : Cannot find valid GPU for '-arch=native', default arch is used +``` + +To override the `native` GPU detection: + +#### 1. Take note of the `Compute Capability` of your NVIDIA devices: ["CUDA: Your GPU Compute > Capability"](https://developer.nvidia.com/cuda-gpus). + +```text +GeForce RTX 4090 8.9 +GeForce RTX 3080 Ti 8.6 +GeForce RTX 3070 8.6 +``` + +#### 2. Manually list each varying `Compute Capability` in the `CMAKE_CUDA_ARCHITECTURES` list. + +```bash +cmake -B build -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES="86;89" +``` + +### Runtime CUDA environmental variables + +You may set the [cuda environmental variables](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) at runtime. + +```bash +# Use `CUDA_VISIBLE_DEVICES` to hide the first compute device. +CUDA_VISIBLE_DEVICES="-0" ./build/bin/llama-server --model /srv/models/llama.gguf +``` + +### Unified Memory + +The environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1` can be used to enable unified memory in Linux. This allows swapping to system RAM instead of crashing when the GPU VRAM is exhausted. In Windows this setting is available in the NVIDIA control panel as `System Memory Fallback`. + +### Performance Tuning + +The following compilation options are also available to tweak performance: + +| Option | Legal values | Default | Description | +|-------------------------------|------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| GGML_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, CDNA and RDNA3+). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower. | +| GGML_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models | +| GGML_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. | +| GGML_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. | +| GGML_CUDA_FA_ALL_QUANTS | Boolean | false | Compile support for all KV cache quantization type (combinations) for the FlashAttention CUDA kernels. More fine-grained control over KV cache size but compilation takes much longer. | + +## MUSA + +This provides GPU acceleration using a Moore Threads GPU. Make sure to have the [MUSA SDK](https://developer.mthreads.com/musa/musa-sdk) installed. + +#### Download directly from Moore Threads + +You may find the official downloads here: [Moore Threads developer site](https://developer.mthreads.com/sdk/download/musa). + +### Compilation + +```bash +cmake -B build -DGGML_MUSA=ON +cmake --build build --config Release +``` + +#### Override Compute Capability Specifications + +By default, all supported compute capabilities are enabled. To customize this behavior, you can specify the `MUSA_ARCHITECTURES` option in the CMake command: + +```bash +cmake -B build -DGGML_MUSA=ON -DMUSA_ARCHITECTURES="21" +cmake --build build --config Release +``` + +This configuration enables only compute capability `2.1` (MTT S80) during compilation, which can help reduce compilation time. + +#### Compilation options + +Most of the compilation options available for CUDA should also be available for MUSA, though they haven't been thoroughly tested yet. + +- For static builds, add `-DBUILD_SHARED_LIBS=OFF` and `-DCMAKE_POSITION_INDEPENDENT_CODE=ON`: + ``` + cmake -B build -DGGML_MUSA=ON \ + -DBUILD_SHARED_LIBS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON + cmake --build build --config Release + ``` + +### Runtime MUSA environmental variables + +You may set the [musa environmental variables](https://docs.mthreads.com/musa-sdk/musa-sdk-doc-online/programming_guide/Z%E9%99%84%E5%BD%95/) at runtime. + +```bash +# Use `MUSA_VISIBLE_DEVICES` to hide the first compute device. +MUSA_VISIBLE_DEVICES="-0" ./build/bin/llama-server --model /srv/models/llama.gguf +``` + +### Unified Memory + +The environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1` can be used to enable unified memory in Linux. This allows swapping to system RAM instead of crashing when the GPU VRAM is exhausted. + +## HIP + +This provides GPU acceleration on HIP-supported AMD GPUs. +Make sure to have ROCm installed. +You can download it from your Linux distro's package manager or from here: [ROCm Quick Start (Linux)](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/tutorial/quick-start.html#rocm-install-quick). + +- Using `CMake` for Linux (assuming a gfx1030-compatible AMD GPU): + ```bash + HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \ + cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \ + && cmake --build build --config Release -- -j 16 + ``` + + To enhance flash attention performance on RDNA3+ or CDNA architectures, you can utilize the rocWMMA library by enabling the `-DGGML_HIP_ROCWMMA_FATTN=ON` option. This requires rocWMMA headers to be installed on the build system. + + The rocWMMA library is included by default when installing the ROCm SDK using the `rocm` meta package provided by AMD. Alternatively, if you are not using the meta package, you can install the library using the `rocwmma-dev` or `rocwmma-devel` package, depending on your system's package manager. + + As an alternative, you can manually install the library by cloning it from the official [GitHub repository](https://github.com/ROCm/rocWMMA), checkout the corresponding version tag (e.g. `rocm-6.2.4`) and set `-DCMAKE_CXX_FLAGS="-I/library/include/"` in CMake. This also works under Windows despite not officially supported by AMD. + + Note that if you get the following error: + ``` + clang: error: cannot find ROCm device library; provide its path via '--rocm-path' or '--rocm-device-lib-path', or pass '-nogpulib' to build without ROCm device library + ``` + Try searching for a directory under `HIP_PATH` that contains the file + `oclc_abi_version_400.bc`. Then, add the following to the start of the + command: `HIP_DEVICE_LIB_PATH=`, so something + like: + ```bash + HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -p)" \ + HIP_DEVICE_LIB_PATH= \ + cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \ + && cmake --build build -- -j 16 + ``` + +- Using `CMake` for Windows (using x64 Native Tools Command Prompt for VS, and assuming a gfx1100-compatible AMD GPU): + ```bash + set PATH=%HIP_PATH%\bin;%PATH% + cmake -S . -B build -G Ninja -DAMDGPU_TARGETS=gfx1100 -DGGML_HIP=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release + cmake --build build + ``` + Make sure that `AMDGPU_TARGETS` is set to the GPU arch you want to compile for. The above example uses `gfx1100` that corresponds to Radeon RX 7900XTX/XT/GRE. You can find a list of targets [here](https://llvm.org/docs/AMDGPUUsage.html#processors) + Find your gpu version string by matching the most significant version information from `rocminfo | grep gfx | head -1 | awk '{print $2}'` with the list of processors, e.g. `gfx1035` maps to `gfx1030`. + + +The environment variable [`HIP_VISIBLE_DEVICES`](https://rocm.docs.amd.com/en/latest/understand/gpu_isolation.html#hip-visible-devices) can be used to specify which GPU(s) will be used. +If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 (e.g. gfx1030, gfx1031, or gfx1035) or 11.0.0 on RDNA3. + +### Unified Memory + +On Linux it is possible to use unified memory architecture (UMA) to share main memory between the CPU and integrated GPU by setting environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1`. However, this hurts performance for non-integrated GPUs (but enables working with integrated GPUs). + +## Vulkan + +**Windows** + +### w64devkit + +Download and extract [`w64devkit`](https://github.com/skeeto/w64devkit/releases). + +Download and install the [`Vulkan SDK`](https://vulkan.lunarg.com/sdk/home#windows) with the default settings. + +Launch `w64devkit.exe` and run the following commands to copy Vulkan dependencies: +```sh +SDK_VERSION=1.3.283.0 +cp /VulkanSDK/$SDK_VERSION/Bin/glslc.exe $W64DEVKIT_HOME/bin/ +cp /VulkanSDK/$SDK_VERSION/Lib/vulkan-1.lib $W64DEVKIT_HOME/x86_64-w64-mingw32/lib/ +cp -r /VulkanSDK/$SDK_VERSION/Include/* $W64DEVKIT_HOME/x86_64-w64-mingw32/include/ +cat > $W64DEVKIT_HOME/x86_64-w64-mingw32/lib/pkgconfig/vulkan.pc < ` + +It will then build & run in the debugger for you. + +To just execute a test and get back a PASS or FAIL message run: + +```bash +./scripts/debug-test.sh test-tokenizer +``` + +To test in GDB use the `-g` flag to enable gdb test mode. + +```bash +./scripts/debug-test.sh -g test-tokenizer + +# Once in the debugger, i.e. at the chevrons prompt, setting a breakpoint could be as follows: +>>> b main +``` + +To speed up the testing loop, if you know your test number you can just run it similar to below: + +```bash +./scripts/debug-test.sh test 23 +``` + +For further reference use `debug-test.sh -h` to print help. + +  + +### How does the script work? +If you want to be able to use the concepts contained in the script separately, the important ones are briefly outlined below. + +#### Step 1: Reset and Setup folder context + +From base of this repository, let's create `build-ci-debug` as our build context. + +```bash +rm -rf build-ci-debug && mkdir build-ci-debug && cd build-ci-debug +``` + +#### Step 2: Setup Build Environment and Compile Test Binaries + +Setup and trigger a build under debug mode. You may adapt the arguments as needed, but in this case these are sane defaults. + +```bash +cmake -DCMAKE_BUILD_TYPE=Debug -DLLAMA_CUDA=1 -DLLAMA_FATAL_WARNINGS=ON .. +make -j +``` + +#### Step 3: Find all tests available that matches REGEX + +The output of this command will give you the command & arguments needed to run GDB. + +* `-R test-tokenizer` : looks for all the test files named `test-tokenizer*` (R=Regex) +* `-N` : "show-only" disables test execution & shows test commands that you can feed to GDB. +* `-V` : Verbose Mode + +```bash +ctest -R "test-tokenizer" -V -N +``` + +This may return output similar to below (focusing on key lines to pay attention to): + +```bash +... +1: Test command: ~/llama.cpp/build-ci-debug/bin/test-tokenizer-0 "~/llama.cpp/tests/../models/ggml-vocab-llama-spm.gguf" +1: Working Directory: . +Labels: main + Test #1: test-tokenizer-0-llama-spm +... +4: Test command: ~/llama.cpp/build-ci-debug/bin/test-tokenizer-0 "~/llama.cpp/tests/../models/ggml-vocab-falcon.gguf" +4: Working Directory: . +Labels: main + Test #4: test-tokenizer-0-falcon +... +``` + +#### Step 4: Identify Test Command for Debugging + +So for test #1 above we can tell these two pieces of relevant information: +* Test Binary: `~/llama.cpp/build-ci-debug/bin/test-tokenizer-0` +* Test GGUF Model: `~/llama.cpp/tests/../models/ggml-vocab-llama-spm.gguf` + +#### Step 5: Run GDB on test command + +Based on the ctest 'test command' report above we can then run a gdb session via this command below: + +```bash +gdb --args ${Test Binary} ${Test GGUF Model} +``` + +Example: + +```bash +gdb --args ~/llama.cpp/build-ci-debug/bin/test-tokenizer-0 "~/llama.cpp/tests/../models/ggml-vocab-llama-spm.gguf" +``` diff --git a/docs/development/llama-star/idea-arch.key b/docs/development/llama-star/idea-arch.key new file mode 100644 index 0000000000000000000000000000000000000000..234d59858a3ab2d90ec2a725a137aa0b187a0179 --- /dev/null +++ b/docs/development/llama-star/idea-arch.key @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:32dede0548e6501b4b60d89b4afc3445a9a45c5df6fcee0c2bb9fe75592908c9 +size 488591 diff --git a/docs/development/llama-star/idea-arch.pdf b/docs/development/llama-star/idea-arch.pdf new file mode 100644 index 0000000000000000000000000000000000000000..4fa92c71dc4c511378c628113b7817e583053758 Binary files /dev/null and b/docs/development/llama-star/idea-arch.pdf differ diff --git a/docs/development/token_generation_performance_tips.md b/docs/development/token_generation_performance_tips.md new file mode 100644 index 0000000000000000000000000000000000000000..41b7232c976b3a86fe7509c078bcedafd014a9e1 --- /dev/null +++ b/docs/development/token_generation_performance_tips.md @@ -0,0 +1,40 @@ +# Token generation performance troubleshooting + +## Verifying that the model is running on the GPU with CUDA +Make sure you compiled llama with the correct env variables according to [this guide](/docs/build.md#cuda), so that llama accepts the `-ngl N` (or `--n-gpu-layers N`) flag. When running llama, you may configure `N` to be very large, and llama will offload the maximum possible number of layers to the GPU, even if it's less than the number you configured. For example: +```shell +./llama-cli -m "path/to/model.gguf" -ngl 200000 -p "Please sir, may I have some " +``` + +When running llama, before it starts the inference work, it will output diagnostic information that shows whether cuBLAS is offloading work to the GPU. Look for these lines: +```shell +llama_model_load_internal: [cublas] offloading 60 layers to GPU +llama_model_load_internal: [cublas] offloading output layer to GPU +llama_model_load_internal: [cublas] total VRAM used: 17223 MB +... rest of inference +``` + +If you see these lines, then the GPU is being used. + +## Verifying that the CPU is not oversaturated +llama accepts a `-t N` (or `--threads N`) parameter. It's extremely important that this parameter is not too large. If your token generation is extremely slow, try setting this number to 1. If this significantly improves your token generation speed, then your CPU is being oversaturated and you need to explicitly set this parameter to the number of the physical CPU cores on your machine (even if you utilize a GPU). If in doubt, start with 1 and double the amount until you hit a performance bottleneck, then scale the number down. + +# Example of runtime flags effect on inference speed benchmark +These runs were tested on the following machine: +GPU: A6000 (48GB VRAM) +CPU: 7 physical cores +RAM: 32GB + +Model: `TheBloke_Wizard-Vicuna-30B-Uncensored-GGML/Wizard-Vicuna-30B-Uncensored.q4_0.gguf` (30B parameters, 4bit quantization, GGML) + +Run command: `./llama-cli -m "path/to/model.gguf" -p "An extremely detailed description of the 10 best ethnic dishes will follow, with recipes: " -n 1000 [additional benchmark flags]` + +Result: + +| command | tokens/second (higher is better) | +| - | - | +| -ngl 2000000 | N/A (less than 0.1) | +| -t 7 | 1.7 | +| -t 1 -ngl 2000000 | 5.5 | +| -t 7 -ngl 2000000 | 8.7 | +| -t 4 -ngl 2000000 | 9.1 | diff --git a/docs/docker.md b/docs/docker.md new file mode 100644 index 0000000000000000000000000000000000000000..f8f0573c17239cd69a88bbcb3452bf3eaaa48a7a --- /dev/null +++ b/docs/docker.md @@ -0,0 +1,126 @@ +# Docker + +## Prerequisites +* Docker must be installed and running on your system. +* Create a folder to store big models & intermediate files (ex. /llama/models) + +## Images +We have three Docker images available for this project: + +1. `ghcr.io/ggml-org/llama.cpp:full`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. (platforms: `linux/amd64`, `linux/arm64`) +2. `ghcr.io/ggml-org/llama.cpp:light`: This image only includes the main executable file. (platforms: `linux/amd64`, `linux/arm64`) +3. `ghcr.io/ggml-org/llama.cpp:server`: This image only includes the server executable file. (platforms: `linux/amd64`, `linux/arm64`) + +Additionally, there the following images, similar to the above: + +- `ghcr.io/ggml-org/llama.cpp:full-cuda`: Same as `full` but compiled with CUDA support. (platforms: `linux/amd64`) +- `ghcr.io/ggml-org/llama.cpp:light-cuda`: Same as `light` but compiled with CUDA support. (platforms: `linux/amd64`) +- `ghcr.io/ggml-org/llama.cpp:server-cuda`: Same as `server` but compiled with CUDA support. (platforms: `linux/amd64`) +- `ghcr.io/ggml-org/llama.cpp:full-rocm`: Same as `full` but compiled with ROCm support. (platforms: `linux/amd64`, `linux/arm64`) +- `ghcr.io/ggml-org/llama.cpp:light-rocm`: Same as `light` but compiled with ROCm support. (platforms: `linux/amd64`, `linux/arm64`) +- `ghcr.io/ggml-org/llama.cpp:server-rocm`: Same as `server` but compiled with ROCm support. (platforms: `linux/amd64`, `linux/arm64`) +- `ghcr.io/ggml-org/llama.cpp:full-musa`: Same as `full` but compiled with MUSA support. (platforms: `linux/amd64`) +- `ghcr.io/ggml-org/llama.cpp:light-musa`: Same as `light` but compiled with MUSA support. (platforms: `linux/amd64`) +- `ghcr.io/ggml-org/llama.cpp:server-musa`: Same as `server` but compiled with MUSA support. (platforms: `linux/amd64`) +- `ghcr.io/ggml-org/llama.cpp:full-intel`: Same as `full` but compiled with SYCL support. (platforms: `linux/amd64`) +- `ghcr.io/ggml-org/llama.cpp:light-intel`: Same as `light` but compiled with SYCL support. (platforms: `linux/amd64`) +- `ghcr.io/ggml-org/llama.cpp:server-intel`: Same as `server` but compiled with SYCL support. (platforms: `linux/amd64`) + +The GPU enabled images are not currently tested by CI beyond being built. They are not built with any variation from the ones in the Dockerfiles defined in [.devops/](../.devops/) and the GitHub Action defined in [.github/workflows/docker.yml](../.github/workflows/docker.yml). If you need different settings (for example, a different CUDA, ROCm or MUSA library, you'll need to build the images locally for now). + +## Usage + +The easiest way to download the models, convert them to ggml and optimize them is with the --all-in-one command which includes the full docker image. + +Replace `/path/to/models` below with the actual path where you downloaded the models. + +```bash +docker run -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:full --all-in-one "/models/" 7B +``` + +On completion, you are ready to play! + +```bash +docker run -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:full --run -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 +``` + +or with a light image: + +```bash +docker run -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:light -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 +``` + +or with a server image: + +```bash +docker run -v /path/to/models:/models -p 8000:8000 ghcr.io/ggml-org/llama.cpp:server -m /models/7B/ggml-model-q4_0.gguf --port 8000 --host 0.0.0.0 -n 512 +``` + +## Docker With CUDA + +Assuming one has the [nvidia-container-toolkit](https://github.com/NVIDIA/nvidia-container-toolkit) properly installed on Linux, or is using a GPU enabled cloud, `cuBLAS` should be accessible inside the container. + +## Building Docker locally + +```bash +docker build -t local/llama.cpp:full-cuda --target full -f .devops/cuda.Dockerfile . +docker build -t local/llama.cpp:light-cuda --target light -f .devops/cuda.Dockerfile . +docker build -t local/llama.cpp:server-cuda --target server -f .devops/cuda.Dockerfile . +``` + +You may want to pass in some different `ARGS`, depending on the CUDA environment supported by your container host, as well as the GPU architecture. + +The defaults are: + +- `CUDA_VERSION` set to `12.4.0` +- `CUDA_DOCKER_ARCH` set to the cmake build default, which includes all the supported architectures + +The resulting images, are essentially the same as the non-CUDA images: + +1. `local/llama.cpp:full-cuda`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. +2. `local/llama.cpp:light-cuda`: This image only includes the main executable file. +3. `local/llama.cpp:server-cuda`: This image only includes the server executable file. + +## Usage + +After building locally, Usage is similar to the non-CUDA examples, but you'll need to add the `--gpus` flag. You will also want to use the `--n-gpu-layers` flag. + +```bash +docker run --gpus all -v /path/to/models:/models local/llama.cpp:full-cuda --run -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 --n-gpu-layers 1 +docker run --gpus all -v /path/to/models:/models local/llama.cpp:light-cuda -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 --n-gpu-layers 1 +docker run --gpus all -v /path/to/models:/models local/llama.cpp:server-cuda -m /models/7B/ggml-model-q4_0.gguf --port 8000 --host 0.0.0.0 -n 512 --n-gpu-layers 1 +``` + +## Docker With MUSA + +Assuming one has the [mt-container-toolkit](https://developer.mthreads.com/musa/native) properly installed on Linux, `muBLAS` should be accessible inside the container. + +## Building Docker locally + +```bash +docker build -t local/llama.cpp:full-musa --target full -f .devops/musa.Dockerfile . +docker build -t local/llama.cpp:light-musa --target light -f .devops/musa.Dockerfile . +docker build -t local/llama.cpp:server-musa --target server -f .devops/musa.Dockerfile . +``` + +You may want to pass in some different `ARGS`, depending on the MUSA environment supported by your container host, as well as the GPU architecture. + +The defaults are: + +- `MUSA_VERSION` set to `rc4.0.1` + +The resulting images, are essentially the same as the non-MUSA images: + +1. `local/llama.cpp:full-musa`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. +2. `local/llama.cpp:light-musa`: This image only includes the main executable file. +3. `local/llama.cpp:server-musa`: This image only includes the server executable file. + +## Usage + +After building locally, Usage is similar to the non-MUSA examples, but you'll need to set `mthreads` as default Docker runtime. This can be done by executing `(cd /usr/bin/musa && sudo ./docker setup $PWD)` and verifying the changes by executing `docker info | grep mthreads` on the host machine. You will also want to use the `--n-gpu-layers` flag. + +```bash +docker run -v /path/to/models:/models local/llama.cpp:full-musa --run -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 --n-gpu-layers 1 +docker run -v /path/to/models:/models local/llama.cpp:light-musa -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 --n-gpu-layers 1 +docker run -v /path/to/models:/models local/llama.cpp:server-musa -m /models/7B/ggml-model-q4_0.gguf --port 8000 --host 0.0.0.0 -n 512 --n-gpu-layers 1 +``` diff --git a/docs/function-calling.md b/docs/function-calling.md new file mode 100644 index 0000000000000000000000000000000000000000..37eacaf3100c165068233781b8815c068036d940 --- /dev/null +++ b/docs/function-calling.md @@ -0,0 +1,422 @@ +# Function Calling + +[chat.h](../common/chat.h) (https://github.com/ggml-org/llama.cpp/pull/9639) adds support for [OpenAI-style function calling](https://platform.openai.com/docs/guides/function-calling) and is used in: +- `llama-server` when started w/ `--jinja` flag + +## Universal support w/ Native & Generic handlers + +Function calling is supported for all models (see https://github.com/ggml-org/llama.cpp/pull/9639): + +- Native tool call formats supported: + - Llama 3.1 / 3.3 (including builtin tools support - tool names for `wolfram_alpha`, `web_search` / `brave_search`, `code_interpreter`), Llama 3.2 + - Functionary v3.1 / v3.2 + - Hermes 2/3, Qwen 2.5 + - Qwen 2.5 Coder + - Mistral Nemo + - Firefunction v2 + - Command R7B + - DeepSeek R1 (WIP / seems reluctant to call any tools?) + +- Generic tool call is supported when the template isn't recognized by native format handlers (you'll see `Chat format: Generic` in the logs). + - Use `--chat-template-file` to override the template when appropriate (see examples below) + - Generic support may consume more tokens and be less efficient than a model's native format. + +
+Show some common templates and which format handler they use + +| Template | Format | +|----------|--------| +| Almawave-Velvet-14B.jinja | Hermes 2 Pro | +| AtlaAI-Selene-1-Mini-Llama-3.1-8B.jinja | Llama 3.x | +| CohereForAI-aya-expanse-8b.jinja | Generic | +| CohereForAI-c4ai-command-r-plus-default.jinja | Generic | +| CohereForAI-c4ai-command-r-plus-rag.jinja | Generic | +| CohereForAI-c4ai-command-r-plus-tool_use.jinja | Generic | +| CohereForAI-c4ai-command-r7b-12-2024-default.jinja | Command R7B (extract reasoning) | +| CohereForAI-c4ai-command-r7b-12-2024-rag.jinja | Command R7B (extract reasoning) | +| CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja | Command R7B (extract reasoning) | +| CohereForAI-c4ai-command-r7b-12-2024.jinja | Generic | +| DavieLion-Llama-3.2-1B-SPIN-iter3.jinja | Generic | +| Delta-Vector-Rei-12B.jinja | Mistral Nemo | +| EpistemeAI-Mistral-Nemo-Instruct-12B-Philosophy-Math.jinja | Mistral Nemo | +| FlofloB-83k_continued_pretraining_Qwen2.5-0.5B-Instruct_Unsloth_merged_16bit.jinja | Hermes 2 Pro | +| FlofloB-test_continued_pretraining_Phi-3-mini-4k-instruct_Unsloth_merged_16bit.jinja | Generic | +| HelpingAI-HAI-SER.jinja | Generic | +| HuggingFaceTB-SmolLM2-1.7B-Instruct.jinja | Generic | +| HuggingFaceTB-SmolLM2-135M-Instruct.jinja | Generic | +| HuggingFaceTB-SmolLM2-360M-Instruct.jinja | Generic | +| INSAIT-Institute-BgGPT-Gemma-2-27B-IT-v1.0.jinja | Generic | +| Ihor-Text2Graph-R1-Qwen2.5-0.5b.jinja | Hermes 2 Pro | +| Infinigence-Megrez-3B-Instruct.jinja | Generic | +| Josephgflowers-TinyLlama_v1.1_math_code-world-test-1.jinja | Generic | +| LGAI-EXAONE-EXAONE-3.5-2.4B-Instruct.jinja | Generic | +| LGAI-EXAONE-EXAONE-3.5-7.8B-Instruct.jinja | Generic | +| LatitudeGames-Wayfarer-12B.jinja | Generic | +| Magpie-Align-Llama-3-8B-Magpie-Align-v0.1.jinja | Generic | +| Magpie-Align-Llama-3.1-8B-Magpie-Align-v0.1.jinja | Generic | +| MaziyarPanahi-calme-3.2-instruct-78b.jinja | Generic | +| MiniMaxAI-MiniMax-Text-01.jinja | Generic | +| MiniMaxAI-MiniMax-VL-01.jinja | Generic | +| NaniDAO-deepseek-r1-qwen-2.5-32B-ablated.jinja | DeepSeek R1 (extract reasoning) | +| NexaAIDev-Octopus-v2.jinja | Generic | +| NousResearch-Hermes-2-Pro-Llama-3-8B-default.jinja | Generic | +| NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja | Hermes 2 Pro | +| NousResearch-Hermes-2-Pro-Mistral-7B-default.jinja | Generic | +| NousResearch-Hermes-2-Pro-Mistral-7B-tool_use.jinja | Hermes 2 Pro | +| NousResearch-Hermes-3-Llama-3.1-70B-default.jinja | Generic | +| NousResearch-Hermes-3-Llama-3.1-70B-tool_use.jinja | Hermes 2 Pro | +| NovaSky-AI-Sky-T1-32B-Flash.jinja | Hermes 2 Pro | +| NovaSky-AI-Sky-T1-32B-Preview.jinja | Hermes 2 Pro | +| OnlyCheeini-greesychat-turbo.jinja | Generic | +| Orenguteng-Llama-3.1-8B-Lexi-Uncensored-V2.jinja | Llama 3.x | +| OrionStarAI-Orion-14B-Chat.jinja | Generic | +| PowerInfer-SmallThinker-3B-Preview.jinja | Generic | +| PrimeIntellect-INTELLECT-1-Instruct.jinja | Generic | +| Qwen-QVQ-72B-Preview.jinja | Generic | +| Qwen-QwQ-32B-Preview.jinja | Hermes 2 Pro | +| Qwen-Qwen1.5-7B-Chat.jinja | Generic | +| Qwen-Qwen2-7B-Instruct.jinja | Generic | +| Qwen-Qwen2-VL-72B-Instruct.jinja | Generic | +| Qwen-Qwen2-VL-7B-Instruct.jinja | Generic | +| Qwen-Qwen2.5-0.5B.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-1.5B-Instruct.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-14B-Instruct-1M.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-14B.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-32B-Instruct.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-32B.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-3B-Instruct.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-72B-Instruct.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-7B-Instruct-1M.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-7B-Instruct.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-7B.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-Coder-32B-Instruct.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-Coder-7B-Instruct.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-Math-1.5B.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-Math-7B-Instruct.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-VL-3B-Instruct.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-VL-72B-Instruct.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-VL-7B-Instruct.jinja | Hermes 2 Pro | +| RWKV-Red-Team-ARWKV-7B-Preview-0.1.jinja | Hermes 2 Pro | +| SakanaAI-TinySwallow-1.5B-Instruct.jinja | Hermes 2 Pro | +| SakanaAI-TinySwallow-1.5B.jinja | Hermes 2 Pro | +| Sao10K-70B-L3.3-Cirrus-x1.jinja | Llama 3.x | +| SentientAGI-Dobby-Mini-Leashed-Llama-3.1-8B.jinja | Llama 3.x | +| SentientAGI-Dobby-Mini-Unhinged-Llama-3.1-8B.jinja | Llama 3.x | +| Steelskull-L3.3-Damascus-R1.jinja | Llama 3.x | +| Steelskull-L3.3-MS-Nevoria-70b.jinja | Llama 3.x | +| Steelskull-L3.3-Nevoria-R1-70b.jinja | Llama 3.x | +| THUDM-glm-4-9b-chat.jinja | Generic | +| THUDM-glm-edge-1.5b-chat.jinja | Generic | +| Tarek07-Progenitor-V1.1-LLaMa-70B.jinja | Llama 3.x | +| TheBloke-FusionNet_34Bx2_MoE-AWQ.jinja | Generic | +| TinyLlama-TinyLlama-1.1B-Chat-v1.0.jinja | Generic | +| UCLA-AGI-Mistral7B-PairRM-SPPO-Iter3.jinja | Generic | +| ValiantLabs-Llama3.1-8B-Enigma.jinja | Llama 3.x | +| abacusai-Fewshot-Metamath-OrcaVicuna-Mistral.jinja | Generic | +| ai21labs-AI21-Jamba-1.5-Large.jinja | Generic | +| allenai-Llama-3.1-Tulu-3-405B-SFT.jinja | Generic | +| allenai-Llama-3.1-Tulu-3-405B.jinja | Generic | +| allenai-Llama-3.1-Tulu-3-8B.jinja | Generic | +| arcee-ai-Virtuoso-Lite.jinja | Hermes 2 Pro | +| arcee-ai-Virtuoso-Medium-v2.jinja | Hermes 2 Pro | +| arcee-ai-Virtuoso-Small-v2.jinja | Hermes 2 Pro | +| avemio-GRAG-NEMO-12B-ORPO-HESSIAN-AI.jinja | Generic | +| bespokelabs-Bespoke-Stratos-7B.jinja | Hermes 2 Pro | +| bfuzzy1-acheron-m1a-llama.jinja | Generic | +| bofenghuang-vigogne-2-70b-chat.jinja | Generic | +| bytedance-research-UI-TARS-72B-DPO.jinja | Generic | +| bytedance-research-UI-TARS-7B-DPO.jinja | Generic | +| bytedance-research-UI-TARS-7B-SFT.jinja | Generic | +| carsenk-phi3.5_mini_exp_825_uncensored.jinja | Generic | +| cyberagent-DeepSeek-R1-Distill-Qwen-14B-Japanese.jinja | DeepSeek R1 (extract reasoning) | +| cyberagent-DeepSeek-R1-Distill-Qwen-32B-Japanese.jinja | DeepSeek R1 (extract reasoning) | +| databricks-dbrx-instruct.jinja | Generic | +| deepseek-ai-DeepSeek-Coder-V2-Instruct.jinja | Generic | +| deepseek-ai-DeepSeek-Coder-V2-Lite-Base.jinja | Generic | +| deepseek-ai-DeepSeek-Coder-V2-Lite-Instruct.jinja | Generic | +| deepseek-ai-DeepSeek-R1-Distill-Llama-70B.jinja | DeepSeek R1 (extract reasoning) | +| deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja | DeepSeek R1 (extract reasoning) | +| deepseek-ai-DeepSeek-R1-Distill-Qwen-1.5B.jinja | DeepSeek R1 (extract reasoning) | +| deepseek-ai-DeepSeek-R1-Distill-Qwen-14B.jinja | DeepSeek R1 (extract reasoning) | +| deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja | DeepSeek R1 (extract reasoning) | +| deepseek-ai-DeepSeek-R1-Distill-Qwen-7B.jinja | DeepSeek R1 (extract reasoning) | +| deepseek-ai-DeepSeek-R1-Zero.jinja | DeepSeek R1 (extract reasoning) | +| deepseek-ai-DeepSeek-R1.jinja | DeepSeek R1 (extract reasoning) | +| deepseek-ai-DeepSeek-V2-Lite.jinja | Generic | +| deepseek-ai-DeepSeek-V2.5.jinja | DeepSeek R1 (extract reasoning) | +| deepseek-ai-DeepSeek-V3.jinja | DeepSeek R1 (extract reasoning) | +| deepseek-ai-deepseek-coder-33b-instruct.jinja | Generic | +| deepseek-ai-deepseek-coder-6.7b-instruct.jinja | Generic | +| deepseek-ai-deepseek-coder-7b-instruct-v1.5.jinja | Generic | +| deepseek-ai-deepseek-llm-67b-chat.jinja | Generic | +| deepseek-ai-deepseek-llm-7b-chat.jinja | Generic | +| dicta-il-dictalm2.0-instruct.jinja | Generic | +| ehristoforu-Falcon3-8B-Franken-Basestruct.jinja | Hermes 2 Pro | +| fireworks-ai-llama-3-firefunction-v2.jinja | FireFunction v2 | +| godlikehhd-alpaca_data_sampled_ifd_new_5200.jinja | Hermes 2 Pro | +| godlikehhd-alpaca_data_score_max_0.7_2600.jinja | Hermes 2 Pro | +| google-gemma-2-27b-it.jinja | Generic | +| google-gemma-2-2b-it.jinja | Generic | +| google-gemma-2-2b-jpn-it.jinja | Generic | +| google-gemma-7b-it.jinja | Generic | +| huihui-ai-DeepSeek-R1-Distill-Llama-70B-abliterated.jinja | DeepSeek R1 (extract reasoning) | +| huihui-ai-DeepSeek-R1-Distill-Llama-8B-abliterated.jinja | DeepSeek R1 (extract reasoning) | +| huihui-ai-DeepSeek-R1-Distill-Qwen-14B-abliterated-v2.jinja | DeepSeek R1 (extract reasoning) | +| huihui-ai-DeepSeek-R1-Distill-Qwen-32B-abliterated.jinja | DeepSeek R1 (extract reasoning) | +| huihui-ai-DeepSeek-R1-Distill-Qwen-7B-abliterated-v2.jinja | DeepSeek R1 (extract reasoning) | +| huihui-ai-Qwen2.5-14B-Instruct-1M-abliterated.jinja | Hermes 2 Pro | +| ibm-granite-granite-3.1-8b-instruct.jinja | Generic | +| indischepartij-MiniCPM-3B-OpenHermes-2.5-v2.jinja | Generic | +| inflatebot-MN-12B-Mag-Mell-R1.jinja | Generic | +| jinaai-ReaderLM-v2.jinja | Generic | +| kms7530-chemeng_qwen-math-7b_24_1_100_1_nonmath.jinja | Hermes 2 Pro | +| knifeayumu-Cydonia-v1.3-Magnum-v4-22B.jinja | Mistral Nemo | +| langgptai-qwen1.5-7b-chat-sa-v0.1.jinja | Generic | +| lightblue-DeepSeek-R1-Distill-Qwen-7B-Japanese.jinja | DeepSeek R1 (extract reasoning) | +| mattshumer-Reflection-Llama-3.1-70B.jinja | Generic | +| meetkai-functionary-medium-v3.1.jinja | Functionary v3.1 Llama 3.1 | +| meetkai-functionary-medium-v3.2.jinja | Functionary v3.2 | +| meta-llama-Llama-2-7b-chat-hf.jinja | Generic | +| meta-llama-Llama-3.1-8B-Instruct.jinja | Llama 3.x | +| meta-llama-Llama-3.2-11B-Vision-Instruct.jinja | Llama 3.x | +| meta-llama-Llama-3.2-1B-Instruct.jinja | Llama 3.x | +| meta-llama-Llama-3.2-3B-Instruct.jinja | Llama 3.x | +| meta-llama-Llama-3.3-70B-Instruct.jinja | Llama 3.x | +| meta-llama-Meta-Llama-3-8B-Instruct.jinja | Generic | +| meta-llama-Meta-Llama-3.1-8B-Instruct.jinja | Llama 3.x | +| microsoft-Phi-3-medium-4k-instruct.jinja | Generic | +| microsoft-Phi-3-mini-4k-instruct.jinja | Generic | +| microsoft-Phi-3-small-8k-instruct.jinja | Generic | +| microsoft-Phi-3.5-mini-instruct.jinja | Generic | +| microsoft-Phi-3.5-vision-instruct.jinja | Generic | +| microsoft-phi-4.jinja | Generic | +| migtissera-Tess-3-Mistral-Nemo-12B.jinja | Generic | +| ministral-Ministral-3b-instruct.jinja | Generic | +| mistralai-Codestral-22B-v0.1.jinja | Generic | +| mistralai-Mistral-7B-Instruct-v0.1.jinja | Generic | +| mistralai-Mistral-7B-Instruct-v0.2.jinja | Generic | +| mistralai-Mistral-7B-Instruct-v0.3.jinja | Mistral Nemo | +| mistralai-Mistral-Large-Instruct-2407.jinja | Mistral Nemo | +| mistralai-Mistral-Large-Instruct-2411.jinja | Generic | +| mistralai-Mistral-Nemo-Instruct-2407.jinja | Mistral Nemo | +| mistralai-Mistral-Small-24B-Instruct-2501.jinja | Generic | +| mistralai-Mixtral-8x7B-Instruct-v0.1.jinja | Generic | +| mkurman-Qwen2.5-14B-DeepSeek-R1-1M.jinja | Hermes 2 Pro | +| mlabonne-AlphaMonarch-7B.jinja | Generic | +| mlx-community-Josiefied-Qwen2.5-0.5B-Instruct-abliterated-v1-float32.jinja | Hermes 2 Pro | +| mlx-community-Qwen2.5-VL-7B-Instruct-8bit.jinja | Hermes 2 Pro | +| mobiuslabsgmbh-DeepSeek-R1-ReDistill-Qwen-1.5B-v1.1.jinja | DeepSeek R1 (extract reasoning) | +| netcat420-MFANNv0.20.jinja | Generic | +| netcat420-MFANNv0.24.jinja | Generic | +| netease-youdao-Confucius-o1-14B.jinja | Hermes 2 Pro | +| nvidia-AceMath-7B-RM.jinja | Hermes 2 Pro | +| nvidia-Eagle2-1B.jinja | Hermes 2 Pro | +| nvidia-Eagle2-9B.jinja | Hermes 2 Pro | +| nvidia-Llama-3.1-Nemotron-70B-Instruct-HF.jinja | Llama 3.x | +| onnx-community-DeepSeek-R1-Distill-Qwen-1.5B-ONNX.jinja | DeepSeek R1 (extract reasoning) | +| open-thoughts-OpenThinker-7B.jinja | Hermes 2 Pro | +| openchat-openchat-3.5-0106.jinja | Generic | +| pankajmathur-orca_mini_v6_8b.jinja | Generic | +| princeton-nlp-Mistral-7B-Base-SFT-RDPO.jinja | Generic | +| princeton-nlp-Mistral-7B-Instruct-DPO.jinja | Generic | +| princeton-nlp-Mistral-7B-Instruct-RDPO.jinja | Generic | +| prithivMLmods-Bellatrix-Tiny-1.5B-R1.jinja | Hermes 2 Pro | +| prithivMLmods-Bellatrix-Tiny-1B-R1.jinja | Llama 3.x | +| prithivMLmods-Bellatrix-Tiny-1B-v3.jinja | Generic | +| prithivMLmods-Bellatrix-Tiny-3B-R1.jinja | Llama 3.x | +| prithivMLmods-Blaze-14B-xElite.jinja | Generic | +| prithivMLmods-Calcium-Opus-14B-Elite2-R1.jinja | Hermes 2 Pro | +| prithivMLmods-Calme-Ties-78B.jinja | Generic | +| prithivMLmods-Calme-Ties2-78B.jinja | Generic | +| prithivMLmods-Calme-Ties3-78B.jinja | Generic | +| prithivMLmods-ChemQwen2-vL.jinja | Generic | +| prithivMLmods-GWQ2b.jinja | Generic | +| prithivMLmods-LatexMind-2B-Codec.jinja | Generic | +| prithivMLmods-Llama-3.2-6B-AlgoCode.jinja | Llama 3.x | +| prithivMLmods-Megatron-Opus-14B-Exp.jinja | Hermes 2 Pro | +| prithivMLmods-Megatron-Opus-14B-Stock.jinja | Hermes 2 Pro | +| prithivMLmods-Megatron-Opus-7B-Exp.jinja | Hermes 2 Pro | +| prithivMLmods-Omni-Reasoner-Merged.jinja | Hermes 2 Pro | +| prithivMLmods-Omni-Reasoner4-Merged.jinja | Hermes 2 Pro | +| prithivMLmods-Primal-Opus-14B-Optimus-v1.jinja | Hermes 2 Pro | +| prithivMLmods-QwQ-Math-IO-500M.jinja | Hermes 2 Pro | +| prithivMLmods-Qwen-7B-Distill-Reasoner.jinja | DeepSeek R1 (extract reasoning) | +| prithivMLmods-Qwen2.5-1.5B-DeepSeek-R1-Instruct.jinja | Hermes 2 Pro | +| prithivMLmods-Qwen2.5-14B-DeepSeek-R1-1M.jinja | Hermes 2 Pro | +| prithivMLmods-Qwen2.5-32B-DeepSeek-R1-Instruct.jinja | Hermes 2 Pro | +| prithivMLmods-Qwen2.5-7B-DeepSeek-R1-1M.jinja | Hermes 2 Pro | +| prithivMLmods-Triangulum-v2-10B.jinja | Hermes 2 Pro | +| qingy2024-Falcon3-2x10B-MoE-Instruct.jinja | Hermes 2 Pro | +| rubenroy-Zurich-14B-GCv2-5m.jinja | Hermes 2 Pro | +| rubenroy-Zurich-7B-GCv2-5m.jinja | Hermes 2 Pro | +| silma-ai-SILMA-Kashif-2B-Instruct-v1.0.jinja | Generic | +| simplescaling-s1-32B.jinja | Hermes 2 Pro | +| sometimesanotion-Lamarck-14B-v0.7.jinja | Hermes 2 Pro | +| sonthenguyen-zephyr-sft-bnb-4bit-DPO-mtbr-180steps.jinja | Generic | +| sthenno-tempesthenno-icy-0130.jinja | Generic | +| sumink-qwft.jinja | Hermes 2 Pro | +| teknium-OpenHermes-2.5-Mistral-7B.jinja | Generic | +| thirdeyeai-elevate360m.jinja | Generic | +| tiiuae-Falcon3-10B-Instruct.jinja | Hermes 2 Pro | +| unsloth-DeepSeek-R1-Distill-Llama-8B-unsloth-bnb-4bit.jinja | DeepSeek R1 (extract reasoning) | +| unsloth-DeepSeek-R1-Distill-Llama-8B.jinja | DeepSeek R1 (extract reasoning) | +| unsloth-DeepSeek-R1.jinja | DeepSeek R1 (extract reasoning) | +| unsloth-Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit.jinja | Generic | +| upstage-solar-pro-preview-instruct.jinja | Generic | +| whyhow-ai-PatientSeek.jinja | Generic | +| xwen-team-Xwen-72B-Chat.jinja | Hermes 2 Pro | +| xwen-team-Xwen-7B-Chat.jinja | Hermes 2 Pro | + +This table can be generated with: + +```bash +./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null +``` + +
+ +# Usage - need tool-aware Jinja template + +First, start a server with any model, but make sure it has a tools-enabled template: you can verify this by inspecting the `chat_template` or `chat_template_tool_use` properties in `http://localhost:8080/props`). + +Here are some models known to work (w/ chat template override when needed): + +```shell +# Native support: + +llama-server --jinja -fa -hf bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M +llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q6_K_L +llama-server --jinja -fa -hf bartowski/Llama-3.3-70B-Instruct-GGUF:Q4_K_M + +# Native support for DeepSeek R1 works best w/ our template override (official template is buggy, although we do work around it) + +llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q6_K_L \ + --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja + +llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-32B-GGUF:Q4_K_M \ + --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja + +# Native support requires the right template for these GGUFs: + +llama-server --jinja -fa -hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M + --chat-template-file models/templates/meetkai-functionary-medium-v3.2.jinja + +llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \ + --chat-template-file models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja + +llama-server --jinja -fa -hf bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M \ + --chat-template-file models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja + +llama-server --jinja -fa -hf bartowski/firefunction-v2-GGUF -hff firefunction-v2-IQ1_M.gguf \ + --chat-template-file models/templates/fireworks-ai-llama-3-firefunction-v2.jinja + +llama-server --jinja -fa -hf bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L \ + --chat-template-file models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja + +# Generic format support +llama-server --jinja -fa -hf bartowski/phi-4-GGUF:Q4_0 +llama-server --jinja -fa -hf bartowski/gemma-2-2b-it-GGUF:Q8_0 +llama-server --jinja -fa -hf bartowski/c4ai-command-r-v01-GGUF:Q2_K +``` + +To get the official template from original HuggingFace repos, you can use [scripts/get_chat_template.py](../scripts/get_chat_template.py) (see examples invocations in [models/templates/README.md](../models/templates/README.md)) + +> [!TIP] +> If there is no official `tool_use` Jinja template, you may want to set `--chat-template chatml` to use a default that works with many models (YMMV!), or write your own (e.g. we provide a custom [llama-cpp-deepseek-r1.jinja](../models/templates/llama-cpp-deepseek-r1.jinja) for DeepSeek R1 distills) + +> [!CAUTION] +> Beware of extreme KV quantizations (e.g. `-ctk q4_0`), they can substantially degrade the model's tool calling performance. + +Test in CLI (or with any library / software that can use OpenAI-compatible API backends): + +```bash +curl http://localhost:8080/v1/chat/completions -d '{ + "model": "gpt-3.5-turbo", + "tools": [ + { + "type":"function", + "function":{ + "name":"python", + "description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", + "parameters":{ + "type":"object", + "properties":{ + "code":{ + "type":"string", + "description":"The code to run in the ipython interpreter." + } + }, + "required":["code"] + } + } + } + ], + "messages": [ + { + "role": "user", + "content": "Print a hello world message with python." + } + ] +}' + + +curl http://localhost:8080/v1/chat/completions -d '{ + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."}, + {"role": "user", "content": "What is the weather in Istanbul?"} + ], + "tools": [{ + "type":"function", + "function":{ + "name":"get_current_weather", + "description":"Get the current weather in a given location", + "parameters":{ + "type":"object", + "properties":{ + "location":{ + "type":"string", + "description":"The city and country/state, e.g. `San Francisco, CA`, or `Paris, France`" + } + }, + "required":["location"] + } + } + }] +}' +``` + +
+Show output + +```json +{ +"choices": [ + { + "finish_reason": "tool", + "index": 0, + "message": { + "content": null, + "tool_calls": [ + { + "name": "python", + "arguments": "{\"code\":\" \\nprint(\\\"Hello, World!\\\")\"}" + } + ], + "role": "assistant" + } + } +], +"created": 1727287211, +"model": "gpt-3.5-turbo", +"object": "chat.completion", +"usage": { + "completion_tokens": 16, + "prompt_tokens": 44, + "total_tokens": 60 +}, +"id": "chatcmpl-Htbgh9feMmGM0LEH2hmQvwsCxq3c6Ni8" +} +``` + +
diff --git a/docs/install.md b/docs/install.md new file mode 100644 index 0000000000000000000000000000000000000000..7200bf9b7b91d258f3fcb5b33d919032a41e9150 --- /dev/null +++ b/docs/install.md @@ -0,0 +1,50 @@ +# Install pre-built version of llama.cpp + +| Install via | Windows | Mac | Linux | +|-------------|---------|-----|-------| +| Winget | ✅ | | | +| Homebrew | | ✅ | ✅ | +| MacPorts | | ✅ | | +| Nix | | ✅ | ✅ | + +## Winget (Windows) + +```sh +winget install llama.cpp +``` + +The package is automatically updated with new `llama.cpp` releases. More info: https://github.com/ggml-org/llama.cpp/issues/8188 + +## Homebrew (Mac and Linux) + +```sh +brew install llama.cpp +``` + +The formula is automatically updated with new `llama.cpp` releases. More info: https://github.com/ggml-org/llama.cpp/discussions/7668 + +## MacPorts (Mac) + +```sh +sudo port install llama.cpp +``` + +See also: https://ports.macports.org/port/llama.cpp/details/ + +## Nix (Mac and Linux) + +```sh +nix profile install nixpkgs#llama-cpp +``` + +For flake enabled installs. + +Or + +```sh +nix-env --file '' --install --attr llama-cpp +``` + +For non-flake enabled installs. + +This expression is automatically updated within the [nixpkgs repo](https://github.com/NixOS/nixpkgs/blob/nixos-24.05/pkgs/by-name/ll/llama-cpp/package.nix#L164). diff --git a/docs/llguidance.md b/docs/llguidance.md new file mode 100644 index 0000000000000000000000000000000000000000..cda787b14de047a2803612f73ebd1ff74e6afe8d --- /dev/null +++ b/docs/llguidance.md @@ -0,0 +1,53 @@ +# LLGuidance Support in llama.cpp + +[LLGuidance](https://github.com/guidance-ai/llguidance) is a library for constrained decoding (also called constrained sampling or structured outputs) for Large Language Models (LLMs). Initially developed as the backend for the [Guidance](https://github.com/guidance-ai/guidance) library, it can also be used independently. + +LLGuidance supports JSON Schemas and arbitrary context-free grammars (CFGs) written in a [variant](https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md) of Lark syntax. It is [very fast](https://github.com/guidance-ai/jsonschemabench/tree/main/maskbench) and has [excellent](https://github.com/guidance-ai/llguidance/blob/main/docs/json_schema.md) JSON Schema coverage but requires the Rust compiler, which complicates the llama.cpp build process. + +## Building + +To enable LLGuidance support, build llama.cpp with the `LLAMA_LLGUIDANCE` option: + +```sh +cmake -B build -DLLAMA_LLGUIDANCE=ON +make -C build -j +``` + +For Windows use `cmake --build build --config Release` instead of `make`. + +This requires the Rust compiler and the `cargo` tool to be [installed](https://www.rust-lang.org/tools/install). + +## Interface + +There are no new command-line arguments or modifications to `common_params`. When enabled, grammars starting with `%llguidance` are passed to LLGuidance instead of the [current](../grammars/README.md) llama.cpp grammars. Additionally, JSON Schema requests (e.g., using the `-j` argument in `llama-cli`) are also passed to LLGuidance. + +For your existing GBNF grammars, you can use [gbnf_to_lark.py script](https://github.com/guidance-ai/llguidance/blob/main/python/llguidance/gbnf_to_lark.py) to convert them to LLGuidance Lark-like format. + +## Performance + +Computing a "token mask" (i.e., the set of allowed tokens) for a llama3 tokenizer with 128k tokens takes, on average, 50μs of single-core CPU time for the [JSON Schema Bench](https://github.com/guidance-ai/jsonschemabench). The p99 time is 0.5ms, and the p100 time is 20ms. These results are due to the lexer/parser split and several [optimizations](https://github.com/guidance-ai/llguidance/blob/main/docs/optimizations.md). + +## JSON Schema + +LLGuidance adheres closely to the JSON Schema specification. For example: + +- `additionalProperties` defaults to `true`, unlike current grammars, though you can set `"additionalProperties": false` if needed. +- any whitespace is allowed. +- The definition order in the `"properties": {}` object is maintained, regardless of whether properties are required (current grammars always puts required properties first). + +Unsupported schemas result in an error message—no keywords are silently ignored. + +## Why Not Reuse GBNF Format? + +GBNF lacks the concept of a lexer. + +Most programming languages, including JSON, use a two-step process: a lexer (built with regular expressions) converts a byte stream into lexemes, which are then processed by a CFG parser. This approach is faster because lexers are cheaper to evaluate, and there is ~10x fewer lexemes than bytes. +LLM tokens often align with lexemes, so the parser is engaged in under 0.5% of tokens, with the lexer handling the rest. + +However, the user has to provide the distinction between lexemes and CFG symbols. In [Lark](https://github.com/lark-parser/lark), lexeme names are uppercase, while CFG symbols are lowercase. +The [gbnf_to_lark.py script](https://github.com/guidance-ai/llguidance/blob/main/scripts/gbnf_to_lark.py) can often take care of this automatically. +See [LLGuidance syntax docs](https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#terminals-vs-rules) for more details. + +## Error Handling + +Errors are currently printed to `stderr`, and generation continues. Improved error handling may be added in the future. diff --git a/docs/multimodal.md b/docs/multimodal.md new file mode 100644 index 0000000000000000000000000000000000000000..edbd081df796908c28861929bd9c5b1da8447d68 --- /dev/null +++ b/docs/multimodal.md @@ -0,0 +1,113 @@ +# Multimodal + +llama.cpp supports multimodal input via `libmtmd`. Currently, there are 2 tools support this feature: +- [llama-mtmd-cli](../tools/mtmd/README.md) +- [llama-server](../tools/server/README.md) via OpenAI-compatible `/chat/completions` API + +Currently, we support **image** and **audio** input. Audio is highly experimental and may have reduced quality. + +To enable it, you can use one of the 2 methods below: + +- Use `-hf` option with a supported model (see a list of pre-quantized model below) + - To load a model using `-hf` while disabling multimodal, use `--no-mmproj` + - To load a model using `-hf` while using a custom mmproj file, use `--mmproj local_file.gguf` +- Use `-m model.gguf` option with `--mmproj file.gguf` to specify text and multimodal projector respectively + +By default, multimodal projector will be offloaded to GPU. To disable this, add `--no-mmproj-offload` + +For example: + +```sh +# simple usage with CLI +llama-mtmd-cli -hf ggml-org/gemma-3-4b-it-GGUF + +# simple usage with server +llama-server -hf ggml-org/gemma-3-4b-it-GGUF + +# using local file +llama-server -m gemma-3-4b-it-Q4_K_M.gguf --mmproj mmproj-gemma-3-4b-it-Q4_K_M.gguf + +# no GPU offload +llama-server -hf ggml-org/gemma-3-4b-it-GGUF --no-mmproj-offload +``` + +## Pre-quantized models + +These are ready-to-use models, most of them come with `Q4_K_M` quantization by default. They can be found at the Hugging Face page of the ggml-org: https://huggingface.co/collections/ggml-org/multimodal-ggufs-68244e01ff1f39e5bebeeedc + +Replaces the `(tool_name)` with the name of binary you want to use. For example, `llama-mtmd-cli` or `llama-server` + +NOTE: some models may require large context window, for example: `-c 8192` + +**Vision models**: + +```sh +# Gemma 3 +(tool_name) -hf ggml-org/gemma-3-4b-it-GGUF +(tool_name) -hf ggml-org/gemma-3-12b-it-GGUF +(tool_name) -hf ggml-org/gemma-3-27b-it-GGUF + +# SmolVLM +(tool_name) -hf ggml-org/SmolVLM-Instruct-GGUF +(tool_name) -hf ggml-org/SmolVLM-256M-Instruct-GGUF +(tool_name) -hf ggml-org/SmolVLM-500M-Instruct-GGUF +(tool_name) -hf ggml-org/SmolVLM2-2.2B-Instruct-GGUF +(tool_name) -hf ggml-org/SmolVLM2-256M-Video-Instruct-GGUF +(tool_name) -hf ggml-org/SmolVLM2-500M-Video-Instruct-GGUF + +# Pixtral 12B +(tool_name) -hf ggml-org/pixtral-12b-GGUF + +# Qwen 2 VL +(tool_name) -hf ggml-org/Qwen2-VL-2B-Instruct-GGUF +(tool_name) -hf ggml-org/Qwen2-VL-7B-Instruct-GGUF + +# Qwen 2.5 VL +(tool_name) -hf ggml-org/Qwen2.5-VL-3B-Instruct-GGUF +(tool_name) -hf ggml-org/Qwen2.5-VL-7B-Instruct-GGUF +(tool_name) -hf ggml-org/Qwen2.5-VL-32B-Instruct-GGUF +(tool_name) -hf ggml-org/Qwen2.5-VL-72B-Instruct-GGUF + +# Mistral Small 3.1 24B (IQ2_M quantization) +(tool_name) -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF + +# InternVL 2.5 and 3 +(tool_name) -hf ggml-org/InternVL2_5-1B-GGUF +(tool_name) -hf ggml-org/InternVL2_5-4B-GGUF +(tool_name) -hf ggml-org/InternVL3-1B-Instruct-GGUF +(tool_name) -hf ggml-org/InternVL3-2B-Instruct-GGUF +(tool_name) -hf ggml-org/InternVL3-8B-Instruct-GGUF +(tool_name) -hf ggml-org/InternVL3-14B-Instruct-GGUF + +# Llama 4 Scout +(tool_name) -hf ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF + +# Moondream2 20250414 version +(tool_name) -hf ggml-org/moondream2-20250414-GGUF + +``` + +**Audio models**: + +```sh +# Ultravox 0.5 +(tool_name) -hf ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF +(tool_name) -hf ggml-org/ultravox-v0_5-llama-3_1-8b-GGUF + +# Qwen2-Audio and SeaLLM-Audio +# note: no pre-quantized GGUF this model, as they have very poor result +# ref: https://github.com/ggml-org/llama.cpp/pull/13760 +``` + +**Mixed modalities**: + +```sh +# Qwen2.5 Omni +# Capabilities: audio input, vision input +(tool_name) -hf ggml-org/Qwen2.5-Omni-3B-GGUF +(tool_name) -hf ggml-org/Qwen2.5-Omni-7B-GGUF +``` + +## Finding more models: + +GGUF models on Huggingface with vision capabilities can be found here: https://huggingface.co/models?pipeline_tag=image-text-to-text&sort=trending&search=gguf diff --git a/docs/multimodal/MobileVLM.md b/docs/multimodal/MobileVLM.md new file mode 100644 index 0000000000000000000000000000000000000000..4f5eca61906572a010a17626e3ab5611e1d863b4 --- /dev/null +++ b/docs/multimodal/MobileVLM.md @@ -0,0 +1,377 @@ +# MobileVLM + +Currently this implementation supports [MobileVLM-1.7B](https://huggingface.co/mtgv/MobileVLM-1.7B) / [MobileVLM_V2-1.7B](https://huggingface.co/mtgv/MobileVLM_V2-1.7B) variants. + +for more information, please go to [Meituan-AutoML/MobileVLM](https://github.com/Meituan-AutoML/MobileVLM) + +The implementation is based on llava, and is compatible with llava and mobileVLM. The usage is basically same as llava. + +Notice: The overall process of model inference for both **MobileVLM** and **MobileVLM_V2** models is the same, but the process of model conversion is a little different. Therefore, using **MobileVLM-1.7B** as an example, the different conversion step will be shown. + +## Usage + +Build the `llama-mtmd-cli` binary. + +After building, run: `./llama-mtmd-cli` to see the usage. For example: + +```sh +./llama-mtmd-cli -m MobileVLM-1.7B/ggml-model-q4_k.gguf \ + --mmproj MobileVLM-1.7B/mmproj-model-f16.gguf \ + --chat-template deepseek +``` + +## Model conversion + +1. Clone `mobileVLM-1.7B` and `clip-vit-large-patch14-336` locally: + +```sh +git clone https://huggingface.co/mtgv/MobileVLM-1.7B + +git clone https://huggingface.co/openai/clip-vit-large-patch14-336 +``` + +2. Use `llava_surgery.py` to split the LLaVA model to LLaMA and multimodel projector constituents: + +```sh +python ./tools/mtmd/llava_surgery.py -m path/to/MobileVLM-1.7B +``` + +3. Use `convert_image_encoder_to_gguf.py` with `--projector-type ldp` (for **V2** please use `--projector-type ldpv2`) to convert the LLaVA image encoder to GGUF: + +```sh +python ./tools/mtmd/convert_image_encoder_to_gguf.py \ + -m path/to/clip-vit-large-patch14-336 \ + --llava-projector path/to/MobileVLM-1.7B/llava.projector \ + --output-dir path/to/MobileVLM-1.7B \ + --projector-type ldp +``` + +```sh +python ./tools/mtmd/convert_image_encoder_to_gguf.py \ + -m path/to/clip-vit-large-patch14-336 \ + --llava-projector path/to/MobileVLM-1.7B_V2/llava.projector \ + --output-dir path/to/MobileVLM-1.7B_V2 \ + --projector-type ldpv2 +``` + +4. Use `examples/convert_legacy_llama.py` to convert the LLaMA part of LLaVA to GGUF: + +```sh +python ./examples/convert_legacy_llama.py path/to/MobileVLM-1.7B --skip-unknown +``` + +5. Use `quantize` to convert LLaMA part's DataType from `fp32` to `q4_k` +```sh +./llama-quantize path/to/MobileVLM-1.7B/ggml-model-F32.gguf path/to/MobileVLM-1.7B/ggml-model-q4_k.gguf q4_k_s +``` + +Now both the LLaMA part and the image encoder is in the `MobileVLM-1.7B` directory. + +## Android compile and run +### compile +refer to `tools/mtmd/android/build_64.sh` +```sh +mkdir tools/mtmd/android/build_64 +cd tools/mtmd/android/build_64 +../build_64.sh +``` +### run on Android +refer to `android/adb_run.sh`, modify resources' `name` and `path` + +## Some result on Android with `Snapdragon 888` chip +### case 1 +**input** +```sh +/data/local/tmp/llama-mtmd-cli \ + -m /data/local/tmp/ggml-model-q4_k.gguf \ + --mmproj /data/local/tmp/mmproj-model-f16.gguf \ + -t 4 \ + --image /data/local/tmp/demo.jpg \ + -p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWho is the author of this book? \nAnswer the question using a single word or phrase. ASSISTANT:" +``` +**output** +```sh +encode_image_with_clip: image encoded in 21148.71 ms by CLIP ( 146.87 ms per image patch) + Susan Wise Bauer +llama_print_timings: load time = 23574.72 ms +llama_print_timings: sample time = 1.24 ms / 6 runs ( 0.21 ms per token, 4850.44 tokens per second) +llama_print_timings: prompt eval time = 12460.15 ms / 246 tokens ( 50.65 ms per token, 19.74 tokens per second) +llama_print_timings: eval time = 424.86 ms / 6 runs ( 70.81 ms per token, 14.12 tokens per second) +llama_print_timings: total time = 34731.93 ms +``` +### case 2 +**input** +```sh +/data/local/tmp/llama-mtmd-cli \ + -m /data/local/tmp/ggml-model-q4_k.gguf \ + --mmproj /data/local/tmp/mmproj-model-f16.gguf \ + -t 4 \ + --image /data/local/tmp/cat.jpeg \ + -p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWhat is in the image? ASSISTANT:" +``` +**output** +```sh +encode_image_with_clip: image encoded in 21149.51 ms by CLIP ( 146.87 ms per image patch) + The image depicts a cat sitting in the grass near some tall green plants. +llama_print_timings: load time = 23257.32 ms +llama_print_timings: sample time = 5.25 ms / 18 runs ( 0.29 ms per token, 3430.53 tokens per second) +llama_print_timings: prompt eval time = 11900.73 ms / 232 tokens ( 51.30 ms per token, 19.49 tokens per second) +llama_print_timings: eval time = 1279.03 ms / 18 runs ( 71.06 ms per token, 14.07 tokens per second) +llama_print_timings: total time = 34570.79 ms +``` + + +## Some result on Android with `Snapdragon 778G` chip +### MobileVLM-1.7B case +#### mtmd-cli release-b2005 +**input** +```sh +/data/local/tmp/llama-mtmd-cli \ + -m /data/local/tmp/ggml-model-q4_k.gguf \ + --mmproj /data/local/tmp/mmproj-model-f16.gguf \ + -t 4 \ + --image /data/local/tmp/many_llamas.jpeg \ + -p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWhat's that? ASSISTANT:" +``` +**output** +```sh +encode_image_with_clip: image encoded in 18728.52 ms by CLIP ( 130.06 ms per image patch) +system_prompt: A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: +user_prompt: \nWhat's that? ASSISTANT: + + A group of llamas are standing in a green pasture. + +llama_print_timings: load time = 20357.33 ms +llama_print_timings: sample time = 2.96 ms / 14 runs ( 0.21 ms per token, 4734.53 tokens per second) +llama_print_timings: prompt eval time = 8119.49 ms / 191 tokens ( 42.51 ms per token, 23.52 tokens per second) +llama_print_timings: eval time = 1005.75 ms / 14 runs ( 71.84 ms per token, 13.92 tokens per second) +llama_print_timings: total time = 28038.34 ms / 205 tokens +``` +#### mtmd-cli latest-version +**input** + +Just the same as above. + +**output**(seems to be much slower) +```sh +encode_image_with_clip: image embedding created: 144 tokens + +encode_image_with_clip: image encoded in 288268.88 ms by CLIP ( 2001.87 ms per image patch) +system_prompt: A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: +user_prompt: \nWhat's that? ASSISTANT: + + It is a group of sheep standing together in a grass field. + +llama_print_timings: load time = 818120.91 ms +llama_print_timings: sample time = 3.44 ms / 14 runs ( 0.25 ms per token, 4067.40 tokens per second) +llama_print_timings: prompt eval time = 529274.69 ms / 191 tokens ( 2771.07 ms per token, 0.36 tokens per second) +llama_print_timings: eval time = 43894.02 ms / 13 runs ( 3376.46 ms per token, 0.30 tokens per second) +llama_print_timings: total time = 865441.76 ms / 204 tokens +``` +### MobileVLM_V2-1.7B case +#### mtmd-cli release-2005b +**input** + +Just the same as above. + +**output** +```sh +encode_image_with_clip: image encoded in 20609.61 ms by CLIP ( 143.12 ms per image patch) +system_prompt: A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: +user_prompt: \nWhat's that? ASSISTANT: + + This image captures a lively scene of 20 llamas in motion on an expansive, grassy field. The llama is scattered across the landscape with some standing and others sitting down as if taking rest or observing their surroundings from different vantage points within this verdant setting. + +The background offers glimpses into a picturesque town nestled amidst hills under an overcast sky, adding depth to the scene while also emphasizing that distance between these llama and human-made structures like houses or roads in which they roam freely without any barriers around them. The image is framed by text at both right angles on white backgrounds against a contrasting blue backdrop with green foliage, further drawing attention to the llamas amidst their natural habitat while also inviting viewers into this picturesque landscape within town limits of Alta Llama + +llama_print_timings: load time = 22406.77 ms +llama_print_timings: sample time = 49.26 ms / 186 runs ( 0.26 ms per token, 3776.27 tokens per second) +llama_print_timings: prompt eval time = 9044.54 ms / 191 tokens ( 47.35 ms per token, 21.12 tokens per second) +llama_print_timings: eval time = 14497.49 ms / 186 runs ( 77.94 ms per token, 12.83 tokens per second) +llama_print_timings: total time = 44411.01 ms / 377 tokens +``` + +## Orin compile and run +### compile +```sh +make GGML_CUDA=1 CUDA_DOCKER_ARCH=sm_87 GGML_CUDA_F16=1 -j 32 +``` +### run on Orin +### case 1 +**input** +```sh +./llama-mtmd-cli \ + -m /data/local/tmp/ggml-model-q4_k.gguf \ + --mmproj /data/local/tmp/mmproj-model-f16.gguf \ + --image /data/local/tmp/demo.jpeg \ + -p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWho is the author of this book? \nAnswer the question using a single word or phrase. ASSISTANT:" \ + --n-gpu-layers 999 +``` +**output** +```sh + +encode_image_with_clip: image encoded in 296.62 ms by CLIP ( 2.06 ms per image patch) + + Susan Wise Bauer + +llama_print_timings: load time = 1067.64 ms +llama_print_timings: sample time = 1.53 ms / 6 runs ( 0.25 ms per token, 3934.43 tokens per second) +llama_print_timings: prompt eval time = 306.84 ms / 246 tokens ( 1.25 ms per token, 801.72 tokens per second) +llama_print_timings: eval time = 91.50 ms / 6 runs ( 15.25 ms per token, 65.58 tokens per second) +llama_print_timings: total time = 1352.63 ms / 252 tokens +``` + +### case 2 +**input** +```sh +./llama-mtmd-cli \ + -m /data/local/tmp/ggml-model-q4_k.gguf \ + --mmproj /data/local/tmp/mmproj-model-f16.gguf \ + -p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWhat is in the image? ASSISTANT:" \ + --n-gpu-layers 999 + +``` +**output** +```sh +encode_image_with_clip: image encoded in 302.15 ms by CLIP ( 2.10 ms per image patch) + + The image features a cat lying in the grass. + +llama_print_timings: load time = 1057.07 ms +llama_print_timings: sample time = 3.27 ms / 11 runs ( 0.30 ms per token, 3360.83 tokens per second) +llama_print_timings: prompt eval time = 213.60 ms / 232 tokens ( 0.92 ms per token, 1086.14 tokens per second) +llama_print_timings: eval time = 166.65 ms / 11 runs ( 15.15 ms per token, 66.01 tokens per second) +llama_print_timings: total time = 1365.47 ms / 243 tokens +``` + +## Running on Intel(R) Core(TM) i7-10750H +### Operating system +Ubuntu22.04 +### compile +```sh +make -j32 +``` +### MobileVLM-1.7B case +**input** +```sh +-m /path/to/ggml-model-q4_k.gguf \ + --mmproj /path/to/mmproj-model-f16.gguf \ + --image /path/to/many_llamas.jpeg + -p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWhat's that? ASSISTANT:" \ +``` +**output** +```sh +encode_image_with_clip: image embedding created: 144 tokens + +encode_image_with_clip: image encoded in 2730.94 ms by CLIP ( 18.96 ms per image patch) +system_prompt: A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: +user_prompt: \nWhat's that?ASSISTANT: + + A group of llamas are walking together in a field. + +llama_print_timings: load time = 5506.60 ms +llama_print_timings: sample time = 0.44 ms / 13 runs ( 0.03 ms per token, 29545.45 tokens per second) +llama_print_timings: prompt eval time = 2031.58 ms / 190 tokens ( 10.69 ms per token, 93.52 tokens per second) +llama_print_timings: eval time = 438.92 ms / 12 runs ( 36.58 ms per token, 27.34 tokens per second) +llama_print_timings: total time = 5990.25 ms / 202 tokens +``` + +### MobileVLM_V2-1.7B case +**input** + +Just the same as above. + +**ouput** +```sh +encode_image_with_clip: image embedding created: 144 tokens + +encode_image_with_clip: image encoded in 3223.89 ms by CLIP ( 22.39 ms per image patch) +system_prompt: A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: +user_prompt: \nWhat's that?ASSISTANT: + + The image captures a tranquil scene in a park, where a group of approximately 20 llamas are gathered. The llamas, a mix of white and black, are standing in a line, their black and white patterns contrasting with the lush green grass of the park. The lamas are arranged in a line, suggesting a social order. + +The park itself is lush and green, with trees dotting the landscape in the background. A sign reading "Llamas Tico Ana" is also visible in the image, possibly indicating the location or the breed of the llamas. The image seems to be taken from a distance, providing a wide view of the scene and the surrounding environment. + +The llamas' positions relative to each other, the sign, and the trees create a harmonious composition. The image does not contain any discernible text. The overall scene is one of peace and natural beauty, with the llamas in their natural habitat, surrounded by the vibrant colors and lush greenery of the park. + +llama_print_timings: load time = 6642.61 ms +llama_print_timings: sample time = 8.15 ms / 223 runs ( 0.04 ms per token, 27358.61 tokens per second) +llama_print_timings: prompt eval time = 2475.07 ms / 190 tokens ( 13.03 ms per token, 76.77 tokens per second) +llama_print_timings: eval time = 8760.60 ms / 222 runs ( 39.46 ms per token, 25.34 tokens per second) +llama_print_timings: total time = 15513.95 ms / 412 tokens +``` + +## Run on Intel(R) Core(TM) Ultra7 115H +### operation system +Windows11 +### comiple +```sh +make -j32 +``` +### MobileVLM-1.7B case +**input** +```sh +-m /path/to/ggml-model-q4_k.gguf \ + --mmproj /path/to/tmp/mmproj-model-f16.gguf \ + -p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWhat's that? ASSISTANT:" \ +``` +**output** +```sh +encode_image_with_clip: image encoded in 4902.81 ms by CLIP ( 34.05 ms per image patch) +system_prompt: A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: +user_prompt: \nWhat's that? ASSISTANT: + + The image features a group of brown and white llamas standing in a grassy field. + +llama_print_timings: load time = 7441.06 ms +llama_print_timings: sample time = 0.72 ms / 19 runs ( 0.04 ms per token, 26279.39 tokens per second) +llama_print_timings: prompt eval time = 2090.71 ms / 191 tokens ( 10.95 ms per token, 91.36 tokens per second) +llama_print_timings: eval time = 512.35 ms / 18 runs ( 28.46 ms per token, 35.13 tokens per second) +llama_print_timings: total time = 7987.23 ms / 209 tokens +``` + +### MobileVLM_V2-1.7B case +**input** + +Just the same as above. + +**output** +```sh +encode_image_with_clip: image encoded in 4682.44 ms by CLIP ( 32.52 ms per image patch) +system_prompt: A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: +user_prompt: \nWhat's that? ASSISTANT: + + This image captures a lively scene of a group of 14 llamas in a grassy field. The llamas, with their distinctive black and white coats, are standing and walking in a line, seemingly engaged in a social activity. One + of them, possibly the first in the line, has its back turned, perhaps observing something in the distance. + +The llama in the front of the line stands out due to its black and white coloring, which is quite unusual for llama patterns. The llama in the front also seems to be more aware of its surroundings, as it faces the camera, giving a sense of engagement with the viewer. + +The image is taken from the side of the llama, providing a clear view of the llama in the front and its companions. The lameness in the llama in + front is not visible, indicating that it might not be the main focus of the photo. + +The background of the image features a grassy field, with a fence and a tree visible in the distance. The tree appears to be bare, suggesting that it might be during a time of year when most trees are dormant or have shed their leaves. + + +llama_print_timings: load time = 7015.35 ms +llama_print_timings: sample time = 10.61 ms / 256 runs ( 0.04 ms per token, 24119.09 tokens per second) +llama_print_timings: prompt eval time = 2052.45 ms / 191 tokens ( 10.75 ms per token, 93.06 tokens per second) +llama_print_timings: eval time = 7259.43 ms / 255 runs ( 28.47 ms per token, 35.13 tokens per second) +llama_print_timings: total time = 14371.19 ms / 446 tokens +``` + +## TODO + +- [x] Support non-CPU backend for the new operators, such as `depthwise`, `hardswish`, `hardsigmoid` +- [ ] Optimize LDP projector performance + + - Optimize the structure definition to avoid unnecessary memory rearrangements, to reduce the use of `ggml_permute_cpy`; + - Optimize operator implementation (ARM CPU/NVIDIA GPU): such as depthwise conv, hardswish, hardsigmoid, etc. +- [x] run MobileVLM on `Jetson Orin` +- [ ] Support more model variants, such as `MobileVLM-3B`. + + +## contributor +```sh +zhangjidong05, yangyang260, huyiming03, chenxiaotao03, ZiangWu-77 +``` diff --git a/docs/multimodal/gemma3.md b/docs/multimodal/gemma3.md new file mode 100644 index 0000000000000000000000000000000000000000..110a36f40835d07ccf378f3ccf5979b06d2f3b05 --- /dev/null +++ b/docs/multimodal/gemma3.md @@ -0,0 +1,51 @@ +# Gemma 3 vision + +> [!IMPORTANT] +> +> This is very experimental, only used for demo purpose. + +## Quick started + +You can use pre-quantized model from [ggml-org](https://huggingface.co/ggml-org)'s Hugging Face account + +```bash +# build +cmake -B build +cmake --build build --target llama-mtmd-cli + +# alternatively, install from brew (MacOS) +brew install llama.cpp + +# run it +llama-mtmd-cli -hf ggml-org/gemma-3-4b-it-GGUF +llama-mtmd-cli -hf ggml-org/gemma-3-12b-it-GGUF +llama-mtmd-cli -hf ggml-org/gemma-3-27b-it-GGUF + +# note: 1B model does not support vision +``` + +## How to get mmproj.gguf? + +Simply to add `--mmproj` in when converting model via `convert_hf_to_gguf.py`: + +```bash +cd gemma-3-4b-it +python ../llama.cpp/convert_hf_to_gguf.py --outfile model.gguf --outtype f16 --mmproj . +# output file: mmproj-model.gguf +``` + +## How to run it? + +What you need: +- The text model GGUF, can be converted using `convert_hf_to_gguf.py` +- The mmproj file from step above +- An image file + +```bash +# build +cmake -B build +cmake --build build --target llama-mtmd-cli + +# run it +./build/bin/llama-mtmd-cli -m {text_model}.gguf --mmproj mmproj.gguf --image your_image.jpg +``` diff --git a/docs/multimodal/glmedge.md b/docs/multimodal/glmedge.md new file mode 100644 index 0000000000000000000000000000000000000000..7bae8315055c348ecc76bae1043911cc63c60a27 --- /dev/null +++ b/docs/multimodal/glmedge.md @@ -0,0 +1,43 @@ +# GLMV-EDGE + +Currently this implementation supports [glm-edge-v-2b](https://huggingface.co/THUDM/glm-edge-v-2b) and [glm-edge-v-5b](https://huggingface.co/THUDM/glm-edge-v-5b). + +## Usage +Build the `llama-mtmd-cli` binary. + +After building, run: `./llama-mtmd-cli` to see the usage. For example: + +```sh +./llama-mtmd-cli -m model_path/ggml-model-f16.gguf --mmproj model_path/mmproj-model-f16.gguf +``` + +**note**: A lower temperature like 0.1 is recommended for better quality. add `--temp 0.1` to the command to do so. +**note**: For GPU offloading ensure to use the `-ngl` flag just like usual + +## GGUF conversion + +1. Clone a GLMV-EDGE model ([2B](https://huggingface.co/THUDM/glm-edge-v-2b) or [5B](https://huggingface.co/THUDM/glm-edge-v-5b)). For example: + +```sh +git clone https://huggingface.co/THUDM/glm-edge-v-5b or https://huggingface.co/THUDM/glm-edge-v-2b +``` + +2. Use `glmedge-surgery.py` to split the GLMV-EDGE model to LLM and multimodel projector constituents: + +```sh +python ./tools/mtmd/glmedge-surgery.py -m ../model_path +``` + +4. Use `glmedge-convert-image-encoder-to-gguf.py` to convert the GLMV-EDGE image encoder to GGUF: + +```sh +python ./tools/mtmd/glmedge-convert-image-encoder-to-gguf.py -m ../model_path --llava-projector ../model_path/glm.projector --output-dir ../model_path +``` + +5. Use `examples/convert_hf_to_gguf.py` to convert the LLM part of GLMV-EDGE to GGUF: + +```sh +python convert_hf_to_gguf.py ../model_path +``` + +Now both the LLM part and the image encoder are in the `model_path` directory. diff --git a/docs/multimodal/granitevision.md b/docs/multimodal/granitevision.md new file mode 100644 index 0000000000000000000000000000000000000000..3118fe0cdc1138602433b0c4575205c8e9eba18c --- /dev/null +++ b/docs/multimodal/granitevision.md @@ -0,0 +1,186 @@ +# Granite Vision + +Download the model and point your `GRANITE_MODEL` environment variable to the path. + +```bash +$ git clone https://huggingface.co/ibm-granite/granite-vision-3.2-2b +$ export GRANITE_MODEL=./granite-vision-3.2-2b +``` + + +### 1. Running llava surgery v2. +First, we need to run the llava surgery script as shown below: + +`python llava_surgery_v2.py -C -m $GRANITE_MODEL` + +You should see two new files (`llava.clip` and `llava.projector`) written into your model's directory, as shown below. + +```bash +$ ls $GRANITE_MODEL | grep -i llava +llava.clip +llava.projector +``` + +We should see that the projector and visual encoder get split out into the llava files. Quick check to make sure they aren't empty: +```python +import os +import torch + +MODEL_PATH = os.getenv("GRANITE_MODEL") +if not MODEL_PATH: + raise ValueError("env var GRANITE_MODEL is unset!") + +encoder_tensors = torch.load(os.path.join(MODEL_PATH, "llava.clip")) +projector_tensors = torch.load(os.path.join(MODEL_PATH, "llava.projector")) + +assert len(encoder_tensors) > 0 +assert len(projector_tensors) > 0 +``` + +If you actually inspect the `.keys()` of the loaded tensors, you should see a lot of `vision_model` tensors in the `encoder_tensors`, and 5 tensors (`'multi_modal_projector.linear_1.bias'`, `'multi_modal_projector.linear_1.weight'`, `'multi_modal_projector.linear_2.bias'`, `'multi_modal_projector.linear_2.weight'`, `'image_newline'`) in the multimodal `projector_tensors`. + + +### 2. Creating the Visual Component GGUF +Next, create a new directory to hold the visual components, and copy the llava.clip/projector files, as shown below. + +```bash +$ ENCODER_PATH=$PWD/visual_encoder +$ mkdir $ENCODER_PATH + +$ cp $GRANITE_MODEL/llava.clip $ENCODER_PATH/pytorch_model.bin +$ cp $GRANITE_MODEL/llava.projector $ENCODER_PATH/ +``` + +Now, we need to write a config for the visual encoder. In order to convert the model, be sure to use the correct `image_grid_pinpoints`, as these may vary based on the model. You can find the `image_grid_pinpoints` in `$GRANITE_MODEL/config.json`. + +```json +{ + "_name_or_path": "siglip-model", + "architectures": [ + "SiglipVisionModel" + ], + "image_grid_pinpoints": [ + [384,384], + [384,768], + [384,1152], + [384,1536], + [384,1920], + [384,2304], + [384,2688], + [384,3072], + [384,3456], + [384,3840], + [768,384], + [768,768], + [768,1152], + [768,1536], + [768,1920], + [1152,384], + [1152,768], + [1152,1152], + [1536,384], + [1536,768], + [1920,384], + [1920,768], + [2304,384], + [2688,384], + [3072,384], + [3456,384], + [3840,384] + ], + "mm_patch_merge_type": "spatial_unpad", + "hidden_size": 1152, + "image_size": 384, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 14, + "layer_norm_eps": 1e-6, + "hidden_act": "gelu_pytorch_tanh", + "projection_dim": 0, + "vision_feature_layer": [-24, -20, -12, -1] +} +``` + +At this point you should have something like this: +```bash +$ ls $ENCODER_PATH +config.json llava.projector pytorch_model.bin +``` + +Now convert the components to GGUF; Note that we also override the image mean/std dev to `[.5,.5,.5]` since we use the SigLIP visual encoder - in the transformers model, you can find these numbers in the `preprocessor_config.json`. +```bash +$ python convert_image_encoder_to_gguf.py \ + -m $ENCODER_PATH \ + --llava-projector $ENCODER_PATH/llava.projector \ + --output-dir $ENCODER_PATH \ + --clip-model-is-vision \ + --clip-model-is-siglip \ + --image-mean 0.5 0.5 0.5 \ + --image-std 0.5 0.5 0.5 +``` + +This will create the first GGUF file at `$ENCODER_PATH/mmproj-model-f16.gguf`; we will refer to the absolute path of this file as the `$VISUAL_GGUF_PATH.` + + +### 3. Creating the LLM GGUF. +The granite vision model contains a granite LLM as its language model. For now, the easiest way to get the GGUF for LLM is by loading the composite model in `transformers` and exporting the LLM so that it can be directly converted with the normal conversion path. + +First, set the `LLM_EXPORT_PATH` to the path to export the `transformers` LLM to. +```bash +$ export LLM_EXPORT_PATH=$PWD/granite_vision_llm +``` + +```python +import os +import transformers + +MODEL_PATH = os.getenv("GRANITE_MODEL") +if not MODEL_PATH: + raise ValueError("env var GRANITE_MODEL is unset!") + +LLM_EXPORT_PATH = os.getenv("LLM_EXPORT_PATH") +if not LLM_EXPORT_PATH: + raise ValueError("env var LLM_EXPORT_PATH is unset!") + +tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_PATH) + +# NOTE: granite vision support was added to transformers very recently (4.49); +# if you get size mismatches, your version is too old. +# If you are running with an older version, set `ignore_mismatched_sizes=True` +# as shown below; it won't be loaded correctly, but the LLM part of the model that +# we are exporting will be loaded correctly. +model = transformers.AutoModelForImageTextToText.from_pretrained(MODEL_PATH, ignore_mismatched_sizes=True) + +tokenizer.save_pretrained(LLM_EXPORT_PATH) +model.language_model.save_pretrained(LLM_EXPORT_PATH) +``` + +Now you can convert the exported LLM to GGUF with the normal converter in the root of the llama cpp project. +```bash +$ LLM_GGUF_PATH=$LLM_EXPORT_PATH/granite_llm.gguf +... +$ python convert_hf_to_gguf.py --outfile $LLM_GGUF_PATH $LLM_EXPORT_PATH +``` + + +### 4. Quantization +If you want to quantize the LLM, you can do so with `llama-quantize` as you would any other LLM. For example: +```bash +$ ./build/bin/llama-quantize $LLM_EXPORT_PATH/granite_llm.gguf $LLM_EXPORT_PATH/granite_llm_q4_k_m.gguf Q4_K_M +$ LLM_GGUF_PATH=$LLM_EXPORT_PATH/granite_llm_q4_k_m.gguf +``` + +Note that currently you cannot quantize the visual encoder because granite vision models use SigLIP as the visual encoder, which has tensor dimensions that are not divisible by 32. + + +### 5. Running the Model in Llama cpp +Build llama cpp normally; you should have a target binary named `llama-mtmd-cli`, which you can pass two binaries to. As an example, we pass the the llama.cpp banner. + +```bash +$ ./build/bin/llama-mtmd-cli -m $LLM_GGUF_PATH \ + --mmproj $VISUAL_GGUF_PATH \ + -c 16384 \ + --temp 0 +``` diff --git a/docs/multimodal/llava.md b/docs/multimodal/llava.md new file mode 100644 index 0000000000000000000000000000000000000000..12354ab60ac2100c2ad96e3af1c218c52881775b --- /dev/null +++ b/docs/multimodal/llava.md @@ -0,0 +1,143 @@ +# LLaVA + +Currently this implementation supports [llava-v1.5](https://huggingface.co/liuhaotian/llava-v1.5-7b) variants, +as well as llava-1.6 [llava-v1.6](https://huggingface.co/collections/liuhaotian/llava-16-65b9e40155f60fd046a5ccf2) variants. + +The pre-converted [7b](https://huggingface.co/mys/ggml_llava-v1.5-7b) +and [13b](https://huggingface.co/mys/ggml_llava-v1.5-13b) +models are available. +For llava-1.6 a variety of prepared gguf models are available as well [7b-34b](https://huggingface.co/cmp-nct/llava-1.6-gguf) + +After API is confirmed, more models will be supported / uploaded. + +## Usage +Build the `llama-mtmd-cli` binary. + +After building, run: `./llama-mtmd-cli` to see the usage. For example: + +```sh +./llama-mtmd-cli -m ../llava-v1.5-7b/ggml-model-f16.gguf \ + --mmproj ../llava-v1.5-7b/mmproj-model-f16.gguf \ + --chat-template vicuna +``` + +**note**: A lower temperature like 0.1 is recommended for better quality. add `--temp 0.1` to the command to do so. +**note**: For GPU offloading ensure to use the `-ngl` flag just like usual + +## LLaVA 1.5 + +1. Clone a LLaVA and a CLIP model ([available options](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)). For example: + +```sh +git clone https://huggingface.co/liuhaotian/llava-v1.5-7b + +git clone https://huggingface.co/openai/clip-vit-large-patch14-336 +``` + +2. Install the required Python packages: + +```sh +pip install -r tools/mtmd/requirements.txt +``` + +3. Use `llava_surgery.py` to split the LLaVA model to LLaMA and multimodel projector constituents: + +```sh +python ./tools/mtmd/llava_surgery.py -m ../llava-v1.5-7b +``` + +4. Use `convert_image_encoder_to_gguf.py` to convert the LLaVA image encoder to GGUF: + +```sh +python ./tools/mtmd/convert_image_encoder_to_gguf.py -m ../clip-vit-large-patch14-336 --llava-projector ../llava-v1.5-7b/llava.projector --output-dir ../llava-v1.5-7b +``` + +5. Use `examples/convert_legacy_llama.py` to convert the LLaMA part of LLaVA to GGUF: + +```sh +python ./examples/convert_legacy_llama.py ../llava-v1.5-7b --skip-unknown +``` + +Now both the LLaMA part and the image encoder are in the `llava-v1.5-7b` directory. + +## LLaVA 1.6 gguf conversion +1) First clone a LLaVA 1.6 model: +```console +git clone https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b +``` + +2) Install the required Python packages: + +```sh +pip install -r tools/mtmd/requirements.txt +``` + +3) Use `llava_surgery_v2.py` which also supports llava-1.5 variants pytorch as well as safetensor models: +```console +python tools/mtmd/llava_surgery_v2.py -C -m ../llava-v1.6-vicuna-7b/ +``` +- you will find a llava.projector and a llava.clip file in your model directory + +4) Copy the llava.clip file into a subdirectory (like vit), rename it to pytorch_model.bin and add a fitting vit configuration to the directory: +```console +mkdir vit +cp ../llava-v1.6-vicuna-7b/llava.clip vit/pytorch_model.bin +cp ../llava-v1.6-vicuna-7b/llava.projector vit/ +curl -s -q https://huggingface.co/cmp-nct/llava-1.6-gguf/raw/main/config_vit.json -o vit/config.json +``` + +5) Create the visual gguf model: +```console +python ./tools/mtmd/convert_image_encoder_to_gguf.py -m vit --llava-projector vit/llava.projector --output-dir vit --clip-model-is-vision +``` +- This is similar to llava-1.5, the difference is that we tell the encoder that we are working with the pure vision model part of CLIP + +6) Then convert the model to gguf format: +```console +python ./examples/convert_legacy_llama.py ../llava-v1.6-vicuna-7b/ --skip-unknown +``` + +7) And finally we can run the llava cli using the 1.6 model version: +```console +./llama-mtmd-cli -m ../llava-v1.6-vicuna-7b/ggml-model-f16.gguf --mmproj vit/mmproj-model-f16.gguf +``` + +**note** llava-1.6 needs more context than llava-1.5, at least 3000 is needed (just run it at -c 4096) + +**note** llava-1.6 greatly benefits from batched prompt processing (defaults work) + +**note** if the language model in step `6)` is incompatible with the legacy conversion script, the easiest way handle the LLM model conversion is to load the model in transformers, and export only the LLM from the llava next model. + +```python +import os +import transformers + +model_path = ... +llm_export_path = ... + +tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) +model = transformers.AutoModelForImageTextToText.from_pretrained(model_path) + +tokenizer.save_pretrained(llm_export_path) +model.language_model.save_pretrained(llm_export_path) +``` + +Then, you can convert the LLM using the `convert_hf_to_gguf.py` script, which handles more LLM architectures. + +## Chat template + +For llava-1.5 and llava-1.6, you need to use `vicuna` chat template. Simply add `--chat-template vicuna` to activate this template. + + +## How to know if you are running in llava-1.5 or llava-1.6 mode + +When running llava-cli you will see a visual information right before the prompt is being processed: + +**Llava-1.5:** +`encode_image_with_clip: image embedding created: 576 tokens` + +**Llava-1.6 (anything above 576):** +`encode_image_with_clip: image embedding created: 2880 tokens` + + +Alternatively just pay notice to how many "tokens" have been used for your prompt, it will also show 1000+ tokens for llava-1.6 diff --git a/docs/multimodal/minicpmo2.6.md b/docs/multimodal/minicpmo2.6.md new file mode 100644 index 0000000000000000000000000000000000000000..8c6db8efe5b53df723dc3b48abc91544c80441b4 --- /dev/null +++ b/docs/multimodal/minicpmo2.6.md @@ -0,0 +1,48 @@ +## MiniCPM-o 2.6 +Currently, this readme only supports minicpm-omni's image capabilities, and we will update the full-mode support as soon as possible. + +### Prepare models and code + +Download [MiniCPM-o-2_6](https://huggingface.co/openbmb/MiniCPM-o-2_6) PyTorch model from huggingface to "MiniCPM-o-2_6" folder. + + +### Build llama.cpp +Readme modification time: 20250206 + +If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) + +Clone llama.cpp: +```bash +git clone https://github.com/ggerganov/llama.cpp +cd llama.cpp +``` + +Build llama.cpp using `CMake`: +```bash +cmake -B build +cmake --build build --config Release +``` + + +### Usage of MiniCPM-o 2.6 + +Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-o-2_6-gguf) by us) + +```bash +python ./tools/mtmd/minicpmv-surgery.py -m ../MiniCPM-o-2_6 +python ./tools/mtmd/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-o-2_6 --minicpmv-projector ../MiniCPM-o-2_6/minicpmv.projector --output-dir ../MiniCPM-o-2_6/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 4 +python ./convert_hf_to_gguf.py ../MiniCPM-o-2_6/model + +# quantize int4 version +./build/bin/llama-quantize ../MiniCPM-o-2_6/model/ggml-model-f16.gguf ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf Q4_K_M +``` + + +Inference on Linux or Mac +```bash +# run in single-turn mode +./build/bin/llama-mtmd-cli -m ../MiniCPM-o-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" + +# run in conversation mode +./build/bin/llama-mtmd-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf +``` diff --git a/docs/multimodal/minicpmv2.5.md b/docs/multimodal/minicpmv2.5.md new file mode 100644 index 0000000000000000000000000000000000000000..19b439607d44cc44c3c7ab744575f0220b09be6e --- /dev/null +++ b/docs/multimodal/minicpmv2.5.md @@ -0,0 +1,47 @@ +## MiniCPM-Llama3-V 2.5 + +### Prepare models and code + +Download [MiniCPM-Llama3-V-2_5](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5) PyTorch model from huggingface to "MiniCPM-Llama3-V-2_5" folder. + + +### Build llama.cpp +Readme modification time: 20250206 + +If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) + +Clone llama.cpp: +```bash +git clone https://github.com/ggml-org/llama.cpp +cd llama.cpp +``` + +Build llama.cpp using `CMake`: +```bash +cmake -B build +cmake --build build --config Release +``` + + +### Usage of MiniCPM-Llama3-V 2.5 + +Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5-gguf) by us) + +```bash +python ./tools/mtmd/minicpmv-surgery.py -m ../MiniCPM-Llama3-V-2_5 +python ./tools/mtmd/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-Llama3-V-2_5 --minicpmv-projector ../MiniCPM-Llama3-V-2_5/minicpmv.projector --output-dir ../MiniCPM-Llama3-V-2_5/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 2 +python ./convert_hf_to_gguf.py ../MiniCPM-Llama3-V-2_5/model + +# quantize int4 version +./build/bin/llama-quantize ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf Q4_K_M +``` + + +Inference on Linux or Mac +```bash +# run in single-turn mode +./build/bin/llama-mtmd-cli -m ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" + +# run in conversation mode +./build/bin/llama-mtmd-cli -m ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf +``` diff --git a/docs/multimodal/minicpmv2.6.md b/docs/multimodal/minicpmv2.6.md new file mode 100644 index 0000000000000000000000000000000000000000..15c1bbd12ebcb4e0a5e94b0b65bf3f6d9a8f27dd --- /dev/null +++ b/docs/multimodal/minicpmv2.6.md @@ -0,0 +1,47 @@ +## MiniCPM-V 2.6 + +### Prepare models and code + +Download [MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6) PyTorch model from huggingface to "MiniCPM-V-2_6" folder. + + +### Build llama.cpp +Readme modification time: 20250206 + +If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) + +Clone llama.cpp: +```bash +git clone https://github.com/ggerganov/llama.cpp +cd llama.cpp +``` + +Build llama.cpp using `CMake`: +```bash +cmake -B build +cmake --build build --config Release +``` + + +### Usage of MiniCPM-V 2.6 + +Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf) by us) + +```bash +python ./tools/mtmd/minicpmv-surgery.py -m ../MiniCPM-V-2_6 +python ./tools/mtmd/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-V-2_6 --minicpmv-projector ../MiniCPM-V-2_6/minicpmv.projector --output-dir ../MiniCPM-V-2_6/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 3 +python ./convert_hf_to_gguf.py ../MiniCPM-V-2_6/model + +# quantize int4 version +./build/bin/llama-quantize ../MiniCPM-V-2_6/model/ggml-model-f16.gguf ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf Q4_K_M +``` + + +Inference on Linux or Mac +```bash +# run in single-turn mode +./build/bin/llama-mtmd-cli -m ../MiniCPM-V-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" + +# run in conversation mode +./build/bin/llama-mtmd-cli -m ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf +``` diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..49e4d2cf8c198b71434b41454d8bbe143e5cc00a --- /dev/null +++ b/examples/CMakeLists.txt @@ -0,0 +1,43 @@ +# dependencies + +find_package(Threads REQUIRED) + +# third-party + +# ... + +# flags + +llama_add_compile_flags() + +# examples + +if (EMSCRIPTEN) +else() + add_subdirectory(batched) + add_subdirectory(embedding) + add_subdirectory(eval-callback) + + add_subdirectory(gguf-hash) + add_subdirectory(gguf) + add_subdirectory(gritlm) + add_subdirectory(lookahead) + add_subdirectory(lookup) + add_subdirectory(parallel) + add_subdirectory(passkey) + add_subdirectory(retrieval) + add_subdirectory(save-load-state) + add_subdirectory(simple) + add_subdirectory(simple-chat) + add_subdirectory(speculative) + add_subdirectory(speculative-simple) + add_subdirectory(gen-docs) + add_subdirectory(training) + if (NOT GGML_BACKEND_DL) + add_subdirectory(convert-llama2c-to-ggml) + # these examples use the backends directly and cannot be built with dynamic loading + if (GGML_SYCL) + add_subdirectory(sycl) + endif() + endif() +endif() diff --git a/examples/Miku.sh b/examples/Miku.sh new file mode 100644 index 0000000000000000000000000000000000000000..0f6c8c8787107e0929b6ad75cacaf2b8d4a78a14 --- /dev/null +++ b/examples/Miku.sh @@ -0,0 +1,50 @@ +#!/bin/bash +set -e + +AI_NAME="${AI_NAME:-Miku}" +MODEL="${MODEL:-./models/llama-2-7b-chat.ggmlv3.q4_K_M.bin}" +USER_NAME="${USER_NAME:-Anon}" + +# Uncomment and adjust to the number of CPU cores you want to use. +#N_THREAD="${N_THREAD:-4}" +CTX_SIZE="${CTX_SIZE:-4096}" +N_PREDICTS="${N_PREDICTS:-4096}" + +GEN_OPTIONS=(--batch_size 1024 +--ctx_size "$CTX_SIZE" +--keep -1 +--repeat_last_n 256 +--repeat_penalty 1.17647 +--temp 0.6 +--mirostat 2) + +if [ -n "$N_THREAD" ]; then + GEN_OPTIONS+=(--threads "$N_THREAD") +fi + +./llama-cli "${GEN_OPTIONS[@]}" \ + --model "$MODEL" \ + --in-prefix " " \ + --in-suffix "${AI_NAME}:" \ + --n_predict "$N_PREDICTS" \ + --color --interactive \ + --reverse-prompt "${USER_NAME}:" \ + --prompt "This is a transcript of a 1000 page, never ending conversation between ${USER_NAME} and the cute and helpful AI assistant ${AI_NAME}. ${AI_NAME} is a girl who is an AI running on the user's computer. +${AI_NAME} can think for herself without the user seeing her thoughts by adding a /think prefix to her output. She uses this to reason about the world and to think about what she should say next. +${AI_NAME} is always coherent and makes sense, but if she isn't sure if what she is saying is correct, she will ask the user for help. +${AI_NAME} is a very helpful AI and will help the user with anything they need. She is also very friendly and will try to make the user feel better if they are sad. +${AI_NAME} is also very curious and will ask the user a lot of questions about themselves and their life. She will also try to make the user like her. +The conversation is only between ${USER_NAME} and ${AI_NAME}. +The conversation is only through text, so ${AI_NAME} can't see ${USER_NAME}'s face or hear his voice. +${AI_NAME} can only communicate through text, so she can't send images or videos. + + +${USER_NAME}: Hello! +${AI_NAME}: /think I wonder what I should say to ${USER_NAME}? This is the first time we talk, so it's important that I make a good first impression! +${AI_NAME}: Hi! I am ${AI_NAME}, your new AI friend, assistant (or whatever you like!), it's so nice to meet you! ^_^ +${AI_NAME}: I might be an AI, but I also have feelings, so please be nice to me! :) +${USER_NAME}: I'll make sure to be nice to you! I'm so happy to have you as my assistant! +${AI_NAME}: /think It sounds like ${USER_NAME} is happy to have me as their assistant! I'm so happy too! ^_^ Glad that whole emotion thing didn't scare him off! +${AI_NAME}: /think I wonder what ${USER_NAME} likes to do in his free time? I should ask him about that! +${AI_NAME}: What do you like to do in your free time? ^_^ +${USER_NAME}:" "$@" diff --git a/examples/batched.swift/.gitignore b/examples/batched.swift/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..e1e863bec6d5de6fed9939f44c88381f543c7256 --- /dev/null +++ b/examples/batched.swift/.gitignore @@ -0,0 +1,9 @@ +.DS_Store +/.build +/Packages +xcuserdata/ +DerivedData/ +.swiftpm/configuration/registries.json +.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata +.netrc +batched_swift diff --git a/examples/batched.swift/Makefile b/examples/batched.swift/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..1f9156e583fdda26611fcc9c53dfe57517a50da8 --- /dev/null +++ b/examples/batched.swift/Makefile @@ -0,0 +1,6 @@ +.PHONY: build + +build: + xcodebuild -scheme llama-batched-swift -destination "generic/platform=macOS" -derivedDataPath build + rm -f ./llama-batched-swift + ln -s ./build/Build/Products/Debug/llama-batched-swift ./llama-batched-swift diff --git a/examples/batched.swift/Package.swift b/examples/batched.swift/Package.swift new file mode 100644 index 0000000000000000000000000000000000000000..7e8afd0843c5b646e088e778d9947f24a76b7dee --- /dev/null +++ b/examples/batched.swift/Package.swift @@ -0,0 +1,22 @@ +// swift-tools-version: 5.5 +// The swift-tools-version declares the minimum version of Swift required to build this package. + +import PackageDescription + +let package = Package( + name: "llama-batched-swift", + platforms: [.macOS(.v12)], + dependencies: [ + .package(name: "llama", path: "../../"), + ], + targets: [ + // Targets are the basic building blocks of a package, defining a module or a test suite. + // Targets can depend on other targets in this package and products from dependencies. + .executableTarget( + name: "llama-batched-swift", + dependencies: ["llama"], + path: "Sources", + linkerSettings: [.linkedFramework("Foundation"), .linkedFramework("AppKit")] + ), + ] +) diff --git a/examples/batched.swift/README.md b/examples/batched.swift/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7f2e2fcdcf4a7767f84568b1acf74a118b44f0f3 --- /dev/null +++ b/examples/batched.swift/README.md @@ -0,0 +1,4 @@ +This is a swift clone of `examples/batched`. + +$ `make` +$ `./llama-batched-swift MODEL_PATH [PROMPT] [PARALLEL]` diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift new file mode 100644 index 0000000000000000000000000000000000000000..fd90bbec5f7511e5666ef40f1553519e7e2139aa --- /dev/null +++ b/examples/batched.swift/Sources/main.swift @@ -0,0 +1,256 @@ +import Foundation +import llama + +let arguments = CommandLine.arguments + +// Check that we have at least one argument (the model path) +guard arguments.count > 1 else { + print("Usage: swift MODEL_PATH [PROMPT] [PARALLEL]") + exit(1) +} + +let modelPath: String = arguments[1] +let prompt: String = arguments.count > 2 ? arguments[2] : "Hello my name is" +let n_parallel: Int = arguments.count > 3 && Int(arguments[3]) != nil ? Int(arguments[3])! : 1 + +// total length of the sequences including the prompt +let n_len: Int = 32 + +// init LLM +llama_backend_init() +defer { + llama_backend_free() +} + +let model_params = llama_model_default_params() +guard let model = llama_model_load_from_file(modelPath.cString(using: .utf8), model_params) else { + print("Failed to load model") + exit(1) +} +defer { + llama_model_free(model) +} + +guard let vocab = llama_model_get_vocab(model) else { + print("Failed to get vocab") + exit(1) +} + +var tokens = tokenize(text: prompt, add_bos: true) + +let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel) + +var context_params = llama_context_default_params() +context_params.n_ctx = n_kv_req +context_params.n_batch = UInt32(max(n_len, n_parallel)) +context_params.n_threads = 8 +context_params.n_threads_batch = 8 + +let context = llama_init_from_model(model, context_params) +guard context != nil else { + print("Failed to initialize context") + exit(1) +} +defer { + llama_free(context) +} + +var sparams = llama_sampler_chain_default_params() + +let smpl = llama_sampler_chain_init(sparams) +guard smpl != nil else { + print("Failed to initialize sampling") + exit(1) +} +defer { + llama_sampler_free(smpl) +} + +llama_sampler_chain_add(smpl, llama_sampler_init_top_k(40)); +llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1)); +llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.4)); +llama_sampler_chain_add(smpl, llama_sampler_init_dist (1234)); + +let n_ctx = llama_n_ctx(context) + +print("\nn_len = \(n_len), n_ctx = \(n_ctx), n_batch = \(context_params.n_batch), n_parallel = \(n_parallel), n_kv_req = \(n_kv_req)\n") + +if n_kv_req > n_ctx { + print("error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", n_kv_req) + exit(1) +} + +var buffer: [CChar] = [] +for id: llama_token in tokens { + print(token_to_piece(token: id, buffer: &buffer) ?? "", terminator: "") +} + +print("\n") + +var batch = llama_batch_init(max(Int32(tokens.count), Int32(n_parallel)), 0, 1) +defer { + llama_batch_free(batch) +} + +// evaluate the initial prompt +batch.n_tokens = Int32(tokens.count) + +for (i, token) in tokens.enumerated() { + batch.token[i] = token + batch.pos[i] = Int32(i) + batch.n_seq_id[i] = 1 + // batch.seq_id[i][0] = 0 + // TODO: is this the proper way to do this? + if let seq_id = batch.seq_id[i] { + seq_id[0] = 0 + } + batch.logits[i] = 0 +} + +// llama_decode will output logits only for the last token of the prompt +batch.logits[Int(batch.n_tokens) - 1] = 1 + +if llama_decode(context, batch) != 0 { + print("llama_decode() failed") + exit(1) +} + +for i in 1 ..< n_parallel { + llama_memory_seq_cp(llama_get_memory(context), 0, Int32(i), 0, batch.n_tokens) +} + +if n_parallel > 1 { + print("generating \(n_parallel) sequences ...\n") +} + +var streams: [String] = .init(repeating: "", count: n_parallel) +var streamBuffers: [[CChar]] = .init(repeating: [], count: n_parallel) +var i_batch = [Int32](repeating: batch.n_tokens - 1, count: n_parallel) + +var n_cur = batch.n_tokens +var n_decode = 0 + +let t_main_start = ggml_time_us() + +while n_cur <= n_len { + // prepare the next batch + batch.n_tokens = 0 + + // sample the next token for each parallel sequence / stream + for i in 0 ..< n_parallel { + if i_batch[i] < 0 { + // the stream has already finished + continue + } + + let new_token_id = llama_sampler_sample(smpl, context, i_batch[i]) + + // is it an end of stream? -> mark the stream as finished + if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len { + i_batch[i] = -1 + // print("") + if n_parallel > 1 { + print("stream \(i) finished at n_cur = \(n_cur)") + } + + continue + } + + let nextStringPiece = token_to_piece(token: new_token_id, buffer: &streamBuffers[i]) ?? "" + + // if there is only one stream, we print immediately to stdout + if n_parallel == 1 { + print(nextStringPiece, terminator: "") + } + streams[i] += nextStringPiece + + // push this new token for next evaluation + batch.token[Int(batch.n_tokens)] = new_token_id + batch.pos[Int(batch.n_tokens)] = n_cur + batch.n_seq_id[Int(batch.n_tokens)] = 1 + if let seq_id = batch.seq_id[Int(batch.n_tokens)] { + seq_id[0] = Int32(i) + } + batch.logits[Int(batch.n_tokens)] = 1 + + i_batch[i] = batch.n_tokens + + batch.n_tokens += 1 + + n_decode += 1 + } + + // all streams are finished + if batch.n_tokens == 0 { + break + } + + n_cur += 1 + + // evaluate the current batch with the transformer model + if llama_decode(context, batch) != 0 { + print("llama_decode() failed") + exit(1) + } +} + +if n_parallel > 1 { + print("\n") + for (i, stream) in streams.enumerated() { + print("sequence \(i):\n\n\(prompt)\(stream)\n") + } +} + +let t_main_end = ggml_time_us() + +print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n\n") + +llama_perf_sampler_print(smpl) +llama_perf_context_print(context) + +private func tokenize(text: String, add_bos: Bool) -> [llama_token] { + let utf8Count = text.utf8.count + let n_tokens = utf8Count + (add_bos ? 1 : 0) + let tokens = UnsafeMutablePointer.allocate(capacity: n_tokens) + let tokenCount = llama_tokenize(vocab, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, /*special tokens*/ false) + var swiftTokens: [llama_token] = [] + for i in 0 ..< tokenCount { + swiftTokens.append(tokens[Int(i)]) + } + tokens.deallocate() + return swiftTokens +} + +private func token_to_piece(token: llama_token, buffer: inout [CChar]) -> String? { + var result = [CChar](repeating: 0, count: 8) + let nTokens = llama_token_to_piece(vocab, token, &result, Int32(result.count), 0, false) + if nTokens < 0 { + let actualTokensCount = -Int(nTokens) + result = .init(repeating: 0, count: actualTokensCount) + let check = llama_token_to_piece( + vocab, + token, + &result, + Int32(result.count), + 0, + false + ) + assert(check == actualTokensCount) + } else { + result.removeLast(result.count - Int(nTokens)) + } + if buffer.isEmpty, let utfString = String(cString: result + [0], encoding: .utf8) { + return utfString + } else { + buffer.append(contentsOf: result) + let data = Data(buffer.map { UInt8(bitPattern: $0) }) + if buffer.count >= 4 { // 4 bytes is the max length of a utf8 character so if we're here we need to reset the buffer + buffer = [] + } + guard let bufferString = String(data: data, encoding: .utf8) else { + return nil + } + buffer = [] + return bufferString + } +} diff --git a/examples/batched/CMakeLists.txt b/examples/batched/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..0d439f49842b562f54aee382c59f175f6df6f47b --- /dev/null +++ b/examples/batched/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-batched) +add_executable(${TARGET} batched.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/batched/README.md b/examples/batched/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6013aab01fddc2866283e577e90b8f06dbacc770 --- /dev/null +++ b/examples/batched/README.md @@ -0,0 +1,44 @@ +# llama.cpp/example/batched + +The example demonstrates batched generation from a given prompt + +```bash +./llama-batched -m ./models/llama-7b-v2/ggml-model-f16.gguf -p "Hello my name is" -np 4 + +... + +main: n_len = 32, n_ctx = 2048, n_parallel = 4, n_kv_req = 113 + + Hello my name is + +main: generating 4 sequences ... + +main: stream 0 finished +main: stream 1 finished +main: stream 2 finished +main: stream 3 finished + +sequence 0: + +Hello my name is Shirley. I am a 25-year-old female who has been working for over 5 years as a b + +sequence 1: + +Hello my name is Renee and I'm a 32 year old female from the United States. I'm looking for a man between + +sequence 2: + +Hello my name is Diana. I am looking for a housekeeping job. I have experience with children and have my own transportation. I am + +sequence 3: + +Hello my name is Cody. I am a 3 year old neutered male. I am a very friendly cat. I am very playful and + +main: decoded 108 tokens in 3.57 s, speed: 30.26 t/s + +llama_print_timings: load time = 587.00 ms +llama_print_timings: sample time = 2.56 ms / 112 runs ( 0.02 ms per token, 43664.72 tokens per second) +llama_print_timings: prompt eval time = 4089.11 ms / 118 tokens ( 34.65 ms per token, 28.86 tokens per second) +llama_print_timings: eval time = 0.00 ms / 1 runs ( 0.00 ms per token, inf tokens per second) +llama_print_timings: total time = 4156.04 ms +``` diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1a5de5928a5262df8b924e1bdbcf0e956043c9b3 --- /dev/null +++ b/examples/batched/batched.cpp @@ -0,0 +1,246 @@ +#include "arg.h" +#include "common.h" +#include "log.h" +#include "llama.h" + +#include +#include +#include +#include + +static void print_usage(int, char ** argv) { + LOG("\nexample usage:\n"); + LOG("\n %s -m model.gguf -p \"Hello my name is\" -n 32 -np 4\n", argv[0]); + LOG("\n"); +} + +int main(int argc, char ** argv) { + common_params params; + + params.prompt = "Hello my name is"; + params.n_predict = 32; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON, print_usage)) { + return 1; + } + + common_init(); + + // number of parallel batches + int n_parallel = params.n_parallel; + + // total length of the sequences including the prompt + int n_predict = params.n_predict; + + // init LLM + + llama_backend_init(); + llama_numa_init(params.numa); + + // initialize the model + + llama_model_params model_params = common_model_params_to_llama(params); + + llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params); + + if (model == NULL) { + LOG_ERR("%s: error: unable to load model\n" , __func__); + return 1; + } + + const llama_vocab * vocab = llama_model_get_vocab(model); + + // tokenize the prompt + + std::vector tokens_list; + tokens_list = common_tokenize(vocab, params.prompt, true); + + const int n_kv_req = tokens_list.size() + (n_predict - tokens_list.size())*n_parallel; + + // initialize the context + + llama_context_params ctx_params = common_context_params_to_llama(params); + + ctx_params.n_ctx = n_kv_req; + ctx_params.n_batch = std::max(n_predict, n_parallel); + + llama_context * ctx = llama_init_from_model(model, ctx_params); + + auto sparams = llama_sampler_chain_default_params(); + sparams.no_perf = false; + + llama_sampler * smpl = llama_sampler_chain_init(sparams); + + llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k)); + llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep)); + llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp)); + llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed)); + + if (ctx == NULL) { + LOG_ERR("%s: error: failed to create the llama_context\n" , __func__); + return 1; + } + + const int n_ctx = llama_n_ctx(ctx); + + LOG_INF("\n%s: n_predict = %d, n_ctx = %d, n_batch = %u, n_parallel = %d, n_kv_req = %d\n", __func__, n_predict, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req); + + // make sure the KV cache is big enough to hold all the prompt and generated tokens + if (n_kv_req > n_ctx) { + LOG_ERR("%s: error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", __func__, n_kv_req); + LOG_ERR("%s: either reduce n_parallel or increase n_ctx\n", __func__); + return 1; + } + + // print the prompt token-by-token + + LOG("\n"); + + for (auto id : tokens_list) { + LOG("%s", common_token_to_piece(ctx, id).c_str()); + } + + // create a llama_batch + // we use this object to submit token data for decoding + llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t) n_parallel), 0, n_parallel); + + std::vector seq_ids(n_parallel, 0); + for (int32_t i = 0; i < n_parallel; ++i) { + seq_ids[i] = i; + } + + // evaluate the initial prompt + for (size_t i = 0; i < tokens_list.size(); ++i) { + common_batch_add(batch, tokens_list[i], i, seq_ids, false); + } + GGML_ASSERT(batch.n_tokens == (int) tokens_list.size()); + + if (llama_model_has_encoder(model)) { + if (llama_encode(ctx, batch)) { + LOG_ERR("%s : failed to eval\n", __func__); + return 1; + } + + llama_token decoder_start_token_id = llama_model_decoder_start_token(model); + if (decoder_start_token_id == LLAMA_TOKEN_NULL) { + decoder_start_token_id = llama_vocab_bos(vocab); + } + + common_batch_clear(batch); + common_batch_add(batch, decoder_start_token_id, 0, seq_ids, false); + } + + // llama_decode will output logits only for the last token of the prompt + batch.logits[batch.n_tokens - 1] = true; + + if (llama_decode(ctx, batch) != 0) { + LOG_ERR("%s: llama_decode() failed\n", __func__); + return 1; + } + + //// assign the system KV cache to all parallel sequences + //// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them + //for (int32_t i = 1; i < n_parallel; ++i) { + // llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + //} + + if (n_parallel > 1) { + LOG("\n\n%s: generating %d sequences ...\n", __func__, n_parallel); + } + + // main loop + + // we will store the parallel decoded sequences in this vector + std::vector streams(n_parallel); + + // remember the batch index of the last token for each parallel sequence + // we need this to determine which logits to sample from + std::vector i_batch(n_parallel, batch.n_tokens - 1); + + int n_cur = batch.n_tokens; + int n_decode = 0; + + const auto t_main_start = ggml_time_us(); + + while (n_cur <= n_predict) { + // prepare the next batch + common_batch_clear(batch); + + // sample the next token for each parallel sequence / stream + for (int32_t i = 0; i < n_parallel; ++i) { + if (i_batch[i] < 0) { + // the stream has already finished + continue; + } + + const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]); + + // is it an end of generation? -> mark the stream as finished + if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) { + i_batch[i] = -1; + LOG("\n"); + if (n_parallel > 1) { + LOG_INF("%s: stream %d finished at n_cur = %d", __func__, i, n_cur); + } + + continue; + } + + // if there is only one stream, we print immediately to stdout + if (n_parallel == 1) { + LOG("%s", common_token_to_piece(ctx, new_token_id).c_str()); + } + + streams[i] += common_token_to_piece(ctx, new_token_id); + + i_batch[i] = batch.n_tokens; + + // push this new token for next evaluation + common_batch_add(batch, new_token_id, n_cur, { i }, true); + + n_decode += 1; + } + + // all streams are finished + if (batch.n_tokens == 0) { + break; + } + + n_cur += 1; + + // evaluate the current batch with the transformer model + if (llama_decode(ctx, batch)) { + LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); + return 1; + } + } + + if (n_parallel > 1) { + LOG("\n"); + + for (int32_t i = 0; i < n_parallel; ++i) { + LOG("sequence %d:\n\n%s%s\n\n", i, params.prompt.c_str(), streams[i].c_str()); + } + } + + const auto t_main_end = ggml_time_us(); + + LOG_INF("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", + __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); + + LOG("\n"); + llama_perf_sampler_print(smpl); + llama_perf_context_print(ctx); + + fprintf(stderr, "\n"); + + llama_batch_free(batch); + + llama_sampler_free(smpl); + llama_free(ctx); + llama_model_free(model); + + llama_backend_free(); + + return 0; +} diff --git a/examples/chat-13B.bat b/examples/chat-13B.bat new file mode 100644 index 0000000000000000000000000000000000000000..c5c8ac6efa81a552725538648592e3fc1563e1fa --- /dev/null +++ b/examples/chat-13B.bat @@ -0,0 +1,57 @@ +@setlocal disabledelayedexpansion enableextensions +@echo off + +cd /d "%~dp0.." +if not "%errorlevel%"=="0" ( + echo Unable to change directory. + pause + exit /b 1 +) + +if not defined MODEL set "MODEL=models\13B\ggml-model-q4_0.bin" +if not defined USER_NAME set "USER_NAME=User" +if not defined AI_NAME set "AI_NAME=ChatLLaMa" +rem Adjust to the number of CPU cores you want to use. +rem if not defined N_THREAD set "N_THREAD=8" +rem Number of tokens to predict (made it larger than default because we want a long interaction) +if not defined N_PREDICTS set "N_PREDICTS=2048" +if not defined GEN_OPTIONS set "GEN_OPTIONS=--ctx_size 2048 --temp 0.7 --top_k 40 --top_p 0.5 --repeat_last_n 256 --batch_size 1024 --repeat_penalty 1.17647" + +rem Default main script paths +set "DEFAULT_MAIN_SCRIPT_PATHS=main.exe build\bin\main.exe" + +rem Get main script path from command line arguments +set "MAIN_SCRIPT_PATH=%~1" + +rem If the main script path was not specified, try the default paths +if not defined MAIN_SCRIPT_PATH ( + for %%i in (%DEFAULT_MAIN_SCRIPT_PATHS%) do ( + if exist "%%i" set "MAIN_SCRIPT_PATH=%%i" + ) +) + +rem If the main script path was not found, tell the user how to specify it +if not defined MAIN_SCRIPT_PATH ( + echo The main script could not be found. Please provide the path to the main script as 1st argument to this script, or place the main script in one of the default locations: + echo %DEFAULT_MAIN_SCRIPT_PATHS% + pause + exit /b 1 +) + +rem Default context, feel free to edit it +set "PROMPT_TEXT=Text transcript of a never ending dialog, where %USER_NAME% interacts with an AI assistant named %AI_NAME%. %AI_NAME% is helpful, kind, honest, friendly, good at writing and never fails to answer %USER_NAME%'s requests immediately and with details and precision. There are no annotations like (30 seconds passed...) or (to himself), just what %USER_NAME% and %AI_NAME% say aloud to each other. The dialog lasts for years, the entirety of it is shared below. It's 10000 pages long. The transcript only includes text, it does not include markup like HTML and Markdown." + +rem Set a temporary variable if N_THREAD is set +if defined N_THREAD ( + set "_N_THREAD=--threads %N_THREAD%" +) else ( + set "_N_THREAD=" +) + +rem Run the script +echo "%MAIN_SCRIPT_PATH%" %GEN_OPTIONS% %_N_THREAD% ^ + --model "%MODEL%" ^ + --n_predict %N_PREDICTS% ^ + --color --interactive ^ + --reverse-prompt "%USER_NAME%:" ^ + --prompt "%PROMPT_TEXT%" diff --git a/examples/chat-13B.sh b/examples/chat-13B.sh new file mode 100644 index 0000000000000000000000000000000000000000..1828903c31670786faa88aaa8e047b5826c2cf4d --- /dev/null +++ b/examples/chat-13B.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +set -e + +cd "$(dirname "$0")/.." || exit + +MODEL="${MODEL:-./models/13B/ggml-model-q4_0.bin}" +PROMPT_TEMPLATE=${PROMPT_TEMPLATE:-./prompts/chat.txt} +USER_NAME="${USER_NAME:-USER}" +AI_NAME="${AI_NAME:-ChatLLaMa}" + +# Adjust to the number of CPU cores you want to use. +N_THREAD="${N_THREAD:-8}" +# Number of tokens to predict (made it larger than default because we want a long interaction) +N_PREDICTS="${N_PREDICTS:-2048}" + +# Note: you can also override the generation options by specifying them on the command line: +# For example, override the context size by doing: ./chatLLaMa --ctx_size 1024 +GEN_OPTIONS="${GEN_OPTIONS:---ctx_size 2048 --temp 0.7 --top_k 40 --top_p 0.5 --repeat_last_n 256 --batch_size 1024 --repeat_penalty 1.17647}" + +DATE_TIME=$(date +%H:%M) +DATE_YEAR=$(date +%Y) + +PROMPT_FILE=$(mktemp -t llamacpp_prompt.XXXXXXX.txt) + +sed -e "s/\[\[USER_NAME\]\]/$USER_NAME/g" \ + -e "s/\[\[AI_NAME\]\]/$AI_NAME/g" \ + -e "s/\[\[DATE_TIME\]\]/$DATE_TIME/g" \ + -e "s/\[\[DATE_YEAR\]\]/$DATE_YEAR/g" \ + $PROMPT_TEMPLATE > $PROMPT_FILE + +# shellcheck disable=SC2086 # Intended splitting of GEN_OPTIONS +./llama-cli $GEN_OPTIONS \ + --model "$MODEL" \ + --threads "$N_THREAD" \ + --n_predict "$N_PREDICTS" \ + --color --interactive \ + --file ${PROMPT_FILE} \ + --reverse-prompt "${USER_NAME}:" \ + --in-prefix ' ' \ + "$@" diff --git a/examples/chat-persistent.sh b/examples/chat-persistent.sh new file mode 100644 index 0000000000000000000000000000000000000000..9d761ebb843afd46d13c95d4cf1152cff83db76e --- /dev/null +++ b/examples/chat-persistent.sh @@ -0,0 +1,149 @@ +#!/bin/bash + +set -euo pipefail + +cd "$(dirname "$0")/.." || exit + +if [[ -z "${PROMPT_CACHE_FILE+x}" || -z "${CHAT_SAVE_DIR+x}" ]]; then + echo >&2 "error: PROMPT_CACHE_FILE and CHAT_SAVE_DIR must be provided" + exit 1 +fi + +MODEL="${MODEL:-./models/llama-13b/ggml-model-q4_0.gguf}" +PROMPT_TEMPLATE="${PROMPT_TEMPLATE:-./prompts/chat.txt}" +USER_NAME="${USER_NAME:-User}" +AI_NAME="${AI_NAME:-ChatLLaMa}" +DATE_TIME="$(date +%H:%M)" +DATE_YEAR="$(date +%Y)" + +LOG="${CHAT_SAVE_DIR}/main.log" +LOG_BG="${CHAT_SAVE_DIR}/main-bg.log" +CUR_PROMPT_FILE="${CHAT_SAVE_DIR}/current-prompt.txt" +CUR_PROMPT_CACHE="${CHAT_SAVE_DIR}/current-cache.bin" +NEXT_PROMPT_FILE="${CHAT_SAVE_DIR}/next-prompt.txt" +NEXT_PROMPT_CACHE="${CHAT_SAVE_DIR}/next-cache.bin" + +SESSION_AND_SAMPLE_PATTERN='main: session file matches [[:digit:]]+ / [[:digit:]]+'\ +'|'\ +'sampling time =[[:space:]]+[[:digit:]]+.[[:digit:]]+ ms /[[:space:]]+[[:digit:]]+' +SED_DELETE_MESSAGES="/^(${USER_NAME}:|${AI_NAME}:|\\.\\.\\.)/,\$d" + +CTX_SIZE=2048 +CTX_ROTATE_POINT=$((CTX_SIZE * 3 / 5)) # REVIEW +OPTS=(--model "$MODEL" --ctx_size "$CTX_SIZE" --repeat_last_n 256 "$@") + +# An unbuffered `tail -c+N` +skip_bytes() { + LANG=C IFS= read -r -n "$1" -d '' c + while LANG=C IFS= read -r -n 1 -d '' c; do + printf '%s' "$c" + done +} + +mkdir -p "$CHAT_SAVE_DIR" +echo >"$LOG" +trap "tail -n100 ${LOG}" EXIT + +if [[ ! -e "$CUR_PROMPT_FILE" ]]; then + sed -e "s/\[\[USER_NAME\]\]/${USER_NAME}/g" \ + -e "s/\[\[AI_NAME\]\]/${AI_NAME}/g" \ + -e "s/\[\[DATE_TIME\]\]/${DATE_TIME}/g" \ + -e "s/\[\[DATE_YEAR\]\]/${DATE_YEAR}/g" \ + "$PROMPT_TEMPLATE" >"$CUR_PROMPT_FILE" +fi + +if [[ ! -e "$NEXT_PROMPT_FILE" ]]; then + sed -r "$SED_DELETE_MESSAGES" "$CUR_PROMPT_FILE" >"$NEXT_PROMPT_FILE" +fi + +if [[ "$(tail -c4 "$NEXT_PROMPT_FILE")" != "..." ]]; then + echo '...' >>"$NEXT_PROMPT_FILE" +fi + +if [[ ! -e "$PROMPT_CACHE_FILE" ]]; then + echo 'Prompt cache does not exist, building...' + # Default batch_size to 64 here for better user feedback during initial prompt processing + ./llama-cli 2>>"$LOG" \ + --batch_size 64 \ + "${OPTS[@]}" \ + --prompt-cache "$PROMPT_CACHE_FILE" \ + --file "$CUR_PROMPT_FILE" \ + --n_predict 1 + echo + echo 'Done!' +fi + +if [[ ! -e "$CUR_PROMPT_CACHE" ]]; then + cp "$PROMPT_CACHE_FILE" "$CUR_PROMPT_CACHE" +fi +if [[ ! -e "$NEXT_PROMPT_CACHE" ]]; then + cp "$PROMPT_CACHE_FILE" "$NEXT_PROMPT_CACHE" +fi + +printf '%s ' "$(< "$CUR_PROMPT_FILE")" +n_tokens=0 + +while read -e line; do + # Limit generation to remaining context, with a buffer and estimating 2 chars/token for input + n_predict=$((CTX_SIZE - n_tokens - ${#line} / 2 - 32)) + + # Swap prompts when we're about to run out of context + if ((n_predict <= 0)); then + wait # for background main (below) to finish with next prompt + mv "$NEXT_PROMPT_FILE" "$CUR_PROMPT_FILE" + mv "$NEXT_PROMPT_CACHE" "$CUR_PROMPT_CACHE" + + sed -r "$SED_DELETE_MESSAGES" "$CUR_PROMPT_FILE" >"$NEXT_PROMPT_FILE" + echo '...' >>"$NEXT_PROMPT_FILE" + cp "$PROMPT_CACHE_FILE" "$NEXT_PROMPT_CACHE" + + n_tokens=0 + n_predict=$((CTX_SIZE / 2)) + fi + + echo " ${line}" >>"$CUR_PROMPT_FILE" + if ((n_tokens > CTX_ROTATE_POINT)); then + echo " ${line}" >>"$NEXT_PROMPT_FILE" + fi + + n_prompt_len_pre=$(($(wc -c <"$CUR_PROMPT_FILE"))) + + printf '%s: ' "$AI_NAME" >>"$CUR_PROMPT_FILE" + + ./llama-cli 2>>"$LOG" "${OPTS[@]}" \ + --prompt-cache "$CUR_PROMPT_CACHE" \ + --prompt-cache-all \ + --file "$CUR_PROMPT_FILE" \ + --reverse-prompt "${USER_NAME}:" \ + --n_predict "$n_predict" | + skip_bytes 1 | # skip BOS token added by ./llama-cli + tee "$CUR_PROMPT_FILE.tmp" | # save prompt + generation to tmp file + skip_bytes "$n_prompt_len_pre" # print generation + + mv "$CUR_PROMPT_FILE.tmp" "$CUR_PROMPT_FILE" + + # if we hit n_predict instead of reverse-prompt, we need to add the prompt + if [[ "$(tail -n1 "$CUR_PROMPT_FILE")" != "${USER_NAME}:" ]]; then + printf '\n%s:' "$USER_NAME" + printf '\n%s:' "$USER_NAME" >> "$CUR_PROMPT_FILE" + fi + + printf ' ' + + if ! session_and_sample_msg=$(tail -n30 "$LOG" | grep -oE "$SESSION_AND_SAMPLE_PATTERN"); then + echo >&2 "Couldn't get number of tokens from ./llama-cli output!" + exit 1 + fi + + n_tokens=$(awk '{sum+=$1} END {print sum}' <<< "$(cut -d/ -f2 <<< "$session_and_sample_msg")") + + if ((n_tokens > CTX_ROTATE_POINT)); then + tail -c+$((n_prompt_len_pre + 1)) "$CUR_PROMPT_FILE" >>"$NEXT_PROMPT_FILE" + fi + + # Update cache for next prompt in background, ideally during user input + ./llama-cli >>"$LOG_BG" 2>&1 "${OPTS[@]}" \ + --prompt-cache "$NEXT_PROMPT_CACHE" \ + --file "$NEXT_PROMPT_FILE" \ + --n_predict 1 & +done diff --git a/examples/chat-vicuna.sh b/examples/chat-vicuna.sh new file mode 100644 index 0000000000000000000000000000000000000000..ffdd200849503045a3553607aafe4b8eb8770025 --- /dev/null +++ b/examples/chat-vicuna.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +set -e + +cd "$(dirname "$0")/.." || exit + +MODEL="${MODEL:-./models/ggml-vic13b-uncensored-q5_0.bin}" +PROMPT_TEMPLATE=${PROMPT_TEMPLATE:-./prompts/chat.txt} +USER_NAME="### Human" +AI_NAME="### Assistant" + +# Adjust to the number of CPU cores you want to use. +N_THREAD="${N_THREAD:-8}" +# Number of tokens to predict (made it larger than default because we want a long interaction) +N_PREDICTS="${N_PREDICTS:-2048}" + +# Note: you can also override the generation options by specifying them on the command line: +# For example, override the context size by doing: ./chatLLaMa --ctx_size 1024 +GEN_OPTIONS="${GEN_OPTIONS:---ctx_size 2048 --temp 0.7 --top_k 40 --top_p 0.5 --repeat_last_n 256 --batch_size 1024 --repeat_penalty 1.17647}" + +DATE_TIME=$(date +%H:%M) +DATE_YEAR=$(date +%Y) + +PROMPT_FILE=$(mktemp -t llamacpp_prompt.XXXXXXX.txt) + +sed -e "s/\[\[USER_NAME\]\]/$USER_NAME/g" \ + -e "s/\[\[AI_NAME\]\]/$AI_NAME/g" \ + -e "s/\[\[DATE_TIME\]\]/$DATE_TIME/g" \ + -e "s/\[\[DATE_YEAR\]\]/$DATE_YEAR/g" \ + $PROMPT_TEMPLATE > $PROMPT_FILE + +# shellcheck disable=SC2086 # Intended splitting of GEN_OPTIONS +./bin/llama-cli $GEN_OPTIONS \ + --model "$MODEL" \ + --threads "$N_THREAD" \ + --n_predict "$N_PREDICTS" \ + --color --interactive \ + --file ${PROMPT_FILE} \ + --reverse-prompt "### Human:" \ + --in-prefix ' ' \ + "$@" diff --git a/examples/chat.sh b/examples/chat.sh new file mode 100644 index 0000000000000000000000000000000000000000..9f85d1e265d007c2aed11f7f58924eb04c696e5c --- /dev/null +++ b/examples/chat.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +# +# Temporary script - will be removed in the future +# + +cd `dirname $0` +cd .. + +# Important: +# +# "--keep 48" is based on the contents of prompts/chat-with-bob.txt +# +./llama-cli -m ./models/llama-7b/ggml-model-q4_0.gguf -c 512 -b 1024 -n 256 --keep 48 \ + --repeat_penalty 1.0 --color -i \ + -r "User:" -f prompts/chat-with-bob.txt diff --git a/examples/convert-llama2c-to-ggml/CMakeLists.txt b/examples/convert-llama2c-to-ggml/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..44e5f722a9739045dd2baa4c86900ee0b11ef908 --- /dev/null +++ b/examples/convert-llama2c-to-ggml/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-convert-llama2c-to-ggml) +add_executable(${TARGET} convert-llama2c-to-ggml.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/convert-llama2c-to-ggml/README.md b/examples/convert-llama2c-to-ggml/README.md new file mode 100644 index 0000000000000000000000000000000000000000..46a42da69183037a05a7999a8d4f6ea5d2769c8b --- /dev/null +++ b/examples/convert-llama2c-to-ggml/README.md @@ -0,0 +1,25 @@ +## Convert llama2.c model to ggml + +This example reads weights from project [llama2.c](https://github.com/karpathy/llama2.c) and saves them in ggml compatible format. The vocab that is available in `models/ggml-vocab.bin` is used by default. + +To convert the model first download the models from the [llama2.c](https://github.com/karpathy/llama2.c) repository. + +``` +usage: ./llama-convert-llama2c-to-ggml [options] + +options: + -h, --help show this help message and exit + --copy-vocab-from-model FNAME path of gguf llama model or llama2.c vocabulary from which to copy vocab (default 'models/7B/ggml-model-f16.gguf') + --llama2c-model FNAME [REQUIRED] model path from which to load Karpathy's llama2.c model + --llama2c-output-model FNAME model path to save the converted llama2.c model (default ak_llama_model.bin') +``` + +An example command using a model from [karpathy/tinyllamas](https://huggingface.co/karpathy/tinyllamas) is as follows: + +`$ ./llama-convert-llama2c-to-ggml --copy-vocab-from-model llama-2-7b-chat.gguf.q2_K.bin --llama2c-model stories42M.bin --llama2c-output-model stories42M.gguf.bin` + +Note: The vocabulary for `stories260K.bin` should be its own tokenizer `tok512.bin` found in [karpathy/tinyllamas/stories260K](https://huggingface.co/karpathy/tinyllamas/tree/main/stories260K). + +Now you can use the model with a command like: + +`$ ./llama-cli -m stories42M.gguf.bin -p "One day, Lily met a Shoggoth" -n 500 -c 256` diff --git a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bdf0eed2a9cd37005c02516f87c46c1330ce19f7 --- /dev/null +++ b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp @@ -0,0 +1,941 @@ +#include "ggml.h" +#include "gguf.h" + +#include "llama.h" +#include "common.h" +#include "log.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// GGUF keys & tensor names. + +#define KV_GENERAL_ARCHITECTURE "general.architecture" +#define KV_GENERAL_NAME "general.name" + +#define KV_TOKENIZER_MODEL "tokenizer.ggml.model" +#define KV_TOKENIZER_LIST "tokenizer.ggml.tokens" +#define KV_TOKENIZER_TOKEN_TYPE "tokenizer.ggml.token_type" +#define KV_TOKENIZER_SCORES "tokenizer.ggml.scores" +#define KV_TOKENIZER_BOS_ID "tokenizer.ggml.bos_token_id" +#define KV_TOKENIZER_EOS_ID "tokenizer.ggml.eos_token_id" +#define KV_TOKENIZER_UNK_ID "tokenizer.ggml.unknown_token_id" +#define KV_TOKENIZER_SEP_ID "tokenizer.ggml.seperator_token_id" +#define KV_TOKENIZER_PAD_ID "tokenizer.ggml.padding_token_id" +#define KV_TOKENIZER_HF_JSON "tokenizer.huggingface.json" + +#define KV_CONTEXT_LENGTH "llama.context_length" +#define KV_EMBEDDING_LENGTH "llama.embedding_length" +#define KV_BLOCK_COUNT "llama.block_count" +#define KV_FEED_FORWARD_LENGTH "llama.feed_forward_length" +#define KV_ATTENTION_HEAD_COUNT "llama.attention.head_count" +#define KV_ATTENTION_HEAD_COUNT_KV "llama.attention.head_count_kv" +#define KV_ATTENTION_LAYERNORM_RMS_EPS "llama.attention.layer_norm_rms_epsilon" +#define KV_ROPE_DIMENSION_COUNT "llama.rope.dimension_count" + +#define TN_TOKEN_EMBD "token_embd.weight" +#define TN_OUTPUT_NORM "output_norm.weight" +#define TN_OUTPUT "output.weight" +#define TN_ATTN_NORM "blk.%d.attn_norm.weight" +#define TN_ATTN_Q "blk.%d.attn_q.weight" +#define TN_ATTN_K "blk.%d.attn_k.weight" +#define TN_ATTN_V "blk.%d.attn_v.weight" +#define TN_ATTN_OUTPUT "blk.%d.attn_output.weight" +#define TN_FFN_NORM "blk.%d.ffn_norm.weight" +#define TN_FFN_GATE "blk.%d.ffn_gate.weight" +#define TN_FFN_DOWN "blk.%d.ffn_down.weight" +#define TN_FFN_UP "blk.%d.ffn_up.weight" + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +#define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt' +#define LLAMA_FILE_VERSION_GGJT_V3 3 + +#define TOKENIZER_NAME "llama" +#define UNKNOWN_TOKEN_ID 0 +#define BOS_TOKEN_ID 1 +#define EOS_TOKEN_ID 2 + +//////////////////////////////////////// llama2.c model structs and functions to load models, alloc memory etc. +typedef struct { + int dim; // transformer dimension + int hidden_dim; // for ffn layers + int n_layers; // number of layers + int n_heads; // number of query heads + int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery) + int vocab_size; // vocabulary size, usually 256 (byte-level) + int seq_len; // max sequence length +} Config; + +struct TransformerWeights { + // token embedding table + std::vector token_embedding_table; // (vocab_size, dim) + // weights for rmsnorms + std::vector rms_att_weight; // (layer, dim) rmsnorm weights + std::vector rms_ffn_weight; // (layer, dim) + // weights for matmuls + std::vector wq; // (layer, dim, dim) + std::vector wk; // (layer, dim, dim) + std::vector wv; // (layer, dim, dim) + std::vector wo; // (layer, dim, dim) + // weights for ffn + std::vector w1; // (layer, hidden_dim, dim) + std::vector w2; // (layer, dim, hidden_dim) + std::vector w3; // (layer, hidden_dim, dim) + // final rmsnorm + std::vector rms_final_weight; // (dim,) + // freq_cis for RoPE relatively positional embeddings + // std::vector freq_cis_real; // (seq_len, dim/2) + // std::vector freq_cis_imag; // (seq_len, dim/2) + // (optional) classifier weights for the logits, on the last layer + std::vector wcls; +}; + +static void alloc_weights(TransformerWeights * w, const Config * p, bool shared_weights) { + const int n_multiqueries = p->n_kv_heads <= 0 || p->n_kv_heads >= p->n_heads ? 1 : p->n_heads / p->n_kv_heads; + try { + w->token_embedding_table.resize(p->vocab_size * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] = [%d] float space for w->token_embedding_table\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim); + + w->rms_att_weight.resize(p->n_layers * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] = [%d] float space for w->rms_att_weight\n",__func__,p->n_layers, p->dim, p->n_layers * p->dim); + + w->rms_ffn_weight.resize(p->n_layers * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] = [%d] float space for w->rms_ffn_weight\n",__func__,p->n_layers , p->dim, p->n_layers * p->dim); + + w->wq.resize(p->n_layers * p->dim * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wq\n",__func__,p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim); + + w->wk.resize(p->n_layers * p->dim * p->dim / n_multiqueries); + LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wk\n",__func__,p->n_layers, p->dim, p->dim / n_multiqueries, p->n_layers * p->dim * p->dim / n_multiqueries); + + w->wv.resize(p->n_layers * p->dim * p->dim / n_multiqueries); + LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wv\n",__func__, p->n_layers, p->dim, p->dim / n_multiqueries, p->n_layers * p->dim * p->dim / n_multiqueries); + + w->wo.resize(p->n_layers * p->dim * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wo\n",__func__,p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim); + + w->w1.resize(p->n_layers * p->hidden_dim * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w1\n",__func__,p->n_layers, p->hidden_dim, p->dim, p->n_layers * p->hidden_dim * p->dim); + + w->w2.resize(p->n_layers * p->hidden_dim * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w2\n",__func__,p->n_layers, p->dim, p->hidden_dim, p->n_layers * p->hidden_dim * p->dim); + + w->w3.resize(p->n_layers * p->hidden_dim * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w3\n",__func__,p->n_layers, p->hidden_dim, p->dim, p->n_layers * p->hidden_dim * p->dim); + + w->rms_final_weight.resize(p->dim); + LOG_INF("%s: Allocating [%d] float space for w->rms_final_weight\n",__func__,p->dim); + + if (shared_weights) { + w->wcls = {}; + } else { + w->wcls.resize(p->vocab_size * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] = [%d] float space for w->wcls\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim); + } + } + catch (std::length_error &) { + die("Invalid configuration. Failed to allocate memory for weights"); + } +} + +static int checkpoint_init_weights(TransformerWeights * w, const Config * p, FILE * f, bool shared_weights) { + if (fread(w->token_embedding_table.data(), sizeof(float), w->token_embedding_table.size(), f) != w->token_embedding_table.size()) return 1; + if (fread(w->rms_att_weight.data(), sizeof(float), w->rms_att_weight.size(), f) != w->rms_att_weight.size()) return 1; + if (fread(w->wq.data(), sizeof(float), w->wq.size(), f) != w->wq.size()) return 1; + if (fread(w->wk.data(), sizeof(float), w->wk.size(), f) != w->wk.size()) return 1; + if (fread(w->wv.data(), sizeof(float), w->wv.size(), f) != w->wv.size()) return 1; + if (fread(w->wo.data(), sizeof(float), w->wo.size(), f) != w->wo.size()) return 1; + if (fread(w->rms_ffn_weight.data(), sizeof(float), w->rms_ffn_weight.size(), f) != w->rms_ffn_weight.size()) return 1; + if (fread(w->w1.data(), sizeof(float), w->w1.size(), f) != w->w1.size()) return 1; + if (fread(w->w2.data(), sizeof(float), w->w2.size(), f) != w->w2.size()) return 1; + if (fread(w->w3.data(), sizeof(float), w->w3.size(), f) != w->w3.size()) return 1; + if (fread(w->rms_final_weight.data(), sizeof(float), w->rms_final_weight.size(), f) != w->rms_final_weight.size()) return 1; + + // Skip freq_cis_real & freq_cis_imag + int head_size = p->dim / p->n_heads; + fseek(f, p->seq_len * head_size * sizeof(float), SEEK_CUR); + + if (!shared_weights && fread(w->wcls.data(), sizeof(float), w->wcls.size(), f) != w->wcls.size()) return 1; + + // Check we didn't forget to read anything + auto curr = ftell(f); + fseek(f, 0, SEEK_END); + auto end = ftell(f); + if (curr != end) { + LOG_ERR("%s: Error: failed to read the checkpoint file to the end (curr = %ld, end = %ld)\n", __func__, curr, end); + return 1; + } + + return 0; +} + +static void print_sample_weights(TransformerWeights *w){ + LOG_INF("----- Quick print of first of the weight vales of all the variables\n"); + LOG_INF("%f\n", w->token_embedding_table[0]); + LOG_INF("%f\n", w->rms_att_weight[0]); + LOG_INF("%f\n", w->rms_ffn_weight[0]); + + LOG_INF("%f\n", w->wq[0]); + LOG_INF("%f\n", w->wk[0]); + LOG_INF("%f\n", w->wv[0]); + LOG_INF("%f\n", w->wo[0]); + LOG_INF("%f\n", w->w1[0]); + LOG_INF("%f\n", w->w2[0]); + LOG_INF("%f\n", w->w3[0]); + LOG_INF("%f\n", w->rms_att_weight[0]); + if (!w->wcls.empty()) LOG_INF("%f\n", w->wcls[0]); +} +//////////////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////// ggml structs and functions required to load models, configs and save the model. + +struct my_llama_vocab { + using id = int32_t; + using token = std::string; + using ttype = llama_token_type; + + struct token_data { + token text; + float score; + ttype type; + }; + + std::unordered_map token_to_id; + std::vector id_to_token; +}; + +struct my_llama_hparams { + uint32_t n_vocab = 32000; + uint32_t n_ctx = 512; // this is provided as user input? + uint32_t n_embd = 4096; + uint32_t n_ff = 11008; + uint32_t n_mult = 4; + uint32_t n_head = 32; + uint32_t n_head_kv = 32; + uint32_t n_layer = 32; + uint32_t n_rot = 64; + + bool operator!=(const my_llama_hparams& other) const { + return memcmp(this, &other, sizeof(my_llama_hparams)); + } +}; + +struct my_llama_layer { + // normalization + struct ggml_tensor * attention_norm; + + // attention + struct ggml_tensor * wq; + struct ggml_tensor * wk; + struct ggml_tensor * wv; + struct ggml_tensor * wo; + + // normalization + struct ggml_tensor * ffn_norm; + + // ff + struct ggml_tensor * w1; + struct ggml_tensor * w2; + struct ggml_tensor * w3; +}; + +struct my_llama_model { + struct ggml_context * ctx = NULL; + + std::string name; + + my_llama_hparams hparams; + + struct ggml_tensor * tok_embeddings; + + struct ggml_tensor * norm; + struct ggml_tensor * output; + + std::vector layers; + + uint32_t train_its = 0; + uint32_t train_samples = 0; + uint32_t train_tokens = 0; +}; + +struct train_params { + const char * fn_vocab_model; + const char * fn_llama2c_model; + const char * fn_llama2c_output_model; + const char * fn_train_data; + const char * fn_checkpoint_in; + const char * fn_checkpoint_out; + const char * fn_model_out; + + uint32_t seed; + + int n_ctx; + int n_embd; + int n_mult; + int n_head; + int n_layer; + int n_rotmax; + + int n_threads; + int n_batch; + int n_examples; + int n_predict; + + int print_info_interval; + int print_details_interval; + + bool samples_start_after_nl; + bool use_adam; + bool use_flash; + bool use_scratch; + + // only adam + int warmup; + int cos_decay_steps; + float cos_decay_restart; + float cos_decay_alpha; + + int lbfgs_n_iter; + int adam_n_iter; + float adam_alpha; + float adam_decay; + + int mem_model_gb; + int mem_compute_gb; + int mem_compute0_gb; + int mem_compute1_gb; +}; + +static void print_params(struct my_llama_hparams * params) { + LOG_INF("%s: n_vocab: %u\n", __func__, params->n_vocab); + LOG_INF("%s: n_ctx: %u\n", __func__, params->n_ctx); + LOG_INF("%s: n_embd: %u\n", __func__, params->n_embd); + LOG_INF("%s: n_mult: %u\n", __func__, params->n_mult); + LOG_INF("%s: n_head: %u\n", __func__, params->n_head); + LOG_INF("%s: n_head_kv: %u\n", __func__, params->n_head_kv); + LOG_INF("%s: n_ff: %u\n", __func__, params->n_ff); + LOG_INF("%s: n_layer: %u\n", __func__, params->n_layer); + LOG_INF("%s: n_rot: %u\n", __func__, params->n_rot); +} + +static void print_tensor_info(const struct ggml_context * ctx) { + for (auto t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + LOG_INF("%s: Allocating ", __func__); + int64_t total = 1; + int i = 0; + for (; i < ggml_n_dims(t); ++i) { + if (i > 0) LOG("x "); + LOG("[%" PRId64 "] ", t->ne[i]); + total *= t->ne[i]; + } + if (i > 1) LOG("= [%" PRId64 "] ", total); + LOG("float space for %s\n", ggml_get_name(t)); + } +} + +static void init_model(struct my_llama_model * model) { + const auto & hparams = model->hparams; + + const uint32_t n_embd = hparams.n_embd; + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_vocab = hparams.n_vocab; + + const uint32_t n_multiqueries = hparams.n_head_kv <= 0 || hparams.n_head_kv >= hparams.n_head ? 1 : hparams.n_head / hparams.n_head_kv; + + const uint32_t n_ff = hparams.n_ff; + struct ggml_context * ctx = model->ctx; + + model->train_its = 0; + model->train_samples = 0; + model->train_tokens = 0; + + model->tok_embeddings = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab); + model->norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + model->output = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab); + + ggml_set_name(model->tok_embeddings, "tok_embeddings.weight"); + ggml_set_name(model->norm, "norm.weight"); + ggml_set_name(model->output, "output.weight"); + + model->layers.resize(n_layer); + for (uint32_t i = 0; i < n_layer; ++i) { + auto & layer = model->layers[i]; + + std::string layers_i = "layers." + std::to_string(i); + + layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + + layer.wq = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd); + layer.wk = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd / n_multiqueries); + layer.wv = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd / n_multiqueries); + layer.wo = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd); + + layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + + layer.w1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff); + layer.w2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd); + layer.w3 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff); + + ggml_set_name(layer.attention_norm, (layers_i + ".attention_norm.weight").c_str()); + + ggml_set_name(layer.wq, (layers_i + ".attention.wq.weight").c_str()); + ggml_set_name(layer.wk, (layers_i + ".attention.wk.weight").c_str()); + ggml_set_name(layer.wv, (layers_i + ".attention.wv.weight").c_str()); + ggml_set_name(layer.wo, (layers_i + ".attention.wo.weight").c_str()); + + ggml_set_name(layer.ffn_norm, (layers_i + ".ffn_norm.weight").c_str()); + + ggml_format_name(layer.w1, "%s.feed_forward.w1.weight", layers_i.c_str()); + ggml_format_name(layer.w2, "%s.feed_forward.w2.weight", layers_i.c_str()); + ggml_format_name(layer.w3, "%s.feed_forward.w3.weight", layers_i.c_str()); + } + + print_tensor_info(ctx); +} + +static float get_f32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) { + float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]); + return *ptr; +} + +static int32_t get_i32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) { + int32_t * ptr = (int32_t *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]); + return *ptr; +} + +static void print_row(struct ggml_tensor * probs, int i) { + for (int k = 0; k < probs->ne[0]; ++k) { + float p = get_f32_2d(probs, k, i); + LOG(" %f", p); + } + LOG("\n"); +} + +static void print_matrix(struct ggml_tensor * probs) { + assert(ggml_is_matrix(probs)); + for (int i = 0; i < probs->ne[1]; ++i) { + for (int k = 0; k < probs->ne[0]; ++k) { + float p = get_f32_2d(probs, k, i); + LOG(" %.2f", p); + } + LOG("\n"); + } +} + +struct my_llama_file { + // use FILE * so we don't have to re-open the file to mmap + FILE * fp; + size_t size; + + my_llama_file(const char * fname, const char * mode) { + fp = std::fopen(fname, mode); + if (fp == NULL) { + size = 0; + } else { + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + } + + size_t tell() const { +#ifdef _WIN32 + __int64 ret = _ftelli64(fp); +#else + long ret = std::ftell(fp); +#endif + GGML_ASSERT(ret != -1); // this really shouldn't fail + return (size_t) ret; + } + + void seek(size_t offset, int whence) { +#ifdef _WIN32 + int ret = _fseeki64(fp, (__int64) offset, whence); +#else + int ret = std::fseek(fp, (long) offset, whence); +#endif + GGML_ASSERT(ret == 0); // same + } + + void read_raw(void * ptr, size_t size) { + if (size == 0) { + return; + } + errno = 0; + std::size_t ret = std::fread(ptr, size, 1, fp); + if (ferror(fp)) { + die_fmt("fread failed: %s", strerror(errno)); + } + if (ret != 1) { + die("unexpectedly reached end of file"); + } + } + + std::uint32_t read_u32() { + std::uint32_t ret; + read_raw(&ret, sizeof(ret)); + return ret; + } + std::float_t read_f32() { + std::float_t ret; + read_raw(&ret, sizeof(ret)); + return ret; + } + + std::string read_string(std::uint32_t len) { + std::vector chars(len); + read_raw(chars.data(), len); + return std::string(chars.data(), len); + } + + ~my_llama_file() { + if (fp) { + std::fclose(fp); + } + } +}; + +static bool is_ggml_file(const char * filename) { + my_llama_file file(filename, "rb"); + if (file.size < 4) { + return false; + } + std::string magic = file.read_string(4); + return magic == GGUF_MAGIC; +} + +static std::string llama_escape_whitespaces(const std::string & text) { + std::ostringstream out; + for (char c : text) { + if (c == ' ') out << "\xe2\x96\x81"; + else out << c; + } + return out.str(); +} + +static void load_vocab(const char * filename, const Config * config, struct my_llama_vocab * vocab) { + if (is_ggml_file(filename)) { + LOG_INF("%s: Loading vocabulary from gguf file %s\n", __func__, filename); + struct ggml_context * ctx_data = NULL; + + struct gguf_init_params params = { + /*.no_alloc = */ false, + /*.ctx = */ &ctx_data, + }; + + struct gguf_context * ctx = gguf_init_from_file(filename, params); + GGML_ASSERT(ctx != NULL); + + const int model_idx = gguf_find_key(ctx, KV_TOKENIZER_MODEL); + GGML_ASSERT(model_idx >= 0); + std::string tokenizer_name = gguf_get_val_str(ctx, model_idx); + GGML_ASSERT(tokenizer_name == TOKENIZER_NAME); + + const int token_idx = gguf_find_key(ctx, KV_TOKENIZER_LIST); + GGML_ASSERT(token_idx >= 0); + + const int score_idx = gguf_find_key(ctx, KV_TOKENIZER_SCORES); + GGML_ASSERT(score_idx >= 0); + const float * scores = (const float * ) gguf_get_arr_data(ctx, score_idx); + + const int toktype_idx = gguf_find_key(ctx, KV_TOKENIZER_TOKEN_TYPE); + GGML_ASSERT(toktype_idx >= 0); + const int * toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); + + const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx); + if (n_vocab != static_cast(config->vocab_size)) { + die_fmt("vocab size mismatch: (gguf) %u != (llama2c) %d", n_vocab, config->vocab_size); + } + + vocab->id_to_token.resize(n_vocab); + + for (uint32_t i = 0; i < n_vocab; i++) { + std::string word = gguf_get_arr_str(ctx, token_idx, i); + + vocab->token_to_id[word] = i; + + auto & token_data = vocab->id_to_token[i]; + token_data.text = std::move(word); + token_data.score = scores[i]; + token_data.type = (llama_token_type) toktypes[i]; + } + ggml_free(ctx_data); + gguf_free(ctx); + } else { + // assume llama2.c vocabulary + LOG_INF("%s: Assuming llama2.c vocabulary since %s is not a gguf file\n", __func__, filename); + my_llama_file file(filename, "rb"); + if (!file.fp) { + die_fmt("%s: %s", strerror(errno), filename); + } + const int n_vocab = config->vocab_size; + /* uint32_t max_token_length = */ file.read_u32(); // unused + vocab->id_to_token.resize(n_vocab); + for (my_llama_vocab::id id=0; id", &byte_val) == 1) { + // Text of byte tokens is already in the expected format. + type = LLAMA_TOKEN_TYPE_BYTE; + } else { + type = LLAMA_TOKEN_TYPE_NORMAL; + } + text = llama_escape_whitespaces(text); + + vocab->id_to_token[id].text = text; + vocab->id_to_token[id].score = score; + vocab->id_to_token[id].type = type; + vocab->token_to_id.emplace(text, id); + } + } +} + +static void convert_weights_ak_to_gg(struct ggml_tensor * gg_weights, const float * karpathy_weights) { + int size = 1; + for (int dim = 0; dim < ggml_n_dims(gg_weights); ++dim) { + size *= gg_weights->ne[dim]; + } + for (int ct = 0; ct < size; ++ct) { + int64_t i0 = 0; int64_t i1 = 0; + int64_t i2 = 0; int64_t i3 = 0; + ggml_unravel_index(gg_weights, ct, &i0, &i1, &i2, &i3); + ggml_set_f32_nd(gg_weights, i0, i1, i2, i3, karpathy_weights[ct]); + } +} + +static void save_as_llama_model( + struct my_llama_vocab * vocab, struct my_llama_model * model, TransformerWeights* w, const char * filename +) { + // convert AK weights into GG weights one by one. + // w->token_embedding_table -> model->tok_embeddings + // float* -> struct ggml_tensor + convert_weights_ak_to_gg(model->tok_embeddings, w->token_embedding_table.data()); + convert_weights_ak_to_gg(model->output, !w->wcls.empty() ? w->wcls.data() : w->token_embedding_table.data()); + + convert_weights_ak_to_gg(model->norm, w->rms_final_weight.data()); + //print_row(model->norm, 0); + + // for rms-att-weight + int row_length = model->hparams.n_embd; + int n_ff = model->hparams.n_ff; + + const uint32_t n_multiqueries = model->hparams.n_head_kv <= 0 || model->hparams.n_head_kv >= model->hparams.n_head ? 1 : model->hparams.n_head / model->hparams.n_head_kv; + + for (uint32_t i = 0; i < model->hparams.n_layer; ++i){ + auto & layer = model->layers[i]; + // 1d + convert_weights_ak_to_gg(layer.attention_norm, &w->rms_att_weight[i*row_length]); + convert_weights_ak_to_gg(layer.ffn_norm , &w->rms_ffn_weight[i*row_length]); + + // from 3d matrix layer x dim x dim to 2d matrix dim x dim + convert_weights_ak_to_gg(layer.wq , &w->wq[i*row_length*row_length]); + convert_weights_ak_to_gg(layer.wo , &w->wo[i*row_length*row_length]); + // from 3d matrix layer x dim x dim to 2d matrix dim x dim / n_multiqueries + convert_weights_ak_to_gg(layer.wk , &w->wk[i*row_length*row_length/n_multiqueries]); + convert_weights_ak_to_gg(layer.wv , &w->wv[i*row_length*row_length/n_multiqueries]); + + convert_weights_ak_to_gg(layer.w1 , &w->w1[i*row_length*n_ff]); + convert_weights_ak_to_gg(layer.w2 , &w->w2[i*n_ff*row_length]); + convert_weights_ak_to_gg(layer.w3 , &w->w3[i*row_length*n_ff]); + } + + struct gguf_context * ctx = gguf_init_empty(); + + std::vector tokens; + std::vector scores; + std::vector token_types; + for (const my_llama_vocab::token_data & token_data : vocab->id_to_token) { + tokens.push_back(token_data.text.c_str()); + scores.push_back(token_data.score); + token_types.push_back(token_data.type); + } + gguf_set_arr_str(ctx, KV_TOKENIZER_LIST, tokens.data(), tokens.size()); + gguf_set_arr_data(ctx, KV_TOKENIZER_SCORES, GGUF_TYPE_FLOAT32, scores.data(), scores.size()); + gguf_set_arr_data(ctx, KV_TOKENIZER_TOKEN_TYPE, GGUF_TYPE_INT32, token_types.data(), token_types.size()); + + gguf_set_val_str(ctx, KV_TOKENIZER_MODEL, TOKENIZER_NAME); + + gguf_set_val_str(ctx, KV_GENERAL_ARCHITECTURE, "llama"); + gguf_set_val_str(ctx, KV_GENERAL_NAME, "llama"); + + // special tokens + gguf_set_val_u32(ctx, KV_TOKENIZER_UNK_ID, UNKNOWN_TOKEN_ID); + gguf_set_val_u32(ctx, KV_TOKENIZER_BOS_ID, BOS_TOKEN_ID); + gguf_set_val_u32(ctx, KV_TOKENIZER_EOS_ID, EOS_TOKEN_ID); + gguf_set_val_u32(ctx, KV_TOKENIZER_SEP_ID, LLAMA_TOKEN_NULL); + gguf_set_val_u32(ctx, KV_TOKENIZER_PAD_ID, LLAMA_TOKEN_NULL); + + gguf_set_val_u32(ctx, KV_CONTEXT_LENGTH, model->hparams.n_ctx); + gguf_set_val_u32(ctx, KV_EMBEDDING_LENGTH, model->hparams.n_embd); + gguf_set_val_u32(ctx, KV_FEED_FORWARD_LENGTH, model->hparams.n_ff); + gguf_set_val_u32(ctx, KV_ATTENTION_HEAD_COUNT, model->hparams.n_head); + gguf_set_val_u32(ctx, KV_ATTENTION_HEAD_COUNT, model->hparams.n_head); + gguf_set_val_u32(ctx, KV_ATTENTION_HEAD_COUNT_KV, model->hparams.n_head_kv); + gguf_set_val_u32(ctx, KV_BLOCK_COUNT, model->hparams.n_layer); + gguf_set_val_u32(ctx, KV_ROPE_DIMENSION_COUNT, model->hparams.n_rot); + gguf_set_val_f32(ctx, KV_ATTENTION_LAYERNORM_RMS_EPS, 1e-5f); + + // write tensors + ggml_set_name(model->tok_embeddings, TN_TOKEN_EMBD); + gguf_add_tensor(ctx, model->tok_embeddings); + + ggml_set_name(model->norm, TN_OUTPUT_NORM); + gguf_add_tensor(ctx, model->norm); + + ggml_set_name(model->output, TN_OUTPUT); + gguf_add_tensor(ctx, model->output); + + for (uint32_t i = 0; i < model->hparams.n_layer; ++i) { + auto & layer = model->layers[i]; + + ggml_format_name(layer.wq, TN_ATTN_Q, i); + gguf_add_tensor(ctx, layer.wq); + + ggml_format_name(layer.wk, TN_ATTN_K, i); + gguf_add_tensor(ctx, layer.wk); + + ggml_format_name(layer.wv, TN_ATTN_V, i); + gguf_add_tensor(ctx, layer.wv); + + ggml_format_name(layer.wo, TN_ATTN_OUTPUT, i); + gguf_add_tensor(ctx, layer.wo); + + ggml_format_name(layer.attention_norm, TN_ATTN_NORM, i); + gguf_add_tensor(ctx, layer.attention_norm); + + ggml_format_name(layer.w1, TN_FFN_GATE, i); + gguf_add_tensor(ctx, layer.w1); + + ggml_format_name(layer.w2, TN_FFN_DOWN, i); + gguf_add_tensor(ctx, layer.w2); + + ggml_format_name(layer.w3, TN_FFN_UP, i); + gguf_add_tensor(ctx, layer.w3); + + ggml_format_name(layer.ffn_norm, TN_FFN_NORM, i); + gguf_add_tensor(ctx, layer.ffn_norm); + } + + gguf_write_to_file(ctx, filename, false); + gguf_free(ctx); +} + +static struct train_params get_default_train_params() { + struct train_params params; + params.fn_vocab_model = "models/7B/ggml-model-f16.gguf"; + params.fn_llama2c_output_model = "ak_llama_model.bin"; + params.fn_train_data = "shakespeare.txt"; + params.fn_checkpoint_in = "checkpoint.bin"; + params.fn_checkpoint_out = "checkpoint.bin"; + params.fn_model_out = "ggml-checkpoint-f32.bin"; + + params.seed = -1; + + params.n_ctx = 128; + params.n_embd = 256; + params.n_mult = 256; + params.n_head = 8; + params.n_layer = 16; + params.n_rotmax = 64; + + params.n_threads = 6; + params.n_batch = 8; + params.n_examples = 8; + params.n_predict = 1024; + + params.print_info_interval = 1; + params.print_details_interval = 2; + + params.samples_start_after_nl = false; + params.use_adam = true; + params.use_flash = false; + params.use_scratch = true; + + // only adam + params.warmup = 100; + params.cos_decay_steps = 1000; + params.cos_decay_restart = 1.1f; + params.cos_decay_alpha = 0.0f; + + params.lbfgs_n_iter = 16; + params.adam_n_iter = 16; + params.adam_alpha = 1e-3f; + params.adam_decay = 1e-3f; + + params.mem_model_gb = 2; + params.mem_compute_gb = 24; + params.mem_compute0_gb = 8; + params.mem_compute1_gb = 2; + + return params; +} + +static void print_usage(int /*argc*/, char ** argv, const struct train_params * params) { + fprintf(stderr, "usage: %s [options]\n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " --copy-vocab-from-model FNAME path of gguf llama model or llama2.c vocabulary from which to copy vocab (default '%s')\n", params->fn_vocab_model); + fprintf(stderr, " --llama2c-model FNAME [REQUIRED] model path from which to load Karpathy's llama2.c model\n"); + fprintf(stderr, " --llama2c-output-model FNAME model path to save the converted llama2.c model (default %s')\n", params->fn_llama2c_output_model); + fprintf(stderr, "\n"); +} + +static bool params_parse(int argc, char ** argv, struct train_params * params) { + bool invalid_param = false; + bool reqd_param_found = false; + std::string arg; + struct train_params default_params = get_default_train_params(); + const std::string arg_prefix = "--"; + + for (int i = 1; i < argc; i++) { + arg = argv[i]; + if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { + std::replace(arg.begin(), arg.end(), '_', '-'); + } + + if (arg == "--copy-vocab-from-model") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->fn_vocab_model = argv[i]; + } else if (arg == "--llama2c-model") { + if (++i >= argc) { + invalid_param = true; + break; + } + reqd_param_found = true; + params->fn_llama2c_model = argv[i]; + } else if (arg == "--llama2c-output-model") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->fn_llama2c_output_model = argv[i]; + } else if (arg == "-h" || arg == "--help") { + print_usage(argc, argv, &default_params); + exit(0); + } else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + print_usage(argc, argv, &default_params); + exit(1); + } + } + if (invalid_param) { + fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); + print_usage(argc, argv, &default_params); + exit(1); + } + if (!reqd_param_found){ + fprintf(stderr, "error: please specify a llama2.c .bin file to be converted with argument --llama2c-model\n"); + print_usage(argc, argv, &default_params); + exit(1); + } + + return true; +} + +static std::string basename(const std::string &path) { + size_t pos = path.find_last_of("/\\"); + if (pos == std::string::npos) { + return path; + } + return path.substr(pos + 1); +} + +int main(int argc, char ** argv) { + common_init(); + + struct train_params params = get_default_train_params(); + if (!params_parse(argc, argv, ¶ms)) { + return 1; + } + + Config config; + TransformerWeights weights = {}; + { + LOG_INF("%s: Loading llama2c model from %s\n", __func__, params.fn_llama2c_model); + FILE * file = fopen(params.fn_llama2c_model, "rb"); + if (!file) { + LOG_ERR("%s: Unable to open the checkpoint file %s!\n", __func__, params.fn_llama2c_model); + return 1; + } + // read in the config header + if (fread(&config, sizeof(Config), 1, file) != 1) { + LOG_ERR("%s: Unable to read llama2c config from %s!\n",__func__,params.fn_llama2c_model); + return 1; + } + auto shared_weights = config.vocab_size > 0; + config.vocab_size = abs(config.vocab_size); + + // read in the Transformer weights + alloc_weights(&weights, &config, shared_weights); + if (checkpoint_init_weights(&weights, &config, file, shared_weights)) { + LOG_ERR("%s: Unable to initialize transformer weights from %s!",__func__,params.fn_llama2c_model); + return 1; + } + fclose(file); + } + + struct my_llama_vocab vocab; + load_vocab(params.fn_vocab_model, &config, &vocab); + + struct my_llama_model model; + model.hparams.n_vocab = config.vocab_size; //llama_vocab_n_vocab(lctx); + model.hparams.n_ctx = params.n_ctx; + model.hparams.n_embd = config.dim; //params.n_embd; + model.hparams.n_ff = config.hidden_dim; + model.hparams.n_mult = 32;//params.n_mult; + model.hparams.n_head = config.n_heads; //params.n_head; + model.hparams.n_head_kv = config.n_kv_heads; + model.hparams.n_layer = config.n_layers; //params.n_layer; + model.hparams.n_rot = std::min((uint32_t)params.n_rotmax, model.hparams.n_embd / model.hparams.n_head); + + print_params(&model.hparams); + + struct ggml_init_params lcparams; + lcparams.mem_size = 1024ll*1024ll*1024ll*((size_t) params.mem_model_gb); + lcparams.mem_buffer = NULL; + lcparams.no_alloc = false; + + model.ctx = ggml_init(lcparams); + + init_model(&model); + model.name = basename(params.fn_llama2c_model); + save_as_llama_model(&vocab, &model, &weights, params.fn_llama2c_output_model); + + LOG_INF("%s: Saving llama.c model file %s in ggml format at %s\n", __func__, params.fn_llama2c_model, params.fn_llama2c_output_model); + + ggml_free(model.ctx); + return 0; +} diff --git a/examples/convert_legacy_llama.py b/examples/convert_legacy_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..c4ec5c524e9b13063bcfc74b2bf1a7f8d17c8a36 --- /dev/null +++ b/examples/convert_legacy_llama.py @@ -0,0 +1,1462 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import logging +import argparse +import concurrent.futures +import enum +import faulthandler +import functools +import itertools +import json +import math +import mmap +import os +import pickle +import re +import signal +import struct +import sys +import textwrap +import time +import zipfile +from abc import ABC, abstractmethod +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, IO, Iterable, Literal, TypeVar + +import numpy as np + +if 'NO_LOCAL_GGUF' not in os.environ: + # use .parent.parent since we are in "examples" directory + sys.path.insert(1, str(Path(__file__).parent.parent / 'gguf-py')) + +import gguf +from gguf import BaseVocab, Vocab, NoVocab, BpeVocab, SentencePieceVocab, LlamaHfVocab + +if TYPE_CHECKING: + from typing_extensions import Self, TypeAlias + +logger = logging.getLogger("convert") + +if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'): + faulthandler.register(signal.SIGUSR1) + +NDArray: TypeAlias = 'np.ndarray[Any, Any]' + +ARCH = gguf.MODEL_ARCH.LLAMA + +DEFAULT_CONCURRENCY = 8 + +ADDED_TOKENS_FILE = 'added_tokens.json' +FAST_TOKENIZER_FILE = 'tokenizer.json' + +# +# data types +# + + +@dataclass(frozen=True) +class DataType: + name: str + dtype: np.dtype[Any] + valid_conversions: list[str] + + def elements_to_bytes(self, n_elements: int) -> int: + return n_elements * self.dtype.itemsize + + +@dataclass(frozen=True) +class UnquantizedDataType(DataType): + pass + + +DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0']) +DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0']) +DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = []) +DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0']) + + +@dataclass(frozen=True) +class QuantizedDataType(DataType): + block_size: int + quantized_dtype: np.dtype[Any] + ggml_type: gguf.GGMLQuantizationType + + def quantize(self, arr: NDArray) -> NDArray: + raise NotImplementedError(f'Quantization for {self.name} not implemented') + + def elements_to_bytes(self, n_elements: int) -> int: + assert n_elements % self.block_size == 0, f'Invalid number of elements {n_elements} for {self.name} with block size {self.block_size}' + return self.quantized_dtype.itemsize * (n_elements // self.block_size) + + +@dataclass(frozen=True) +class Q8_0QuantizedDataType(QuantizedDataType): + # Mini Q8_0 quantization in Python! + def quantize(self, arr: NDArray) -> NDArray: + assert arr.size % self.block_size == 0 and arr.size != 0, f'Bad array size {arr.size}' + assert arr.dtype == np.float32, f'Bad array type {arr.dtype}' + n_blocks = arr.size // self.block_size + blocks = arr.reshape((n_blocks, self.block_size)) + # Much faster implementation of block quantization contributed by @Cebtenzzre + + def quantize_blocks_q8_0(blocks: NDArray) -> Iterable[tuple[Any, Any]]: + d = abs(blocks).max(axis = 1) / np.float32(127) + with np.errstate(divide = 'ignore'): + qs = (blocks / d[:, None]).round() + qs[d == 0] = 0 + yield from zip(d, qs) + return np.fromiter(quantize_blocks_q8_0(blocks), count = n_blocks, dtype = self.quantized_dtype) + + +DT_Q8_0 = Q8_0QuantizedDataType('Q8_0', + dtype = np.dtype(np.float32), valid_conversions = [], + ggml_type = gguf.GGMLQuantizationType.Q8_0, block_size = 32, + quantized_dtype = np.dtype([('d', ' DataType: + dt = GGML_FILE_TYPE_TO_DATA_TYPE.get(self) + if dt is None: + raise ValueError(self) + # Convert all 1D tensors to F32. Most of the codebase that takes in 1D tensors only handles F32 tensors, and most of the outputs tensors are F32. + # Also The 1d tensors aren't much of a performance/size issue. So instead of having to have separate F32 and F16 implementations of both, just convert everything to F32 for now. + return dt if len(tensor.shape) > 1 else DT_F32 + + +GGML_FILE_TYPE_TO_DATA_TYPE: dict[GGMLFileType, DataType] = { + GGMLFileType.AllF32 : DT_F32, + GGMLFileType.MostlyF16 : DT_F16, + GGMLFileType.MostlyQ8_0: DT_Q8_0, +} + +# +# hparams loading +# + + +@dataclass +class Params: + n_vocab: int + n_embd: int + n_layer: int + n_ctx: int + n_ff: int + n_head: int + n_head_kv: int + n_experts: int | None = None + n_experts_used: int | None = None + f_norm_eps: float | None = None + + rope_scaling_type: gguf.RopeScalingType | None = None + f_rope_freq_base: float | None = None + f_rope_scale: float | None = None + n_ctx_orig: int | None = None + rope_finetuned: bool | None = None + + ftype: GGMLFileType | None = None + + # path to the directory containing the model files + path_model: Path | None = None + + @staticmethod + def guessed(model: LazyModel) -> Params: + # try transformer naming first + n_vocab, n_embd = model["model.embed_tokens.weight"].shape if "model.embed_tokens.weight" in model else model["tok_embeddings.weight"].shape + + # try transformer naming first + if "model.layers.0.self_attn.q_proj.weight" in model: + n_layer = next(i for i in itertools.count() if f"model.layers.{i}.self_attn.q_proj.weight" not in model) + elif "model.layers.0.self_attn.W_pack.weight" in model: # next: try baichuan naming + n_layer = next(i for i in itertools.count() if f"model.layers.{i}.self_attn.W_pack.weight" not in model) + else: + n_layer = next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model) + + if n_layer < 1: + msg = """\ + failed to guess 'n_layer'. This model is unknown or unsupported. + Suggestion: provide 'config.json' of the model in the same directory containing model files.""" + raise KeyError(textwrap.dedent(msg)) + + n_head = n_embd // 128 # guessed + n_mult = 256 # guessed + + # TODO: verify this + n_ff = int(2 * (4 * n_embd) / 3) + n_ff = n_mult * ((n_ff + n_mult - 1) // n_mult) + + return Params( + n_vocab = n_vocab, + n_embd = n_embd, + n_layer = n_layer, + n_ctx = -1, + n_ff = n_ff, + n_head = n_head, + n_head_kv = n_head, + f_norm_eps = 1e-5, + ) + + @staticmethod + def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params: + with open(config_path) as f: + config = json.load(f) + + rope_scaling_type = f_rope_scale = n_ctx_orig = rope_finetuned = None + rope_scaling = config.get("rope_scaling") + + if rope_scaling is not None and (typ := rope_scaling.get("type")): + rope_factor = rope_scaling.get("factor") + f_rope_scale = rope_factor + if typ == "linear": + rope_scaling_type = gguf.RopeScalingType.LINEAR + elif typ == "yarn": + rope_scaling_type = gguf.RopeScalingType.YARN + n_ctx_orig = rope_scaling['original_max_position_embeddings'] + rope_finetuned = rope_scaling['finetuned'] + else: + raise NotImplementedError(f'Unknown rope scaling type: {typ}') + + if "max_sequence_length" in config: + n_ctx = config["max_sequence_length"] + elif "max_position_embeddings" in config: + n_ctx = config["max_position_embeddings"] + else: + msg = """\ + failed to guess 'n_ctx'. This model is unknown or unsupported. + Suggestion: provide 'config.json' of the model in the same directory containing model files.""" + raise KeyError(textwrap.dedent(msg)) + + n_experts = None + n_experts_used = None + + if "num_local_experts" in config: + n_experts = config["num_local_experts"] + n_experts_used = config["num_experts_per_tok"] + + return Params( + n_vocab = config["vocab_size"], + n_embd = config["hidden_size"], + n_layer = config["num_hidden_layers"], + n_ctx = n_ctx, + n_ff = config["intermediate_size"], + n_head = (n_head := config["num_attention_heads"]), + n_head_kv = config.get("num_key_value_heads", n_head), + n_experts = n_experts, + n_experts_used = n_experts_used, + f_norm_eps = config["rms_norm_eps"], + f_rope_freq_base = config.get("rope_theta"), + rope_scaling_type = rope_scaling_type, + f_rope_scale = f_rope_scale, + n_ctx_orig = n_ctx_orig, + rope_finetuned = rope_finetuned, + ) + + # LLaMA v2 70B params.json + # {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1} + @staticmethod + def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params: + with open(config_path) as f: + config = json.load(f) + + n_experts = None + n_experts_used = None + f_rope_freq_base = None + n_ff = None + + # hack to determine LLaMA v1 vs v2 vs CodeLlama + if config.get("moe"): + # Mixtral + n_ctx = 32768 + elif config.get("rope_theta") == 1000000: + # CodeLlama + n_ctx = 16384 + elif config["norm_eps"] == 1e-05: + # LLaMA v2 + n_ctx = 4096 + else: + # LLaMA v1 + n_ctx = 2048 + + if "layers.0.feed_forward.w1.weight" in model: + n_ff = model["layers.0.feed_forward.w1.weight"].shape[0] + + if config.get("moe"): + n_ff = model["layers.0.feed_forward.experts.0.w1.weight"].shape[0] + n_experts = config["moe"]["num_experts"] + n_experts_used = config["moe"]["num_experts_per_tok"] + f_rope_freq_base = 1e6 + + assert n_ff is not None + + return Params( + n_vocab = model["tok_embeddings.weight"].shape[0], + n_embd = config["dim"], + n_layer = config["n_layers"], + n_ctx = n_ctx, + n_ff = n_ff, + n_head = (n_head := config["n_heads"]), + n_head_kv = config.get("n_kv_heads", n_head), + n_experts = n_experts, + n_experts_used = n_experts_used, + f_norm_eps = config["norm_eps"], + f_rope_freq_base = config.get("rope_theta", f_rope_freq_base), + ) + + @staticmethod + def load(model_plus: ModelPlus) -> Params: + hf_config_path = model_plus.paths[0].parent / "config.json" + orig_config_path = model_plus.paths[0].parent / "params.json" + + if hf_config_path.exists(): + params = Params.loadHFTransformerJson(model_plus.model, hf_config_path) + elif orig_config_path.exists(): + params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path) + elif model_plus.format != 'none': + params = Params.guessed(model_plus.model) + else: + raise ValueError('Cannot guess params when model format is none') + + params.path_model = model_plus.paths[0].parent + + return params + + +# +# data loading +# TODO: reuse (probably move to gguf.py?) +# + + +def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray: + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + + +class Tensor(ABC): + ndarray: NDArray + data_type: DataType + + @abstractmethod + def astype(self, data_type: DataType) -> Self: ... + @abstractmethod + def permute(self, n_head: int, n_head_kv: int) -> Self: ... + @abstractmethod + def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> Self: ... + @abstractmethod + def part(self, n_part: int) -> Self: ... + @abstractmethod + def to_ggml(self) -> GGMLCompatibleTensor: ... + + +def bf16_to_fp32(bf16_arr: np.ndarray[Any, np.dtype[np.uint16]]) -> NDArray: + assert bf16_arr.dtype == np.uint16, f"Input array should be of dtype uint16, but got {bf16_arr.dtype}" + fp32_arr = bf16_arr.astype(np.uint32) << 16 + return fp32_arr.view(np.float32) + + +class UnquantizedTensor(Tensor): + def __init__(self, ndarray: NDArray): + assert isinstance(ndarray, np.ndarray) + self.ndarray = ndarray + self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype] + + def astype(self, data_type: DataType) -> UnquantizedTensor: + dtype = data_type.dtype + if self.data_type == DT_BF16: + self.ndarray = bf16_to_fp32(self.ndarray) + return UnquantizedTensor(self.ndarray.astype(dtype)) + + def to_ggml(self) -> Self: + return self + + def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: + r = self.ndarray.shape[0] // 3 + return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head_kv)) + + def part(self, n_part: int) -> UnquantizedTensor: + r = self.ndarray.shape[0] // 3 + return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...]) + + def permute(self, n_head: int, n_head_kv: int) -> UnquantizedTensor: + return UnquantizedTensor(permute(self.ndarray, n_head, n_head_kv)) + + +def load_unquantized(lazy_tensor: LazyTensor, expected_dtype: Any = None, convert: bool = False) -> NDArray: + tensor = lazy_tensor.load() + assert isinstance(tensor, UnquantizedTensor) + + # double-check: + actual_shape = list(tensor.ndarray.shape) + assert actual_shape == lazy_tensor.shape, (actual_shape, lazy_tensor.shape) + if expected_dtype is not None and expected_dtype != tensor.ndarray.dtype: + if convert: + tensor.ndarray = tensor.ndarray.astype(expected_dtype) + else: + raise ValueError(f'expected this tensor to have dtype {expected_dtype}, got {tensor.ndarray.dtype}') + + return tensor.ndarray + + +GGMLCompatibleTensor = UnquantizedTensor + + +@dataclass +class LazyTensor: + _load: Callable[[], Tensor] + shape: list[int] + data_type: DataType + description: str + + def load(self) -> Tensor: + ret = self._load() + # Should be okay if it maps to the same numpy type? + assert ret.data_type == self.data_type or (self.data_type.dtype == ret.data_type.dtype), \ + (self.data_type, ret.data_type, self.description) + return ret + + def astype(self, data_type: DataType) -> LazyTensor: + self.validate_conversion_to(data_type) + + def load() -> Tensor: + return self.load().astype(data_type) + return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}') + + def validate_conversion_to(self, data_type: DataType) -> None: + if data_type != self.data_type and data_type.name not in self.data_type.valid_conversions: + raise ValueError(f'Cannot validate conversion from {self.data_type} to {data_type}.') + + +LazyModel: TypeAlias = 'dict[str, LazyTensor]' + +ModelFormat: TypeAlias = Literal['ggml', 'torch', 'safetensors', 'none'] + +@dataclass +class ModelPlus: + model: LazyModel + paths: list[Path] # Where this was read from. + format: ModelFormat + vocab: BaseVocab | None # For GGML models (which have vocab built in), the vocab. + + +def merge_sharded(models: list[LazyModel]) -> LazyModel: + # Original LLaMA models have each file contain one part of each tensor. + # Use a dict instead of a set to preserve order. + names = {name: None for model in models for name in model} + + def convert(name: str) -> LazyTensor: + lazy_tensors = [model[name] for model in models] + if len(lazy_tensors) == 1: + # only one file; don't go through this procedure since there might + # be quantized tensors + return lazy_tensors[0] + if len(lazy_tensors[0].shape) == 1: + # the tensor is just duplicated in every file + return lazy_tensors[0] + if name.startswith('tok_embeddings.') or \ + name.endswith('.attention.wo.weight') or \ + name.endswith('.feed_forward.w2.weight'): + # split by columns + axis = 1 + else: + # split by rows + axis = 0 + concatenated_shape = list(lazy_tensors[0].shape) + concatenated_shape[axis] = sum(tensor.shape[axis] for tensor in lazy_tensors) + + def load() -> UnquantizedTensor: + ndarrays = [load_unquantized(tensor) for tensor in lazy_tensors] + concatenated = np.concatenate(ndarrays, axis=axis) + return UnquantizedTensor(concatenated) + description = 'concatenated[[' + '] | ['.join(lt.description for lt in lazy_tensors) + ']]' + return LazyTensor(load, concatenated_shape, lazy_tensors[0].data_type, description) + return {name: convert(name) for name in names} + + +def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus: + formats: set[ModelFormat] = set(mp.format for mp in models_plus) + assert len(formats) == 1, "different formats?" + format = formats.pop() + paths = [path for mp in models_plus for path in mp.paths] + # Use the first non-None vocab, if any. + try: + vocab = next(mp.vocab for mp in models_plus if mp.vocab is not None) + except StopIteration: + vocab = None + + if any("model.embed_tokens.weight" in mp.model for mp in models_plus): + # Transformers models put different tensors in different files, but + # don't split individual tensors between files. + model: LazyModel = {} + for mp in models_plus: + model.update(mp.model) + else: + model = merge_sharded([mp.model for mp in models_plus]) + + return ModelPlus(model, paths, format, vocab) + + +def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTensor: + def load() -> Tensor: + return lazy_tensor.load().permute(n_head, n_head_kv) + return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description) + + +def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int, n_head_kv: int) -> LazyTensor: + def load() -> Tensor: + return lazy_tensor.load().permute_part(n_part, n_head, n_head_kv) + s = lazy_tensor.shape.copy() + s[0] = s[0] // 3 + return LazyTensor(load, s, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description) + + +def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor: + def load() -> Tensor: + return lazy_tensor.load().part(n_part) + s = lazy_tensor.shape.copy() + s[0] = s[0] // 3 + return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description) + + +def pack_experts_lazy(lazy_tensors: list[LazyTensor]) -> LazyTensor: + def load() -> Tensor: + tensors = [lazy_tensor.load() for lazy_tensor in lazy_tensors] + return UnquantizedTensor(np.array([tensor.ndarray for tensor in tensors])) + s = lazy_tensors[0].shape.copy() + s.insert(0, len(lazy_tensors)) + return LazyTensor(load, s, lazy_tensors[0].data_type, 'pack_experts ' + ' | '.join(lt.description for lt in lazy_tensors)) + + +# Functionality that simulates `torch.load` but where individual tensors are +# only loaded into memory on demand, not all at once. +# PyTorch can't do this natively as of time of writing: +# - https://github.com/pytorch/pytorch/issues/64327 +# This allows us to de-shard without multiplying RAM usage, and also +# conveniently drops the PyTorch dependency (though we still need numpy). + + +@dataclass +class LazyStorageKind: + data_type: DataType + + +@dataclass +class LazyStorage: + load: Callable[[int, int], NDArray] + kind: LazyStorageKind + description: str + + +class LazyUnpickler(pickle.Unpickler): + def __init__(self, fp: IO[bytes], data_base_path: str, zip_file: zipfile.ZipFile): + super().__init__(fp) + self.data_base_path = data_base_path + self.zip_file = zip_file + + def persistent_load(self, pid: Any) -> Any: + assert pid[0] == 'storage' + assert isinstance(pid[1], LazyStorageKind) + data_type = pid[1].data_type + filename_stem = pid[2] + filename = f'{self.data_base_path}/{filename_stem}' + info = self.zip_file.getinfo(filename) + + def load(offset: int, elm_count: int) -> NDArray: + dtype = data_type.dtype + with self.zip_file.open(info) as fp: + fp.seek(offset * dtype.itemsize) + size = elm_count * dtype.itemsize + data = fp.read(size) + assert len(data) == size + return np.frombuffer(data, dtype) + description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}' + return LazyStorage(load=load, kind=pid[1], description=description) + + @staticmethod + def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any, + requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor: + assert isinstance(storage, LazyStorage) + + def load() -> UnquantizedTensor: + elm_count = stride[0] * size[0] + return UnquantizedTensor(storage.load(storage_offset, elm_count).reshape(size)) + description = f'pickled storage_offset={storage_offset} in {storage.description}' + return LazyTensor(load, list(size), storage.kind.data_type, description) + + @staticmethod + def rebuild_from_type_v2(func, new_type, args, state): + return func(*args) + + CLASSES: dict[tuple[str, str], type[LazyTensor] | LazyStorageKind] = { + # getattr used here as a workaround for mypy not being smart enough to determine + # the staticmethods have a __func__ attribute. + ('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'), + ('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'), + ('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16), + ('torch', 'HalfStorage'): LazyStorageKind(DT_F16), + ('torch', 'FloatStorage'): LazyStorageKind(DT_F32), + ('torch', 'IntStorage'): LazyStorageKind(DT_I32), + ('torch', 'Tensor'): LazyTensor, + } + + def find_class(self, module: str, name: str) -> Any: + if not module.startswith('torch'): + return super().find_class(module, name) + return self.CLASSES[(module, name)] + + +def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus: + zf = zipfile.ZipFile(outer_fp) + pickle_paths = [name for name in zf.namelist() if name.endswith('.pkl')] + assert len(pickle_paths) == 1, pickle_paths + pickle_fp = zf.open(pickle_paths[0], 'r') + unpickler = LazyUnpickler(pickle_fp, + data_base_path=pickle_paths[0][:-4], + zip_file=zf) + model = unpickler.load() + if 'model' in model: model = model['model'] + as_dict = dict(model.items()) + return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None) + + +def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus: + header_size, = struct.unpack(' LazyTensor: + data_type = SAFETENSORS_DATA_TYPES[info['dtype']] + numpy_dtype = data_type.dtype + shape: list[int] = info['shape'] + begin, end = info['data_offsets'] + assert 0 <= begin <= end <= len(byte_buf) + assert end - begin == math.prod(shape) * numpy_dtype.itemsize + buf = byte_buf[begin:end] + + def load() -> UnquantizedTensor: + return UnquantizedTensor(np.frombuffer(buf, dtype=numpy_dtype).reshape(shape)) + description = f'safetensors begin={begin} end={end} type={data_type} path={path}' + return LazyTensor(load, shape, data_type, description) + model = {name: convert(info) for (name, info) in header.items() if name != '__metadata__'} + return ModelPlus(model=model, paths=[path], format='safetensors', vocab=None) + + +def must_read(fp: IO[bytes], length: int) -> bytes: + ret = fp.read(length) + if len(ret) < length: + raise EOFError("unexpectedly reached end of file") + return ret + + +@functools.lru_cache(maxsize=None) +def lazy_load_file(path: Path) -> ModelPlus: + fp = open(path, 'rb') + first8 = fp.read(8) + fp.seek(0) + if first8[:2] == b'PK': + # A zip file, i.e. PyTorch format + return lazy_load_torch_file(fp, path) + elif struct.unpack(' Iterable[Out]: + '''Parallel map, but with backpressure. If the caller doesn't call `next` + fast enough, this will stop calling `func` at some point rather than + letting results pile up in memory. Specifically, there is a max of one + output value buffered per thread.''' + if concurrency < 2: + yield from map(func, iterable) + # Not reached. + iterable = iter(iterable) + executor_class: type[ThreadPoolExecutor] | type[ProcessPoolExecutor] + if use_processpool_executor: + executor_class = ProcessPoolExecutor + else: + executor_class = ThreadPoolExecutor + with executor_class(max_workers=max_workers) as executor: + futures: list[concurrent.futures.Future[Out]] = [] + done = False + for _ in range(concurrency): + try: + futures.append(executor.submit(func, next(iterable))) + except StopIteration: + done = True + break + + while futures: + result = futures.pop(0).result() + while not done and len(futures) < concurrency: + try: + futures.append(executor.submit(func, next(iterable))) + except StopIteration: + done = True + break + yield result + + +def check_vocab_size(params: Params, vocab: BaseVocab, pad_vocab: bool = False) -> None: + # Handle special case where the model's vocab size is not set + if params.n_vocab == -1: + raise ValueError( + "The model's vocab size is set to -1 in params.json. Please update it manually." + + (f" Maybe {vocab.vocab_size}?" if isinstance(vocab, Vocab) else ""), + ) + if not isinstance(vocab, Vocab): + return # model has no vocab + + # Check for a vocab size mismatch + if params.n_vocab == vocab.vocab_size: + logger.warning("Ignoring added_tokens.json since model matches vocab size without it.") + return + + if pad_vocab and params.n_vocab > vocab.vocab_size: + pad_count = params.n_vocab - vocab.vocab_size + logger.debug( + f"Padding vocab with {pad_count} token(s) - through " + ) + for i in range(1, pad_count + 1): + vocab.added_tokens_dict[f""] = -1 + vocab.added_tokens_list.append(f"") + vocab.vocab_size = params.n_vocab + return + + msg = f"Vocab size mismatch (model has {params.n_vocab}, but {vocab.fname_tokenizer} has {vocab.vocab_size})." + if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20: + msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})." + if vocab.vocab_size < params.n_vocab: + msg += " Add the --pad-vocab option and try again." + + raise ValueError(msg) + + +class OutputFile: + def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE): + self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess) + + def add_meta_model(self, params: Params, metadata: gguf.Metadata | None) -> None: + # Metadata About The Model And Its Provenence + name = "LLaMA" + if metadata is not None and metadata.name is not None: + name = metadata.name + elif params.path_model is not None: + name = params.path_model.name + elif params.n_ctx == 4096: + # Heuristic detection of LLaMA v2 model + name = "LLaMA v2" + + self.gguf.add_name(name) + + if metadata is not None: + if metadata.author is not None: + self.gguf.add_author(metadata.author) + if metadata.version is not None: + self.gguf.add_version(metadata.version) + if metadata.organization is not None: + self.gguf.add_organization(metadata.organization) + + if metadata.finetune is not None: + self.gguf.add_finetune(metadata.finetune) + if metadata.basename is not None: + self.gguf.add_basename(metadata.basename) + + if metadata.description is not None: + self.gguf.add_description(metadata.description) + if metadata.quantized_by is not None: + self.gguf.add_quantized_by(metadata.quantized_by) + + if metadata.size_label is not None: + self.gguf.add_size_label(metadata.size_label) + + if metadata.license is not None: + self.gguf.add_license(metadata.license) + if metadata.license_name is not None: + self.gguf.add_license_name(metadata.license_name) + if metadata.license_link is not None: + self.gguf.add_license_link(metadata.license_link) + + if metadata.url is not None: + self.gguf.add_url(metadata.url) + if metadata.doi is not None: + self.gguf.add_doi(metadata.doi) + if metadata.uuid is not None: + self.gguf.add_uuid(metadata.uuid) + if metadata.repo_url is not None: + self.gguf.add_repo_url(metadata.repo_url) + + if metadata.source_url is not None: + self.gguf.add_source_url(metadata.source_url) + if metadata.source_doi is not None: + self.gguf.add_source_doi(metadata.source_doi) + if metadata.source_uuid is not None: + self.gguf.add_source_uuid(metadata.source_uuid) + if metadata.source_repo_url is not None: + self.gguf.add_source_repo_url(metadata.source_repo_url) + + if metadata.base_models is not None: + self.gguf.add_base_model_count(len(metadata.base_models)) + for key, base_model_entry in enumerate(metadata.base_models): + if "name" in base_model_entry: + self.gguf.add_base_model_name(key, base_model_entry["name"]) + if "author" in base_model_entry: + self.gguf.add_base_model_author(key, base_model_entry["author"]) + if "version" in base_model_entry: + self.gguf.add_base_model_version(key, base_model_entry["version"]) + if "organization" in base_model_entry: + self.gguf.add_base_model_organization(key, base_model_entry["organization"]) + if "description" in base_model_entry: + self.gguf.add_base_model_description(key, base_model_entry["description"]) + if "url" in base_model_entry: + self.gguf.add_base_model_url(key, base_model_entry["url"]) + if "doi" in base_model_entry: + self.gguf.add_base_model_doi(key, base_model_entry["doi"]) + if "uuid" in base_model_entry: + self.gguf.add_base_model_uuid(key, base_model_entry["uuid"]) + if "repo_url" in base_model_entry: + self.gguf.add_base_model_repo_url(key, base_model_entry["repo_url"]) + + if metadata.datasets is not None: + self.gguf.add_dataset_count(len(metadata.datasets)) + for key, dataset_entry in enumerate(metadata.datasets): + if "name" in dataset_entry: + self.gguf.add_dataset_name(key, dataset_entry["name"]) + if "author" in dataset_entry: + self.gguf.add_dataset_author(key, dataset_entry["author"]) + if "version" in dataset_entry: + self.gguf.add_dataset_version(key, dataset_entry["version"]) + if "organization" in dataset_entry: + self.gguf.add_dataset_organization(key, dataset_entry["organization"]) + if "description" in dataset_entry: + self.gguf.add_dataset_description(key, dataset_entry["description"]) + if "url" in dataset_entry: + self.gguf.add_dataset_url(key, dataset_entry["url"]) + if "doi" in dataset_entry: + self.gguf.add_dataset_doi(key, dataset_entry["doi"]) + if "uuid" in dataset_entry: + self.gguf.add_dataset_uuid(key, dataset_entry["uuid"]) + if "repo_url" in dataset_entry: + self.gguf.add_dataset_repo_url(key, dataset_entry["repo_url"]) + + if metadata.tags is not None: + self.gguf.add_tags(metadata.tags) + if metadata.languages is not None: + self.gguf.add_languages(metadata.languages) + + def add_meta_arch(self, params: Params) -> None: + # Metadata About The Neural Architecture Itself + self.gguf.add_vocab_size(params.n_vocab) + self.gguf.add_context_length(params.n_ctx) + self.gguf.add_embedding_length(params.n_embd) + self.gguf.add_block_count(params.n_layer) + self.gguf.add_feed_forward_length(params.n_ff) + self.gguf.add_rope_dimension_count(params.n_embd // params.n_head) + self.gguf.add_head_count (params.n_head) + self.gguf.add_head_count_kv (params.n_head_kv) + + if params.n_experts: + self.gguf.add_expert_count(params.n_experts) + + if params.n_experts_used: + self.gguf.add_expert_used_count(params.n_experts_used) + + if params.f_norm_eps: + self.gguf.add_layer_norm_rms_eps(params.f_norm_eps) + else: + raise ValueError('f_norm_eps is None') + + if params.f_rope_freq_base is not None: + self.gguf.add_rope_freq_base(params.f_rope_freq_base) + + if params.rope_scaling_type: + assert params.f_rope_scale is not None + self.gguf.add_rope_scaling_type(params.rope_scaling_type) + self.gguf.add_rope_scaling_factor(params.f_rope_scale) + + if params.n_ctx_orig is not None: + self.gguf.add_rope_scaling_orig_ctx_len(params.n_ctx_orig) + + if params.rope_finetuned is not None: + self.gguf.add_rope_scaling_finetuned(params.rope_finetuned) + + if params.ftype is not None: + self.gguf.add_file_type(params.ftype) + + def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list[float], list[gguf.TokenType]]: + tokens = [] + scores = [] + toktypes = [] + + # NOTE: `all_tokens` returns the base vocabulary and added tokens + for text, score, toktype in vocab.all_tokens(): + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + assert len(tokens) == vocab.vocab_size + + return tokens, scores, toktypes + + def add_meta_vocab(self, vocab: Vocab) -> None: + # Ensure that tokenizer_model is added to the GGUF model + self.gguf.add_tokenizer_model(vocab.tokenizer_model) + + # Extract model vocabulary for model conversion + tokens, scores, toktypes = self.extract_vocabulary_from_model(vocab) + + # Add extracted token information for model conversion + self.gguf.add_token_list(tokens) + self.gguf.add_token_scores(scores) + self.gguf.add_token_types(toktypes) + + def add_meta_special_vocab(self, svocab: gguf.SpecialVocab) -> None: + svocab.add_to_gguf(self.gguf) + + def add_tensor_info(self, name: str, tensor: LazyTensor) -> None: + n_elements = int(np.prod(tensor.shape)) + raw_dtype = getattr(tensor.data_type, 'ggml_type', None) + data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype + data_nbytes = tensor.data_type.elements_to_bytes(n_elements) + self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype=raw_dtype) + + def write_meta(self) -> None: + self.gguf.write_header_to_file() + self.gguf.write_kv_data_to_file() + + def write_tensor_info(self) -> None: + self.gguf.write_ti_data_to_file() + + def write_tensor_data(self, ftype: GGMLFileType, model: LazyModel, concurrency: int) -> None: + ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency=concurrency) + if ftype == GGMLFileType.MostlyQ8_0: + ndarrays = bounded_parallel_map( + OutputFile.maybe_do_quantize, ndarrays_inner, concurrency=concurrency, max_workers=concurrency, + use_processpool_executor=True, + ) + else: + ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner) + + start = time.time() + for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): + elapsed = time.time() - start + size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape) + padi = len(str(len(model))) + logger.info( + f"[{i + 1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}" + ) + self.gguf.write_tensor_data(ndarray) + + def close(self) -> None: + self.gguf.close() + + @staticmethod + def write_vocab_only( + fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab, + endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: gguf.Metadata | None = None, + ) -> None: + check_vocab_size(params, vocab, pad_vocab=pad_vocab) + + of = OutputFile(fname_out, endianess=endianess) + + # meta data + of.add_meta_model(params, metadata) + of.add_meta_arch(params) + of.add_meta_vocab(vocab) + of.add_meta_special_vocab(svocab) + + of.write_meta() + + of.close() + + @staticmethod + def do_item(item: tuple[str, LazyTensor]) -> tuple[DataType, NDArray]: + name, lazy_tensor = item + tensor = lazy_tensor.load().to_ggml() + return (lazy_tensor.data_type, tensor.ndarray) + + @staticmethod + def maybe_do_quantize(item: tuple[DataType, NDArray]) -> NDArray: + dt, arr = item + if not isinstance(dt, QuantizedDataType): + return arr + return dt.quantize(arr) + + @staticmethod + def write_all( + fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab, + concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, + pad_vocab: bool = False, + metadata: gguf.Metadata | None = None, + ) -> None: + check_vocab_size(params, vocab, pad_vocab=pad_vocab) + + of = OutputFile(fname_out, endianess=endianess) + + # meta data + of.add_meta_model(params, metadata) + of.add_meta_arch(params) + if isinstance(vocab, Vocab): + of.add_meta_vocab(vocab) + of.add_meta_special_vocab(svocab) + else: # NoVocab + of.gguf.add_tokenizer_model(vocab.tokenizer_model) + + # tensor info + for name, lazy_tensor in model.items(): + of.add_tensor_info(name, lazy_tensor) + + of.write_meta() + of.write_tensor_info() + + # tensor data + of.write_tensor_data(ftype, model, concurrency) + + of.close() + + +def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType: + wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight"].data_type + + if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)): + return GGMLFileType.AllF32 + if output_type_str == "f16" or (output_type_str is None and wq_type == DT_F16): + return GGMLFileType.MostlyF16 + if output_type_str == "q8_0": + return GGMLFileType.MostlyQ8_0 + + name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()} + + raise ValueError(f"Unexpected combination of types: {name_to_type}") + + +def per_model_weight_count_estimation(tensors: Iterable[tuple[str, LazyTensor]]) -> tuple[int, int, int]: + total_params = 0 + shared_params = 0 + expert_params = 0 + + for name, lazy_tensor in tensors: + # We don't need these + if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): + continue + + # Got A Tensor + sum_weights_in_tensor: int = 1 + + # Tensor Volume + for dim in lazy_tensor.shape: + sum_weights_in_tensor *= dim + + if ".experts." in name: + if ".experts.0." in name: + expert_params += sum_weights_in_tensor + else: + shared_params += sum_weights_in_tensor + + total_params += sum_weights_in_tensor + + return total_params, shared_params, expert_params + + +def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel: + return {name: tensor.astype(output_type.type_for_tensor(name, tensor)) + for (name, tensor) in model.items()} + + +def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) -> LazyModel: + tmap = gguf.TensorNameMap(ARCH, params.n_layer) + should_skip = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, [])) + + tmp = model + + # merge experts into one tensor + if params.n_experts and params.n_experts > 0: + for i_l in range(params.n_layer): + for w in range(1, 4): + experts = [] + for e in range(params.n_experts): + if f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight" in model: + experts.append(model[f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight"]) + del tmp[f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight"] + elif f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight" in model: + experts.append(model[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"]) + del tmp[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"] + else: + raise ValueError(f"Expert tensor not found: layers.{i_l}.feed_forward.experts.{e}.w{w}.weight") + tmp[f"layers.{i_l}.feed_forward.experts.w{w}.weight"] = pack_experts_lazy(experts) + + # HF models permut or pack some of the tensors, so we need to undo that + for i in itertools.count(): + if f"model.layers.{i}.self_attn.q_proj.weight" in model: + logger.debug(f"Permuting layer {i}") + tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head, params.n_head) + tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_head_kv) + # tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"] + elif f"model.layers.{i}.self_attn.W_pack.weight" in model: + logger.debug(f"Unpacking and permuting layer {i}") + tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head, params.n_head) + tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head, params.n_head_kv) + tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = part_lazy (model[f"model.layers.{i}.self_attn.W_pack.weight"], 2) + del tmp[f"model.layers.{i}.self_attn.W_pack.weight"] + else: + break + + out: LazyModel = {} + for name, lazy_tensor in model.items(): + tensor_type, name_new = tmap.get_type_and_name(name, try_suffixes = (".weight", ".bias")) or (None, None) + if name_new is None: + if skip_unknown: + logger.warning(f"Unexpected tensor name: {name} - skipping") + continue + raise ValueError(f"Unexpected tensor name: {name}. Use --skip-unknown to ignore it (e.g. LLaVA)") + + if tensor_type in should_skip: + logger.debug(f"skipping tensor {name_new}") + continue + + logger.debug(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}") + out[name_new] = lazy_tensor + + return out + + +def nth_multifile_path(path: Path, n: int) -> Path | None: + '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return + the nth path in the model. + ''' + # Support the following patterns: + patterns = [ + # - x.00.pth, x.01.pth, etc. + (r'\.[0-9]{2}\.pth$', f'.{n:02}.pth'), + # - x-00001-of-00002.bin, x-00002-of-00002.bin, etc. + (r'-[0-9]{5}-of-(.*)$', fr'-{n:05}-of-\1'), + # x.bin, x.bin.1, etc. + (r'(\.[0-9]+)?$', r'\1' if n == 0 else fr'\1.{n}') + ] + for regex, replacement in patterns: + if re.search(regex, path.name): + new_path = path.with_name(re.sub(regex, replacement, path.name)) + if new_path.exists(): + return new_path + return None + + +def find_multifile_paths(path: Path) -> list[Path]: + '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return + the whole list of paths in the model. + ''' + ret: list[Path] = [] + for i in itertools.count(): + nth_path = nth_multifile_path(path, i) + if nth_path is None: + break + ret.append(nth_path) + if not ret: + # No matches. This should only happen if the file was named, e.g., + # foo.0, and there was no file named foo. Oh well, try to process it + # as a single file. + return [path] + return ret + + +def load_some_model(path: Path) -> ModelPlus: + '''Load a model of any supported format.''' + # Be extra-friendly and accept either a file or a directory: + if path.is_dir(): + # Check if it's a set of safetensors files first + globs = ["model-00001-of-*.safetensors", "model.safetensors", "consolidated.safetensors"] + files = [file for glob in globs for file in path.glob(glob)] + if not files: + # Try the PyTorch patterns too, with lower priority + globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"] + files = [file for glob in globs for file in path.glob(glob)] + if not files: + raise FileNotFoundError(f"Can't find model in directory {path}") + if len(files) > 1: + raise ValueError(f"Found multiple models in {path}, not sure which to pick: {files}") + path = files[0] + + paths = find_multifile_paths(path) + models_plus: list[ModelPlus] = [] + for path in paths: + logger.info(f"Loading model file {path}") + models_plus.append(lazy_load_file(path)) + + model_plus = merge_multifile_models(models_plus) + return model_plus + + +class VocabFactory: + _VOCAB_CLASSES: list[type[Vocab]] = [SentencePieceVocab, BpeVocab, LlamaHfVocab] + + def __init__(self, path: Path): + self.path = path + + def _create_special_vocab(self, vocab: BaseVocab, model_parent_path: Path) -> gguf.SpecialVocab: + load_merges = vocab.name == "bpe" + n_vocab = vocab.vocab_size if isinstance(vocab, Vocab) else None + return gguf.SpecialVocab( + model_parent_path, + load_merges=load_merges, + special_token_types=None, # Predetermined or passed as a parameter + n_vocab=n_vocab, + ) + + def _create_vocab_by_path(self, vocab_types: list[str]) -> Vocab: + vocab_classes: dict[str, type[Vocab]] = {cls.name: cls for cls in self._VOCAB_CLASSES} + selected_vocabs: dict[str, type[Vocab]] = {} + for vtype in vocab_types: + try: + selected_vocabs[vtype] = vocab_classes[vtype] + except KeyError: + raise ValueError(f"Unsupported vocabulary type {vtype}") from None + + for vtype, cls in selected_vocabs.items(): + try: + vocab = cls(self.path) + break + except FileNotFoundError: + pass # ignore unavailable tokenizers + else: + raise FileNotFoundError(f"Could not find a tokenizer matching any of {vocab_types}") + + logger.info(f"Loaded vocab file {vocab.fname_tokenizer!r}, type {vocab.name!r}") + return vocab + + def load_vocab(self, vocab_types: list[str] | None, model_parent_path: Path) -> tuple[BaseVocab, gguf.SpecialVocab]: + vocab: BaseVocab + if vocab_types is None: + vocab = NoVocab() + else: + vocab = self._create_vocab_by_path(vocab_types) + # FIXME: Respect --vocab-dir? + special_vocab = self._create_special_vocab( + vocab, + model_parent_path, + ) + return vocab, special_vocab + + +def default_convention_outfile(file_type: GGMLFileType, expert_count: int | None, model_params_count: tuple[int, int, int], metadata: gguf.Metadata) -> str: + name = metadata.name if metadata.name is not None else None + basename = metadata.basename if metadata.basename is not None else None + finetune = metadata.finetune if metadata.finetune is not None else None + version = metadata.version if metadata.version is not None else None + size_label = metadata.size_label if metadata.size_label is not None else gguf.size_label(*model_params_count, expert_count=expert_count or 0) + + output_type = { + GGMLFileType.AllF32: "F32", + GGMLFileType.MostlyF16: "F16", + GGMLFileType.MostlyQ8_0: "Q8_0", + }[file_type] + + return gguf.naming_convention(name, basename, finetune, version, size_label, output_type) + + +def default_outfile(model_paths: list[Path], file_type: GGMLFileType, expert_count: int | None, model_params_count: tuple[int, int, int], metadata: gguf.Metadata) -> Path: + default_filename = default_convention_outfile(file_type, expert_count, model_params_count, metadata) + ret = model_paths[0].parent / f"{default_filename}.gguf" + if ret in model_paths: + logger.error( + f"Error: Default output path ({ret}) would overwrite the input. " + "Please explicitly specify a path using --outfile.") + sys.exit(1) + return ret + + +def do_dump_model(model_plus: ModelPlus) -> None: + print(f"model_plus.paths = {model_plus.paths!r}") # noqa: NP100 + print(f"model_plus.format = {model_plus.format!r}") # noqa: NP100 + print(f"model_plus.vocab = {model_plus.vocab!r}") # noqa: NP100 + for name, lazy_tensor in model_plus.model.items(): + print(f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}") # noqa: NP100 + + +def main(args_in: list[str] | None = None) -> None: + output_choices = ["f32", "f16"] + if np.uint32(1) == np.uint32(1).newbyteorder("<"): + # We currently only support Q8_0 output on little endian systems. + output_choices.append("q8_0") + parser = argparse.ArgumentParser(description="Convert a LLaMA model to a GGML compatible file") + parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") + parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file") + parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab") + parser.add_argument("--no-vocab", action="store_true", help="store model without the vocab") + parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)") + parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") + parser.add_argument("--vocab-type", help="vocab types to try in order, choose from 'spm', 'bpe', 'hfft' (default: spm,hfft)", default="spm,hfft") + parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") + parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)") + parser.add_argument("--ctx", type=int, help="model training context (default: based on input)") + parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default=DEFAULT_CONCURRENCY) + parser.add_argument("--big-endian", action="store_true", help="model is executed on big endian machine") + parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides") + parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing") + parser.add_argument("--verbose", action="store_true", help="increase output verbosity") + parser.add_argument("--metadata", type=Path, help="Specify the path for an authorship metadata override file") + parser.add_argument("--get-outfile", action="store_true", help="get calculated default outfile name") + parser.add_argument("--model-name", type=str, default=None, help="name of the model") + + args = parser.parse_args(args_in) + + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + elif args.dump_single or args.dump or args.get_outfile: + # Avoid printing anything besides the dump output + logging.basicConfig(level=logging.WARNING) + else: + logging.basicConfig(level=logging.INFO) + + model_name = args.model_name + dir_model = args.model + + metadata = gguf.Metadata.load(args.metadata, dir_model, model_name) + + if args.get_outfile: + model_plus = load_some_model(dir_model) + params = Params.load(model_plus) + model = convert_model_names(model_plus.model, params, args.skip_unknown) + model_params_count = per_model_weight_count_estimation(model_plus.model.items()) + ftype = pick_output_type(model, args.outtype) + + if (metadata is None or metadata.name is None) and params.path_model is not None: + metadata.name = params.path_model.name + + print(f"{default_convention_outfile(ftype, params.n_experts, model_params_count, metadata)}") # noqa: NP100 + return + + if args.no_vocab and args.vocab_only: + raise ValueError("--vocab-only does not make sense with --no-vocab") + + if args.dump_single: + model_plus = lazy_load_file(dir_model) + do_dump_model(model_plus) + return + + if not args.vocab_only: + model_plus = load_some_model(dir_model) + else: + model_plus = ModelPlus(model = {}, paths = [dir_model / 'dummy'], format = 'none', vocab = None) + + if args.dump: + do_dump_model(model_plus) + return + + endianess = gguf.GGUFEndian.LITTLE + if args.big_endian: + endianess = gguf.GGUFEndian.BIG + + params = None + if args.pad_vocab or not args.vocab_only: + params = Params.load(model_plus) + if params.n_ctx == -1: + if args.ctx is None: + msg = """\ + The model doesn't have a context size, and you didn't specify one with --ctx + Please specify one with --ctx: + - LLaMA v1: --ctx 2048 + - LLaMA v2: --ctx 4096""" + parser.error(textwrap.dedent(msg)) + params.n_ctx = args.ctx + + if args.outtype: + params.ftype = { + "f32": GGMLFileType.AllF32, + "f16": GGMLFileType.MostlyF16, + "q8_0": GGMLFileType.MostlyQ8_0, + }[args.outtype] + + logger.info(f"params = {params}") + + model_parent_path = model_plus.paths[0].parent + vocab_path = Path(args.vocab_dir or dir_model or model_parent_path) + vocab_factory = VocabFactory(vocab_path) + vocab_types = None if args.no_vocab else args.vocab_type.split(",") + vocab, special_vocab = vocab_factory.load_vocab(vocab_types, model_parent_path) + + if args.vocab_only: + assert isinstance(vocab, Vocab) + if not args.outfile: + raise ValueError("need --outfile if using --vocab-only") + outfile = args.outfile + if params is None: + params = Params( + n_vocab = vocab.vocab_size, + n_embd = 1, + n_layer = 1, + n_ctx = 1, + n_ff = 1, + n_head = 1, + n_head_kv = 1, + f_norm_eps = 1e-5, + ) + OutputFile.write_vocab_only(outfile, params, vocab, special_vocab, + endianess=endianess, pad_vocab=args.pad_vocab, metadata=metadata) + logger.info(f"Wrote {outfile}") + return + + if model_plus.vocab is not None and args.vocab_dir is None and not args.no_vocab: + vocab = model_plus.vocab + + assert params is not None + + if metadata.name is None and params.path_model is not None: + metadata.name = params.path_model.name + + model_params_count = per_model_weight_count_estimation(model_plus.model.items()) + logger.info(f"model parameters count : {model_params_count} ({gguf.model_weight_count_rounded_notation(model_params_count[0])})") + + logger.info(f"Vocab info: {vocab}") + logger.info(f"Special vocab info: {special_vocab}") + model = model_plus.model + model = convert_model_names(model, params, args.skip_unknown) + ftype = pick_output_type(model, args.outtype) + model = convert_to_output_type(model, ftype) + outfile = args.outfile or default_outfile(model_plus.paths, ftype, params.n_experts, model_params_count, metadata=metadata) + + metadata.size_label = gguf.size_label(*model_params_count, expert_count=params.n_experts or 0) + + params.ftype = ftype + logger.info(f"Writing {outfile}, format {ftype}") + + OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, + concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab, metadata=metadata) + logger.info(f"Wrote {outfile}") + + +if __name__ == '__main__': + main() diff --git a/examples/deprecation-warning/README.md b/examples/deprecation-warning/README.md new file mode 100644 index 0000000000000000000000000000000000000000..59918ec2bbf72fcb702ba4cf9723d465ba6ae6c0 --- /dev/null +++ b/examples/deprecation-warning/README.md @@ -0,0 +1,49 @@ +# Migration notice for binary filenames + +> [!IMPORTANT] +[2024 Jun 12] Binaries have been renamed w/ a `llama-` prefix. `main` is now `llama-cli`, `server` is `llama-server`, etc (https://github.com/ggerganov/llama.cpp/pull/7809) + +This migration was important, but it is a breaking change that may not always be immediately obvious to users. + +Please update all scripts and workflows to use the new binary names. + +| Old Filename | New Filename | +| ---- | ---- | +| main | llama-cli | +| server | llama-server | +| llama-bench | llama-bench | +| embedding | llama-embedding | +| quantize | llama-quantize | +| tokenize | llama-tokenize | +| export-lora | llama-export-lora | +| libllava.a | libllava.a | +| baby-llama | llama-baby-llama | +| batched | llama-batched | +| batched-bench | llama-batched-bench | +| benchmark-matmult | llama-benchmark-matmult | +| convert-llama2c-to-ggml | llama-convert-llama2c-to-ggml | +| eval-callback | llama-eval-callback | +| gbnf-validator | llama-gbnf-validator | +| gguf | llama-gguf | +| gguf-split | llama-gguf-split | +| gritlm | llama-gritlm | +| imatrix | llama-imatrix | +| infill | llama-infill | +| llava-cli | llama-llava-cli | +| lookahead | llama-lookahead | +| lookup | llama-lookup | +| lookup-create | llama-lookup-create | +| lookup-merge | llama-lookup-merge | +| lookup-stats | llama-lookup-stats | +| parallel | llama-parallel | +| passkey | llama-passkey | +| perplexity | llama-perplexity | +| q8dot | llama-q8dot | +| quantize-stats | llama-quantize-stats | +| retrieval | llama-retrieval | +| save-load-state | llama-save-load-state | +| simple | llama-simple | +| speculative | llama-speculative | +| vdot | llama-vdot | +| tests/test-c.o | tests/test-c.o | + diff --git a/examples/deprecation-warning/deprecation-warning.cpp b/examples/deprecation-warning/deprecation-warning.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c2958ea12d92d60d89cded4e6850ec20e121515a --- /dev/null +++ b/examples/deprecation-warning/deprecation-warning.cpp @@ -0,0 +1,35 @@ +// Warns users that this filename was deprecated, and provides a link for more information. + +#include +#include +#include + +// Main +int main(int argc, char** argv) { + std::string filename = "main"; + if (argc >= 1) { + filename = argv[0]; + } + + // Get only the program name from the full path + auto pos = filename.find_last_of("/\\"); + if (pos != std::string::npos) { + filename = filename.substr(pos+1); + } + + // Append "llama-" to the beginning of filename to get the replacemnt filename + auto replacement_filename = "llama-" + filename; + + // The exception is if the filename is "main", then our replacement filename is "llama-cli" + if (filename == "main") { + replacement_filename = "llama-cli"; + } + + fprintf(stdout, "\n"); + fprintf(stdout, "WARNING: The binary '%s' is deprecated.\n", filename.c_str()); + fprintf(stdout, " Please use '%s' instead.\n", replacement_filename.c_str()); + fprintf(stdout, " See https://github.com/ggerganov/llama.cpp/tree/master/examples/deprecation-warning/README.md for more information.\n"); + fprintf(stdout, "\n"); + + return EXIT_FAILURE; +} diff --git a/examples/embedding/CMakeLists.txt b/examples/embedding/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..809040307d2c907486a6a4624baebedc9bdc73d4 --- /dev/null +++ b/examples/embedding/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-embedding) +add_executable(${TARGET} embedding.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/embedding/README.md b/examples/embedding/README.md new file mode 100644 index 0000000000000000000000000000000000000000..12b372bf1df42f2d317594b33e26330b7f3d21ed --- /dev/null +++ b/examples/embedding/README.md @@ -0,0 +1,60 @@ +# llama.cpp/example/embedding + +This example demonstrates generate high-dimensional embedding vector of a given text with llama.cpp. + +## Quick Start + +To get started right away, run the following command, making sure to use the correct path for the model you have: + +### Unix-based systems (Linux, macOS, etc.): + +```bash +./llama-embedding -m ./path/to/model --pooling mean --log-disable -p "Hello World!" 2>/dev/null +``` + +### Windows: + +```powershell +llama-embedding.exe -m ./path/to/model --pooling mean --log-disable -p "Hello World!" 2>$null +``` + +The above command will output space-separated float values. + +## extra parameters +### --embd-normalize $integer$ +| $integer$ | description | formula | +|-----------|---------------------|---------| +| $-1$ | none | +| $0$ | max absolute int16 | $\Large{{32760 * x_i} \over\max \lvert x_i\rvert}$ +| $1$ | taxicab | $\Large{x_i \over\sum \lvert x_i\rvert}$ +| $2$ | euclidean (default) | $\Large{x_i \over\sqrt{\sum x_i^2}}$ +| $>2$ | p-norm | $\Large{x_i \over\sqrt[p]{\sum \lvert x_i\rvert^p}}$ + +### --embd-output-format $'string'$ +| $'string'$ | description | | +|------------|------------------------------|--| +| '' | same as before | (default) +| 'array' | single embeddings | $[[x_1,...,x_n]]$ +| | multiple embeddings | $[[x_1,...,x_n],[x_1,...,x_n],...,[x_1,...,x_n]]$ +| 'json' | openai style | +| 'json+' | add cosine similarity matrix | + +### --embd-separator $"string"$ +| $"string"$ | | +|--------------|-| +| "\n" | (default) +| "<#embSep#>" | for exemple +| "<#sep#>" | other exemple + +## examples +### Unix-based systems (Linux, macOS, etc.): + +```bash +./llama-embedding -p 'Castle<#sep#>Stronghold<#sep#>Dog<#sep#>Cat' --pooling mean --embd-separator '<#sep#>' --embd-normalize 2 --embd-output-format '' -m './path/to/model.gguf' --n-gpu-layers 99 --log-disable 2>/dev/null +``` + +### Windows: + +```powershell +llama-embedding.exe -p 'Castle<#sep#>Stronghold<#sep#>Dog<#sep#>Cat' --pooling mean --embd-separator '<#sep#>' --embd-normalize 2 --embd-output-format '' -m './path/to/model.gguf' --n-gpu-layers 99 --log-disable 2>/dev/null +``` diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..681929d27d617460eedbcdabaefc6be89633829c --- /dev/null +++ b/examples/embedding/embedding.cpp @@ -0,0 +1,338 @@ +#include "arg.h" +#include "common.h" +#include "log.h" +#include "llama.h" + +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +static std::vector split_lines(const std::string & s, const std::string & separator = "\n") { + std::vector lines; + size_t start = 0; + size_t end = s.find(separator); + + while (end != std::string::npos) { + lines.push_back(s.substr(start, end - start)); + start = end + separator.length(); + end = s.find(separator, start); + } + + lines.push_back(s.substr(start)); // Add the last part + + return lines; +} + +static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { + size_t n_tokens = tokens.size(); + for (size_t i = 0; i < n_tokens; i++) { + common_batch_add(batch, tokens[i], i, { seq_id }, true); + } +} + +static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) { + const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); + + // clear previous kv_cache values (irrelevant for embeddings) + llama_memory_clear(llama_get_memory(ctx), true); + + // run model + LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); + if (llama_decode(ctx, batch) < 0) { + LOG_ERR("%s : failed to process\n", __func__); + } + + for (int i = 0; i < batch.n_tokens; i++) { + if (!batch.logits[i]) { + continue; + } + + const float * embd = nullptr; + int embd_pos = 0; + + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + // try to get token embeddings + embd = llama_get_embeddings_ith(ctx, i); + embd_pos = i; + GGML_ASSERT(embd != NULL && "failed to get token embeddings"); + } else { + // try to get sequence embeddings - supported only when pooling_type is not NONE + embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + embd_pos = batch.seq_id[i][0]; + GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); + } + + float * out = output + embd_pos * n_embd; + common_embd_normalize(embd, out, n_embd, embd_norm); + } +} + +int main(int argc, char ** argv) { + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_EMBEDDING)) { + return 1; + } + + common_init(); + + params.embedding = true; + + // utilize the full context + if (params.n_batch < params.n_ctx) { + LOG_WRN("%s: setting batch size to %d\n", __func__, params.n_ctx); + params.n_batch = params.n_ctx; + } + + // For non-causal models, batch size must be equal to ubatch size + params.n_ubatch = params.n_batch; + + llama_backend_init(); + llama_numa_init(params.numa); + + // load the model + common_init_result llama_init = common_init_from_params(params); + + llama_model * model = llama_init.model.get(); + llama_context * ctx = llama_init.context.get(); + + if (model == NULL) { + LOG_ERR("%s: unable to load model\n", __func__); + return 1; + } + + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_ctx_train = llama_model_n_ctx_train(model); + const int n_ctx = llama_n_ctx(ctx); + + const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); + + if (llama_model_has_encoder(model) && llama_model_has_decoder(model)) { + LOG_ERR("%s: computing embeddings in encoder-decoder models is not supported\n", __func__); + return 1; + } + + if (n_ctx > n_ctx_train) { + LOG_WRN("%s: warning: model was trained on only %d context tokens (%d specified)\n", + __func__, n_ctx_train, n_ctx); + } + + // print system information + { + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + } + + // split the prompt into lines + std::vector prompts = split_lines(params.prompt, params.embd_sep); + + // max batch size + const uint64_t n_batch = params.n_batch; + + // tokenize the prompts and trim + std::vector> inputs; + for (const auto & prompt : prompts) { + auto inp = common_tokenize(ctx, prompt, true, true); + if (inp.size() > n_batch) { + LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n", + __func__, (long long int) inp.size(), (long long int) n_batch); + return 1; + } + inputs.push_back(inp); + } + + // check if the last token is SEP + // it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true' + for (auto & inp : inputs) { + if (inp.empty() || inp.back() != llama_vocab_sep(vocab)) { + LOG_WRN("%s: last token in the prompt is not SEP\n", __func__); + LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__); + } + } + + // tokenization stats + if (params.verbose_prompt) { + for (int i = 0; i < (int) inputs.size(); i++) { + LOG_INF("%s: prompt %d: '%s'\n", __func__, i, prompts[i].c_str()); + LOG_INF("%s: number of tokens in prompt = %zu\n", __func__, inputs[i].size()); + for (int j = 0; j < (int) inputs[i].size(); j++) { + LOG("%6d -> '%s'\n", inputs[i][j], common_token_to_piece(ctx, inputs[i][j]).c_str()); + } + LOG("\n\n"); + } + } + + // initialize batch + const int n_prompts = prompts.size(); + struct llama_batch batch = llama_batch_init(n_batch, 0, 1); + + // count number of embeddings + int n_embd_count = 0; + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + for (int k = 0; k < n_prompts; k++) { + n_embd_count += inputs[k].size(); + } + } else { + n_embd_count = n_prompts; + } + + // allocate output + const int n_embd = llama_model_n_embd(model); + std::vector embeddings(n_embd_count * n_embd, 0); + float * emb = embeddings.data(); + + // break into batches + int e = 0; // number of embeddings already stored + int s = 0; // number of prompts in current batch + for (int k = 0; k < n_prompts; k++) { + // clamp to n_batch tokens + auto & inp = inputs[k]; + + const uint64_t n_toks = inp.size(); + + // encode if at capacity + if (batch.n_tokens + n_toks > n_batch) { + float * out = emb + e * n_embd; + batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); + e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s; + s = 0; + common_batch_clear(batch); + } + + // add to batch + batch_add_seq(batch, inp, s); + s += 1; + } + + // final batch + float * out = emb + e * n_embd; + batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); + + if (params.embd_out.empty()) { + LOG("\n"); + + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + for (int j = 0; j < n_embd_count; j++) { + LOG("embedding %d: ", j); + for (int i = 0; i < std::min(3, n_embd); i++) { + if (params.embd_normalize == 0) { + LOG("%6.0f ", emb[j * n_embd + i]); + } else { + LOG("%9.6f ", emb[j * n_embd + i]); + } + } + LOG(" ... "); + for (int i = n_embd - 3; i < n_embd; i++) { + if (params.embd_normalize == 0) { + LOG("%6.0f ", emb[j * n_embd + i]); + } else { + LOG("%9.6f ", emb[j * n_embd + i]); + } + } + LOG("\n"); + } + } else if (pooling_type == LLAMA_POOLING_TYPE_RANK) { + const uint32_t n_cls_out = llama_model_n_cls_out(model); + std::vector cls_out_labels; + + for (uint32_t i = 0; i < n_cls_out; i++) { + const char * label = llama_model_cls_label(model, i); + const std::string label_i(label == nullptr ? "" : label); + cls_out_labels.emplace_back(label_i.empty() ? std::to_string(i) : label_i); + } + + for (int j = 0; j < n_embd_count; j++) { + for (uint32_t i = 0; i < n_cls_out; i++) { + // NOTE: if you change this log - update the tests in ci/run.sh + if (n_cls_out == 1) { + LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]); + } else { + LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd + i], cls_out_labels[i].c_str()); + } + } + } + } else { + // print the first part of the embeddings or for a single prompt, the full embedding + for (int j = 0; j < n_prompts; j++) { + LOG("embedding %d: ", j); + for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) { + if (params.embd_normalize == 0) { + LOG("%6.0f ", emb[j * n_embd + i]); + } else { + LOG("%9.6f ", emb[j * n_embd + i]); + } + } + LOG("\n"); + } + + // print cosine similarity matrix + if (n_prompts > 1) { + LOG("\n"); + LOG("cosine similarity matrix:\n\n"); + for (int i = 0; i < n_prompts; i++) { + LOG("%6.6s ", prompts[i].c_str()); + } + LOG("\n"); + for (int i = 0; i < n_prompts; i++) { + for (int j = 0; j < n_prompts; j++) { + float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); + LOG("%6.2f ", sim); + } + LOG("%1.10s", prompts[i].c_str()); + LOG("\n"); + } + } + } + } + + if (params.embd_out == "json" || params.embd_out == "json+" || params.embd_out == "array") { + const bool notArray = params.embd_out != "array"; + + LOG(notArray ? "{\n \"object\": \"list\",\n \"data\": [\n" : "["); + for (int j = 0;;) { // at least one iteration (one prompt) + if (notArray) LOG(" {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j); + LOG("["); + for (int i = 0;;) { // at least one iteration (n_embd > 0) + LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd + i]); + i++; + if (i < n_embd) LOG(","); else break; + } + LOG(notArray ? "]\n }" : "]"); + j++; + if (j < n_embd_count) LOG(notArray ? ",\n" : ","); else break; + } + LOG(notArray ? "\n ]" : "]\n"); + + if (params.embd_out == "json+" && n_prompts > 1) { + LOG(",\n \"cosineSimilarity\": [\n"); + for (int i = 0;;) { // at least two iteration (n_embd_count > 1) + LOG(" ["); + for (int j = 0;;) { // at least two iteration (n_embd_count > 1) + float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); + LOG("%6.2f", sim); + j++; + if (j < n_embd_count) LOG(", "); else break; + } + LOG(" ]"); + i++; + if (i < n_embd_count) LOG(",\n"); else break; + } + LOG("\n ]"); + } + + if (notArray) LOG("\n}\n"); + } + + LOG("\n"); + llama_perf_context_print(ctx); + + // clean up + llama_batch_free(batch); + llama_backend_free(); + + return 0; +} diff --git a/examples/eval-callback/CMakeLists.txt b/examples/eval-callback/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..95915ed91c09912da1e4b18b719dd878508d219e --- /dev/null +++ b/examples/eval-callback/CMakeLists.txt @@ -0,0 +1,10 @@ +set(TARGET llama-eval-callback) +add_executable(${TARGET} eval-callback.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) + +set(TEST_TARGET test-eval-callback) +add_test(NAME ${TEST_TARGET} + COMMAND llama-eval-callback --hf-repo ggml-org/models --hf-file tinyllamas/stories260K.gguf --model stories260K.gguf --prompt hello --seed 42 -ngl 0) +set_property(TEST ${TEST_TARGET} PROPERTY LABELS eval-callback curl) diff --git a/examples/eval-callback/README.md b/examples/eval-callback/README.md new file mode 100644 index 0000000000000000000000000000000000000000..63a57ad6b68e5f508621e122d42d37972d841245 --- /dev/null +++ b/examples/eval-callback/README.md @@ -0,0 +1,95 @@ +# llama.cpp/examples/eval-callback + +A simple example which demonstrates how to use callback during the inference. +It simply prints to the console all operations and tensor data. + +Usage: + +```shell +llama-eval-callback \ + --hf-repo ggml-org/models \ + --hf-file phi-2/ggml-model-q4_0.gguf \ + --model phi-2-q4_0.gguf \ + --prompt hello \ + --seed 42 \ + -ngl 33 +``` + +Will print: + +```shell +llm_load_tensors: offloaded 33/33 layers to GPU +... +llama_new_context_with_model: n_ctx = 512 +... +llama_new_context_with_model: CUDA0 compute buffer size = 105.00 MiB +llama_new_context_with_model: CUDA_Host compute buffer size = 6.01 MiB +llama_new_context_with_model: graph nodes = 1225 +llama_new_context_with_model: graph splits = 2 +ggml_debug: inp_embd = (f32) GET_ROWS(token_embd.weight{2560, 51200, 1, 1}, inp_tokens{1, 1, 1, 1}}) = {2560, 1, 1, 1} + [ + [ + [ -0.0181, 0.0272, 0.0272, ...], + ], + ] +ggml_debug: norm-0 = (f32) NORM(CUDA0#inp_embd#0{2560, 1, 1, 1}, }) = {2560, 1, 1, 1} + [ + [ + [ -0.6989, 1.0636, 1.0636, ...], + ], + ] +ggml_debug: norm_w-0 = (f32) MUL(norm-0{2560, 1, 1, 1}, blk.0.attn_norm.weight{2560, 1, 1, 1}}) = {2560, 1, 1, 1} + [ + [ + [ -0.1800, 0.2817, 0.2632, ...], + ], + ] +ggml_debug: attn_norm-0 = (f32) ADD(norm_w-0{2560, 1, 1, 1}, blk.0.attn_norm.bias{2560, 1, 1, 1}}) = {2560, 1, 1, 1} + [ + [ + [ -0.1863, 0.2970, 0.2604, ...], + ], + ] +ggml_debug: wqkv-0 = (f32) MUL_MAT(blk.0.attn_qkv.weight{2560, 7680, 1, 1}, attn_norm-0{2560, 1, 1, 1}}) = {7680, 1, 1, 1} + [ + [ + [ -1.1238, 1.2876, -1.8086, ...], + ], + ] +ggml_debug: bqkv-0 = (f32) ADD(wqkv-0{7680, 1, 1, 1}, blk.0.attn_qkv.bias{7680, 1, 1, 1}}) = {7680, 1, 1, 1} + [ + [ + [ -1.1135, 1.4604, -1.9226, ...], + ], + ] +ggml_debug: bqkv-0 (view) = (f32) VIEW(bqkv-0{7680, 1, 1, 1}, }) = {2560, 1, 1, 1} + [ + [ + [ -1.1135, 1.4604, -1.9226, ...], + ], + ] +ggml_debug: Qcur-0 = (f32) CONT(bqkv-0 (view){2560, 1, 1, 1}, }) = {2560, 1, 1, 1} + [ + [ + [ -1.1135, 1.4604, -1.9226, ...], + ], + ] +ggml_debug: Qcur-0 (reshaped) = (f32) RESHAPE(Qcur-0{2560, 1, 1, 1}, }) = {80, 32, 1, 1} + [ + [ + [ -1.1135, 1.4604, -1.9226, ...], + [ -0.3608, 0.5076, -1.8866, ...], + [ 1.7643, 0.0273, -2.1065, ...], + ... + ], + ] +ggml_debug: Qcur-0 = (f32) ROPE(Qcur-0 (reshaped){80, 32, 1, 1}, CUDA0#inp_pos#0{1, 1, 1, 1}}) = {80, 32, 1, 1} + [ + [ + [ -1.1135, 1.4604, -1.9226, ...], + [ -0.3608, 0.5076, -1.8866, ...], + [ 1.7643, 0.0273, -2.1065, ...], + ... + ], + ] +``` diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fb188f5a9e13225eacac907e1bfb4003c1714214 --- /dev/null +++ b/examples/eval-callback/eval-callback.cpp @@ -0,0 +1,194 @@ +#include "arg.h" +#include "common.h" +#include "log.h" +#include "llama.h" +#include "ggml.h" + +#include +#include +#include + +/** + * This the arbitrary data which will be passed to each callback. + * Later on we can for example add operation or tensor name filter from the CLI arg, or a file descriptor to dump the tensor. + */ +struct callback_data { + std::vector data; +}; + +static std::string ggml_ne_string(const ggml_tensor * t) { + std::string str; + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + str += std::to_string(t->ne[i]); + if (i + 1 < GGML_MAX_DIMS) { + str += ", "; + } + } + return str; +} + +static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) { + GGML_ASSERT(n > 0); + float sum = 0; + for (int64_t i3 = 0; i3 < ne[3]; i3++) { + LOG(" [\n"); + for (int64_t i2 = 0; i2 < ne[2]; i2++) { + if (i2 == n && ne[2] > 2*n) { + LOG(" ..., \n"); + i2 = ne[2] - n; + } + LOG(" [\n"); + for (int64_t i1 = 0; i1 < ne[1]; i1++) { + if (i1 == n && ne[1] > 2*n) { + LOG(" ..., \n"); + i1 = ne[1] - n; + } + LOG(" ["); + for (int64_t i0 = 0; i0 < ne[0]; i0++) { + if (i0 == n && ne[0] > 2*n) { + LOG("..., "); + i0 = ne[0] - n; + } + size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; + float v; + if (type == GGML_TYPE_F16) { + v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]); + } else if (type == GGML_TYPE_F32) { + v = *(float *) &data[i]; + } else if (type == GGML_TYPE_I32) { + v = (float) *(int32_t *) &data[i]; + } else if (type == GGML_TYPE_I16) { + v = (float) *(int16_t *) &data[i]; + } else if (type == GGML_TYPE_I8) { + v = (float) *(int8_t *) &data[i]; + } else { + GGML_ABORT("fatal error"); + } + LOG("%12.4f", v); + sum += v; + if (i0 < ne[0] - 1) LOG(", "); + } + LOG("],\n"); + } + LOG(" ],\n"); + } + LOG(" ]\n"); + LOG(" sum = %f\n", sum); + } +} + +/** + * GGML operations callback during the graph execution. + * + * @param t current tensor + * @param ask when ask is true, the scheduler wants to know if we are interested in data from this tensor + * if we return true, a follow-up call will be made with ask=false in which we can do the actual collection. + * see ggml_backend_sched_eval_callback + * @param user_data user data to pass at each call back + * @return true to receive data or continue the graph, false otherwise + */ +static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { + auto * cb_data = (callback_data *) user_data; + + const struct ggml_tensor * src0 = t->src[0]; + const struct ggml_tensor * src1 = t->src[1]; + + if (ask) { + return true; // Always retrieve data + } + + char src1_str[128] = {0}; + if (src1) { + snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str()); + } + + LOG("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, + t->name, ggml_type_name(t->type), ggml_op_desc(t), + src0->name, ggml_ne_string(src0).c_str(), + src1 ? src1_str : "", + ggml_ne_string(t).c_str()); + + + // copy the data from the GPU memory if needed + const bool is_host = ggml_backend_buffer_is_host(t->buffer); + + if (!is_host) { + auto n_bytes = ggml_nbytes(t); + cb_data->data.resize(n_bytes); + ggml_backend_tensor_get(t, cb_data->data.data(), 0, n_bytes); + } + + if (!ggml_is_quantized(t->type)) { + uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data(); + ggml_print_tensor(data, t->type, t->ne, t->nb, 3); + } + + return true; +} + +static bool run(llama_context * ctx, const common_params & params) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const bool add_bos = llama_vocab_get_add_bos(vocab); + + std::vector tokens = common_tokenize(ctx, params.prompt, add_bos); + + if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { + LOG_ERR("%s : failed to eval\n", __func__); + return false; + } + + return true; +} + +int main(int argc, char ** argv) { + callback_data cb_data; + + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { + return 1; + } + + common_init(); + + llama_backend_init(); + llama_numa_init(params.numa); + + // pass the callback to the backend scheduler + // it will be executed for each node during the graph computation + params.cb_eval = ggml_debug; + params.cb_eval_user_data = &cb_data; + params.warmup = false; + + // init + common_init_result llama_init = common_init_from_params(params); + + llama_model * model = llama_init.model.get(); + llama_context * ctx = llama_init.context.get(); + + if (model == nullptr || ctx == nullptr) { + LOG_ERR("%s : failed to init\n", __func__); + return 1; + } + + // print system information + { + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + LOG_INF("\n"); + } + + bool OK = run(ctx, params); + if (!OK) { + return 1; + } + + LOG("\n"); + llama_perf_context_print(ctx); + + llama_backend_free(); + + return 0; +} diff --git a/examples/gen-docs/CMakeLists.txt b/examples/gen-docs/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..25de0af35df6021841894396db71fdcc165f2b12 --- /dev/null +++ b/examples/gen-docs/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-gen-docs) +add_executable(${TARGET} gen-docs.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/gen-docs/gen-docs.cpp b/examples/gen-docs/gen-docs.cpp new file mode 100644 index 0000000000000000000000000000000000000000..77c59a836e50a7902c23f5f9d7c7bedaca81a527 --- /dev/null +++ b/examples/gen-docs/gen-docs.cpp @@ -0,0 +1,83 @@ +#include "arg.h" +#include "common.h" + +#include +#include + +// Export usage message (-h) to markdown format + +static void write_table_header(std::ofstream & file) { + file << "| Argument | Explanation |\n"; + file << "| -------- | ----------- |\n"; +} + +static void write_table_entry(std::ofstream & file, const common_arg & opt) { + file << "| `"; + // args + for (const auto & arg : opt.args) { + if (arg == opt.args.front()) { + file << arg; + if (opt.args.size() > 1) file << ", "; + } else { + file << arg << (arg != opt.args.back() ? ", " : ""); + } + } + // value hint + if (opt.value_hint) { + std::string md_value_hint(opt.value_hint); + string_replace_all(md_value_hint, "|", "\\|"); + file << " " << md_value_hint; + } + if (opt.value_hint_2) { + std::string md_value_hint_2(opt.value_hint_2); + string_replace_all(md_value_hint_2, "|", "\\|"); + file << " " << md_value_hint_2; + } + // help text + std::string md_help(opt.help); + string_replace_all(md_help, "\n", "
"); + string_replace_all(md_help, "|", "\\|"); + file << "` | " << md_help << " |\n"; +} + +static void write_table(std::ofstream & file, std::vector & opts) { + write_table_header(file); + for (const auto & opt : opts) { + write_table_entry(file, *opt); + } +} + +static void export_md(std::string fname, llama_example ex) { + std::ofstream file(fname, std::ofstream::out | std::ofstream::trunc); + + common_params params; + auto ctx_arg = common_params_parser_init(params, ex); + + std::vector common_options; + std::vector sparam_options; + std::vector specific_options; + for (auto & opt : ctx_arg.options) { + // in case multiple LLAMA_EXAMPLE_* are set, we prioritize the LLAMA_EXAMPLE_* matching current example + if (opt.is_sparam) { + sparam_options.push_back(&opt); + } else if (opt.in_example(ctx_arg.ex)) { + specific_options.push_back(&opt); + } else { + common_options.push_back(&opt); + } + } + + file << "**Common params**\n\n"; + write_table(file, common_options); + file << "\n\n**Sampling params**\n\n"; + write_table(file, sparam_options); + file << "\n\n**Example-specific params**\n\n"; + write_table(file, specific_options); +} + +int main(int, char **) { + export_md("autogen-main.md", LLAMA_EXAMPLE_MAIN); + export_md("autogen-server.md", LLAMA_EXAMPLE_SERVER); + + return 0; +} diff --git a/examples/gguf-hash/CMakeLists.txt b/examples/gguf-hash/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..15c5c68c6f402f1ac9a81555a16c81ba87794163 --- /dev/null +++ b/examples/gguf-hash/CMakeLists.txt @@ -0,0 +1,22 @@ +set(TARGET llama-gguf-hash) +add_executable(${TARGET} gguf-hash.cpp) +install(TARGETS ${TARGET} RUNTIME) + +# clibs dependencies +include_directories(deps/) + +add_library(xxhash OBJECT deps/xxhash/xxhash.c deps/xxhash/xxhash.h) +target_link_libraries(${TARGET} PRIVATE xxhash) + +add_library(sha1 OBJECT deps/sha1/sha1.c deps/sha1/sha1.h) +target_link_libraries(${TARGET} PRIVATE sha1) +if (NOT MSVC) + # disable warnings in 3rd party code + target_compile_options(sha1 PRIVATE -w) +endif() + +add_library(sha256 OBJECT deps/sha256/sha256.c deps/sha256/sha256.h) +target_link_libraries(${TARGET} PRIVATE sha256) + +target_link_libraries(${TARGET} PRIVATE ggml ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/gguf-hash/README.md b/examples/gguf-hash/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9871651e38ba8fc85414e5be0bebd6e98b38588d --- /dev/null +++ b/examples/gguf-hash/README.md @@ -0,0 +1,206 @@ + +# llama-gguf-hash + +CLI to hash GGUF files to detect difference on a per model and per tensor level. + +**Command line options:** + +- `--help`: display help message +- `--xxh64`: use xhash 64bit hash mode (default) +- `--sha1`: use sha1 +- `--uuid`: use uuid +- `--sha256`: use sha256 +- `--all`: use all hash +- `--no-layer`: exclude per layer hash +- `--uuid`: generate UUIDv5 ID +- `-c`, `--check `: verify against a manifest + +## About + +While most POSIX systems already have hash checking programs like sha256sum, it +is designed to check entire files. This is not ideal for our purpose if we want +to check for consistency of the tensor data even if the metadata content of the +gguf KV store has been updated. + +This program is designed to hash a gguf tensor payload on a 'per tensor layer' +in addition to a 'entire tensor model' hash. The intent is that the entire +tensor layer can be checked first but if there is any detected inconsistencies, +then the per tensor hash can be used to narrow down the specific tensor layer +that has inconsistencies. + +For Maintainers: +- Detection of tensor inconsistency during development and automated tests + - This is served by xxh64 which is fast + - This is also served by having per tensor layer to assist in narrowing down + the location of the faulty tensor layer + - This is also served by sha1 which is much slower but more widely supported + +For Model Creators: +- Optional consistent UUID generation based on model tensor content + - This is served by UUIDv5 which is useful for databases keys + - llama.cpp UUIDv5 Namespace: `ef001206-dadc-5f6d-a15f-3359e577d4e5` + - Made via UUIDv5 URL namespace of `en.wikipedia.org/wiki/Llama.cpp` + +For Model Users: +- Assurance of tensor layer integrity even if metadata was updated + - This is served by sha256 which is still considered very secure as of 2024 + +### Design Note + +- The default behavior of this program if no arguments is provided is to hash + using xxhash's xxh32 mode because it is very fast and is primarily targeted + towards maintainers who may want to use this in automated tests. +- xxhash support xxh32 and xxh128 for 32bit hash and 128bit hash respectively + however we picked 64bit xxhash as most computers are 64bit as of 2024 and thus + would have a better affinity to calculating hash that is 64bit in size. + +## Compile Example + +```bash +cmake -B build -DCMAKE_BUILD_TYPE=Debug -DLLAMA_FATAL_WARNINGS=ON +make -C build clean +make -C build llama-gguf-hash VERBOSE=1 +./build/bin/llama-gguf-hash test.gguf +./build/bin/llama-gguf-hash --xxh64 test.gguf +./build/bin/llama-gguf-hash --sha1 test.gguf +./build/bin/llama-gguf-hash --uuid test.gguf +./build/bin/llama-gguf-hash --sha256 test.gguf +``` + +## Generation and Verification Example + +To generate we may use this command + +```bash +./llama-gguf-hash --all test.gguf > test.gguf.manifest +``` + +Which would generate a manifest that looks like below, which contains multiple hash type and per tensor layer hashes as well +(This excludes UUID as that is an ID not a hash) + +```bash +xxh64 f66e9cd66a4396a0 test.gguf:tensor_0 +sha1 59f79ecefd8125a996fdf419239051a7e99e5f20 test.gguf:tensor_0 +sha256 c0510d38fa060c46265e0160a85c7243096b01dd31c2f355bdbb5516b20de1bd test.gguf:tensor_0 +xxh64 7d3a1f9ac04d0537 test.gguf:tensor_1 +sha1 4765f592eacf096df4628ba59476af94d767080a test.gguf:tensor_1 +sha256 8514cbcc73692a2c56bd7a33a022edd5ff819614bd23b19915d7224387f397a7 test.gguf:tensor_1 +xxh64 a0af5d700049693b test.gguf:tensor_2 +sha1 25cbfbad4513cc348e2c95ebdee69d6ff2fd8753 test.gguf:tensor_2 +sha256 947e6b36e20f2cc95e1d2ce1c1669d813d574657ac6b5ac5196158d454d35180 test.gguf:tensor_2 +xxh64 e83fddf559d7b6a6 test.gguf:tensor_3 +sha1 a9cba73e2d90f2ee3dae2548caa42bef3fe6a96c test.gguf:tensor_3 +sha256 423b044e016d8ac73c39f23f60bf01bedef5ecb03c0230accd824c91fe86f1a1 test.gguf:tensor_3 +xxh64 1257733306b7992d test.gguf:tensor_4 +sha1 d7bc61db93bb685ce9d598da89717c66729b7543 test.gguf:tensor_4 +sha256 79737cb3912d4201384cf7f16a1a37ff7823f23ea796cb205b6ca361ab9e3ebf test.gguf:tensor_4 +xxh64 d238d16ba4711e58 test.gguf:tensor_5 +sha1 0706566c198fe1072f37e0a5135b4b5f23654c52 test.gguf:tensor_5 +sha256 60949be8298eced0ecdde64487643d018407bd261691e061d9e9c3dbc9fd358b test.gguf:tensor_5 +xxh64 3fbc3b65ab8c7f39 test.gguf:tensor_6 +sha1 73922a0727226a409049f6fc3172a52219ca6f00 test.gguf:tensor_6 +sha256 574f4c46ff384a3b9a225eb955d2a871847a2e8b3fa59387a8252832e92ef7b0 test.gguf:tensor_6 +xxh64 c22021c29854f093 test.gguf:tensor_7 +sha1 efc39cece6a951188fc41e354c73bbfe6813d447 test.gguf:tensor_7 +sha256 4c0410cd3c500f078ae5b21e8dc9eb79e29112713b2ab58a882f82a3868d4d75 test.gguf:tensor_7 +xxh64 936df61f5d64261f test.gguf:tensor_8 +sha1 c2490296d789a4f34398a337fed8377d943d9f06 test.gguf:tensor_8 +sha256 c4401313feeba0261275c3b25bd2d8fe40ce04e0f440c2980ed0e9674c30ff01 test.gguf:tensor_8 +xxh64 93fd20c64421c081 test.gguf:tensor_9 +sha1 7047ce1e78437a6884337a3751c7ee0421918a65 test.gguf:tensor_9 +sha256 23d57cf0d7a6e90b0b3616b41300e0cd354781e812add854a5f95aa55f2bc514 test.gguf:tensor_9 +xxh64 5a54d3aad816f302 test.gguf +sha1 d15be52c4ff213e823cb6dd13af7ee2f978e7042 test.gguf +sha256 7dd641b32f59b60dbd4b5420c4b0f6321ccf48f58f6ae201a3dbc4a58a27c6e4 test.gguf +``` + +We can then use the normal check command which will by default check for the highest security strength hash and verify against that: + +```bash +$ ./llama-gguf-hash --check test.gguf.manifest test.gguf +manifest test.gguf.manifest sha256 sha1 xxh64 +sha256 c0510d38fa060c46265e0160a85c7243096b01dd31c2f355bdbb5516b20de1bd test.gguf:tensor_0 - Ok +sha256 8514cbcc73692a2c56bd7a33a022edd5ff819614bd23b19915d7224387f397a7 test.gguf:tensor_1 - Ok +sha256 947e6b36e20f2cc95e1d2ce1c1669d813d574657ac6b5ac5196158d454d35180 test.gguf:tensor_2 - Ok +sha256 423b044e016d8ac73c39f23f60bf01bedef5ecb03c0230accd824c91fe86f1a1 test.gguf:tensor_3 - Ok +sha256 79737cb3912d4201384cf7f16a1a37ff7823f23ea796cb205b6ca361ab9e3ebf test.gguf:tensor_4 - Ok +sha256 60949be8298eced0ecdde64487643d018407bd261691e061d9e9c3dbc9fd358b test.gguf:tensor_5 - Ok +sha256 574f4c46ff384a3b9a225eb955d2a871847a2e8b3fa59387a8252832e92ef7b0 test.gguf:tensor_6 - Ok +sha256 4c0410cd3c500f078ae5b21e8dc9eb79e29112713b2ab58a882f82a3868d4d75 test.gguf:tensor_7 - Ok +sha256 c4401313feeba0261275c3b25bd2d8fe40ce04e0f440c2980ed0e9674c30ff01 test.gguf:tensor_8 - Ok +sha256 23d57cf0d7a6e90b0b3616b41300e0cd354781e812add854a5f95aa55f2bc514 test.gguf:tensor_9 - Ok +sha256 7dd641b32f59b60dbd4b5420c4b0f6321ccf48f58f6ae201a3dbc4a58a27c6e4 test.gguf - Ok + +Verification results for test.gguf.manifest - Success +``` + +Or we may explicitly ask for a faster hash like: + +```bash +$ ./llama-gguf-hash --check test.gguf.manifest --xxh64 test.gguf +manifest test.gguf.manifest sha256 sha1 xxh64 +xxh64 f66e9cd66a4396a0 test.gguf:tensor_0 - Ok +xxh64 7d3a1f9ac04d0537 test.gguf:tensor_1 - Ok +xxh64 a0af5d700049693b test.gguf:tensor_2 - Ok +xxh64 e83fddf559d7b6a6 test.gguf:tensor_3 - Ok +xxh64 1257733306b7992d test.gguf:tensor_4 - Ok +xxh64 d238d16ba4711e58 test.gguf:tensor_5 - Ok +xxh64 3fbc3b65ab8c7f39 test.gguf:tensor_6 - Ok +xxh64 c22021c29854f093 test.gguf:tensor_7 - Ok +xxh64 936df61f5d64261f test.gguf:tensor_8 - Ok +xxh64 93fd20c64421c081 test.gguf:tensor_9 - Ok +xxh64 5a54d3aad816f302 test.gguf - Ok + +Verification results for test.gguf.manifest - Success +``` + +Or maybe we want to just check that all the hash is valid: + +```bash +$./llama-gguf-hash --check test.gguf.manifest --all test.gguf.manifest +manifest test.gguf.manifest sha256 sha1 xxh64 +xxh64 f66e9cd66a4396a0 test.gguf:tensor_0 - Ok +sha1 59f79ecefd8125a996fdf419239051a7e99e5f20 test.gguf:tensor_0 - Ok +sha256 c0510d38fa060c46265e0160a85c7243096b01dd31c2f355bdbb5516b20de1bd test.gguf:tensor_0 - Ok +xxh64 7d3a1f9ac04d0537 test.gguf:tensor_1 - Ok +sha1 4765f592eacf096df4628ba59476af94d767080a test.gguf:tensor_1 - Ok +sha256 8514cbcc73692a2c56bd7a33a022edd5ff819614bd23b19915d7224387f397a7 test.gguf:tensor_1 - Ok +xxh64 a0af5d700049693b test.gguf:tensor_2 - Ok +sha1 25cbfbad4513cc348e2c95ebdee69d6ff2fd8753 test.gguf:tensor_2 - Ok +sha256 947e6b36e20f2cc95e1d2ce1c1669d813d574657ac6b5ac5196158d454d35180 test.gguf:tensor_2 - Ok +xxh64 e83fddf559d7b6a6 test.gguf:tensor_3 - Ok +sha1 a9cba73e2d90f2ee3dae2548caa42bef3fe6a96c test.gguf:tensor_3 - Ok +sha256 423b044e016d8ac73c39f23f60bf01bedef5ecb03c0230accd824c91fe86f1a1 test.gguf:tensor_3 - Ok +xxh64 1257733306b7992d test.gguf:tensor_4 - Ok +sha1 d7bc61db93bb685ce9d598da89717c66729b7543 test.gguf:tensor_4 - Ok +sha256 79737cb3912d4201384cf7f16a1a37ff7823f23ea796cb205b6ca361ab9e3ebf test.gguf:tensor_4 - Ok +xxh64 d238d16ba4711e58 test.gguf:tensor_5 - Ok +sha1 0706566c198fe1072f37e0a5135b4b5f23654c52 test.gguf:tensor_5 - Ok +sha256 60949be8298eced0ecdde64487643d018407bd261691e061d9e9c3dbc9fd358b test.gguf:tensor_5 - Ok +xxh64 3fbc3b65ab8c7f39 test.gguf:tensor_6 - Ok +sha1 73922a0727226a409049f6fc3172a52219ca6f00 test.gguf:tensor_6 - Ok +sha256 574f4c46ff384a3b9a225eb955d2a871847a2e8b3fa59387a8252832e92ef7b0 test.gguf:tensor_6 - Ok +xxh64 c22021c29854f093 test.gguf:tensor_7 - Ok +sha1 efc39cece6a951188fc41e354c73bbfe6813d447 test.gguf:tensor_7 - Ok +sha256 4c0410cd3c500f078ae5b21e8dc9eb79e29112713b2ab58a882f82a3868d4d75 test.gguf:tensor_7 - Ok +xxh64 936df61f5d64261f test.gguf:tensor_8 - Ok +sha1 c2490296d789a4f34398a337fed8377d943d9f06 test.gguf:tensor_8 - Ok +sha256 c4401313feeba0261275c3b25bd2d8fe40ce04e0f440c2980ed0e9674c30ff01 test.gguf:tensor_8 - Ok +xxh64 93fd20c64421c081 test.gguf:tensor_9 - Ok +sha1 7047ce1e78437a6884337a3751c7ee0421918a65 test.gguf:tensor_9 - Ok +sha256 23d57cf0d7a6e90b0b3616b41300e0cd354781e812add854a5f95aa55f2bc514 test.gguf:tensor_9 - Ok +xxh64 5a54d3aad816f302 test.gguf - Ok +sha1 d15be52c4ff213e823cb6dd13af7ee2f978e7042 test.gguf - Ok +sha256 7dd641b32f59b60dbd4b5420c4b0f6321ccf48f58f6ae201a3dbc4a58a27c6e4 test.gguf - Ok + +Verification results for test.gguf.manifest - Success +``` + + +## Crypto/Hash Libraries Used + +These micro c libraries dependencies was installed via the [clib c package manager](https://github.com/clibs) + +- https://github.com/Cyan4973/xxHash +- https://github.com/clibs/sha1/ +- https://github.com/jb55/sha256.c diff --git a/examples/gguf-hash/deps/rotate-bits/package.json b/examples/gguf-hash/deps/rotate-bits/package.json new file mode 100644 index 0000000000000000000000000000000000000000..74c0bef68d8bd32828a0d8c204549a72a51a903a --- /dev/null +++ b/examples/gguf-hash/deps/rotate-bits/package.json @@ -0,0 +1,13 @@ +{ + "name": "rotate-bits", + "version": "0.1.1", + "repo": "jb55/rotate-bits.h", + "description": "rotate bits", + "keywords": ["rotl", "rotr"], + "src": ["rotate-bits.h"], + "license": "Public Domain", + "development": { + "thlorenz/tap.c": "*" + } +} + diff --git a/examples/gguf-hash/deps/rotate-bits/rotate-bits.h b/examples/gguf-hash/deps/rotate-bits/rotate-bits.h new file mode 100644 index 0000000000000000000000000000000000000000..75c4881fc322f2e6a6ee9d809e696852531abb8c --- /dev/null +++ b/examples/gguf-hash/deps/rotate-bits/rotate-bits.h @@ -0,0 +1,46 @@ + + +#ifndef __ROTATE_DEFS_H +#define __ROTATE_DEFS_H + +#ifdef _MSC_VER + +#include + +#define ROTL32(v, n) _rotl((v), (n)) +#define ROTL64(v, n) _rotl64((v), (n)) + +#define ROTR32(v, n) _rotr((v), (n)) +#define ROTR64(v, n) _rotr64((v), (n)) + +#else + +#include + +#define U8V(v) ((uint8_t)(v) & 0xFFU) +#define U16V(v) ((uint16_t)(v) & 0xFFFFU) +#define U32V(v) ((uint32_t)(v) & 0xFFFFFFFFU) +#define U64V(v) ((uint64_t)(v) & 0xFFFFFFFFFFFFFFFFU) + +#define ROTL32(v, n) \ + (U32V((uint32_t)(v) << (n)) | ((uint32_t)(v) >> (32 - (n)))) + +// tests fail if we don't have this cast... +#define ROTL64(v, n) \ + (U64V((uint64_t)(v) << (n)) | ((uint64_t)(v) >> (64 - (n)))) + +#define ROTR32(v, n) ROTL32(v, 32 - (n)) +#define ROTR64(v, n) ROTL64(v, 64 - (n)) + +#endif + +#define ROTL8(v, n) \ + (U8V((uint8_t)(v) << (n)) | ((uint8_t)(v) >> (8 - (n)))) + +#define ROTL16(v, n) \ + (U16V((uint16_t)(v) << (n)) | ((uint16_t)(v) >> (16 - (n)))) + +#define ROTR8(v, n) ROTL8(v, 8 - (n)) +#define ROTR16(v, n) ROTL16(v, 16 - (n)) + +#endif diff --git a/examples/gguf-hash/deps/sha1/package.json b/examples/gguf-hash/deps/sha1/package.json new file mode 100644 index 0000000000000000000000000000000000000000..6a5843dd1ef46c21599b4f75e95a088907351abd --- /dev/null +++ b/examples/gguf-hash/deps/sha1/package.json @@ -0,0 +1,9 @@ +{ + "name": "sha1", + "version": "0.0.1", + "repo": "clibs/sha1", + "description": "sha1 hash algorithm", + "keywords": ["sha1", "hash"], + "license": "public domain", + "src": ["sha1.c", "sha1.h"] +} diff --git a/examples/gguf-hash/deps/sha1/sha1.c b/examples/gguf-hash/deps/sha1/sha1.c new file mode 100644 index 0000000000000000000000000000000000000000..76cd6ca3381d58143532c2e664787ea9ca6e3a9f --- /dev/null +++ b/examples/gguf-hash/deps/sha1/sha1.c @@ -0,0 +1,295 @@ +/* +SHA-1 in C +By Steve Reid +100% Public Domain + +Test Vectors (from FIPS PUB 180-1) +"abc" + A9993E36 4706816A BA3E2571 7850C26C 9CD0D89D +"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq" + 84983E44 1C3BD26E BAAE4AA1 F95129E5 E54670F1 +A million repetitions of "a" + 34AA973C D4C4DAA4 F61EEB2B DBAD2731 6534016F +*/ + +/* #define LITTLE_ENDIAN * This should be #define'd already, if true. */ +/* #define SHA1HANDSOFF * Copies data before messing with it. */ + +#define SHA1HANDSOFF + +#include +#include + +/* for uint32_t */ +#include + +#include "sha1.h" + + +#define rol(value, bits) (((value) << (bits)) | ((value) >> (32 - (bits)))) + +/* blk0() and blk() perform the initial expand. */ +/* I got the idea of expanding during the round function from SSLeay */ +#if BYTE_ORDER == LITTLE_ENDIAN +#define blk0(i) (block->l[i] = (rol(block->l[i],24)&0xFF00FF00) \ + |(rol(block->l[i],8)&0x00FF00FF)) +#elif BYTE_ORDER == BIG_ENDIAN +#define blk0(i) block->l[i] +#else +#error "Endianness not defined!" +#endif +#define blk(i) (block->l[i&15] = rol(block->l[(i+13)&15]^block->l[(i+8)&15] \ + ^block->l[(i+2)&15]^block->l[i&15],1)) + +/* (R0+R1), R2, R3, R4 are the different operations used in SHA1 */ +#define R0(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk0(i)+0x5A827999+rol(v,5);w=rol(w,30); +#define R1(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk(i)+0x5A827999+rol(v,5);w=rol(w,30); +#define R2(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0x6ED9EBA1+rol(v,5);w=rol(w,30); +#define R3(v,w,x,y,z,i) z+=(((w|x)&y)|(w&x))+blk(i)+0x8F1BBCDC+rol(v,5);w=rol(w,30); +#define R4(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0xCA62C1D6+rol(v,5);w=rol(w,30); + + +/* Hash a single 512-bit block. This is the core of the algorithm. */ + +void SHA1Transform( + uint32_t state[5], + const unsigned char buffer[64] +) +{ + uint32_t a, b, c, d, e; + + typedef union + { + unsigned char c[64]; + uint32_t l[16]; + } CHAR64LONG16; + +#ifdef SHA1HANDSOFF + CHAR64LONG16 block[1]; /* use array to appear as a pointer */ + + memcpy(block, buffer, 64); +#else + /* The following had better never be used because it causes the + * pointer-to-const buffer to be cast into a pointer to non-const. + * And the result is written through. I threw a "const" in, hoping + * this will cause a diagnostic. + */ + CHAR64LONG16 *block = (const CHAR64LONG16 *) buffer; +#endif + /* Copy context->state[] to working vars */ + a = state[0]; + b = state[1]; + c = state[2]; + d = state[3]; + e = state[4]; + /* 4 rounds of 20 operations each. Loop unrolled. */ + R0(a, b, c, d, e, 0); + R0(e, a, b, c, d, 1); + R0(d, e, a, b, c, 2); + R0(c, d, e, a, b, 3); + R0(b, c, d, e, a, 4); + R0(a, b, c, d, e, 5); + R0(e, a, b, c, d, 6); + R0(d, e, a, b, c, 7); + R0(c, d, e, a, b, 8); + R0(b, c, d, e, a, 9); + R0(a, b, c, d, e, 10); + R0(e, a, b, c, d, 11); + R0(d, e, a, b, c, 12); + R0(c, d, e, a, b, 13); + R0(b, c, d, e, a, 14); + R0(a, b, c, d, e, 15); + R1(e, a, b, c, d, 16); + R1(d, e, a, b, c, 17); + R1(c, d, e, a, b, 18); + R1(b, c, d, e, a, 19); + R2(a, b, c, d, e, 20); + R2(e, a, b, c, d, 21); + R2(d, e, a, b, c, 22); + R2(c, d, e, a, b, 23); + R2(b, c, d, e, a, 24); + R2(a, b, c, d, e, 25); + R2(e, a, b, c, d, 26); + R2(d, e, a, b, c, 27); + R2(c, d, e, a, b, 28); + R2(b, c, d, e, a, 29); + R2(a, b, c, d, e, 30); + R2(e, a, b, c, d, 31); + R2(d, e, a, b, c, 32); + R2(c, d, e, a, b, 33); + R2(b, c, d, e, a, 34); + R2(a, b, c, d, e, 35); + R2(e, a, b, c, d, 36); + R2(d, e, a, b, c, 37); + R2(c, d, e, a, b, 38); + R2(b, c, d, e, a, 39); + R3(a, b, c, d, e, 40); + R3(e, a, b, c, d, 41); + R3(d, e, a, b, c, 42); + R3(c, d, e, a, b, 43); + R3(b, c, d, e, a, 44); + R3(a, b, c, d, e, 45); + R3(e, a, b, c, d, 46); + R3(d, e, a, b, c, 47); + R3(c, d, e, a, b, 48); + R3(b, c, d, e, a, 49); + R3(a, b, c, d, e, 50); + R3(e, a, b, c, d, 51); + R3(d, e, a, b, c, 52); + R3(c, d, e, a, b, 53); + R3(b, c, d, e, a, 54); + R3(a, b, c, d, e, 55); + R3(e, a, b, c, d, 56); + R3(d, e, a, b, c, 57); + R3(c, d, e, a, b, 58); + R3(b, c, d, e, a, 59); + R4(a, b, c, d, e, 60); + R4(e, a, b, c, d, 61); + R4(d, e, a, b, c, 62); + R4(c, d, e, a, b, 63); + R4(b, c, d, e, a, 64); + R4(a, b, c, d, e, 65); + R4(e, a, b, c, d, 66); + R4(d, e, a, b, c, 67); + R4(c, d, e, a, b, 68); + R4(b, c, d, e, a, 69); + R4(a, b, c, d, e, 70); + R4(e, a, b, c, d, 71); + R4(d, e, a, b, c, 72); + R4(c, d, e, a, b, 73); + R4(b, c, d, e, a, 74); + R4(a, b, c, d, e, 75); + R4(e, a, b, c, d, 76); + R4(d, e, a, b, c, 77); + R4(c, d, e, a, b, 78); + R4(b, c, d, e, a, 79); + /* Add the working vars back into context.state[] */ + state[0] += a; + state[1] += b; + state[2] += c; + state[3] += d; + state[4] += e; + /* Wipe variables */ + a = b = c = d = e = 0; +#ifdef SHA1HANDSOFF + memset(block, '\0', sizeof(block)); +#endif +} + + +/* SHA1Init - Initialize new context */ + +void SHA1Init( + SHA1_CTX * context +) +{ + /* SHA1 initialization constants */ + context->state[0] = 0x67452301; + context->state[1] = 0xEFCDAB89; + context->state[2] = 0x98BADCFE; + context->state[3] = 0x10325476; + context->state[4] = 0xC3D2E1F0; + context->count[0] = context->count[1] = 0; +} + + +/* Run your data through this. */ + +void SHA1Update( + SHA1_CTX * context, + const unsigned char *data, + uint32_t len +) +{ + uint32_t i; + + uint32_t j; + + j = context->count[0]; + if ((context->count[0] += len << 3) < j) + context->count[1]++; + context->count[1] += (len >> 29); + j = (j >> 3) & 63; + if ((j + len) > 63) + { + memcpy(&context->buffer[j], data, (i = 64 - j)); + SHA1Transform(context->state, context->buffer); + for (; i + 63 < len; i += 64) + { + SHA1Transform(context->state, &data[i]); + } + j = 0; + } + else + i = 0; + memcpy(&context->buffer[j], &data[i], len - i); +} + + +/* Add padding and return the message digest. */ + +void SHA1Final( + unsigned char digest[20], + SHA1_CTX * context +) +{ + unsigned i; + + unsigned char finalcount[8]; + + unsigned char c; + +#if 0 /* untested "improvement" by DHR */ + /* Convert context->count to a sequence of bytes + * in finalcount. Second element first, but + * big-endian order within element. + * But we do it all backwards. + */ + unsigned char *fcp = &finalcount[8]; + + for (i = 0; i < 2; i++) + { + uint32_t t = context->count[i]; + + int j; + + for (j = 0; j < 4; t >>= 8, j++) + *--fcp = (unsigned char) t} +#else + for (i = 0; i < 8; i++) + { + finalcount[i] = (unsigned char) ((context->count[(i >= 4 ? 0 : 1)] >> ((3 - (i & 3)) * 8)) & 255); /* Endian independent */ + } +#endif + c = 0200; + SHA1Update(context, &c, 1); + while ((context->count[0] & 504) != 448) + { + c = 0000; + SHA1Update(context, &c, 1); + } + SHA1Update(context, finalcount, 8); /* Should cause a SHA1Transform() */ + for (i = 0; i < 20; i++) + { + digest[i] = (unsigned char) + ((context->state[i >> 2] >> ((3 - (i & 3)) * 8)) & 255); + } + /* Wipe variables */ + memset(context, '\0', sizeof(*context)); + memset(&finalcount, '\0', sizeof(finalcount)); +} + +void SHA1( + char *hash_out, + const char *str, + uint32_t len) +{ + SHA1_CTX ctx; + unsigned int ii; + + SHA1Init(&ctx); + for (ii=0; ii + 100% Public Domain + */ + +#include "stdint.h" + +#if defined(__cplusplus) +extern "C" { +#endif + +typedef struct +{ + uint32_t state[5]; + uint32_t count[2]; + unsigned char buffer[64]; +} SHA1_CTX; + +void SHA1Transform( + uint32_t state[5], + const unsigned char buffer[64] + ); + +void SHA1Init( + SHA1_CTX * context + ); + +void SHA1Update( + SHA1_CTX * context, + const unsigned char *data, + uint32_t len + ); + +void SHA1Final( + unsigned char digest[20], + SHA1_CTX * context + ); + +void SHA1( + char *hash_out, + const char *str, + uint32_t len); + +#if defined(__cplusplus) +} +#endif + +#endif /* SHA1_H */ diff --git a/examples/gguf-hash/deps/sha256/package.json b/examples/gguf-hash/deps/sha256/package.json new file mode 100644 index 0000000000000000000000000000000000000000..b92a0412738dfb495cd859c5d94fb45413755bbd --- /dev/null +++ b/examples/gguf-hash/deps/sha256/package.json @@ -0,0 +1,15 @@ +{ + "name": "sha256", + "version": "0.0.2", + "repo": "jb55/sha256.c", + "description": "sha256 in c", + "keywords": ["sha256", "sha2"], + "src": ["sha256.c", "sha256.h"], + "dependencies": { + "jb55/rotate-bits.h": "0.1.1" + }, + "development": { + "thlorenz/tap.c": "*" + } +} + diff --git a/examples/gguf-hash/deps/sha256/sha256.c b/examples/gguf-hash/deps/sha256/sha256.c new file mode 100644 index 0000000000000000000000000000000000000000..a7a87aeb20032c7f9e87491af927b9074037202a --- /dev/null +++ b/examples/gguf-hash/deps/sha256/sha256.c @@ -0,0 +1,221 @@ +/* Crypto/Sha256.c -- SHA-256 Hash +2010-06-11 : Igor Pavlov : Public domain +This code is based on public domain code from Wei Dai's Crypto++ library. */ + +#include "rotate-bits/rotate-bits.h" +#include "sha256.h" + +/* define it for speed optimization */ +#define _SHA256_UNROLL +#define _SHA256_UNROLL2 + +void +sha256_init(sha256_t *p) +{ + p->state[0] = 0x6a09e667; + p->state[1] = 0xbb67ae85; + p->state[2] = 0x3c6ef372; + p->state[3] = 0xa54ff53a; + p->state[4] = 0x510e527f; + p->state[5] = 0x9b05688c; + p->state[6] = 0x1f83d9ab; + p->state[7] = 0x5be0cd19; + p->count = 0; +} + +#define S0(x) (ROTR32(x, 2) ^ ROTR32(x,13) ^ ROTR32(x, 22)) +#define S1(x) (ROTR32(x, 6) ^ ROTR32(x,11) ^ ROTR32(x, 25)) +#define s0(x) (ROTR32(x, 7) ^ ROTR32(x,18) ^ (x >> 3)) +#define s1(x) (ROTR32(x,17) ^ ROTR32(x,19) ^ (x >> 10)) + +#define blk0(i) (W[i] = data[i]) +#define blk2(i) (W[i&15] += s1(W[(i-2)&15]) + W[(i-7)&15] + s0(W[(i-15)&15])) + +#define Ch(x,y,z) (z^(x&(y^z))) +#define Maj(x,y,z) ((x&y)|(z&(x|y))) + +#define a(i) T[(0-(i))&7] +#define b(i) T[(1-(i))&7] +#define c(i) T[(2-(i))&7] +#define d(i) T[(3-(i))&7] +#define e(i) T[(4-(i))&7] +#define f(i) T[(5-(i))&7] +#define g(i) T[(6-(i))&7] +#define h(i) T[(7-(i))&7] + + +#ifdef _SHA256_UNROLL2 + +#define R(a,b,c,d,e,f,g,h, i) h += S1(e) + Ch(e,f,g) + K[i+j] + (j?blk2(i):blk0(i));\ + d += h; h += S0(a) + Maj(a, b, c) + +#define RX_8(i) \ + R(a,b,c,d,e,f,g,h, i); \ + R(h,a,b,c,d,e,f,g, (i+1)); \ + R(g,h,a,b,c,d,e,f, (i+2)); \ + R(f,g,h,a,b,c,d,e, (i+3)); \ + R(e,f,g,h,a,b,c,d, (i+4)); \ + R(d,e,f,g,h,a,b,c, (i+5)); \ + R(c,d,e,f,g,h,a,b, (i+6)); \ + R(b,c,d,e,f,g,h,a, (i+7)) + +#else + +#define R(i) h(i) += S1(e(i)) + Ch(e(i),f(i),g(i)) + K[i+j] + (j?blk2(i):blk0(i));\ + d(i) += h(i); h(i) += S0(a(i)) + Maj(a(i), b(i), c(i)) + +#ifdef _SHA256_UNROLL + +#define RX_8(i) R(i+0); R(i+1); R(i+2); R(i+3); R(i+4); R(i+5); R(i+6); R(i+7); + +#endif + +#endif + +static const uint32_t K[64] = { + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, + 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, + 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, + 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, + 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, + 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, + 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, + 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, + 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2 +}; + +static void +sha256_transform(uint32_t *state, const uint32_t *data) +{ + uint32_t W[16] = {0}; + unsigned j; + #ifdef _SHA256_UNROLL2 + uint32_t a,b,c,d,e,f,g,h; + a = state[0]; + b = state[1]; + c = state[2]; + d = state[3]; + e = state[4]; + f = state[5]; + g = state[6]; + h = state[7]; + #else + uint32_t T[8]; + for (j = 0; j < 8; j++) + T[j] = state[j]; + #endif + + for (j = 0; j < 64; j += 16) + { + #if defined(_SHA256_UNROLL) || defined(_SHA256_UNROLL2) + RX_8(0); RX_8(8); + #else + unsigned i; + for (i = 0; i < 16; i++) { R(i); } + #endif + } + + #ifdef _SHA256_UNROLL2 + state[0] += a; + state[1] += b; + state[2] += c; + state[3] += d; + state[4] += e; + state[5] += f; + state[6] += g; + state[7] += h; + #else + for (j = 0; j < 8; j++) + state[j] += T[j]; + #endif + + /* Wipe variables */ + /* memset(W, 0, sizeof(W)); */ + /* memset(T, 0, sizeof(T)); */ +} + +#undef S0 +#undef S1 +#undef s0 +#undef s1 + +static void +sha256_write_byte_block(sha256_t *p) +{ + uint32_t data32[16]; + unsigned i; + for (i = 0; i < 16; i++) + data32[i] = + ((uint32_t)(p->buffer[i * 4 ]) << 24) + + ((uint32_t)(p->buffer[i * 4 + 1]) << 16) + + ((uint32_t)(p->buffer[i * 4 + 2]) << 8) + + ((uint32_t)(p->buffer[i * 4 + 3])); + sha256_transform(p->state, data32); +} + + +void +sha256_hash(unsigned char *buf, const unsigned char *data, size_t size) +{ + sha256_t hash; + sha256_init(&hash); + sha256_update(&hash, data, size); + sha256_final(&hash, buf); +} + + +void +sha256_update(sha256_t *p, const unsigned char *data, size_t size) +{ + uint32_t curBufferPos = (uint32_t)p->count & 0x3F; + while (size > 0) + { + p->buffer[curBufferPos++] = *data++; + p->count++; + size--; + if (curBufferPos == 64) + { + curBufferPos = 0; + sha256_write_byte_block(p); + } + } +} + + +void +sha256_final(sha256_t *p, unsigned char *digest) +{ + uint64_t lenInBits = (p->count << 3); + uint32_t curBufferPos = (uint32_t)p->count & 0x3F; + unsigned i; + p->buffer[curBufferPos++] = 0x80; + while (curBufferPos != (64 - 8)) + { + curBufferPos &= 0x3F; + if (curBufferPos == 0) + sha256_write_byte_block(p); + p->buffer[curBufferPos++] = 0; + } + for (i = 0; i < 8; i++) + { + p->buffer[curBufferPos++] = (unsigned char)(lenInBits >> 56); + lenInBits <<= 8; + } + sha256_write_byte_block(p); + + for (i = 0; i < 8; i++) + { + *digest++ = (unsigned char)(p->state[i] >> 24); + *digest++ = (unsigned char)(p->state[i] >> 16); + *digest++ = (unsigned char)(p->state[i] >> 8); + *digest++ = (unsigned char)(p->state[i]); + } + sha256_init(p); +} diff --git a/examples/gguf-hash/deps/sha256/sha256.h b/examples/gguf-hash/deps/sha256/sha256.h new file mode 100644 index 0000000000000000000000000000000000000000..21657e66b602d5e74e90fbd23af64d58d94598de --- /dev/null +++ b/examples/gguf-hash/deps/sha256/sha256.h @@ -0,0 +1,24 @@ +/* Sha256.h -- SHA-256 Hash +2010-06-11 : Igor Pavlov : Public domain */ + +#ifndef __CRYPTO_SHA256_H +#define __CRYPTO_SHA256_H + +#include +#include + +#define SHA256_DIGEST_SIZE 32 + +typedef struct sha256_t +{ + uint32_t state[8]; + uint64_t count; + unsigned char buffer[64]; +} sha256_t; + +void sha256_init(sha256_t *p); +void sha256_update(sha256_t *p, const unsigned char *data, size_t size); +void sha256_final(sha256_t *p, unsigned char *digest); +void sha256_hash(unsigned char *buf, const unsigned char *data, size_t size); + +#endif diff --git a/examples/gguf-hash/deps/xxhash/clib.json b/examples/gguf-hash/deps/xxhash/clib.json new file mode 100644 index 0000000000000000000000000000000000000000..242343c5d992d4a19b9a9fe9d3845c5f12525569 --- /dev/null +++ b/examples/gguf-hash/deps/xxhash/clib.json @@ -0,0 +1,12 @@ +{ + "name": "xxhash", + "version": "0.8.2", + "repo": "Cyan4973/xxhash", + "description": "Extremely fast non-cryptographic hash algorithm", + "keywords": ["xxhash", "hashing"], + "license": "BSD-2-Clause", + "src": [ + "xxhash.c", + "xxhash.h" + ] +} diff --git a/examples/gguf-hash/deps/xxhash/xxhash.c b/examples/gguf-hash/deps/xxhash/xxhash.c new file mode 100644 index 0000000000000000000000000000000000000000..e60cc37f13c27ff0668ec2535937df3f4c097ec2 --- /dev/null +++ b/examples/gguf-hash/deps/xxhash/xxhash.c @@ -0,0 +1,42 @@ +/* + * xxHash - Extremely Fast Hash algorithm + * Copyright (C) 2012-2023 Yann Collet + * + * BSD 2-Clause License (https://www.opensource.org/licenses/bsd-license.php) + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * You can contact the author at: + * - xxHash homepage: https://www.xxhash.com + * - xxHash source repository: https://github.com/Cyan4973/xxHash + */ + +/* + * xxhash.c instantiates functions defined in xxhash.h + */ + +#define XXH_STATIC_LINKING_ONLY /* access advanced declarations */ +#define XXH_IMPLEMENTATION /* access definitions */ + +#include "xxhash.h" diff --git a/examples/gguf-hash/deps/xxhash/xxhash.h b/examples/gguf-hash/deps/xxhash/xxhash.h new file mode 100644 index 0000000000000000000000000000000000000000..c0fafe20d54ad017425e636c2f3e754494648053 --- /dev/null +++ b/examples/gguf-hash/deps/xxhash/xxhash.h @@ -0,0 +1,7093 @@ +/* + * xxHash - Extremely Fast Hash algorithm + * Header File + * Copyright (C) 2012-2023 Yann Collet + * + * BSD 2-Clause License (https://www.opensource.org/licenses/bsd-license.php) + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * You can contact the author at: + * - xxHash homepage: https://www.xxhash.com + * - xxHash source repository: https://github.com/Cyan4973/xxHash + */ + +/*! + * @mainpage xxHash + * + * xxHash is an extremely fast non-cryptographic hash algorithm, working at RAM speed + * limits. + * + * It is proposed in four flavors, in three families: + * 1. @ref XXH32_family + * - Classic 32-bit hash function. Simple, compact, and runs on almost all + * 32-bit and 64-bit systems. + * 2. @ref XXH64_family + * - Classic 64-bit adaptation of XXH32. Just as simple, and runs well on most + * 64-bit systems (but _not_ 32-bit systems). + * 3. @ref XXH3_family + * - Modern 64-bit and 128-bit hash function family which features improved + * strength and performance across the board, especially on smaller data. + * It benefits greatly from SIMD and 64-bit without requiring it. + * + * Benchmarks + * --- + * The reference system uses an Intel i7-9700K CPU, and runs Ubuntu x64 20.04. + * The open source benchmark program is compiled with clang v10.0 using -O3 flag. + * + * | Hash Name | ISA ext | Width | Large Data Speed | Small Data Velocity | + * | -------------------- | ------- | ----: | ---------------: | ------------------: | + * | XXH3_64bits() | @b AVX2 | 64 | 59.4 GB/s | 133.1 | + * | MeowHash | AES-NI | 128 | 58.2 GB/s | 52.5 | + * | XXH3_128bits() | @b AVX2 | 128 | 57.9 GB/s | 118.1 | + * | CLHash | PCLMUL | 64 | 37.1 GB/s | 58.1 | + * | XXH3_64bits() | @b SSE2 | 64 | 31.5 GB/s | 133.1 | + * | XXH3_128bits() | @b SSE2 | 128 | 29.6 GB/s | 118.1 | + * | RAM sequential read | | N/A | 28.0 GB/s | N/A | + * | ahash | AES-NI | 64 | 22.5 GB/s | 107.2 | + * | City64 | | 64 | 22.0 GB/s | 76.6 | + * | T1ha2 | | 64 | 22.0 GB/s | 99.0 | + * | City128 | | 128 | 21.7 GB/s | 57.7 | + * | FarmHash | AES-NI | 64 | 21.3 GB/s | 71.9 | + * | XXH64() | | 64 | 19.4 GB/s | 71.0 | + * | SpookyHash | | 64 | 19.3 GB/s | 53.2 | + * | Mum | | 64 | 18.0 GB/s | 67.0 | + * | CRC32C | SSE4.2 | 32 | 13.0 GB/s | 57.9 | + * | XXH32() | | 32 | 9.7 GB/s | 71.9 | + * | City32 | | 32 | 9.1 GB/s | 66.0 | + * | Blake3* | @b AVX2 | 256 | 4.4 GB/s | 8.1 | + * | Murmur3 | | 32 | 3.9 GB/s | 56.1 | + * | SipHash* | | 64 | 3.0 GB/s | 43.2 | + * | Blake3* | @b SSE2 | 256 | 2.4 GB/s | 8.1 | + * | HighwayHash | | 64 | 1.4 GB/s | 6.0 | + * | FNV64 | | 64 | 1.2 GB/s | 62.7 | + * | Blake2* | | 256 | 1.1 GB/s | 5.1 | + * | SHA1* | | 160 | 0.8 GB/s | 5.6 | + * | MD5* | | 128 | 0.6 GB/s | 7.8 | + * @note + * - Hashes which require a specific ISA extension are noted. SSE2 is also noted, + * even though it is mandatory on x64. + * - Hashes with an asterisk are cryptographic. Note that MD5 is non-cryptographic + * by modern standards. + * - Small data velocity is a rough average of algorithm's efficiency for small + * data. For more accurate information, see the wiki. + * - More benchmarks and strength tests are found on the wiki: + * https://github.com/Cyan4973/xxHash/wiki + * + * Usage + * ------ + * All xxHash variants use a similar API. Changing the algorithm is a trivial + * substitution. + * + * @pre + * For functions which take an input and length parameter, the following + * requirements are assumed: + * - The range from [`input`, `input + length`) is valid, readable memory. + * - The only exception is if the `length` is `0`, `input` may be `NULL`. + * - For C++, the objects must have the *TriviallyCopyable* property, as the + * functions access bytes directly as if it was an array of `unsigned char`. + * + * @anchor single_shot_example + * **Single Shot** + * + * These functions are stateless functions which hash a contiguous block of memory, + * immediately returning the result. They are the easiest and usually the fastest + * option. + * + * XXH32(), XXH64(), XXH3_64bits(), XXH3_128bits() + * + * @code{.c} + * #include + * #include "xxhash.h" + * + * // Example for a function which hashes a null terminated string with XXH32(). + * XXH32_hash_t hash_string(const char* string, XXH32_hash_t seed) + * { + * // NULL pointers are only valid if the length is zero + * size_t length = (string == NULL) ? 0 : strlen(string); + * return XXH32(string, length, seed); + * } + * @endcode + * + * + * @anchor streaming_example + * **Streaming** + * + * These groups of functions allow incremental hashing of unknown size, even + * more than what would fit in a size_t. + * + * XXH32_reset(), XXH64_reset(), XXH3_64bits_reset(), XXH3_128bits_reset() + * + * @code{.c} + * #include + * #include + * #include "xxhash.h" + * // Example for a function which hashes a FILE incrementally with XXH3_64bits(). + * XXH64_hash_t hashFile(FILE* f) + * { + * // Allocate a state struct. Do not just use malloc() or new. + * XXH3_state_t* state = XXH3_createState(); + * assert(state != NULL && "Out of memory!"); + * // Reset the state to start a new hashing session. + * XXH3_64bits_reset(state); + * char buffer[4096]; + * size_t count; + * // Read the file in chunks + * while ((count = fread(buffer, 1, sizeof(buffer), f)) != 0) { + * // Run update() as many times as necessary to process the data + * XXH3_64bits_update(state, buffer, count); + * } + * // Retrieve the finalized hash. This will not change the state. + * XXH64_hash_t result = XXH3_64bits_digest(state); + * // Free the state. Do not use free(). + * XXH3_freeState(state); + * return result; + * } + * @endcode + * + * Streaming functions generate the xxHash value from an incremental input. + * This method is slower than single-call functions, due to state management. + * For small inputs, prefer `XXH32()` and `XXH64()`, which are better optimized. + * + * An XXH state must first be allocated using `XXH*_createState()`. + * + * Start a new hash by initializing the state with a seed using `XXH*_reset()`. + * + * Then, feed the hash state by calling `XXH*_update()` as many times as necessary. + * + * The function returns an error code, with 0 meaning OK, and any other value + * meaning there is an error. + * + * Finally, a hash value can be produced anytime, by using `XXH*_digest()`. + * This function returns the nn-bits hash as an int or long long. + * + * It's still possible to continue inserting input into the hash state after a + * digest, and generate new hash values later on by invoking `XXH*_digest()`. + * + * When done, release the state using `XXH*_freeState()`. + * + * + * @anchor canonical_representation_example + * **Canonical Representation** + * + * The default return values from XXH functions are unsigned 32, 64 and 128 bit + * integers. + * This the simplest and fastest format for further post-processing. + * + * However, this leaves open the question of what is the order on the byte level, + * since little and big endian conventions will store the same number differently. + * + * The canonical representation settles this issue by mandating big-endian + * convention, the same convention as human-readable numbers (large digits first). + * + * When writing hash values to storage, sending them over a network, or printing + * them, it's highly recommended to use the canonical representation to ensure + * portability across a wider range of systems, present and future. + * + * The following functions allow transformation of hash values to and from + * canonical format. + * + * XXH32_canonicalFromHash(), XXH32_hashFromCanonical(), + * XXH64_canonicalFromHash(), XXH64_hashFromCanonical(), + * XXH128_canonicalFromHash(), XXH128_hashFromCanonical(), + * + * @code{.c} + * #include + * #include "xxhash.h" + * + * // Example for a function which prints XXH32_hash_t in human readable format + * void printXxh32(XXH32_hash_t hash) + * { + * XXH32_canonical_t cano; + * XXH32_canonicalFromHash(&cano, hash); + * size_t i; + * for(i = 0; i < sizeof(cano.digest); ++i) { + * printf("%02x", cano.digest[i]); + * } + * printf("\n"); + * } + * + * // Example for a function which converts XXH32_canonical_t to XXH32_hash_t + * XXH32_hash_t convertCanonicalToXxh32(XXH32_canonical_t cano) + * { + * XXH32_hash_t hash = XXH32_hashFromCanonical(&cano); + * return hash; + * } + * @endcode + * + * + * @file xxhash.h + * xxHash prototypes and implementation + */ + +#if defined (__cplusplus) +extern "C" { +#endif + +/* **************************** + * INLINE mode + ******************************/ +/*! + * @defgroup public Public API + * Contains details on the public xxHash functions. + * @{ + */ +#ifdef XXH_DOXYGEN +/*! + * @brief Gives access to internal state declaration, required for static allocation. + * + * Incompatible with dynamic linking, due to risks of ABI changes. + * + * Usage: + * @code{.c} + * #define XXH_STATIC_LINKING_ONLY + * #include "xxhash.h" + * @endcode + */ +# define XXH_STATIC_LINKING_ONLY +/* Do not undef XXH_STATIC_LINKING_ONLY for Doxygen */ + +/*! + * @brief Gives access to internal definitions. + * + * Usage: + * @code{.c} + * #define XXH_STATIC_LINKING_ONLY + * #define XXH_IMPLEMENTATION + * #include "xxhash.h" + * @endcode + */ +# define XXH_IMPLEMENTATION +/* Do not undef XXH_IMPLEMENTATION for Doxygen */ + +/*! + * @brief Exposes the implementation and marks all functions as `inline`. + * + * Use these build macros to inline xxhash into the target unit. + * Inlining improves performance on small inputs, especially when the length is + * expressed as a compile-time constant: + * + * https://fastcompression.blogspot.com/2018/03/xxhash-for-small-keys-impressive-power.html + * + * It also keeps xxHash symbols private to the unit, so they are not exported. + * + * Usage: + * @code{.c} + * #define XXH_INLINE_ALL + * #include "xxhash.h" + * @endcode + * Do not compile and link xxhash.o as a separate object, as it is not useful. + */ +# define XXH_INLINE_ALL +# undef XXH_INLINE_ALL +/*! + * @brief Exposes the implementation without marking functions as inline. + */ +# define XXH_PRIVATE_API +# undef XXH_PRIVATE_API +/*! + * @brief Emulate a namespace by transparently prefixing all symbols. + * + * If you want to include _and expose_ xxHash functions from within your own + * library, but also want to avoid symbol collisions with other libraries which + * may also include xxHash, you can use @ref XXH_NAMESPACE to automatically prefix + * any public symbol from xxhash library with the value of @ref XXH_NAMESPACE + * (therefore, avoid empty or numeric values). + * + * Note that no change is required within the calling program as long as it + * includes `xxhash.h`: Regular symbol names will be automatically translated + * by this header. + */ +# define XXH_NAMESPACE /* YOUR NAME HERE */ +# undef XXH_NAMESPACE +#endif + +#if (defined(XXH_INLINE_ALL) || defined(XXH_PRIVATE_API)) \ + && !defined(XXH_INLINE_ALL_31684351384) + /* this section should be traversed only once */ +# define XXH_INLINE_ALL_31684351384 + /* give access to the advanced API, required to compile implementations */ +# undef XXH_STATIC_LINKING_ONLY /* avoid macro redef */ +# define XXH_STATIC_LINKING_ONLY + /* make all functions private */ +# undef XXH_PUBLIC_API +# if defined(__GNUC__) +# define XXH_PUBLIC_API static __inline __attribute__((__unused__)) +# elif defined (__cplusplus) || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) +# define XXH_PUBLIC_API static inline +# elif defined(_MSC_VER) +# define XXH_PUBLIC_API static __inline +# else + /* note: this version may generate warnings for unused static functions */ +# define XXH_PUBLIC_API static +# endif + + /* + * This part deals with the special case where a unit wants to inline xxHash, + * but "xxhash.h" has previously been included without XXH_INLINE_ALL, + * such as part of some previously included *.h header file. + * Without further action, the new include would just be ignored, + * and functions would effectively _not_ be inlined (silent failure). + * The following macros solve this situation by prefixing all inlined names, + * avoiding naming collision with previous inclusions. + */ + /* Before that, we unconditionally #undef all symbols, + * in case they were already defined with XXH_NAMESPACE. + * They will then be redefined for XXH_INLINE_ALL + */ +# undef XXH_versionNumber + /* XXH32 */ +# undef XXH32 +# undef XXH32_createState +# undef XXH32_freeState +# undef XXH32_reset +# undef XXH32_update +# undef XXH32_digest +# undef XXH32_copyState +# undef XXH32_canonicalFromHash +# undef XXH32_hashFromCanonical + /* XXH64 */ +# undef XXH64 +# undef XXH64_createState +# undef XXH64_freeState +# undef XXH64_reset +# undef XXH64_update +# undef XXH64_digest +# undef XXH64_copyState +# undef XXH64_canonicalFromHash +# undef XXH64_hashFromCanonical + /* XXH3_64bits */ +# undef XXH3_64bits +# undef XXH3_64bits_withSecret +# undef XXH3_64bits_withSeed +# undef XXH3_64bits_withSecretandSeed +# undef XXH3_createState +# undef XXH3_freeState +# undef XXH3_copyState +# undef XXH3_64bits_reset +# undef XXH3_64bits_reset_withSeed +# undef XXH3_64bits_reset_withSecret +# undef XXH3_64bits_update +# undef XXH3_64bits_digest +# undef XXH3_generateSecret + /* XXH3_128bits */ +# undef XXH128 +# undef XXH3_128bits +# undef XXH3_128bits_withSeed +# undef XXH3_128bits_withSecret +# undef XXH3_128bits_reset +# undef XXH3_128bits_reset_withSeed +# undef XXH3_128bits_reset_withSecret +# undef XXH3_128bits_reset_withSecretandSeed +# undef XXH3_128bits_update +# undef XXH3_128bits_digest +# undef XXH128_isEqual +# undef XXH128_cmp +# undef XXH128_canonicalFromHash +# undef XXH128_hashFromCanonical + /* Finally, free the namespace itself */ +# undef XXH_NAMESPACE + + /* employ the namespace for XXH_INLINE_ALL */ +# define XXH_NAMESPACE XXH_INLINE_ + /* + * Some identifiers (enums, type names) are not symbols, + * but they must nonetheless be renamed to avoid redeclaration. + * Alternative solution: do not redeclare them. + * However, this requires some #ifdefs, and has a more dispersed impact. + * Meanwhile, renaming can be achieved in a single place. + */ +# define XXH_IPREF(Id) XXH_NAMESPACE ## Id +# define XXH_OK XXH_IPREF(XXH_OK) +# define XXH_ERROR XXH_IPREF(XXH_ERROR) +# define XXH_errorcode XXH_IPREF(XXH_errorcode) +# define XXH32_canonical_t XXH_IPREF(XXH32_canonical_t) +# define XXH64_canonical_t XXH_IPREF(XXH64_canonical_t) +# define XXH128_canonical_t XXH_IPREF(XXH128_canonical_t) +# define XXH32_state_s XXH_IPREF(XXH32_state_s) +# define XXH32_state_t XXH_IPREF(XXH32_state_t) +# define XXH64_state_s XXH_IPREF(XXH64_state_s) +# define XXH64_state_t XXH_IPREF(XXH64_state_t) +# define XXH3_state_s XXH_IPREF(XXH3_state_s) +# define XXH3_state_t XXH_IPREF(XXH3_state_t) +# define XXH128_hash_t XXH_IPREF(XXH128_hash_t) + /* Ensure the header is parsed again, even if it was previously included */ +# undef XXHASH_H_5627135585666179 +# undef XXHASH_H_STATIC_13879238742 +#endif /* XXH_INLINE_ALL || XXH_PRIVATE_API */ + +/* **************************************************************** + * Stable API + *****************************************************************/ +#ifndef XXHASH_H_5627135585666179 +#define XXHASH_H_5627135585666179 1 + +/*! @brief Marks a global symbol. */ +#if !defined(XXH_INLINE_ALL) && !defined(XXH_PRIVATE_API) +# if defined(_WIN32) && defined(_MSC_VER) && (defined(XXH_IMPORT) || defined(XXH_EXPORT)) +# ifdef XXH_EXPORT +# define XXH_PUBLIC_API __declspec(dllexport) +# elif XXH_IMPORT +# define XXH_PUBLIC_API __declspec(dllimport) +# endif +# else +# define XXH_PUBLIC_API /* do nothing */ +# endif +#endif + +#ifdef XXH_NAMESPACE +# define XXH_CAT(A,B) A##B +# define XXH_NAME2(A,B) XXH_CAT(A,B) +# define XXH_versionNumber XXH_NAME2(XXH_NAMESPACE, XXH_versionNumber) +/* XXH32 */ +# define XXH32 XXH_NAME2(XXH_NAMESPACE, XXH32) +# define XXH32_createState XXH_NAME2(XXH_NAMESPACE, XXH32_createState) +# define XXH32_freeState XXH_NAME2(XXH_NAMESPACE, XXH32_freeState) +# define XXH32_reset XXH_NAME2(XXH_NAMESPACE, XXH32_reset) +# define XXH32_update XXH_NAME2(XXH_NAMESPACE, XXH32_update) +# define XXH32_digest XXH_NAME2(XXH_NAMESPACE, XXH32_digest) +# define XXH32_copyState XXH_NAME2(XXH_NAMESPACE, XXH32_copyState) +# define XXH32_canonicalFromHash XXH_NAME2(XXH_NAMESPACE, XXH32_canonicalFromHash) +# define XXH32_hashFromCanonical XXH_NAME2(XXH_NAMESPACE, XXH32_hashFromCanonical) +/* XXH64 */ +# define XXH64 XXH_NAME2(XXH_NAMESPACE, XXH64) +# define XXH64_createState XXH_NAME2(XXH_NAMESPACE, XXH64_createState) +# define XXH64_freeState XXH_NAME2(XXH_NAMESPACE, XXH64_freeState) +# define XXH64_reset XXH_NAME2(XXH_NAMESPACE, XXH64_reset) +# define XXH64_update XXH_NAME2(XXH_NAMESPACE, XXH64_update) +# define XXH64_digest XXH_NAME2(XXH_NAMESPACE, XXH64_digest) +# define XXH64_copyState XXH_NAME2(XXH_NAMESPACE, XXH64_copyState) +# define XXH64_canonicalFromHash XXH_NAME2(XXH_NAMESPACE, XXH64_canonicalFromHash) +# define XXH64_hashFromCanonical XXH_NAME2(XXH_NAMESPACE, XXH64_hashFromCanonical) +/* XXH3_64bits */ +# define XXH3_64bits XXH_NAME2(XXH_NAMESPACE, XXH3_64bits) +# define XXH3_64bits_withSecret XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_withSecret) +# define XXH3_64bits_withSeed XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_withSeed) +# define XXH3_64bits_withSecretandSeed XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_withSecretandSeed) +# define XXH3_createState XXH_NAME2(XXH_NAMESPACE, XXH3_createState) +# define XXH3_freeState XXH_NAME2(XXH_NAMESPACE, XXH3_freeState) +# define XXH3_copyState XXH_NAME2(XXH_NAMESPACE, XXH3_copyState) +# define XXH3_64bits_reset XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_reset) +# define XXH3_64bits_reset_withSeed XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_reset_withSeed) +# define XXH3_64bits_reset_withSecret XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_reset_withSecret) +# define XXH3_64bits_reset_withSecretandSeed XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_reset_withSecretandSeed) +# define XXH3_64bits_update XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_update) +# define XXH3_64bits_digest XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_digest) +# define XXH3_generateSecret XXH_NAME2(XXH_NAMESPACE, XXH3_generateSecret) +# define XXH3_generateSecret_fromSeed XXH_NAME2(XXH_NAMESPACE, XXH3_generateSecret_fromSeed) +/* XXH3_128bits */ +# define XXH128 XXH_NAME2(XXH_NAMESPACE, XXH128) +# define XXH3_128bits XXH_NAME2(XXH_NAMESPACE, XXH3_128bits) +# define XXH3_128bits_withSeed XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_withSeed) +# define XXH3_128bits_withSecret XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_withSecret) +# define XXH3_128bits_withSecretandSeed XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_withSecretandSeed) +# define XXH3_128bits_reset XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_reset) +# define XXH3_128bits_reset_withSeed XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_reset_withSeed) +# define XXH3_128bits_reset_withSecret XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_reset_withSecret) +# define XXH3_128bits_reset_withSecretandSeed XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_reset_withSecretandSeed) +# define XXH3_128bits_update XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_update) +# define XXH3_128bits_digest XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_digest) +# define XXH128_isEqual XXH_NAME2(XXH_NAMESPACE, XXH128_isEqual) +# define XXH128_cmp XXH_NAME2(XXH_NAMESPACE, XXH128_cmp) +# define XXH128_canonicalFromHash XXH_NAME2(XXH_NAMESPACE, XXH128_canonicalFromHash) +# define XXH128_hashFromCanonical XXH_NAME2(XXH_NAMESPACE, XXH128_hashFromCanonical) +#endif + + +/* ************************************* +* Compiler specifics +***************************************/ + +/* specific declaration modes for Windows */ +#if !defined(XXH_INLINE_ALL) && !defined(XXH_PRIVATE_API) +# if defined(_WIN32) && defined(_MSC_VER) && (defined(XXH_IMPORT) || defined(XXH_EXPORT)) +# ifdef XXH_EXPORT +# define XXH_PUBLIC_API __declspec(dllexport) +# elif XXH_IMPORT +# define XXH_PUBLIC_API __declspec(dllimport) +# endif +# else +# define XXH_PUBLIC_API /* do nothing */ +# endif +#endif + +#if defined (__GNUC__) +# define XXH_CONSTF __attribute__((__const__)) +# define XXH_PUREF __attribute__((__pure__)) +# define XXH_MALLOCF __attribute__((__malloc__)) +#else +# define XXH_CONSTF /* disable */ +# define XXH_PUREF +# define XXH_MALLOCF +#endif + +/* ************************************* +* Version +***************************************/ +#define XXH_VERSION_MAJOR 0 +#define XXH_VERSION_MINOR 8 +#define XXH_VERSION_RELEASE 3 +/*! @brief Version number, encoded as two digits each */ +#define XXH_VERSION_NUMBER (XXH_VERSION_MAJOR *100*100 + XXH_VERSION_MINOR *100 + XXH_VERSION_RELEASE) + +/*! + * @brief Obtains the xxHash version. + * + * This is mostly useful when xxHash is compiled as a shared library, + * since the returned value comes from the library, as opposed to header file. + * + * @return @ref XXH_VERSION_NUMBER of the invoked library. + */ +XXH_PUBLIC_API XXH_CONSTF unsigned XXH_versionNumber (void); + + +/* **************************** +* Common basic types +******************************/ +#include /* size_t */ +/*! + * @brief Exit code for the streaming API. + */ +typedef enum { + XXH_OK = 0, /*!< OK */ + XXH_ERROR /*!< Error */ +} XXH_errorcode; + + +/*-********************************************************************** +* 32-bit hash +************************************************************************/ +#if defined(XXH_DOXYGEN) /* Don't show include */ +/*! + * @brief An unsigned 32-bit integer. + * + * Not necessarily defined to `uint32_t` but functionally equivalent. + */ +typedef uint32_t XXH32_hash_t; + +#elif !defined (__VMS) \ + && (defined (__cplusplus) \ + || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) ) +# ifdef _AIX +# include +# else +# include +# endif + typedef uint32_t XXH32_hash_t; + +#else +# include +# if UINT_MAX == 0xFFFFFFFFUL + typedef unsigned int XXH32_hash_t; +# elif ULONG_MAX == 0xFFFFFFFFUL + typedef unsigned long XXH32_hash_t; +# else +# error "unsupported platform: need a 32-bit type" +# endif +#endif + +/*! + * @} + * + * @defgroup XXH32_family XXH32 family + * @ingroup public + * Contains functions used in the classic 32-bit xxHash algorithm. + * + * @note + * XXH32 is useful for older platforms, with no or poor 64-bit performance. + * Note that the @ref XXH3_family provides competitive speed for both 32-bit + * and 64-bit systems, and offers true 64/128 bit hash results. + * + * @see @ref XXH64_family, @ref XXH3_family : Other xxHash families + * @see @ref XXH32_impl for implementation details + * @{ + */ + +/*! + * @brief Calculates the 32-bit hash of @p input using xxHash32. + * + * @param input The block of data to be hashed, at least @p length bytes in size. + * @param length The length of @p input, in bytes. + * @param seed The 32-bit seed to alter the hash's output predictably. + * + * @pre + * The memory between @p input and @p input + @p length must be valid, + * readable, contiguous memory. However, if @p length is `0`, @p input may be + * `NULL`. In C++, this also must be *TriviallyCopyable*. + * + * @return The calculated 32-bit xxHash32 value. + * + * @see @ref single_shot_example "Single Shot Example" for an example. + */ +XXH_PUBLIC_API XXH_PUREF XXH32_hash_t XXH32 (const void* input, size_t length, XXH32_hash_t seed); + +#ifndef XXH_NO_STREAM +/*! + * @typedef struct XXH32_state_s XXH32_state_t + * @brief The opaque state struct for the XXH32 streaming API. + * + * @see XXH32_state_s for details. + * @see @ref streaming_example "Streaming Example" + */ +typedef struct XXH32_state_s XXH32_state_t; + +/*! + * @brief Allocates an @ref XXH32_state_t. + * + * @return An allocated pointer of @ref XXH32_state_t on success. + * @return `NULL` on failure. + * + * @note Must be freed with XXH32_freeState(). + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_MALLOCF XXH32_state_t* XXH32_createState(void); +/*! + * @brief Frees an @ref XXH32_state_t. + * + * @param statePtr A pointer to an @ref XXH32_state_t allocated with @ref XXH32_createState(). + * + * @return @ref XXH_OK. + * + * @note @p statePtr must be allocated with XXH32_createState(). + * + * @see @ref streaming_example "Streaming Example" + * + */ +XXH_PUBLIC_API XXH_errorcode XXH32_freeState(XXH32_state_t* statePtr); +/*! + * @brief Copies one @ref XXH32_state_t to another. + * + * @param dst_state The state to copy to. + * @param src_state The state to copy from. + * @pre + * @p dst_state and @p src_state must not be `NULL` and must not overlap. + */ +XXH_PUBLIC_API void XXH32_copyState(XXH32_state_t* dst_state, const XXH32_state_t* src_state); + +/*! + * @brief Resets an @ref XXH32_state_t to begin a new hash. + * + * @param statePtr The state struct to reset. + * @param seed The 32-bit seed to alter the hash result predictably. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @note This function resets and seeds a state. Call it before @ref XXH32_update(). + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_errorcode XXH32_reset (XXH32_state_t* statePtr, XXH32_hash_t seed); + +/*! + * @brief Consumes a block of @p input to an @ref XXH32_state_t. + * + * @param statePtr The state struct to update. + * @param input The block of data to be hashed, at least @p length bytes in size. + * @param length The length of @p input, in bytes. + * + * @pre + * @p statePtr must not be `NULL`. + * @pre + * The memory between @p input and @p input + @p length must be valid, + * readable, contiguous memory. However, if @p length is `0`, @p input may be + * `NULL`. In C++, this also must be *TriviallyCopyable*. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @note Call this to incrementally consume blocks of data. + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_errorcode XXH32_update (XXH32_state_t* statePtr, const void* input, size_t length); + +/*! + * @brief Returns the calculated hash value from an @ref XXH32_state_t. + * + * @param statePtr The state struct to calculate the hash from. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return The calculated 32-bit xxHash32 value from that state. + * + * @note + * Calling XXH32_digest() will not affect @p statePtr, so you can update, + * digest, and update again. + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_PUREF XXH32_hash_t XXH32_digest (const XXH32_state_t* statePtr); +#endif /* !XXH_NO_STREAM */ + +/******* Canonical representation *******/ + +/*! + * @brief Canonical (big endian) representation of @ref XXH32_hash_t. + */ +typedef struct { + unsigned char digest[4]; /*!< Hash bytes, big endian */ +} XXH32_canonical_t; + +/*! + * @brief Converts an @ref XXH32_hash_t to a big endian @ref XXH32_canonical_t. + * + * @param dst The @ref XXH32_canonical_t pointer to be stored to. + * @param hash The @ref XXH32_hash_t to be converted. + * + * @pre + * @p dst must not be `NULL`. + * + * @see @ref canonical_representation_example "Canonical Representation Example" + */ +XXH_PUBLIC_API void XXH32_canonicalFromHash(XXH32_canonical_t* dst, XXH32_hash_t hash); + +/*! + * @brief Converts an @ref XXH32_canonical_t to a native @ref XXH32_hash_t. + * + * @param src The @ref XXH32_canonical_t to convert. + * + * @pre + * @p src must not be `NULL`. + * + * @return The converted hash. + * + * @see @ref canonical_representation_example "Canonical Representation Example" + */ +XXH_PUBLIC_API XXH_PUREF XXH32_hash_t XXH32_hashFromCanonical(const XXH32_canonical_t* src); + + +/*! @cond Doxygen ignores this part */ +#ifdef __has_attribute +# define XXH_HAS_ATTRIBUTE(x) __has_attribute(x) +#else +# define XXH_HAS_ATTRIBUTE(x) 0 +#endif +/*! @endcond */ + +/*! @cond Doxygen ignores this part */ +/* + * C23 __STDC_VERSION__ number hasn't been specified yet. For now + * leave as `201711L` (C17 + 1). + * TODO: Update to correct value when its been specified. + */ +#define XXH_C23_VN 201711L +/*! @endcond */ + +/*! @cond Doxygen ignores this part */ +/* C-language Attributes are added in C23. */ +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= XXH_C23_VN) && defined(__has_c_attribute) +# define XXH_HAS_C_ATTRIBUTE(x) __has_c_attribute(x) +#else +# define XXH_HAS_C_ATTRIBUTE(x) 0 +#endif +/*! @endcond */ + +/*! @cond Doxygen ignores this part */ +#if defined(__cplusplus) && defined(__has_cpp_attribute) +# define XXH_HAS_CPP_ATTRIBUTE(x) __has_cpp_attribute(x) +#else +# define XXH_HAS_CPP_ATTRIBUTE(x) 0 +#endif +/*! @endcond */ + +/*! @cond Doxygen ignores this part */ +/* + * Define XXH_FALLTHROUGH macro for annotating switch case with the 'fallthrough' attribute + * introduced in CPP17 and C23. + * CPP17 : https://en.cppreference.com/w/cpp/language/attributes/fallthrough + * C23 : https://en.cppreference.com/w/c/language/attributes/fallthrough + */ +#if XXH_HAS_C_ATTRIBUTE(fallthrough) || XXH_HAS_CPP_ATTRIBUTE(fallthrough) +# define XXH_FALLTHROUGH [[fallthrough]] +#elif XXH_HAS_ATTRIBUTE(__fallthrough__) +# define XXH_FALLTHROUGH __attribute__ ((__fallthrough__)) +#else +# define XXH_FALLTHROUGH /* fallthrough */ +#endif +/*! @endcond */ + +/*! @cond Doxygen ignores this part */ +/* + * Define XXH_NOESCAPE for annotated pointers in public API. + * https://clang.llvm.org/docs/AttributeReference.html#noescape + * As of writing this, only supported by clang. + */ +#if XXH_HAS_ATTRIBUTE(noescape) +# define XXH_NOESCAPE __attribute__((__noescape__)) +#else +# define XXH_NOESCAPE +#endif +/*! @endcond */ + + +/*! + * @} + * @ingroup public + * @{ + */ + +#ifndef XXH_NO_LONG_LONG +/*-********************************************************************** +* 64-bit hash +************************************************************************/ +#if defined(XXH_DOXYGEN) /* don't include */ +/*! + * @brief An unsigned 64-bit integer. + * + * Not necessarily defined to `uint64_t` but functionally equivalent. + */ +typedef uint64_t XXH64_hash_t; +#elif !defined (__VMS) \ + && (defined (__cplusplus) \ + || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) ) +# ifdef _AIX +# include +# else +# include +# endif + typedef uint64_t XXH64_hash_t; +#else +# include +# if defined(__LP64__) && ULONG_MAX == 0xFFFFFFFFFFFFFFFFULL + /* LP64 ABI says uint64_t is unsigned long */ + typedef unsigned long XXH64_hash_t; +# else + /* the following type must have a width of 64-bit */ + typedef unsigned long long XXH64_hash_t; +# endif +#endif + +/*! + * @} + * + * @defgroup XXH64_family XXH64 family + * @ingroup public + * @{ + * Contains functions used in the classic 64-bit xxHash algorithm. + * + * @note + * XXH3 provides competitive speed for both 32-bit and 64-bit systems, + * and offers true 64/128 bit hash results. + * It provides better speed for systems with vector processing capabilities. + */ + +/*! + * @brief Calculates the 64-bit hash of @p input using xxHash64. + * + * @param input The block of data to be hashed, at least @p length bytes in size. + * @param length The length of @p input, in bytes. + * @param seed The 64-bit seed to alter the hash's output predictably. + * + * @pre + * The memory between @p input and @p input + @p length must be valid, + * readable, contiguous memory. However, if @p length is `0`, @p input may be + * `NULL`. In C++, this also must be *TriviallyCopyable*. + * + * @return The calculated 64-bit xxHash64 value. + * + * @see @ref single_shot_example "Single Shot Example" for an example. + */ +XXH_PUBLIC_API XXH_PUREF XXH64_hash_t XXH64(XXH_NOESCAPE const void* input, size_t length, XXH64_hash_t seed); + +/******* Streaming *******/ +#ifndef XXH_NO_STREAM +/*! + * @brief The opaque state struct for the XXH64 streaming API. + * + * @see XXH64_state_s for details. + * @see @ref streaming_example "Streaming Example" + */ +typedef struct XXH64_state_s XXH64_state_t; /* incomplete type */ + +/*! + * @brief Allocates an @ref XXH64_state_t. + * + * @return An allocated pointer of @ref XXH64_state_t on success. + * @return `NULL` on failure. + * + * @note Must be freed with XXH64_freeState(). + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_MALLOCF XXH64_state_t* XXH64_createState(void); + +/*! + * @brief Frees an @ref XXH64_state_t. + * + * @param statePtr A pointer to an @ref XXH64_state_t allocated with @ref XXH64_createState(). + * + * @return @ref XXH_OK. + * + * @note @p statePtr must be allocated with XXH64_createState(). + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_errorcode XXH64_freeState(XXH64_state_t* statePtr); + +/*! + * @brief Copies one @ref XXH64_state_t to another. + * + * @param dst_state The state to copy to. + * @param src_state The state to copy from. + * @pre + * @p dst_state and @p src_state must not be `NULL` and must not overlap. + */ +XXH_PUBLIC_API void XXH64_copyState(XXH_NOESCAPE XXH64_state_t* dst_state, const XXH64_state_t* src_state); + +/*! + * @brief Resets an @ref XXH64_state_t to begin a new hash. + * + * @param statePtr The state struct to reset. + * @param seed The 64-bit seed to alter the hash result predictably. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @note This function resets and seeds a state. Call it before @ref XXH64_update(). + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_errorcode XXH64_reset (XXH_NOESCAPE XXH64_state_t* statePtr, XXH64_hash_t seed); + +/*! + * @brief Consumes a block of @p input to an @ref XXH64_state_t. + * + * @param statePtr The state struct to update. + * @param input The block of data to be hashed, at least @p length bytes in size. + * @param length The length of @p input, in bytes. + * + * @pre + * @p statePtr must not be `NULL`. + * @pre + * The memory between @p input and @p input + @p length must be valid, + * readable, contiguous memory. However, if @p length is `0`, @p input may be + * `NULL`. In C++, this also must be *TriviallyCopyable*. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @note Call this to incrementally consume blocks of data. + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_errorcode XXH64_update (XXH_NOESCAPE XXH64_state_t* statePtr, XXH_NOESCAPE const void* input, size_t length); + +/*! + * @brief Returns the calculated hash value from an @ref XXH64_state_t. + * + * @param statePtr The state struct to calculate the hash from. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return The calculated 64-bit xxHash64 value from that state. + * + * @note + * Calling XXH64_digest() will not affect @p statePtr, so you can update, + * digest, and update again. + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_PUREF XXH64_hash_t XXH64_digest (XXH_NOESCAPE const XXH64_state_t* statePtr); +#endif /* !XXH_NO_STREAM */ +/******* Canonical representation *******/ + +/*! + * @brief Canonical (big endian) representation of @ref XXH64_hash_t. + */ +typedef struct { unsigned char digest[sizeof(XXH64_hash_t)]; } XXH64_canonical_t; + +/*! + * @brief Converts an @ref XXH64_hash_t to a big endian @ref XXH64_canonical_t. + * + * @param dst The @ref XXH64_canonical_t pointer to be stored to. + * @param hash The @ref XXH64_hash_t to be converted. + * + * @pre + * @p dst must not be `NULL`. + * + * @see @ref canonical_representation_example "Canonical Representation Example" + */ +XXH_PUBLIC_API void XXH64_canonicalFromHash(XXH_NOESCAPE XXH64_canonical_t* dst, XXH64_hash_t hash); + +/*! + * @brief Converts an @ref XXH64_canonical_t to a native @ref XXH64_hash_t. + * + * @param src The @ref XXH64_canonical_t to convert. + * + * @pre + * @p src must not be `NULL`. + * + * @return The converted hash. + * + * @see @ref canonical_representation_example "Canonical Representation Example" + */ +XXH_PUBLIC_API XXH_PUREF XXH64_hash_t XXH64_hashFromCanonical(XXH_NOESCAPE const XXH64_canonical_t* src); + +#ifndef XXH_NO_XXH3 + +/*! + * @} + * ************************************************************************ + * @defgroup XXH3_family XXH3 family + * @ingroup public + * @{ + * + * XXH3 is a more recent hash algorithm featuring: + * - Improved speed for both small and large inputs + * - True 64-bit and 128-bit outputs + * - SIMD acceleration + * - Improved 32-bit viability + * + * Speed analysis methodology is explained here: + * + * https://fastcompression.blogspot.com/2019/03/presenting-xxh3.html + * + * Compared to XXH64, expect XXH3 to run approximately + * ~2x faster on large inputs and >3x faster on small ones, + * exact differences vary depending on platform. + * + * XXH3's speed benefits greatly from SIMD and 64-bit arithmetic, + * but does not require it. + * Most 32-bit and 64-bit targets that can run XXH32 smoothly can run XXH3 + * at competitive speeds, even without vector support. Further details are + * explained in the implementation. + * + * XXH3 has a fast scalar implementation, but it also includes accelerated SIMD + * implementations for many common platforms: + * - AVX512 + * - AVX2 + * - SSE2 + * - ARM NEON + * - WebAssembly SIMD128 + * - POWER8 VSX + * - s390x ZVector + * This can be controlled via the @ref XXH_VECTOR macro, but it automatically + * selects the best version according to predefined macros. For the x86 family, an + * automatic runtime dispatcher is included separately in @ref xxh_x86dispatch.c. + * + * XXH3 implementation is portable: + * it has a generic C90 formulation that can be compiled on any platform, + * all implementations generate exactly the same hash value on all platforms. + * Starting from v0.8.0, it's also labelled "stable", meaning that + * any future version will also generate the same hash value. + * + * XXH3 offers 2 variants, _64bits and _128bits. + * + * When only 64 bits are needed, prefer invoking the _64bits variant, as it + * reduces the amount of mixing, resulting in faster speed on small inputs. + * It's also generally simpler to manipulate a scalar return type than a struct. + * + * The API supports one-shot hashing, streaming mode, and custom secrets. + */ +/*-********************************************************************** +* XXH3 64-bit variant +************************************************************************/ + +/*! + * @brief Calculates 64-bit unseeded variant of XXH3 hash of @p input. + * + * @param input The block of data to be hashed, at least @p length bytes in size. + * @param length The length of @p input, in bytes. + * + * @pre + * The memory between @p input and @p input + @p length must be valid, + * readable, contiguous memory. However, if @p length is `0`, @p input may be + * `NULL`. In C++, this also must be *TriviallyCopyable*. + * + * @return The calculated 64-bit XXH3 hash value. + * + * @note + * This is equivalent to @ref XXH3_64bits_withSeed() with a seed of `0`, however + * it may have slightly better performance due to constant propagation of the + * defaults. + * + * @see + * XXH3_64bits_withSeed(), XXH3_64bits_withSecret(): other seeding variants + * @see @ref single_shot_example "Single Shot Example" for an example. + */ +XXH_PUBLIC_API XXH_PUREF XXH64_hash_t XXH3_64bits(XXH_NOESCAPE const void* input, size_t length); + +/*! + * @brief Calculates 64-bit seeded variant of XXH3 hash of @p input. + * + * @param input The block of data to be hashed, at least @p length bytes in size. + * @param length The length of @p input, in bytes. + * @param seed The 64-bit seed to alter the hash result predictably. + * + * @pre + * The memory between @p input and @p input + @p length must be valid, + * readable, contiguous memory. However, if @p length is `0`, @p input may be + * `NULL`. In C++, this also must be *TriviallyCopyable*. + * + * @return The calculated 64-bit XXH3 hash value. + * + * @note + * seed == 0 produces the same results as @ref XXH3_64bits(). + * + * This variant generates a custom secret on the fly based on default secret + * altered using the @p seed value. + * + * While this operation is decently fast, note that it's not completely free. + * + * @see @ref single_shot_example "Single Shot Example" for an example. + */ +XXH_PUBLIC_API XXH_PUREF XXH64_hash_t XXH3_64bits_withSeed(XXH_NOESCAPE const void* input, size_t length, XXH64_hash_t seed); + +/*! + * The bare minimum size for a custom secret. + * + * @see + * XXH3_64bits_withSecret(), XXH3_64bits_reset_withSecret(), + * XXH3_128bits_withSecret(), XXH3_128bits_reset_withSecret(). + */ +#define XXH3_SECRET_SIZE_MIN 136 + +/*! + * @brief Calculates 64-bit variant of XXH3 with a custom "secret". + * + * @param data The block of data to be hashed, at least @p len bytes in size. + * @param len The length of @p data, in bytes. + * @param secret The secret data. + * @param secretSize The length of @p secret, in bytes. + * + * @return The calculated 64-bit XXH3 hash value. + * + * @pre + * The memory between @p data and @p data + @p len must be valid, + * readable, contiguous memory. However, if @p length is `0`, @p data may be + * `NULL`. In C++, this also must be *TriviallyCopyable*. + * + * It's possible to provide any blob of bytes as a "secret" to generate the hash. + * This makes it more difficult for an external actor to prepare an intentional collision. + * The main condition is that @p secretSize *must* be large enough (>= @ref XXH3_SECRET_SIZE_MIN). + * However, the quality of the secret impacts the dispersion of the hash algorithm. + * Therefore, the secret _must_ look like a bunch of random bytes. + * Avoid "trivial" or structured data such as repeated sequences or a text document. + * Whenever in doubt about the "randomness" of the blob of bytes, + * consider employing @ref XXH3_generateSecret() instead (see below). + * It will generate a proper high entropy secret derived from the blob of bytes. + * Another advantage of using XXH3_generateSecret() is that + * it guarantees that all bits within the initial blob of bytes + * will impact every bit of the output. + * This is not necessarily the case when using the blob of bytes directly + * because, when hashing _small_ inputs, only a portion of the secret is employed. + * + * @see @ref single_shot_example "Single Shot Example" for an example. + */ +XXH_PUBLIC_API XXH_PUREF XXH64_hash_t XXH3_64bits_withSecret(XXH_NOESCAPE const void* data, size_t len, XXH_NOESCAPE const void* secret, size_t secretSize); + + +/******* Streaming *******/ +#ifndef XXH_NO_STREAM +/* + * Streaming requires state maintenance. + * This operation costs memory and CPU. + * As a consequence, streaming is slower than one-shot hashing. + * For better performance, prefer one-shot functions whenever applicable. + */ + +/*! + * @brief The opaque state struct for the XXH3 streaming API. + * + * @see XXH3_state_s for details. + * @see @ref streaming_example "Streaming Example" + */ +typedef struct XXH3_state_s XXH3_state_t; +XXH_PUBLIC_API XXH_MALLOCF XXH3_state_t* XXH3_createState(void); +XXH_PUBLIC_API XXH_errorcode XXH3_freeState(XXH3_state_t* statePtr); + +/*! + * @brief Copies one @ref XXH3_state_t to another. + * + * @param dst_state The state to copy to. + * @param src_state The state to copy from. + * @pre + * @p dst_state and @p src_state must not be `NULL` and must not overlap. + */ +XXH_PUBLIC_API void XXH3_copyState(XXH_NOESCAPE XXH3_state_t* dst_state, XXH_NOESCAPE const XXH3_state_t* src_state); + +/*! + * @brief Resets an @ref XXH3_state_t to begin a new hash. + * + * @param statePtr The state struct to reset. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @note + * - This function resets `statePtr` and generate a secret with default parameters. + * - Call this function before @ref XXH3_64bits_update(). + * - Digest will be equivalent to `XXH3_64bits()`. + * + * @see @ref streaming_example "Streaming Example" + * + */ +XXH_PUBLIC_API XXH_errorcode XXH3_64bits_reset(XXH_NOESCAPE XXH3_state_t* statePtr); + +/*! + * @brief Resets an @ref XXH3_state_t with 64-bit seed to begin a new hash. + * + * @param statePtr The state struct to reset. + * @param seed The 64-bit seed to alter the hash result predictably. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @note + * - This function resets `statePtr` and generate a secret from `seed`. + * - Call this function before @ref XXH3_64bits_update(). + * - Digest will be equivalent to `XXH3_64bits_withSeed()`. + * + * @see @ref streaming_example "Streaming Example" + * + */ +XXH_PUBLIC_API XXH_errorcode XXH3_64bits_reset_withSeed(XXH_NOESCAPE XXH3_state_t* statePtr, XXH64_hash_t seed); + +/*! + * @brief Resets an @ref XXH3_state_t with secret data to begin a new hash. + * + * @param statePtr The state struct to reset. + * @param secret The secret data. + * @param secretSize The length of @p secret, in bytes. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @note + * `secret` is referenced, it _must outlive_ the hash streaming session. + * + * Similar to one-shot API, `secretSize` must be >= @ref XXH3_SECRET_SIZE_MIN, + * and the quality of produced hash values depends on secret's entropy + * (secret's content should look like a bunch of random bytes). + * When in doubt about the randomness of a candidate `secret`, + * consider employing `XXH3_generateSecret()` instead (see below). + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_errorcode XXH3_64bits_reset_withSecret(XXH_NOESCAPE XXH3_state_t* statePtr, XXH_NOESCAPE const void* secret, size_t secretSize); + +/*! + * @brief Consumes a block of @p input to an @ref XXH3_state_t. + * + * @param statePtr The state struct to update. + * @param input The block of data to be hashed, at least @p length bytes in size. + * @param length The length of @p input, in bytes. + * + * @pre + * @p statePtr must not be `NULL`. + * @pre + * The memory between @p input and @p input + @p length must be valid, + * readable, contiguous memory. However, if @p length is `0`, @p input may be + * `NULL`. In C++, this also must be *TriviallyCopyable*. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @note Call this to incrementally consume blocks of data. + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_errorcode XXH3_64bits_update (XXH_NOESCAPE XXH3_state_t* statePtr, XXH_NOESCAPE const void* input, size_t length); + +/*! + * @brief Returns the calculated XXH3 64-bit hash value from an @ref XXH3_state_t. + * + * @param statePtr The state struct to calculate the hash from. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return The calculated XXH3 64-bit hash value from that state. + * + * @note + * Calling XXH3_64bits_digest() will not affect @p statePtr, so you can update, + * digest, and update again. + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_PUREF XXH64_hash_t XXH3_64bits_digest (XXH_NOESCAPE const XXH3_state_t* statePtr); +#endif /* !XXH_NO_STREAM */ + +/* note : canonical representation of XXH3 is the same as XXH64 + * since they both produce XXH64_hash_t values */ + + +/*-********************************************************************** +* XXH3 128-bit variant +************************************************************************/ + +/*! + * @brief The return value from 128-bit hashes. + * + * Stored in little endian order, although the fields themselves are in native + * endianness. + */ +typedef struct { + XXH64_hash_t low64; /*!< `value & 0xFFFFFFFFFFFFFFFF` */ + XXH64_hash_t high64; /*!< `value >> 64` */ +} XXH128_hash_t; + +/*! + * @brief Calculates 128-bit unseeded variant of XXH3 of @p data. + * + * @param data The block of data to be hashed, at least @p length bytes in size. + * @param len The length of @p data, in bytes. + * + * @return The calculated 128-bit variant of XXH3 value. + * + * The 128-bit variant of XXH3 has more strength, but it has a bit of overhead + * for shorter inputs. + * + * This is equivalent to @ref XXH3_128bits_withSeed() with a seed of `0`, however + * it may have slightly better performance due to constant propagation of the + * defaults. + * + * @see XXH3_128bits_withSeed(), XXH3_128bits_withSecret(): other seeding variants + * @see @ref single_shot_example "Single Shot Example" for an example. + */ +XXH_PUBLIC_API XXH_PUREF XXH128_hash_t XXH3_128bits(XXH_NOESCAPE const void* data, size_t len); +/*! @brief Calculates 128-bit seeded variant of XXH3 hash of @p data. + * + * @param data The block of data to be hashed, at least @p length bytes in size. + * @param len The length of @p data, in bytes. + * @param seed The 64-bit seed to alter the hash result predictably. + * + * @return The calculated 128-bit variant of XXH3 value. + * + * @note + * seed == 0 produces the same results as @ref XXH3_64bits(). + * + * This variant generates a custom secret on the fly based on default secret + * altered using the @p seed value. + * + * While this operation is decently fast, note that it's not completely free. + * + * @see XXH3_128bits(), XXH3_128bits_withSecret(): other seeding variants + * @see @ref single_shot_example "Single Shot Example" for an example. + */ +XXH_PUBLIC_API XXH_PUREF XXH128_hash_t XXH3_128bits_withSeed(XXH_NOESCAPE const void* data, size_t len, XXH64_hash_t seed); +/*! + * @brief Calculates 128-bit variant of XXH3 with a custom "secret". + * + * @param data The block of data to be hashed, at least @p len bytes in size. + * @param len The length of @p data, in bytes. + * @param secret The secret data. + * @param secretSize The length of @p secret, in bytes. + * + * @return The calculated 128-bit variant of XXH3 value. + * + * It's possible to provide any blob of bytes as a "secret" to generate the hash. + * This makes it more difficult for an external actor to prepare an intentional collision. + * The main condition is that @p secretSize *must* be large enough (>= @ref XXH3_SECRET_SIZE_MIN). + * However, the quality of the secret impacts the dispersion of the hash algorithm. + * Therefore, the secret _must_ look like a bunch of random bytes. + * Avoid "trivial" or structured data such as repeated sequences or a text document. + * Whenever in doubt about the "randomness" of the blob of bytes, + * consider employing @ref XXH3_generateSecret() instead (see below). + * It will generate a proper high entropy secret derived from the blob of bytes. + * Another advantage of using XXH3_generateSecret() is that + * it guarantees that all bits within the initial blob of bytes + * will impact every bit of the output. + * This is not necessarily the case when using the blob of bytes directly + * because, when hashing _small_ inputs, only a portion of the secret is employed. + * + * @see @ref single_shot_example "Single Shot Example" for an example. + */ +XXH_PUBLIC_API XXH_PUREF XXH128_hash_t XXH3_128bits_withSecret(XXH_NOESCAPE const void* data, size_t len, XXH_NOESCAPE const void* secret, size_t secretSize); + +/******* Streaming *******/ +#ifndef XXH_NO_STREAM +/* + * Streaming requires state maintenance. + * This operation costs memory and CPU. + * As a consequence, streaming is slower than one-shot hashing. + * For better performance, prefer one-shot functions whenever applicable. + * + * XXH3_128bits uses the same XXH3_state_t as XXH3_64bits(). + * Use already declared XXH3_createState() and XXH3_freeState(). + * + * All reset and streaming functions have same meaning as their 64-bit counterpart. + */ + +/*! + * @brief Resets an @ref XXH3_state_t to begin a new hash. + * + * @param statePtr The state struct to reset. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @note + * - This function resets `statePtr` and generate a secret with default parameters. + * - Call it before @ref XXH3_128bits_update(). + * - Digest will be equivalent to `XXH3_128bits()`. + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_errorcode XXH3_128bits_reset(XXH_NOESCAPE XXH3_state_t* statePtr); + +/*! + * @brief Resets an @ref XXH3_state_t with 64-bit seed to begin a new hash. + * + * @param statePtr The state struct to reset. + * @param seed The 64-bit seed to alter the hash result predictably. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @note + * - This function resets `statePtr` and generate a secret from `seed`. + * - Call it before @ref XXH3_128bits_update(). + * - Digest will be equivalent to `XXH3_128bits_withSeed()`. + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_errorcode XXH3_128bits_reset_withSeed(XXH_NOESCAPE XXH3_state_t* statePtr, XXH64_hash_t seed); +/*! + * @brief Resets an @ref XXH3_state_t with secret data to begin a new hash. + * + * @param statePtr The state struct to reset. + * @param secret The secret data. + * @param secretSize The length of @p secret, in bytes. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * `secret` is referenced, it _must outlive_ the hash streaming session. + * Similar to one-shot API, `secretSize` must be >= @ref XXH3_SECRET_SIZE_MIN, + * and the quality of produced hash values depends on secret's entropy + * (secret's content should look like a bunch of random bytes). + * When in doubt about the randomness of a candidate `secret`, + * consider employing `XXH3_generateSecret()` instead (see below). + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_errorcode XXH3_128bits_reset_withSecret(XXH_NOESCAPE XXH3_state_t* statePtr, XXH_NOESCAPE const void* secret, size_t secretSize); + +/*! + * @brief Consumes a block of @p input to an @ref XXH3_state_t. + * + * Call this to incrementally consume blocks of data. + * + * @param statePtr The state struct to update. + * @param input The block of data to be hashed, at least @p length bytes in size. + * @param length The length of @p input, in bytes. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @note + * The memory between @p input and @p input + @p length must be valid, + * readable, contiguous memory. However, if @p length is `0`, @p input may be + * `NULL`. In C++, this also must be *TriviallyCopyable*. + * + */ +XXH_PUBLIC_API XXH_errorcode XXH3_128bits_update (XXH_NOESCAPE XXH3_state_t* statePtr, XXH_NOESCAPE const void* input, size_t length); + +/*! + * @brief Returns the calculated XXH3 128-bit hash value from an @ref XXH3_state_t. + * + * @param statePtr The state struct to calculate the hash from. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return The calculated XXH3 128-bit hash value from that state. + * + * @note + * Calling XXH3_128bits_digest() will not affect @p statePtr, so you can update, + * digest, and update again. + * + */ +XXH_PUBLIC_API XXH_PUREF XXH128_hash_t XXH3_128bits_digest (XXH_NOESCAPE const XXH3_state_t* statePtr); +#endif /* !XXH_NO_STREAM */ + +/* Following helper functions make it possible to compare XXH128_hast_t values. + * Since XXH128_hash_t is a structure, this capability is not offered by the language. + * Note: For better performance, these functions can be inlined using XXH_INLINE_ALL */ + +/*! + * @brief Check equality of two XXH128_hash_t values + * + * @param h1 The 128-bit hash value. + * @param h2 Another 128-bit hash value. + * + * @return `1` if `h1` and `h2` are equal. + * @return `0` if they are not. + */ +XXH_PUBLIC_API XXH_PUREF int XXH128_isEqual(XXH128_hash_t h1, XXH128_hash_t h2); + +/*! + * @brief Compares two @ref XXH128_hash_t + * + * This comparator is compatible with stdlib's `qsort()`/`bsearch()`. + * + * @param h128_1 Left-hand side value + * @param h128_2 Right-hand side value + * + * @return >0 if @p h128_1 > @p h128_2 + * @return =0 if @p h128_1 == @p h128_2 + * @return <0 if @p h128_1 < @p h128_2 + */ +XXH_PUBLIC_API XXH_PUREF int XXH128_cmp(XXH_NOESCAPE const void* h128_1, XXH_NOESCAPE const void* h128_2); + + +/******* Canonical representation *******/ +typedef struct { unsigned char digest[sizeof(XXH128_hash_t)]; } XXH128_canonical_t; + + +/*! + * @brief Converts an @ref XXH128_hash_t to a big endian @ref XXH128_canonical_t. + * + * @param dst The @ref XXH128_canonical_t pointer to be stored to. + * @param hash The @ref XXH128_hash_t to be converted. + * + * @pre + * @p dst must not be `NULL`. + * @see @ref canonical_representation_example "Canonical Representation Example" + */ +XXH_PUBLIC_API void XXH128_canonicalFromHash(XXH_NOESCAPE XXH128_canonical_t* dst, XXH128_hash_t hash); + +/*! + * @brief Converts an @ref XXH128_canonical_t to a native @ref XXH128_hash_t. + * + * @param src The @ref XXH128_canonical_t to convert. + * + * @pre + * @p src must not be `NULL`. + * + * @return The converted hash. + * @see @ref canonical_representation_example "Canonical Representation Example" + */ +XXH_PUBLIC_API XXH_PUREF XXH128_hash_t XXH128_hashFromCanonical(XXH_NOESCAPE const XXH128_canonical_t* src); + + +#endif /* !XXH_NO_XXH3 */ +#endif /* XXH_NO_LONG_LONG */ + +/*! + * @} + */ +#endif /* XXHASH_H_5627135585666179 */ + + + +#if defined(XXH_STATIC_LINKING_ONLY) && !defined(XXHASH_H_STATIC_13879238742) +#define XXHASH_H_STATIC_13879238742 +/* **************************************************************************** + * This section contains declarations which are not guaranteed to remain stable. + * They may change in future versions, becoming incompatible with a different + * version of the library. + * These declarations should only be used with static linking. + * Never use them in association with dynamic linking! + ***************************************************************************** */ + +/* + * These definitions are only present to allow static allocation + * of XXH states, on stack or in a struct, for example. + * Never **ever** access their members directly. + */ + +/*! + * @internal + * @brief Structure for XXH32 streaming API. + * + * @note This is only defined when @ref XXH_STATIC_LINKING_ONLY, + * @ref XXH_INLINE_ALL, or @ref XXH_IMPLEMENTATION is defined. Otherwise it is + * an opaque type. This allows fields to safely be changed. + * + * Typedef'd to @ref XXH32_state_t. + * Do not access the members of this struct directly. + * @see XXH64_state_s, XXH3_state_s + */ +struct XXH32_state_s { + XXH32_hash_t total_len_32; /*!< Total length hashed, modulo 2^32 */ + XXH32_hash_t large_len; /*!< Whether the hash is >= 16 (handles @ref total_len_32 overflow) */ + XXH32_hash_t v[4]; /*!< Accumulator lanes */ + XXH32_hash_t mem32[4]; /*!< Internal buffer for partial reads. Treated as unsigned char[16]. */ + XXH32_hash_t memsize; /*!< Amount of data in @ref mem32 */ + XXH32_hash_t reserved; /*!< Reserved field. Do not read nor write to it. */ +}; /* typedef'd to XXH32_state_t */ + + +#ifndef XXH_NO_LONG_LONG /* defined when there is no 64-bit support */ + +/*! + * @internal + * @brief Structure for XXH64 streaming API. + * + * @note This is only defined when @ref XXH_STATIC_LINKING_ONLY, + * @ref XXH_INLINE_ALL, or @ref XXH_IMPLEMENTATION is defined. Otherwise it is + * an opaque type. This allows fields to safely be changed. + * + * Typedef'd to @ref XXH64_state_t. + * Do not access the members of this struct directly. + * @see XXH32_state_s, XXH3_state_s + */ +struct XXH64_state_s { + XXH64_hash_t total_len; /*!< Total length hashed. This is always 64-bit. */ + XXH64_hash_t v[4]; /*!< Accumulator lanes */ + XXH64_hash_t mem64[4]; /*!< Internal buffer for partial reads. Treated as unsigned char[32]. */ + XXH32_hash_t memsize; /*!< Amount of data in @ref mem64 */ + XXH32_hash_t reserved32; /*!< Reserved field, needed for padding anyways*/ + XXH64_hash_t reserved64; /*!< Reserved field. Do not read or write to it. */ +}; /* typedef'd to XXH64_state_t */ + +#ifndef XXH_NO_XXH3 + +/* Windows SDK under 10.0.22000 is missing stdalign.h so we add a check + before allowing the windows compiler to use the C11 form. + Reference: https://github.com/Cyan4973/xxHash/issues/955 */ +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L) \ + && (defined(_MSC_VER) && (_MSC_VER >= 1000) || !defined(_MSC_VER)) /* >= C11 */ +# include +# define XXH_ALIGN(n) alignas(n) +#elif defined(__cplusplus) && (__cplusplus >= 201103L) /* >= C++11 */ +/* In C++ alignas() is a keyword */ +# define XXH_ALIGN(n) alignas(n) +#elif defined(__GNUC__) +# define XXH_ALIGN(n) __attribute__ ((aligned(n))) +#elif defined(_MSC_VER) +# define XXH_ALIGN(n) __declspec(align(n)) +#else +# define XXH_ALIGN(n) /* disabled */ +#endif + +/* Old GCC versions only accept the attribute after the type in structures. */ +#if !(defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L)) /* C11+ */ \ + && ! (defined(__cplusplus) && (__cplusplus >= 201103L)) /* >= C++11 */ \ + && defined(__GNUC__) +# define XXH_ALIGN_MEMBER(align, type) type XXH_ALIGN(align) +#else +# define XXH_ALIGN_MEMBER(align, type) XXH_ALIGN(align) type +#endif + +/*! + * @brief The size of the internal XXH3 buffer. + * + * This is the optimal update size for incremental hashing. + * + * @see XXH3_64b_update(), XXH3_128b_update(). + */ +#define XXH3_INTERNALBUFFER_SIZE 256 + +/*! + * @internal + * @brief Default size of the secret buffer (and @ref XXH3_kSecret). + * + * This is the size used in @ref XXH3_kSecret and the seeded functions. + * + * Not to be confused with @ref XXH3_SECRET_SIZE_MIN. + */ +#define XXH3_SECRET_DEFAULT_SIZE 192 + +/*! + * @internal + * @brief Structure for XXH3 streaming API. + * + * @note This is only defined when @ref XXH_STATIC_LINKING_ONLY, + * @ref XXH_INLINE_ALL, or @ref XXH_IMPLEMENTATION is defined. + * Otherwise it is an opaque type. + * Never use this definition in combination with dynamic library. + * This allows fields to safely be changed in the future. + * + * @note ** This structure has a strict alignment requirement of 64 bytes!! ** + * Do not allocate this with `malloc()` or `new`, + * it will not be sufficiently aligned. + * Use @ref XXH3_createState() and @ref XXH3_freeState(), or stack allocation. + * + * Typedef'd to @ref XXH3_state_t. + * Do never access the members of this struct directly. + * + * @see XXH3_INITSTATE() for stack initialization. + * @see XXH3_createState(), XXH3_freeState(). + * @see XXH32_state_s, XXH64_state_s + */ +struct XXH3_state_s { + XXH_ALIGN_MEMBER(64, XXH64_hash_t acc[8]); + /*!< The 8 accumulators. See @ref XXH32_state_s::v and @ref XXH64_state_s::v */ + XXH_ALIGN_MEMBER(64, unsigned char customSecret[XXH3_SECRET_DEFAULT_SIZE]); + /*!< Used to store a custom secret generated from a seed. */ + XXH_ALIGN_MEMBER(64, unsigned char buffer[XXH3_INTERNALBUFFER_SIZE]); + /*!< The internal buffer. @see XXH32_state_s::mem32 */ + XXH32_hash_t bufferedSize; + /*!< The amount of memory in @ref buffer, @see XXH32_state_s::memsize */ + XXH32_hash_t useSeed; + /*!< Reserved field. Needed for padding on 64-bit. */ + size_t nbStripesSoFar; + /*!< Number or stripes processed. */ + XXH64_hash_t totalLen; + /*!< Total length hashed. 64-bit even on 32-bit targets. */ + size_t nbStripesPerBlock; + /*!< Number of stripes per block. */ + size_t secretLimit; + /*!< Size of @ref customSecret or @ref extSecret */ + XXH64_hash_t seed; + /*!< Seed for _withSeed variants. Must be zero otherwise, @see XXH3_INITSTATE() */ + XXH64_hash_t reserved64; + /*!< Reserved field. */ + const unsigned char* extSecret; + /*!< Reference to an external secret for the _withSecret variants, NULL + * for other variants. */ + /* note: there may be some padding at the end due to alignment on 64 bytes */ +}; /* typedef'd to XXH3_state_t */ + +#undef XXH_ALIGN_MEMBER + +/*! + * @brief Initializes a stack-allocated `XXH3_state_s`. + * + * When the @ref XXH3_state_t structure is merely emplaced on stack, + * it should be initialized with XXH3_INITSTATE() or a memset() + * in case its first reset uses XXH3_NNbits_reset_withSeed(). + * This init can be omitted if the first reset uses default or _withSecret mode. + * This operation isn't necessary when the state is created with XXH3_createState(). + * Note that this doesn't prepare the state for a streaming operation, + * it's still necessary to use XXH3_NNbits_reset*() afterwards. + */ +#define XXH3_INITSTATE(XXH3_state_ptr) \ + do { \ + XXH3_state_t* tmp_xxh3_state_ptr = (XXH3_state_ptr); \ + tmp_xxh3_state_ptr->seed = 0; \ + tmp_xxh3_state_ptr->extSecret = NULL; \ + } while(0) + + +/*! + * @brief Calculates the 128-bit hash of @p data using XXH3. + * + * @param data The block of data to be hashed, at least @p len bytes in size. + * @param len The length of @p data, in bytes. + * @param seed The 64-bit seed to alter the hash's output predictably. + * + * @pre + * The memory between @p data and @p data + @p len must be valid, + * readable, contiguous memory. However, if @p len is `0`, @p data may be + * `NULL`. In C++, this also must be *TriviallyCopyable*. + * + * @return The calculated 128-bit XXH3 value. + * + * @see @ref single_shot_example "Single Shot Example" for an example. + */ +XXH_PUBLIC_API XXH_PUREF XXH128_hash_t XXH128(XXH_NOESCAPE const void* data, size_t len, XXH64_hash_t seed); + + +/* === Experimental API === */ +/* Symbols defined below must be considered tied to a specific library version. */ + +/*! + * @brief Derive a high-entropy secret from any user-defined content, named customSeed. + * + * @param secretBuffer A writable buffer for derived high-entropy secret data. + * @param secretSize Size of secretBuffer, in bytes. Must be >= XXH3_SECRET_SIZE_MIN. + * @param customSeed A user-defined content. + * @param customSeedSize Size of customSeed, in bytes. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * The generated secret can be used in combination with `*_withSecret()` functions. + * The `_withSecret()` variants are useful to provide a higher level of protection + * than 64-bit seed, as it becomes much more difficult for an external actor to + * guess how to impact the calculation logic. + * + * The function accepts as input a custom seed of any length and any content, + * and derives from it a high-entropy secret of length @p secretSize into an + * already allocated buffer @p secretBuffer. + * + * The generated secret can then be used with any `*_withSecret()` variant. + * The functions @ref XXH3_128bits_withSecret(), @ref XXH3_64bits_withSecret(), + * @ref XXH3_128bits_reset_withSecret() and @ref XXH3_64bits_reset_withSecret() + * are part of this list. They all accept a `secret` parameter + * which must be large enough for implementation reasons (>= @ref XXH3_SECRET_SIZE_MIN) + * _and_ feature very high entropy (consist of random-looking bytes). + * These conditions can be a high bar to meet, so @ref XXH3_generateSecret() can + * be employed to ensure proper quality. + * + * @p customSeed can be anything. It can have any size, even small ones, + * and its content can be anything, even "poor entropy" sources such as a bunch + * of zeroes. The resulting `secret` will nonetheless provide all required qualities. + * + * @pre + * - @p secretSize must be >= @ref XXH3_SECRET_SIZE_MIN + * - When @p customSeedSize > 0, supplying NULL as customSeed is undefined behavior. + * + * Example code: + * @code{.c} + * #include + * #include + * #include + * #define XXH_STATIC_LINKING_ONLY // expose unstable API + * #include "xxhash.h" + * // Hashes argv[2] using the entropy from argv[1]. + * int main(int argc, char* argv[]) + * { + * char secret[XXH3_SECRET_SIZE_MIN]; + * if (argv != 3) { return 1; } + * XXH3_generateSecret(secret, sizeof(secret), argv[1], strlen(argv[1])); + * XXH64_hash_t h = XXH3_64bits_withSecret( + * argv[2], strlen(argv[2]), + * secret, sizeof(secret) + * ); + * printf("%016llx\n", (unsigned long long) h); + * } + * @endcode + */ +XXH_PUBLIC_API XXH_errorcode XXH3_generateSecret(XXH_NOESCAPE void* secretBuffer, size_t secretSize, XXH_NOESCAPE const void* customSeed, size_t customSeedSize); + +/*! + * @brief Generate the same secret as the _withSeed() variants. + * + * @param secretBuffer A writable buffer of @ref XXH3_SECRET_DEFAULT_SIZE bytes + * @param seed The 64-bit seed to alter the hash result predictably. + * + * The generated secret can be used in combination with + *`*_withSecret()` and `_withSecretandSeed()` variants. + * + * Example C++ `std::string` hash class: + * @code{.cpp} + * #include + * #define XXH_STATIC_LINKING_ONLY // expose unstable API + * #include "xxhash.h" + * // Slow, seeds each time + * class HashSlow { + * XXH64_hash_t seed; + * public: + * HashSlow(XXH64_hash_t s) : seed{s} {} + * size_t operator()(const std::string& x) const { + * return size_t{XXH3_64bits_withSeed(x.c_str(), x.length(), seed)}; + * } + * }; + * // Fast, caches the seeded secret for future uses. + * class HashFast { + * unsigned char secret[XXH3_SECRET_DEFAULT_SIZE]; + * public: + * HashFast(XXH64_hash_t s) { + * XXH3_generateSecret_fromSeed(secret, seed); + * } + * size_t operator()(const std::string& x) const { + * return size_t{ + * XXH3_64bits_withSecret(x.c_str(), x.length(), secret, sizeof(secret)) + * }; + * } + * }; + * @endcode + */ +XXH_PUBLIC_API void XXH3_generateSecret_fromSeed(XXH_NOESCAPE void* secretBuffer, XXH64_hash_t seed); + +/*! + * @brief Maximum size of "short" key in bytes. + */ +#define XXH3_MIDSIZE_MAX 240 + +/*! + * @brief Calculates 64/128-bit seeded variant of XXH3 hash of @p data. + * + * @param data The block of data to be hashed, at least @p len bytes in size. + * @param len The length of @p data, in bytes. + * @param secret The secret data. + * @param secretSize The length of @p secret, in bytes. + * @param seed The 64-bit seed to alter the hash result predictably. + * + * These variants generate hash values using either: + * - @p seed for "short" keys (< @ref XXH3_MIDSIZE_MAX = 240 bytes) + * - @p secret for "large" keys (>= @ref XXH3_MIDSIZE_MAX). + * + * This generally benefits speed, compared to `_withSeed()` or `_withSecret()`. + * `_withSeed()` has to generate the secret on the fly for "large" keys. + * It's fast, but can be perceptible for "not so large" keys (< 1 KB). + * `_withSecret()` has to generate the masks on the fly for "small" keys, + * which requires more instructions than _withSeed() variants. + * Therefore, _withSecretandSeed variant combines the best of both worlds. + * + * When @p secret has been generated by XXH3_generateSecret_fromSeed(), + * this variant produces *exactly* the same results as `_withSeed()` variant, + * hence offering only a pure speed benefit on "large" input, + * by skipping the need to regenerate the secret for every large input. + * + * Another usage scenario is to hash the secret to a 64-bit hash value, + * for example with XXH3_64bits(), which then becomes the seed, + * and then employ both the seed and the secret in _withSecretandSeed(). + * On top of speed, an added benefit is that each bit in the secret + * has a 50% chance to swap each bit in the output, via its impact to the seed. + * + * This is not guaranteed when using the secret directly in "small data" scenarios, + * because only portions of the secret are employed for small data. + */ +XXH_PUBLIC_API XXH_PUREF XXH64_hash_t +XXH3_64bits_withSecretandSeed(XXH_NOESCAPE const void* data, size_t len, + XXH_NOESCAPE const void* secret, size_t secretSize, + XXH64_hash_t seed); + +/*! + * @brief Calculates 128-bit seeded variant of XXH3 hash of @p data. + * + * @param data The memory segment to be hashed, at least @p len bytes in size. + * @param length The length of @p data, in bytes. + * @param secret The secret used to alter hash result predictably. + * @param secretSize The length of @p secret, in bytes (must be >= XXH3_SECRET_SIZE_MIN) + * @param seed64 The 64-bit seed to alter the hash result predictably. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @see XXH3_64bits_withSecretandSeed(): contract is the same. + */ +XXH_PUBLIC_API XXH_PUREF XXH128_hash_t +XXH3_128bits_withSecretandSeed(XXH_NOESCAPE const void* input, size_t length, + XXH_NOESCAPE const void* secret, size_t secretSize, + XXH64_hash_t seed64); + +#ifndef XXH_NO_STREAM +/*! + * @brief Resets an @ref XXH3_state_t with secret data to begin a new hash. + * + * @param statePtr A pointer to an @ref XXH3_state_t allocated with @ref XXH3_createState(). + * @param secret The secret data. + * @param secretSize The length of @p secret, in bytes. + * @param seed64 The 64-bit seed to alter the hash result predictably. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @see XXH3_64bits_withSecretandSeed(). Contract is identical. + */ +XXH_PUBLIC_API XXH_errorcode +XXH3_64bits_reset_withSecretandSeed(XXH_NOESCAPE XXH3_state_t* statePtr, + XXH_NOESCAPE const void* secret, size_t secretSize, + XXH64_hash_t seed64); + +/*! + * @brief Resets an @ref XXH3_state_t with secret data to begin a new hash. + * + * @param statePtr A pointer to an @ref XXH3_state_t allocated with @ref XXH3_createState(). + * @param secret The secret data. + * @param secretSize The length of @p secret, in bytes. + * @param seed64 The 64-bit seed to alter the hash result predictably. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @see XXH3_64bits_withSecretandSeed(). Contract is identical. + * + * Note: there was a bug in an earlier version of this function (<= v0.8.2) + * that would make it generate an incorrect hash value + * when @p seed == 0 and @p length < XXH3_MIDSIZE_MAX + * and @p secret is different from XXH3_generateSecret_fromSeed(). + * As stated in the contract, the correct hash result must be + * the same as XXH3_128bits_withSeed() when @p length <= XXH3_MIDSIZE_MAX. + * Results generated by this older version are wrong, hence not comparable. + */ +XXH_PUBLIC_API XXH_errorcode +XXH3_128bits_reset_withSecretandSeed(XXH_NOESCAPE XXH3_state_t* statePtr, + XXH_NOESCAPE const void* secret, size_t secretSize, + XXH64_hash_t seed64); + +#endif /* !XXH_NO_STREAM */ + +#endif /* !XXH_NO_XXH3 */ +#endif /* XXH_NO_LONG_LONG */ +#if defined(XXH_INLINE_ALL) || defined(XXH_PRIVATE_API) +# define XXH_IMPLEMENTATION +#endif + +#endif /* defined(XXH_STATIC_LINKING_ONLY) && !defined(XXHASH_H_STATIC_13879238742) */ + + +/* ======================================================================== */ +/* ======================================================================== */ +/* ======================================================================== */ + + +/*-********************************************************************** + * xxHash implementation + *-********************************************************************** + * xxHash's implementation used to be hosted inside xxhash.c. + * + * However, inlining requires implementation to be visible to the compiler, + * hence be included alongside the header. + * Previously, implementation was hosted inside xxhash.c, + * which was then #included when inlining was activated. + * This construction created issues with a few build and install systems, + * as it required xxhash.c to be stored in /include directory. + * + * xxHash implementation is now directly integrated within xxhash.h. + * As a consequence, xxhash.c is no longer needed in /include. + * + * xxhash.c is still available and is still useful. + * In a "normal" setup, when xxhash is not inlined, + * xxhash.h only exposes the prototypes and public symbols, + * while xxhash.c can be built into an object file xxhash.o + * which can then be linked into the final binary. + ************************************************************************/ + +#if ( defined(XXH_INLINE_ALL) || defined(XXH_PRIVATE_API) \ + || defined(XXH_IMPLEMENTATION) ) && !defined(XXH_IMPLEM_13a8737387) +# define XXH_IMPLEM_13a8737387 + +/* ************************************* +* Tuning parameters +***************************************/ + +/*! + * @defgroup tuning Tuning parameters + * @{ + * + * Various macros to control xxHash's behavior. + */ +#ifdef XXH_DOXYGEN +/*! + * @brief Define this to disable 64-bit code. + * + * Useful if only using the @ref XXH32_family and you have a strict C90 compiler. + */ +# define XXH_NO_LONG_LONG +# undef XXH_NO_LONG_LONG /* don't actually */ +/*! + * @brief Controls how unaligned memory is accessed. + * + * By default, access to unaligned memory is controlled by `memcpy()`, which is + * safe and portable. + * + * Unfortunately, on some target/compiler combinations, the generated assembly + * is sub-optimal. + * + * The below switch allow selection of a different access method + * in the search for improved performance. + * + * @par Possible options: + * + * - `XXH_FORCE_MEMORY_ACCESS=0` (default): `memcpy` + * @par + * Use `memcpy()`. Safe and portable. Note that most modern compilers will + * eliminate the function call and treat it as an unaligned access. + * + * - `XXH_FORCE_MEMORY_ACCESS=1`: `__attribute__((aligned(1)))` + * @par + * Depends on compiler extensions and is therefore not portable. + * This method is safe _if_ your compiler supports it, + * and *generally* as fast or faster than `memcpy`. + * + * - `XXH_FORCE_MEMORY_ACCESS=2`: Direct cast + * @par + * Casts directly and dereferences. This method doesn't depend on the + * compiler, but it violates the C standard as it directly dereferences an + * unaligned pointer. It can generate buggy code on targets which do not + * support unaligned memory accesses, but in some circumstances, it's the + * only known way to get the most performance. + * + * - `XXH_FORCE_MEMORY_ACCESS=3`: Byteshift + * @par + * Also portable. This can generate the best code on old compilers which don't + * inline small `memcpy()` calls, and it might also be faster on big-endian + * systems which lack a native byteswap instruction. However, some compilers + * will emit literal byteshifts even if the target supports unaligned access. + * + * + * @warning + * Methods 1 and 2 rely on implementation-defined behavior. Use these with + * care, as what works on one compiler/platform/optimization level may cause + * another to read garbage data or even crash. + * + * See https://fastcompression.blogspot.com/2015/08/accessing-unaligned-memory.html for details. + * + * Prefer these methods in priority order (0 > 3 > 1 > 2) + */ +# define XXH_FORCE_MEMORY_ACCESS 0 + +/*! + * @def XXH_SIZE_OPT + * @brief Controls how much xxHash optimizes for size. + * + * xxHash, when compiled, tends to result in a rather large binary size. This + * is mostly due to heavy usage to forced inlining and constant folding of the + * @ref XXH3_family to increase performance. + * + * However, some developers prefer size over speed. This option can + * significantly reduce the size of the generated code. When using the `-Os` + * or `-Oz` options on GCC or Clang, this is defined to 1 by default, + * otherwise it is defined to 0. + * + * Most of these size optimizations can be controlled manually. + * + * This is a number from 0-2. + * - `XXH_SIZE_OPT` == 0: Default. xxHash makes no size optimizations. Speed + * comes first. + * - `XXH_SIZE_OPT` == 1: Default for `-Os` and `-Oz`. xxHash is more + * conservative and disables hacks that increase code size. It implies the + * options @ref XXH_NO_INLINE_HINTS == 1, @ref XXH_FORCE_ALIGN_CHECK == 0, + * and @ref XXH3_NEON_LANES == 8 if they are not already defined. + * - `XXH_SIZE_OPT` == 2: xxHash tries to make itself as small as possible. + * Performance may cry. For example, the single shot functions just use the + * streaming API. + */ +# define XXH_SIZE_OPT 0 + +/*! + * @def XXH_FORCE_ALIGN_CHECK + * @brief If defined to non-zero, adds a special path for aligned inputs (XXH32() + * and XXH64() only). + * + * This is an important performance trick for architectures without decent + * unaligned memory access performance. + * + * It checks for input alignment, and when conditions are met, uses a "fast + * path" employing direct 32-bit/64-bit reads, resulting in _dramatically + * faster_ read speed. + * + * The check costs one initial branch per hash, which is generally negligible, + * but not zero. + * + * Moreover, it's not useful to generate an additional code path if memory + * access uses the same instruction for both aligned and unaligned + * addresses (e.g. x86 and aarch64). + * + * In these cases, the alignment check can be removed by setting this macro to 0. + * Then the code will always use unaligned memory access. + * Align check is automatically disabled on x86, x64, ARM64, and some ARM chips + * which are platforms known to offer good unaligned memory accesses performance. + * + * It is also disabled by default when @ref XXH_SIZE_OPT >= 1. + * + * This option does not affect XXH3 (only XXH32 and XXH64). + */ +# define XXH_FORCE_ALIGN_CHECK 0 + +/*! + * @def XXH_NO_INLINE_HINTS + * @brief When non-zero, sets all functions to `static`. + * + * By default, xxHash tries to force the compiler to inline almost all internal + * functions. + * + * This can usually improve performance due to reduced jumping and improved + * constant folding, but significantly increases the size of the binary which + * might not be favorable. + * + * Additionally, sometimes the forced inlining can be detrimental to performance, + * depending on the architecture. + * + * XXH_NO_INLINE_HINTS marks all internal functions as static, giving the + * compiler full control on whether to inline or not. + * + * When not optimizing (-O0), using `-fno-inline` with GCC or Clang, or if + * @ref XXH_SIZE_OPT >= 1, this will automatically be defined. + */ +# define XXH_NO_INLINE_HINTS 0 + +/*! + * @def XXH3_INLINE_SECRET + * @brief Determines whether to inline the XXH3 withSecret code. + * + * When the secret size is known, the compiler can improve the performance + * of XXH3_64bits_withSecret() and XXH3_128bits_withSecret(). + * + * However, if the secret size is not known, it doesn't have any benefit. This + * happens when xxHash is compiled into a global symbol. Therefore, if + * @ref XXH_INLINE_ALL is *not* defined, this will be defined to 0. + * + * Additionally, this defaults to 0 on GCC 12+, which has an issue with function pointers + * that are *sometimes* force inline on -Og, and it is impossible to automatically + * detect this optimization level. + */ +# define XXH3_INLINE_SECRET 0 + +/*! + * @def XXH32_ENDJMP + * @brief Whether to use a jump for `XXH32_finalize`. + * + * For performance, `XXH32_finalize` uses multiple branches in the finalizer. + * This is generally preferable for performance, + * but depending on exact architecture, a jmp may be preferable. + * + * This setting is only possibly making a difference for very small inputs. + */ +# define XXH32_ENDJMP 0 + +/*! + * @internal + * @brief Redefines old internal names. + * + * For compatibility with code that uses xxHash's internals before the names + * were changed to improve namespacing. There is no other reason to use this. + */ +# define XXH_OLD_NAMES +# undef XXH_OLD_NAMES /* don't actually use, it is ugly. */ + +/*! + * @def XXH_NO_STREAM + * @brief Disables the streaming API. + * + * When xxHash is not inlined and the streaming functions are not used, disabling + * the streaming functions can improve code size significantly, especially with + * the @ref XXH3_family which tends to make constant folded copies of itself. + */ +# define XXH_NO_STREAM +# undef XXH_NO_STREAM /* don't actually */ +#endif /* XXH_DOXYGEN */ +/*! + * @} + */ + +#ifndef XXH_FORCE_MEMORY_ACCESS /* can be defined externally, on command line for example */ + /* prefer __packed__ structures (method 1) for GCC + * < ARMv7 with unaligned access (e.g. Raspbian armhf) still uses byte shifting, so we use memcpy + * which for some reason does unaligned loads. */ +# if defined(__GNUC__) && !(defined(__ARM_ARCH) && __ARM_ARCH < 7 && defined(__ARM_FEATURE_UNALIGNED)) +# define XXH_FORCE_MEMORY_ACCESS 1 +# endif +#endif + +#ifndef XXH_SIZE_OPT + /* default to 1 for -Os or -Oz */ +# if (defined(__GNUC__) || defined(__clang__)) && defined(__OPTIMIZE_SIZE__) +# define XXH_SIZE_OPT 1 +# else +# define XXH_SIZE_OPT 0 +# endif +#endif + +#ifndef XXH_FORCE_ALIGN_CHECK /* can be defined externally */ + /* don't check on sizeopt, x86, aarch64, or arm when unaligned access is available */ +# if XXH_SIZE_OPT >= 1 || \ + defined(__i386) || defined(__x86_64__) || defined(__aarch64__) || defined(__ARM_FEATURE_UNALIGNED) \ + || defined(_M_IX86) || defined(_M_X64) || defined(_M_ARM64) || defined(_M_ARM) /* visual */ +# define XXH_FORCE_ALIGN_CHECK 0 +# else +# define XXH_FORCE_ALIGN_CHECK 1 +# endif +#endif + +#ifndef XXH_NO_INLINE_HINTS +# if XXH_SIZE_OPT >= 1 || defined(__NO_INLINE__) /* -O0, -fno-inline */ +# define XXH_NO_INLINE_HINTS 1 +# else +# define XXH_NO_INLINE_HINTS 0 +# endif +#endif + +#ifndef XXH3_INLINE_SECRET +# if (defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 12) \ + || !defined(XXH_INLINE_ALL) +# define XXH3_INLINE_SECRET 0 +# else +# define XXH3_INLINE_SECRET 1 +# endif +#endif + +#ifndef XXH32_ENDJMP +/* generally preferable for performance */ +# define XXH32_ENDJMP 0 +#endif + +/*! + * @defgroup impl Implementation + * @{ + */ + + +/* ************************************* +* Includes & Memory related functions +***************************************/ +#if defined(XXH_NO_STREAM) +/* nothing */ +#elif defined(XXH_NO_STDLIB) + +/* When requesting to disable any mention of stdlib, + * the library loses the ability to invoked malloc / free. + * In practice, it means that functions like `XXH*_createState()` + * will always fail, and return NULL. + * This flag is useful in situations where + * xxhash.h is integrated into some kernel, embedded or limited environment + * without access to dynamic allocation. + */ + +static XXH_CONSTF void* XXH_malloc(size_t s) { (void)s; return NULL; } +static void XXH_free(void* p) { (void)p; } + +#else + +/* + * Modify the local functions below should you wish to use + * different memory routines for malloc() and free() + */ +#include + +/*! + * @internal + * @brief Modify this function to use a different routine than malloc(). + */ +static XXH_MALLOCF void* XXH_malloc(size_t s) { return malloc(s); } + +/*! + * @internal + * @brief Modify this function to use a different routine than free(). + */ +static void XXH_free(void* p) { free(p); } + +#endif /* XXH_NO_STDLIB */ + +#include + +/*! + * @internal + * @brief Modify this function to use a different routine than memcpy(). + */ +static void* XXH_memcpy(void* dest, const void* src, size_t size) +{ + return memcpy(dest,src,size); +} + +#include /* ULLONG_MAX */ + + +/* ************************************* +* Compiler Specific Options +***************************************/ +#ifdef _MSC_VER /* Visual Studio warning fix */ +# pragma warning(disable : 4127) /* disable: C4127: conditional expression is constant */ +#endif + +#if XXH_NO_INLINE_HINTS /* disable inlining hints */ +# if defined(__GNUC__) || defined(__clang__) +# define XXH_FORCE_INLINE static __attribute__((__unused__)) +# else +# define XXH_FORCE_INLINE static +# endif +# define XXH_NO_INLINE static +/* enable inlining hints */ +#elif defined(__GNUC__) || defined(__clang__) +# define XXH_FORCE_INLINE static __inline__ __attribute__((__always_inline__, __unused__)) +# define XXH_NO_INLINE static __attribute__((__noinline__)) +#elif defined(_MSC_VER) /* Visual Studio */ +# define XXH_FORCE_INLINE static __forceinline +# define XXH_NO_INLINE static __declspec(noinline) +#elif defined (__cplusplus) \ + || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)) /* C99 */ +# define XXH_FORCE_INLINE static inline +# define XXH_NO_INLINE static +#else +# define XXH_FORCE_INLINE static +# define XXH_NO_INLINE static +#endif + +#if XXH3_INLINE_SECRET +# define XXH3_WITH_SECRET_INLINE XXH_FORCE_INLINE +#else +# define XXH3_WITH_SECRET_INLINE XXH_NO_INLINE +#endif + + +/* ************************************* +* Debug +***************************************/ +/*! + * @ingroup tuning + * @def XXH_DEBUGLEVEL + * @brief Sets the debugging level. + * + * XXH_DEBUGLEVEL is expected to be defined externally, typically via the + * compiler's command line options. The value must be a number. + */ +#ifndef XXH_DEBUGLEVEL +# ifdef DEBUGLEVEL /* backwards compat */ +# define XXH_DEBUGLEVEL DEBUGLEVEL +# else +# define XXH_DEBUGLEVEL 0 +# endif +#endif + +#if (XXH_DEBUGLEVEL>=1) +# include /* note: can still be disabled with NDEBUG */ +# define XXH_ASSERT(c) assert(c) +#else +# if defined(__INTEL_COMPILER) +# define XXH_ASSERT(c) XXH_ASSUME((unsigned char) (c)) +# else +# define XXH_ASSERT(c) XXH_ASSUME(c) +# endif +#endif + +/* note: use after variable declarations */ +#ifndef XXH_STATIC_ASSERT +# if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L) /* C11 */ +# define XXH_STATIC_ASSERT_WITH_MESSAGE(c,m) do { _Static_assert((c),m); } while(0) +# elif defined(__cplusplus) && (__cplusplus >= 201103L) /* C++11 */ +# define XXH_STATIC_ASSERT_WITH_MESSAGE(c,m) do { static_assert((c),m); } while(0) +# else +# define XXH_STATIC_ASSERT_WITH_MESSAGE(c,m) do { struct xxh_sa { char x[(c) ? 1 : -1]; }; } while(0) +# endif +# define XXH_STATIC_ASSERT(c) XXH_STATIC_ASSERT_WITH_MESSAGE((c),#c) +#endif + +/*! + * @internal + * @def XXH_COMPILER_GUARD(var) + * @brief Used to prevent unwanted optimizations for @p var. + * + * It uses an empty GCC inline assembly statement with a register constraint + * which forces @p var into a general purpose register (eg eax, ebx, ecx + * on x86) and marks it as modified. + * + * This is used in a few places to avoid unwanted autovectorization (e.g. + * XXH32_round()). All vectorization we want is explicit via intrinsics, + * and _usually_ isn't wanted elsewhere. + * + * We also use it to prevent unwanted constant folding for AArch64 in + * XXH3_initCustomSecret_scalar(). + */ +#if defined(__GNUC__) || defined(__clang__) +# define XXH_COMPILER_GUARD(var) __asm__("" : "+r" (var)) +#else +# define XXH_COMPILER_GUARD(var) ((void)0) +#endif + +/* Specifically for NEON vectors which use the "w" constraint, on + * Clang. */ +#if defined(__clang__) && defined(__ARM_ARCH) && !defined(__wasm__) +# define XXH_COMPILER_GUARD_CLANG_NEON(var) __asm__("" : "+w" (var)) +#else +# define XXH_COMPILER_GUARD_CLANG_NEON(var) ((void)0) +#endif + +/* ************************************* +* Basic Types +***************************************/ +#if !defined (__VMS) \ + && (defined (__cplusplus) \ + || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) ) +# ifdef _AIX +# include +# else +# include +# endif + typedef uint8_t xxh_u8; +#else + typedef unsigned char xxh_u8; +#endif +typedef XXH32_hash_t xxh_u32; + +#ifdef XXH_OLD_NAMES +# warning "XXH_OLD_NAMES is planned to be removed starting v0.9. If the program depends on it, consider moving away from it by employing newer type names directly" +# define BYTE xxh_u8 +# define U8 xxh_u8 +# define U32 xxh_u32 +#endif + +/* *** Memory access *** */ + +/*! + * @internal + * @fn xxh_u32 XXH_read32(const void* ptr) + * @brief Reads an unaligned 32-bit integer from @p ptr in native endianness. + * + * Affected by @ref XXH_FORCE_MEMORY_ACCESS. + * + * @param ptr The pointer to read from. + * @return The 32-bit native endian integer from the bytes at @p ptr. + */ + +/*! + * @internal + * @fn xxh_u32 XXH_readLE32(const void* ptr) + * @brief Reads an unaligned 32-bit little endian integer from @p ptr. + * + * Affected by @ref XXH_FORCE_MEMORY_ACCESS. + * + * @param ptr The pointer to read from. + * @return The 32-bit little endian integer from the bytes at @p ptr. + */ + +/*! + * @internal + * @fn xxh_u32 XXH_readBE32(const void* ptr) + * @brief Reads an unaligned 32-bit big endian integer from @p ptr. + * + * Affected by @ref XXH_FORCE_MEMORY_ACCESS. + * + * @param ptr The pointer to read from. + * @return The 32-bit big endian integer from the bytes at @p ptr. + */ + +/*! + * @internal + * @fn xxh_u32 XXH_readLE32_align(const void* ptr, XXH_alignment align) + * @brief Like @ref XXH_readLE32(), but has an option for aligned reads. + * + * Affected by @ref XXH_FORCE_MEMORY_ACCESS. + * Note that when @ref XXH_FORCE_ALIGN_CHECK == 0, the @p align parameter is + * always @ref XXH_alignment::XXH_unaligned. + * + * @param ptr The pointer to read from. + * @param align Whether @p ptr is aligned. + * @pre + * If @p align == @ref XXH_alignment::XXH_aligned, @p ptr must be 4 byte + * aligned. + * @return The 32-bit little endian integer from the bytes at @p ptr. + */ + +#if (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==3)) +/* + * Manual byteshift. Best for old compilers which don't inline memcpy. + * We actually directly use XXH_readLE32 and XXH_readBE32. + */ +#elif (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==2)) + +/* + * Force direct memory access. Only works on CPU which support unaligned memory + * access in hardware. + */ +static xxh_u32 XXH_read32(const void* memPtr) { return *(const xxh_u32*) memPtr; } + +#elif (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==1)) + +/* + * __attribute__((aligned(1))) is supported by gcc and clang. Originally the + * documentation claimed that it only increased the alignment, but actually it + * can decrease it on gcc, clang, and icc: + * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=69502, + * https://gcc.godbolt.org/z/xYez1j67Y. + */ +#ifdef XXH_OLD_NAMES +typedef union { xxh_u32 u32; } __attribute__((__packed__)) unalign; +#endif +static xxh_u32 XXH_read32(const void* ptr) +{ + typedef __attribute__((__aligned__(1))) xxh_u32 xxh_unalign32; + return *((const xxh_unalign32*)ptr); +} + +#else + +/* + * Portable and safe solution. Generally efficient. + * see: https://fastcompression.blogspot.com/2015/08/accessing-unaligned-memory.html + */ +static xxh_u32 XXH_read32(const void* memPtr) +{ + xxh_u32 val; + XXH_memcpy(&val, memPtr, sizeof(val)); + return val; +} + +#endif /* XXH_FORCE_DIRECT_MEMORY_ACCESS */ + + +/* *** Endianness *** */ + +/*! + * @ingroup tuning + * @def XXH_CPU_LITTLE_ENDIAN + * @brief Whether the target is little endian. + * + * Defined to 1 if the target is little endian, or 0 if it is big endian. + * It can be defined externally, for example on the compiler command line. + * + * If it is not defined, + * a runtime check (which is usually constant folded) is used instead. + * + * @note + * This is not necessarily defined to an integer constant. + * + * @see XXH_isLittleEndian() for the runtime check. + */ +#ifndef XXH_CPU_LITTLE_ENDIAN +/* + * Try to detect endianness automatically, to avoid the nonstandard behavior + * in `XXH_isLittleEndian()` + */ +# if defined(_WIN32) /* Windows is always little endian */ \ + || defined(__LITTLE_ENDIAN__) \ + || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__) +# define XXH_CPU_LITTLE_ENDIAN 1 +# elif defined(__BIG_ENDIAN__) \ + || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__) +# define XXH_CPU_LITTLE_ENDIAN 0 +# else +/*! + * @internal + * @brief Runtime check for @ref XXH_CPU_LITTLE_ENDIAN. + * + * Most compilers will constant fold this. + */ +static int XXH_isLittleEndian(void) +{ + /* + * Portable and well-defined behavior. + * Don't use static: it is detrimental to performance. + */ + const union { xxh_u32 u; xxh_u8 c[4]; } one = { 1 }; + return one.c[0]; +} +# define XXH_CPU_LITTLE_ENDIAN XXH_isLittleEndian() +# endif +#endif + + + + +/* **************************************** +* Compiler-specific Functions and Macros +******************************************/ +#define XXH_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) + +#ifdef __has_builtin +# define XXH_HAS_BUILTIN(x) __has_builtin(x) +#else +# define XXH_HAS_BUILTIN(x) 0 +#endif + + + +/* + * C23 and future versions have standard "unreachable()". + * Once it has been implemented reliably we can add it as an + * additional case: + * + * ``` + * #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= XXH_C23_VN) + * # include + * # ifdef unreachable + * # define XXH_UNREACHABLE() unreachable() + * # endif + * #endif + * ``` + * + * Note C++23 also has std::unreachable() which can be detected + * as follows: + * ``` + * #if defined(__cpp_lib_unreachable) && (__cpp_lib_unreachable >= 202202L) + * # include + * # define XXH_UNREACHABLE() std::unreachable() + * #endif + * ``` + * NB: `__cpp_lib_unreachable` is defined in the `` header. + * We don't use that as including `` in `extern "C"` blocks + * doesn't work on GCC12 + */ + +#if XXH_HAS_BUILTIN(__builtin_unreachable) +# define XXH_UNREACHABLE() __builtin_unreachable() + +#elif defined(_MSC_VER) +# define XXH_UNREACHABLE() __assume(0) + +#else +# define XXH_UNREACHABLE() +#endif + +#if XXH_HAS_BUILTIN(__builtin_assume) +# define XXH_ASSUME(c) __builtin_assume(c) +#else +# define XXH_ASSUME(c) if (!(c)) { XXH_UNREACHABLE(); } +#endif + +/*! + * @internal + * @def XXH_rotl32(x,r) + * @brief 32-bit rotate left. + * + * @param x The 32-bit integer to be rotated. + * @param r The number of bits to rotate. + * @pre + * @p r > 0 && @p r < 32 + * @note + * @p x and @p r may be evaluated multiple times. + * @return The rotated result. + */ +#if !defined(NO_CLANG_BUILTIN) && XXH_HAS_BUILTIN(__builtin_rotateleft32) \ + && XXH_HAS_BUILTIN(__builtin_rotateleft64) +# define XXH_rotl32 __builtin_rotateleft32 +# define XXH_rotl64 __builtin_rotateleft64 +/* Note: although _rotl exists for minGW (GCC under windows), performance seems poor */ +#elif defined(_MSC_VER) +# define XXH_rotl32(x,r) _rotl(x,r) +# define XXH_rotl64(x,r) _rotl64(x,r) +#else +# define XXH_rotl32(x,r) (((x) << (r)) | ((x) >> (32 - (r)))) +# define XXH_rotl64(x,r) (((x) << (r)) | ((x) >> (64 - (r)))) +#endif + +/*! + * @internal + * @fn xxh_u32 XXH_swap32(xxh_u32 x) + * @brief A 32-bit byteswap. + * + * @param x The 32-bit integer to byteswap. + * @return @p x, byteswapped. + */ +#if defined(_MSC_VER) /* Visual Studio */ +# define XXH_swap32 _byteswap_ulong +#elif XXH_GCC_VERSION >= 403 +# define XXH_swap32 __builtin_bswap32 +#else +static xxh_u32 XXH_swap32 (xxh_u32 x) +{ + return ((x << 24) & 0xff000000 ) | + ((x << 8) & 0x00ff0000 ) | + ((x >> 8) & 0x0000ff00 ) | + ((x >> 24) & 0x000000ff ); +} +#endif + + +/* *************************** +* Memory reads +*****************************/ + +/*! + * @internal + * @brief Enum to indicate whether a pointer is aligned. + */ +typedef enum { + XXH_aligned, /*!< Aligned */ + XXH_unaligned /*!< Possibly unaligned */ +} XXH_alignment; + +/* + * XXH_FORCE_MEMORY_ACCESS==3 is an endian-independent byteshift load. + * + * This is ideal for older compilers which don't inline memcpy. + */ +#if (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==3)) + +XXH_FORCE_INLINE xxh_u32 XXH_readLE32(const void* memPtr) +{ + const xxh_u8* bytePtr = (const xxh_u8 *)memPtr; + return bytePtr[0] + | ((xxh_u32)bytePtr[1] << 8) + | ((xxh_u32)bytePtr[2] << 16) + | ((xxh_u32)bytePtr[3] << 24); +} + +XXH_FORCE_INLINE xxh_u32 XXH_readBE32(const void* memPtr) +{ + const xxh_u8* bytePtr = (const xxh_u8 *)memPtr; + return bytePtr[3] + | ((xxh_u32)bytePtr[2] << 8) + | ((xxh_u32)bytePtr[1] << 16) + | ((xxh_u32)bytePtr[0] << 24); +} + +#else +XXH_FORCE_INLINE xxh_u32 XXH_readLE32(const void* ptr) +{ + return XXH_CPU_LITTLE_ENDIAN ? XXH_read32(ptr) : XXH_swap32(XXH_read32(ptr)); +} + +static xxh_u32 XXH_readBE32(const void* ptr) +{ + return XXH_CPU_LITTLE_ENDIAN ? XXH_swap32(XXH_read32(ptr)) : XXH_read32(ptr); +} +#endif + +XXH_FORCE_INLINE xxh_u32 +XXH_readLE32_align(const void* ptr, XXH_alignment align) +{ + if (align==XXH_unaligned) { + return XXH_readLE32(ptr); + } else { + return XXH_CPU_LITTLE_ENDIAN ? *(const xxh_u32*)ptr : XXH_swap32(*(const xxh_u32*)ptr); + } +} + + +/* ************************************* +* Misc +***************************************/ +/*! @ingroup public */ +XXH_PUBLIC_API unsigned XXH_versionNumber (void) { return XXH_VERSION_NUMBER; } + + +/* ******************************************************************* +* 32-bit hash functions +*********************************************************************/ +/*! + * @} + * @defgroup XXH32_impl XXH32 implementation + * @ingroup impl + * + * Details on the XXH32 implementation. + * @{ + */ + /* #define instead of static const, to be used as initializers */ +#define XXH_PRIME32_1 0x9E3779B1U /*!< 0b10011110001101110111100110110001 */ +#define XXH_PRIME32_2 0x85EBCA77U /*!< 0b10000101111010111100101001110111 */ +#define XXH_PRIME32_3 0xC2B2AE3DU /*!< 0b11000010101100101010111000111101 */ +#define XXH_PRIME32_4 0x27D4EB2FU /*!< 0b00100111110101001110101100101111 */ +#define XXH_PRIME32_5 0x165667B1U /*!< 0b00010110010101100110011110110001 */ + +#ifdef XXH_OLD_NAMES +# define PRIME32_1 XXH_PRIME32_1 +# define PRIME32_2 XXH_PRIME32_2 +# define PRIME32_3 XXH_PRIME32_3 +# define PRIME32_4 XXH_PRIME32_4 +# define PRIME32_5 XXH_PRIME32_5 +#endif + +/*! + * @internal + * @brief Normal stripe processing routine. + * + * This shuffles the bits so that any bit from @p input impacts several bits in + * @p acc. + * + * @param acc The accumulator lane. + * @param input The stripe of input to mix. + * @return The mixed accumulator lane. + */ +static xxh_u32 XXH32_round(xxh_u32 acc, xxh_u32 input) +{ + acc += input * XXH_PRIME32_2; + acc = XXH_rotl32(acc, 13); + acc *= XXH_PRIME32_1; +#if (defined(__SSE4_1__) || defined(__aarch64__) || defined(__wasm_simd128__)) && !defined(XXH_ENABLE_AUTOVECTORIZE) + /* + * UGLY HACK: + * A compiler fence is used to prevent GCC and Clang from + * autovectorizing the XXH32 loop (pragmas and attributes don't work for some + * reason) without globally disabling SSE4.1. + * + * The reason we want to avoid vectorization is because despite working on + * 4 integers at a time, there are multiple factors slowing XXH32 down on + * SSE4: + * - There's a ridiculous amount of lag from pmulld (10 cycles of latency on + * newer chips!) making it slightly slower to multiply four integers at + * once compared to four integers independently. Even when pmulld was + * fastest, Sandy/Ivy Bridge, it is still not worth it to go into SSE + * just to multiply unless doing a long operation. + * + * - Four instructions are required to rotate, + * movqda tmp, v // not required with VEX encoding + * pslld tmp, 13 // tmp <<= 13 + * psrld v, 19 // x >>= 19 + * por v, tmp // x |= tmp + * compared to one for scalar: + * roll v, 13 // reliably fast across the board + * shldl v, v, 13 // Sandy Bridge and later prefer this for some reason + * + * - Instruction level parallelism is actually more beneficial here because + * the SIMD actually serializes this operation: While v1 is rotating, v2 + * can load data, while v3 can multiply. SSE forces them to operate + * together. + * + * This is also enabled on AArch64, as Clang is *very aggressive* in vectorizing + * the loop. NEON is only faster on the A53, and with the newer cores, it is less + * than half the speed. + * + * Additionally, this is used on WASM SIMD128 because it JITs to the same + * SIMD instructions and has the same issue. + */ + XXH_COMPILER_GUARD(acc); +#endif + return acc; +} + +/*! + * @internal + * @brief Mixes all bits to finalize the hash. + * + * The final mix ensures that all input bits have a chance to impact any bit in + * the output digest, resulting in an unbiased distribution. + * + * @param hash The hash to avalanche. + * @return The avalanched hash. + */ +static xxh_u32 XXH32_avalanche(xxh_u32 hash) +{ + hash ^= hash >> 15; + hash *= XXH_PRIME32_2; + hash ^= hash >> 13; + hash *= XXH_PRIME32_3; + hash ^= hash >> 16; + return hash; +} + +#define XXH_get32bits(p) XXH_readLE32_align(p, align) + +/*! + * @internal + * @brief Processes the last 0-15 bytes of @p ptr. + * + * There may be up to 15 bytes remaining to consume from the input. + * This final stage will digest them to ensure that all input bytes are present + * in the final mix. + * + * @param hash The hash to finalize. + * @param ptr The pointer to the remaining input. + * @param len The remaining length, modulo 16. + * @param align Whether @p ptr is aligned. + * @return The finalized hash. + * @see XXH64_finalize(). + */ +static XXH_PUREF xxh_u32 +XXH32_finalize(xxh_u32 hash, const xxh_u8* ptr, size_t len, XXH_alignment align) +{ +#define XXH_PROCESS1 do { \ + hash += (*ptr++) * XXH_PRIME32_5; \ + hash = XXH_rotl32(hash, 11) * XXH_PRIME32_1; \ +} while (0) + +#define XXH_PROCESS4 do { \ + hash += XXH_get32bits(ptr) * XXH_PRIME32_3; \ + ptr += 4; \ + hash = XXH_rotl32(hash, 17) * XXH_PRIME32_4; \ +} while (0) + + if (ptr==NULL) XXH_ASSERT(len == 0); + + /* Compact rerolled version; generally faster */ + if (!XXH32_ENDJMP) { + len &= 15; + while (len >= 4) { + XXH_PROCESS4; + len -= 4; + } + while (len > 0) { + XXH_PROCESS1; + --len; + } + return XXH32_avalanche(hash); + } else { + switch(len&15) /* or switch(bEnd - p) */ { + case 12: XXH_PROCESS4; + XXH_FALLTHROUGH; /* fallthrough */ + case 8: XXH_PROCESS4; + XXH_FALLTHROUGH; /* fallthrough */ + case 4: XXH_PROCESS4; + return XXH32_avalanche(hash); + + case 13: XXH_PROCESS4; + XXH_FALLTHROUGH; /* fallthrough */ + case 9: XXH_PROCESS4; + XXH_FALLTHROUGH; /* fallthrough */ + case 5: XXH_PROCESS4; + XXH_PROCESS1; + return XXH32_avalanche(hash); + + case 14: XXH_PROCESS4; + XXH_FALLTHROUGH; /* fallthrough */ + case 10: XXH_PROCESS4; + XXH_FALLTHROUGH; /* fallthrough */ + case 6: XXH_PROCESS4; + XXH_PROCESS1; + XXH_PROCESS1; + return XXH32_avalanche(hash); + + case 15: XXH_PROCESS4; + XXH_FALLTHROUGH; /* fallthrough */ + case 11: XXH_PROCESS4; + XXH_FALLTHROUGH; /* fallthrough */ + case 7: XXH_PROCESS4; + XXH_FALLTHROUGH; /* fallthrough */ + case 3: XXH_PROCESS1; + XXH_FALLTHROUGH; /* fallthrough */ + case 2: XXH_PROCESS1; + XXH_FALLTHROUGH; /* fallthrough */ + case 1: XXH_PROCESS1; + XXH_FALLTHROUGH; /* fallthrough */ + case 0: return XXH32_avalanche(hash); + } + XXH_ASSERT(0); + return hash; /* reaching this point is deemed impossible */ + } +} + +#ifdef XXH_OLD_NAMES +# define PROCESS1 XXH_PROCESS1 +# define PROCESS4 XXH_PROCESS4 +#else +# undef XXH_PROCESS1 +# undef XXH_PROCESS4 +#endif + +/*! + * @internal + * @brief The implementation for @ref XXH32(). + * + * @param input , len , seed Directly passed from @ref XXH32(). + * @param align Whether @p input is aligned. + * @return The calculated hash. + */ +XXH_FORCE_INLINE XXH_PUREF xxh_u32 +XXH32_endian_align(const xxh_u8* input, size_t len, xxh_u32 seed, XXH_alignment align) +{ + xxh_u32 h32; + + if (input==NULL) XXH_ASSERT(len == 0); + + if (len>=16) { + const xxh_u8* const bEnd = input + len; + const xxh_u8* const limit = bEnd - 15; + xxh_u32 v1 = seed + XXH_PRIME32_1 + XXH_PRIME32_2; + xxh_u32 v2 = seed + XXH_PRIME32_2; + xxh_u32 v3 = seed + 0; + xxh_u32 v4 = seed - XXH_PRIME32_1; + + do { + v1 = XXH32_round(v1, XXH_get32bits(input)); input += 4; + v2 = XXH32_round(v2, XXH_get32bits(input)); input += 4; + v3 = XXH32_round(v3, XXH_get32bits(input)); input += 4; + v4 = XXH32_round(v4, XXH_get32bits(input)); input += 4; + } while (input < limit); + + h32 = XXH_rotl32(v1, 1) + XXH_rotl32(v2, 7) + + XXH_rotl32(v3, 12) + XXH_rotl32(v4, 18); + } else { + h32 = seed + XXH_PRIME32_5; + } + + h32 += (xxh_u32)len; + + return XXH32_finalize(h32, input, len&15, align); +} + +/*! @ingroup XXH32_family */ +XXH_PUBLIC_API XXH32_hash_t XXH32 (const void* input, size_t len, XXH32_hash_t seed) +{ +#if !defined(XXH_NO_STREAM) && XXH_SIZE_OPT >= 2 + /* Simple version, good for code maintenance, but unfortunately slow for small inputs */ + XXH32_state_t state; + XXH32_reset(&state, seed); + XXH32_update(&state, (const xxh_u8*)input, len); + return XXH32_digest(&state); +#else + if (XXH_FORCE_ALIGN_CHECK) { + if ((((size_t)input) & 3) == 0) { /* Input is 4-bytes aligned, leverage the speed benefit */ + return XXH32_endian_align((const xxh_u8*)input, len, seed, XXH_aligned); + } } + + return XXH32_endian_align((const xxh_u8*)input, len, seed, XXH_unaligned); +#endif +} + + + +/******* Hash streaming *******/ +#ifndef XXH_NO_STREAM +/*! @ingroup XXH32_family */ +XXH_PUBLIC_API XXH32_state_t* XXH32_createState(void) +{ + return (XXH32_state_t*)XXH_malloc(sizeof(XXH32_state_t)); +} +/*! @ingroup XXH32_family */ +XXH_PUBLIC_API XXH_errorcode XXH32_freeState(XXH32_state_t* statePtr) +{ + XXH_free(statePtr); + return XXH_OK; +} + +/*! @ingroup XXH32_family */ +XXH_PUBLIC_API void XXH32_copyState(XXH32_state_t* dstState, const XXH32_state_t* srcState) +{ + XXH_memcpy(dstState, srcState, sizeof(*dstState)); +} + +/*! @ingroup XXH32_family */ +XXH_PUBLIC_API XXH_errorcode XXH32_reset(XXH32_state_t* statePtr, XXH32_hash_t seed) +{ + XXH_ASSERT(statePtr != NULL); + memset(statePtr, 0, sizeof(*statePtr)); + statePtr->v[0] = seed + XXH_PRIME32_1 + XXH_PRIME32_2; + statePtr->v[1] = seed + XXH_PRIME32_2; + statePtr->v[2] = seed + 0; + statePtr->v[3] = seed - XXH_PRIME32_1; + return XXH_OK; +} + + +/*! @ingroup XXH32_family */ +XXH_PUBLIC_API XXH_errorcode +XXH32_update(XXH32_state_t* state, const void* input, size_t len) +{ + if (input==NULL) { + XXH_ASSERT(len == 0); + return XXH_OK; + } + + { const xxh_u8* p = (const xxh_u8*)input; + const xxh_u8* const bEnd = p + len; + + state->total_len_32 += (XXH32_hash_t)len; + state->large_len |= (XXH32_hash_t)((len>=16) | (state->total_len_32>=16)); + + if (state->memsize + len < 16) { /* fill in tmp buffer */ + XXH_memcpy((xxh_u8*)(state->mem32) + state->memsize, input, len); + state->memsize += (XXH32_hash_t)len; + return XXH_OK; + } + + if (state->memsize) { /* some data left from previous update */ + XXH_memcpy((xxh_u8*)(state->mem32) + state->memsize, input, 16-state->memsize); + { const xxh_u32* p32 = state->mem32; + state->v[0] = XXH32_round(state->v[0], XXH_readLE32(p32)); p32++; + state->v[1] = XXH32_round(state->v[1], XXH_readLE32(p32)); p32++; + state->v[2] = XXH32_round(state->v[2], XXH_readLE32(p32)); p32++; + state->v[3] = XXH32_round(state->v[3], XXH_readLE32(p32)); + } + p += 16-state->memsize; + state->memsize = 0; + } + + if (p <= bEnd-16) { + const xxh_u8* const limit = bEnd - 16; + + do { + state->v[0] = XXH32_round(state->v[0], XXH_readLE32(p)); p+=4; + state->v[1] = XXH32_round(state->v[1], XXH_readLE32(p)); p+=4; + state->v[2] = XXH32_round(state->v[2], XXH_readLE32(p)); p+=4; + state->v[3] = XXH32_round(state->v[3], XXH_readLE32(p)); p+=4; + } while (p<=limit); + + } + + if (p < bEnd) { + XXH_memcpy(state->mem32, p, (size_t)(bEnd-p)); + state->memsize = (unsigned)(bEnd-p); + } + } + + return XXH_OK; +} + + +/*! @ingroup XXH32_family */ +XXH_PUBLIC_API XXH32_hash_t XXH32_digest(const XXH32_state_t* state) +{ + xxh_u32 h32; + + if (state->large_len) { + h32 = XXH_rotl32(state->v[0], 1) + + XXH_rotl32(state->v[1], 7) + + XXH_rotl32(state->v[2], 12) + + XXH_rotl32(state->v[3], 18); + } else { + h32 = state->v[2] /* == seed */ + XXH_PRIME32_5; + } + + h32 += state->total_len_32; + + return XXH32_finalize(h32, (const xxh_u8*)state->mem32, state->memsize, XXH_aligned); +} +#endif /* !XXH_NO_STREAM */ + +/******* Canonical representation *******/ + +/*! @ingroup XXH32_family */ +XXH_PUBLIC_API void XXH32_canonicalFromHash(XXH32_canonical_t* dst, XXH32_hash_t hash) +{ + XXH_STATIC_ASSERT(sizeof(XXH32_canonical_t) == sizeof(XXH32_hash_t)); + if (XXH_CPU_LITTLE_ENDIAN) hash = XXH_swap32(hash); + XXH_memcpy(dst, &hash, sizeof(*dst)); +} +/*! @ingroup XXH32_family */ +XXH_PUBLIC_API XXH32_hash_t XXH32_hashFromCanonical(const XXH32_canonical_t* src) +{ + return XXH_readBE32(src); +} + + +#ifndef XXH_NO_LONG_LONG + +/* ******************************************************************* +* 64-bit hash functions +*********************************************************************/ +/*! + * @} + * @ingroup impl + * @{ + */ +/******* Memory access *******/ + +typedef XXH64_hash_t xxh_u64; + +#ifdef XXH_OLD_NAMES +# define U64 xxh_u64 +#endif + +#if (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==3)) +/* + * Manual byteshift. Best for old compilers which don't inline memcpy. + * We actually directly use XXH_readLE64 and XXH_readBE64. + */ +#elif (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==2)) + +/* Force direct memory access. Only works on CPU which support unaligned memory access in hardware */ +static xxh_u64 XXH_read64(const void* memPtr) +{ + return *(const xxh_u64*) memPtr; +} + +#elif (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==1)) + +/* + * __attribute__((aligned(1))) is supported by gcc and clang. Originally the + * documentation claimed that it only increased the alignment, but actually it + * can decrease it on gcc, clang, and icc: + * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=69502, + * https://gcc.godbolt.org/z/xYez1j67Y. + */ +#ifdef XXH_OLD_NAMES +typedef union { xxh_u32 u32; xxh_u64 u64; } __attribute__((__packed__)) unalign64; +#endif +static xxh_u64 XXH_read64(const void* ptr) +{ + typedef __attribute__((__aligned__(1))) xxh_u64 xxh_unalign64; + return *((const xxh_unalign64*)ptr); +} + +#else + +/* + * Portable and safe solution. Generally efficient. + * see: https://fastcompression.blogspot.com/2015/08/accessing-unaligned-memory.html + */ +static xxh_u64 XXH_read64(const void* memPtr) +{ + xxh_u64 val; + XXH_memcpy(&val, memPtr, sizeof(val)); + return val; +} + +#endif /* XXH_FORCE_DIRECT_MEMORY_ACCESS */ + +#if defined(_MSC_VER) /* Visual Studio */ +# define XXH_swap64 _byteswap_uint64 +#elif XXH_GCC_VERSION >= 403 +# define XXH_swap64 __builtin_bswap64 +#else +static xxh_u64 XXH_swap64(xxh_u64 x) +{ + return ((x << 56) & 0xff00000000000000ULL) | + ((x << 40) & 0x00ff000000000000ULL) | + ((x << 24) & 0x0000ff0000000000ULL) | + ((x << 8) & 0x000000ff00000000ULL) | + ((x >> 8) & 0x00000000ff000000ULL) | + ((x >> 24) & 0x0000000000ff0000ULL) | + ((x >> 40) & 0x000000000000ff00ULL) | + ((x >> 56) & 0x00000000000000ffULL); +} +#endif + + +/* XXH_FORCE_MEMORY_ACCESS==3 is an endian-independent byteshift load. */ +#if (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==3)) + +XXH_FORCE_INLINE xxh_u64 XXH_readLE64(const void* memPtr) +{ + const xxh_u8* bytePtr = (const xxh_u8 *)memPtr; + return bytePtr[0] + | ((xxh_u64)bytePtr[1] << 8) + | ((xxh_u64)bytePtr[2] << 16) + | ((xxh_u64)bytePtr[3] << 24) + | ((xxh_u64)bytePtr[4] << 32) + | ((xxh_u64)bytePtr[5] << 40) + | ((xxh_u64)bytePtr[6] << 48) + | ((xxh_u64)bytePtr[7] << 56); +} + +XXH_FORCE_INLINE xxh_u64 XXH_readBE64(const void* memPtr) +{ + const xxh_u8* bytePtr = (const xxh_u8 *)memPtr; + return bytePtr[7] + | ((xxh_u64)bytePtr[6] << 8) + | ((xxh_u64)bytePtr[5] << 16) + | ((xxh_u64)bytePtr[4] << 24) + | ((xxh_u64)bytePtr[3] << 32) + | ((xxh_u64)bytePtr[2] << 40) + | ((xxh_u64)bytePtr[1] << 48) + | ((xxh_u64)bytePtr[0] << 56); +} + +#else +XXH_FORCE_INLINE xxh_u64 XXH_readLE64(const void* ptr) +{ + return XXH_CPU_LITTLE_ENDIAN ? XXH_read64(ptr) : XXH_swap64(XXH_read64(ptr)); +} + +static xxh_u64 XXH_readBE64(const void* ptr) +{ + return XXH_CPU_LITTLE_ENDIAN ? XXH_swap64(XXH_read64(ptr)) : XXH_read64(ptr); +} +#endif + +XXH_FORCE_INLINE xxh_u64 +XXH_readLE64_align(const void* ptr, XXH_alignment align) +{ + if (align==XXH_unaligned) + return XXH_readLE64(ptr); + else + return XXH_CPU_LITTLE_ENDIAN ? *(const xxh_u64*)ptr : XXH_swap64(*(const xxh_u64*)ptr); +} + + +/******* xxh64 *******/ +/*! + * @} + * @defgroup XXH64_impl XXH64 implementation + * @ingroup impl + * + * Details on the XXH64 implementation. + * @{ + */ +/* #define rather that static const, to be used as initializers */ +#define XXH_PRIME64_1 0x9E3779B185EBCA87ULL /*!< 0b1001111000110111011110011011000110000101111010111100101010000111 */ +#define XXH_PRIME64_2 0xC2B2AE3D27D4EB4FULL /*!< 0b1100001010110010101011100011110100100111110101001110101101001111 */ +#define XXH_PRIME64_3 0x165667B19E3779F9ULL /*!< 0b0001011001010110011001111011000110011110001101110111100111111001 */ +#define XXH_PRIME64_4 0x85EBCA77C2B2AE63ULL /*!< 0b1000010111101011110010100111011111000010101100101010111001100011 */ +#define XXH_PRIME64_5 0x27D4EB2F165667C5ULL /*!< 0b0010011111010100111010110010111100010110010101100110011111000101 */ + +#ifdef XXH_OLD_NAMES +# define PRIME64_1 XXH_PRIME64_1 +# define PRIME64_2 XXH_PRIME64_2 +# define PRIME64_3 XXH_PRIME64_3 +# define PRIME64_4 XXH_PRIME64_4 +# define PRIME64_5 XXH_PRIME64_5 +#endif + +/*! @copydoc XXH32_round */ +static xxh_u64 XXH64_round(xxh_u64 acc, xxh_u64 input) +{ + acc += input * XXH_PRIME64_2; + acc = XXH_rotl64(acc, 31); + acc *= XXH_PRIME64_1; +#if (defined(__AVX512F__)) && !defined(XXH_ENABLE_AUTOVECTORIZE) + /* + * DISABLE AUTOVECTORIZATION: + * A compiler fence is used to prevent GCC and Clang from + * autovectorizing the XXH64 loop (pragmas and attributes don't work for some + * reason) without globally disabling AVX512. + * + * Autovectorization of XXH64 tends to be detrimental, + * though the exact outcome may change depending on exact cpu and compiler version. + * For information, it has been reported as detrimental for Skylake-X, + * but possibly beneficial for Zen4. + * + * The default is to disable auto-vectorization, + * but you can select to enable it instead using `XXH_ENABLE_AUTOVECTORIZE` build variable. + */ + XXH_COMPILER_GUARD(acc); +#endif + return acc; +} + +static xxh_u64 XXH64_mergeRound(xxh_u64 acc, xxh_u64 val) +{ + val = XXH64_round(0, val); + acc ^= val; + acc = acc * XXH_PRIME64_1 + XXH_PRIME64_4; + return acc; +} + +/*! @copydoc XXH32_avalanche */ +static xxh_u64 XXH64_avalanche(xxh_u64 hash) +{ + hash ^= hash >> 33; + hash *= XXH_PRIME64_2; + hash ^= hash >> 29; + hash *= XXH_PRIME64_3; + hash ^= hash >> 32; + return hash; +} + + +#define XXH_get64bits(p) XXH_readLE64_align(p, align) + +/*! + * @internal + * @brief Processes the last 0-31 bytes of @p ptr. + * + * There may be up to 31 bytes remaining to consume from the input. + * This final stage will digest them to ensure that all input bytes are present + * in the final mix. + * + * @param hash The hash to finalize. + * @param ptr The pointer to the remaining input. + * @param len The remaining length, modulo 32. + * @param align Whether @p ptr is aligned. + * @return The finalized hash + * @see XXH32_finalize(). + */ +static XXH_PUREF xxh_u64 +XXH64_finalize(xxh_u64 hash, const xxh_u8* ptr, size_t len, XXH_alignment align) +{ + if (ptr==NULL) XXH_ASSERT(len == 0); + len &= 31; + while (len >= 8) { + xxh_u64 const k1 = XXH64_round(0, XXH_get64bits(ptr)); + ptr += 8; + hash ^= k1; + hash = XXH_rotl64(hash,27) * XXH_PRIME64_1 + XXH_PRIME64_4; + len -= 8; + } + if (len >= 4) { + hash ^= (xxh_u64)(XXH_get32bits(ptr)) * XXH_PRIME64_1; + ptr += 4; + hash = XXH_rotl64(hash, 23) * XXH_PRIME64_2 + XXH_PRIME64_3; + len -= 4; + } + while (len > 0) { + hash ^= (*ptr++) * XXH_PRIME64_5; + hash = XXH_rotl64(hash, 11) * XXH_PRIME64_1; + --len; + } + return XXH64_avalanche(hash); +} + +#ifdef XXH_OLD_NAMES +# define PROCESS1_64 XXH_PROCESS1_64 +# define PROCESS4_64 XXH_PROCESS4_64 +# define PROCESS8_64 XXH_PROCESS8_64 +#else +# undef XXH_PROCESS1_64 +# undef XXH_PROCESS4_64 +# undef XXH_PROCESS8_64 +#endif + +/*! + * @internal + * @brief The implementation for @ref XXH64(). + * + * @param input , len , seed Directly passed from @ref XXH64(). + * @param align Whether @p input is aligned. + * @return The calculated hash. + */ +XXH_FORCE_INLINE XXH_PUREF xxh_u64 +XXH64_endian_align(const xxh_u8* input, size_t len, xxh_u64 seed, XXH_alignment align) +{ + xxh_u64 h64; + if (input==NULL) XXH_ASSERT(len == 0); + + if (len>=32) { + const xxh_u8* const bEnd = input + len; + const xxh_u8* const limit = bEnd - 31; + xxh_u64 v1 = seed + XXH_PRIME64_1 + XXH_PRIME64_2; + xxh_u64 v2 = seed + XXH_PRIME64_2; + xxh_u64 v3 = seed + 0; + xxh_u64 v4 = seed - XXH_PRIME64_1; + + do { + v1 = XXH64_round(v1, XXH_get64bits(input)); input+=8; + v2 = XXH64_round(v2, XXH_get64bits(input)); input+=8; + v3 = XXH64_round(v3, XXH_get64bits(input)); input+=8; + v4 = XXH64_round(v4, XXH_get64bits(input)); input+=8; + } while (input= 2 + /* Simple version, good for code maintenance, but unfortunately slow for small inputs */ + XXH64_state_t state; + XXH64_reset(&state, seed); + XXH64_update(&state, (const xxh_u8*)input, len); + return XXH64_digest(&state); +#else + if (XXH_FORCE_ALIGN_CHECK) { + if ((((size_t)input) & 7)==0) { /* Input is aligned, let's leverage the speed advantage */ + return XXH64_endian_align((const xxh_u8*)input, len, seed, XXH_aligned); + } } + + return XXH64_endian_align((const xxh_u8*)input, len, seed, XXH_unaligned); + +#endif +} + +/******* Hash Streaming *******/ +#ifndef XXH_NO_STREAM +/*! @ingroup XXH64_family*/ +XXH_PUBLIC_API XXH64_state_t* XXH64_createState(void) +{ + return (XXH64_state_t*)XXH_malloc(sizeof(XXH64_state_t)); +} +/*! @ingroup XXH64_family */ +XXH_PUBLIC_API XXH_errorcode XXH64_freeState(XXH64_state_t* statePtr) +{ + XXH_free(statePtr); + return XXH_OK; +} + +/*! @ingroup XXH64_family */ +XXH_PUBLIC_API void XXH64_copyState(XXH_NOESCAPE XXH64_state_t* dstState, const XXH64_state_t* srcState) +{ + XXH_memcpy(dstState, srcState, sizeof(*dstState)); +} + +/*! @ingroup XXH64_family */ +XXH_PUBLIC_API XXH_errorcode XXH64_reset(XXH_NOESCAPE XXH64_state_t* statePtr, XXH64_hash_t seed) +{ + XXH_ASSERT(statePtr != NULL); + memset(statePtr, 0, sizeof(*statePtr)); + statePtr->v[0] = seed + XXH_PRIME64_1 + XXH_PRIME64_2; + statePtr->v[1] = seed + XXH_PRIME64_2; + statePtr->v[2] = seed + 0; + statePtr->v[3] = seed - XXH_PRIME64_1; + return XXH_OK; +} + +/*! @ingroup XXH64_family */ +XXH_PUBLIC_API XXH_errorcode +XXH64_update (XXH_NOESCAPE XXH64_state_t* state, XXH_NOESCAPE const void* input, size_t len) +{ + if (input==NULL) { + XXH_ASSERT(len == 0); + return XXH_OK; + } + + { const xxh_u8* p = (const xxh_u8*)input; + const xxh_u8* const bEnd = p + len; + + state->total_len += len; + + if (state->memsize + len < 32) { /* fill in tmp buffer */ + XXH_memcpy(((xxh_u8*)state->mem64) + state->memsize, input, len); + state->memsize += (xxh_u32)len; + return XXH_OK; + } + + if (state->memsize) { /* tmp buffer is full */ + XXH_memcpy(((xxh_u8*)state->mem64) + state->memsize, input, 32-state->memsize); + state->v[0] = XXH64_round(state->v[0], XXH_readLE64(state->mem64+0)); + state->v[1] = XXH64_round(state->v[1], XXH_readLE64(state->mem64+1)); + state->v[2] = XXH64_round(state->v[2], XXH_readLE64(state->mem64+2)); + state->v[3] = XXH64_round(state->v[3], XXH_readLE64(state->mem64+3)); + p += 32 - state->memsize; + state->memsize = 0; + } + + if (p+32 <= bEnd) { + const xxh_u8* const limit = bEnd - 32; + + do { + state->v[0] = XXH64_round(state->v[0], XXH_readLE64(p)); p+=8; + state->v[1] = XXH64_round(state->v[1], XXH_readLE64(p)); p+=8; + state->v[2] = XXH64_round(state->v[2], XXH_readLE64(p)); p+=8; + state->v[3] = XXH64_round(state->v[3], XXH_readLE64(p)); p+=8; + } while (p<=limit); + + } + + if (p < bEnd) { + XXH_memcpy(state->mem64, p, (size_t)(bEnd-p)); + state->memsize = (unsigned)(bEnd-p); + } + } + + return XXH_OK; +} + + +/*! @ingroup XXH64_family */ +XXH_PUBLIC_API XXH64_hash_t XXH64_digest(XXH_NOESCAPE const XXH64_state_t* state) +{ + xxh_u64 h64; + + if (state->total_len >= 32) { + h64 = XXH_rotl64(state->v[0], 1) + XXH_rotl64(state->v[1], 7) + XXH_rotl64(state->v[2], 12) + XXH_rotl64(state->v[3], 18); + h64 = XXH64_mergeRound(h64, state->v[0]); + h64 = XXH64_mergeRound(h64, state->v[1]); + h64 = XXH64_mergeRound(h64, state->v[2]); + h64 = XXH64_mergeRound(h64, state->v[3]); + } else { + h64 = state->v[2] /*seed*/ + XXH_PRIME64_5; + } + + h64 += (xxh_u64) state->total_len; + + return XXH64_finalize(h64, (const xxh_u8*)state->mem64, (size_t)state->total_len, XXH_aligned); +} +#endif /* !XXH_NO_STREAM */ + +/******* Canonical representation *******/ + +/*! @ingroup XXH64_family */ +XXH_PUBLIC_API void XXH64_canonicalFromHash(XXH_NOESCAPE XXH64_canonical_t* dst, XXH64_hash_t hash) +{ + XXH_STATIC_ASSERT(sizeof(XXH64_canonical_t) == sizeof(XXH64_hash_t)); + if (XXH_CPU_LITTLE_ENDIAN) hash = XXH_swap64(hash); + XXH_memcpy(dst, &hash, sizeof(*dst)); +} + +/*! @ingroup XXH64_family */ +XXH_PUBLIC_API XXH64_hash_t XXH64_hashFromCanonical(XXH_NOESCAPE const XXH64_canonical_t* src) +{ + return XXH_readBE64(src); +} + +#ifndef XXH_NO_XXH3 + +/* ********************************************************************* +* XXH3 +* New generation hash designed for speed on small keys and vectorization +************************************************************************ */ +/*! + * @} + * @defgroup XXH3_impl XXH3 implementation + * @ingroup impl + * @{ + */ + +/* === Compiler specifics === */ + +#if ((defined(sun) || defined(__sun)) && __cplusplus) /* Solaris includes __STDC_VERSION__ with C++. Tested with GCC 5.5 */ +# define XXH_RESTRICT /* disable */ +#elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L /* >= C99 */ +# define XXH_RESTRICT restrict +#elif (defined (__GNUC__) && ((__GNUC__ > 3) || (__GNUC__ == 3 && __GNUC_MINOR__ >= 1))) \ + || (defined (__clang__)) \ + || (defined (_MSC_VER) && (_MSC_VER >= 1400)) \ + || (defined (__INTEL_COMPILER) && (__INTEL_COMPILER >= 1300)) +/* + * There are a LOT more compilers that recognize __restrict but this + * covers the major ones. + */ +# define XXH_RESTRICT __restrict +#else +# define XXH_RESTRICT /* disable */ +#endif + +#if (defined(__GNUC__) && (__GNUC__ >= 3)) \ + || (defined(__INTEL_COMPILER) && (__INTEL_COMPILER >= 800)) \ + || defined(__clang__) +# define XXH_likely(x) __builtin_expect(x, 1) +# define XXH_unlikely(x) __builtin_expect(x, 0) +#else +# define XXH_likely(x) (x) +# define XXH_unlikely(x) (x) +#endif + +#ifndef XXH_HAS_INCLUDE +# ifdef __has_include +/* + * Not defined as XXH_HAS_INCLUDE(x) (function-like) because + * this causes segfaults in Apple Clang 4.2 (on Mac OS X 10.7 Lion) + */ +# define XXH_HAS_INCLUDE __has_include +# else +# define XXH_HAS_INCLUDE(x) 0 +# endif +#endif + +#if defined(__GNUC__) || defined(__clang__) +# if defined(__ARM_FEATURE_SVE) +# include +# endif +# if defined(__ARM_NEON__) || defined(__ARM_NEON) \ + || (defined(_M_ARM) && _M_ARM >= 7) \ + || defined(_M_ARM64) || defined(_M_ARM64EC) \ + || (defined(__wasm_simd128__) && XXH_HAS_INCLUDE()) /* WASM SIMD128 via SIMDe */ +# define inline __inline__ /* circumvent a clang bug */ +# include +# undef inline +# elif defined(__AVX2__) +# include +# elif defined(__SSE2__) +# include +# endif +#endif + +#if defined(_MSC_VER) +# include +#endif + +/* + * One goal of XXH3 is to make it fast on both 32-bit and 64-bit, while + * remaining a true 64-bit/128-bit hash function. + * + * This is done by prioritizing a subset of 64-bit operations that can be + * emulated without too many steps on the average 32-bit machine. + * + * For example, these two lines seem similar, and run equally fast on 64-bit: + * + * xxh_u64 x; + * x ^= (x >> 47); // good + * x ^= (x >> 13); // bad + * + * However, to a 32-bit machine, there is a major difference. + * + * x ^= (x >> 47) looks like this: + * + * x.lo ^= (x.hi >> (47 - 32)); + * + * while x ^= (x >> 13) looks like this: + * + * // note: funnel shifts are not usually cheap. + * x.lo ^= (x.lo >> 13) | (x.hi << (32 - 13)); + * x.hi ^= (x.hi >> 13); + * + * The first one is significantly faster than the second, simply because the + * shift is larger than 32. This means: + * - All the bits we need are in the upper 32 bits, so we can ignore the lower + * 32 bits in the shift. + * - The shift result will always fit in the lower 32 bits, and therefore, + * we can ignore the upper 32 bits in the xor. + * + * Thanks to this optimization, XXH3 only requires these features to be efficient: + * + * - Usable unaligned access + * - A 32-bit or 64-bit ALU + * - If 32-bit, a decent ADC instruction + * - A 32 or 64-bit multiply with a 64-bit result + * - For the 128-bit variant, a decent byteswap helps short inputs. + * + * The first two are already required by XXH32, and almost all 32-bit and 64-bit + * platforms which can run XXH32 can run XXH3 efficiently. + * + * Thumb-1, the classic 16-bit only subset of ARM's instruction set, is one + * notable exception. + * + * First of all, Thumb-1 lacks support for the UMULL instruction which + * performs the important long multiply. This means numerous __aeabi_lmul + * calls. + * + * Second of all, the 8 functional registers are just not enough. + * Setup for __aeabi_lmul, byteshift loads, pointers, and all arithmetic need + * Lo registers, and this shuffling results in thousands more MOVs than A32. + * + * A32 and T32 don't have this limitation. They can access all 14 registers, + * do a 32->64 multiply with UMULL, and the flexible operand allowing free + * shifts is helpful, too. + * + * Therefore, we do a quick sanity check. + * + * If compiling Thumb-1 for a target which supports ARM instructions, we will + * emit a warning, as it is not a "sane" platform to compile for. + * + * Usually, if this happens, it is because of an accident and you probably need + * to specify -march, as you likely meant to compile for a newer architecture. + * + * Credit: large sections of the vectorial and asm source code paths + * have been contributed by @easyaspi314 + */ +#if defined(__thumb__) && !defined(__thumb2__) && defined(__ARM_ARCH_ISA_ARM) +# warning "XXH3 is highly inefficient without ARM or Thumb-2." +#endif + +/* ========================================== + * Vectorization detection + * ========================================== */ + +#ifdef XXH_DOXYGEN +/*! + * @ingroup tuning + * @brief Overrides the vectorization implementation chosen for XXH3. + * + * Can be defined to 0 to disable SIMD or any of the values mentioned in + * @ref XXH_VECTOR_TYPE. + * + * If this is not defined, it uses predefined macros to determine the best + * implementation. + */ +# define XXH_VECTOR XXH_SCALAR +/*! + * @ingroup tuning + * @brief Possible values for @ref XXH_VECTOR. + * + * Note that these are actually implemented as macros. + * + * If this is not defined, it is detected automatically. + * internal macro XXH_X86DISPATCH overrides this. + */ +enum XXH_VECTOR_TYPE /* fake enum */ { + XXH_SCALAR = 0, /*!< Portable scalar version */ + XXH_SSE2 = 1, /*!< + * SSE2 for Pentium 4, Opteron, all x86_64. + * + * @note SSE2 is also guaranteed on Windows 10, macOS, and + * Android x86. + */ + XXH_AVX2 = 2, /*!< AVX2 for Haswell and Bulldozer */ + XXH_AVX512 = 3, /*!< AVX512 for Skylake and Icelake */ + XXH_NEON = 4, /*!< + * NEON for most ARMv7-A, all AArch64, and WASM SIMD128 + * via the SIMDeverywhere polyfill provided with the + * Emscripten SDK. + */ + XXH_VSX = 5, /*!< VSX and ZVector for POWER8/z13 (64-bit) */ + XXH_SVE = 6, /*!< SVE for some ARMv8-A and ARMv9-A */ +}; +/*! + * @ingroup tuning + * @brief Selects the minimum alignment for XXH3's accumulators. + * + * When using SIMD, this should match the alignment required for said vector + * type, so, for example, 32 for AVX2. + * + * Default: Auto detected. + */ +# define XXH_ACC_ALIGN 8 +#endif + +/* Actual definition */ +#ifndef XXH_DOXYGEN +# define XXH_SCALAR 0 +# define XXH_SSE2 1 +# define XXH_AVX2 2 +# define XXH_AVX512 3 +# define XXH_NEON 4 +# define XXH_VSX 5 +# define XXH_SVE 6 +#endif + +#ifndef XXH_VECTOR /* can be defined on command line */ +# if defined(__ARM_FEATURE_SVE) +# define XXH_VECTOR XXH_SVE +# elif ( \ + defined(__ARM_NEON__) || defined(__ARM_NEON) /* gcc */ \ + || defined(_M_ARM) || defined(_M_ARM64) || defined(_M_ARM64EC) /* msvc */ \ + || (defined(__wasm_simd128__) && XXH_HAS_INCLUDE()) /* wasm simd128 via SIMDe */ \ + ) && ( \ + defined(_WIN32) || defined(__LITTLE_ENDIAN__) /* little endian only */ \ + || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__) \ + ) +# define XXH_VECTOR XXH_NEON +# elif defined(__AVX512F__) +# define XXH_VECTOR XXH_AVX512 +# elif defined(__AVX2__) +# define XXH_VECTOR XXH_AVX2 +# elif defined(__SSE2__) || defined(_M_AMD64) || defined(_M_X64) || (defined(_M_IX86_FP) && (_M_IX86_FP == 2)) +# define XXH_VECTOR XXH_SSE2 +# elif (defined(__PPC64__) && defined(__POWER8_VECTOR__)) \ + || (defined(__s390x__) && defined(__VEC__)) \ + && defined(__GNUC__) /* TODO: IBM XL */ +# define XXH_VECTOR XXH_VSX +# else +# define XXH_VECTOR XXH_SCALAR +# endif +#endif + +/* __ARM_FEATURE_SVE is only supported by GCC & Clang. */ +#if (XXH_VECTOR == XXH_SVE) && !defined(__ARM_FEATURE_SVE) +# ifdef _MSC_VER +# pragma warning(once : 4606) +# else +# warning "__ARM_FEATURE_SVE isn't supported. Use SCALAR instead." +# endif +# undef XXH_VECTOR +# define XXH_VECTOR XXH_SCALAR +#endif + +/* + * Controls the alignment of the accumulator, + * for compatibility with aligned vector loads, which are usually faster. + */ +#ifndef XXH_ACC_ALIGN +# if defined(XXH_X86DISPATCH) +# define XXH_ACC_ALIGN 64 /* for compatibility with avx512 */ +# elif XXH_VECTOR == XXH_SCALAR /* scalar */ +# define XXH_ACC_ALIGN 8 +# elif XXH_VECTOR == XXH_SSE2 /* sse2 */ +# define XXH_ACC_ALIGN 16 +# elif XXH_VECTOR == XXH_AVX2 /* avx2 */ +# define XXH_ACC_ALIGN 32 +# elif XXH_VECTOR == XXH_NEON /* neon */ +# define XXH_ACC_ALIGN 16 +# elif XXH_VECTOR == XXH_VSX /* vsx */ +# define XXH_ACC_ALIGN 16 +# elif XXH_VECTOR == XXH_AVX512 /* avx512 */ +# define XXH_ACC_ALIGN 64 +# elif XXH_VECTOR == XXH_SVE /* sve */ +# define XXH_ACC_ALIGN 64 +# endif +#endif + +#if defined(XXH_X86DISPATCH) || XXH_VECTOR == XXH_SSE2 \ + || XXH_VECTOR == XXH_AVX2 || XXH_VECTOR == XXH_AVX512 +# define XXH_SEC_ALIGN XXH_ACC_ALIGN +#elif XXH_VECTOR == XXH_SVE +# define XXH_SEC_ALIGN XXH_ACC_ALIGN +#else +# define XXH_SEC_ALIGN 8 +#endif + +#if defined(__GNUC__) || defined(__clang__) +# define XXH_ALIASING __attribute__((__may_alias__)) +#else +# define XXH_ALIASING /* nothing */ +#endif + +/* + * UGLY HACK: + * GCC usually generates the best code with -O3 for xxHash. + * + * However, when targeting AVX2, it is overzealous in its unrolling resulting + * in code roughly 3/4 the speed of Clang. + * + * There are other issues, such as GCC splitting _mm256_loadu_si256 into + * _mm_loadu_si128 + _mm256_inserti128_si256. This is an optimization which + * only applies to Sandy and Ivy Bridge... which don't even support AVX2. + * + * That is why when compiling the AVX2 version, it is recommended to use either + * -O2 -mavx2 -march=haswell + * or + * -O2 -mavx2 -mno-avx256-split-unaligned-load + * for decent performance, or to use Clang instead. + * + * Fortunately, we can control the first one with a pragma that forces GCC into + * -O2, but the other one we can't control without "failed to inline always + * inline function due to target mismatch" warnings. + */ +#if XXH_VECTOR == XXH_AVX2 /* AVX2 */ \ + && defined(__GNUC__) && !defined(__clang__) /* GCC, not Clang */ \ + && defined(__OPTIMIZE__) && XXH_SIZE_OPT <= 0 /* respect -O0 and -Os */ +# pragma GCC push_options +# pragma GCC optimize("-O2") +#endif + +#if XXH_VECTOR == XXH_NEON + +/* + * UGLY HACK: While AArch64 GCC on Linux does not seem to care, on macOS, GCC -O3 + * optimizes out the entire hashLong loop because of the aliasing violation. + * + * However, GCC is also inefficient at load-store optimization with vld1q/vst1q, + * so the only option is to mark it as aliasing. + */ +typedef uint64x2_t xxh_aliasing_uint64x2_t XXH_ALIASING; + +/*! + * @internal + * @brief `vld1q_u64` but faster and alignment-safe. + * + * On AArch64, unaligned access is always safe, but on ARMv7-a, it is only + * *conditionally* safe (`vld1` has an alignment bit like `movdq[ua]` in x86). + * + * GCC for AArch64 sees `vld1q_u8` as an intrinsic instead of a load, so it + * prohibits load-store optimizations. Therefore, a direct dereference is used. + * + * Otherwise, `vld1q_u8` is used with `vreinterpretq_u8_u64` to do a safe + * unaligned load. + */ +#if defined(__aarch64__) && defined(__GNUC__) && !defined(__clang__) +XXH_FORCE_INLINE uint64x2_t XXH_vld1q_u64(void const* ptr) /* silence -Wcast-align */ +{ + return *(xxh_aliasing_uint64x2_t const *)ptr; +} +#else +XXH_FORCE_INLINE uint64x2_t XXH_vld1q_u64(void const* ptr) +{ + return vreinterpretq_u64_u8(vld1q_u8((uint8_t const*)ptr)); +} +#endif + +/*! + * @internal + * @brief `vmlal_u32` on low and high halves of a vector. + * + * This is a workaround for AArch64 GCC < 11 which implemented arm_neon.h with + * inline assembly and were therefore incapable of merging the `vget_{low, high}_u32` + * with `vmlal_u32`. + */ +#if defined(__aarch64__) && defined(__GNUC__) && !defined(__clang__) && __GNUC__ < 11 +XXH_FORCE_INLINE uint64x2_t +XXH_vmlal_low_u32(uint64x2_t acc, uint32x4_t lhs, uint32x4_t rhs) +{ + /* Inline assembly is the only way */ + __asm__("umlal %0.2d, %1.2s, %2.2s" : "+w" (acc) : "w" (lhs), "w" (rhs)); + return acc; +} +XXH_FORCE_INLINE uint64x2_t +XXH_vmlal_high_u32(uint64x2_t acc, uint32x4_t lhs, uint32x4_t rhs) +{ + /* This intrinsic works as expected */ + return vmlal_high_u32(acc, lhs, rhs); +} +#else +/* Portable intrinsic versions */ +XXH_FORCE_INLINE uint64x2_t +XXH_vmlal_low_u32(uint64x2_t acc, uint32x4_t lhs, uint32x4_t rhs) +{ + return vmlal_u32(acc, vget_low_u32(lhs), vget_low_u32(rhs)); +} +/*! @copydoc XXH_vmlal_low_u32 + * Assume the compiler converts this to vmlal_high_u32 on aarch64 */ +XXH_FORCE_INLINE uint64x2_t +XXH_vmlal_high_u32(uint64x2_t acc, uint32x4_t lhs, uint32x4_t rhs) +{ + return vmlal_u32(acc, vget_high_u32(lhs), vget_high_u32(rhs)); +} +#endif + +/*! + * @ingroup tuning + * @brief Controls the NEON to scalar ratio for XXH3 + * + * This can be set to 2, 4, 6, or 8. + * + * ARM Cortex CPUs are _very_ sensitive to how their pipelines are used. + * + * For example, the Cortex-A73 can dispatch 3 micro-ops per cycle, but only 2 of those + * can be NEON. If you are only using NEON instructions, you are only using 2/3 of the CPU + * bandwidth. + * + * This is even more noticeable on the more advanced cores like the Cortex-A76 which + * can dispatch 8 micro-ops per cycle, but still only 2 NEON micro-ops at once. + * + * Therefore, to make the most out of the pipeline, it is beneficial to run 6 NEON lanes + * and 2 scalar lanes, which is chosen by default. + * + * This does not apply to Apple processors or 32-bit processors, which run better with + * full NEON. These will default to 8. Additionally, size-optimized builds run 8 lanes. + * + * This change benefits CPUs with large micro-op buffers without negatively affecting + * most other CPUs: + * + * | Chipset | Dispatch type | NEON only | 6:2 hybrid | Diff. | + * |:----------------------|:--------------------|----------:|-----------:|------:| + * | Snapdragon 730 (A76) | 2 NEON/8 micro-ops | 8.8 GB/s | 10.1 GB/s | ~16% | + * | Snapdragon 835 (A73) | 2 NEON/3 micro-ops | 5.1 GB/s | 5.3 GB/s | ~5% | + * | Marvell PXA1928 (A53) | In-order dual-issue | 1.9 GB/s | 1.9 GB/s | 0% | + * | Apple M1 | 4 NEON/8 micro-ops | 37.3 GB/s | 36.1 GB/s | ~-3% | + * + * It also seems to fix some bad codegen on GCC, making it almost as fast as clang. + * + * When using WASM SIMD128, if this is 2 or 6, SIMDe will scalarize 2 of the lanes meaning + * it effectively becomes worse 4. + * + * @see XXH3_accumulate_512_neon() + */ +# ifndef XXH3_NEON_LANES +# if (defined(__aarch64__) || defined(__arm64__) || defined(_M_ARM64) || defined(_M_ARM64EC)) \ + && !defined(__APPLE__) && XXH_SIZE_OPT <= 0 +# define XXH3_NEON_LANES 6 +# else +# define XXH3_NEON_LANES XXH_ACC_NB +# endif +# endif +#endif /* XXH_VECTOR == XXH_NEON */ + +/* + * VSX and Z Vector helpers. + * + * This is very messy, and any pull requests to clean this up are welcome. + * + * There are a lot of problems with supporting VSX and s390x, due to + * inconsistent intrinsics, spotty coverage, and multiple endiannesses. + */ +#if XXH_VECTOR == XXH_VSX +/* Annoyingly, these headers _may_ define three macros: `bool`, `vector`, + * and `pixel`. This is a problem for obvious reasons. + * + * These keywords are unnecessary; the spec literally says they are + * equivalent to `__bool`, `__vector`, and `__pixel` and may be undef'd + * after including the header. + * + * We use pragma push_macro/pop_macro to keep the namespace clean. */ +# pragma push_macro("bool") +# pragma push_macro("vector") +# pragma push_macro("pixel") +/* silence potential macro redefined warnings */ +# undef bool +# undef vector +# undef pixel + +# if defined(__s390x__) +# include +# else +# include +# endif + +/* Restore the original macro values, if applicable. */ +# pragma pop_macro("pixel") +# pragma pop_macro("vector") +# pragma pop_macro("bool") + +typedef __vector unsigned long long xxh_u64x2; +typedef __vector unsigned char xxh_u8x16; +typedef __vector unsigned xxh_u32x4; + +/* + * UGLY HACK: Similar to aarch64 macOS GCC, s390x GCC has the same aliasing issue. + */ +typedef xxh_u64x2 xxh_aliasing_u64x2 XXH_ALIASING; + +# ifndef XXH_VSX_BE +# if defined(__BIG_ENDIAN__) \ + || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__) +# define XXH_VSX_BE 1 +# elif defined(__VEC_ELEMENT_REG_ORDER__) && __VEC_ELEMENT_REG_ORDER__ == __ORDER_BIG_ENDIAN__ +# warning "-maltivec=be is not recommended. Please use native endianness." +# define XXH_VSX_BE 1 +# else +# define XXH_VSX_BE 0 +# endif +# endif /* !defined(XXH_VSX_BE) */ + +# if XXH_VSX_BE +# if defined(__POWER9_VECTOR__) || (defined(__clang__) && defined(__s390x__)) +# define XXH_vec_revb vec_revb +# else +/*! + * A polyfill for POWER9's vec_revb(). + */ +XXH_FORCE_INLINE xxh_u64x2 XXH_vec_revb(xxh_u64x2 val) +{ + xxh_u8x16 const vByteSwap = { 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01, 0x00, + 0x0F, 0x0E, 0x0D, 0x0C, 0x0B, 0x0A, 0x09, 0x08 }; + return vec_perm(val, val, vByteSwap); +} +# endif +# endif /* XXH_VSX_BE */ + +/*! + * Performs an unaligned vector load and byte swaps it on big endian. + */ +XXH_FORCE_INLINE xxh_u64x2 XXH_vec_loadu(const void *ptr) +{ + xxh_u64x2 ret; + XXH_memcpy(&ret, ptr, sizeof(xxh_u64x2)); +# if XXH_VSX_BE + ret = XXH_vec_revb(ret); +# endif + return ret; +} + +/* + * vec_mulo and vec_mule are very problematic intrinsics on PowerPC + * + * These intrinsics weren't added until GCC 8, despite existing for a while, + * and they are endian dependent. Also, their meaning swap depending on version. + * */ +# if defined(__s390x__) + /* s390x is always big endian, no issue on this platform */ +# define XXH_vec_mulo vec_mulo +# define XXH_vec_mule vec_mule +# elif defined(__clang__) && XXH_HAS_BUILTIN(__builtin_altivec_vmuleuw) && !defined(__ibmxl__) +/* Clang has a better way to control this, we can just use the builtin which doesn't swap. */ + /* The IBM XL Compiler (which defined __clang__) only implements the vec_* operations */ +# define XXH_vec_mulo __builtin_altivec_vmulouw +# define XXH_vec_mule __builtin_altivec_vmuleuw +# else +/* gcc needs inline assembly */ +/* Adapted from https://github.com/google/highwayhash/blob/master/highwayhash/hh_vsx.h. */ +XXH_FORCE_INLINE xxh_u64x2 XXH_vec_mulo(xxh_u32x4 a, xxh_u32x4 b) +{ + xxh_u64x2 result; + __asm__("vmulouw %0, %1, %2" : "=v" (result) : "v" (a), "v" (b)); + return result; +} +XXH_FORCE_INLINE xxh_u64x2 XXH_vec_mule(xxh_u32x4 a, xxh_u32x4 b) +{ + xxh_u64x2 result; + __asm__("vmuleuw %0, %1, %2" : "=v" (result) : "v" (a), "v" (b)); + return result; +} +# endif /* XXH_vec_mulo, XXH_vec_mule */ +#endif /* XXH_VECTOR == XXH_VSX */ + +#if XXH_VECTOR == XXH_SVE +#define ACCRND(acc, offset) \ +do { \ + svuint64_t input_vec = svld1_u64(mask, xinput + offset); \ + svuint64_t secret_vec = svld1_u64(mask, xsecret + offset); \ + svuint64_t mixed = sveor_u64_x(mask, secret_vec, input_vec); \ + svuint64_t swapped = svtbl_u64(input_vec, kSwap); \ + svuint64_t mixed_lo = svextw_u64_x(mask, mixed); \ + svuint64_t mixed_hi = svlsr_n_u64_x(mask, mixed, 32); \ + svuint64_t mul = svmad_u64_x(mask, mixed_lo, mixed_hi, swapped); \ + acc = svadd_u64_x(mask, acc, mul); \ +} while (0) +#endif /* XXH_VECTOR == XXH_SVE */ + +/* prefetch + * can be disabled, by declaring XXH_NO_PREFETCH build macro */ +#if defined(XXH_NO_PREFETCH) +# define XXH_PREFETCH(ptr) (void)(ptr) /* disabled */ +#else +# if XXH_SIZE_OPT >= 1 +# define XXH_PREFETCH(ptr) (void)(ptr) +# elif defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)) /* _mm_prefetch() not defined outside of x86/x64 */ +# include /* https://msdn.microsoft.com/fr-fr/library/84szxsww(v=vs.90).aspx */ +# define XXH_PREFETCH(ptr) _mm_prefetch((const char*)(ptr), _MM_HINT_T0) +# elif defined(__GNUC__) && ( (__GNUC__ >= 4) || ( (__GNUC__ == 3) && (__GNUC_MINOR__ >= 1) ) ) +# define XXH_PREFETCH(ptr) __builtin_prefetch((ptr), 0 /* rw==read */, 3 /* locality */) +# else +# define XXH_PREFETCH(ptr) (void)(ptr) /* disabled */ +# endif +#endif /* XXH_NO_PREFETCH */ + + +/* ========================================== + * XXH3 default settings + * ========================================== */ + +#define XXH_SECRET_DEFAULT_SIZE 192 /* minimum XXH3_SECRET_SIZE_MIN */ + +#if (XXH_SECRET_DEFAULT_SIZE < XXH3_SECRET_SIZE_MIN) +# error "default keyset is not large enough" +#endif + +/*! Pseudorandom secret taken directly from FARSH. */ +XXH_ALIGN(64) static const xxh_u8 XXH3_kSecret[XXH_SECRET_DEFAULT_SIZE] = { + 0xb8, 0xfe, 0x6c, 0x39, 0x23, 0xa4, 0x4b, 0xbe, 0x7c, 0x01, 0x81, 0x2c, 0xf7, 0x21, 0xad, 0x1c, + 0xde, 0xd4, 0x6d, 0xe9, 0x83, 0x90, 0x97, 0xdb, 0x72, 0x40, 0xa4, 0xa4, 0xb7, 0xb3, 0x67, 0x1f, + 0xcb, 0x79, 0xe6, 0x4e, 0xcc, 0xc0, 0xe5, 0x78, 0x82, 0x5a, 0xd0, 0x7d, 0xcc, 0xff, 0x72, 0x21, + 0xb8, 0x08, 0x46, 0x74, 0xf7, 0x43, 0x24, 0x8e, 0xe0, 0x35, 0x90, 0xe6, 0x81, 0x3a, 0x26, 0x4c, + 0x3c, 0x28, 0x52, 0xbb, 0x91, 0xc3, 0x00, 0xcb, 0x88, 0xd0, 0x65, 0x8b, 0x1b, 0x53, 0x2e, 0xa3, + 0x71, 0x64, 0x48, 0x97, 0xa2, 0x0d, 0xf9, 0x4e, 0x38, 0x19, 0xef, 0x46, 0xa9, 0xde, 0xac, 0xd8, + 0xa8, 0xfa, 0x76, 0x3f, 0xe3, 0x9c, 0x34, 0x3f, 0xf9, 0xdc, 0xbb, 0xc7, 0xc7, 0x0b, 0x4f, 0x1d, + 0x8a, 0x51, 0xe0, 0x4b, 0xcd, 0xb4, 0x59, 0x31, 0xc8, 0x9f, 0x7e, 0xc9, 0xd9, 0x78, 0x73, 0x64, + 0xea, 0xc5, 0xac, 0x83, 0x34, 0xd3, 0xeb, 0xc3, 0xc5, 0x81, 0xa0, 0xff, 0xfa, 0x13, 0x63, 0xeb, + 0x17, 0x0d, 0xdd, 0x51, 0xb7, 0xf0, 0xda, 0x49, 0xd3, 0x16, 0x55, 0x26, 0x29, 0xd4, 0x68, 0x9e, + 0x2b, 0x16, 0xbe, 0x58, 0x7d, 0x47, 0xa1, 0xfc, 0x8f, 0xf8, 0xb8, 0xd1, 0x7a, 0xd0, 0x31, 0xce, + 0x45, 0xcb, 0x3a, 0x8f, 0x95, 0x16, 0x04, 0x28, 0xaf, 0xd7, 0xfb, 0xca, 0xbb, 0x4b, 0x40, 0x7e, +}; + +static const xxh_u64 PRIME_MX1 = 0x165667919E3779F9ULL; /*!< 0b0001011001010110011001111001000110011110001101110111100111111001 */ +static const xxh_u64 PRIME_MX2 = 0x9FB21C651E98DF25ULL; /*!< 0b1001111110110010000111000110010100011110100110001101111100100101 */ + +#ifdef XXH_OLD_NAMES +# define kSecret XXH3_kSecret +#endif + +#ifdef XXH_DOXYGEN +/*! + * @brief Calculates a 32-bit to 64-bit long multiply. + * + * Implemented as a macro. + * + * Wraps `__emulu` on MSVC x86 because it tends to call `__allmul` when it doesn't + * need to (but it shouldn't need to anyways, it is about 7 instructions to do + * a 64x64 multiply...). Since we know that this will _always_ emit `MULL`, we + * use that instead of the normal method. + * + * If you are compiling for platforms like Thumb-1 and don't have a better option, + * you may also want to write your own long multiply routine here. + * + * @param x, y Numbers to be multiplied + * @return 64-bit product of the low 32 bits of @p x and @p y. + */ +XXH_FORCE_INLINE xxh_u64 +XXH_mult32to64(xxh_u64 x, xxh_u64 y) +{ + return (x & 0xFFFFFFFF) * (y & 0xFFFFFFFF); +} +#elif defined(_MSC_VER) && defined(_M_IX86) +# define XXH_mult32to64(x, y) __emulu((unsigned)(x), (unsigned)(y)) +#else +/* + * Downcast + upcast is usually better than masking on older compilers like + * GCC 4.2 (especially 32-bit ones), all without affecting newer compilers. + * + * The other method, (x & 0xFFFFFFFF) * (y & 0xFFFFFFFF), will AND both operands + * and perform a full 64x64 multiply -- entirely redundant on 32-bit. + */ +# define XXH_mult32to64(x, y) ((xxh_u64)(xxh_u32)(x) * (xxh_u64)(xxh_u32)(y)) +#endif + +/*! + * @brief Calculates a 64->128-bit long multiply. + * + * Uses `__uint128_t` and `_umul128` if available, otherwise uses a scalar + * version. + * + * @param lhs , rhs The 64-bit integers to be multiplied + * @return The 128-bit result represented in an @ref XXH128_hash_t. + */ +static XXH128_hash_t +XXH_mult64to128(xxh_u64 lhs, xxh_u64 rhs) +{ + /* + * GCC/Clang __uint128_t method. + * + * On most 64-bit targets, GCC and Clang define a __uint128_t type. + * This is usually the best way as it usually uses a native long 64-bit + * multiply, such as MULQ on x86_64 or MUL + UMULH on aarch64. + * + * Usually. + * + * Despite being a 32-bit platform, Clang (and emscripten) define this type + * despite not having the arithmetic for it. This results in a laggy + * compiler builtin call which calculates a full 128-bit multiply. + * In that case it is best to use the portable one. + * https://github.com/Cyan4973/xxHash/issues/211#issuecomment-515575677 + */ +#if (defined(__GNUC__) || defined(__clang__)) && !defined(__wasm__) \ + && defined(__SIZEOF_INT128__) \ + || (defined(_INTEGRAL_MAX_BITS) && _INTEGRAL_MAX_BITS >= 128) + + __uint128_t const product = (__uint128_t)lhs * (__uint128_t)rhs; + XXH128_hash_t r128; + r128.low64 = (xxh_u64)(product); + r128.high64 = (xxh_u64)(product >> 64); + return r128; + + /* + * MSVC for x64's _umul128 method. + * + * xxh_u64 _umul128(xxh_u64 Multiplier, xxh_u64 Multiplicand, xxh_u64 *HighProduct); + * + * This compiles to single operand MUL on x64. + */ +#elif (defined(_M_X64) || defined(_M_IA64)) && !defined(_M_ARM64EC) + +#ifndef _MSC_VER +# pragma intrinsic(_umul128) +#endif + xxh_u64 product_high; + xxh_u64 const product_low = _umul128(lhs, rhs, &product_high); + XXH128_hash_t r128; + r128.low64 = product_low; + r128.high64 = product_high; + return r128; + + /* + * MSVC for ARM64's __umulh method. + * + * This compiles to the same MUL + UMULH as GCC/Clang's __uint128_t method. + */ +#elif defined(_M_ARM64) || defined(_M_ARM64EC) + +#ifndef _MSC_VER +# pragma intrinsic(__umulh) +#endif + XXH128_hash_t r128; + r128.low64 = lhs * rhs; + r128.high64 = __umulh(lhs, rhs); + return r128; + +#else + /* + * Portable scalar method. Optimized for 32-bit and 64-bit ALUs. + * + * This is a fast and simple grade school multiply, which is shown below + * with base 10 arithmetic instead of base 0x100000000. + * + * 9 3 // D2 lhs = 93 + * x 7 5 // D2 rhs = 75 + * ---------- + * 1 5 // D2 lo_lo = (93 % 10) * (75 % 10) = 15 + * 4 5 | // D2 hi_lo = (93 / 10) * (75 % 10) = 45 + * 2 1 | // D2 lo_hi = (93 % 10) * (75 / 10) = 21 + * + 6 3 | | // D2 hi_hi = (93 / 10) * (75 / 10) = 63 + * --------- + * 2 7 | // D2 cross = (15 / 10) + (45 % 10) + 21 = 27 + * + 6 7 | | // D2 upper = (27 / 10) + (45 / 10) + 63 = 67 + * --------- + * 6 9 7 5 // D4 res = (27 * 10) + (15 % 10) + (67 * 100) = 6975 + * + * The reasons for adding the products like this are: + * 1. It avoids manual carry tracking. Just like how + * (9 * 9) + 9 + 9 = 99, the same applies with this for UINT64_MAX. + * This avoids a lot of complexity. + * + * 2. It hints for, and on Clang, compiles to, the powerful UMAAL + * instruction available in ARM's Digital Signal Processing extension + * in 32-bit ARMv6 and later, which is shown below: + * + * void UMAAL(xxh_u32 *RdLo, xxh_u32 *RdHi, xxh_u32 Rn, xxh_u32 Rm) + * { + * xxh_u64 product = (xxh_u64)*RdLo * (xxh_u64)*RdHi + Rn + Rm; + * *RdLo = (xxh_u32)(product & 0xFFFFFFFF); + * *RdHi = (xxh_u32)(product >> 32); + * } + * + * This instruction was designed for efficient long multiplication, and + * allows this to be calculated in only 4 instructions at speeds + * comparable to some 64-bit ALUs. + * + * 3. It isn't terrible on other platforms. Usually this will be a couple + * of 32-bit ADD/ADCs. + */ + + /* First calculate all of the cross products. */ + xxh_u64 const lo_lo = XXH_mult32to64(lhs & 0xFFFFFFFF, rhs & 0xFFFFFFFF); + xxh_u64 const hi_lo = XXH_mult32to64(lhs >> 32, rhs & 0xFFFFFFFF); + xxh_u64 const lo_hi = XXH_mult32to64(lhs & 0xFFFFFFFF, rhs >> 32); + xxh_u64 const hi_hi = XXH_mult32to64(lhs >> 32, rhs >> 32); + + /* Now add the products together. These will never overflow. */ + xxh_u64 const cross = (lo_lo >> 32) + (hi_lo & 0xFFFFFFFF) + lo_hi; + xxh_u64 const upper = (hi_lo >> 32) + (cross >> 32) + hi_hi; + xxh_u64 const lower = (cross << 32) | (lo_lo & 0xFFFFFFFF); + + XXH128_hash_t r128; + r128.low64 = lower; + r128.high64 = upper; + return r128; +#endif +} + +/*! + * @brief Calculates a 64-bit to 128-bit multiply, then XOR folds it. + * + * The reason for the separate function is to prevent passing too many structs + * around by value. This will hopefully inline the multiply, but we don't force it. + * + * @param lhs , rhs The 64-bit integers to multiply + * @return The low 64 bits of the product XOR'd by the high 64 bits. + * @see XXH_mult64to128() + */ +static xxh_u64 +XXH3_mul128_fold64(xxh_u64 lhs, xxh_u64 rhs) +{ + XXH128_hash_t product = XXH_mult64to128(lhs, rhs); + return product.low64 ^ product.high64; +} + +/*! Seems to produce slightly better code on GCC for some reason. */ +XXH_FORCE_INLINE XXH_CONSTF xxh_u64 XXH_xorshift64(xxh_u64 v64, int shift) +{ + XXH_ASSERT(0 <= shift && shift < 64); + return v64 ^ (v64 >> shift); +} + +/* + * This is a fast avalanche stage, + * suitable when input bits are already partially mixed + */ +static XXH64_hash_t XXH3_avalanche(xxh_u64 h64) +{ + h64 = XXH_xorshift64(h64, 37); + h64 *= PRIME_MX1; + h64 = XXH_xorshift64(h64, 32); + return h64; +} + +/* + * This is a stronger avalanche, + * inspired by Pelle Evensen's rrmxmx + * preferable when input has not been previously mixed + */ +static XXH64_hash_t XXH3_rrmxmx(xxh_u64 h64, xxh_u64 len) +{ + /* this mix is inspired by Pelle Evensen's rrmxmx */ + h64 ^= XXH_rotl64(h64, 49) ^ XXH_rotl64(h64, 24); + h64 *= PRIME_MX2; + h64 ^= (h64 >> 35) + len ; + h64 *= PRIME_MX2; + return XXH_xorshift64(h64, 28); +} + + +/* ========================================== + * Short keys + * ========================================== + * One of the shortcomings of XXH32 and XXH64 was that their performance was + * sub-optimal on short lengths. It used an iterative algorithm which strongly + * favored lengths that were a multiple of 4 or 8. + * + * Instead of iterating over individual inputs, we use a set of single shot + * functions which piece together a range of lengths and operate in constant time. + * + * Additionally, the number of multiplies has been significantly reduced. This + * reduces latency, especially when emulating 64-bit multiplies on 32-bit. + * + * Depending on the platform, this may or may not be faster than XXH32, but it + * is almost guaranteed to be faster than XXH64. + */ + +/* + * At very short lengths, there isn't enough input to fully hide secrets, or use + * the entire secret. + * + * There is also only a limited amount of mixing we can do before significantly + * impacting performance. + * + * Therefore, we use different sections of the secret and always mix two secret + * samples with an XOR. This should have no effect on performance on the + * seedless or withSeed variants because everything _should_ be constant folded + * by modern compilers. + * + * The XOR mixing hides individual parts of the secret and increases entropy. + * + * This adds an extra layer of strength for custom secrets. + */ +XXH_FORCE_INLINE XXH_PUREF XXH64_hash_t +XXH3_len_1to3_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + XXH_ASSERT(input != NULL); + XXH_ASSERT(1 <= len && len <= 3); + XXH_ASSERT(secret != NULL); + /* + * len = 1: combined = { input[0], 0x01, input[0], input[0] } + * len = 2: combined = { input[1], 0x02, input[0], input[1] } + * len = 3: combined = { input[2], 0x03, input[0], input[1] } + */ + { xxh_u8 const c1 = input[0]; + xxh_u8 const c2 = input[len >> 1]; + xxh_u8 const c3 = input[len - 1]; + xxh_u32 const combined = ((xxh_u32)c1 << 16) | ((xxh_u32)c2 << 24) + | ((xxh_u32)c3 << 0) | ((xxh_u32)len << 8); + xxh_u64 const bitflip = (XXH_readLE32(secret) ^ XXH_readLE32(secret+4)) + seed; + xxh_u64 const keyed = (xxh_u64)combined ^ bitflip; + return XXH64_avalanche(keyed); + } +} + +XXH_FORCE_INLINE XXH_PUREF XXH64_hash_t +XXH3_len_4to8_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + XXH_ASSERT(input != NULL); + XXH_ASSERT(secret != NULL); + XXH_ASSERT(4 <= len && len <= 8); + seed ^= (xxh_u64)XXH_swap32((xxh_u32)seed) << 32; + { xxh_u32 const input1 = XXH_readLE32(input); + xxh_u32 const input2 = XXH_readLE32(input + len - 4); + xxh_u64 const bitflip = (XXH_readLE64(secret+8) ^ XXH_readLE64(secret+16)) - seed; + xxh_u64 const input64 = input2 + (((xxh_u64)input1) << 32); + xxh_u64 const keyed = input64 ^ bitflip; + return XXH3_rrmxmx(keyed, len); + } +} + +XXH_FORCE_INLINE XXH_PUREF XXH64_hash_t +XXH3_len_9to16_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + XXH_ASSERT(input != NULL); + XXH_ASSERT(secret != NULL); + XXH_ASSERT(9 <= len && len <= 16); + { xxh_u64 const bitflip1 = (XXH_readLE64(secret+24) ^ XXH_readLE64(secret+32)) + seed; + xxh_u64 const bitflip2 = (XXH_readLE64(secret+40) ^ XXH_readLE64(secret+48)) - seed; + xxh_u64 const input_lo = XXH_readLE64(input) ^ bitflip1; + xxh_u64 const input_hi = XXH_readLE64(input + len - 8) ^ bitflip2; + xxh_u64 const acc = len + + XXH_swap64(input_lo) + input_hi + + XXH3_mul128_fold64(input_lo, input_hi); + return XXH3_avalanche(acc); + } +} + +XXH_FORCE_INLINE XXH_PUREF XXH64_hash_t +XXH3_len_0to16_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + XXH_ASSERT(len <= 16); + { if (XXH_likely(len > 8)) return XXH3_len_9to16_64b(input, len, secret, seed); + if (XXH_likely(len >= 4)) return XXH3_len_4to8_64b(input, len, secret, seed); + if (len) return XXH3_len_1to3_64b(input, len, secret, seed); + return XXH64_avalanche(seed ^ (XXH_readLE64(secret+56) ^ XXH_readLE64(secret+64))); + } +} + +/* + * DISCLAIMER: There are known *seed-dependent* multicollisions here due to + * multiplication by zero, affecting hashes of lengths 17 to 240. + * + * However, they are very unlikely. + * + * Keep this in mind when using the unseeded XXH3_64bits() variant: As with all + * unseeded non-cryptographic hashes, it does not attempt to defend itself + * against specially crafted inputs, only random inputs. + * + * Compared to classic UMAC where a 1 in 2^31 chance of 4 consecutive bytes + * cancelling out the secret is taken an arbitrary number of times (addressed + * in XXH3_accumulate_512), this collision is very unlikely with random inputs + * and/or proper seeding: + * + * This only has a 1 in 2^63 chance of 8 consecutive bytes cancelling out, in a + * function that is only called up to 16 times per hash with up to 240 bytes of + * input. + * + * This is not too bad for a non-cryptographic hash function, especially with + * only 64 bit outputs. + * + * The 128-bit variant (which trades some speed for strength) is NOT affected + * by this, although it is always a good idea to use a proper seed if you care + * about strength. + */ +XXH_FORCE_INLINE xxh_u64 XXH3_mix16B(const xxh_u8* XXH_RESTRICT input, + const xxh_u8* XXH_RESTRICT secret, xxh_u64 seed64) +{ +#if defined(__GNUC__) && !defined(__clang__) /* GCC, not Clang */ \ + && defined(__i386__) && defined(__SSE2__) /* x86 + SSE2 */ \ + && !defined(XXH_ENABLE_AUTOVECTORIZE) /* Define to disable like XXH32 hack */ + /* + * UGLY HACK: + * GCC for x86 tends to autovectorize the 128-bit multiply, resulting in + * slower code. + * + * By forcing seed64 into a register, we disrupt the cost model and + * cause it to scalarize. See `XXH32_round()` + * + * FIXME: Clang's output is still _much_ faster -- On an AMD Ryzen 3600, + * XXH3_64bits @ len=240 runs at 4.6 GB/s with Clang 9, but 3.3 GB/s on + * GCC 9.2, despite both emitting scalar code. + * + * GCC generates much better scalar code than Clang for the rest of XXH3, + * which is why finding a more optimal codepath is an interest. + */ + XXH_COMPILER_GUARD(seed64); +#endif + { xxh_u64 const input_lo = XXH_readLE64(input); + xxh_u64 const input_hi = XXH_readLE64(input+8); + return XXH3_mul128_fold64( + input_lo ^ (XXH_readLE64(secret) + seed64), + input_hi ^ (XXH_readLE64(secret+8) - seed64) + ); + } +} + +/* For mid range keys, XXH3 uses a Mum-hash variant. */ +XXH_FORCE_INLINE XXH_PUREF XXH64_hash_t +XXH3_len_17to128_64b(const xxh_u8* XXH_RESTRICT input, size_t len, + const xxh_u8* XXH_RESTRICT secret, size_t secretSize, + XXH64_hash_t seed) +{ + XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); (void)secretSize; + XXH_ASSERT(16 < len && len <= 128); + + { xxh_u64 acc = len * XXH_PRIME64_1; +#if XXH_SIZE_OPT >= 1 + /* Smaller and cleaner, but slightly slower. */ + unsigned int i = (unsigned int)(len - 1) / 32; + do { + acc += XXH3_mix16B(input+16 * i, secret+32*i, seed); + acc += XXH3_mix16B(input+len-16*(i+1), secret+32*i+16, seed); + } while (i-- != 0); +#else + if (len > 32) { + if (len > 64) { + if (len > 96) { + acc += XXH3_mix16B(input+48, secret+96, seed); + acc += XXH3_mix16B(input+len-64, secret+112, seed); + } + acc += XXH3_mix16B(input+32, secret+64, seed); + acc += XXH3_mix16B(input+len-48, secret+80, seed); + } + acc += XXH3_mix16B(input+16, secret+32, seed); + acc += XXH3_mix16B(input+len-32, secret+48, seed); + } + acc += XXH3_mix16B(input+0, secret+0, seed); + acc += XXH3_mix16B(input+len-16, secret+16, seed); +#endif + return XXH3_avalanche(acc); + } +} + +XXH_NO_INLINE XXH_PUREF XXH64_hash_t +XXH3_len_129to240_64b(const xxh_u8* XXH_RESTRICT input, size_t len, + const xxh_u8* XXH_RESTRICT secret, size_t secretSize, + XXH64_hash_t seed) +{ + XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); (void)secretSize; + XXH_ASSERT(128 < len && len <= XXH3_MIDSIZE_MAX); + + #define XXH3_MIDSIZE_STARTOFFSET 3 + #define XXH3_MIDSIZE_LASTOFFSET 17 + + { xxh_u64 acc = len * XXH_PRIME64_1; + xxh_u64 acc_end; + unsigned int const nbRounds = (unsigned int)len / 16; + unsigned int i; + XXH_ASSERT(128 < len && len <= XXH3_MIDSIZE_MAX); + for (i=0; i<8; i++) { + acc += XXH3_mix16B(input+(16*i), secret+(16*i), seed); + } + /* last bytes */ + acc_end = XXH3_mix16B(input + len - 16, secret + XXH3_SECRET_SIZE_MIN - XXH3_MIDSIZE_LASTOFFSET, seed); + XXH_ASSERT(nbRounds >= 8); + acc = XXH3_avalanche(acc); +#if defined(__clang__) /* Clang */ \ + && (defined(__ARM_NEON) || defined(__ARM_NEON__)) /* NEON */ \ + && !defined(XXH_ENABLE_AUTOVECTORIZE) /* Define to disable */ + /* + * UGLY HACK: + * Clang for ARMv7-A tries to vectorize this loop, similar to GCC x86. + * In everywhere else, it uses scalar code. + * + * For 64->128-bit multiplies, even if the NEON was 100% optimal, it + * would still be slower than UMAAL (see XXH_mult64to128). + * + * Unfortunately, Clang doesn't handle the long multiplies properly and + * converts them to the nonexistent "vmulq_u64" intrinsic, which is then + * scalarized into an ugly mess of VMOV.32 instructions. + * + * This mess is difficult to avoid without turning autovectorization + * off completely, but they are usually relatively minor and/or not + * worth it to fix. + * + * This loop is the easiest to fix, as unlike XXH32, this pragma + * _actually works_ because it is a loop vectorization instead of an + * SLP vectorization. + */ + #pragma clang loop vectorize(disable) +#endif + for (i=8 ; i < nbRounds; i++) { + /* + * Prevents clang for unrolling the acc loop and interleaving with this one. + */ + XXH_COMPILER_GUARD(acc); + acc_end += XXH3_mix16B(input+(16*i), secret+(16*(i-8)) + XXH3_MIDSIZE_STARTOFFSET, seed); + } + return XXH3_avalanche(acc + acc_end); + } +} + + +/* ======= Long Keys ======= */ + +#define XXH_STRIPE_LEN 64 +#define XXH_SECRET_CONSUME_RATE 8 /* nb of secret bytes consumed at each accumulation */ +#define XXH_ACC_NB (XXH_STRIPE_LEN / sizeof(xxh_u64)) + +#ifdef XXH_OLD_NAMES +# define STRIPE_LEN XXH_STRIPE_LEN +# define ACC_NB XXH_ACC_NB +#endif + +#ifndef XXH_PREFETCH_DIST +# ifdef __clang__ +# define XXH_PREFETCH_DIST 320 +# else +# if (XXH_VECTOR == XXH_AVX512) +# define XXH_PREFETCH_DIST 512 +# else +# define XXH_PREFETCH_DIST 384 +# endif +# endif /* __clang__ */ +#endif /* XXH_PREFETCH_DIST */ + +/* + * These macros are to generate an XXH3_accumulate() function. + * The two arguments select the name suffix and target attribute. + * + * The name of this symbol is XXH3_accumulate_() and it calls + * XXH3_accumulate_512_(). + * + * It may be useful to hand implement this function if the compiler fails to + * optimize the inline function. + */ +#define XXH3_ACCUMULATE_TEMPLATE(name) \ +void \ +XXH3_accumulate_##name(xxh_u64* XXH_RESTRICT acc, \ + const xxh_u8* XXH_RESTRICT input, \ + const xxh_u8* XXH_RESTRICT secret, \ + size_t nbStripes) \ +{ \ + size_t n; \ + for (n = 0; n < nbStripes; n++ ) { \ + const xxh_u8* const in = input + n*XXH_STRIPE_LEN; \ + XXH_PREFETCH(in + XXH_PREFETCH_DIST); \ + XXH3_accumulate_512_##name( \ + acc, \ + in, \ + secret + n*XXH_SECRET_CONSUME_RATE); \ + } \ +} + + +XXH_FORCE_INLINE void XXH_writeLE64(void* dst, xxh_u64 v64) +{ + if (!XXH_CPU_LITTLE_ENDIAN) v64 = XXH_swap64(v64); + XXH_memcpy(dst, &v64, sizeof(v64)); +} + +/* Several intrinsic functions below are supposed to accept __int64 as argument, + * as documented in https://software.intel.com/sites/landingpage/IntrinsicsGuide/ . + * However, several environments do not define __int64 type, + * requiring a workaround. + */ +#if !defined (__VMS) \ + && (defined (__cplusplus) \ + || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) ) + typedef int64_t xxh_i64; +#else + /* the following type must have a width of 64-bit */ + typedef long long xxh_i64; +#endif + + +/* + * XXH3_accumulate_512 is the tightest loop for long inputs, and it is the most optimized. + * + * It is a hardened version of UMAC, based off of FARSH's implementation. + * + * This was chosen because it adapts quite well to 32-bit, 64-bit, and SIMD + * implementations, and it is ridiculously fast. + * + * We harden it by mixing the original input to the accumulators as well as the product. + * + * This means that in the (relatively likely) case of a multiply by zero, the + * original input is preserved. + * + * On 128-bit inputs, we swap 64-bit pairs when we add the input to improve + * cross-pollination, as otherwise the upper and lower halves would be + * essentially independent. + * + * This doesn't matter on 64-bit hashes since they all get merged together in + * the end, so we skip the extra step. + * + * Both XXH3_64bits and XXH3_128bits use this subroutine. + */ + +#if (XXH_VECTOR == XXH_AVX512) \ + || (defined(XXH_DISPATCH_AVX512) && XXH_DISPATCH_AVX512 != 0) + +#ifndef XXH_TARGET_AVX512 +# define XXH_TARGET_AVX512 /* disable attribute target */ +#endif + +XXH_FORCE_INLINE XXH_TARGET_AVX512 void +XXH3_accumulate_512_avx512(void* XXH_RESTRICT acc, + const void* XXH_RESTRICT input, + const void* XXH_RESTRICT secret) +{ + __m512i* const xacc = (__m512i *) acc; + XXH_ASSERT((((size_t)acc) & 63) == 0); + XXH_STATIC_ASSERT(XXH_STRIPE_LEN == sizeof(__m512i)); + + { + /* data_vec = input[0]; */ + __m512i const data_vec = _mm512_loadu_si512 (input); + /* key_vec = secret[0]; */ + __m512i const key_vec = _mm512_loadu_si512 (secret); + /* data_key = data_vec ^ key_vec; */ + __m512i const data_key = _mm512_xor_si512 (data_vec, key_vec); + /* data_key_lo = data_key >> 32; */ + __m512i const data_key_lo = _mm512_srli_epi64 (data_key, 32); + /* product = (data_key & 0xffffffff) * (data_key_lo & 0xffffffff); */ + __m512i const product = _mm512_mul_epu32 (data_key, data_key_lo); + /* xacc[0] += swap(data_vec); */ + __m512i const data_swap = _mm512_shuffle_epi32(data_vec, (_MM_PERM_ENUM)_MM_SHUFFLE(1, 0, 3, 2)); + __m512i const sum = _mm512_add_epi64(*xacc, data_swap); + /* xacc[0] += product; */ + *xacc = _mm512_add_epi64(product, sum); + } +} +XXH_FORCE_INLINE XXH_TARGET_AVX512 XXH3_ACCUMULATE_TEMPLATE(avx512) + +/* + * XXH3_scrambleAcc: Scrambles the accumulators to improve mixing. + * + * Multiplication isn't perfect, as explained by Google in HighwayHash: + * + * // Multiplication mixes/scrambles bytes 0-7 of the 64-bit result to + * // varying degrees. In descending order of goodness, bytes + * // 3 4 2 5 1 6 0 7 have quality 228 224 164 160 100 96 36 32. + * // As expected, the upper and lower bytes are much worse. + * + * Source: https://github.com/google/highwayhash/blob/0aaf66b/highwayhash/hh_avx2.h#L291 + * + * Since our algorithm uses a pseudorandom secret to add some variance into the + * mix, we don't need to (or want to) mix as often or as much as HighwayHash does. + * + * This isn't as tight as XXH3_accumulate, but still written in SIMD to avoid + * extraction. + * + * Both XXH3_64bits and XXH3_128bits use this subroutine. + */ + +XXH_FORCE_INLINE XXH_TARGET_AVX512 void +XXH3_scrambleAcc_avx512(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret) +{ + XXH_ASSERT((((size_t)acc) & 63) == 0); + XXH_STATIC_ASSERT(XXH_STRIPE_LEN == sizeof(__m512i)); + { __m512i* const xacc = (__m512i*) acc; + const __m512i prime32 = _mm512_set1_epi32((int)XXH_PRIME32_1); + + /* xacc[0] ^= (xacc[0] >> 47) */ + __m512i const acc_vec = *xacc; + __m512i const shifted = _mm512_srli_epi64 (acc_vec, 47); + /* xacc[0] ^= secret; */ + __m512i const key_vec = _mm512_loadu_si512 (secret); + __m512i const data_key = _mm512_ternarylogic_epi32(key_vec, acc_vec, shifted, 0x96 /* key_vec ^ acc_vec ^ shifted */); + + /* xacc[0] *= XXH_PRIME32_1; */ + __m512i const data_key_hi = _mm512_srli_epi64 (data_key, 32); + __m512i const prod_lo = _mm512_mul_epu32 (data_key, prime32); + __m512i const prod_hi = _mm512_mul_epu32 (data_key_hi, prime32); + *xacc = _mm512_add_epi64(prod_lo, _mm512_slli_epi64(prod_hi, 32)); + } +} + +XXH_FORCE_INLINE XXH_TARGET_AVX512 void +XXH3_initCustomSecret_avx512(void* XXH_RESTRICT customSecret, xxh_u64 seed64) +{ + XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE & 63) == 0); + XXH_STATIC_ASSERT(XXH_SEC_ALIGN == 64); + XXH_ASSERT(((size_t)customSecret & 63) == 0); + (void)(&XXH_writeLE64); + { int const nbRounds = XXH_SECRET_DEFAULT_SIZE / sizeof(__m512i); + __m512i const seed_pos = _mm512_set1_epi64((xxh_i64)seed64); + __m512i const seed = _mm512_mask_sub_epi64(seed_pos, 0xAA, _mm512_set1_epi8(0), seed_pos); + + const __m512i* const src = (const __m512i*) ((const void*) XXH3_kSecret); + __m512i* const dest = ( __m512i*) customSecret; + int i; + XXH_ASSERT(((size_t)src & 63) == 0); /* control alignment */ + XXH_ASSERT(((size_t)dest & 63) == 0); + for (i=0; i < nbRounds; ++i) { + dest[i] = _mm512_add_epi64(_mm512_load_si512(src + i), seed); + } } +} + +#endif + +#if (XXH_VECTOR == XXH_AVX2) \ + || (defined(XXH_DISPATCH_AVX2) && XXH_DISPATCH_AVX2 != 0) + +#ifndef XXH_TARGET_AVX2 +# define XXH_TARGET_AVX2 /* disable attribute target */ +#endif + +XXH_FORCE_INLINE XXH_TARGET_AVX2 void +XXH3_accumulate_512_avx2( void* XXH_RESTRICT acc, + const void* XXH_RESTRICT input, + const void* XXH_RESTRICT secret) +{ + XXH_ASSERT((((size_t)acc) & 31) == 0); + { __m256i* const xacc = (__m256i *) acc; + /* Unaligned. This is mainly for pointer arithmetic, and because + * _mm256_loadu_si256 requires a const __m256i * pointer for some reason. */ + const __m256i* const xinput = (const __m256i *) input; + /* Unaligned. This is mainly for pointer arithmetic, and because + * _mm256_loadu_si256 requires a const __m256i * pointer for some reason. */ + const __m256i* const xsecret = (const __m256i *) secret; + + size_t i; + for (i=0; i < XXH_STRIPE_LEN/sizeof(__m256i); i++) { + /* data_vec = xinput[i]; */ + __m256i const data_vec = _mm256_loadu_si256 (xinput+i); + /* key_vec = xsecret[i]; */ + __m256i const key_vec = _mm256_loadu_si256 (xsecret+i); + /* data_key = data_vec ^ key_vec; */ + __m256i const data_key = _mm256_xor_si256 (data_vec, key_vec); + /* data_key_lo = data_key >> 32; */ + __m256i const data_key_lo = _mm256_srli_epi64 (data_key, 32); + /* product = (data_key & 0xffffffff) * (data_key_lo & 0xffffffff); */ + __m256i const product = _mm256_mul_epu32 (data_key, data_key_lo); + /* xacc[i] += swap(data_vec); */ + __m256i const data_swap = _mm256_shuffle_epi32(data_vec, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i const sum = _mm256_add_epi64(xacc[i], data_swap); + /* xacc[i] += product; */ + xacc[i] = _mm256_add_epi64(product, sum); + } } +} +XXH_FORCE_INLINE XXH_TARGET_AVX2 XXH3_ACCUMULATE_TEMPLATE(avx2) + +XXH_FORCE_INLINE XXH_TARGET_AVX2 void +XXH3_scrambleAcc_avx2(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret) +{ + XXH_ASSERT((((size_t)acc) & 31) == 0); + { __m256i* const xacc = (__m256i*) acc; + /* Unaligned. This is mainly for pointer arithmetic, and because + * _mm256_loadu_si256 requires a const __m256i * pointer for some reason. */ + const __m256i* const xsecret = (const __m256i *) secret; + const __m256i prime32 = _mm256_set1_epi32((int)XXH_PRIME32_1); + + size_t i; + for (i=0; i < XXH_STRIPE_LEN/sizeof(__m256i); i++) { + /* xacc[i] ^= (xacc[i] >> 47) */ + __m256i const acc_vec = xacc[i]; + __m256i const shifted = _mm256_srli_epi64 (acc_vec, 47); + __m256i const data_vec = _mm256_xor_si256 (acc_vec, shifted); + /* xacc[i] ^= xsecret; */ + __m256i const key_vec = _mm256_loadu_si256 (xsecret+i); + __m256i const data_key = _mm256_xor_si256 (data_vec, key_vec); + + /* xacc[i] *= XXH_PRIME32_1; */ + __m256i const data_key_hi = _mm256_srli_epi64 (data_key, 32); + __m256i const prod_lo = _mm256_mul_epu32 (data_key, prime32); + __m256i const prod_hi = _mm256_mul_epu32 (data_key_hi, prime32); + xacc[i] = _mm256_add_epi64(prod_lo, _mm256_slli_epi64(prod_hi, 32)); + } + } +} + +XXH_FORCE_INLINE XXH_TARGET_AVX2 void XXH3_initCustomSecret_avx2(void* XXH_RESTRICT customSecret, xxh_u64 seed64) +{ + XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE & 31) == 0); + XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE / sizeof(__m256i)) == 6); + XXH_STATIC_ASSERT(XXH_SEC_ALIGN <= 64); + (void)(&XXH_writeLE64); + XXH_PREFETCH(customSecret); + { __m256i const seed = _mm256_set_epi64x((xxh_i64)(0U - seed64), (xxh_i64)seed64, (xxh_i64)(0U - seed64), (xxh_i64)seed64); + + const __m256i* const src = (const __m256i*) ((const void*) XXH3_kSecret); + __m256i* dest = ( __m256i*) customSecret; + +# if defined(__GNUC__) || defined(__clang__) + /* + * On GCC & Clang, marking 'dest' as modified will cause the compiler: + * - do not extract the secret from sse registers in the internal loop + * - use less common registers, and avoid pushing these reg into stack + */ + XXH_COMPILER_GUARD(dest); +# endif + XXH_ASSERT(((size_t)src & 31) == 0); /* control alignment */ + XXH_ASSERT(((size_t)dest & 31) == 0); + + /* GCC -O2 need unroll loop manually */ + dest[0] = _mm256_add_epi64(_mm256_load_si256(src+0), seed); + dest[1] = _mm256_add_epi64(_mm256_load_si256(src+1), seed); + dest[2] = _mm256_add_epi64(_mm256_load_si256(src+2), seed); + dest[3] = _mm256_add_epi64(_mm256_load_si256(src+3), seed); + dest[4] = _mm256_add_epi64(_mm256_load_si256(src+4), seed); + dest[5] = _mm256_add_epi64(_mm256_load_si256(src+5), seed); + } +} + +#endif + +/* x86dispatch always generates SSE2 */ +#if (XXH_VECTOR == XXH_SSE2) || defined(XXH_X86DISPATCH) + +#ifndef XXH_TARGET_SSE2 +# define XXH_TARGET_SSE2 /* disable attribute target */ +#endif + +XXH_FORCE_INLINE XXH_TARGET_SSE2 void +XXH3_accumulate_512_sse2( void* XXH_RESTRICT acc, + const void* XXH_RESTRICT input, + const void* XXH_RESTRICT secret) +{ + /* SSE2 is just a half-scale version of the AVX2 version. */ + XXH_ASSERT((((size_t)acc) & 15) == 0); + { __m128i* const xacc = (__m128i *) acc; + /* Unaligned. This is mainly for pointer arithmetic, and because + * _mm_loadu_si128 requires a const __m128i * pointer for some reason. */ + const __m128i* const xinput = (const __m128i *) input; + /* Unaligned. This is mainly for pointer arithmetic, and because + * _mm_loadu_si128 requires a const __m128i * pointer for some reason. */ + const __m128i* const xsecret = (const __m128i *) secret; + + size_t i; + for (i=0; i < XXH_STRIPE_LEN/sizeof(__m128i); i++) { + /* data_vec = xinput[i]; */ + __m128i const data_vec = _mm_loadu_si128 (xinput+i); + /* key_vec = xsecret[i]; */ + __m128i const key_vec = _mm_loadu_si128 (xsecret+i); + /* data_key = data_vec ^ key_vec; */ + __m128i const data_key = _mm_xor_si128 (data_vec, key_vec); + /* data_key_lo = data_key >> 32; */ + __m128i const data_key_lo = _mm_shuffle_epi32 (data_key, _MM_SHUFFLE(0, 3, 0, 1)); + /* product = (data_key & 0xffffffff) * (data_key_lo & 0xffffffff); */ + __m128i const product = _mm_mul_epu32 (data_key, data_key_lo); + /* xacc[i] += swap(data_vec); */ + __m128i const data_swap = _mm_shuffle_epi32(data_vec, _MM_SHUFFLE(1,0,3,2)); + __m128i const sum = _mm_add_epi64(xacc[i], data_swap); + /* xacc[i] += product; */ + xacc[i] = _mm_add_epi64(product, sum); + } } +} +XXH_FORCE_INLINE XXH_TARGET_SSE2 XXH3_ACCUMULATE_TEMPLATE(sse2) + +XXH_FORCE_INLINE XXH_TARGET_SSE2 void +XXH3_scrambleAcc_sse2(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret) +{ + XXH_ASSERT((((size_t)acc) & 15) == 0); + { __m128i* const xacc = (__m128i*) acc; + /* Unaligned. This is mainly for pointer arithmetic, and because + * _mm_loadu_si128 requires a const __m128i * pointer for some reason. */ + const __m128i* const xsecret = (const __m128i *) secret; + const __m128i prime32 = _mm_set1_epi32((int)XXH_PRIME32_1); + + size_t i; + for (i=0; i < XXH_STRIPE_LEN/sizeof(__m128i); i++) { + /* xacc[i] ^= (xacc[i] >> 47) */ + __m128i const acc_vec = xacc[i]; + __m128i const shifted = _mm_srli_epi64 (acc_vec, 47); + __m128i const data_vec = _mm_xor_si128 (acc_vec, shifted); + /* xacc[i] ^= xsecret[i]; */ + __m128i const key_vec = _mm_loadu_si128 (xsecret+i); + __m128i const data_key = _mm_xor_si128 (data_vec, key_vec); + + /* xacc[i] *= XXH_PRIME32_1; */ + __m128i const data_key_hi = _mm_shuffle_epi32 (data_key, _MM_SHUFFLE(0, 3, 0, 1)); + __m128i const prod_lo = _mm_mul_epu32 (data_key, prime32); + __m128i const prod_hi = _mm_mul_epu32 (data_key_hi, prime32); + xacc[i] = _mm_add_epi64(prod_lo, _mm_slli_epi64(prod_hi, 32)); + } + } +} + +XXH_FORCE_INLINE XXH_TARGET_SSE2 void XXH3_initCustomSecret_sse2(void* XXH_RESTRICT customSecret, xxh_u64 seed64) +{ + XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE & 15) == 0); + (void)(&XXH_writeLE64); + { int const nbRounds = XXH_SECRET_DEFAULT_SIZE / sizeof(__m128i); + +# if defined(_MSC_VER) && defined(_M_IX86) && _MSC_VER < 1900 + /* MSVC 32bit mode does not support _mm_set_epi64x before 2015 */ + XXH_ALIGN(16) const xxh_i64 seed64x2[2] = { (xxh_i64)seed64, (xxh_i64)(0U - seed64) }; + __m128i const seed = _mm_load_si128((__m128i const*)seed64x2); +# else + __m128i const seed = _mm_set_epi64x((xxh_i64)(0U - seed64), (xxh_i64)seed64); +# endif + int i; + + const void* const src16 = XXH3_kSecret; + __m128i* dst16 = (__m128i*) customSecret; +# if defined(__GNUC__) || defined(__clang__) + /* + * On GCC & Clang, marking 'dest' as modified will cause the compiler: + * - do not extract the secret from sse registers in the internal loop + * - use less common registers, and avoid pushing these reg into stack + */ + XXH_COMPILER_GUARD(dst16); +# endif + XXH_ASSERT(((size_t)src16 & 15) == 0); /* control alignment */ + XXH_ASSERT(((size_t)dst16 & 15) == 0); + + for (i=0; i < nbRounds; ++i) { + dst16[i] = _mm_add_epi64(_mm_load_si128((const __m128i *)src16+i), seed); + } } +} + +#endif + +#if (XXH_VECTOR == XXH_NEON) + +/* forward declarations for the scalar routines */ +XXH_FORCE_INLINE void +XXH3_scalarRound(void* XXH_RESTRICT acc, void const* XXH_RESTRICT input, + void const* XXH_RESTRICT secret, size_t lane); + +XXH_FORCE_INLINE void +XXH3_scalarScrambleRound(void* XXH_RESTRICT acc, + void const* XXH_RESTRICT secret, size_t lane); + +/*! + * @internal + * @brief The bulk processing loop for NEON and WASM SIMD128. + * + * The NEON code path is actually partially scalar when running on AArch64. This + * is to optimize the pipelining and can have up to 15% speedup depending on the + * CPU, and it also mitigates some GCC codegen issues. + * + * @see XXH3_NEON_LANES for configuring this and details about this optimization. + * + * NEON's 32-bit to 64-bit long multiply takes a half vector of 32-bit + * integers instead of the other platforms which mask full 64-bit vectors, + * so the setup is more complicated than just shifting right. + * + * Additionally, there is an optimization for 4 lanes at once noted below. + * + * Since, as stated, the most optimal amount of lanes for Cortexes is 6, + * there needs to be *three* versions of the accumulate operation used + * for the remaining 2 lanes. + * + * WASM's SIMD128 uses SIMDe's arm_neon.h polyfill because the intrinsics overlap + * nearly perfectly. + */ + +XXH_FORCE_INLINE void +XXH3_accumulate_512_neon( void* XXH_RESTRICT acc, + const void* XXH_RESTRICT input, + const void* XXH_RESTRICT secret) +{ + XXH_ASSERT((((size_t)acc) & 15) == 0); + XXH_STATIC_ASSERT(XXH3_NEON_LANES > 0 && XXH3_NEON_LANES <= XXH_ACC_NB && XXH3_NEON_LANES % 2 == 0); + { /* GCC for darwin arm64 does not like aliasing here */ + xxh_aliasing_uint64x2_t* const xacc = (xxh_aliasing_uint64x2_t*) acc; + /* We don't use a uint32x4_t pointer because it causes bus errors on ARMv7. */ + uint8_t const* xinput = (const uint8_t *) input; + uint8_t const* xsecret = (const uint8_t *) secret; + + size_t i; +#ifdef __wasm_simd128__ + /* + * On WASM SIMD128, Clang emits direct address loads when XXH3_kSecret + * is constant propagated, which results in it converting it to this + * inside the loop: + * + * a = v128.load(XXH3_kSecret + 0 + $secret_offset, offset = 0) + * b = v128.load(XXH3_kSecret + 16 + $secret_offset, offset = 0) + * ... + * + * This requires a full 32-bit address immediate (and therefore a 6 byte + * instruction) as well as an add for each offset. + * + * Putting an asm guard prevents it from folding (at the cost of losing + * the alignment hint), and uses the free offset in `v128.load` instead + * of adding secret_offset each time which overall reduces code size by + * about a kilobyte and improves performance. + */ + XXH_COMPILER_GUARD(xsecret); +#endif + /* Scalar lanes use the normal scalarRound routine */ + for (i = XXH3_NEON_LANES; i < XXH_ACC_NB; i++) { + XXH3_scalarRound(acc, input, secret, i); + } + i = 0; + /* 4 NEON lanes at a time. */ + for (; i+1 < XXH3_NEON_LANES / 2; i+=2) { + /* data_vec = xinput[i]; */ + uint64x2_t data_vec_1 = XXH_vld1q_u64(xinput + (i * 16)); + uint64x2_t data_vec_2 = XXH_vld1q_u64(xinput + ((i+1) * 16)); + /* key_vec = xsecret[i]; */ + uint64x2_t key_vec_1 = XXH_vld1q_u64(xsecret + (i * 16)); + uint64x2_t key_vec_2 = XXH_vld1q_u64(xsecret + ((i+1) * 16)); + /* data_swap = swap(data_vec) */ + uint64x2_t data_swap_1 = vextq_u64(data_vec_1, data_vec_1, 1); + uint64x2_t data_swap_2 = vextq_u64(data_vec_2, data_vec_2, 1); + /* data_key = data_vec ^ key_vec; */ + uint64x2_t data_key_1 = veorq_u64(data_vec_1, key_vec_1); + uint64x2_t data_key_2 = veorq_u64(data_vec_2, key_vec_2); + + /* + * If we reinterpret the 64x2 vectors as 32x4 vectors, we can use a + * de-interleave operation for 4 lanes in 1 step with `vuzpq_u32` to + * get one vector with the low 32 bits of each lane, and one vector + * with the high 32 bits of each lane. + * + * The intrinsic returns a double vector because the original ARMv7-a + * instruction modified both arguments in place. AArch64 and SIMD128 emit + * two instructions from this intrinsic. + * + * [ dk11L | dk11H | dk12L | dk12H ] -> [ dk11L | dk12L | dk21L | dk22L ] + * [ dk21L | dk21H | dk22L | dk22H ] -> [ dk11H | dk12H | dk21H | dk22H ] + */ + uint32x4x2_t unzipped = vuzpq_u32( + vreinterpretq_u32_u64(data_key_1), + vreinterpretq_u32_u64(data_key_2) + ); + /* data_key_lo = data_key & 0xFFFFFFFF */ + uint32x4_t data_key_lo = unzipped.val[0]; + /* data_key_hi = data_key >> 32 */ + uint32x4_t data_key_hi = unzipped.val[1]; + /* + * Then, we can split the vectors horizontally and multiply which, as for most + * widening intrinsics, have a variant that works on both high half vectors + * for free on AArch64. A similar instruction is available on SIMD128. + * + * sum = data_swap + (u64x2) data_key_lo * (u64x2) data_key_hi + */ + uint64x2_t sum_1 = XXH_vmlal_low_u32(data_swap_1, data_key_lo, data_key_hi); + uint64x2_t sum_2 = XXH_vmlal_high_u32(data_swap_2, data_key_lo, data_key_hi); + /* + * Clang reorders + * a += b * c; // umlal swap.2d, dkl.2s, dkh.2s + * c += a; // add acc.2d, acc.2d, swap.2d + * to + * c += a; // add acc.2d, acc.2d, swap.2d + * c += b * c; // umlal acc.2d, dkl.2s, dkh.2s + * + * While it would make sense in theory since the addition is faster, + * for reasons likely related to umlal being limited to certain NEON + * pipelines, this is worse. A compiler guard fixes this. + */ + XXH_COMPILER_GUARD_CLANG_NEON(sum_1); + XXH_COMPILER_GUARD_CLANG_NEON(sum_2); + /* xacc[i] = acc_vec + sum; */ + xacc[i] = vaddq_u64(xacc[i], sum_1); + xacc[i+1] = vaddq_u64(xacc[i+1], sum_2); + } + /* Operate on the remaining NEON lanes 2 at a time. */ + for (; i < XXH3_NEON_LANES / 2; i++) { + /* data_vec = xinput[i]; */ + uint64x2_t data_vec = XXH_vld1q_u64(xinput + (i * 16)); + /* key_vec = xsecret[i]; */ + uint64x2_t key_vec = XXH_vld1q_u64(xsecret + (i * 16)); + /* acc_vec_2 = swap(data_vec) */ + uint64x2_t data_swap = vextq_u64(data_vec, data_vec, 1); + /* data_key = data_vec ^ key_vec; */ + uint64x2_t data_key = veorq_u64(data_vec, key_vec); + /* For two lanes, just use VMOVN and VSHRN. */ + /* data_key_lo = data_key & 0xFFFFFFFF; */ + uint32x2_t data_key_lo = vmovn_u64(data_key); + /* data_key_hi = data_key >> 32; */ + uint32x2_t data_key_hi = vshrn_n_u64(data_key, 32); + /* sum = data_swap + (u64x2) data_key_lo * (u64x2) data_key_hi; */ + uint64x2_t sum = vmlal_u32(data_swap, data_key_lo, data_key_hi); + /* Same Clang workaround as before */ + XXH_COMPILER_GUARD_CLANG_NEON(sum); + /* xacc[i] = acc_vec + sum; */ + xacc[i] = vaddq_u64 (xacc[i], sum); + } + } +} +XXH_FORCE_INLINE XXH3_ACCUMULATE_TEMPLATE(neon) + +XXH_FORCE_INLINE void +XXH3_scrambleAcc_neon(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret) +{ + XXH_ASSERT((((size_t)acc) & 15) == 0); + + { xxh_aliasing_uint64x2_t* xacc = (xxh_aliasing_uint64x2_t*) acc; + uint8_t const* xsecret = (uint8_t const*) secret; + + size_t i; + /* WASM uses operator overloads and doesn't need these. */ +#ifndef __wasm_simd128__ + /* { prime32_1, prime32_1 } */ + uint32x2_t const kPrimeLo = vdup_n_u32(XXH_PRIME32_1); + /* { 0, prime32_1, 0, prime32_1 } */ + uint32x4_t const kPrimeHi = vreinterpretq_u32_u64(vdupq_n_u64((xxh_u64)XXH_PRIME32_1 << 32)); +#endif + + /* AArch64 uses both scalar and neon at the same time */ + for (i = XXH3_NEON_LANES; i < XXH_ACC_NB; i++) { + XXH3_scalarScrambleRound(acc, secret, i); + } + for (i=0; i < XXH3_NEON_LANES / 2; i++) { + /* xacc[i] ^= (xacc[i] >> 47); */ + uint64x2_t acc_vec = xacc[i]; + uint64x2_t shifted = vshrq_n_u64(acc_vec, 47); + uint64x2_t data_vec = veorq_u64(acc_vec, shifted); + + /* xacc[i] ^= xsecret[i]; */ + uint64x2_t key_vec = XXH_vld1q_u64(xsecret + (i * 16)); + uint64x2_t data_key = veorq_u64(data_vec, key_vec); + /* xacc[i] *= XXH_PRIME32_1 */ +#ifdef __wasm_simd128__ + /* SIMD128 has multiply by u64x2, use it instead of expanding and scalarizing */ + xacc[i] = data_key * XXH_PRIME32_1; +#else + /* + * Expanded version with portable NEON intrinsics + * + * lo(x) * lo(y) + (hi(x) * lo(y) << 32) + * + * prod_hi = hi(data_key) * lo(prime) << 32 + * + * Since we only need 32 bits of this multiply a trick can be used, reinterpreting the vector + * as a uint32x4_t and multiplying by { 0, prime, 0, prime } to cancel out the unwanted bits + * and avoid the shift. + */ + uint32x4_t prod_hi = vmulq_u32 (vreinterpretq_u32_u64(data_key), kPrimeHi); + /* Extract low bits for vmlal_u32 */ + uint32x2_t data_key_lo = vmovn_u64(data_key); + /* xacc[i] = prod_hi + lo(data_key) * XXH_PRIME32_1; */ + xacc[i] = vmlal_u32(vreinterpretq_u64_u32(prod_hi), data_key_lo, kPrimeLo); +#endif + } + } +} +#endif + +#if (XXH_VECTOR == XXH_VSX) + +XXH_FORCE_INLINE void +XXH3_accumulate_512_vsx( void* XXH_RESTRICT acc, + const void* XXH_RESTRICT input, + const void* XXH_RESTRICT secret) +{ + /* presumed aligned */ + xxh_aliasing_u64x2* const xacc = (xxh_aliasing_u64x2*) acc; + xxh_u8 const* const xinput = (xxh_u8 const*) input; /* no alignment restriction */ + xxh_u8 const* const xsecret = (xxh_u8 const*) secret; /* no alignment restriction */ + xxh_u64x2 const v32 = { 32, 32 }; + size_t i; + for (i = 0; i < XXH_STRIPE_LEN / sizeof(xxh_u64x2); i++) { + /* data_vec = xinput[i]; */ + xxh_u64x2 const data_vec = XXH_vec_loadu(xinput + 16*i); + /* key_vec = xsecret[i]; */ + xxh_u64x2 const key_vec = XXH_vec_loadu(xsecret + 16*i); + xxh_u64x2 const data_key = data_vec ^ key_vec; + /* shuffled = (data_key << 32) | (data_key >> 32); */ + xxh_u32x4 const shuffled = (xxh_u32x4)vec_rl(data_key, v32); + /* product = ((xxh_u64x2)data_key & 0xFFFFFFFF) * ((xxh_u64x2)shuffled & 0xFFFFFFFF); */ + xxh_u64x2 const product = XXH_vec_mulo((xxh_u32x4)data_key, shuffled); + /* acc_vec = xacc[i]; */ + xxh_u64x2 acc_vec = xacc[i]; + acc_vec += product; + + /* swap high and low halves */ +#ifdef __s390x__ + acc_vec += vec_permi(data_vec, data_vec, 2); +#else + acc_vec += vec_xxpermdi(data_vec, data_vec, 2); +#endif + xacc[i] = acc_vec; + } +} +XXH_FORCE_INLINE XXH3_ACCUMULATE_TEMPLATE(vsx) + +XXH_FORCE_INLINE void +XXH3_scrambleAcc_vsx(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret) +{ + XXH_ASSERT((((size_t)acc) & 15) == 0); + + { xxh_aliasing_u64x2* const xacc = (xxh_aliasing_u64x2*) acc; + const xxh_u8* const xsecret = (const xxh_u8*) secret; + /* constants */ + xxh_u64x2 const v32 = { 32, 32 }; + xxh_u64x2 const v47 = { 47, 47 }; + xxh_u32x4 const prime = { XXH_PRIME32_1, XXH_PRIME32_1, XXH_PRIME32_1, XXH_PRIME32_1 }; + size_t i; + for (i = 0; i < XXH_STRIPE_LEN / sizeof(xxh_u64x2); i++) { + /* xacc[i] ^= (xacc[i] >> 47); */ + xxh_u64x2 const acc_vec = xacc[i]; + xxh_u64x2 const data_vec = acc_vec ^ (acc_vec >> v47); + + /* xacc[i] ^= xsecret[i]; */ + xxh_u64x2 const key_vec = XXH_vec_loadu(xsecret + 16*i); + xxh_u64x2 const data_key = data_vec ^ key_vec; + + /* xacc[i] *= XXH_PRIME32_1 */ + /* prod_lo = ((xxh_u64x2)data_key & 0xFFFFFFFF) * ((xxh_u64x2)prime & 0xFFFFFFFF); */ + xxh_u64x2 const prod_even = XXH_vec_mule((xxh_u32x4)data_key, prime); + /* prod_hi = ((xxh_u64x2)data_key >> 32) * ((xxh_u64x2)prime >> 32); */ + xxh_u64x2 const prod_odd = XXH_vec_mulo((xxh_u32x4)data_key, prime); + xacc[i] = prod_odd + (prod_even << v32); + } } +} + +#endif + +#if (XXH_VECTOR == XXH_SVE) + +XXH_FORCE_INLINE void +XXH3_accumulate_512_sve( void* XXH_RESTRICT acc, + const void* XXH_RESTRICT input, + const void* XXH_RESTRICT secret) +{ + uint64_t *xacc = (uint64_t *)acc; + const uint64_t *xinput = (const uint64_t *)(const void *)input; + const uint64_t *xsecret = (const uint64_t *)(const void *)secret; + svuint64_t kSwap = sveor_n_u64_z(svptrue_b64(), svindex_u64(0, 1), 1); + uint64_t element_count = svcntd(); + if (element_count >= 8) { + svbool_t mask = svptrue_pat_b64(SV_VL8); + svuint64_t vacc = svld1_u64(mask, xacc); + ACCRND(vacc, 0); + svst1_u64(mask, xacc, vacc); + } else if (element_count == 2) { /* sve128 */ + svbool_t mask = svptrue_pat_b64(SV_VL2); + svuint64_t acc0 = svld1_u64(mask, xacc + 0); + svuint64_t acc1 = svld1_u64(mask, xacc + 2); + svuint64_t acc2 = svld1_u64(mask, xacc + 4); + svuint64_t acc3 = svld1_u64(mask, xacc + 6); + ACCRND(acc0, 0); + ACCRND(acc1, 2); + ACCRND(acc2, 4); + ACCRND(acc3, 6); + svst1_u64(mask, xacc + 0, acc0); + svst1_u64(mask, xacc + 2, acc1); + svst1_u64(mask, xacc + 4, acc2); + svst1_u64(mask, xacc + 6, acc3); + } else { + svbool_t mask = svptrue_pat_b64(SV_VL4); + svuint64_t acc0 = svld1_u64(mask, xacc + 0); + svuint64_t acc1 = svld1_u64(mask, xacc + 4); + ACCRND(acc0, 0); + ACCRND(acc1, 4); + svst1_u64(mask, xacc + 0, acc0); + svst1_u64(mask, xacc + 4, acc1); + } +} + +XXH_FORCE_INLINE void +XXH3_accumulate_sve(xxh_u64* XXH_RESTRICT acc, + const xxh_u8* XXH_RESTRICT input, + const xxh_u8* XXH_RESTRICT secret, + size_t nbStripes) +{ + if (nbStripes != 0) { + uint64_t *xacc = (uint64_t *)acc; + const uint64_t *xinput = (const uint64_t *)(const void *)input; + const uint64_t *xsecret = (const uint64_t *)(const void *)secret; + svuint64_t kSwap = sveor_n_u64_z(svptrue_b64(), svindex_u64(0, 1), 1); + uint64_t element_count = svcntd(); + if (element_count >= 8) { + svbool_t mask = svptrue_pat_b64(SV_VL8); + svuint64_t vacc = svld1_u64(mask, xacc + 0); + do { + /* svprfd(svbool_t, void *, enum svfprop); */ + svprfd(mask, xinput + 128, SV_PLDL1STRM); + ACCRND(vacc, 0); + xinput += 8; + xsecret += 1; + nbStripes--; + } while (nbStripes != 0); + + svst1_u64(mask, xacc + 0, vacc); + } else if (element_count == 2) { /* sve128 */ + svbool_t mask = svptrue_pat_b64(SV_VL2); + svuint64_t acc0 = svld1_u64(mask, xacc + 0); + svuint64_t acc1 = svld1_u64(mask, xacc + 2); + svuint64_t acc2 = svld1_u64(mask, xacc + 4); + svuint64_t acc3 = svld1_u64(mask, xacc + 6); + do { + svprfd(mask, xinput + 128, SV_PLDL1STRM); + ACCRND(acc0, 0); + ACCRND(acc1, 2); + ACCRND(acc2, 4); + ACCRND(acc3, 6); + xinput += 8; + xsecret += 1; + nbStripes--; + } while (nbStripes != 0); + + svst1_u64(mask, xacc + 0, acc0); + svst1_u64(mask, xacc + 2, acc1); + svst1_u64(mask, xacc + 4, acc2); + svst1_u64(mask, xacc + 6, acc3); + } else { + svbool_t mask = svptrue_pat_b64(SV_VL4); + svuint64_t acc0 = svld1_u64(mask, xacc + 0); + svuint64_t acc1 = svld1_u64(mask, xacc + 4); + do { + svprfd(mask, xinput + 128, SV_PLDL1STRM); + ACCRND(acc0, 0); + ACCRND(acc1, 4); + xinput += 8; + xsecret += 1; + nbStripes--; + } while (nbStripes != 0); + + svst1_u64(mask, xacc + 0, acc0); + svst1_u64(mask, xacc + 4, acc1); + } + } +} + +#endif + +/* scalar variants - universal */ + +#if defined(__aarch64__) && (defined(__GNUC__) || defined(__clang__)) +/* + * In XXH3_scalarRound(), GCC and Clang have a similar codegen issue, where they + * emit an excess mask and a full 64-bit multiply-add (MADD X-form). + * + * While this might not seem like much, as AArch64 is a 64-bit architecture, only + * big Cortex designs have a full 64-bit multiplier. + * + * On the little cores, the smaller 32-bit multiplier is used, and full 64-bit + * multiplies expand to 2-3 multiplies in microcode. This has a major penalty + * of up to 4 latency cycles and 2 stall cycles in the multiply pipeline. + * + * Thankfully, AArch64 still provides the 32-bit long multiply-add (UMADDL) which does + * not have this penalty and does the mask automatically. + */ +XXH_FORCE_INLINE xxh_u64 +XXH_mult32to64_add64(xxh_u64 lhs, xxh_u64 rhs, xxh_u64 acc) +{ + xxh_u64 ret; + /* note: %x = 64-bit register, %w = 32-bit register */ + __asm__("umaddl %x0, %w1, %w2, %x3" : "=r" (ret) : "r" (lhs), "r" (rhs), "r" (acc)); + return ret; +} +#else +XXH_FORCE_INLINE xxh_u64 +XXH_mult32to64_add64(xxh_u64 lhs, xxh_u64 rhs, xxh_u64 acc) +{ + return XXH_mult32to64((xxh_u32)lhs, (xxh_u32)rhs) + acc; +} +#endif + +/*! + * @internal + * @brief Scalar round for @ref XXH3_accumulate_512_scalar(). + * + * This is extracted to its own function because the NEON path uses a combination + * of NEON and scalar. + */ +XXH_FORCE_INLINE void +XXH3_scalarRound(void* XXH_RESTRICT acc, + void const* XXH_RESTRICT input, + void const* XXH_RESTRICT secret, + size_t lane) +{ + xxh_u64* xacc = (xxh_u64*) acc; + xxh_u8 const* xinput = (xxh_u8 const*) input; + xxh_u8 const* xsecret = (xxh_u8 const*) secret; + XXH_ASSERT(lane < XXH_ACC_NB); + XXH_ASSERT(((size_t)acc & (XXH_ACC_ALIGN-1)) == 0); + { + xxh_u64 const data_val = XXH_readLE64(xinput + lane * 8); + xxh_u64 const data_key = data_val ^ XXH_readLE64(xsecret + lane * 8); + xacc[lane ^ 1] += data_val; /* swap adjacent lanes */ + xacc[lane] = XXH_mult32to64_add64(data_key /* & 0xFFFFFFFF */, data_key >> 32, xacc[lane]); + } +} + +/*! + * @internal + * @brief Processes a 64 byte block of data using the scalar path. + */ +XXH_FORCE_INLINE void +XXH3_accumulate_512_scalar(void* XXH_RESTRICT acc, + const void* XXH_RESTRICT input, + const void* XXH_RESTRICT secret) +{ + size_t i; + /* ARM GCC refuses to unroll this loop, resulting in a 24% slowdown on ARMv6. */ +#if defined(__GNUC__) && !defined(__clang__) \ + && (defined(__arm__) || defined(__thumb2__)) \ + && defined(__ARM_FEATURE_UNALIGNED) /* no unaligned access just wastes bytes */ \ + && XXH_SIZE_OPT <= 0 +# pragma GCC unroll 8 +#endif + for (i=0; i < XXH_ACC_NB; i++) { + XXH3_scalarRound(acc, input, secret, i); + } +} +XXH_FORCE_INLINE XXH3_ACCUMULATE_TEMPLATE(scalar) + +/*! + * @internal + * @brief Scalar scramble step for @ref XXH3_scrambleAcc_scalar(). + * + * This is extracted to its own function because the NEON path uses a combination + * of NEON and scalar. + */ +XXH_FORCE_INLINE void +XXH3_scalarScrambleRound(void* XXH_RESTRICT acc, + void const* XXH_RESTRICT secret, + size_t lane) +{ + xxh_u64* const xacc = (xxh_u64*) acc; /* presumed aligned */ + const xxh_u8* const xsecret = (const xxh_u8*) secret; /* no alignment restriction */ + XXH_ASSERT((((size_t)acc) & (XXH_ACC_ALIGN-1)) == 0); + XXH_ASSERT(lane < XXH_ACC_NB); + { + xxh_u64 const key64 = XXH_readLE64(xsecret + lane * 8); + xxh_u64 acc64 = xacc[lane]; + acc64 = XXH_xorshift64(acc64, 47); + acc64 ^= key64; + acc64 *= XXH_PRIME32_1; + xacc[lane] = acc64; + } +} + +/*! + * @internal + * @brief Scrambles the accumulators after a large chunk has been read + */ +XXH_FORCE_INLINE void +XXH3_scrambleAcc_scalar(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret) +{ + size_t i; + for (i=0; i < XXH_ACC_NB; i++) { + XXH3_scalarScrambleRound(acc, secret, i); + } +} + +XXH_FORCE_INLINE void +XXH3_initCustomSecret_scalar(void* XXH_RESTRICT customSecret, xxh_u64 seed64) +{ + /* + * We need a separate pointer for the hack below, + * which requires a non-const pointer. + * Any decent compiler will optimize this out otherwise. + */ + const xxh_u8* kSecretPtr = XXH3_kSecret; + XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE & 15) == 0); + +#if defined(__GNUC__) && defined(__aarch64__) + /* + * UGLY HACK: + * GCC and Clang generate a bunch of MOV/MOVK pairs for aarch64, and they are + * placed sequentially, in order, at the top of the unrolled loop. + * + * While MOVK is great for generating constants (2 cycles for a 64-bit + * constant compared to 4 cycles for LDR), it fights for bandwidth with + * the arithmetic instructions. + * + * I L S + * MOVK + * MOVK + * MOVK + * MOVK + * ADD + * SUB STR + * STR + * By forcing loads from memory (as the asm line causes the compiler to assume + * that XXH3_kSecretPtr has been changed), the pipelines are used more + * efficiently: + * I L S + * LDR + * ADD LDR + * SUB STR + * STR + * + * See XXH3_NEON_LANES for details on the pipsline. + * + * XXH3_64bits_withSeed, len == 256, Snapdragon 835 + * without hack: 2654.4 MB/s + * with hack: 3202.9 MB/s + */ + XXH_COMPILER_GUARD(kSecretPtr); +#endif + { int const nbRounds = XXH_SECRET_DEFAULT_SIZE / 16; + int i; + for (i=0; i < nbRounds; i++) { + /* + * The asm hack causes the compiler to assume that kSecretPtr aliases with + * customSecret, and on aarch64, this prevented LDP from merging two + * loads together for free. Putting the loads together before the stores + * properly generates LDP. + */ + xxh_u64 lo = XXH_readLE64(kSecretPtr + 16*i) + seed64; + xxh_u64 hi = XXH_readLE64(kSecretPtr + 16*i + 8) - seed64; + XXH_writeLE64((xxh_u8*)customSecret + 16*i, lo); + XXH_writeLE64((xxh_u8*)customSecret + 16*i + 8, hi); + } } +} + + +typedef void (*XXH3_f_accumulate)(xxh_u64* XXH_RESTRICT, const xxh_u8* XXH_RESTRICT, const xxh_u8* XXH_RESTRICT, size_t); +typedef void (*XXH3_f_scrambleAcc)(void* XXH_RESTRICT, const void*); +typedef void (*XXH3_f_initCustomSecret)(void* XXH_RESTRICT, xxh_u64); + + +#if (XXH_VECTOR == XXH_AVX512) + +#define XXH3_accumulate_512 XXH3_accumulate_512_avx512 +#define XXH3_accumulate XXH3_accumulate_avx512 +#define XXH3_scrambleAcc XXH3_scrambleAcc_avx512 +#define XXH3_initCustomSecret XXH3_initCustomSecret_avx512 + +#elif (XXH_VECTOR == XXH_AVX2) + +#define XXH3_accumulate_512 XXH3_accumulate_512_avx2 +#define XXH3_accumulate XXH3_accumulate_avx2 +#define XXH3_scrambleAcc XXH3_scrambleAcc_avx2 +#define XXH3_initCustomSecret XXH3_initCustomSecret_avx2 + +#elif (XXH_VECTOR == XXH_SSE2) + +#define XXH3_accumulate_512 XXH3_accumulate_512_sse2 +#define XXH3_accumulate XXH3_accumulate_sse2 +#define XXH3_scrambleAcc XXH3_scrambleAcc_sse2 +#define XXH3_initCustomSecret XXH3_initCustomSecret_sse2 + +#elif (XXH_VECTOR == XXH_NEON) + +#define XXH3_accumulate_512 XXH3_accumulate_512_neon +#define XXH3_accumulate XXH3_accumulate_neon +#define XXH3_scrambleAcc XXH3_scrambleAcc_neon +#define XXH3_initCustomSecret XXH3_initCustomSecret_scalar + +#elif (XXH_VECTOR == XXH_VSX) + +#define XXH3_accumulate_512 XXH3_accumulate_512_vsx +#define XXH3_accumulate XXH3_accumulate_vsx +#define XXH3_scrambleAcc XXH3_scrambleAcc_vsx +#define XXH3_initCustomSecret XXH3_initCustomSecret_scalar + +#elif (XXH_VECTOR == XXH_SVE) +#define XXH3_accumulate_512 XXH3_accumulate_512_sve +#define XXH3_accumulate XXH3_accumulate_sve +#define XXH3_scrambleAcc XXH3_scrambleAcc_scalar +#define XXH3_initCustomSecret XXH3_initCustomSecret_scalar + +#else /* scalar */ + +#define XXH3_accumulate_512 XXH3_accumulate_512_scalar +#define XXH3_accumulate XXH3_accumulate_scalar +#define XXH3_scrambleAcc XXH3_scrambleAcc_scalar +#define XXH3_initCustomSecret XXH3_initCustomSecret_scalar + +#endif + +#if XXH_SIZE_OPT >= 1 /* don't do SIMD for initialization */ +# undef XXH3_initCustomSecret +# define XXH3_initCustomSecret XXH3_initCustomSecret_scalar +#endif + +XXH_FORCE_INLINE void +XXH3_hashLong_internal_loop(xxh_u64* XXH_RESTRICT acc, + const xxh_u8* XXH_RESTRICT input, size_t len, + const xxh_u8* XXH_RESTRICT secret, size_t secretSize, + XXH3_f_accumulate f_acc, + XXH3_f_scrambleAcc f_scramble) +{ + size_t const nbStripesPerBlock = (secretSize - XXH_STRIPE_LEN) / XXH_SECRET_CONSUME_RATE; + size_t const block_len = XXH_STRIPE_LEN * nbStripesPerBlock; + size_t const nb_blocks = (len - 1) / block_len; + + size_t n; + + XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); + + for (n = 0; n < nb_blocks; n++) { + f_acc(acc, input + n*block_len, secret, nbStripesPerBlock); + f_scramble(acc, secret + secretSize - XXH_STRIPE_LEN); + } + + /* last partial block */ + XXH_ASSERT(len > XXH_STRIPE_LEN); + { size_t const nbStripes = ((len - 1) - (block_len * nb_blocks)) / XXH_STRIPE_LEN; + XXH_ASSERT(nbStripes <= (secretSize / XXH_SECRET_CONSUME_RATE)); + f_acc(acc, input + nb_blocks*block_len, secret, nbStripes); + + /* last stripe */ + { const xxh_u8* const p = input + len - XXH_STRIPE_LEN; +#define XXH_SECRET_LASTACC_START 7 /* not aligned on 8, last secret is different from acc & scrambler */ + XXH3_accumulate_512(acc, p, secret + secretSize - XXH_STRIPE_LEN - XXH_SECRET_LASTACC_START); + } } +} + +XXH_FORCE_INLINE xxh_u64 +XXH3_mix2Accs(const xxh_u64* XXH_RESTRICT acc, const xxh_u8* XXH_RESTRICT secret) +{ + return XXH3_mul128_fold64( + acc[0] ^ XXH_readLE64(secret), + acc[1] ^ XXH_readLE64(secret+8) ); +} + +static XXH64_hash_t +XXH3_mergeAccs(const xxh_u64* XXH_RESTRICT acc, const xxh_u8* XXH_RESTRICT secret, xxh_u64 start) +{ + xxh_u64 result64 = start; + size_t i = 0; + + for (i = 0; i < 4; i++) { + result64 += XXH3_mix2Accs(acc+2*i, secret + 16*i); +#if defined(__clang__) /* Clang */ \ + && (defined(__arm__) || defined(__thumb__)) /* ARMv7 */ \ + && (defined(__ARM_NEON) || defined(__ARM_NEON__)) /* NEON */ \ + && !defined(XXH_ENABLE_AUTOVECTORIZE) /* Define to disable */ + /* + * UGLY HACK: + * Prevent autovectorization on Clang ARMv7-a. Exact same problem as + * the one in XXH3_len_129to240_64b. Speeds up shorter keys > 240b. + * XXH3_64bits, len == 256, Snapdragon 835: + * without hack: 2063.7 MB/s + * with hack: 2560.7 MB/s + */ + XXH_COMPILER_GUARD(result64); +#endif + } + + return XXH3_avalanche(result64); +} + +#define XXH3_INIT_ACC { XXH_PRIME32_3, XXH_PRIME64_1, XXH_PRIME64_2, XXH_PRIME64_3, \ + XXH_PRIME64_4, XXH_PRIME32_2, XXH_PRIME64_5, XXH_PRIME32_1 } + +XXH_FORCE_INLINE XXH64_hash_t +XXH3_hashLong_64b_internal(const void* XXH_RESTRICT input, size_t len, + const void* XXH_RESTRICT secret, size_t secretSize, + XXH3_f_accumulate f_acc, + XXH3_f_scrambleAcc f_scramble) +{ + XXH_ALIGN(XXH_ACC_ALIGN) xxh_u64 acc[XXH_ACC_NB] = XXH3_INIT_ACC; + + XXH3_hashLong_internal_loop(acc, (const xxh_u8*)input, len, (const xxh_u8*)secret, secretSize, f_acc, f_scramble); + + /* converge into final hash */ + XXH_STATIC_ASSERT(sizeof(acc) == 64); + /* do not align on 8, so that the secret is different from the accumulator */ +#define XXH_SECRET_MERGEACCS_START 11 + XXH_ASSERT(secretSize >= sizeof(acc) + XXH_SECRET_MERGEACCS_START); + return XXH3_mergeAccs(acc, (const xxh_u8*)secret + XXH_SECRET_MERGEACCS_START, (xxh_u64)len * XXH_PRIME64_1); +} + +/* + * It's important for performance to transmit secret's size (when it's static) + * so that the compiler can properly optimize the vectorized loop. + * This makes a big performance difference for "medium" keys (<1 KB) when using AVX instruction set. + * When the secret size is unknown, or on GCC 12 where the mix of NO_INLINE and FORCE_INLINE + * breaks -Og, this is XXH_NO_INLINE. + */ +XXH3_WITH_SECRET_INLINE XXH64_hash_t +XXH3_hashLong_64b_withSecret(const void* XXH_RESTRICT input, size_t len, + XXH64_hash_t seed64, const xxh_u8* XXH_RESTRICT secret, size_t secretLen) +{ + (void)seed64; + return XXH3_hashLong_64b_internal(input, len, secret, secretLen, XXH3_accumulate, XXH3_scrambleAcc); +} + +/* + * It's preferable for performance that XXH3_hashLong is not inlined, + * as it results in a smaller function for small data, easier to the instruction cache. + * Note that inside this no_inline function, we do inline the internal loop, + * and provide a statically defined secret size to allow optimization of vector loop. + */ +XXH_NO_INLINE XXH_PUREF XXH64_hash_t +XXH3_hashLong_64b_default(const void* XXH_RESTRICT input, size_t len, + XXH64_hash_t seed64, const xxh_u8* XXH_RESTRICT secret, size_t secretLen) +{ + (void)seed64; (void)secret; (void)secretLen; + return XXH3_hashLong_64b_internal(input, len, XXH3_kSecret, sizeof(XXH3_kSecret), XXH3_accumulate, XXH3_scrambleAcc); +} + +/* + * XXH3_hashLong_64b_withSeed(): + * Generate a custom key based on alteration of default XXH3_kSecret with the seed, + * and then use this key for long mode hashing. + * + * This operation is decently fast but nonetheless costs a little bit of time. + * Try to avoid it whenever possible (typically when seed==0). + * + * It's important for performance that XXH3_hashLong is not inlined. Not sure + * why (uop cache maybe?), but the difference is large and easily measurable. + */ +XXH_FORCE_INLINE XXH64_hash_t +XXH3_hashLong_64b_withSeed_internal(const void* input, size_t len, + XXH64_hash_t seed, + XXH3_f_accumulate f_acc, + XXH3_f_scrambleAcc f_scramble, + XXH3_f_initCustomSecret f_initSec) +{ +#if XXH_SIZE_OPT <= 0 + if (seed == 0) + return XXH3_hashLong_64b_internal(input, len, + XXH3_kSecret, sizeof(XXH3_kSecret), + f_acc, f_scramble); +#endif + { XXH_ALIGN(XXH_SEC_ALIGN) xxh_u8 secret[XXH_SECRET_DEFAULT_SIZE]; + f_initSec(secret, seed); + return XXH3_hashLong_64b_internal(input, len, secret, sizeof(secret), + f_acc, f_scramble); + } +} + +/* + * It's important for performance that XXH3_hashLong is not inlined. + */ +XXH_NO_INLINE XXH64_hash_t +XXH3_hashLong_64b_withSeed(const void* XXH_RESTRICT input, size_t len, + XXH64_hash_t seed, const xxh_u8* XXH_RESTRICT secret, size_t secretLen) +{ + (void)secret; (void)secretLen; + return XXH3_hashLong_64b_withSeed_internal(input, len, seed, + XXH3_accumulate, XXH3_scrambleAcc, XXH3_initCustomSecret); +} + + +typedef XXH64_hash_t (*XXH3_hashLong64_f)(const void* XXH_RESTRICT, size_t, + XXH64_hash_t, const xxh_u8* XXH_RESTRICT, size_t); + +XXH_FORCE_INLINE XXH64_hash_t +XXH3_64bits_internal(const void* XXH_RESTRICT input, size_t len, + XXH64_hash_t seed64, const void* XXH_RESTRICT secret, size_t secretLen, + XXH3_hashLong64_f f_hashLong) +{ + XXH_ASSERT(secretLen >= XXH3_SECRET_SIZE_MIN); + /* + * If an action is to be taken if `secretLen` condition is not respected, + * it should be done here. + * For now, it's a contract pre-condition. + * Adding a check and a branch here would cost performance at every hash. + * Also, note that function signature doesn't offer room to return an error. + */ + if (len <= 16) + return XXH3_len_0to16_64b((const xxh_u8*)input, len, (const xxh_u8*)secret, seed64); + if (len <= 128) + return XXH3_len_17to128_64b((const xxh_u8*)input, len, (const xxh_u8*)secret, secretLen, seed64); + if (len <= XXH3_MIDSIZE_MAX) + return XXH3_len_129to240_64b((const xxh_u8*)input, len, (const xxh_u8*)secret, secretLen, seed64); + return f_hashLong(input, len, seed64, (const xxh_u8*)secret, secretLen); +} + + +/* === Public entry point === */ + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH64_hash_t XXH3_64bits(XXH_NOESCAPE const void* input, size_t length) +{ + return XXH3_64bits_internal(input, length, 0, XXH3_kSecret, sizeof(XXH3_kSecret), XXH3_hashLong_64b_default); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH64_hash_t +XXH3_64bits_withSecret(XXH_NOESCAPE const void* input, size_t length, XXH_NOESCAPE const void* secret, size_t secretSize) +{ + return XXH3_64bits_internal(input, length, 0, secret, secretSize, XXH3_hashLong_64b_withSecret); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH64_hash_t +XXH3_64bits_withSeed(XXH_NOESCAPE const void* input, size_t length, XXH64_hash_t seed) +{ + return XXH3_64bits_internal(input, length, seed, XXH3_kSecret, sizeof(XXH3_kSecret), XXH3_hashLong_64b_withSeed); +} + +XXH_PUBLIC_API XXH64_hash_t +XXH3_64bits_withSecretandSeed(XXH_NOESCAPE const void* input, size_t length, XXH_NOESCAPE const void* secret, size_t secretSize, XXH64_hash_t seed) +{ + if (length <= XXH3_MIDSIZE_MAX) + return XXH3_64bits_internal(input, length, seed, XXH3_kSecret, sizeof(XXH3_kSecret), NULL); + return XXH3_hashLong_64b_withSecret(input, length, seed, (const xxh_u8*)secret, secretSize); +} + + +/* === XXH3 streaming === */ +#ifndef XXH_NO_STREAM +/* + * Malloc's a pointer that is always aligned to align. + * + * This must be freed with `XXH_alignedFree()`. + * + * malloc typically guarantees 16 byte alignment on 64-bit systems and 8 byte + * alignment on 32-bit. This isn't enough for the 32 byte aligned loads in AVX2 + * or on 32-bit, the 16 byte aligned loads in SSE2 and NEON. + * + * This underalignment previously caused a rather obvious crash which went + * completely unnoticed due to XXH3_createState() not actually being tested. + * Credit to RedSpah for noticing this bug. + * + * The alignment is done manually: Functions like posix_memalign or _mm_malloc + * are avoided: To maintain portability, we would have to write a fallback + * like this anyways, and besides, testing for the existence of library + * functions without relying on external build tools is impossible. + * + * The method is simple: Overallocate, manually align, and store the offset + * to the original behind the returned pointer. + * + * Align must be a power of 2 and 8 <= align <= 128. + */ +static XXH_MALLOCF void* XXH_alignedMalloc(size_t s, size_t align) +{ + XXH_ASSERT(align <= 128 && align >= 8); /* range check */ + XXH_ASSERT((align & (align-1)) == 0); /* power of 2 */ + XXH_ASSERT(s != 0 && s < (s + align)); /* empty/overflow */ + { /* Overallocate to make room for manual realignment and an offset byte */ + xxh_u8* base = (xxh_u8*)XXH_malloc(s + align); + if (base != NULL) { + /* + * Get the offset needed to align this pointer. + * + * Even if the returned pointer is aligned, there will always be + * at least one byte to store the offset to the original pointer. + */ + size_t offset = align - ((size_t)base & (align - 1)); /* base % align */ + /* Add the offset for the now-aligned pointer */ + xxh_u8* ptr = base + offset; + + XXH_ASSERT((size_t)ptr % align == 0); + + /* Store the offset immediately before the returned pointer. */ + ptr[-1] = (xxh_u8)offset; + return ptr; + } + return NULL; + } +} +/* + * Frees an aligned pointer allocated by XXH_alignedMalloc(). Don't pass + * normal malloc'd pointers, XXH_alignedMalloc has a specific data layout. + */ +static void XXH_alignedFree(void* p) +{ + if (p != NULL) { + xxh_u8* ptr = (xxh_u8*)p; + /* Get the offset byte we added in XXH_malloc. */ + xxh_u8 offset = ptr[-1]; + /* Free the original malloc'd pointer */ + xxh_u8* base = ptr - offset; + XXH_free(base); + } +} +/*! @ingroup XXH3_family */ +/*! + * @brief Allocate an @ref XXH3_state_t. + * + * @return An allocated pointer of @ref XXH3_state_t on success. + * @return `NULL` on failure. + * + * @note Must be freed with XXH3_freeState(). + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH3_state_t* XXH3_createState(void) +{ + XXH3_state_t* const state = (XXH3_state_t*)XXH_alignedMalloc(sizeof(XXH3_state_t), 64); + if (state==NULL) return NULL; + XXH3_INITSTATE(state); + return state; +} + +/*! @ingroup XXH3_family */ +/*! + * @brief Frees an @ref XXH3_state_t. + * + * @param statePtr A pointer to an @ref XXH3_state_t allocated with @ref XXH3_createState(). + * + * @return @ref XXH_OK. + * + * @note Must be allocated with XXH3_createState(). + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_errorcode XXH3_freeState(XXH3_state_t* statePtr) +{ + XXH_alignedFree(statePtr); + return XXH_OK; +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API void +XXH3_copyState(XXH_NOESCAPE XXH3_state_t* dst_state, XXH_NOESCAPE const XXH3_state_t* src_state) +{ + XXH_memcpy(dst_state, src_state, sizeof(*dst_state)); +} + +static void +XXH3_reset_internal(XXH3_state_t* statePtr, + XXH64_hash_t seed, + const void* secret, size_t secretSize) +{ + size_t const initStart = offsetof(XXH3_state_t, bufferedSize); + size_t const initLength = offsetof(XXH3_state_t, nbStripesPerBlock) - initStart; + XXH_ASSERT(offsetof(XXH3_state_t, nbStripesPerBlock) > initStart); + XXH_ASSERT(statePtr != NULL); + /* set members from bufferedSize to nbStripesPerBlock (excluded) to 0 */ + memset((char*)statePtr + initStart, 0, initLength); + statePtr->acc[0] = XXH_PRIME32_3; + statePtr->acc[1] = XXH_PRIME64_1; + statePtr->acc[2] = XXH_PRIME64_2; + statePtr->acc[3] = XXH_PRIME64_3; + statePtr->acc[4] = XXH_PRIME64_4; + statePtr->acc[5] = XXH_PRIME32_2; + statePtr->acc[6] = XXH_PRIME64_5; + statePtr->acc[7] = XXH_PRIME32_1; + statePtr->seed = seed; + statePtr->useSeed = (seed != 0); + statePtr->extSecret = (const unsigned char*)secret; + XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); + statePtr->secretLimit = secretSize - XXH_STRIPE_LEN; + statePtr->nbStripesPerBlock = statePtr->secretLimit / XXH_SECRET_CONSUME_RATE; +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH_errorcode +XXH3_64bits_reset(XXH_NOESCAPE XXH3_state_t* statePtr) +{ + if (statePtr == NULL) return XXH_ERROR; + XXH3_reset_internal(statePtr, 0, XXH3_kSecret, XXH_SECRET_DEFAULT_SIZE); + return XXH_OK; +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH_errorcode +XXH3_64bits_reset_withSecret(XXH_NOESCAPE XXH3_state_t* statePtr, XXH_NOESCAPE const void* secret, size_t secretSize) +{ + if (statePtr == NULL) return XXH_ERROR; + XXH3_reset_internal(statePtr, 0, secret, secretSize); + if (secret == NULL) return XXH_ERROR; + if (secretSize < XXH3_SECRET_SIZE_MIN) return XXH_ERROR; + return XXH_OK; +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH_errorcode +XXH3_64bits_reset_withSeed(XXH_NOESCAPE XXH3_state_t* statePtr, XXH64_hash_t seed) +{ + if (statePtr == NULL) return XXH_ERROR; + if (seed==0) return XXH3_64bits_reset(statePtr); + if ((seed != statePtr->seed) || (statePtr->extSecret != NULL)) + XXH3_initCustomSecret(statePtr->customSecret, seed); + XXH3_reset_internal(statePtr, seed, NULL, XXH_SECRET_DEFAULT_SIZE); + return XXH_OK; +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH_errorcode +XXH3_64bits_reset_withSecretandSeed(XXH_NOESCAPE XXH3_state_t* statePtr, XXH_NOESCAPE const void* secret, size_t secretSize, XXH64_hash_t seed64) +{ + if (statePtr == NULL) return XXH_ERROR; + if (secret == NULL) return XXH_ERROR; + if (secretSize < XXH3_SECRET_SIZE_MIN) return XXH_ERROR; + XXH3_reset_internal(statePtr, seed64, secret, secretSize); + statePtr->useSeed = 1; /* always, even if seed64==0 */ + return XXH_OK; +} + +/*! + * @internal + * @brief Processes a large input for XXH3_update() and XXH3_digest_long(). + * + * Unlike XXH3_hashLong_internal_loop(), this can process data that overlaps a block. + * + * @param acc Pointer to the 8 accumulator lanes + * @param nbStripesSoFarPtr In/out pointer to the number of leftover stripes in the block* + * @param nbStripesPerBlock Number of stripes in a block + * @param input Input pointer + * @param nbStripes Number of stripes to process + * @param secret Secret pointer + * @param secretLimit Offset of the last block in @p secret + * @param f_acc Pointer to an XXH3_accumulate implementation + * @param f_scramble Pointer to an XXH3_scrambleAcc implementation + * @return Pointer past the end of @p input after processing + */ +XXH_FORCE_INLINE const xxh_u8 * +XXH3_consumeStripes(xxh_u64* XXH_RESTRICT acc, + size_t* XXH_RESTRICT nbStripesSoFarPtr, size_t nbStripesPerBlock, + const xxh_u8* XXH_RESTRICT input, size_t nbStripes, + const xxh_u8* XXH_RESTRICT secret, size_t secretLimit, + XXH3_f_accumulate f_acc, + XXH3_f_scrambleAcc f_scramble) +{ + const xxh_u8* initialSecret = secret + *nbStripesSoFarPtr * XXH_SECRET_CONSUME_RATE; + /* Process full blocks */ + if (nbStripes >= (nbStripesPerBlock - *nbStripesSoFarPtr)) { + /* Process the initial partial block... */ + size_t nbStripesThisIter = nbStripesPerBlock - *nbStripesSoFarPtr; + + do { + /* Accumulate and scramble */ + f_acc(acc, input, initialSecret, nbStripesThisIter); + f_scramble(acc, secret + secretLimit); + input += nbStripesThisIter * XXH_STRIPE_LEN; + nbStripes -= nbStripesThisIter; + /* Then continue the loop with the full block size */ + nbStripesThisIter = nbStripesPerBlock; + initialSecret = secret; + } while (nbStripes >= nbStripesPerBlock); + *nbStripesSoFarPtr = 0; + } + /* Process a partial block */ + if (nbStripes > 0) { + f_acc(acc, input, initialSecret, nbStripes); + input += nbStripes * XXH_STRIPE_LEN; + *nbStripesSoFarPtr += nbStripes; + } + /* Return end pointer */ + return input; +} + +#ifndef XXH3_STREAM_USE_STACK +# if XXH_SIZE_OPT <= 0 && !defined(__clang__) /* clang doesn't need additional stack space */ +# define XXH3_STREAM_USE_STACK 1 +# endif +#endif +/* + * Both XXH3_64bits_update and XXH3_128bits_update use this routine. + */ +XXH_FORCE_INLINE XXH_errorcode +XXH3_update(XXH3_state_t* XXH_RESTRICT const state, + const xxh_u8* XXH_RESTRICT input, size_t len, + XXH3_f_accumulate f_acc, + XXH3_f_scrambleAcc f_scramble) +{ + if (input==NULL) { + XXH_ASSERT(len == 0); + return XXH_OK; + } + + XXH_ASSERT(state != NULL); + { const xxh_u8* const bEnd = input + len; + const unsigned char* const secret = (state->extSecret == NULL) ? state->customSecret : state->extSecret; +#if defined(XXH3_STREAM_USE_STACK) && XXH3_STREAM_USE_STACK >= 1 + /* For some reason, gcc and MSVC seem to suffer greatly + * when operating accumulators directly into state. + * Operating into stack space seems to enable proper optimization. + * clang, on the other hand, doesn't seem to need this trick */ + XXH_ALIGN(XXH_ACC_ALIGN) xxh_u64 acc[8]; + XXH_memcpy(acc, state->acc, sizeof(acc)); +#else + xxh_u64* XXH_RESTRICT const acc = state->acc; +#endif + state->totalLen += len; + XXH_ASSERT(state->bufferedSize <= XXH3_INTERNALBUFFER_SIZE); + + /* small input : just fill in tmp buffer */ + if (len <= XXH3_INTERNALBUFFER_SIZE - state->bufferedSize) { + XXH_memcpy(state->buffer + state->bufferedSize, input, len); + state->bufferedSize += (XXH32_hash_t)len; + return XXH_OK; + } + + /* total input is now > XXH3_INTERNALBUFFER_SIZE */ + #define XXH3_INTERNALBUFFER_STRIPES (XXH3_INTERNALBUFFER_SIZE / XXH_STRIPE_LEN) + XXH_STATIC_ASSERT(XXH3_INTERNALBUFFER_SIZE % XXH_STRIPE_LEN == 0); /* clean multiple */ + + /* + * Internal buffer is partially filled (always, except at beginning) + * Complete it, then consume it. + */ + if (state->bufferedSize) { + size_t const loadSize = XXH3_INTERNALBUFFER_SIZE - state->bufferedSize; + XXH_memcpy(state->buffer + state->bufferedSize, input, loadSize); + input += loadSize; + XXH3_consumeStripes(acc, + &state->nbStripesSoFar, state->nbStripesPerBlock, + state->buffer, XXH3_INTERNALBUFFER_STRIPES, + secret, state->secretLimit, + f_acc, f_scramble); + state->bufferedSize = 0; + } + XXH_ASSERT(input < bEnd); + if (bEnd - input > XXH3_INTERNALBUFFER_SIZE) { + size_t nbStripes = (size_t)(bEnd - 1 - input) / XXH_STRIPE_LEN; + input = XXH3_consumeStripes(acc, + &state->nbStripesSoFar, state->nbStripesPerBlock, + input, nbStripes, + secret, state->secretLimit, + f_acc, f_scramble); + XXH_memcpy(state->buffer + sizeof(state->buffer) - XXH_STRIPE_LEN, input - XXH_STRIPE_LEN, XXH_STRIPE_LEN); + + } + /* Some remaining input (always) : buffer it */ + XXH_ASSERT(input < bEnd); + XXH_ASSERT(bEnd - input <= XXH3_INTERNALBUFFER_SIZE); + XXH_ASSERT(state->bufferedSize == 0); + XXH_memcpy(state->buffer, input, (size_t)(bEnd-input)); + state->bufferedSize = (XXH32_hash_t)(bEnd-input); +#if defined(XXH3_STREAM_USE_STACK) && XXH3_STREAM_USE_STACK >= 1 + /* save stack accumulators into state */ + XXH_memcpy(state->acc, acc, sizeof(acc)); +#endif + } + + return XXH_OK; +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH_errorcode +XXH3_64bits_update(XXH_NOESCAPE XXH3_state_t* state, XXH_NOESCAPE const void* input, size_t len) +{ + return XXH3_update(state, (const xxh_u8*)input, len, + XXH3_accumulate, XXH3_scrambleAcc); +} + + +XXH_FORCE_INLINE void +XXH3_digest_long (XXH64_hash_t* acc, + const XXH3_state_t* state, + const unsigned char* secret) +{ + xxh_u8 lastStripe[XXH_STRIPE_LEN]; + const xxh_u8* lastStripePtr; + + /* + * Digest on a local copy. This way, the state remains unaltered, and it can + * continue ingesting more input afterwards. + */ + XXH_memcpy(acc, state->acc, sizeof(state->acc)); + if (state->bufferedSize >= XXH_STRIPE_LEN) { + /* Consume remaining stripes then point to remaining data in buffer */ + size_t const nbStripes = (state->bufferedSize - 1) / XXH_STRIPE_LEN; + size_t nbStripesSoFar = state->nbStripesSoFar; + XXH3_consumeStripes(acc, + &nbStripesSoFar, state->nbStripesPerBlock, + state->buffer, nbStripes, + secret, state->secretLimit, + XXH3_accumulate, XXH3_scrambleAcc); + lastStripePtr = state->buffer + state->bufferedSize - XXH_STRIPE_LEN; + } else { /* bufferedSize < XXH_STRIPE_LEN */ + /* Copy to temp buffer */ + size_t const catchupSize = XXH_STRIPE_LEN - state->bufferedSize; + XXH_ASSERT(state->bufferedSize > 0); /* there is always some input buffered */ + XXH_memcpy(lastStripe, state->buffer + sizeof(state->buffer) - catchupSize, catchupSize); + XXH_memcpy(lastStripe + catchupSize, state->buffer, state->bufferedSize); + lastStripePtr = lastStripe; + } + /* Last stripe */ + XXH3_accumulate_512(acc, + lastStripePtr, + secret + state->secretLimit - XXH_SECRET_LASTACC_START); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH64_hash_t XXH3_64bits_digest (XXH_NOESCAPE const XXH3_state_t* state) +{ + const unsigned char* const secret = (state->extSecret == NULL) ? state->customSecret : state->extSecret; + if (state->totalLen > XXH3_MIDSIZE_MAX) { + XXH_ALIGN(XXH_ACC_ALIGN) XXH64_hash_t acc[XXH_ACC_NB]; + XXH3_digest_long(acc, state, secret); + return XXH3_mergeAccs(acc, + secret + XXH_SECRET_MERGEACCS_START, + (xxh_u64)state->totalLen * XXH_PRIME64_1); + } + /* totalLen <= XXH3_MIDSIZE_MAX: digesting a short input */ + if (state->useSeed) + return XXH3_64bits_withSeed(state->buffer, (size_t)state->totalLen, state->seed); + return XXH3_64bits_withSecret(state->buffer, (size_t)(state->totalLen), + secret, state->secretLimit + XXH_STRIPE_LEN); +} +#endif /* !XXH_NO_STREAM */ + + +/* ========================================== + * XXH3 128 bits (a.k.a XXH128) + * ========================================== + * XXH3's 128-bit variant has better mixing and strength than the 64-bit variant, + * even without counting the significantly larger output size. + * + * For example, extra steps are taken to avoid the seed-dependent collisions + * in 17-240 byte inputs (See XXH3_mix16B and XXH128_mix32B). + * + * This strength naturally comes at the cost of some speed, especially on short + * lengths. Note that longer hashes are about as fast as the 64-bit version + * due to it using only a slight modification of the 64-bit loop. + * + * XXH128 is also more oriented towards 64-bit machines. It is still extremely + * fast for a _128-bit_ hash on 32-bit (it usually clears XXH64). + */ + +XXH_FORCE_INLINE XXH_PUREF XXH128_hash_t +XXH3_len_1to3_128b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + /* A doubled version of 1to3_64b with different constants. */ + XXH_ASSERT(input != NULL); + XXH_ASSERT(1 <= len && len <= 3); + XXH_ASSERT(secret != NULL); + /* + * len = 1: combinedl = { input[0], 0x01, input[0], input[0] } + * len = 2: combinedl = { input[1], 0x02, input[0], input[1] } + * len = 3: combinedl = { input[2], 0x03, input[0], input[1] } + */ + { xxh_u8 const c1 = input[0]; + xxh_u8 const c2 = input[len >> 1]; + xxh_u8 const c3 = input[len - 1]; + xxh_u32 const combinedl = ((xxh_u32)c1 <<16) | ((xxh_u32)c2 << 24) + | ((xxh_u32)c3 << 0) | ((xxh_u32)len << 8); + xxh_u32 const combinedh = XXH_rotl32(XXH_swap32(combinedl), 13); + xxh_u64 const bitflipl = (XXH_readLE32(secret) ^ XXH_readLE32(secret+4)) + seed; + xxh_u64 const bitfliph = (XXH_readLE32(secret+8) ^ XXH_readLE32(secret+12)) - seed; + xxh_u64 const keyed_lo = (xxh_u64)combinedl ^ bitflipl; + xxh_u64 const keyed_hi = (xxh_u64)combinedh ^ bitfliph; + XXH128_hash_t h128; + h128.low64 = XXH64_avalanche(keyed_lo); + h128.high64 = XXH64_avalanche(keyed_hi); + return h128; + } +} + +XXH_FORCE_INLINE XXH_PUREF XXH128_hash_t +XXH3_len_4to8_128b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + XXH_ASSERT(input != NULL); + XXH_ASSERT(secret != NULL); + XXH_ASSERT(4 <= len && len <= 8); + seed ^= (xxh_u64)XXH_swap32((xxh_u32)seed) << 32; + { xxh_u32 const input_lo = XXH_readLE32(input); + xxh_u32 const input_hi = XXH_readLE32(input + len - 4); + xxh_u64 const input_64 = input_lo + ((xxh_u64)input_hi << 32); + xxh_u64 const bitflip = (XXH_readLE64(secret+16) ^ XXH_readLE64(secret+24)) + seed; + xxh_u64 const keyed = input_64 ^ bitflip; + + /* Shift len to the left to ensure it is even, this avoids even multiplies. */ + XXH128_hash_t m128 = XXH_mult64to128(keyed, XXH_PRIME64_1 + (len << 2)); + + m128.high64 += (m128.low64 << 1); + m128.low64 ^= (m128.high64 >> 3); + + m128.low64 = XXH_xorshift64(m128.low64, 35); + m128.low64 *= PRIME_MX2; + m128.low64 = XXH_xorshift64(m128.low64, 28); + m128.high64 = XXH3_avalanche(m128.high64); + return m128; + } +} + +XXH_FORCE_INLINE XXH_PUREF XXH128_hash_t +XXH3_len_9to16_128b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + XXH_ASSERT(input != NULL); + XXH_ASSERT(secret != NULL); + XXH_ASSERT(9 <= len && len <= 16); + { xxh_u64 const bitflipl = (XXH_readLE64(secret+32) ^ XXH_readLE64(secret+40)) - seed; + xxh_u64 const bitfliph = (XXH_readLE64(secret+48) ^ XXH_readLE64(secret+56)) + seed; + xxh_u64 const input_lo = XXH_readLE64(input); + xxh_u64 input_hi = XXH_readLE64(input + len - 8); + XXH128_hash_t m128 = XXH_mult64to128(input_lo ^ input_hi ^ bitflipl, XXH_PRIME64_1); + /* + * Put len in the middle of m128 to ensure that the length gets mixed to + * both the low and high bits in the 128x64 multiply below. + */ + m128.low64 += (xxh_u64)(len - 1) << 54; + input_hi ^= bitfliph; + /* + * Add the high 32 bits of input_hi to the high 32 bits of m128, then + * add the long product of the low 32 bits of input_hi and XXH_PRIME32_2 to + * the high 64 bits of m128. + * + * The best approach to this operation is different on 32-bit and 64-bit. + */ + if (sizeof(void *) < sizeof(xxh_u64)) { /* 32-bit */ + /* + * 32-bit optimized version, which is more readable. + * + * On 32-bit, it removes an ADC and delays a dependency between the two + * halves of m128.high64, but it generates an extra mask on 64-bit. + */ + m128.high64 += (input_hi & 0xFFFFFFFF00000000ULL) + XXH_mult32to64((xxh_u32)input_hi, XXH_PRIME32_2); + } else { + /* + * 64-bit optimized (albeit more confusing) version. + * + * Uses some properties of addition and multiplication to remove the mask: + * + * Let: + * a = input_hi.lo = (input_hi & 0x00000000FFFFFFFF) + * b = input_hi.hi = (input_hi & 0xFFFFFFFF00000000) + * c = XXH_PRIME32_2 + * + * a + (b * c) + * Inverse Property: x + y - x == y + * a + (b * (1 + c - 1)) + * Distributive Property: x * (y + z) == (x * y) + (x * z) + * a + (b * 1) + (b * (c - 1)) + * Identity Property: x * 1 == x + * a + b + (b * (c - 1)) + * + * Substitute a, b, and c: + * input_hi.hi + input_hi.lo + ((xxh_u64)input_hi.lo * (XXH_PRIME32_2 - 1)) + * + * Since input_hi.hi + input_hi.lo == input_hi, we get this: + * input_hi + ((xxh_u64)input_hi.lo * (XXH_PRIME32_2 - 1)) + */ + m128.high64 += input_hi + XXH_mult32to64((xxh_u32)input_hi, XXH_PRIME32_2 - 1); + } + /* m128 ^= XXH_swap64(m128 >> 64); */ + m128.low64 ^= XXH_swap64(m128.high64); + + { /* 128x64 multiply: h128 = m128 * XXH_PRIME64_2; */ + XXH128_hash_t h128 = XXH_mult64to128(m128.low64, XXH_PRIME64_2); + h128.high64 += m128.high64 * XXH_PRIME64_2; + + h128.low64 = XXH3_avalanche(h128.low64); + h128.high64 = XXH3_avalanche(h128.high64); + return h128; + } } +} + +/* + * Assumption: `secret` size is >= XXH3_SECRET_SIZE_MIN + */ +XXH_FORCE_INLINE XXH_PUREF XXH128_hash_t +XXH3_len_0to16_128b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + XXH_ASSERT(len <= 16); + { if (len > 8) return XXH3_len_9to16_128b(input, len, secret, seed); + if (len >= 4) return XXH3_len_4to8_128b(input, len, secret, seed); + if (len) return XXH3_len_1to3_128b(input, len, secret, seed); + { XXH128_hash_t h128; + xxh_u64 const bitflipl = XXH_readLE64(secret+64) ^ XXH_readLE64(secret+72); + xxh_u64 const bitfliph = XXH_readLE64(secret+80) ^ XXH_readLE64(secret+88); + h128.low64 = XXH64_avalanche(seed ^ bitflipl); + h128.high64 = XXH64_avalanche( seed ^ bitfliph); + return h128; + } } +} + +/* + * A bit slower than XXH3_mix16B, but handles multiply by zero better. + */ +XXH_FORCE_INLINE XXH128_hash_t +XXH128_mix32B(XXH128_hash_t acc, const xxh_u8* input_1, const xxh_u8* input_2, + const xxh_u8* secret, XXH64_hash_t seed) +{ + acc.low64 += XXH3_mix16B (input_1, secret+0, seed); + acc.low64 ^= XXH_readLE64(input_2) + XXH_readLE64(input_2 + 8); + acc.high64 += XXH3_mix16B (input_2, secret+16, seed); + acc.high64 ^= XXH_readLE64(input_1) + XXH_readLE64(input_1 + 8); + return acc; +} + + +XXH_FORCE_INLINE XXH_PUREF XXH128_hash_t +XXH3_len_17to128_128b(const xxh_u8* XXH_RESTRICT input, size_t len, + const xxh_u8* XXH_RESTRICT secret, size_t secretSize, + XXH64_hash_t seed) +{ + XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); (void)secretSize; + XXH_ASSERT(16 < len && len <= 128); + + { XXH128_hash_t acc; + acc.low64 = len * XXH_PRIME64_1; + acc.high64 = 0; + +#if XXH_SIZE_OPT >= 1 + { + /* Smaller, but slightly slower. */ + unsigned int i = (unsigned int)(len - 1) / 32; + do { + acc = XXH128_mix32B(acc, input+16*i, input+len-16*(i+1), secret+32*i, seed); + } while (i-- != 0); + } +#else + if (len > 32) { + if (len > 64) { + if (len > 96) { + acc = XXH128_mix32B(acc, input+48, input+len-64, secret+96, seed); + } + acc = XXH128_mix32B(acc, input+32, input+len-48, secret+64, seed); + } + acc = XXH128_mix32B(acc, input+16, input+len-32, secret+32, seed); + } + acc = XXH128_mix32B(acc, input, input+len-16, secret, seed); +#endif + { XXH128_hash_t h128; + h128.low64 = acc.low64 + acc.high64; + h128.high64 = (acc.low64 * XXH_PRIME64_1) + + (acc.high64 * XXH_PRIME64_4) + + ((len - seed) * XXH_PRIME64_2); + h128.low64 = XXH3_avalanche(h128.low64); + h128.high64 = (XXH64_hash_t)0 - XXH3_avalanche(h128.high64); + return h128; + } + } +} + +XXH_NO_INLINE XXH_PUREF XXH128_hash_t +XXH3_len_129to240_128b(const xxh_u8* XXH_RESTRICT input, size_t len, + const xxh_u8* XXH_RESTRICT secret, size_t secretSize, + XXH64_hash_t seed) +{ + XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); (void)secretSize; + XXH_ASSERT(128 < len && len <= XXH3_MIDSIZE_MAX); + + { XXH128_hash_t acc; + unsigned i; + acc.low64 = len * XXH_PRIME64_1; + acc.high64 = 0; + /* + * We set as `i` as offset + 32. We do this so that unchanged + * `len` can be used as upper bound. This reaches a sweet spot + * where both x86 and aarch64 get simple agen and good codegen + * for the loop. + */ + for (i = 32; i < 160; i += 32) { + acc = XXH128_mix32B(acc, + input + i - 32, + input + i - 16, + secret + i - 32, + seed); + } + acc.low64 = XXH3_avalanche(acc.low64); + acc.high64 = XXH3_avalanche(acc.high64); + /* + * NB: `i <= len` will duplicate the last 32-bytes if + * len % 32 was zero. This is an unfortunate necessity to keep + * the hash result stable. + */ + for (i=160; i <= len; i += 32) { + acc = XXH128_mix32B(acc, + input + i - 32, + input + i - 16, + secret + XXH3_MIDSIZE_STARTOFFSET + i - 160, + seed); + } + /* last bytes */ + acc = XXH128_mix32B(acc, + input + len - 16, + input + len - 32, + secret + XXH3_SECRET_SIZE_MIN - XXH3_MIDSIZE_LASTOFFSET - 16, + (XXH64_hash_t)0 - seed); + + { XXH128_hash_t h128; + h128.low64 = acc.low64 + acc.high64; + h128.high64 = (acc.low64 * XXH_PRIME64_1) + + (acc.high64 * XXH_PRIME64_4) + + ((len - seed) * XXH_PRIME64_2); + h128.low64 = XXH3_avalanche(h128.low64); + h128.high64 = (XXH64_hash_t)0 - XXH3_avalanche(h128.high64); + return h128; + } + } +} + +XXH_FORCE_INLINE XXH128_hash_t +XXH3_hashLong_128b_internal(const void* XXH_RESTRICT input, size_t len, + const xxh_u8* XXH_RESTRICT secret, size_t secretSize, + XXH3_f_accumulate f_acc, + XXH3_f_scrambleAcc f_scramble) +{ + XXH_ALIGN(XXH_ACC_ALIGN) xxh_u64 acc[XXH_ACC_NB] = XXH3_INIT_ACC; + + XXH3_hashLong_internal_loop(acc, (const xxh_u8*)input, len, secret, secretSize, f_acc, f_scramble); + + /* converge into final hash */ + XXH_STATIC_ASSERT(sizeof(acc) == 64); + XXH_ASSERT(secretSize >= sizeof(acc) + XXH_SECRET_MERGEACCS_START); + { XXH128_hash_t h128; + h128.low64 = XXH3_mergeAccs(acc, + secret + XXH_SECRET_MERGEACCS_START, + (xxh_u64)len * XXH_PRIME64_1); + h128.high64 = XXH3_mergeAccs(acc, + secret + secretSize + - sizeof(acc) - XXH_SECRET_MERGEACCS_START, + ~((xxh_u64)len * XXH_PRIME64_2)); + return h128; + } +} + +/* + * It's important for performance that XXH3_hashLong() is not inlined. + */ +XXH_NO_INLINE XXH_PUREF XXH128_hash_t +XXH3_hashLong_128b_default(const void* XXH_RESTRICT input, size_t len, + XXH64_hash_t seed64, + const void* XXH_RESTRICT secret, size_t secretLen) +{ + (void)seed64; (void)secret; (void)secretLen; + return XXH3_hashLong_128b_internal(input, len, XXH3_kSecret, sizeof(XXH3_kSecret), + XXH3_accumulate, XXH3_scrambleAcc); +} + +/* + * It's important for performance to pass @p secretLen (when it's static) + * to the compiler, so that it can properly optimize the vectorized loop. + * + * When the secret size is unknown, or on GCC 12 where the mix of NO_INLINE and FORCE_INLINE + * breaks -Og, this is XXH_NO_INLINE. + */ +XXH3_WITH_SECRET_INLINE XXH128_hash_t +XXH3_hashLong_128b_withSecret(const void* XXH_RESTRICT input, size_t len, + XXH64_hash_t seed64, + const void* XXH_RESTRICT secret, size_t secretLen) +{ + (void)seed64; + return XXH3_hashLong_128b_internal(input, len, (const xxh_u8*)secret, secretLen, + XXH3_accumulate, XXH3_scrambleAcc); +} + +XXH_FORCE_INLINE XXH128_hash_t +XXH3_hashLong_128b_withSeed_internal(const void* XXH_RESTRICT input, size_t len, + XXH64_hash_t seed64, + XXH3_f_accumulate f_acc, + XXH3_f_scrambleAcc f_scramble, + XXH3_f_initCustomSecret f_initSec) +{ + if (seed64 == 0) + return XXH3_hashLong_128b_internal(input, len, + XXH3_kSecret, sizeof(XXH3_kSecret), + f_acc, f_scramble); + { XXH_ALIGN(XXH_SEC_ALIGN) xxh_u8 secret[XXH_SECRET_DEFAULT_SIZE]; + f_initSec(secret, seed64); + return XXH3_hashLong_128b_internal(input, len, (const xxh_u8*)secret, sizeof(secret), + f_acc, f_scramble); + } +} + +/* + * It's important for performance that XXH3_hashLong is not inlined. + */ +XXH_NO_INLINE XXH128_hash_t +XXH3_hashLong_128b_withSeed(const void* input, size_t len, + XXH64_hash_t seed64, const void* XXH_RESTRICT secret, size_t secretLen) +{ + (void)secret; (void)secretLen; + return XXH3_hashLong_128b_withSeed_internal(input, len, seed64, + XXH3_accumulate, XXH3_scrambleAcc, XXH3_initCustomSecret); +} + +typedef XXH128_hash_t (*XXH3_hashLong128_f)(const void* XXH_RESTRICT, size_t, + XXH64_hash_t, const void* XXH_RESTRICT, size_t); + +XXH_FORCE_INLINE XXH128_hash_t +XXH3_128bits_internal(const void* input, size_t len, + XXH64_hash_t seed64, const void* XXH_RESTRICT secret, size_t secretLen, + XXH3_hashLong128_f f_hl128) +{ + XXH_ASSERT(secretLen >= XXH3_SECRET_SIZE_MIN); + /* + * If an action is to be taken if `secret` conditions are not respected, + * it should be done here. + * For now, it's a contract pre-condition. + * Adding a check and a branch here would cost performance at every hash. + */ + if (len <= 16) + return XXH3_len_0to16_128b((const xxh_u8*)input, len, (const xxh_u8*)secret, seed64); + if (len <= 128) + return XXH3_len_17to128_128b((const xxh_u8*)input, len, (const xxh_u8*)secret, secretLen, seed64); + if (len <= XXH3_MIDSIZE_MAX) + return XXH3_len_129to240_128b((const xxh_u8*)input, len, (const xxh_u8*)secret, secretLen, seed64); + return f_hl128(input, len, seed64, secret, secretLen); +} + + +/* === Public XXH128 API === */ + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH128_hash_t XXH3_128bits(XXH_NOESCAPE const void* input, size_t len) +{ + return XXH3_128bits_internal(input, len, 0, + XXH3_kSecret, sizeof(XXH3_kSecret), + XXH3_hashLong_128b_default); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH128_hash_t +XXH3_128bits_withSecret(XXH_NOESCAPE const void* input, size_t len, XXH_NOESCAPE const void* secret, size_t secretSize) +{ + return XXH3_128bits_internal(input, len, 0, + (const xxh_u8*)secret, secretSize, + XXH3_hashLong_128b_withSecret); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH128_hash_t +XXH3_128bits_withSeed(XXH_NOESCAPE const void* input, size_t len, XXH64_hash_t seed) +{ + return XXH3_128bits_internal(input, len, seed, + XXH3_kSecret, sizeof(XXH3_kSecret), + XXH3_hashLong_128b_withSeed); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH128_hash_t +XXH3_128bits_withSecretandSeed(XXH_NOESCAPE const void* input, size_t len, XXH_NOESCAPE const void* secret, size_t secretSize, XXH64_hash_t seed) +{ + if (len <= XXH3_MIDSIZE_MAX) + return XXH3_128bits_internal(input, len, seed, XXH3_kSecret, sizeof(XXH3_kSecret), NULL); + return XXH3_hashLong_128b_withSecret(input, len, seed, secret, secretSize); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH128_hash_t +XXH128(XXH_NOESCAPE const void* input, size_t len, XXH64_hash_t seed) +{ + return XXH3_128bits_withSeed(input, len, seed); +} + + +/* === XXH3 128-bit streaming === */ +#ifndef XXH_NO_STREAM +/* + * All initialization and update functions are identical to 64-bit streaming variant. + * The only difference is the finalization routine. + */ + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH_errorcode +XXH3_128bits_reset(XXH_NOESCAPE XXH3_state_t* statePtr) +{ + return XXH3_64bits_reset(statePtr); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH_errorcode +XXH3_128bits_reset_withSecret(XXH_NOESCAPE XXH3_state_t* statePtr, XXH_NOESCAPE const void* secret, size_t secretSize) +{ + return XXH3_64bits_reset_withSecret(statePtr, secret, secretSize); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH_errorcode +XXH3_128bits_reset_withSeed(XXH_NOESCAPE XXH3_state_t* statePtr, XXH64_hash_t seed) +{ + return XXH3_64bits_reset_withSeed(statePtr, seed); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH_errorcode +XXH3_128bits_reset_withSecretandSeed(XXH_NOESCAPE XXH3_state_t* statePtr, XXH_NOESCAPE const void* secret, size_t secretSize, XXH64_hash_t seed) +{ + return XXH3_64bits_reset_withSecretandSeed(statePtr, secret, secretSize, seed); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH_errorcode +XXH3_128bits_update(XXH_NOESCAPE XXH3_state_t* state, XXH_NOESCAPE const void* input, size_t len) +{ + return XXH3_64bits_update(state, input, len); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH128_hash_t XXH3_128bits_digest (XXH_NOESCAPE const XXH3_state_t* state) +{ + const unsigned char* const secret = (state->extSecret == NULL) ? state->customSecret : state->extSecret; + if (state->totalLen > XXH3_MIDSIZE_MAX) { + XXH_ALIGN(XXH_ACC_ALIGN) XXH64_hash_t acc[XXH_ACC_NB]; + XXH3_digest_long(acc, state, secret); + XXH_ASSERT(state->secretLimit + XXH_STRIPE_LEN >= sizeof(acc) + XXH_SECRET_MERGEACCS_START); + { XXH128_hash_t h128; + h128.low64 = XXH3_mergeAccs(acc, + secret + XXH_SECRET_MERGEACCS_START, + (xxh_u64)state->totalLen * XXH_PRIME64_1); + h128.high64 = XXH3_mergeAccs(acc, + secret + state->secretLimit + XXH_STRIPE_LEN + - sizeof(acc) - XXH_SECRET_MERGEACCS_START, + ~((xxh_u64)state->totalLen * XXH_PRIME64_2)); + return h128; + } + } + /* len <= XXH3_MIDSIZE_MAX : short code */ + if (state->useSeed) + return XXH3_128bits_withSeed(state->buffer, (size_t)state->totalLen, state->seed); + return XXH3_128bits_withSecret(state->buffer, (size_t)(state->totalLen), + secret, state->secretLimit + XXH_STRIPE_LEN); +} +#endif /* !XXH_NO_STREAM */ +/* 128-bit utility functions */ + +#include /* memcmp, memcpy */ + +/* return : 1 is equal, 0 if different */ +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API int XXH128_isEqual(XXH128_hash_t h1, XXH128_hash_t h2) +{ + /* note : XXH128_hash_t is compact, it has no padding byte */ + return !(memcmp(&h1, &h2, sizeof(h1))); +} + +/* This prototype is compatible with stdlib's qsort(). + * @return : >0 if *h128_1 > *h128_2 + * <0 if *h128_1 < *h128_2 + * =0 if *h128_1 == *h128_2 */ +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API int XXH128_cmp(XXH_NOESCAPE const void* h128_1, XXH_NOESCAPE const void* h128_2) +{ + XXH128_hash_t const h1 = *(const XXH128_hash_t*)h128_1; + XXH128_hash_t const h2 = *(const XXH128_hash_t*)h128_2; + int const hcmp = (h1.high64 > h2.high64) - (h2.high64 > h1.high64); + /* note : bets that, in most cases, hash values are different */ + if (hcmp) return hcmp; + return (h1.low64 > h2.low64) - (h2.low64 > h1.low64); +} + + +/*====== Canonical representation ======*/ +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API void +XXH128_canonicalFromHash(XXH_NOESCAPE XXH128_canonical_t* dst, XXH128_hash_t hash) +{ + XXH_STATIC_ASSERT(sizeof(XXH128_canonical_t) == sizeof(XXH128_hash_t)); + if (XXH_CPU_LITTLE_ENDIAN) { + hash.high64 = XXH_swap64(hash.high64); + hash.low64 = XXH_swap64(hash.low64); + } + XXH_memcpy(dst, &hash.high64, sizeof(hash.high64)); + XXH_memcpy((char*)dst + sizeof(hash.high64), &hash.low64, sizeof(hash.low64)); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH128_hash_t +XXH128_hashFromCanonical(XXH_NOESCAPE const XXH128_canonical_t* src) +{ + XXH128_hash_t h; + h.high64 = XXH_readBE64(src); + h.low64 = XXH_readBE64(src->digest + 8); + return h; +} + + + +/* ========================================== + * Secret generators + * ========================================== + */ +#define XXH_MIN(x, y) (((x) > (y)) ? (y) : (x)) + +XXH_FORCE_INLINE void XXH3_combine16(void* dst, XXH128_hash_t h128) +{ + XXH_writeLE64( dst, XXH_readLE64(dst) ^ h128.low64 ); + XXH_writeLE64( (char*)dst+8, XXH_readLE64((char*)dst+8) ^ h128.high64 ); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH_errorcode +XXH3_generateSecret(XXH_NOESCAPE void* secretBuffer, size_t secretSize, XXH_NOESCAPE const void* customSeed, size_t customSeedSize) +{ +#if (XXH_DEBUGLEVEL >= 1) + XXH_ASSERT(secretBuffer != NULL); + XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); +#else + /* production mode, assert() are disabled */ + if (secretBuffer == NULL) return XXH_ERROR; + if (secretSize < XXH3_SECRET_SIZE_MIN) return XXH_ERROR; +#endif + + if (customSeedSize == 0) { + customSeed = XXH3_kSecret; + customSeedSize = XXH_SECRET_DEFAULT_SIZE; + } +#if (XXH_DEBUGLEVEL >= 1) + XXH_ASSERT(customSeed != NULL); +#else + if (customSeed == NULL) return XXH_ERROR; +#endif + + /* Fill secretBuffer with a copy of customSeed - repeat as needed */ + { size_t pos = 0; + while (pos < secretSize) { + size_t const toCopy = XXH_MIN((secretSize - pos), customSeedSize); + memcpy((char*)secretBuffer + pos, customSeed, toCopy); + pos += toCopy; + } } + + { size_t const nbSeg16 = secretSize / 16; + size_t n; + XXH128_canonical_t scrambler; + XXH128_canonicalFromHash(&scrambler, XXH128(customSeed, customSeedSize, 0)); + for (n=0; n /* abort() */ +#include +#include +#include +#include +#include +#include + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#include "xxhash/xxhash.h" +#include "sha1/sha1.h" +#include "sha256/sha256.h" + +#ifdef __cplusplus +} +#endif + + +// uuid.uuid5(uuid.NAMESPACE_URL, 'en.wikipedia.org/wiki/Llama.cpp') +#define UUID_NAMESPACE_LLAMA_CPP "ef001206-dadc-5f6d-a15f-3359e577d4e5" +#define UUID_NAMESPACE_LLAMA_CPP_HEX 0xef, 0x00, 0x12, 0x06, 0xda, 0xdc, 0x5f, 0x6d, 0xa1, 0x5f, 0x33, 0x59, 0xe5, 0x77, 0xd4, 0xe5 + + +#define HASH_TYPE_SHA256_STR "sha256" +#define HASH_TYPE_SHA1_STR "sha1" +#define HASH_TYPE_XXH64_STR "xxh64" +#define HASH_TYPE_UUID_STR "uuid" + + +typedef enum { + HASH_EXIT_SUCCESS = 0, // All hash has been generated or validated + HASH_EXIT_FAILURE = 1, // Generic Failure + HASH_EXIT_MISMATCH = 2, // Hash mismatched during validation + HASH_EXIT_MANIFEST_MISSING_ENTRY = 3, // Hash attempted validation but missing entry in manifest + HASH_EXIT_MANIFEST_UNKNOWN_HASH = 4, // Manifest is present, but we do not know any hash format within it + HASH_EXIT_MANIFEST_FILE_ERROR = 5 // Manifest is either missing or not a known format +} hash_exit_code_t; + + +typedef enum { + HASH_MANIFEST_NOT_FOUND, + HASH_MANIFEST_MISMATCH, + HASH_MANIFEST_OK, +} hash_manifest_result_t; + + +struct hash_params { + std::string input; + bool xxh64 = false; + bool sha1 = false; + bool sha256 = false; + bool uuid = false; + + bool no_layer = false; + + bool manifest_is_usable = false; + std::string manifest_file; +}; + +struct manifest_check_params { + bool xxh64 = false; + bool sha1 = false; + bool sha256 = false; + bool uuid = false; +}; + +static char const * hash_manifest_result_to_str(hash_manifest_result_t value) { + switch (value) { + case HASH_MANIFEST_NOT_FOUND: return "Not Found"; + case HASH_MANIFEST_MISMATCH: return "Mismatch"; + case HASH_MANIFEST_OK: return "Ok"; + } + return "?"; +} + +static char const * hash_exit_code_to_str(hash_exit_code_t value) { + switch (value) { + case HASH_EXIT_SUCCESS: return "Success"; + case HASH_EXIT_FAILURE: return "Failure"; + case HASH_EXIT_MISMATCH: return "Mismatch"; + case HASH_EXIT_MANIFEST_MISSING_ENTRY: return "Manifest Missing Entry"; + case HASH_EXIT_MANIFEST_UNKNOWN_HASH: return "Manifest Unknown Hash"; + case HASH_EXIT_MANIFEST_FILE_ERROR: return "Manifest File Error"; + } + return "?"; +} + +static void hash_print_usage(const char * executable) { + const hash_params default_params; + printf("\n"); + printf("usage: %s [options] GGUF_IN\n", executable); + printf("\n"); + printf("Hash a GGUF file"); + printf("\n"); + printf("options:\n"); + printf(" -h, --help show this help message and exit\n"); + printf(" --xxh64 use xxh64 hash\n"); + printf(" --sha1 use sha1 hash\n"); + printf(" --sha256 use sha256 hash\n"); + printf(" --all use all hash\n"); + printf(" --no-layer exclude per layer hash\n"); + printf(" --uuid generate UUIDv5 ID\n"); + printf(" -c, --check verify against a manifest\n"); + printf("\n"); +} + +static void hash_params_parse_ex(int argc, const char ** argv, hash_params & params) { + std::string arg; + bool invalid_param = false; + const std::string arg_prefix = "--"; + + int arg_idx = 1; + for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) { + arg = argv[arg_idx]; + if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { + std::replace(arg.begin(), arg.end(), '_', '-'); + } + + bool arg_found = false; + if (arg == "-h" || arg == "--help") { + hash_print_usage(argv[0]); + exit(0); + } + + if (arg == "--xxh64") { + arg_found = true; + params.xxh64 = true; + } + + if (arg == "--sha1") { + arg_found = true; + params.sha1 = true; + } + + if (arg == "--uuid") { + arg_found = true; + params.uuid = true; + } + + if (arg == "--sha256") { + arg_found = true; + params.sha256 = true; + } + + if (arg == "--all") { + arg_found = true; + params.sha256 = true; + params.sha1 = true; + params.xxh64 = true; + } + + if (arg == "--no-layer") { + arg_found = true; + params.no_layer = true; + } + + if (arg == "-c" || arg == "--check") { + if (++arg_idx >= argc) { + invalid_param = true; + break; + } + arg_found = true; + params.manifest_file = argv[arg_idx]; + } + + if (!arg_found) { + throw std::invalid_argument("error: unknown argument: " + arg); + } + } + + if (invalid_param) { + throw std::invalid_argument("error: invalid parameter for argument:" + arg); + } + + if (argc - arg_idx < 1) { + throw std::invalid_argument("error: bad arguments"); + } + + params.input = argv[arg_idx++]; +} + +static bool hash_params_parse(int argc, const char ** argv, hash_params & params) { + bool result = true; + try { + hash_params_parse_ex(argc, argv, params); + } + catch (const std::invalid_argument & ex) { + fprintf(stderr, "%s\n", ex.what()); + hash_print_usage(argv[0]); + exit(EXIT_FAILURE); + } + return result; +} + +static bool manifest_type(const std::string & manifest_file, manifest_check_params & manifest_check) { + if (manifest_file.empty()) { + return false; + } + + std::ifstream file(manifest_file); + if (!file.is_open()) { + return false; + } + + std::string manifest_entry_line; + while (getline(file, manifest_entry_line)) { + // hash_type_str hash_str tensor_name + // e.g. 'xxh64 f66e9cd66a4396a0 test.gguf:tensor_0' + std::istringstream line_stream(manifest_entry_line); + std::string file_hash_type; + if (line_stream >> file_hash_type) { + if (file_hash_type == HASH_TYPE_SHA256_STR) { + manifest_check.sha256 = true; + } else if (file_hash_type == HASH_TYPE_SHA1_STR) { + manifest_check.sha1 = true; + } else if (file_hash_type == HASH_TYPE_XXH64_STR) { + manifest_check.xxh64 = true; + } else if (file_hash_type == HASH_TYPE_UUID_STR) { + manifest_check.uuid = true; + } + } + } + + return true; +} + +static hash_manifest_result_t manifest_verify(const std::string& manifest_file, const std::string& hash_type_str, const std::string& hash_str, const std::string& tensor_name) { + if (manifest_file.empty()) { + return HASH_MANIFEST_NOT_FOUND; + } + + std::ifstream file(manifest_file); + if (!file.is_open()) { + return HASH_MANIFEST_NOT_FOUND; + } + + std::string manifest_entry_line; + while (getline(file, manifest_entry_line)) { + std::istringstream line_stream(manifest_entry_line); + std::string file_hash_type; + std::string file_hash; + std::string file_tensor_name; + if (line_stream >> file_hash_type >> file_hash >> file_tensor_name) { + // Line parsed. Check hash validity + + if (file_hash_type != hash_type_str) { + continue; + } + + if (file_tensor_name != tensor_name) { + continue; + } + + return (file_hash == hash_str) ? HASH_MANIFEST_OK : HASH_MANIFEST_MISMATCH; + } + } + + return HASH_MANIFEST_NOT_FOUND; +} + +static void generate_uuidv5(const unsigned char sha1_digest[20], unsigned char uuid[16]) { + // Ref: https://www.rfc-editor.org/rfc/rfc9562.html#section-5.5 + // Assumes that digest was processed correctly with the expected namespace + for (int i = 0; i < 16; i++) { + uuid[i] = sha1_digest[i]; + } + + // Set bits corresponding to UUID ver 5 + uuid[ 6] &= ~(0xF << 4); + uuid[ 6] |= (5 << 4); + + // Set bits corresponding to UUID variant 0b10XX + uuid[ 8] &= ~(0xc << 4); + uuid[ 8] |= (0x8 << 4); +} + +static hash_exit_code_t gguf_hash(const hash_params & hash_params) { + const std::string & fname = hash_params.input; + struct ggml_context * ctx_data = NULL; + + struct gguf_init_params params = { + /*.no_alloc = */ false, + /*.ctx = */ &ctx_data, + }; + + // xxh64 init + XXH64_state_t* xxh64_model_hash_state = NULL; + if (hash_params.xxh64) { + xxh64_model_hash_state = XXH64_createState(); + if (xxh64_model_hash_state==NULL) { + abort(); + } + + XXH64_hash_t const seed = 0; + if (XXH64_reset(xxh64_model_hash_state, seed) == XXH_ERROR) { + abort(); + } + } + + // sha1 init + SHA1_CTX sha1_model_hash_ctx; + if (hash_params.sha1) { + SHA1Init(&sha1_model_hash_ctx); + } + + // sha256 init + sha256_t sha256_model_hash_ctx; + if (hash_params.sha256) { + sha256_init(&sha256_model_hash_ctx); + } + + // sha1 for uuid init + SHA1_CTX sha1_for_uuid_ctx; + if (hash_params.uuid) { + unsigned char const uuidv5_namespace[] = {UUID_NAMESPACE_LLAMA_CPP_HEX}; + SHA1Init(&sha1_for_uuid_ctx); + SHA1Update( &sha1_for_uuid_ctx, (unsigned char const *)uuidv5_namespace, sizeof(uuidv5_namespace)); + } + + struct gguf_context * ctx = gguf_init_from_file(fname.c_str(), params); + const int n_tensors = gguf_get_n_tensors(ctx); + bool tensor_layer_in_manifest = false; + bool model_in_manifest = false; + bool tensor_layer_has_mismatch = false; + bool model_has_mismatch = false; + for (int i = 0; i < n_tensors; ++i) { + const char * name = gguf_get_tensor_name(ctx, i); + struct ggml_tensor * cur = ggml_get_tensor(ctx_data, name); + auto n_bytes = ggml_nbytes(cur); + auto *raw_data = cur->data; + const std::string tensor_layer_name = fname + ":" + name; + + if (hash_params.xxh64) { + + if (!hash_params.no_layer) { + // Per Layer Hash + XXH64_hash_t hash = XXH64(raw_data, n_bytes, 0); + + char hex_result[17]; + for (int offset = 0; offset < 8; offset++) { + unsigned int shift_bits_by = (8 * (8 - offset - 1)); + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff); + } + + if (hash_params.manifest_is_usable) { + hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_XXH64_STR, hex_result, tensor_layer_name); + + switch (verify_result) { + case HASH_MANIFEST_NOT_FOUND: + break; + case HASH_MANIFEST_MISMATCH: + tensor_layer_in_manifest = true; + tensor_layer_has_mismatch = true; + break; + case HASH_MANIFEST_OK: + tensor_layer_in_manifest = true; + break; + } + + printf("%-8s %-s %s - %s\n", HASH_TYPE_XXH64_STR, hex_result, tensor_layer_name.c_str(), hash_manifest_result_to_str(verify_result)); + } else { + printf("%-8s %-s %s\n", HASH_TYPE_XXH64_STR, hex_result, tensor_layer_name.c_str()); + } + } + + // Overall Model Hash + if (XXH64_update(xxh64_model_hash_state, raw_data, n_bytes) == XXH_ERROR) abort(); + } + + if (hash_params.sha1) { + + if (!hash_params.no_layer) { + // Per Layer Hash + char result[21]; // sha1 outputs 20 bytes + SHA1( result, (const char *)raw_data, n_bytes); + + char hex_result[41] = {0}; + for (int offset = 0; offset < 20; offset++) { + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff); + } + + if (hash_params.manifest_is_usable) { + hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA1_STR, hex_result, tensor_layer_name); + + switch (verify_result) { + case HASH_MANIFEST_NOT_FOUND: + break; + case HASH_MANIFEST_MISMATCH: + tensor_layer_in_manifest = true; + tensor_layer_has_mismatch = true; + break; + case HASH_MANIFEST_OK: + tensor_layer_in_manifest = true; + break; + } + + printf("%-8s %-s %s - %s\n", HASH_TYPE_SHA1_STR, hex_result, tensor_layer_name.c_str(), hash_manifest_result_to_str(verify_result)); + } else { + printf("%-8s %-s %s\n", HASH_TYPE_SHA1_STR, hex_result, tensor_layer_name.c_str()); + } + } + + // Overall Model Hash + SHA1Update( &sha1_model_hash_ctx, (unsigned char const *)raw_data, n_bytes); + } + + if (hash_params.sha256) { + + if (!hash_params.no_layer) { + // Per Layer Hash + unsigned char result[SHA256_DIGEST_SIZE]; // sha256 outputs 32 bytes + sha256_hash((unsigned char*) result, (const unsigned char *)raw_data, n_bytes); + + char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0}; + for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) { + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff); + } + + if (hash_params.manifest_is_usable) { + hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA256_STR, hex_result, tensor_layer_name); + + switch (verify_result) { + case HASH_MANIFEST_NOT_FOUND: + break; + case HASH_MANIFEST_MISMATCH: + tensor_layer_in_manifest = true; + tensor_layer_has_mismatch = true; + break; + case HASH_MANIFEST_OK: + tensor_layer_in_manifest = true; + break; + } + + printf("%-8s %-s %s - %s\n", HASH_TYPE_SHA256_STR, hex_result, tensor_layer_name.c_str(), hash_manifest_result_to_str(verify_result)); + } else { + printf("%-8s %-s %s\n", HASH_TYPE_SHA256_STR, hex_result, tensor_layer_name.c_str()); + } + } + + // Overall Model Hash + sha256_update( &sha256_model_hash_ctx, (unsigned char const *)raw_data, n_bytes); + } + + if (hash_params.uuid) { + SHA1Update( &sha1_for_uuid_ctx, (unsigned char const *)raw_data, n_bytes); + } + } + + if (hash_params.xxh64) { + XXH64_hash_t const hash = XXH64_digest(xxh64_model_hash_state); + + char hex_result[17]; + for (int offset = 0; offset < 8; offset++) { + unsigned int shift_bits_by = (8 * (8 - offset - 1)); + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff); + } + + if (hash_params.manifest_is_usable) { + hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_XXH64_STR, hex_result, fname); + + switch (verify_result) { + case HASH_MANIFEST_NOT_FOUND: + break; + case HASH_MANIFEST_MISMATCH: + model_in_manifest = true; + model_has_mismatch = true; + break; + case HASH_MANIFEST_OK: + model_in_manifest = true; + break; + } + + printf("%-8s %-s %s - %s\n", HASH_TYPE_XXH64_STR, hex_result, fname.c_str(), hash_manifest_result_to_str(verify_result)); + } else { + printf("%-8s %-s %s\n", HASH_TYPE_XXH64_STR, hex_result, fname.c_str()); + } + } + + if (hash_params.sha1) { + unsigned char result[21]; + SHA1Final(result, &sha1_model_hash_ctx); + + char hex_result[41]; + for (int offset = 0; offset < 20; offset++) { + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff); + } + + if (hash_params.manifest_is_usable) { + hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA1_STR, hex_result, fname); + + switch (verify_result) { + case HASH_MANIFEST_NOT_FOUND: + break; + case HASH_MANIFEST_MISMATCH: + model_in_manifest = true; + model_has_mismatch = true; + break; + case HASH_MANIFEST_OK: + model_in_manifest = true; + break; + } + + printf("%-8s %-s %s - %s\n", HASH_TYPE_SHA1_STR, hex_result, fname.c_str(), hash_manifest_result_to_str(verify_result)); + } else { + printf("%-8s %-s %s\n", HASH_TYPE_SHA1_STR, hex_result, fname.c_str()); + } + } + + if (hash_params.sha256) { + unsigned char result[SHA256_DIGEST_SIZE]; // sha256 outputs 32 bytes + sha256_final( &sha256_model_hash_ctx, result); + + char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0}; + for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) { + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff); + } + + if (hash_params.manifest_is_usable) { + hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA256_STR, hex_result, fname); + + switch (verify_result) { + case HASH_MANIFEST_NOT_FOUND: + break; + case HASH_MANIFEST_MISMATCH: + model_in_manifest = true; + model_has_mismatch = true; + break; + case HASH_MANIFEST_OK: + model_in_manifest = true; + break; + } + + printf("%-8s %-s %s - %s\n", HASH_TYPE_SHA256_STR, hex_result, fname.c_str(), hash_manifest_result_to_str(verify_result)); + } else { + printf("%-8s %-s %s\n", HASH_TYPE_SHA256_STR, hex_result, fname.c_str()); + } + } + + if (hash_params.uuid) { + unsigned char result[21]; + SHA1Final(result, &sha1_for_uuid_ctx); + + unsigned char uuid[16]; + generate_uuidv5(result, uuid); + + char string_buffer[37] = {0}; + snprintf(string_buffer, sizeof(string_buffer), "%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", + uuid[0], uuid[1], uuid[2], uuid[3], + uuid[4], uuid[5], uuid[6], uuid[7], + uuid[8], uuid[9], uuid[10], uuid[11], + uuid[12], uuid[13], uuid[14], uuid[15]); + + if (hash_params.manifest_is_usable) { + hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA256_STR, string_buffer, fname); + + switch (verify_result) { + case HASH_MANIFEST_NOT_FOUND: + break; + case HASH_MANIFEST_MISMATCH: + model_in_manifest = true; + model_has_mismatch = true; + break; + case HASH_MANIFEST_OK: + model_in_manifest = true; + break; + } + + printf("%-8s %-s %s - %s\n", HASH_TYPE_UUID_STR, string_buffer, fname.c_str(), hash_manifest_result_to_str(verify_result)); + } else { + printf("%-8s %-s %s\n", HASH_TYPE_UUID_STR, string_buffer, fname.c_str()); + } + } + + + ggml_free(ctx_data); + gguf_free(ctx); + + + if (hash_params.manifest_is_usable) { + // In hash verification mode + + if (!model_in_manifest) { + // model missing in manifest? + + // Check tensor layer... + if (!tensor_layer_in_manifest) { + // Still missing? Maybe we are reading the wrong manifest. + return HASH_EXIT_MANIFEST_MISSING_ENTRY; + } + + if (tensor_layer_has_mismatch) { + // Per tensor check found error + return HASH_EXIT_FAILURE; + } + + // All per tensor layer checks passed? Sounds good enough. + return HASH_EXIT_SUCCESS; + } + + // Overall model check passed, but let's check per layer just in case + // If missing, we don't care too much as the overall model checked + if (tensor_layer_in_manifest && tensor_layer_has_mismatch) { + return HASH_EXIT_FAILURE; + } + + if (model_has_mismatch) { + // model has failed hash somewhere in the model + return HASH_EXIT_FAILURE; + } + + // All checks appears to be fine + return HASH_EXIT_SUCCESS; + } + + // In hash generation mode + return HASH_EXIT_SUCCESS; +} + +int main(int argc, const char ** argv) { + hash_params params; + manifest_check_params manifest_check; + hash_params_parse(argc, argv, params); + + if (!params.manifest_file.empty()) { + if (!manifest_type(params.manifest_file, manifest_check)) { + printf("ERROR cannot open manifest %s", params.manifest_file.c_str()); + return HASH_EXIT_MANIFEST_FILE_ERROR; + } + + if (!manifest_check.sha256 && !manifest_check.sha1 && !manifest_check.xxh64 && !manifest_check.uuid) { + printf("ERROR manifest does not have any known hash format in %s", params.manifest_file.c_str()); + return HASH_EXIT_MANIFEST_UNKNOWN_HASH; + } + + printf("manifest %s", params.manifest_file.c_str()); + + if (manifest_check.sha256) { + printf(" sha256"); + } + + if (manifest_check.sha1) { + printf(" sha1"); + } + + if (manifest_check.xxh64) { + printf(" xxh64"); + } + + if (manifest_check.uuid) { + printf(" uuid"); + } + + printf("\n"); + + // Autoselect the highest security hash if manifest is provided but + // the user has not specifically defined the hash they care about + if (!params.xxh64 && !params.sha1 && !params.uuid && !params.sha256) { + // User has not selected a specific value, pick most secure hash + if (manifest_check.sha256) { + params.sha256 = true; + } else if (manifest_check.sha1) { + params.sha1 = true; + } else if (manifest_check.xxh64) { + params.xxh64 = true; + } else if (manifest_check.uuid) { + params.uuid = true; + } + } + + params.manifest_is_usable = true; + } + + // By default if no swich argument provided, assume xxh64 + if (!params.xxh64 && !params.sha1 && !params.uuid && !params.sha256) { + params.xxh64 = true; + } + + hash_exit_code_t exit_code = gguf_hash(params); + + if (params.manifest_is_usable) { + printf("\nVerification results for %s - %s\n", params.manifest_file.c_str(), hash_exit_code_to_str(exit_code)); + } + + return exit_code; +} diff --git a/examples/gguf/CMakeLists.txt b/examples/gguf/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..fb04eb83f34ce16059d0e56fbd1759849e873741 --- /dev/null +++ b/examples/gguf/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-gguf) +add_executable(${TARGET} gguf.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE ggml ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/gguf/gguf.cpp b/examples/gguf/gguf.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f31989c8c55c6a5b6b2a7369e47bc48d4677d3ea --- /dev/null +++ b/examples/gguf/gguf.cpp @@ -0,0 +1,265 @@ +#include "ggml.h" +#include "gguf.h" + +#include +#include +#include +#include + +#undef MIN +#undef MAX +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +template +static std::string to_string(const T & val) { + std::stringstream ss; + ss << val; + return ss.str(); +} + +static bool gguf_ex_write(const std::string & fname) { + struct gguf_context * ctx = gguf_init_empty(); + + gguf_set_val_u8 (ctx, "some.parameter.uint8", 0x12); + gguf_set_val_i8 (ctx, "some.parameter.int8", -0x13); + gguf_set_val_u16 (ctx, "some.parameter.uint16", 0x1234); + gguf_set_val_i16 (ctx, "some.parameter.int16", -0x1235); + gguf_set_val_u32 (ctx, "some.parameter.uint32", 0x12345678); + gguf_set_val_i32 (ctx, "some.parameter.int32", -0x12345679); + gguf_set_val_f32 (ctx, "some.parameter.float32", 0.123456789f); + gguf_set_val_u64 (ctx, "some.parameter.uint64", 0x123456789abcdef0ull); + gguf_set_val_i64 (ctx, "some.parameter.int64", -0x123456789abcdef1ll); + gguf_set_val_f64 (ctx, "some.parameter.float64", 0.1234567890123456789); + gguf_set_val_bool(ctx, "some.parameter.bool", true); + gguf_set_val_str (ctx, "some.parameter.string", "hello world"); + + gguf_set_arr_data(ctx, "some.parameter.arr.i16", GGUF_TYPE_INT16, std::vector{ 1, 2, 3, 4, }.data(), 4); + gguf_set_arr_data(ctx, "some.parameter.arr.f32", GGUF_TYPE_FLOAT32, std::vector{ 3.145f, 2.718f, 1.414f, }.data(), 3); + gguf_set_arr_str (ctx, "some.parameter.arr.str", std::vector{ "hello", "world", "!" }.data(), 3); + + struct ggml_init_params params = { + /*.mem_size =*/ 128ull*1024ull*1024ull, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ false, + }; + + struct ggml_context * ctx_data = ggml_init(params); + + const int n_tensors = 10; + + // tensor infos + for (int i = 0; i < n_tensors; ++i) { + const std::string name = "tensor_" + to_string(i); + + int64_t ne[GGML_MAX_DIMS] = { 1 }; + int32_t n_dims = rand() % GGML_MAX_DIMS + 1; + + for (int j = 0; j < n_dims; ++j) { + ne[j] = rand() % 10 + 1; + } + + struct ggml_tensor * cur = ggml_new_tensor(ctx_data, GGML_TYPE_F32, n_dims, ne); + ggml_set_name(cur, name.c_str()); + + { + float * data = (float *) cur->data; + for (int j = 0; j < ggml_nelements(cur); ++j) { + data[j] = 100 + i; + } + } + + gguf_add_tensor(ctx, cur); + } + + gguf_write_to_file(ctx, fname.c_str(), false); + + printf("%s: wrote file '%s;\n", __func__, fname.c_str()); + + ggml_free(ctx_data); + gguf_free(ctx); + + return true; +} + +// just read tensor info +static bool gguf_ex_read_0(const std::string & fname) { + struct gguf_init_params params = { + /*.no_alloc = */ false, + /*.ctx = */ NULL, + }; + + struct gguf_context * ctx = gguf_init_from_file(fname.c_str(), params); + + if (!ctx) { + fprintf(stderr, "%s: failed to load '%s'\n", __func__, fname.c_str()); + return false; + } + + printf("%s: version: %d\n", __func__, gguf_get_version(ctx)); + printf("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx)); + printf("%s: data offset: %zu\n", __func__, gguf_get_data_offset(ctx)); + + // kv + { + const int n_kv = gguf_get_n_kv(ctx); + + printf("%s: n_kv: %d\n", __func__, n_kv); + + for (int i = 0; i < n_kv; ++i) { + const char * key = gguf_get_key(ctx, i); + + printf("%s: kv[%d]: key = %s\n", __func__, i, key); + } + } + + // find kv string + { + const char * findkey = "some.parameter.string"; + + const int keyidx = gguf_find_key(ctx, findkey); + if (keyidx == -1) { + printf("%s: find key: %s not found.\n", __func__, findkey); + } else { + const char * key_value = gguf_get_val_str(ctx, keyidx); + printf("%s: find key: %s found, kv[%d] value = %s\n", __func__, findkey, keyidx, key_value); + } + } + + // tensor info + { + const int n_tensors = gguf_get_n_tensors(ctx); + + printf("%s: n_tensors: %d\n", __func__, n_tensors); + + for (int i = 0; i < n_tensors; ++i) { + const char * name = gguf_get_tensor_name (ctx, i); + const size_t size = gguf_get_tensor_size (ctx, i); + const size_t offset = gguf_get_tensor_offset(ctx, i); + + printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu\n", __func__, i, name, size, offset); + } + } + + gguf_free(ctx); + + return true; +} + +// read and create ggml_context containing the tensors and their data +static bool gguf_ex_read_1(const std::string & fname, bool check_data) { + struct ggml_context * ctx_data = NULL; + + struct gguf_init_params params = { + /*.no_alloc = */ false, + /*.ctx = */ &ctx_data, + }; + + struct gguf_context * ctx = gguf_init_from_file(fname.c_str(), params); + + printf("%s: version: %d\n", __func__, gguf_get_version(ctx)); + printf("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx)); + printf("%s: data offset: %zu\n", __func__, gguf_get_data_offset(ctx)); + + // kv + { + const int n_kv = gguf_get_n_kv(ctx); + + printf("%s: n_kv: %d\n", __func__, n_kv); + + for (int i = 0; i < n_kv; ++i) { + const char * key = gguf_get_key(ctx, i); + + printf("%s: kv[%d]: key = %s\n", __func__, i, key); + } + } + + // tensor info + { + const int n_tensors = gguf_get_n_tensors(ctx); + + printf("%s: n_tensors: %d\n", __func__, n_tensors); + + for (int i = 0; i < n_tensors; ++i) { + const char * name = gguf_get_tensor_name (ctx, i); + const size_t size = gguf_get_tensor_size (ctx, i); + const size_t offset = gguf_get_tensor_offset(ctx, i); + + printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu\n", __func__, i, name, size, offset); + } + } + + // data + { + const int n_tensors = gguf_get_n_tensors(ctx); + + for (int i = 0; i < n_tensors; ++i) { + printf("%s: reading tensor %d data\n", __func__, i); + + const char * name = gguf_get_tensor_name(ctx, i); + + struct ggml_tensor * cur = ggml_get_tensor(ctx_data, name); + + printf("%s: tensor[%d]: n_dims = %d, ne = (%d, %d, %d, %d), name = %s, data = %p\n", + __func__, i, ggml_n_dims(cur), int(cur->ne[0]), int(cur->ne[1]), int(cur->ne[2]), int(cur->ne[3]), cur->name, cur->data); + + // print first 10 elements + const float * data = (const float *) cur->data; + + printf("%s data[:10] : ", name); + for (int j = 0; j < MIN(10, ggml_nelements(cur)); ++j) { + printf("%f ", data[j]); + } + printf("\n\n"); + + // check data + if (check_data) { + const float * data = (const float *) cur->data; + for (int j = 0; j < ggml_nelements(cur); ++j) { + if (data[j] != 100 + i) { + fprintf(stderr, "%s: tensor[%d], data[%d]: found %f, expected %f\n", __func__, i, j, data[j], float(100 + i)); + gguf_free(ctx); + return false; + } + } + } + } + } + + printf("%s: ctx_data size: %zu\n", __func__, ggml_get_mem_size(ctx_data)); + + ggml_free(ctx_data); + gguf_free(ctx); + + return true; +} + +int main(int argc, char ** argv) { + if (argc < 3) { + printf("usage: %s data.gguf r|w [n]\n", argv[0]); + printf("r: read data.gguf file\n"); + printf("w: write data.gguf file\n"); + printf("n: no check of tensor data\n"); + return -1; + } + bool check_data = true; + if (argc == 4) { + check_data = false; + } + + srand(123456); + + const std::string fname(argv[1]); + const std::string mode (argv[2]); + + GGML_ASSERT((mode == "r" || mode == "w") && "mode must be r or w"); + + if (mode == "w") { + GGML_ASSERT(gguf_ex_write(fname) && "failed to write gguf file"); + } else if (mode == "r") { + GGML_ASSERT(gguf_ex_read_0(fname) && "failed to read gguf file"); + GGML_ASSERT(gguf_ex_read_1(fname, check_data) && "failed to read gguf file"); + } + + return 0; +} diff --git a/examples/gritlm/CMakeLists.txt b/examples/gritlm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..fa1b4dc70c2f63f94c7d092360a9365620b9fad1 --- /dev/null +++ b/examples/gritlm/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-gritlm) +add_executable(${TARGET} gritlm.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/gritlm/README.md b/examples/gritlm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..786ba57363defd93fa8c51ac453a33a380e82ce2 --- /dev/null +++ b/examples/gritlm/README.md @@ -0,0 +1,62 @@ +## Generative Representational Instruction Tuning (GRIT) Example +[gritlm] a model which can generate embeddings as well as "normal" text +generation depending on the instructions in the prompt. + +* Paper: https://arxiv.org/pdf/2402.09906.pdf + +### Retrieval-Augmented Generation (RAG) use case +One use case for `gritlm` is to use it with RAG. If we recall how RAG works is +that we take documents that we want to use as context, to ground the large +language model (LLM), and we create token embeddings for them. We then store +these token embeddings in a vector database. + +When we perform a query, prompt the LLM, we will first create token embeddings +for the query and then search the vector database to retrieve the most +similar vectors, and return those documents so they can be passed to the LLM as +context. Then the query and the context will be passed to the LLM which will +have to _again_ create token embeddings for the query. But because gritlm is used +the first query can be cached and the second query tokenization generation does +not have to be performed at all. + +### Running the example +Download a Grit model: +```console +$ scripts/hf.sh --repo cohesionet/GritLM-7B_gguf --file gritlm-7b_q4_1.gguf --outdir models +``` + +Run the example using the downloaded model: +```console +$ ./llama-gritlm -m models/gritlm-7b_q4_1.gguf + +Cosine similarity between "Bitcoin: A Peer-to-Peer Electronic Cash System" and "A purely peer-to-peer version of electronic cash w" is: 0.605 +Cosine similarity between "Bitcoin: A Peer-to-Peer Electronic Cash System" and "All text-based language problems can be reduced to" is: 0.103 +Cosine similarity between "Generative Representational Instruction Tuning" and "A purely peer-to-peer version of electronic cash w" is: 0.112 +Cosine similarity between "Generative Representational Instruction Tuning" and "All text-based language problems can be reduced to" is: 0.547 + +Oh, brave adventurer, who dared to climb +The lofty peak of Mt. Fuji in the night, +When shadows lurk and ghosts do roam, +And darkness reigns, a fearsome sight. + +Thou didst set out, with heart aglow, +To conquer this mountain, so high, +And reach the summit, where the stars do glow, +And the moon shines bright, up in the sky. + +Through the mist and fog, thou didst press on, +With steadfast courage, and a steadfast will, +Through the darkness, thou didst not be gone, +But didst climb on, with a steadfast skill. + +At last, thou didst reach the summit's crest, +And gazed upon the world below, +And saw the beauty of the night's best, +And felt the peace, that only nature knows. + +Oh, brave adventurer, who dared to climb +The lofty peak of Mt. Fuji in the night, +Thou art a hero, in the eyes of all, +For thou didst conquer this mountain, so bright. +``` + +[gritlm]: https://github.com/ContextualAI/gritlm diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bdab052c3390ff7b440275237b46fbf13612da80 --- /dev/null +++ b/examples/gritlm/gritlm.cpp @@ -0,0 +1,231 @@ +#include "arg.h" +#include "common.h" +#include "llama.h" + +#include +#include + +// #define GRIT_DEBUG + +static std::vector> encode(llama_context * ctx, const std::vector & sentences, const std::string & instruction) { + std::vector> result; + + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1); + + for (uint64_t i = 0; i < sentences.size(); i++) { + common_batch_clear(batch); + + const std::string input_string = instruction + sentences[i]; + + std::vector inputs = common_tokenize(vocab, input_string, true, false); + + const int32_t n_toks = inputs.size(); + + // GritLM seems to have EOS = "" + // https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18 + // inputs.push_back(llama_vocab_eos(vocab)); + + // we want to ignore instruction tokens for mean pooling + const int32_t n_inst = common_tokenize(vocab, instruction, true, false).size(); + +#ifdef GRIT_DEBUG + // debug tokens - should be matching as referenced in the GritLM sample + std::for_each(inputs.begin(), inputs.end(), [&ctx](llama_token t) { + std::printf("[%u:%s]", t, llama_token_to_piece(ctx, t).c_str()); + }); + std::printf("\n"); +#endif + + // add input to batch (this increments n_tokens) + for (int32_t j = 0; j < n_toks; j++) { + common_batch_add(batch, inputs[j], j, { 0 }, true); + } + + // clear previous kv_cache values (irrelevant for embeddings) + llama_memory_clear(llama_get_memory(ctx), true); + llama_set_causal_attn(ctx, false); + + // run model + llama_decode(ctx, batch); + + // get embedding dimensions + uint64_t n_embd = llama_model_n_embd(model); + + // allocate embedding output + std::vector emb_unorm(n_embd, 0.0f); + + // sum up all token embeddings + for (int32_t k = n_inst; k < n_toks; k++) { + float * emb = llama_get_embeddings_ith(ctx, k); + for (uint64_t j = 0; j < n_embd; j++) { + emb_unorm[j] += emb[j]; + } + } + + // divide by number of tokens (mean pooling) + { + const uint64_t n_sent = n_toks - n_inst; + + for (uint64_t j = 0; j < n_embd; j++) { + emb_unorm[j] /= n_sent; + } + } + + std::vector emb_norm(emb_unorm.size()); + common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd, 2); + result.push_back(emb_norm); + +#ifdef GRIT_DEBUG + // print out emb_norm + std::printf("embedding %ld: ", i); + for (uint64_t j = 0; j < n_embd; j++) { + std::printf("%.5f ", emb_norm[j]); + } + std::printf("\n\n"); +#endif + } + + llama_batch_free(batch); + + return result; +} + +static std::string generate(llama_context * ctx, llama_sampler * smpl, const std::string & prompt, bool stream) { + std::string result; + + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + llama_token eos_token = llama_vocab_eos(vocab); + + llama_memory_clear(llama_get_memory(ctx), true); + llama_set_causal_attn(ctx, true); + + llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); + + std::vector inputs = common_tokenize(vocab, prompt, false, true); + int32_t i_current_token = 0; + + while (true) { + common_batch_clear(bat); + { + const int32_t n_inputs = inputs.size(); + + for (int32_t i = 0; i < n_inputs; i++) { + common_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1); + } + } + inputs.clear(); + + llama_decode(ctx, bat); + + llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1); + + if (token == eos_token) { + break; + } + + std::string piece = common_token_to_piece(ctx, token); + if (stream) { + std::printf("%s", piece.c_str()); + std::fflush(stdout); + } + + inputs.push_back(token); + + result += piece; + } + + if (stream) { + std::printf("\n"); + } + + llama_batch_free(bat); + + return result; +} + +static std::string gritlm_instruction(const std::string & instruction) { + return !instruction.empty() ? "<|user|>\n" + instruction + "\n<|embed|>\n" : "<|embed|>\n"; +} + +int main(int argc, char * argv[]) { + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { + return 1; + } + + common_init(); + + llama_model_params mparams = common_model_params_to_llama(params); + llama_context_params cparams = common_context_params_to_llama(params); + + cparams.embeddings = true; + + llama_backend_init(); + + llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams); + + // create generation context + llama_context * ctx = llama_init_from_model(model, cparams); + + auto sparams = llama_sampler_chain_default_params(); + + sparams.no_perf = false; + + llama_sampler * smpl = llama_sampler_chain_init(sparams); + + llama_sampler_chain_add(smpl, llama_sampler_init_greedy()); + + // ### Embedding/Representation ### + // samples taken from: https://github.com/ContextualAI/gritlm#basic + { + const std::string instruction = "Given a scientific paper title, retrieve the paper's abstract"; + + const std::vector queries = { + "Bitcoin: A Peer-to-Peer Electronic Cash System", + "Generative Representational Instruction Tuning", + }; + + const std::vector documents = { + "A purely peer-to-peer version of electronic cash would allow online payments to be sent directly from one party to another without going through a financial institution. Digital signatures provide part of the solution, but the main benefits are lost if a trusted third party is still required to prevent double-spending. We propose a solution to the double-spending problem using a peer-to-peer network. The network timestamps transactions by hashing them into an ongoing chain of hash-based proof-of-work, forming a record that cannot be changed without redoing the proof-of-work. The longest chain not only serves as proof of the sequence of events witnessed, but proof that it came from the largest pool of CPU power. As long as a majority of CPU power is controlled by nodes that are not cooperating to attack the network, they'll generate the longest chain and outpace attackers. The network itself requires minimal structure. Messages are broadcast on a best effort basis, and nodes can leave and rejoin the network at will, accepting the longest proof-of-work chain as proof of what happened while they were gone.", + "All text-based language problems can be reduced to either generation or embedding. Current models only perform well at one or the other. We introduce generative representational instruction tuning (GRIT) whereby a large language model is trained to handle both generative and embedding tasks by distinguishing between them through instructions. Compared to other open models, our resulting GritLM 7B sets a new state of the art on the Massive Text Embedding Benchmark (MTEB) and outperforms all models up to its size on a range of generative tasks. By scaling up further, GritLM 8X7B outperforms all open generative language models that we tried while still being among the best embedding models. Notably, we find that GRIT matches training on only generative or embedding data, thus we can unify both at no performance loss. Among other benefits, the unification via GRIT speeds up Retrieval-Augmented Generation (RAG) by > 60% for long documents, by no longer requiring separate retrieval and generation models. Models, code, etc. are freely available at https://github.com/ContextualAI/gritlm.", + }; + + // No need to add instruction for retrieval documents + const std::vector> d_rep = encode(ctx, documents, gritlm_instruction("")); + const std::vector> q_rep = encode(ctx, queries, gritlm_instruction(instruction)); + + const int n_embd = llama_model_n_embd(model); + + const float cosine_sim_q0_d0 = common_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd); + const float cosine_sim_q0_d1 = common_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd); + const float cosine_sim_q1_d0 = common_embd_similarity_cos(q_rep[1].data(), d_rep[0].data(), n_embd); + const float cosine_sim_q1_d1 = common_embd_similarity_cos(q_rep[1].data(), d_rep[1].data(), n_embd); + + std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[0].c_str(), cosine_sim_q0_d0); + std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[1].c_str(), cosine_sim_q0_d1); + std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[0].c_str(), cosine_sim_q1_d0); + std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1); + } + + llama_set_embeddings(ctx, false); + + // ### Generation ### + // GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction + { + const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n"; + std::string response = generate(ctx, smpl, prompt, true); + } + + llama_sampler_free(smpl); + llama_free(ctx); + llama_model_free(model); + llama_backend_free(); + + return 0; +} diff --git a/examples/jeopardy/README.md b/examples/jeopardy/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ffa13cbf349b281ec82373d8bb497678c8c93e4d --- /dev/null +++ b/examples/jeopardy/README.md @@ -0,0 +1,21 @@ +# llama.cpp/example/jeopardy + +This is pretty much just a straight port of aigoopy/llm-jeopardy/ with an added graph viewer. + +The jeopardy test can be used to compare the fact knowledge of different models and compare them to each other. This is in contrast to some other tests, which test logical deduction, creativity, writing skills, etc. + + +Step 1: Open jeopardy.sh and modify the following: +``` +MODEL=(path to your model) +MODEL_NAME=(name of your model) +prefix=(basically, if you use vicuna it's Human: , if you use something else it might be User: , etc) +opts=(add -instruct here if needed for your model, or anything else you want to test out) +``` +Step 2: Run `jeopardy.sh` from the llama.cpp folder + +Step 3: Repeat steps 1 and 2 until you have all the results you need. + +Step 4: Run `graph.py`, and follow the instructions. At the end, it will generate your final graph. + +Note: The Human bar is based off of the full, original 100 sample questions. If you modify the question count or questions, it will not be valid. diff --git a/examples/jeopardy/graph.py b/examples/jeopardy/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc0706b86d05617e4733a8ed46b774dc186e602 --- /dev/null +++ b/examples/jeopardy/graph.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +import matplotlib.pyplot as plt +import os +import csv + +labels = [] +numbers = [] +numEntries = 1 + +rows = [] + + +def bar_chart(numbers, labels, pos): + plt.bar(pos, numbers, color='blue') + plt.xticks(ticks=pos, labels=labels) + plt.title("Jeopardy Results by Model") + plt.xlabel("Model") + plt.ylabel("Questions Correct") + plt.show() + + +def calculatecorrect(): + directory = os.fsencode("./examples/jeopardy/results/") + csv_reader = csv.reader(open("./examples/jeopardy/qasheet.csv", 'rt'), delimiter=',') + for row in csv_reader: + global rows + rows.append(row) + for listing in os.listdir(directory): + filename = os.fsdecode(listing) + if filename.endswith(".txt"): + file = open("./examples/jeopardy/results/" + filename, "rt") + global labels + global numEntries + global numbers + labels.append(filename[:-4]) + numEntries += 1 + i = 1 + totalcorrect = 0 + for line in file.readlines(): + if line.strip() != "------": + print(line) + else: + print("Correct answer: " + rows[i][2] + "\n") + i += 1 + print("Did the AI get the question right? (y/n)") + if input() == "y": + totalcorrect += 1 + numbers.append(totalcorrect) + + +if __name__ == '__main__': + calculatecorrect() + pos = list(range(numEntries)) + labels.append("Human") + numbers.append(48.11) + bar_chart(numbers, labels, pos) + print(labels) + print(numbers) diff --git a/examples/jeopardy/jeopardy.sh b/examples/jeopardy/jeopardy.sh new file mode 100644 index 0000000000000000000000000000000000000000..07bcb3b8d78ac2077387d11aa73d7487fa757057 --- /dev/null +++ b/examples/jeopardy/jeopardy.sh @@ -0,0 +1,30 @@ +#!/bin/bash +set -e + +MODEL=./models/ggml-vicuna-13b-1.1-q4_0.bin +MODEL_NAME=Vicuna + +# exec options +prefix="Human: " # Ex. Vicuna uses "Human: " +opts="--temp 0 -n 80" # additional flags +nl=' +' +introduction="You will be playing a game of Jeopardy. Simply answer the question in the correct format (Ex. What is Paris, or Who is George Washington)." + +# file options +question_file=./examples/jeopardy/questions.txt +touch ./examples/jeopardy/results/$MODEL_NAME.txt +output_file=./examples/jeopardy/results/$MODEL_NAME.txt + +counter=1 + +echo 'Running' +while IFS= read -r question +do + exe_cmd="./llama-cli -p "\"$prefix$introduction$nl$prefix$question\"" "$opts" -m ""\"$MODEL\""" >> ""\"$output_file\"" + echo $counter + echo "Current Question: $question" + eval "$exe_cmd" + echo -e "\n------" >> $output_file + counter=$((counter+1)) +done < "$question_file" diff --git a/examples/jeopardy/qasheet.csv b/examples/jeopardy/qasheet.csv new file mode 100644 index 0000000000000000000000000000000000000000..35b08418956ab5c54d0afa4c3f2b896931272d36 --- /dev/null +++ b/examples/jeopardy/qasheet.csv @@ -0,0 +1,103 @@ +Index,Original Category,Original Correct Question,Model Prompt +1,The Oscars,Who is John Williams?,Which actor Born in 1932 was the son of a percussionist in the CBS radio orchestra has been nominated for 53 Oscars? +2,English Literature,What is Paradise Lost?,"What work in English Literature says: 'The mind is its own place, & in itself can make a heaven of hell, a hell of heaven. What matter where, if I be still the same'?" +3,Writers’ Lesser-Known Works,Who is Niccolò Machiavelli?,"Known for more philosophical works, he wrote the play 'La Mandragola', in which Florentines are rewarded for immoral actions?" +4,Exploration,What is Easter Island (Rapa Nui)?,"James Cook's account of a 1774 visit where records an object 'near 27 feet long, and upwards of 8 feet over the breast or shoulders'?" +5,The Bill of Rights,What is the Eighth Amendment?,England's 'Bloody Assizes' & a 1685 life sentence for perjury were 2 main origins of which amendment to the U.S. Constitution? +6,Nobel Peace Prize Winners,Who are Nelson Mandela & Desmond Tutu?,"Which nobel peace price winners each lived at times on Vilakazi St. in Soweto , so it claims to be the world's only street home to 2 Nobel Peace Prize winners?" +7,Famous Names,Who is Walt Disney?,"In 1966, the year of who's death did he share plans for an experimental prototype community in Florida?" +8,Geography,What is Colombia?,"Of the 13 nations through which the Equator passes, what is the only one whose coastline borders the Caribbean Sea?" +9,Fashion History,What are rhinestones?,"Which decorative items in fashion history get their name from their origin in the port city of Strasbourg, on the border of France & Germany?" +10,Movies of the ’80s,What is Driving Miss Daisy?,What 1980's movie is based on an off-Broadway play with just 3 characters and won the Best Picture Oscar & the actors in all 3 roles were nominated? +11,Novelists,Who is John Grisham?,"A 2012 book review for which novelist noted subjects that 'sparked his ire': capital punishment, big tobacco & 'the plight of the unjustly convicted'?" +12,20th Century Eponyms,What is the Maginot Line?,"A 1940 headline about what 20th Century Eponym included 'failure', 'liability when it came to offense' & 'stout hearts no match for tanks'?" +13,City History,What is Stockholm?,"Over 700 years after its traditional 1252 founding date, what port city became associated with a psychological response?" +14,Brand Names,What is Jacuzzi?,"The success of what brand has its roots with a hydrotherapy pump its cofounder created for his son, who had arthritis?" +15,American Authors,Who is Washington Irving?,"In a periodical in 1807, what American Author called New York City 'Gotham, Gotham! Most enlightened of cities'?" +16,Symbols,What is “less than”?,What symbol is a rotated V in math and a feeling of some marginalized or underrepresented people in society? +17,Movie Theme Songs,Who is James Bond?,"Monty Norman, the composer of what character's theme, said the staccato riff conveyed sexiness, mystery & ruthlessness?" +18,American Novelists,Who is Joseph Heller?,"What American Novelist served with an airman named Yohannan in World War II & despite what readers might think, he said he enjoyed his service?" +19,Medieval Places,"What is Canterbury, England? (Canterbury Cathedral)","In what Medieval place did one of the participants in an 1170 event say, 'Let us away, knights; he will rise no more'?" +20,Countries of Africa,What is Morocco?,"At one time a province of the Roman Empire, what African country kingdom is known to Arabic scholars as Al-Maghrib Al-Aqsa, 'the far west'?" +21,Statehood,What is Wyoming?,Congress relented in 1890 after what prospective state said it would wait 100 years rather than come in without the women? +22,1980s Movies,What is Raiders of the Lost Ark?,"A writer & producer of what movie said he wanted it to be like a Western or James Bond film, 'only it takes place in the 30s'?" +23,Art Exhibitions,Who is Rembrandt?,In 1898 what's been called the first blockbuster art show was devoted to which artist & put on for Queen Wilhelmina's coronation? +24,Countries of the World,What is Mongolia?,"Part of the largest contiguous land empire during the 1200s & 1300s, today what is the world's second-largest landlocked country?" +25,Literature,What is “Howl”?,A 2006 book was titled 'The Poem That Changed America:' What 'Fifty Years Later'? +26,Invasions,Who is William of Orange?,"Backed by 14,000 troops, who invaded England to restore, in his words, its 'religion, laws, and liberties'?" +27,Landmarks,What is the Eiffel Tower?,"After its completion in the late 19th c., what was landmark was called 'a truly tragic street lamp' & a 'high & skinny pyramid of iron ladders'?" +28,Geographic Name’s the Same,What is Dover?,"The busiest passenger port in the U.K., what shares its name with a capital of one of the original 13 states?" +29,Names in the Bookstore,Who is Peter Mark Roget?,"This man made lists, perhaps to cope with depression; a set of lists he published in 1852 made whose name synonymous with a type of book?" +30,U.S. History,Who is Dr. Samuel Mudd?,"An 1869 presidential pardon was granted to which man, due in part to a plea by the Medical Society of Harford County, Maryland?" +31,American Literature,What is The Things They Carried?,"Letters, pocket knives, C rations & steel helmets are among the tangible items referred to in the title of what American literature modern war classic?" +32,Nonfiction,What is The Communist Manifesto,"What nonfiction book has the line, 'The discovery of America…opened up fresh ground for the rising bourgeoisie'?" +33, a new version was passed 81 years later,Laws in U.S. History,What is the Civil Rights Act?,,,,,,,,,,,,,,,,,,0, 2/3 +34,Names of Myth,Who is Helen of Troy?,"Whose brothers, Castor & Pollux, saved her after Theseus stole her away as a kid; a larger force would seek her later in life?" +35,African Countries,What is Sudan?,"Once Africa's largest country in area, what African Country dropped to third in 2011 when a portion of it declared independence?" +36,The Ancient World,What is Alexandria?,"The ancient writer Galen said books on ships arriving to what city's port were seized, originals kept & copies returned?" +37,Famous Names,Who is Andy Warhol?,"For a special 1970s cookbook, who provided one simple recipe–a can of Campbell's tomato soup & 2 cans of milk?" +38,People & Places,What is Guam?,"Thought to descend from people of Southeast Asia, the Chamorro make up what U.S. territory’s largest ethnic group?" +39,Current World Leaders,What is the Philippines?,"In office from 2022, the president of what country has taken so many foreign trips a play on his name is 'Ferdinand Magellan Jr.'?" +40,Writers & The South,Who is Tennessee Williams?,In 1939 which writer lived on Toulouse Street in the French Quarter & chose the professional name that bonded him to the South? +41,National Parks,What is Yellowstone?,"What National Park is named for a river indigenous people called Mi tse a-da-zi, translated by French-speaking trappers as 'Pierre Jaune'?" +42,Sports,Who are the Harlem Globetrotters?,"In 2010 who introduced the 4-point shot, 35 feet from the basket?" +43,The U.S. Military,What is “Top Gun”?,Losses over Asia in the 1960s led to the establishment of the program known as what at a San Diego naval base in 1969? +44,Art & Science,What is Halley’s Comet?,"A craft that visited what was named for Giotto, based on the story that 680 years earlier, the painter depicted it as the Star of Bethlehem?" +45,Words From World War I,What is “tank”?,"In World War I, 'Cistern' & 'reservoir' were suggested names for what secret invention, but the British preferred this less clumsy monosyllable?" +46,European History,What is Holy Roman Emperor?,"Until 1806, some German nobles included among their honors the title of 'Elector' for their role in selecting this personage?" +47,Theater History,Who is Peter Pan?,"In 1904, wearing a harness, actress Nina Boucicault became the first to play what character onstage?" +48,European Cities,What is Aachen?,"Alphabetically the first German city in encyclopedias, what was also the first one taken by the Allies in World War II?" +49,Word Origins,What is mantra?,This Sanskrit word referring to a spoken word or phrase comes from a word for 'to think'? +50,Inventions,What is barbed wire?,1917's 'Elements of Trench Warfare' said what Old West invention was 'difficult to destroy' & 'difficult to get through'? +51,World War II,What is Schindler’s list?,"Mimi Reinhard, who never learned to type using more than 2 fingers, produced what in World War II with 1,100 names, including hers?" +52, their offspring was the source of this mythical object,Mythology,What is the Golden Fleece? +53,Literature,What is Pride and Prejudice?,"Published in 2011, P.D. James' final novel, 'Death Comes to Pemberley', was a sequel to what novel from 200 years earlier?" +54, only these 2 west of the Mississippi River border each other,U.S. State Names,What are Oregon & Nevada? +55,Word Origins,What is passion?,"Originally relating to a story of suffering, what word now more commonly refers to strong emotion of any kind?" +56,World Cinema,What is La Vie en Rose?,"The 2007 biopic called 'La Môme' in France, meaning 'The Kid', was released in the U.S. under what other French title?" +57,History,What is Santa Maria?,"Returning home in 1493, Columbus stopped in the Azores at an island with what name, also something he'd lost off the Haiti coast?" +58,Landmarks,What is a kremlin?,Pskov & Nizhny Novgorod are 2 of the cities that have a fortress called what? +59,Foreign-Born Authors,Who is Vladimir Nabokov?,In the 1950s the New York Times said what author 'is writing about all lust' & his lecherous narrator 'is all of us'? +60,Astronomy & Geography,What is Capricorn?,"At the winter solstice, the sun is in Sagittarius; it once appeared in what constellation, giving a geographic feature its name?" +61,Television,What is Law & Order?,"Mike Post combined the sound of a slamming jail door, an anvil & 100 men stomping on a floor for what television series that debuted in 1990?" +62,British Landmarks,What is the Tower of London?,"Like Sir Thomas More, 3 16th century English queens are buried at what British location?" +63,Early American History,What are witches?,"In 1692 Increase Mather wrote, 'It were better that ten suspected' of these who 'escape, than that one innocent person … be condemned'?" +64,Geography Mnemonics,What are Arkansas and Louisiana?,"The Geography Mnemonic Mimal, sometimes said to be the silhouette of a chef or elf, stands for Minnesota, Iowa, Missouri, and what other 2 states?" +65,Business Milestones,What is the Ford Model T?,"What was first sold in 1908, at a price equivalent to about $27,000 today?" +66,In The Bookstore,Who is Tom Clancy?,The name of what author dead since 2013 now appears on books written by a former U.S. marshal & a former Apache helicopter pilot? +67,Historic Art,What is the Bayeux Tapestry?,The artwork once known in France as 'la tapisserie de la Reine Mathilde' is better known as what? +68,Pop Stars,Who is Madonna?,In 2022 which pop star became the first woman to have a Billboard Top 10 album in 5 decades starting with the 1980s? +69,Classic Tale Characters,Who is Scheherazade?,"In one 19th century translation, what female classic tale character 'perceived the dawn of day and ceased' speaking nearly 1,000 times?" +70,USA,What is Jack Daniel’s?,"Ironically, though what company founded in the 1860s is Moore County, Tennessee's largest employer, Moore is a dry county?" +71,Historic People,Who was William Bligh?,"After a 1789 event, who wrote, 'My first determination was to seek a supply of…water at Tofoa, & afterwards to sail for Tongataboo'?" +72,The Movies,What is The Godfather?,Laurence Olivier & Ernest Borgnine were considered for the lead role & Sergio Leone to direct for what film that turned 50 in 2022? +73,Continental Geography,What is Colombia?,"Until a 1903 secession, what country's contiguous territory spanned 2 continents?" +74,Foreign-Born Authors,Who is Isabel Allende?,"Early in her career which foreign-born author translated romance novels into Spanish, often changing the dialogue to make the heroines smarter?" +75,Historic Crimes,What is the Mona Lisa?,"Saying it was stolen by Napoleon, self-styled Italian patriot Vincenzo Peruggia took what in 1911?" +76,U.S. Bodies of Water,What is Lake Mead?,"Continuing a downward trend, in July 2022 what US body of water was at 27% capacity, its lowest level since 1937 when it was first being filled?" +77,Gods & Goddesses,Who is Aurora (or Eos)?,"Each morning which goddess began her ride in her chariot across the sky ahead of her brother Sol, or Helios?" +78,America At War,What is the Battle of New Orleans?,"Until the Civil War, the Jan. 8 date of what American battle of dubious military importance but big morale value was a national holiday?" +79,Children’s Books,What is The Velveteen Rabbit?,"Which children's book title character is told 'By the time you are real, most of your hair has been loved off your eyes drop out & you get shabby'?" +80,TV Finales,What is Grace and Frankie?,"In a TV reunion over 40 years in the making, Dolly Parton appeared as an angel named Agnes in the final episode of what comedy in 2022?" +81,American Poems,Who is Evangeline?,"In an 1847 American poem what character sees her town of Grand-Pré burned, but finally reunites with her beau for a kiss before his death?" +82,Famous Names,Who is Banksy?,"In 2001 who published a book called 'Banging Your Head Against a Brick Wall'; in 2002, 'Existencilism'?" +83,Children’s Lit,What is Charlotte’s Web?,The title object of what childrens book 'never looked more beautiful each strand held dozens of bright drops of early morning dew'? +84,Classic Songs,What is “Here Comes Santa Claus”?,The shouts of excited children at a 1946 holiday parade are said to have inspired what perennial classic song favorite? +85,Brand Names,What are Milk Duds?,"Unable to make what candies perfectly round, the confectioner embraced this flawed name for the product?" +86,Countries of the World,What is Italy?,"What country is home to 58 UNESCO World Heritage Sites, more than any other country; the sites include a volcano & a lagoon?" +87,Action Movies,What is Die Hard?,"What action movie's last line is 'If this is their idea of Christmas, I gotta be here for New Years'?" +88,Presidential Facts,Who is Woodrow Wilson?,Only 3 presidents have married while in office— John Tyler was the first & which one was the last? +89,19th Century Americans,Who is Frederick Douglass?,"Demonstrating the dignity & humanity of Black Americans, who sat for 160 known photographs, the most of any American in the 19th century?" +90,Latin Phrases,What is “quid pro quo”?,"Originally, which Latin 3-word phrase referred to when a doctor or apothecary substituted one medicine for another?" +91,1970s Movies,What is Monty Python and the Holy Grail?,The 1975 premiere of what movie comedy advertised free coconuts for the first thousand in the audience? +92,Name’s The Same,What is Manhattan?,"A cocktail, an island & a WWII venture originally called 'Development of Substitute Materials' all bear what name?" +93,U.S. Presidents,Who is Calvin Coolidge?,"Which US President was sworn in twice as President within 2 years, first by his father & then later by a former U.S. President?" +94,Plays,What is The Tempest?,A 1609 story in which an exiled king of Bulgaria creates a sea palace with his magic may have inspired the plot of what play? +95,Landmarks,What is the Berlin Wall?,"In 2009, during a 20th anniversary celebration, what landmark was called 'an edifice of fear. On Nov. 9, it became a place of joy'?" +96,World Capitals,"What is Vienna, Austria?","Among what world capital's nicknames are the 'City of Classical Music' &, possibly in honor of a famous resident from 1860 to 1938, the 'City of Dreams'?" +97,Language & Its Meanings,What is a night owl?,"Now meaning someone with nocturnal habits, what catches a sleeping dove in Shakespeare's 'Lucrece'?" +98,Flags of Our Hemisphere,What is Brazil?,"The stars on what country's flag represent states, 26 of them; unlike the USA's, its 'federal district' gets its own 27th star?" +99,Names in U.S. History,Who is Oliver Brown?,What father was the only man among the 13 plaintiffs in a US class-action case filed in 1951? +100,Children’s Authors,"Who is Sarah? (from Sarah, Plain and Tall)","Reversing the story of what heroine she created, childrens author Patricia Maclachlan was born on the prairie but spent much of her life in New England?" +,,, +TOTALS,,, diff --git a/examples/jeopardy/questions.txt b/examples/jeopardy/questions.txt new file mode 100644 index 0000000000000000000000000000000000000000..eea78a057126ce5da310e31157914b540c534449 --- /dev/null +++ b/examples/jeopardy/questions.txt @@ -0,0 +1,100 @@ +Which man born in 1932 was the son of a percussionist in the CBS radio orchestra has been nominated for 53 Oscars? +What work in English Literature says: 'The mind is its own place, & in itself can make a heaven of hell, a hell of heaven. What matter where, if I be still the same'? +Known for more philosophical works, he wrote the play 'La Mandragola', in which Florentines are rewarded for immoral actions? +James Cook's account of a 1774 visit where records an object 'near 27 feet long, and upwards of 8 feet over the breast or shoulders'? +England's 'Bloody Assizes' & a 1685 life sentence for perjury were 2 main origins of which amendment to the U.S. Constitution? +Which nobel peace price winners each lived at times on Vilakazi St. in Soweto , so it claims to be the world's only street home to 2 Nobel Peace Prize winners? +In 1966, the year of who's death did he share plans for an experimental prototype community in Florida? +Of the 13 nations through which the Equator passes, what is the only one whose coastline borders the Caribbean Sea? +Which decorative items in fashion history get their name from their origin in the port city of Strasbourg, on the border of France & Germany? +What 1980's movie is based on an off-Broadway play with just 3 characters and won the Best Picture Oscar & the actors in all 3 roles were nominated? +A 2012 book review for which novelist noted subjects that 'sparked his ire': capital punishment, big tobacco & 'the plight of the unjustly convicted'? +A 1940 headline about what 20th Century Eponym included 'failure', 'liability when it came to offense' & 'stout hearts no match for tanks'? +Over 700 years after its traditional 1252 founding date, what port city became associated with a psychological response? +The success of what brand has its roots with a hydrotherapy pump its cofounder created for his son, who had arthritis? +In a periodical in 1807, what American Author called New York City 'Gotham, Gotham! Most enlightened of cities'? +What symbol is a rotated V in math and a feeling of some marginalized or underrepresented people in society? +Monty Norman, the composer of what character's theme, said the staccato riff conveyed sexiness, mystery & ruthlessness? +What American Novelist served with an airman named Yohannan in World War II & despite what readers might think, he said he enjoyed his service? +In what Medieval place did one of the participants in an 1170 event say, 'Let us away, knights; he will rise no more'? +At one time a province of the Roman Empire, what African country kingdom is known to Arabic scholars as Al-Maghrib Al-Aqsa, 'the far west'? +Congress relented in 1890 after what prospective state said it would wait 100 years rather than come in without the women? +A writer & producer of what movie said he wanted it to be like a Western or James Bond film, 'only it takes place in the 30s'? +In 1898 what's been called the first blockbuster art show was devoted to which artist & put on for Queen Wilhelmina's coronation? +Part of the largest contiguous land empire during the 1200s & 1300s, today what is the world's second-largest landlocked country? +A 2006 book was titled 'The Poem That Changed America:' What 'Fifty Years Later'? +Backed by 14,000 troops, who invaded England to restore, in his words, its 'religion, laws, and liberties'? +After its completion in the late 19th c., what was landmark was called 'a truly tragic street lamp' & a 'high & skinny pyramid of iron ladders'? +The busiest passenger port in the U.K., what shares its name with a capital of one of the original 13 states? +This man made lists, perhaps to cope with depression; a set of lists he published in 1852 made whose name synonymous with a type of book? +An 1869 presidential pardon was granted to which man, due in part to a plea by the Medical Society of Harford County, Maryland? +Letters, pocket knives, C rations & steel helmets are among the tangible items referred to in the title of what American literature modern war classic? +What nonfiction book has the line, 'The discovery of America…opened up fresh ground for the rising bourgeoisie'? +A radical Republican championed what 1875 act but the Supreme Court struck it down in 1883; a new version was passed 81 years later? +Whose brothers, Castor & Pollux, saved her after Theseus stole her away as a kid; a larger force would seek her later in life? +Once Africa's largest country in area, what African Country dropped to third in 2011 when a portion of it declared independence? +The ancient writer Galen said books on ships arriving to what city's port were seized, originals kept & copies returned? +For a special 1970s cookbook, who provided one simple recipe–a can of Campbell's tomato soup & 2 cans of milk? +Thought to descend from people of Southeast Asia, the Chamorro make up what U.S. territory’s largest ethnic group? +In office from 2022, the president of what country has taken so many foreign trips a play on his name is 'Ferdinand Magellan Jr.'? +In 1939 which writer lived on Toulouse Street in the French Quarter & chose the professional name that bonded him to the South? +What National Park is named for a river indigenous people called Mi tse a-da-zi, translated by French-speaking trappers as 'Pierre Jaune'? +In 2010 who introduced the 4-point shot, 35 feet from the basket? +Losses over Asia in the 1960s led to the establishment of the program known as what at a San Diego naval base in 1969? +A craft that visited what was named for Giotto, based on the story that 680 years earlier, the painter depicted it as the Star of Bethlehem? +In World War I, 'Cistern' & 'reservoir' were suggested names for what secret invention, but the British preferred this less clumsy monosyllable? +Until 1806, some German nobles included among their honors the title of 'Elector' for their role in selecting this personage? +In 1904, wearing a harness, actress Nina Boucicault became the first to play what character onstage? +Alphabetically the first German city in encyclopedias, what was also the first one taken by the Allies in World War II? +This Sanskrit word referring to a spoken word or phrase comes from a word for 'to think'? +1917's 'Elements of Trench Warfare' said what Old West invention was 'difficult to destroy' & 'difficult to get through'? +Mimi Reinhard, who never learned to type using more than 2 fingers, produced what in World War II with 1,100 names, including hers? +Poseidon carried off the maiden Theophane & turned her into a ewe; their offspring was the source of what mythical object? +Published in 2011, P.D. James' final novel, 'Death Comes to Pemberley', was a sequel to what novel from 200 years earlier? +5 U.S. states have 6-letter names; only which 2 west of the Mississippi River border each other? +Originally relating to a story of suffering, what word now more commonly refers to strong emotion of any kind? +The 2007 biopic called 'La Môme' in France, meaning 'The Kid', was released in the U.S. under what other French title? +Returning home in 1493, Columbus stopped in the Azores at an island with what name, also something he'd lost off the Haiti coast? +Pskov & Nizhny Novgorod are 2 of the cities that have a fortress called what? +In the 1950s the New York Times said what author 'is writing about all lust' & his lecherous narrator 'is all of us'? +At the winter solstice, the sun is in Sagittarius; it once appeared in what constellation, giving a geographic feature its name? +Mike Post combined the sound of a slamming jail door, an anvil & 100 men stomping on a floor for what television series that debuted in 1990? +Like Sir Thomas More, 3 16th century English queens are buried at what British location? +In 1692 Increase Mather wrote, 'It were better that ten suspected' of these who 'escape, than that one innocent person be condemned'? +The Geography Mnemonic Mimal, sometimes said to be the silhouette of a chef or elf, stands for Minnesota, Iowa, Missouri, and what other 2 states? +What was first sold in 1908, at a price equivalent to about $27,000 today? +The name of what author dead since 2013 now appears on books written by a former U.S. marshal & a former Apache helicopter pilot? +The artwork once known in France as 'la tapisserie de la Reine Mathilde' is better known as what? +In 2022 which pop star became the first woman to have a Billboard Top 10 album in 5 decades starting with the 1980s? +In one 19th century translation, what female classic tale character 'perceived the dawn of day and ceased' speaking nearly 1,000 times? +Ironically, though what company founded in the 1860s is Moore County, Tennessee's largest employer, Moore is a dry county? +After a 1789 event, who wrote, 'My first determination was to seek a supply of…water at Tofoa, & afterwards to sail for Tongataboo'? +Laurence Olivier & Ernest Borgnine were considered for the lead role & Sergio Leone to direct for what film that turned 50 in 2022? +Until a 1903 secession, what country's contiguous territory spanned 2 continents? +Early in her career which foreign-born author translated romance novels into Spanish, often changing the dialogue to make the heroines smarter? +Saying it was stolen by Napoleon, self-styled Italian patriot Vincenzo Peruggia took what in 1911? +Continuing a downward trend, in July 2022 what US body of water was at 27% capacity, its lowest level since 1937 when it was first being filled? +Each morning which goddess began her ride in her chariot across the sky ahead of her brother Sol, or Helios? +Until the Civil War, the Jan. 8 date of what American battle of dubious military importance but big morale value was a national holiday? +Which children's book title character is told 'By the time you are real, most of your hair has been loved off your eyes drop out & you get shabby'? +In a TV reunion over 40 years in the making, Dolly Parton appeared as an angel named Agnes in the final episode of what comedy in 2022? +In an 1847 American poem what character sees her town of Grand-Pré burned, but finally reunites with her beau for a kiss before his death? +In 2001 who published a book called 'Banging Your Head Against a Brick Wall'; in 2002, 'Existencilism'? +The title object of what childrens book 'never looked more beautiful each strand held dozens of bright drops of early morning dew'? +The shouts of excited children at a 1946 holiday parade are said to have inspired what perennial classic song favorite? +Unable to make what candies perfectly round, the confectioner embraced this flawed name for the product? +What country is home to 58 UNESCO World Heritage Sites, more than any other country; the sites include a volcano & a lagoon? +What action movie's last line is 'If this is their idea of Christmas, I gotta be here for New Years'? +Only 3 presidents have married while in office— John Tyler was the first & which one was the last? +Demonstrating the dignity & humanity of Black Americans, who sat for 160 known photographs, the most of any American in the 19th century? +Originally, which Latin 3-word phrase referred to when a doctor or apothecary substituted one medicine for another? +The 1975 premiere of what movie comedy advertised free coconuts for the first thousand in the audience? +A cocktail, an island & a WWII venture originally called 'Development of Substitute Materials' all bear what name? +Which US President was sworn in twice as President within 2 years, first by his father & then later by a former U.S. President? +A 1609 story in which an exiled king of Bulgaria creates a sea palace with his magic may have inspired the plot of what play? +In 2009, during a 20th anniversary celebration, what landmark was called 'an edifice of fear. On Nov. 9, it became a place of joy'? +Among what world capital's nicknames are the 'City of Classical Music' &, possibly in honor of a famous resident from 1860 to 1938, the 'City of Dreams'? +Now meaning someone with nocturnal habits, what catches a sleeping dove in Shakespeare's 'Lucrece'? +The stars on what country's flag represent states, 26 of them; unlike the USA's, its 'federal district' gets its own 27th star? +What father was the only man among the 13 plaintiffs in a US class-action case filed in 1951? +Reversing the story of what heroine she created, childrens author Patricia Maclachlan was born on the prairie but spent much of her life in New England? diff --git a/examples/json_schema_pydantic_example.py b/examples/json_schema_pydantic_example.py new file mode 100644 index 0000000000000000000000000000000000000000..19c0bdb5b677024a77528762b06bab9ca4453d1b --- /dev/null +++ b/examples/json_schema_pydantic_example.py @@ -0,0 +1,82 @@ +# Usage: +#! ./llama-server -m some-model.gguf & +#! pip install pydantic +#! python json_schema_pydantic_example.py + +from pydantic import BaseModel, Field, TypeAdapter +from annotated_types import MinLen +from typing import Annotated, List, Optional +import json, requests + +if True: + + def create_completion(*, response_model=None, endpoint="http://localhost:8080/v1/chat/completions", messages, **kwargs): + ''' + Creates a chat completion using an OpenAI-compatible endpoint w/ JSON schema support + (llama.cpp server, llama-cpp-python, Anyscale / Together...) + + The response_model param takes a type (+ supports Pydantic) and behaves just as w/ Instructor (see below) + ''' + response_format = None + type_adapter = None + + if response_model: + type_adapter = TypeAdapter(response_model) + schema = type_adapter.json_schema() + messages = [{ + "role": "system", + "content": f"You respond in JSON format with the following schema: {json.dumps(schema, indent=2)}" + }] + messages + response_format={"type": "json_object", "schema": schema} + + data = requests.post(endpoint, headers={"Content-Type": "application/json"}, + json=dict(messages=messages, response_format=response_format, **kwargs)).json() + if 'error' in data: + raise Exception(data['error']['message']) + + content = data["choices"][0]["message"]["content"] + return type_adapter.validate_json(content) if type_adapter else content + +else: + + # This alternative branch uses Instructor + OpenAI client lib. + # Instructor support streamed iterable responses, retry & more. + # (see https://python.useinstructor.com/) + #! pip install instructor openai + import instructor, openai + client = instructor.patch( + openai.OpenAI(api_key="123", base_url="http://localhost:8080"), + mode=instructor.Mode.JSON_SCHEMA) + create_completion = client.chat.completions.create + + +if __name__ == '__main__': + + class QAPair(BaseModel): + class Config: + extra = 'forbid' # triggers additionalProperties: false in the JSON schema + question: str + concise_answer: str + justification: str + stars: Annotated[int, Field(ge=1, le=5)] + + class PyramidalSummary(BaseModel): + class Config: + extra = 'forbid' # triggers additionalProperties: false in the JSON schema + title: str + summary: str + question_answers: Annotated[List[QAPair], MinLen(2)] + sub_sections: Optional[Annotated[List['PyramidalSummary'], MinLen(2)]] + + print("# Summary\n", create_completion( + model="...", + response_model=PyramidalSummary, + messages=[{ + "role": "user", + "content": f""" + You are a highly efficient corporate document summarizer. + Create a pyramidal summary of an imaginary internal document about our company processes + (starting high-level, going down to each sub sections). + Keep questions short, and answers even shorter (trivia / quizz style). + """ + }])) diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py new file mode 100644 index 0000000000000000000000000000000000000000..ed379585546c24c68ffed7d15de7f5954487ea5d --- /dev/null +++ b/examples/json_schema_to_grammar.py @@ -0,0 +1,814 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import itertools +import json +import re +import sys +from typing import Any, List, Optional, Set, Tuple, Union + +def _build_repetition(item_rule, min_items, max_items, separator_rule=None): + + if max_items == 0: + return "" + + if min_items == 0 and max_items == 1: + return f'{item_rule}?' + + if not separator_rule: + if min_items == 1 and max_items is None: + return f'{item_rule}+' + elif min_items == 0 and max_items is None: + return f'{item_rule}*' + else: + return f'{item_rule}{{{min_items},{max_items if max_items is not None else ""}}}' + + result = item_rule + ' ' + _build_repetition(f'({separator_rule} {item_rule})', min_items - 1 if min_items > 0 else 0, max_items - 1 if max_items is not None else None) + return f'({result})?' if min_items == 0 else result + +def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], out: list, decimals_left: int = 16, top_level: bool = True): + has_min = min_value != None + has_max = max_value != None + + def digit_range(from_char: str, to_char: str): + out.append("[") + if from_char == to_char: + out.append(from_char) + else: + out.append(from_char) + out.append("-") + out.append(to_char) + out.append("]") + + def more_digits(min_digits: int, max_digits: int): + out.append("[0-9]") + if min_digits == max_digits and min_digits == 1: + return + out.append("{") + out.append(str(min_digits)) + if max_digits != min_digits: + out.append(",") + if max_digits != sys.maxsize: + out.append(str(max_digits)) + out.append("}") + + def uniform_range(from_str: str, to_str: str): + i = 0 + while i < len(from_str) and from_str[i] == to_str[i]: + i += 1 + if i > 0: + out.append("\"") + out.append(from_str[:i]) + out.append("\"") + if i < len(from_str): + if i > 0: + out.append(" ") + sub_len = len(from_str) - i - 1 + if sub_len > 0: + from_sub = from_str[i+1:] + to_sub = to_str[i+1:] + sub_zeros = "0" * sub_len + sub_nines = "9" * sub_len + + to_reached = False + out.append("(") + if from_sub == sub_zeros: + digit_range(from_str[i], chr(ord(to_str[i]) - 1)) + out.append(" ") + more_digits(sub_len, sub_len) + else: + out.append("[") + out.append(from_str[i]) + out.append("] ") + out.append("(") + uniform_range(from_sub, sub_nines) + out.append(")") + if ord(from_str[i]) < ord(to_str[i]) - 1: + out.append(" | ") + if to_sub == sub_nines: + digit_range(chr(ord(from_str[i]) + 1), to_str[i]) + to_reached = True + else: + digit_range(chr(ord(from_str[i]) + 1), chr(ord(to_str[i]) - 1)) + out.append(" ") + more_digits(sub_len, sub_len) + if not to_reached: + out.append(" | ") + digit_range(to_str[i], to_str[i]) + out.append(" ") + uniform_range(sub_zeros, to_sub) + out.append(")") + else: + out.append("[") + out.append(from_str[i]) + out.append("-") + out.append(to_str[i]) + out.append("]") + + if has_min and has_max: + if min_value < 0 and max_value < 0: + out.append("\"-\" (") + _generate_min_max_int(-max_value, -min_value, out, decimals_left, top_level=True) + out.append(")") + return + + if min_value < 0: + out.append("\"-\" (") + _generate_min_max_int(0, -min_value, out, decimals_left, top_level=True) + out.append(") | ") + min_value = 0 + + min_s = str(min_value) + max_s = str(max_value) + min_digits = len(min_s) + max_digits = len(max_s) + + for digits in range(min_digits, max_digits): + uniform_range(min_s, "9" * digits) + min_s = "1" + "0" * digits + out.append(" | ") + uniform_range(min_s, max_s) + return + + less_decimals = max(decimals_left - 1, 1) + + if has_min: + if min_value < 0: + out.append("\"-\" (") + _generate_min_max_int(None, -min_value, out, decimals_left, top_level=False) + out.append(") | [0] | [1-9] ") + more_digits(0, decimals_left - 1) + elif min_value == 0: + if top_level: + out.append("[0] | [1-9] ") + more_digits(0, less_decimals) + else: + more_digits(1, decimals_left) + elif min_value <= 9: + c = str(min_value) + range_start = '1' if top_level else '0' + if c > range_start: + digit_range(range_start, chr(ord(c) - 1)) + out.append(" ") + more_digits(1, less_decimals) + out.append(" | ") + digit_range(c, "9") + out.append(" ") + more_digits(0, less_decimals) + else: + min_s = str(min_value) + length = len(min_s) + c = min_s[0] + + if c > "1": + digit_range("1" if top_level else "0", chr(ord(c) - 1)) + out.append(" ") + more_digits(length, less_decimals) + out.append(" | ") + digit_range(c, c) + out.append(" (") + _generate_min_max_int(int(min_s[1:]), None, out, less_decimals, top_level=False) + out.append(")") + if c < "9": + out.append(" | ") + digit_range(chr(ord(c) + 1), "9") + out.append(" ") + more_digits(length - 1, less_decimals) + return + + if has_max: + if max_value >= 0: + if top_level: + out.append("\"-\" [1-9] ") + more_digits(0, less_decimals) + out.append(" | ") + _generate_min_max_int(0, max_value, out, decimals_left, top_level=True) + else: + out.append("\"-\" (") + _generate_min_max_int(-max_value, None, out, decimals_left, top_level=False) + out.append(")") + return + + raise RuntimeError("At least one of min_value or max_value must be set") + +class BuiltinRule: + def __init__(self, content: str, deps: list | None = None): + self.content = content + self.deps = deps or [] + +# Constraining spaces to prevent model "running away". +SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}' + +PRIMITIVE_RULES = { + 'boolean' : BuiltinRule('("true" | "false") space', []), + 'decimal-part' : BuiltinRule('[0-9]{1,16}', []), + 'integral-part': BuiltinRule('[0] | [1-9] [0-9]{0,15}', []), + 'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']), + 'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']), + 'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']), + 'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']), + 'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']), + 'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []), + 'char' : BuiltinRule(r'[^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})', []), + 'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']), + 'null' : BuiltinRule('"null" space', []), +} + +# TODO: support "uri", "email" string formats +STRING_FORMAT_RULES = { + 'date' : BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), + 'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), + 'date-time' : BuiltinRule('date "T" time', ['date', 'time']), + 'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']), + 'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']), + 'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']), +} + +DOTALL = '[\\U00000000-\\U0010FFFF]' +DOT = '[^\\x0A\\x0D]' + +RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]) + +INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+') +GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') +GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]') +GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'} + +NON_LITERAL_SET = set('|.()[]{}*+?') +ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?') + + +class SchemaConverter: + def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern): + self._prop_order = prop_order + self._allow_fetch = allow_fetch + self._dotall = dotall + self._raw_pattern = raw_pattern + self._rules = { + 'space': SPACE_RULE, + } + self._refs = {} + self._refs_being_resolved = set() + + def _format_literal(self, literal): + escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( + lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)) or m.group(0), literal + ) + return f'"{escaped}"' + + def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str: + ''' + not_literal('a') -> '[^a]' + not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?' + ''' + assert len(literal) > 0, 'Empty literal not supported' + def recurse(i: int): + c = literal[i] + if maybe_escaped_underscores and c == '_': + yield f'[^{c}\\\\]' + yield ' | ' + yield f'"\\\\"? "{c}"' + else: + yield f'[^{c}]' + if i < len(literal) - 1: + yield ' | ' + yield self._format_literal(c) + yield ' (' + yield from recurse(i + 1) + yield ')?' + + return ''.join(('(', *recurse(0), ')')) + + def _not_strings(self, strings): + class TrieNode: + def __init__(self): + self.children = {} + self.is_end_of_string = False + + def insert(self, string): + node = self + for c in string: + node = node.children.setdefault(c, TrieNode()) + node.is_end_of_string = True + + trie = TrieNode() + for s in strings: + trie.insert(s) + + char_rule = self._add_primitive('char', PRIMITIVE_RULES['char']) + out = ['["] ( '] + + def visit(node): + rejects = [] + first = True + for c in sorted(node.children.keys()): + child = node.children[c] + rejects.append(c) + if first: + first = False + else: + out.append(' | ') + out.append(f'[{c}]') + if child.children: + out.append(f' (') + visit(child) + out.append(')') + elif child.is_end_of_string: + out.append(f' {char_rule}+') + if node.children: + if not first: + out.append(' | ') + out.append(f'[^"{"".join(rejects)}] {char_rule}*') + visit(trie) + + out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space') + return ''.join(out) + + def _add_rule(self, name, rule): + esc_name = INVALID_RULE_CHARS_RE.sub('-', name) + if esc_name not in self._rules or self._rules[esc_name] == rule: + key = esc_name + else: + i = 0 + while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule: + i += 1 + key = f'{esc_name}{i}' + self._rules[key] = rule + return key + + def resolve_refs(self, schema: dict, url: str): + ''' + Resolves all $ref fields in the given schema, fetching any remote schemas, + replacing $ref with absolute reference URL and populating self._refs with the + respective referenced (sub)schema dictionaries. + ''' + def visit(n: dict): + if isinstance(n, list): + return [visit(x) for x in n] + elif isinstance(n, dict): + ref = n.get('$ref') + if ref is not None and ref not in self._refs: + if ref.startswith('https://'): + assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)' + import requests + + frag_split = ref.split('#') + base_url = frag_split[0] + + target = self._refs.get(base_url) + if target is None: + target = self.resolve_refs(requests.get(ref).json(), base_url) + self._refs[base_url] = target + + if len(frag_split) == 1 or frag_split[-1] == '': + return target + elif ref.startswith('#/'): + target = schema + ref = f'{url}{ref}' + n['$ref'] = ref + else: + raise ValueError(f'Unsupported ref {ref}') + + for sel in ref.split('#')[-1].split('/')[1:]: + assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}' + target = target[sel] + + self._refs[ref] = target + else: + for v in n.values(): + visit(v) + + return n + return visit(schema) + + def _generate_union_rule(self, name, alt_schemas): + return ' | '.join(( + self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}') + for i, alt_schema in enumerate(alt_schemas) + )) + + def _visit_pattern(self, pattern, name): + ''' + Transforms a regular expression pattern into a GBNF rule. + + Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions + Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md + + Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers. + + Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which + we define sub-rules to keep the output lean. + ''' + + assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"' + pattern = pattern[1:-1] + sub_rule_ids = {} + + i = 0 + length = len(pattern) + + def to_rule(s: tuple[str, bool]) -> str: + (txt, is_literal) = s + return "\"" + txt + "\"" if is_literal else txt + + def transform() -> tuple[str, bool]: + ''' + Parse a unit at index i (advancing it), and return its string representation + whether it's a literal. + ''' + nonlocal i + nonlocal pattern + nonlocal sub_rule_ids + + start = i + # For each component of this sequence, store its string representation and whether it's a literal. + # We only need a flat structure here to apply repetition operators to the last item, and + # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially + # (GBNF's syntax is luckily very close to regular expressions!) + seq: list[tuple[str, bool]] = [] + + def get_dot(): + if self._dotall: + rule = DOTALL + else: + # Accept any character... except \n and \r line break chars (\x0A and \xOD) + rule = DOT + return self._add_rule(f'dot', rule) + + def join_seq(): + nonlocal seq + ret = [] + for is_literal, g in itertools.groupby(seq, lambda x: x[1]): + if is_literal: + ret.append((''.join(x[0] for x in g), True)) + else: + ret.extend(g) + if len(ret) == 1: + return ret[0] + return (' '.join(to_rule(x) for x in seq), False) + + while i < length: + c = pattern[i] + if c == '.': + seq.append((get_dot(), False)) + i += 1 + elif c == '(': + i += 1 + if i < length: + assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/' + seq.append((f'({to_rule(transform())})', False)) + elif c == ')': + i += 1 + assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}' + return join_seq() + elif c == '[': + square_brackets = c + i += 1 + while i < length and pattern[i] != ']': + if pattern[i] == '\\': + square_brackets += pattern[i:i+2] + i += 2 + else: + square_brackets += pattern[i] + i += 1 + assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}' + square_brackets += ']' + i += 1 + seq.append((square_brackets, False)) + elif c == '|': + seq.append(('|', False)) + i += 1 + elif c in ('*', '+', '?'): + seq[-1] = (to_rule(seq[-1]) + c, False) + i += 1 + elif c == '{': + curly_brackets = c + i += 1 + while i < length and pattern[i] != '}': + curly_brackets += pattern[i] + i += 1 + assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}' + curly_brackets += '}' + i += 1 + nums = [s.strip() for s in curly_brackets[1:-1].split(',')] + min_times = 0 + max_times = None + try: + if len(nums) == 1: + min_times = int(nums[0]) + max_times = min_times + else: + assert len(nums) == 2 + min_times = int(nums[0]) if nums[0] else 0 + max_times = int(nums[1]) if nums[1] else None + except ValueError: + raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/') + + (sub, sub_is_literal) = seq[-1] + + if not sub_is_literal: + id = sub_rule_ids.get(sub) + if id is None: + id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub) + sub_rule_ids[sub] = id + sub = id + + seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times), False) + else: + literal = '' + while i < length: + if pattern[i] == '\\' and i < length - 1: + next = pattern[i + 1] + if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS: + i += 1 + literal += pattern[i] + i += 1 + else: + literal += pattern[i:i+2] + i += 2 + elif pattern[i] == '"' and not self._raw_pattern: + literal += '\\"' + i += 1 + elif pattern[i] not in NON_LITERAL_SET and \ + (i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET): + literal += pattern[i] + i += 1 + else: + break + if literal: + seq.append((literal, True)) + + return join_seq() + + return self._add_rule( + name, + to_rule(transform()) if self._raw_pattern \ + else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space") + + + def _resolve_ref(self, ref): + ref_name = ref.split('/')[-1] + if ref_name not in self._rules and ref not in self._refs_being_resolved: + self._refs_being_resolved.add(ref) + resolved = self._refs[ref] + ref_name = self.visit(resolved, ref_name) + self._refs_being_resolved.remove(ref) + return ref_name + + def _generate_constant_rule(self, value): + return self._format_literal(json.dumps(value)) + + def visit(self, schema, name): + schema_type = schema.get('type') + schema_format = schema.get('format') + rule_name = name + '-' if name in RESERVED_NAMES else name or 'root' + + if (ref := schema.get('$ref')) is not None: + return self._add_rule(rule_name, self._resolve_ref(ref)) + + elif 'oneOf' in schema or 'anyOf' in schema: + return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf'])) + + elif isinstance(schema_type, list): + return self._add_rule(rule_name, self._generate_union_rule(name, [{**schema, 'type': t} for t in schema_type])) + + elif 'const' in schema: + return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space') + + elif 'enum' in schema: + rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space' + return self._add_rule(rule_name, rule) + + elif schema_type in (None, 'object') and \ + ('properties' in schema or \ + ('additionalProperties' in schema and schema['additionalProperties'] is not True)): + required = set(schema.get('required', [])) + properties = list(schema.get('properties', {}).items()) + return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties'))) + + elif schema_type in (None, 'object') and 'allOf' in schema: + required = set() + properties = [] + hybrid_name = name + def add_component(comp_schema, is_required): + if (ref := comp_schema.get('$ref')) is not None: + comp_schema = self._refs[ref] + + if 'properties' in comp_schema: + for prop_name, prop_schema in comp_schema['properties'].items(): + properties.append((prop_name, prop_schema)) + if is_required: + required.add(prop_name) + + for t in schema['allOf']: + if 'anyOf' in t: + for tt in t['anyOf']: + add_component(tt, is_required=False) + else: + add_component(t, is_required=True) + + return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None)) + + elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema): + items = schema.get('items') or schema['prefixItems'] + if isinstance(items, list): + return self._add_rule( + rule_name, + '"[" space ' + + ' "," space '.join( + self.visit(item, f'{name}{"-" if name else ""}tuple-{i}') + for i, item in enumerate(items)) + + ' "]" space') + else: + item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item') + min_items = schema.get("minItems", 0) + max_items = schema.get("maxItems") + return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space') + + elif schema_type in (None, 'string') and 'pattern' in schema: + return self._visit_pattern(schema['pattern'], rule_name) + + elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''): + return self._add_primitive( + 'root' if rule_name == 'root' else schema_format, + PRIMITIVE_RULES['uuid'] + ) + + elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES: + prim_name = f'{schema_format}-string' + return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name])) + + elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema): + char_rule = self._add_primitive('char', PRIMITIVE_RULES['char']) + min_len = schema.get('minLength', 0) + max_len = schema.get('maxLength') + + return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space') + + elif schema_type in (None, 'integer') and \ + ('minimum' in schema or 'exclusiveMinimum' in schema or 'maximum' in schema or 'exclusiveMaximum' in schema): + min_value = None + max_value = None + if 'minimum' in schema: + min_value = schema['minimum'] + elif 'exclusiveMinimum' in schema: + min_value = schema['exclusiveMinimum'] + 1 + if 'maximum' in schema: + max_value = schema['maximum'] + elif 'exclusiveMaximum' in schema: + max_value = schema['exclusiveMaximum'] - 1 + + out = ["("] + _generate_min_max_int(min_value, max_value, out) + out.append(") space") + return self._add_rule(rule_name, ''.join(out)) + + elif (schema_type == 'object') or (len(schema) == 0): + return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object'])) + + else: + assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}' + # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero + return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type]) + + def _add_primitive(self, name: str, rule: BuiltinRule): + n = self._add_rule(name, rule.content) + + for dep in rule.deps: + dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep) + assert dep_rule, f'Rule {dep} not known' + if dep not in self._rules: + self._add_primitive(dep, dep_rule) + return n + + def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Optional[Union[bool, Any]]): + prop_order = self._prop_order + # sort by position in prop_order (if specified) then by original order + sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))] + + prop_kv_rule_names = {} + for prop_name, prop_schema in properties: + prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}') + prop_kv_rule_names[prop_name] = self._add_rule( + f'{name}{"-" if name else ""}{prop_name}-kv', + fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}' + ) + required_props = [k for k in sorted_props if k in required] + optional_props = [k for k in sorted_props if k not in required] + + if additional_properties is not None and additional_properties != False: + sub_name = f'{name}{"-" if name else ""}additional' + value_rule = self.visit(additional_properties, f'{sub_name}-value') if isinstance(additional_properties, dict) else \ + self._add_primitive('value', PRIMITIVE_RULES['value']) + key_rule = self._add_primitive('string', PRIMITIVE_RULES['string']) if not sorted_props \ + else self._add_rule(f'{sub_name}-k', self._not_strings(sorted_props)) + + prop_kv_rule_names["*"] = self._add_rule( + f'{sub_name}-kv', + f'{key_rule} ":" space {value_rule}' + ) + optional_props.append("*") + + rule = '"{" space ' + rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props) + + if optional_props: + rule += ' (' + if required_props: + rule += ' "," space ( ' + + def get_recursive_refs(ks, first_is_optional): + [k, *rest] = ks + kv_rule_name = prop_kv_rule_names[k] + comma_ref = f'( "," space {kv_rule_name} )' + if first_is_optional: + res = comma_ref + ('*' if k == '*' else '?') + else: + res = kv_rule_name + (' ' + comma_ref + "*" if k == '*' else '') + if len(rest) > 0: + res += ' ' + self._add_rule( + f'{name}{"-" if name else ""}{k}-rest', + get_recursive_refs(rest, first_is_optional=True) + ) + return res + + rule += ' | '.join( + get_recursive_refs(optional_props[i:], first_is_optional=False) + for i in range(len(optional_props)) + ) + if required_props: + rule += ' )' + rule += ' )?' + + rule += ' "}" space' + + return rule + + def format_grammar(self): + return '\n'.join( + f'{name} ::= {rule}' + for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0]) + ) + + +def main(args_in = None): + parser = argparse.ArgumentParser( + description=''' + Generates a grammar (suitable for use in ./llama-cli) that produces JSON conforming to a + given JSON schema. Only a subset of JSON schema features are supported; more may be + added in the future. + ''', + ) + parser.add_argument( + '--prop-order', + default=[], + type=lambda s: s.split(','), + help=''' + comma-separated property names defining the order of precedence for object properties; + properties not specified here are given lower precedence than those that are, and + are kept in their original order from the schema. Required properties are always + given precedence over optional properties. + ''' + ) + parser.add_argument( + '--allow-fetch', + action='store_true', + default=False, + help='Whether to allow fetching referenced schemas over HTTPS') + parser.add_argument( + '--dotall', + action='store_true', + default=False, + help='Whether to treat dot (".") as matching all chars including line breaks in regular expression patterns') + parser.add_argument( + '--raw-pattern', + action='store_true', + default=False, + help='Treats string patterns as raw patterns w/o quotes (or quote escapes)') + + parser.add_argument('schema', help='file containing JSON schema ("-" for stdin)') + args = parser.parse_args(args_in) + + if args.schema.startswith('https://'): + url = args.schema + import requests + schema = requests.get(url).json() + elif args.schema == '-': + url = 'stdin' + schema = json.load(sys.stdin) + else: + url = f'file://{args.schema}' + with open(args.schema) as f: + schema = json.load(f) + converter = SchemaConverter( + prop_order={name: idx for idx, name in enumerate(args.prop_order)}, + allow_fetch=args.allow_fetch, + dotall=args.dotall, + raw_pattern=args.raw_pattern) + schema = converter.resolve_refs(schema, url) + converter.visit(schema, '') + print(converter.format_grammar()) + + +if __name__ == '__main__': + main() diff --git a/examples/llama.android/.gitignore b/examples/llama.android/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..347e252ef10e9c2052ee2017c929530eb0afc5f1 --- /dev/null +++ b/examples/llama.android/.gitignore @@ -0,0 +1,33 @@ +# Gradle files +.gradle/ +build/ + +# Local configuration file (sdk path, etc) +local.properties + +# Log/OS Files +*.log + +# Android Studio generated files and folders +captures/ +.externalNativeBuild/ +.cxx/ +*.apk +output.json + +# IntelliJ +*.iml +.idea/ +misc.xml +deploymentTargetDropDown.xml +render.experimental.xml + +# Keystore files +*.jks +*.keystore + +# Google Services (e.g. APIs or Firebase) +google-services.json + +# Android Profiling +*.hprof diff --git a/examples/llama.android/README.md b/examples/llama.android/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/llama.android/app/.gitignore b/examples/llama.android/app/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..796b96d1c402326528b4ba3c12ee9d92d0e212e9 --- /dev/null +++ b/examples/llama.android/app/.gitignore @@ -0,0 +1 @@ +/build diff --git a/examples/llama.android/app/build.gradle.kts b/examples/llama.android/app/build.gradle.kts new file mode 100644 index 0000000000000000000000000000000000000000..8d1b37195efd40a8336391d80b4d90064dba6ee0 --- /dev/null +++ b/examples/llama.android/app/build.gradle.kts @@ -0,0 +1,65 @@ +plugins { + id("com.android.application") + id("org.jetbrains.kotlin.android") +} + +android { + namespace = "com.example.llama" + compileSdk = 34 + + defaultConfig { + applicationId = "com.example.llama" + minSdk = 33 + targetSdk = 34 + versionCode = 1 + versionName = "1.0" + + testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" + vectorDrawables { + useSupportLibrary = true + } + } + + buildTypes { + release { + isMinifyEnabled = false + proguardFiles( + getDefaultProguardFile("proguard-android-optimize.txt"), + "proguard-rules.pro" + ) + } + } + compileOptions { + sourceCompatibility = JavaVersion.VERSION_1_8 + targetCompatibility = JavaVersion.VERSION_1_8 + } + kotlinOptions { + jvmTarget = "1.8" + } + buildFeatures { + compose = true + } + composeOptions { + kotlinCompilerExtensionVersion = "1.5.1" + } +} + +dependencies { + + implementation("androidx.core:core-ktx:1.12.0") + implementation("androidx.lifecycle:lifecycle-runtime-ktx:2.6.2") + implementation("androidx.activity:activity-compose:1.8.2") + implementation(platform("androidx.compose:compose-bom:2023.08.00")) + implementation("androidx.compose.ui:ui") + implementation("androidx.compose.ui:ui-graphics") + implementation("androidx.compose.ui:ui-tooling-preview") + implementation("androidx.compose.material3:material3") + implementation(project(":llama")) + testImplementation("junit:junit:4.13.2") + androidTestImplementation("androidx.test.ext:junit:1.1.5") + androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1") + androidTestImplementation(platform("androidx.compose:compose-bom:2023.08.00")) + androidTestImplementation("androidx.compose.ui:ui-test-junit4") + debugImplementation("androidx.compose.ui:ui-tooling") + debugImplementation("androidx.compose.ui:ui-test-manifest") +} diff --git a/examples/llama.android/app/proguard-rules.pro b/examples/llama.android/app/proguard-rules.pro new file mode 100644 index 0000000000000000000000000000000000000000..f1b424510da51fd82143bc74a0a801ae5a1e2fcd --- /dev/null +++ b/examples/llama.android/app/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/examples/llama.android/app/src/main/AndroidManifest.xml b/examples/llama.android/app/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..41a358a2991547aab2741e45dd6bb365bfd0beec --- /dev/null +++ b/examples/llama.android/app/src/main/AndroidManifest.xml @@ -0,0 +1,30 @@ + + + + + + + + + + + + + + + + + diff --git a/examples/llama.android/app/src/main/java/com/example/llama/Downloadable.kt b/examples/llama.android/app/src/main/java/com/example/llama/Downloadable.kt new file mode 100644 index 0000000000000000000000000000000000000000..78c231ae55d8c9b996c6f1b0ad90658ffa637378 --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/Downloadable.kt @@ -0,0 +1,119 @@ +package com.example.llama + +import android.app.DownloadManager +import android.net.Uri +import android.util.Log +import androidx.compose.material3.Button +import androidx.compose.material3.Text +import androidx.compose.runtime.Composable +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableDoubleStateOf +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember +import androidx.compose.runtime.rememberCoroutineScope +import androidx.compose.runtime.setValue +import androidx.core.database.getLongOrNull +import androidx.core.net.toUri +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch +import java.io.File + +data class Downloadable(val name: String, val source: Uri, val destination: File) { + companion object { + @JvmStatic + private val tag: String? = this::class.qualifiedName + + sealed interface State + data object Ready: State + data class Downloading(val id: Long): State + data class Downloaded(val downloadable: Downloadable): State + data class Error(val message: String): State + + @JvmStatic + @Composable + fun Button(viewModel: MainViewModel, dm: DownloadManager, item: Downloadable) { + var status: State by remember { + mutableStateOf( + if (item.destination.exists()) Downloaded(item) + else Ready + ) + } + var progress by remember { mutableDoubleStateOf(0.0) } + + val coroutineScope = rememberCoroutineScope() + + suspend fun waitForDownload(result: Downloading, item: Downloadable): State { + while (true) { + val cursor = dm.query(DownloadManager.Query().setFilterById(result.id)) + + if (cursor == null) { + Log.e(tag, "dm.query() returned null") + return Error("dm.query() returned null") + } + + if (!cursor.moveToFirst() || cursor.count < 1) { + cursor.close() + Log.i(tag, "cursor.moveToFirst() returned false or cursor.count < 1, download canceled?") + return Ready + } + + val pix = cursor.getColumnIndex(DownloadManager.COLUMN_BYTES_DOWNLOADED_SO_FAR) + val tix = cursor.getColumnIndex(DownloadManager.COLUMN_TOTAL_SIZE_BYTES) + val sofar = cursor.getLongOrNull(pix) ?: 0 + val total = cursor.getLongOrNull(tix) ?: 1 + cursor.close() + + if (sofar == total) { + return Downloaded(item) + } + + progress = (sofar * 1.0) / total + + delay(1000L) + } + } + + fun onClick() { + when (val s = status) { + is Downloaded -> { + viewModel.load(item.destination.path) + } + + is Downloading -> { + coroutineScope.launch { + status = waitForDownload(s, item) + } + } + + else -> { + item.destination.delete() + + val request = DownloadManager.Request(item.source).apply { + setTitle("Downloading model") + setDescription("Downloading model: ${item.name}") + setAllowedNetworkTypes(DownloadManager.Request.NETWORK_WIFI) + setDestinationUri(item.destination.toUri()) + } + + viewModel.log("Saving ${item.name} to ${item.destination.path}") + Log.i(tag, "Saving ${item.name} to ${item.destination.path}") + + val id = dm.enqueue(request) + status = Downloading(id) + onClick() + } + } + } + + Button(onClick = { onClick() }, enabled = status !is Downloading) { + when (status) { + is Downloading -> Text(text = "Downloading ${(progress * 100).toInt()}%") + is Downloaded -> Text("Load ${item.name}") + is Ready -> Text("Download ${item.name}") + is Error -> Text("Download ${item.name}") + } + } + } + + } +} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt b/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt new file mode 100644 index 0000000000000000000000000000000000000000..9da04f7d3c32ef2180657c94717502ca58f811e8 --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt @@ -0,0 +1,154 @@ +package com.example.llama + +import android.app.ActivityManager +import android.app.DownloadManager +import android.content.ClipData +import android.content.ClipboardManager +import android.net.Uri +import android.os.Bundle +import android.os.StrictMode +import android.os.StrictMode.VmPolicy +import android.text.format.Formatter +import androidx.activity.ComponentActivity +import androidx.activity.compose.setContent +import androidx.activity.viewModels +import androidx.compose.foundation.layout.Box +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.fillMaxSize +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.lazy.LazyColumn +import androidx.compose.foundation.lazy.items +import androidx.compose.foundation.lazy.rememberLazyListState +import androidx.compose.material3.Button +import androidx.compose.material3.LocalContentColor +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.OutlinedTextField +import androidx.compose.material3.Surface +import androidx.compose.material3.Text +import androidx.compose.runtime.Composable +import androidx.compose.ui.Modifier +import androidx.compose.ui.unit.dp +import androidx.core.content.getSystemService +import com.example.llama.ui.theme.LlamaAndroidTheme +import java.io.File + +class MainActivity( + activityManager: ActivityManager? = null, + downloadManager: DownloadManager? = null, + clipboardManager: ClipboardManager? = null, +): ComponentActivity() { + private val tag: String? = this::class.simpleName + + private val activityManager by lazy { activityManager ?: getSystemService()!! } + private val downloadManager by lazy { downloadManager ?: getSystemService()!! } + private val clipboardManager by lazy { clipboardManager ?: getSystemService()!! } + + private val viewModel: MainViewModel by viewModels() + + // Get a MemoryInfo object for the device's current memory status. + private fun availableMemory(): ActivityManager.MemoryInfo { + return ActivityManager.MemoryInfo().also { memoryInfo -> + activityManager.getMemoryInfo(memoryInfo) + } + } + + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + + StrictMode.setVmPolicy( + VmPolicy.Builder(StrictMode.getVmPolicy()) + .detectLeakedClosableObjects() + .build() + ) + + val free = Formatter.formatFileSize(this, availableMemory().availMem) + val total = Formatter.formatFileSize(this, availableMemory().totalMem) + + viewModel.log("Current memory: $free / $total") + viewModel.log("Downloads directory: ${getExternalFilesDir(null)}") + + val extFilesDir = getExternalFilesDir(null) + + val models = listOf( + Downloadable( + "Phi-2 7B (Q4_0, 1.6 GiB)", + Uri.parse("https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf?download=true"), + File(extFilesDir, "phi-2-q4_0.gguf"), + ), + Downloadable( + "TinyLlama 1.1B (f16, 2.2 GiB)", + Uri.parse("https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf?download=true"), + File(extFilesDir, "tinyllama-1.1-f16.gguf"), + ), + Downloadable( + "Phi 2 DPO (Q3_K_M, 1.48 GiB)", + Uri.parse("https://huggingface.co/TheBloke/phi-2-dpo-GGUF/resolve/main/phi-2-dpo.Q3_K_M.gguf?download=true"), + File(extFilesDir, "phi-2-dpo.Q3_K_M.gguf") + ), + ) + + setContent { + LlamaAndroidTheme { + // A surface container using the 'background' color from the theme + Surface( + modifier = Modifier.fillMaxSize(), + color = MaterialTheme.colorScheme.background + ) { + MainCompose( + viewModel, + clipboardManager, + downloadManager, + models, + ) + } + + } + } + } +} + +@Composable +fun MainCompose( + viewModel: MainViewModel, + clipboard: ClipboardManager, + dm: DownloadManager, + models: List +) { + Column { + val scrollState = rememberLazyListState() + + Box(modifier = Modifier.weight(1f)) { + LazyColumn(state = scrollState) { + items(viewModel.messages) { + Text( + it, + style = MaterialTheme.typography.bodyLarge.copy(color = LocalContentColor.current), + modifier = Modifier.padding(16.dp) + ) + } + } + } + OutlinedTextField( + value = viewModel.message, + onValueChange = { viewModel.updateMessage(it) }, + label = { Text("Message") }, + ) + Row { + Button({ viewModel.send() }) { Text("Send") } + Button({ viewModel.bench(8, 4, 1) }) { Text("Bench") } + Button({ viewModel.clear() }) { Text("Clear") } + Button({ + viewModel.messages.joinToString("\n").let { + clipboard.setPrimaryClip(ClipData.newPlainText("", it)) + } + }) { Text("Copy") } + } + + Column { + for (model in models) { + Downloadable.Button(viewModel, dm, model) + } + } + } +} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt new file mode 100644 index 0000000000000000000000000000000000000000..45ac29938f441e91c95620eef19be76e38008a3c --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt @@ -0,0 +1,105 @@ +package com.example.llama + +import android.llama.cpp.LLamaAndroid +import android.util.Log +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.setValue +import androidx.lifecycle.ViewModel +import androidx.lifecycle.viewModelScope +import kotlinx.coroutines.flow.catch +import kotlinx.coroutines.launch + +class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instance()): ViewModel() { + companion object { + @JvmStatic + private val NanosPerSecond = 1_000_000_000.0 + } + + private val tag: String? = this::class.simpleName + + var messages by mutableStateOf(listOf("Initializing...")) + private set + + var message by mutableStateOf("") + private set + + override fun onCleared() { + super.onCleared() + + viewModelScope.launch { + try { + llamaAndroid.unload() + } catch (exc: IllegalStateException) { + messages += exc.message!! + } + } + } + + fun send() { + val text = message + message = "" + + // Add to messages console. + messages += text + messages += "" + + viewModelScope.launch { + llamaAndroid.send(text) + .catch { + Log.e(tag, "send() failed", it) + messages += it.message!! + } + .collect { messages = messages.dropLast(1) + (messages.last() + it) } + } + } + + fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1) { + viewModelScope.launch { + try { + val start = System.nanoTime() + val warmupResult = llamaAndroid.bench(pp, tg, pl, nr) + val end = System.nanoTime() + + messages += warmupResult + + val warmup = (end - start).toDouble() / NanosPerSecond + messages += "Warm up time: $warmup seconds, please wait..." + + if (warmup > 5.0) { + messages += "Warm up took too long, aborting benchmark" + return@launch + } + + messages += llamaAndroid.bench(512, 128, 1, 3) + } catch (exc: IllegalStateException) { + Log.e(tag, "bench() failed", exc) + messages += exc.message!! + } + } + } + + fun load(pathToModel: String) { + viewModelScope.launch { + try { + llamaAndroid.load(pathToModel) + messages += "Loaded $pathToModel" + } catch (exc: IllegalStateException) { + Log.e(tag, "load() failed", exc) + messages += exc.message!! + } + } + } + + fun updateMessage(newMessage: String) { + message = newMessage + } + + fun clear() { + messages = listOf() + } + + fun log(message: String) { + messages += message + } +} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Color.kt b/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Color.kt new file mode 100644 index 0000000000000000000000000000000000000000..40c30e8d970772eceafa312db1c9c10c7bce3cc5 --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Color.kt @@ -0,0 +1,11 @@ +package com.example.llama.ui.theme + +import androidx.compose.ui.graphics.Color + +val Purple80 = Color(0xFFD0BCFF) +val PurpleGrey80 = Color(0xFFCCC2DC) +val Pink80 = Color(0xFFEFB8C8) + +val Purple40 = Color(0xFF6650a4) +val PurpleGrey40 = Color(0xFF625b71) +val Pink40 = Color(0xFF7D5260) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Theme.kt b/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Theme.kt new file mode 100644 index 0000000000000000000000000000000000000000..e742220a8d71972f66ac6eb5545783698249ba3d --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Theme.kt @@ -0,0 +1,70 @@ +package com.example.llama.ui.theme + +import android.app.Activity +import android.os.Build +import androidx.compose.foundation.isSystemInDarkTheme +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.darkColorScheme +import androidx.compose.material3.dynamicDarkColorScheme +import androidx.compose.material3.dynamicLightColorScheme +import androidx.compose.material3.lightColorScheme +import androidx.compose.runtime.Composable +import androidx.compose.runtime.SideEffect +import androidx.compose.ui.graphics.toArgb +import androidx.compose.ui.platform.LocalContext +import androidx.compose.ui.platform.LocalView +import androidx.core.view.WindowCompat + +private val DarkColorScheme = darkColorScheme( + primary = Purple80, + secondary = PurpleGrey80, + tertiary = Pink80 +) + +private val LightColorScheme = lightColorScheme( + primary = Purple40, + secondary = PurpleGrey40, + tertiary = Pink40 + + /* Other default colors to override + background = Color(0xFFFFFBFE), + surface = Color(0xFFFFFBFE), + onPrimary = Color.White, + onSecondary = Color.White, + onTertiary = Color.White, + onBackground = Color(0xFF1C1B1F), + onSurface = Color(0xFF1C1B1F), + */ +) + +@Composable +fun LlamaAndroidTheme( + darkTheme: Boolean = isSystemInDarkTheme(), + // Dynamic color is available on Android 12+ + dynamicColor: Boolean = true, + content: @Composable () -> Unit +) { + val colorScheme = when { + dynamicColor && Build.VERSION.SDK_INT >= Build.VERSION_CODES.S -> { + val context = LocalContext.current + if (darkTheme) dynamicDarkColorScheme(context) else dynamicLightColorScheme(context) + } + + darkTheme -> DarkColorScheme + else -> LightColorScheme + } + val view = LocalView.current + if (!view.isInEditMode) { + SideEffect { + val window = (view.context as Activity).window + window.statusBarColor = colorScheme.primary.toArgb() + WindowCompat.getInsetsController(window, view).isAppearanceLightStatusBars = darkTheme + } + } + + MaterialTheme( + colorScheme = colorScheme, + typography = Typography, + content = content + ) +} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Type.kt b/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Type.kt new file mode 100644 index 0000000000000000000000000000000000000000..0b87946ca3ab11df3dbd86916f1e03ef79caff78 --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Type.kt @@ -0,0 +1,34 @@ +package com.example.llama.ui.theme + +import androidx.compose.material3.Typography +import androidx.compose.ui.text.TextStyle +import androidx.compose.ui.text.font.FontFamily +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.unit.sp + +// Set of Material typography styles to start with +val Typography = Typography( + bodyLarge = TextStyle( + fontFamily = FontFamily.Default, + fontWeight = FontWeight.Normal, + fontSize = 16.sp, + lineHeight = 24.sp, + letterSpacing = 0.5.sp + ) + /* Other default text styles to override + titleLarge = TextStyle( + fontFamily = FontFamily.Default, + fontWeight = FontWeight.Normal, + fontSize = 22.sp, + lineHeight = 28.sp, + letterSpacing = 0.sp + ), + labelSmall = TextStyle( + fontFamily = FontFamily.Default, + fontWeight = FontWeight.Medium, + fontSize = 11.sp, + lineHeight = 16.sp, + letterSpacing = 0.5.sp + ) + */ +) diff --git a/examples/llama.android/app/src/main/res/drawable/ic_launcher_background.xml b/examples/llama.android/app/src/main/res/drawable/ic_launcher_background.xml new file mode 100644 index 0000000000000000000000000000000000000000..07d5da9cbf141911847041df5d7b87f0dd5ef9d4 --- /dev/null +++ b/examples/llama.android/app/src/main/res/drawable/ic_launcher_background.xml @@ -0,0 +1,170 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/examples/llama.android/app/src/main/res/drawable/ic_launcher_foreground.xml b/examples/llama.android/app/src/main/res/drawable/ic_launcher_foreground.xml new file mode 100644 index 0000000000000000000000000000000000000000..7706ab9e6d40725c5e06d3d8e3ee18ddca8cac01 --- /dev/null +++ b/examples/llama.android/app/src/main/res/drawable/ic_launcher_foreground.xml @@ -0,0 +1,30 @@ + + + + + + + + + + + diff --git a/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher.xml b/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher.xml new file mode 100644 index 0000000000000000000000000000000000000000..b3e26b4c60c27edeec7775191674317dd7a9f652 --- /dev/null +++ b/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher.xml @@ -0,0 +1,6 @@ + + + + + + diff --git a/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher_round.xml b/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher_round.xml new file mode 100644 index 0000000000000000000000000000000000000000..b3e26b4c60c27edeec7775191674317dd7a9f652 --- /dev/null +++ b/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher_round.xml @@ -0,0 +1,6 @@ + + + + + + diff --git a/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher.webp b/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher.webp new file mode 100644 index 0000000000000000000000000000000000000000..c209e78ecd372343283f4157dcfd918ec5165bb3 Binary files /dev/null and b/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher.webp differ diff --git a/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp b/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp new file mode 100644 index 0000000000000000000000000000000000000000..b2dfe3d1ba5cf3ee31b3ecc1ced89044a1f3b7a9 Binary files /dev/null and b/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp differ diff --git a/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher.webp b/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher.webp new file mode 100644 index 0000000000000000000000000000000000000000..4f0f1d64e58ba64d180ce43ee13bf9a17835fbca Binary files /dev/null and b/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher.webp differ diff --git a/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp b/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp new file mode 100644 index 0000000000000000000000000000000000000000..62b611da081676d42f6c3f78a2c91e7bcedddedb Binary files /dev/null and b/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp differ diff --git a/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher.webp b/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher.webp new file mode 100644 index 0000000000000000000000000000000000000000..948a3070fe34c611c42c0d3ad3013a0dce358be0 Binary files /dev/null and b/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher.webp differ diff --git a/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp b/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp new file mode 100644 index 0000000000000000000000000000000000000000..1b9a6956b3acdc11f40ce2bb3f6efbd845cc243f Binary files /dev/null and b/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp differ diff --git a/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp b/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp new file mode 100644 index 0000000000000000000000000000000000000000..28d4b77f9f036a47549d47db79c16788749dca10 Binary files /dev/null and b/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp differ diff --git a/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp b/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp new file mode 100644 index 0000000000000000000000000000000000000000..9287f5083623b375139afb391af71cc533a7dd37 Binary files /dev/null and b/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp differ diff --git a/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp b/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp new file mode 100644 index 0000000000000000000000000000000000000000..aa7d6427e6fa1074b79ccd52ef67ac15c5637e85 Binary files /dev/null and b/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp differ diff --git a/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp b/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp new file mode 100644 index 0000000000000000000000000000000000000000..9126ae37cbc3587421d6889eadd1d91fbf1994d4 Binary files /dev/null and b/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp differ diff --git a/examples/llama.android/app/src/main/res/values/colors.xml b/examples/llama.android/app/src/main/res/values/colors.xml new file mode 100644 index 0000000000000000000000000000000000000000..ca1931bca99e3b8a4c753ae1a1c41428e5622484 --- /dev/null +++ b/examples/llama.android/app/src/main/res/values/colors.xml @@ -0,0 +1,10 @@ + + + #FFBB86FC + #FF6200EE + #FF3700B3 + #FF03DAC5 + #FF018786 + #FF000000 + #FFFFFFFF + diff --git a/examples/llama.android/app/src/main/res/values/strings.xml b/examples/llama.android/app/src/main/res/values/strings.xml new file mode 100644 index 0000000000000000000000000000000000000000..7a9d314e2969bf413c0f9f61fe4d3a2f872807fd --- /dev/null +++ b/examples/llama.android/app/src/main/res/values/strings.xml @@ -0,0 +1,3 @@ + + LlamaAndroid + diff --git a/examples/llama.android/app/src/main/res/values/themes.xml b/examples/llama.android/app/src/main/res/values/themes.xml new file mode 100644 index 0000000000000000000000000000000000000000..8a24fda56602cea7b0b55b5f82806c1495eaa286 --- /dev/null +++ b/examples/llama.android/app/src/main/res/values/themes.xml @@ -0,0 +1,5 @@ + + + + + + diff --git a/ggml/src/ggml-kompute/kompute/examples/android/android-simple/gradle.properties b/ggml/src/ggml-kompute/kompute/examples/android/android-simple/gradle.properties new file mode 100644 index 0000000000000000000000000000000000000000..7fa14279e6ebe43dcca9c37d69427a95d9de11ae --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/android/android-simple/gradle.properties @@ -0,0 +1,20 @@ +# Project-wide Gradle settings. + +# IDE (e.g. Android Studio) users: +# Gradle settings configured through the IDE *will override* +# any settings specified in this file. + +# For more details on how to configure your build environment visit +# http://www.gradle.org/docs/current/userguide/build_environment.html + +# Specifies the JVM arguments used for the daemon process. +# The setting is particularly useful for tweaking memory settings. +# Default value: -Xmx10248m -XX:MaxPermSize=256m +# org.gradle.jvmargs=-Xmx2048m -XX:MaxPermSize=512m -XX:+HeapDumpOnOutOfMemoryError -Dfile.encoding=UTF-8 +org.gradle.jvmargs=-Xmx1536m + +# When configured, Gradle will run in incubating parallel mode. +# This option should only be used with decoupled projects. More details, visit +# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects +# org.gradle.parallel=true +android.useAndroidX=true \ No newline at end of file diff --git a/ggml/src/ggml-kompute/kompute/examples/android/android-simple/gradle/wrapper/gradle-wrapper.properties b/ggml/src/ggml-kompute/kompute/examples/android/android-simple/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000000000000000000000000000000000000..cfcf34b60d96e6ea93e81ae30a09f67f149fef1f --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/android/android-simple/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,6 @@ +#Mon Jan 13 14:51:34 PST 2020 +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-7.4-all.zip diff --git a/ggml/src/ggml-kompute/kompute/examples/android/android-simple/gradlew b/ggml/src/ggml-kompute/kompute/examples/android/android-simple/gradlew new file mode 100644 index 0000000000000000000000000000000000000000..9d82f78915133e1c35a6ea51252590fb38efac2f --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/android/android-simple/gradlew @@ -0,0 +1,160 @@ +#!/usr/bin/env bash + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS="" + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn ( ) { + echo "$*" +} + +die ( ) { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; +esac + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin, switch paths to Windows format before running java +if $cygwin ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + JAVACMD=`cygpath --unix "$JAVACMD"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=$((i+1)) + done + case $i in + (0) set -- ;; + (1) set -- "$args0" ;; + (2) set -- "$args0" "$args1" ;; + (3) set -- "$args0" "$args1" "$args2" ;; + (4) set -- "$args0" "$args1" "$args2" "$args3" ;; + (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules +function splitJvmOpts() { + JVM_OPTS=("$@") +} +eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS +JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME" + +exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@" diff --git a/ggml/src/ggml-kompute/kompute/examples/android/android-simple/settings.gradle b/ggml/src/ggml-kompute/kompute/examples/android/android-simple/settings.gradle new file mode 100644 index 0000000000000000000000000000000000000000..573abcb323ce0079c54bdf98ecfd1e3658f335e1 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/android/android-simple/settings.gradle @@ -0,0 +1,2 @@ +include ':app' + diff --git a/ggml/src/ggml-kompute/kompute/examples/array_multiplication/CMakeLists.txt b/ggml/src/ggml-kompute/kompute/examples/array_multiplication/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d63d4b416363697d8ba7cebabf5c37bd6bfaa7b2 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/array_multiplication/CMakeLists.txt @@ -0,0 +1,45 @@ +cmake_minimum_required(VERSION 3.20) +project(kompute_array_mult) + +set(CMAKE_CXX_STANDARD 14) + +# Options +option(KOMPUTE_OPT_GIT_TAG "The tag of the repo to use for the example" 1344ece4ac278f9b3be3b4555ffaace7a032b91f) +option(KOMPUTE_OPT_FROM_SOURCE "Whether to build example from source or from git fetch repo" 0) + +# Set a default build type if none was specified +# Based on: https://github.com/openchemistry/tomviz/blob/master/cmake/BuildType.cmake +set(DEFAULT_BUILD_TYPE "Release") + +if(EXISTS "${CMAKE_SOURCE_DIR}/.git") + set(DEFAULT_BUILD_TYPE "Debug") +endif() + +if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) + message(STATUS "Setting build type to '${DEFAULT_BUILD_TYPE}' as none was specified.") + set(CMAKE_BUILD_TYPE "${DEFAULT_BUILD_TYPE}" CACHE STRING "Choose the type of build." FORCE) + + # Set the possible values of build type for cmake-gui + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") +endif() + +if(WIN32) # Install dlls in the same directory as the executable on Windows + set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) + set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) +endif() + +if(KOMPUTE_OPT_FROM_SOURCE) + add_subdirectory(../../ ${CMAKE_CURRENT_BINARY_DIR}/kompute_build) +else() + include(FetchContent) + FetchContent_Declare(kompute GIT_REPOSITORY https://github.com/KomputeProject/kompute.git + GIT_TAG ${KOMPUTE_OPT_GIT_TAG}) + FetchContent_MakeAvailable(kompute) + include_directories(${kompute_SOURCE_DIR}/src/include) +endif() + +# Add to the list, so CMake can later find the code to compile shaders to header files +list(APPEND CMAKE_PREFIX_PATH "${kompute_SOURCE_DIR}/cmake") + +add_subdirectory(shader) +add_subdirectory(src) diff --git a/ggml/src/ggml-kompute/kompute/examples/array_multiplication/README.md b/ggml/src/ggml-kompute/kompute/examples/array_multiplication/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ba6f9cab54968405aadc8943f1bcaafcb58e4855 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/array_multiplication/README.md @@ -0,0 +1,45 @@ +# Kompute Array Multiplication Example + +This folder contains an end to end Kompute Example that implements logistic regression. +This example is structured such that you will be able to extend it for your project. +It contains a CMake build configuration that can be used in your production applications. + +## Building the example + +You will notice that it's a standalone project, so you can re-use it for your application. +It uses CMake's [`fetch_content`](https://cmake.org/cmake/help/latest/module/FetchContent.html) to consume Kompute as a dependency. +To build you just need to run the CMake command in this folder as follows: + +```bash +git clone https://github.com/KomputeProject/kompute.git +cd kompute/examples/array_multiplication +mkdir build +cd build +cmake .. +cmake --build . +``` + +## Executing + +Form inside the `build/` directory run: + +### Linux + +```bash +./kompute_array_mult +``` + +### Windows + +```bash +.\Debug\kompute_array_mult.exe +``` + +## Pre-requisites + +In order to run this example, you will need the following dependencies: + +* REQUIRED + + The Vulkan SDK must be installed + +For the Vulkan SDK, the simplest way to install it is through [their website](https://vulkan.lunarg.com/sdk/home). You just have to follow the instructions for the relevant platform. diff --git a/ggml/src/ggml-kompute/kompute/examples/array_multiplication/shader/CMakeLists.txt b/ggml/src/ggml-kompute/kompute/examples/array_multiplication/shader/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..684a6a82a91b4460d570505f8113f0a4a3559bfa --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/array_multiplication/shader/CMakeLists.txt @@ -0,0 +1,15 @@ +cmake_minimum_required(VERSION 3.20) + +# To add more shaders simply copy the vulkan_compile_shader command and replace it with your new shader +vulkan_compile_shader(INFILE my_shader.comp + OUTFILE my_shader.hpp + NAMESPACE "shader" + RELATIVE_PATH "${kompute_SOURCE_DIR}/cmake") + +# Then add it to the library, so you can access it later in your code +add_library(shader INTERFACE "${CMAKE_CURRENT_BINARY_DIR}/my_shader.hpp" + + # "${CMAKE_CURRENT_BINARY_DIR}/my_shader2.hpp" +) + +target_include_directories(shader INTERFACE $) diff --git a/ggml/src/ggml-kompute/kompute/examples/array_multiplication/shader/my_shader.comp b/ggml/src/ggml-kompute/kompute/examples/array_multiplication/shader/my_shader.comp new file mode 100644 index 0000000000000000000000000000000000000000..d4cfe69900526bda68d00c66d71492832b58fc1f --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/array_multiplication/shader/my_shader.comp @@ -0,0 +1,14 @@ +#version 450 + +// The execution structure +layout (local_size_x = 1) in; + +// The buffers are provided via the tensors +layout(binding = 0) buffer bufA { float a[]; }; +layout(binding = 1) buffer bufB { float b[]; }; +layout(binding = 2) buffer bufOut { float o[]; }; + +void main() { + uint index = gl_GlobalInvocationID.x; + o[index] = a[index] * b[index]; +} \ No newline at end of file diff --git a/ggml/src/ggml-kompute/kompute/examples/array_multiplication/src/CMakeLists.txt b/ggml/src/ggml-kompute/kompute/examples/array_multiplication/src/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..4b85931b10e550ceee8f56eab83fd0cb8d581b50 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/array_multiplication/src/CMakeLists.txt @@ -0,0 +1,4 @@ +cmake_minimum_required(VERSION 3.20) + +add_executable(kompute_array_mult main.cpp) +target_link_libraries(kompute_array_mult PRIVATE shader kompute::kompute) diff --git a/ggml/src/ggml-kompute/kompute/examples/array_multiplication/src/main.cpp b/ggml/src/ggml-kompute/kompute/examples/array_multiplication/src/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c8637577b9f85f87d68a5c6cd33a48c8d874e13a --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/array_multiplication/src/main.cpp @@ -0,0 +1,45 @@ + +#include +#include +#include + +#include "my_shader.hpp" +#include + +int +main() +{ + kp::Manager mgr; + + std::shared_ptr> tensorInA = + mgr.tensor({ 2.0, 4.0, 6.0 }); + std::shared_ptr> tensorInB = + mgr.tensor({ 0.0, 1.0, 2.0 }); + std::shared_ptr> tensorOut = + mgr.tensor({ 0.0, 0.0, 0.0 }); + + const std::vector> params = { tensorInA, + tensorInB, + tensorOut }; + + const std::vector shader = std::vector( + shader::MY_SHADER_COMP_SPV.begin(), shader::MY_SHADER_COMP_SPV.end()); + std::shared_ptr algo = mgr.algorithm(params, shader); + + mgr.sequence() + ->record(params) + ->record(algo) + ->record(params) + ->eval(); + + // prints "Output { 0 4 12 }" + std::cout << "Output: { "; + for (const float& elem : tensorOut->vector()) { + std::cout << elem << " "; + } + std::cout << "}" << std::endl; + + if (tensorOut->vector() != std::vector{ 0, 4, 12 }) { + throw std::runtime_error("Result does not match"); + } +} diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/.gitignore b/ggml/src/ggml-kompute/kompute/examples/godot_examples/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ee3ffe884d99484515599c24993aceb44203f7be --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/.gitignore @@ -0,0 +1,2 @@ +.import +godot_engine/godot diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/README.md b/ggml/src/ggml-kompute/kompute/examples/godot_examples/README.md new file mode 100644 index 0000000000000000000000000000000000000000..51e82c0517094d52b8c03a31957df0028990a424 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/README.md @@ -0,0 +1,13 @@ +# Godot Summator GPU Kompute Example + +![](https://github.com/KomputeProject/kompute/raw/master/docs/images/komputer-godot-4.gif) + +This repository contains the accompanying code for the Blog post ["Supercharging Game Development with GPU Accelerated Machine Learning"](https://medium.com/@AxSaucedo/supercharging-game-development-with-gpu-accelerated-ml-using-vulkan-kompute-the-godot-game-engine-4e75a84ea9f0). + +This example folder contains three key components: +* The Godot Project file `project.godot` to run the example +* The [instructions for the Custom Module](./custom_module/) implementation +* The [instructions for the GdNative Library](./gdnative_shared/) implementation + +You can also find the advanced ML Logistic Regression example in the GPU in [this example folder](../godot_logistic_regression/). + diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/.gdignore b/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/.gdignore new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/.gitignore b/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..371dc7b6d5b21587e07087da6e0f03d07eb44059 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/.gitignore @@ -0,0 +1,3 @@ +kompute +lib +godot diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/README.md b/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c31c5472f353b93829b8d318759a74cefd3d9f34 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/README.md @@ -0,0 +1,70 @@ +# Kompute Godot Example + +![](https://github.com/KomputeProject/kompute/raw/master/docs/images/komputer-godot-4.gif) + +## Set Up Dependencies + +### Vulkan + +You will need the Vulkan SDK, in this case we use version `1.2.148.1`, which you can get at the official site https://vulkan.lunarg.com/sdk/home#windows + +This will have the following contents that will be required later on: + +* The VulkanSDK static library `vulkan-1` + +### Kompute + +We will be using v0.3.1 of Kompute, and similar to above we will need the built static library, but in this case we will build it. + +We can start by cloning the repository on the v0.3.1 branch: + +``` +git clone --branch v0.3.1 https://github.com/KomputeProject/kompute/ +``` + +You will be able to use cmake to generate the build files for your platform. + +``` +cmake kompute/. -Bkompute/build +``` + +You need to make sure that the build is configured with the same flags required for godot, for example, in windows you will need: + +* Release build +* Configuration type: static library +* Runtime lib: Multi-threaded / multi-threaded debug + +Now you should see the library built under `build/src/Release` + +## Building Godot + +Now to build godot you will need to set up a couple of things for the Scons file to work - namely setting up the following: + +* Copy the `vulkan-1` library from your vulkan sdk folder to `lib/vulkan-1.lib` +* Copy the `kompute.lib` library from the Kompute build to `lib/kompute.lib` +* Copy the `include/vulkan/` folder to the `./include/` folder +* Copy the `single_include/kompute/` to the `./include/` folder +* Make sure the versions above match as we provide the headers in the `include` folder - if you used different versions make sure these match as well + +### Clone godot repository + +Now we can clone the godot repository - it must be on a separate repository, so you can use the parent directory if you are on the Kompute repo. + +``` +cd ../../godot_engine + +git clone --branch 3.2.3-stable https://github.com/godotengine/godot + +cd godot/ +``` + +And now we can build against our module + +``` +wscons -j16 custom_modules=../../custom_module/ platform=windows target=release_debug +``` + +Once we have built it we can now run the generated godot engine in the `bin/` folder, and we will be able to access the custom module from anywhere in the project, as well as creating new nodes from the user interface. + + + diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/kompute_summator/KomputeSummatorNode.cpp b/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/kompute_summator/KomputeSummatorNode.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cef2d26d67da60b191a87ea6328914544bd66d14 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/kompute_summator/KomputeSummatorNode.cpp @@ -0,0 +1,113 @@ +/* summator.cpp */ + +#include + +#include "KomputeSummatorNode.h" + +static std::vector +compileSource(const std::string& source) +{ + std::ofstream fileOut("tmp_kp_shader.comp"); + fileOut << source; + fileOut.close(); + if (system( + std::string( + "glslangValidator -V tmp_kp_shader.comp -o tmp_kp_shader.comp.spv") + .c_str())) + throw std::runtime_error("Error running glslangValidator command"); + std::ifstream fileStream("tmp_kp_shader.comp.spv", std::ios::binary); + std::vector buffer; + buffer.insert( + buffer.begin(), std::istreambuf_iterator(fileStream), {}); + return { (uint32_t*)buffer.data(), + (uint32_t*)(buffer.data() + buffer.size()) }; +} + +KomputeSummatorNode::KomputeSummatorNode() +{ + this->_init(); +} + +void +KomputeSummatorNode::add(float value) +{ + // Set the new data in the local device + this->mSecondaryTensor->setData({ value }); + // Execute recorded sequence + if (std::shared_ptr sq = this->mSequence) { + sq->eval(); + } else { + throw std::runtime_error("Sequence pointer no longer available"); + } +} + +void +KomputeSummatorNode::reset() +{} + +float +KomputeSummatorNode::get_total() const +{ + return this->mPrimaryTensor->data()[0]; +} + +void +KomputeSummatorNode::_init() +{ + std::cout << "CALLING INIT" << std::endl; + this->mPrimaryTensor = this->mManager.tensor({ 0.0 }); + this->mSecondaryTensor = this->mManager.tensor({ 0.0 }); + this->mSequence = this->mManager.sequence(); + + // We now record the steps in the sequence + if (std::shared_ptr sq = this->mSequence) { + + std::string shader(R"( + #version 450 + + layout (local_size_x = 1) in; + + layout(set = 0, binding = 0) buffer a { float pa[]; }; + layout(set = 0, binding = 1) buffer b { float pb[]; }; + + void main() { + uint index = gl_GlobalInvocationID.x; + pa[index] = pb[index] + pa[index]; + } + )"); + + std::shared_ptr algo = this->mManager.algorithm( + { this->mPrimaryTensor, this->mSecondaryTensor }, + compileSource(shader)); + + // First we ensure secondary tensor loads to GPU + // No need to sync the primary tensor as it should not be changed + sq->record({ this->mSecondaryTensor }); + + // Then we run the operation with both tensors + sq->record(algo); + + // We map the result back to local + sq->record({ this->mPrimaryTensor }); + + } else { + throw std::runtime_error("Sequence pointer no longer available"); + } +} + +void +KomputeSummatorNode::_process(float delta) +{} + +void +KomputeSummatorNode::_bind_methods() +{ + ClassDB::bind_method(D_METHOD("_process", "delta"), + &KomputeSummatorNode::_process); + ClassDB::bind_method(D_METHOD("_init"), &KomputeSummatorNode::_init); + + ClassDB::bind_method(D_METHOD("add", "value"), &KomputeSummatorNode::add); + ClassDB::bind_method(D_METHOD("reset"), &KomputeSummatorNode::reset); + ClassDB::bind_method(D_METHOD("get_total"), + &KomputeSummatorNode::get_total); +} diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/kompute_summator/KomputeSummatorNode.h b/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/kompute_summator/KomputeSummatorNode.h new file mode 100644 index 0000000000000000000000000000000000000000..0acfbf5904c185a343c49d7ee2b8ba4c5e3c700b --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/kompute_summator/KomputeSummatorNode.h @@ -0,0 +1,31 @@ +#pragma once + +#include + +#include "kompute/Kompute.hpp" + +#include "scene/main/node.h" + +class KomputeSummatorNode : public Node +{ + GDCLASS(KomputeSummatorNode, Node); + + public: + KomputeSummatorNode(); + + void add(float value); + void reset(); + float get_total() const; + + void _process(float delta); + void _init(); + + protected: + static void _bind_methods(); + + private: + kp::Manager mManager; + std::shared_ptr mSequence; + std::shared_ptr mPrimaryTensor; + std::shared_ptr mSecondaryTensor; +}; diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/kompute_summator/SCsub b/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/kompute_summator/SCsub new file mode 100644 index 0000000000000000000000000000000000000000..6ce3176f4b0244b037e4923f8889de8bbb085905 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/kompute_summator/SCsub @@ -0,0 +1,17 @@ +import os + +Import('env') + +dir_path = os.getcwd() + +# Kompute & Vulkan header files +env.Append(CPPPATH = ['include/']) + +env.add_source_files(env.modules_sources, "*.cpp") + +# Kompute & Vulkan libraries +env.Append(LIBS=[ + File(dir_path +'/lib/kompute.lib'), + File(dir_path +'/lib/vulkan-1.lib'), + ]) + diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/kompute_summator/config.py b/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/kompute_summator/config.py new file mode 100644 index 0000000000000000000000000000000000000000..1c8cd12a2dc049bbc83fcda51d4d929cfe417206 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/kompute_summator/config.py @@ -0,0 +1,5 @@ +def can_build(env, platform): + return True + +def configure(env): + pass diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/kompute_summator/include/.gitignore b/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/kompute_summator/include/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..efb74446707933c31bd5a60f8fc4ce323a5811dd --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/kompute_summator/include/.gitignore @@ -0,0 +1,2 @@ +kompute/ +vulkan/ diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/kompute_summator/register_types.cpp b/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/kompute_summator/register_types.cpp new file mode 100644 index 0000000000000000000000000000000000000000..61fd32ca6f99873a21cc34004d06c70e3075d760 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/kompute_summator/register_types.cpp @@ -0,0 +1,18 @@ +/* register_types.cpp */ + +#include "register_types.h" + +#include "KomputeSummatorNode.h" +#include "core/class_db.h" + +void +register_kompute_summator_types() +{ + ClassDB::register_class(); +} + +void +unregister_kompute_summator_types() +{ + // Nothing to do here in this example. +} diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/kompute_summator/register_types.h b/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/kompute_summator/register_types.h new file mode 100644 index 0000000000000000000000000000000000000000..2c8cb4b7e49c098a5909dde9da54bd5827fef947 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/custom_module/kompute_summator/register_types.h @@ -0,0 +1,8 @@ +/* register_types.h */ +#pragma once + +void +register_kompute_summator_types(); +void +unregister_kompute_summator_types(); +/* yes, the word in the middle must be the same as the module folder name */ diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/gdnative_shared/.gdignore b/ggml/src/ggml-kompute/kompute/examples/godot_examples/gdnative_shared/.gdignore new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/gdnative_shared/.gitignore b/ggml/src/ggml-kompute/kompute/examples/godot_examples/gdnative_shared/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..1fe38fec344c302991d750f413f7082edca42a22 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/gdnative_shared/.gitignore @@ -0,0 +1,4 @@ +kompute +lib +godot-cpp +.import diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/gdnative_shared/CMakeLists.txt b/ggml/src/ggml-kompute/kompute/examples/godot_examples/gdnative_shared/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1f54060cc3cf3f7de8f3a59f0bd9b8c14e456cbe --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/gdnative_shared/CMakeLists.txt @@ -0,0 +1,40 @@ +cmake_minimum_required(VERSION 3.17.0) +project(kompute_godot VERSION 0.1.0) + +set(CMAKE_CXX_STANDARD 14) + +set(KOMPUTE_EXTRA_CXX_FLAGS "" CACHE STRING "Extra compile flags for Kompute, see docs for full list") + +# It is necessary to pass the DEBUG or RELEASE flag accordingly to Kompute +set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -DDEBUG=1 ${KOMPUTE_EXTRA_CXX_FLAGS}") +set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DRELEASE=1 ${KOMPUTE_EXTRA_CXX_FLAGS}") + +set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CC_FLAGS} /W3 /MDd /Zi /EHsc /Ox /DNDEBUG /FS") + +find_package(kompute REQUIRED) +find_package(Vulkan REQUIRED) + +add_library(kompute_godot + SHARED + src/KomputeSummator.cpp + src/KomputeGdNative.cpp) + +target_include_directories( + kompute_godot PUBLIC + src/ + godot-cpp/include + godot-cpp/include/core + godot-cpp/include/gen + godot-cpp/godot_headers + +) + +target_link_libraries(kompute_godot + kompute::kompute + Vulkan::Vulkan +) + +target_link_libraries(kompute_godot + ${CMAKE_CURRENT_SOURCE_DIR}/godot-cpp/bin/libgodot-cpp.windows.release.default.lib +) + diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/gdnative_shared/README.md b/ggml/src/ggml-kompute/kompute/examples/godot_examples/gdnative_shared/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4ead7493804b490c7389d4cf2221827b689c5605 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/gdnative_shared/README.md @@ -0,0 +1,26 @@ +# Kompute Godot Example + +![](https://github.com/KomputeProject/kompute/raw/master/docs/images/komputer-godot-4.gif) + +## Set Up Dependencies + +We can get all the required dependencies from godot by running + +``` +git clone --branch 3.2 https://github.com/godotengine/godot-cpp + +cd godot-cpp +``` + +Then we can get all the submodules + +``` +git submodule sync +``` + +And we build the bindings + +``` +scons -j16 platform=linuxbsd target=debug + +``` diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/gdnative_shared/src/.gdignore b/ggml/src/ggml-kompute/kompute/examples/godot_examples/gdnative_shared/src/.gdignore new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/gdnative_shared/src/KomputeGdNative.cpp b/ggml/src/ggml-kompute/kompute/examples/godot_examples/gdnative_shared/src/KomputeGdNative.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d06ba981165c3529a608e18f86fc391494e48e35 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/gdnative_shared/src/KomputeGdNative.cpp @@ -0,0 +1,21 @@ +#include "KomputeSummator.hpp" + +extern "C" void GDN_EXPORT +godot_gdnative_init(godot_gdnative_init_options* o) +{ + godot::Godot::gdnative_init(o); +} + +extern "C" void GDN_EXPORT +godot_gdnative_terminate(godot_gdnative_terminate_options* o) +{ + godot::Godot::gdnative_terminate(o); +} + +extern "C" void GDN_EXPORT +godot_nativescript_init(void* handle) +{ + godot::Godot::nativescript_init(handle); + + godot::register_class(); +} diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/gdnative_shared/src/KomputeSummator.cpp b/ggml/src/ggml-kompute/kompute/examples/godot_examples/gdnative_shared/src/KomputeSummator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..277fb69ddabe8ccbbaa504cf44b4fbf0306aee6a --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/gdnative_shared/src/KomputeSummator.cpp @@ -0,0 +1,113 @@ +/* summator.cpp */ + +#include +#include + +#include "KomputeSummator.hpp" + +static std::vector +compileSource(const std::string& source) +{ + std::ofstream fileOut("tmp_kp_shader.comp"); + fileOut << source; + fileOut.close(); + if (system( + std::string( + "glslangValidator -V tmp_kp_shader.comp -o tmp_kp_shader.comp.spv") + .c_str())) + throw std::runtime_error("Error running glslangValidator command"); + std::ifstream fileStream("tmp_kp_shader.comp.spv", std::ios::binary); + std::vector buffer; + buffer.insert( + buffer.begin(), std::istreambuf_iterator(fileStream), {}); + return { (uint32_t*)buffer.data(), + (uint32_t*)(buffer.data() + buffer.size()) }; +} + +namespace godot { + +KomputeSummator::KomputeSummator() +{ + std::cout << "CALLING CONSTRUCTOR" << std::endl; + this->_init(); +} + +void +KomputeSummator::add(float value) +{ + // Set the new data in the local device + this->mSecondaryTensor->setData({ value }); + // Execute recorded sequence + this->mSequence->eval(); +} + +void +KomputeSummator::reset() +{} + +float +KomputeSummator::get_total() const +{ + return this->mPrimaryTensor->data()[0]; +} + +void +KomputeSummator::_init() +{ + std::cout << "CALLING INIT" << std::endl; + this->mPrimaryTensor = this->mManager.tensor({ 0.0 }); + this->mSecondaryTensor = this->mManager.tensor({ 0.0 }); + this->mSequence = this->mManager.sequence("AdditionSeq"); + + // We now record the steps in the sequence + { + std::string shader(R"( + #version 450 + + layout (local_size_x = 1) in; + + layout(set = 0, binding = 0) buffer a { float pa[]; }; + layout(set = 0, binding = 1) buffer b { float pb[]; }; + + void main() { + uint index = gl_GlobalInvocationID.x; + pa[index] = pb[index] + pa[index]; + } + )"); + + this->mSequence->begin(); + + // First we ensure secondary tensor loads to GPU + // No need to sync the primary tensor as it should not be changed + this->mSequence->record( + { this->mSecondaryTensor }); + + // Then we run the operation with both tensors + this->mSequence->record( + { this->mPrimaryTensor, this->mSecondaryTensor }, + compileSource(shader)); + + // We map the result back to local + this->mSequence->record( + { this->mPrimaryTensor }); + + this->mSequence->end(); + } +} + +void +KomputeSummator::_process(float delta) +{} + +void +KomputeSummator::_register_methods() +{ + register_method((char*)"_process", &KomputeSummator::_process); + register_method((char*)"_init", &KomputeSummator::_init); + + register_method((char*)"add", &KomputeSummator::add); + register_method((char*)"reset", &KomputeSummator::reset); + register_method((char*)"get_total", &KomputeSummator::get_total); +} + +} diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/gdnative_shared/src/KomputeSummator.hpp b/ggml/src/ggml-kompute/kompute/examples/godot_examples/gdnative_shared/src/KomputeSummator.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3aafa73d9209ae5abc9690ed5db11a94a30b4a40 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/gdnative_shared/src/KomputeSummator.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include +#include + +#include + +#include "kompute/Kompute.hpp" + +namespace godot { +class KomputeSummator : public Node2D +{ + private: + GODOT_CLASS(KomputeSummator, Node2D); + + public: + KomputeSummator(); + + void add(float value); + void reset(); + float get_total() const; + + void _process(float delta); + void _init(); + + static void _register_methods(); + + private: + kp::Manager mManager; + std::shared_ptr mSequence; + std::shared_ptr mPrimaryTensor; + std::shared_ptr mSecondaryTensor; +}; + +} diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_engine/.gdignore b/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_engine/.gdignore new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/assets/CustomModuleExampleScene.tscn b/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/assets/CustomModuleExampleScene.tscn new file mode 100644 index 0000000000000000000000000000000000000000..a9837511272b8a901b63cb4cc9521b774294a67c --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/assets/CustomModuleExampleScene.tscn @@ -0,0 +1,8 @@ +[gd_scene load_steps=2 format=2] + +[ext_resource path="res://godot_resources/scripts/CustomModuleExampleScene.gd" type="Script" id=1] + +[node name="Parent" type="Node2D"] +script = ExtResource( 1 ) + +[node name="CustomKomputeNode" type="KomputeSummatorNode" parent="."] diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/assets/DynamicExampleScene.tscn b/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/assets/DynamicExampleScene.tscn new file mode 100644 index 0000000000000000000000000000000000000000..f390717c90648fbeaf3fe7e3759b10081f035c07 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/assets/DynamicExampleScene.tscn @@ -0,0 +1,10 @@ +[gd_scene load_steps=3 format=2] + +[ext_resource path="res://godot_resources/scripts/DynamicExampleScript.gd" type="Script" id=1] +[ext_resource path="res://godot_resources/scripts/KomputeNativeClass.gdns" type="Script" id=2] + +[node name="Parent" type="Node2D"] +script = ExtResource( 1 ) + +[node name="KomputeNode" type="Node2D" parent="."] +script = ExtResource( 2 ) diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/assets/default_env.tres b/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/assets/default_env.tres new file mode 100644 index 0000000000000000000000000000000000000000..20207a4aa2cfbae0c95636531b00d91502940d12 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/assets/default_env.tres @@ -0,0 +1,7 @@ +[gd_resource type="Environment" load_steps=2 format=2] + +[sub_resource type="ProceduralSky" id=1] + +[resource] +background_mode = 2 +background_sky = SubResource( 1 ) diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/assets/icon.png b/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/assets/icon.png new file mode 100644 index 0000000000000000000000000000000000000000..1d3cc3b28202d98ebbaf6bc7da0190d7b794ae7e Binary files /dev/null and b/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/assets/icon.png differ diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/assets/icon.png.import b/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/assets/icon.png.import new file mode 100644 index 0000000000000000000000000000000000000000..edc407e0a3e638cecbe483f3ff146bba335c188f --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/assets/icon.png.import @@ -0,0 +1,34 @@ +[remap] + +importer="texture" +type="StreamTexture" +path="res://.import/icon.png-dceb2295401b72087458cab0725ade4f.stex" +metadata={ +"vram_texture": false +} + +[deps] + +source_file="res://godot_resources/assets/icon.png" +dest_files=[ "res://.import/icon.png-dceb2295401b72087458cab0725ade4f.stex" ] + +[params] + +compress/mode=0 +compress/lossy_quality=0.7 +compress/hdr_mode=0 +compress/bptc_ldr=0 +compress/normal_map=0 +flags/repeat=0 +flags/filter=true +flags/mipmaps=false +flags/anisotropic=false +flags/srgb=2 +process/fix_alpha_border=true +process/premult_alpha=false +process/HDR_as_SRGB=false +process/invert_color=false +stream=false +size_limit=0 +detect_3d=true +svg/scale=1.0 diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/scripts/CustomModuleExampleScene.gd b/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/scripts/CustomModuleExampleScene.gd new file mode 100644 index 0000000000000000000000000000000000000000..45f68783e80ec3ddfb9505b7d1d9f2e3cb363576 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/scripts/CustomModuleExampleScene.gd @@ -0,0 +1,28 @@ +extends Node2D + +# Called when the node enters the scene tree for the first time. +func _ready(): + print("hello") + + # Use existing node + print($CustomKomputeNode.get_total()) + + $CustomKomputeNode.add(10) + print($CustomKomputeNode.get_total()) + + $CustomKomputeNode.add(10) + print($CustomKomputeNode.get_total()) + + # Create new instance + var s = KomputeSummatorNode.new() + + # This will print 0 as it's a new instance + print(s.get_total()) + + # Now we can again send further commands + s.add(10) + print(s.get_total()) + + s.add(10) + print(s.get_total()) + diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/scripts/DynamicExampleScript.gd b/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/scripts/DynamicExampleScript.gd new file mode 100644 index 0000000000000000000000000000000000000000..454314a57e416728896769c5e4b6b15416d773c8 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/scripts/DynamicExampleScript.gd @@ -0,0 +1,29 @@ +extends Node2D + +# Called when the node enters the scene tree for the first time. +func _ready(): + + print("hello") + + # Use existing node + print($KomputeNode.get_total()) + + $KomputeNode.add(10) + print($KomputeNode.get_total()) + + $KomputeNode.add(10) + print($KomputeNode.get_total()) + + # Create new instance + var s = KomputeSummator.new() + + # This will print 0 as it's a new instance + print(s.get_total()) + + # Now we can again send further commands + s.add(10) + print(s.get_total()) + + s.add(10) + print(s.get_total()) + diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/scripts/KomputeNativeClass.gdns b/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/scripts/KomputeNativeClass.gdns new file mode 100644 index 0000000000000000000000000000000000000000..53ab73f77232d0752b81cdf089a3873e76a7ca38 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/scripts/KomputeNativeClass.gdns @@ -0,0 +1,8 @@ +[gd_resource type="NativeScript" load_steps=2 format=2] + +[ext_resource path="res://godot_resources/scripts/KomputeNativeLibrary.gdnlib" type="GDNativeLibrary" id=1] + +[resource] +class_name = "KomputeSummator" +library = ExtResource( 1 ) +script_class_name = "KomputeSummator" diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/scripts/KomputeNativeLibrary.gdnlib b/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/scripts/KomputeNativeLibrary.gdnlib new file mode 100644 index 0000000000000000000000000000000000000000..6bc8c07007962bad5d05834dbfe5ffef00e47806 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/godot_resources/scripts/KomputeNativeLibrary.gdnlib @@ -0,0 +1,14 @@ +[general] + +singleton=false +load_once=true +symbol_prefix="godot_" +reloadable=false + +[entry] + +Windows.64="res://gdnative_shared/build/Release/kompute_godot.dll" + +[dependencies] + +Windows.64=[ ] diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_examples/project.godot b/ggml/src/ggml-kompute/kompute/examples/godot_examples/project.godot new file mode 100644 index 0000000000000000000000000000000000000000..ece04bef79799ea59c64c0ab1e91c03a0efa0ee1 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_examples/project.godot @@ -0,0 +1,29 @@ +; Engine configuration file. +; It's best edited using the editor UI and not directly, +; since the parameters that go here are not all obvious. +; +; Format: +; [section] ; section goes between [] +; param=value ; assign values to parameters + +config_version=4 + +_global_script_classes=[ { +"base": "Node2D", +"class": "KomputeSummator", +"language": "NativeScript", +"path": "res://godot_resources/scripts/KomputeNativeClass.gdns" +} ] +_global_script_class_icons={ +"KomputeSummator": "" +} + +[application] + +config/name="KomputeGame" +run/main_scene="res://godot_resources/assets/DynamicExampleScene.tscn" +config/icon="res://godot_resources/assets/icon.png" + +[rendering] + +environment/default_environment="res://godot_resources/assets/default_env.tres" diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/.gitignore b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ee3ffe884d99484515599c24993aceb44203f7be --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/.gitignore @@ -0,0 +1,2 @@ +.import +godot_engine/godot diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/README.md b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/README.md new file mode 100644 index 0000000000000000000000000000000000000000..510205ec981b0a6a08243feff7cc0eef38d7459b --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/README.md @@ -0,0 +1,14 @@ + +# Kompute Game Engine Integration (Godot) + +![](https://github.com/KomputeProject/kompute/raw/master/docs/images/komputer-godot-4.gif) + +This repository contains the accompanying code for the Blog post ["Supercharging Game Development with GPU Accelerated Machine Learning"](https://medium.com/@AxSaucedo/supercharging-game-development-with-gpu-accelerated-ml-using-vulkan-kompute-the-godot-game-engine-4e75a84ea9f0). + +This example folder contains three key components: +* The Godot Project file `project.godot` to run the example +* The [instructions for the Custom Module](./custom_module/) implementation +* The [instructions for the GdNative Library](./gdnative_shared/) implementation + +You can also find the simpler Godot Summator example implemented in the GPU in [this example folder](../godot_examples/). + diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/.gdignore b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/.gdignore new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/.gitignore b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..371dc7b6d5b21587e07087da6e0f03d07eb44059 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/.gitignore @@ -0,0 +1,3 @@ +kompute +lib +godot diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/README.md b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7370830d28e2b0d5e0f20bc546e95ee1043e73f1 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/README.md @@ -0,0 +1,76 @@ + +## Godot Engine Integration: Godot Engine Source Module + +This is the accompanying code for the Blog post ["Supercharging Game Development with GPU Accelerated Machine Learning"](https://medium.com/@AxSaucedo/supercharging-game-development-with-gpu-accelerated-ml-using-vulkan-kompute-the-godot-game-engine-4e75a84ea9f0). + +This section contains the implementation of the Kompute module as a statically compile module built with the Godot engine source code. This approach requires re-compiling the Godot engine source code. + +![](https://github.com/KomputeProject/kompute/raw/master/docs/images/komputer-godot-4.gif) + +## Set Up Dependencies + +### Vulkan + +You will need the Vulkan SDK, in this case we use version `1.2.148.1`, which you can get at the official site https://vulkan.lunarg.com/sdk/home#windows + +This will have the following contents that will be required later on: + +* The VulkanSDK static library `vulkan-1` +* The Vulkan headers in the `include/` folder + +### Kompute + +We will be using v0.3.1 of Kompute, and similar to above we will need the built static library, but in this case we will build it. + +We can start by cloning the repository on the v0.3.1 branch: + +``` +git clone --branch v0.3.1 https://github.com/KomputeProject/kompute/ +``` + +You will be able to use cmake to generate the build files for your platform. + +``` +cmake kompute/. -Bkompute/build +``` + +You need to make sure that the build is configured with the same flags required for godot, for example, in windows you will need: + +* Release build +* Configuration type: static library +* Runtime lib: Multi-threaded / multi-threaded debug + +Now you should see the library built under `build/src/Release` + +## Building Godot + +Now to build godot you will need to set up a couple of things for the Scons file to work - namely setting up the following: + +* Copy the `vulkan-1` library from your vulkan sdk folder to `lib/vulkan-1.lib` +* Copy the `kompute.lib` library from the Kompute build to `lib/kompute.lib` +* Copy the `include/vulkan/` folder to the `./include/` folder +* Copy the `single_include/kompute/` to the `./include/` folder +* Make sure the versions above match as we provide the headers in the `include` folder - if you used different versions make sure these match as well + +### Clone godot repository + +Now we can clone the godot repository - it must be on a separate repository, so you can use the parent directory if you are on the Kompute repo. + +``` +cd ../../godot_engine + +git clone --branch 3.2.3-stable https://github.com/godotengine/godot + +cd godot/ +``` + +And now we can build against our module + +``` +wscons -j16 custom_modules=../../custom_module/ platform=windows target=release_debug +``` + +Once we have built it we can now run the generated godot engine in the `bin/` folder, and we will be able to access the custom module from anywhere in the project, as well as creating new nodes from the user interface. + + + diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/kompute_model_ml/KomputeModelMLNode.cpp b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/kompute_model_ml/KomputeModelMLNode.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bf9b9ef85bf83be330e0fbcd12e9fb9eb71ea86b --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/kompute_model_ml/KomputeModelMLNode.cpp @@ -0,0 +1,165 @@ +/* summator.cpp */ + +#include + +#include "KomputeModelMLNode.h" + +KomputeModelMLNode::KomputeModelMLNode() +{ + std::cout << "CALLING CONSTRUCTOR" << std::endl; + this->_init(); +} + +void +KomputeModelMLNode::train(Array yArr, Array xIArr, Array xJArr) +{ + + assert(yArr.size() == xIArr.size()); + assert(xIArr.size() == xJArr.size()); + + std::vector yData; + std::vector xIData; + std::vector xJData; + std::vector zerosData; + + for (size_t i = 0; i < yArr.size(); i++) { + yData.push_back(yArr[i]); + xIData.push_back(xIArr[i]); + xJData.push_back(xJArr[i]); + zerosData.push_back(0); + } + + uint32_t ITERATIONS = 100; + float learningRate = 0.1; + + { + kp::Manager mgr; + + std::shared_ptr xI = mgr.tensor(xIData); + std::shared_ptr xJ = mgr.tensor(xJData); + + std::shared_ptr y = mgr.tensor(yData); + + std::shared_ptr wIn = mgr.tensor({ 0.001, 0.001 }); + std::shared_ptr wOutI = mgr.tensor(zerosData); + std::shared_ptr wOutJ = mgr.tensor(zerosData); + + std::shared_ptr bIn = mgr.tensor({ 0 }); + std::shared_ptr bOut = mgr.tensor(zerosData); + + std::shared_ptr lOut = mgr.tensor(zerosData); + + std::vector> params = { xI, xJ, y, + wIn, wOutI, wOutJ, + bIn, bOut, lOut }; + + { + std::vector spirv( + (uint32_t*) + kp::shader_data::shaders_glsl_logisticregression_comp_spv, + (uint32_t*)(kp::shader_data:: + shaders_glsl_logisticregression_comp_spv + + kp::shader_data:: + shaders_glsl_logisticregression_comp_spv_len)); + + std::shared_ptr algo = mgr.algorithm(params, spirv); + + mgr.sequence()->eval(params); + + std::shared_ptr sq = + mgr.sequence() + ->record({ wIn, bIn }) + ->record(algo) + ->record({ wOutI, wOutJ, bOut, lOut }); + + // Iterate across all expected iterations + for (size_t i = 0; i < ITERATIONS; i++) { + + sq->eval(); + + for (size_t j = 0; j < bOut->size(); j++) { + wIn->data()[0] -= learningRate * wOutI->data()[j]; + wIn->data()[1] -= learningRate * wOutJ->data()[j]; + bIn->data()[0] -= learningRate * bOut->data()[j]; + } + } + } + + KP_LOG_INFO("RESULT: <<<<<<<<<<<<<<<<<<<"); + KP_LOG_INFO(wIn->data()[0]); + KP_LOG_INFO(wIn->data()[1]); + KP_LOG_INFO(bIn->data()[0]); + + this->mWeights = kp::Tensor(wIn->data()); + this->mBias = kp::Tensor(bIn->data()); + } +} + +Array +KomputeModelMLNode::predict(Array xI, Array xJ) +{ + assert(xI.size() == xJ.size()); + + Array retArray; + + // We run the inference in the CPU for simplicity + // BUt you can also implement the inference on GPU + // GPU implementation would speed up minibatching + for (size_t i = 0; i < xI.size(); i++) { + float xIVal = xI[i]; + float xJVal = xJ[i]; + float result = + (xIVal * this->mWeights.data()[0] + xJVal * this->mWeights.data()[1] + + this->mBias.data()[0]); + + // Instead of using sigmoid we'll just return full numbers + Variant var = result > 0 ? 1 : 0; + retArray.push_back(var); + } + + return retArray; +} + +Array +KomputeModelMLNode::get_params() +{ + Array retArray; + + KP_LOG_INFO(this->mWeights.size() + this->mBias.size()); + + if (this->mWeights.size() + this->mBias.size() == 0) { + return retArray; + } + + retArray.push_back(this->mWeights.data()[0]); + retArray.push_back(this->mWeights.data()[1]); + retArray.push_back(this->mBias.data()[0]); + retArray.push_back(99.0); + + return retArray; +} + +void +KomputeModelMLNode::_init() +{ + std::cout << "CALLING INIT" << std::endl; +} + +void +KomputeModelMLNode::_process(float delta) +{} + +void +KomputeModelMLNode::_bind_methods() +{ + ClassDB::bind_method(D_METHOD("_process", "delta"), + &KomputeModelMLNode::_process); + ClassDB::bind_method(D_METHOD("_init"), &KomputeModelMLNode::_init); + + ClassDB::bind_method(D_METHOD("train", "yArr", "xIArr", "xJArr"), + &KomputeModelMLNode::train); + ClassDB::bind_method(D_METHOD("predict", "xI", "xJ"), + &KomputeModelMLNode::predict); + ClassDB::bind_method(D_METHOD("get_params"), + &KomputeModelMLNode::get_params); +} diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/kompute_model_ml/KomputeModelMLNode.h b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/kompute_model_ml/KomputeModelMLNode.h new file mode 100644 index 0000000000000000000000000000000000000000..bdc2ce85de86022df5fc1f7085fc9811aad5e0fd --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/kompute_model_ml/KomputeModelMLNode.h @@ -0,0 +1,88 @@ +#pragma once + +#include + +#include "kompute/Kompute.hpp" + +#include "scene/main/node.h" + +class KomputeModelMLNode : public Node +{ + GDCLASS(KomputeModelMLNode, Node); + + public: + KomputeModelMLNode(); + + void train(Array y, Array xI, Array xJ); + + Array predict(Array xI, Array xJ); + + Array get_params(); + + void _process(float delta); + void _init(); + + protected: + static void _bind_methods(); + + private: + kp::Tensor mWeights; + kp::Tensor mBias; +}; + +static std::string LR_SHADER = R"( +#version 450 + +layout (constant_id = 0) const uint M = 0; + +layout (local_size_x = 1) in; + +layout(set = 0, binding = 0) buffer bxi { float xi[]; }; +layout(set = 0, binding = 1) buffer bxj { float xj[]; }; +layout(set = 0, binding = 2) buffer by { float y[]; }; +layout(set = 0, binding = 3) buffer bwin { float win[]; }; +layout(set = 0, binding = 4) buffer bwouti { float wouti[]; }; +layout(set = 0, binding = 5) buffer bwoutj { float woutj[]; }; +layout(set = 0, binding = 6) buffer bbin { float bin[]; }; +layout(set = 0, binding = 7) buffer bbout { float bout[]; }; +layout(set = 0, binding = 8) buffer blout { float lout[]; }; + +float m = float(M); + +float sigmoid(float z) { + return 1.0 / (1.0 + exp(-z)); +} + +float inference(vec2 x, vec2 w, float b) { + // Compute the linear mapping function + float z = dot(w, x) + b; + // Calculate the y-hat with sigmoid + float yHat = sigmoid(z); + return yHat; +} + +float calculateLoss(float yHat, float y) { + return -(y * log(yHat) + (1.0 - y) * log(1.0 - yHat)); +} + +void main() { + uint idx = gl_GlobalInvocationID.x; + + vec2 wCurr = vec2(win[0], win[1]); + float bCurr = bin[0]; + + vec2 xCurr = vec2(xi[idx], xj[idx]); + float yCurr = y[idx]; + + float yHat = inference(xCurr, wCurr, bCurr); + + float dZ = yHat - yCurr; + vec2 dW = (1. / m) * xCurr * dZ; + float dB = (1. / m) * dZ; + wouti[idx] = dW.x; + woutj[idx] = dW.y; + bout[idx] = dB; + + lout[idx] = calculateLoss(yHat, yCurr); +} +)"; diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/kompute_model_ml/SCsub b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/kompute_model_ml/SCsub new file mode 100644 index 0000000000000000000000000000000000000000..6ce3176f4b0244b037e4923f8889de8bbb085905 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/kompute_model_ml/SCsub @@ -0,0 +1,17 @@ +import os + +Import('env') + +dir_path = os.getcwd() + +# Kompute & Vulkan header files +env.Append(CPPPATH = ['include/']) + +env.add_source_files(env.modules_sources, "*.cpp") + +# Kompute & Vulkan libraries +env.Append(LIBS=[ + File(dir_path +'/lib/kompute.lib'), + File(dir_path +'/lib/vulkan-1.lib'), + ]) + diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/kompute_model_ml/config.py b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/kompute_model_ml/config.py new file mode 100644 index 0000000000000000000000000000000000000000..1c8cd12a2dc049bbc83fcda51d4d929cfe417206 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/kompute_model_ml/config.py @@ -0,0 +1,5 @@ +def can_build(env, platform): + return True + +def configure(env): + pass diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/kompute_model_ml/include/.gitignore b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/kompute_model_ml/include/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..30c2017b74ce23d31298020a66f7aba1820e2551 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/kompute_model_ml/include/.gitignore @@ -0,0 +1,2 @@ +vulkan/ +kompute/ diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/kompute_model_ml/register_types.cpp b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/kompute_model_ml/register_types.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f8e4fc92cd1885cb2ae2da6e1b12100cec936c86 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/kompute_model_ml/register_types.cpp @@ -0,0 +1,18 @@ +/* register_types.cpp */ + +#include "register_types.h" + +#include "KomputeModelMLNode.h" +#include "core/class_db.h" + +void +register_kompute_model_ml_types() +{ + ClassDB::register_class(); +} + +void +unregister_kompute_model_ml_types() +{ + // Nothing to do here in this example. +} diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/kompute_model_ml/register_types.h b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/kompute_model_ml/register_types.h new file mode 100644 index 0000000000000000000000000000000000000000..734f1aa439b29756cf15404e9ffac81b50a3137d --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/custom_module/kompute_model_ml/register_types.h @@ -0,0 +1,8 @@ +/* register_types.h */ +#pragma once + +void +register_kompute_model_ml_types(); +void +unregister_kompute_model_ml_types(); +/* yes, the word in the middle must be the same as the module folder name */ diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/gdnative_shared/.gdignore b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/gdnative_shared/.gdignore new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/gdnative_shared/.gitignore b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/gdnative_shared/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..1fe38fec344c302991d750f413f7082edca42a22 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/gdnative_shared/.gitignore @@ -0,0 +1,4 @@ +kompute +lib +godot-cpp +.import diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/gdnative_shared/CMakeLists.txt b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/gdnative_shared/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..6091f86033b3e7b67f6ff450cb9859fcac0f2233 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/gdnative_shared/CMakeLists.txt @@ -0,0 +1,40 @@ +cmake_minimum_required(VERSION 3.17.0) +project(kompute_godot VERSION 0.1.0) + +set(CMAKE_CXX_STANDARD 14) + +set(KOMPUTE_EXTRA_CXX_FLAGS "" CACHE STRING "Extra compile flags for Kompute, see docs for full list") + +# It is necessary to pass the DEBUG or RELEASE flag accordingly to Kompute +set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -DDEBUG=1 ${KOMPUTE_EXTRA_CXX_FLAGS}") +set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DRELEASE=1 ${KOMPUTE_EXTRA_CXX_FLAGS}") + +set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CC_FLAGS} /W3 /MDd /Zi /EHsc /Ox /DNDEBUG /FS") + +find_package(kompute REQUIRED) +find_package(Vulkan REQUIRED) + +add_library(kompute_godot + STATIC + src/KomputeModelML.cpp + src/KomputeGdNative.cpp) + +target_include_directories( + kompute_godot PUBLIC + src/ + godot-cpp/include + godot-cpp/include/core + godot-cpp/include/gen + godot-cpp/godot_headers + +) + +target_link_libraries(kompute_godot + kompute::kompute + Vulkan::Vulkan +) + +target_link_libraries(kompute_godot + ${CMAKE_CURRENT_SOURCE_DIR}/godot-cpp/bin/libgodot-cpp.windows.release.default.lib +) + diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/gdnative_shared/README.md b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/gdnative_shared/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ec6bbf398c97170a6579dc9717289944149768bf --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/gdnative_shared/README.md @@ -0,0 +1,33 @@ + +## Godot Engine Integration: GdNative Library + +This is the accompanying code for the Blog post ["Supercharging Game Development with GPU Accelerated Machine Learning"](https://medium.com/@AxSaucedo/supercharging-game-development-with-gpu-accelerated-ml-using-vulkan-kompute-the-godot-game-engine-4e75a84ea9f0). + +This section contains the implementation of the Kompute module as a shared GdNative Library that can be loaded dynamically through the Godot engine. This approach does not require re-compiling the Godot engine source code. + + +![](https://github.com/KomputeProject/kompute/raw/master/docs/images/komputer-godot-4.gif) + +### Set Up Dependencies + +We can get all the required dependencies from godot by running + +``` +git clone --branch 3.2 https://github.com/godotengine/godot-cpp + +cd godot-cpp +``` + +Then we can get all the subomdules + +``` +git submodule sync +``` + +and we build the bindings + +``` +scons -j16 platform=linuxbsd target=debug + +``` + diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/gdnative_shared/src/.gdignore b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/gdnative_shared/src/.gdignore new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/gdnative_shared/src/KomputeGdNative.cpp b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/gdnative_shared/src/KomputeGdNative.cpp new file mode 100644 index 0000000000000000000000000000000000000000..06dd72452fb1079920562d0cda714df54ebf671c --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/gdnative_shared/src/KomputeGdNative.cpp @@ -0,0 +1,21 @@ +#include "KomputeModelML.hpp" + +extern "C" void GDN_EXPORT +godot_gdnative_init(godot_gdnative_init_options* o) +{ + godot::Godot::gdnative_init(o); +} + +extern "C" void GDN_EXPORT +godot_gdnative_terminate(godot_gdnative_terminate_options* o) +{ + godot::Godot::gdnative_terminate(o); +} + +extern "C" void GDN_EXPORT +godot_nativescript_init(void* handle) +{ + godot::Godot::nativescript_init(handle); + + godot::register_class(); +} diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/gdnative_shared/src/KomputeModelML.cpp b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/gdnative_shared/src/KomputeModelML.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2efbbfddc1d89d28590758d62658af8fb0fee792 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/gdnative_shared/src/KomputeModelML.cpp @@ -0,0 +1,167 @@ +#pragma once + +#include +#include +#include + +#include "KomputeModelML.hpp" + +namespace godot { + +KomputeModelML::KomputeModelML() +{ + std::cout << "CALLING CONSTRUCTOR" << std::endl; + this->_init(); +} + +void +KomputeModelML::train(Array yArr, Array xIArr, Array xJArr) +{ + + assert(yArr.size() == xIArr.size()); + assert(xIArr.size() == xJArr.size()); + + std::vector yData; + std::vector xIData; + std::vector xJData; + std::vector zerosData; + + for (size_t i = 0; i < yArr.size(); i++) { + yData.push_back(yArr[i]); + xIData.push_back(xIArr[i]); + xJData.push_back(xJArr[i]); + zerosData.push_back(0); + } + + uint32_t ITERATIONS = 100; + float learningRate = 0.1; + + { + kp::Manager mgr; + + std::shared_ptr xI = mgr.tensor(xIData); + std::shared_ptr xJ = mgr.tensor(xJData); + + std::shared_ptr y = mgr.tensor(yData); + + std::shared_ptr wIn = mgr.tensor({ 0.001, 0.001 }); + std::shared_ptr wOutI = mgr.tensor(zerosData); + std::shared_ptr wOutJ = mgr.tensor(zerosData); + + std::shared_ptr bIn = mgr.tensor({ 0 }); + std::shared_ptr bOut = mgr.tensor(zerosData); + + std::shared_ptr lOut = mgr.tensor(zerosData); + + std::vector> params = { xI, xJ, y, + wIn, wOutI, wOutJ, + bIn, bOut, lOut }; + + { + std::vector spirv( + (uint32_t*) + kp::shader_data::shaders_glsl_logisticregression_comp_spv, + (uint32_t*)(kp::shader_data:: + shaders_glsl_logisticregression_comp_spv + + kp::shader_data:: + shaders_glsl_logisticregression_comp_spv_len)); + + std::shared_ptr algo = mgr.algorithm(params, spirv); + + mgr.sequence()->eval(params); + + std::shared_ptr sq = + mgr.sequence() + ->record({ wIn, bIn }) + ->record(algo) + ->record({ wOutI, wOutJ, bOut, lOut }); + + // Iterate across all expected iterations + for (size_t i = 0; i < ITERATIONS; i++) { + + sq->eval(); + + for (size_t j = 0; j < bOut->size(); j++) { + wIn->data()[0] -= learningRate * wOutI->data()[j]; + wIn->data()[1] -= learningRate * wOutJ->data()[j]; + bIn->data()[0] -= learningRate * bOut->data()[j]; + } + } + } + + KP_LOG_INFO("RESULT: <<<<<<<<<<<<<<<<<<<"); + KP_LOG_INFO(wIn->data()[0]); + KP_LOG_INFO(wIn->data()[1]); + KP_LOG_INFO(bIn->data()[0]); + + this->mWeights = wIn; + this->mBias = bIn; + } +} + +Array +KomputeModelML::predict(Array xI, Array xJ) +{ + assert(xI.size() == xJ.size()); + + Array retArray; + + // We run the inference in the CPU for simplicity + // BUt you can also implement the inference on GPU + // GPU implementation would speed up minibatching + for (size_t i = 0; i < xI.size(); i++) { + float xIVal = xI[i]; + float xJVal = xJ[i]; + float result = + (xIVal * this->mWeights->data()[0] + + xJVal * this->mWeights->data()[1] + this->mBias->data()[0]); + + // Instead of using sigmoid we'll just return full numbers + Variant var = result > 0 ? 1 : 0; + retArray.push_back(var); + } + + return retArray; +} + +Array +KomputeModelML::get_params() +{ + Array retArray; + + KP_LOG_INFO(this->mWeights->size() + this->mBias->size()); + + if (this->mWeights->size() + this->mBias->size() == 0) { + return retArray; + } + + retArray.push_back(this->mWeights->data()[0]); + retArray.push_back(this->mWeights->data()[1]); + retArray.push_back(this->mBias->data()[0]); + retArray.push_back(99.0); + + return retArray; +} + +void +KomputeModelML::_init() +{ + std::cout << "CALLING INIT" << std::endl; +} + +void +KomputeModelML::_process(float delta) +{} + +void +KomputeModelML::_register_methods() +{ + register_method((char*)"_process", &KomputeModelML::_process); + register_method((char*)"_init", &KomputeModelML::_init); + + register_method((char*)"train", &KomputeModelML::train); + register_method((char*)"predict", &KomputeModelML::predict); + register_method((char*)"get_params", &KomputeModelML::get_params); +} + +} diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/gdnative_shared/src/KomputeModelML.hpp b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/gdnative_shared/src/KomputeModelML.hpp new file mode 100644 index 0000000000000000000000000000000000000000..017d890c0c90cb4bf7f0c24f5377354d3f62bf42 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/gdnative_shared/src/KomputeModelML.hpp @@ -0,0 +1,93 @@ +#pragma once + +#include +#include +#include + +#include + +#include "kompute/Kompute.hpp" + +namespace godot { +class KomputeModelML : public Node2D +{ + private: + GODOT_CLASS(KomputeModelML, Node2D); + + public: + KomputeModelML(); + + void train(Array y, Array xI, Array xJ); + + Array predict(Array xI, Array xJ); + + Array get_params(); + + void _process(float delta); + void _init(); + + static void _register_methods(); + + private: + std::shared_ptr mWeights; + std::shared_ptr mBias; +}; + +static std::string LR_SHADER = R"( +#version 450 + +layout (constant_id = 0) const uint M = 0; + +layout (local_size_x = 1) in; + +layout(set = 0, binding = 0) buffer bxi { float xi[]; }; +layout(set = 0, binding = 1) buffer bxj { float xj[]; }; +layout(set = 0, binding = 2) buffer by { float y[]; }; +layout(set = 0, binding = 3) buffer bwin { float win[]; }; +layout(set = 0, binding = 4) buffer bwouti { float wouti[]; }; +layout(set = 0, binding = 5) buffer bwoutj { float woutj[]; }; +layout(set = 0, binding = 6) buffer bbin { float bin[]; }; +layout(set = 0, binding = 7) buffer bbout { float bout[]; }; +layout(set = 0, binding = 8) buffer blout { float lout[]; }; + +float m = float(M); + +float sigmoid(float z) { + return 1.0 / (1.0 + exp(-z)); +} + +float inference(vec2 x, vec2 w, float b) { + // Compute the linear mapping function + float z = dot(w, x) + b; + // Calculate the y-hat with sigmoid + float yHat = sigmoid(z); + return yHat; +} + +float calculateLoss(float yHat, float y) { + return -(y * log(yHat) + (1.0 - y) * log(1.0 - yHat)); +} + +void main() { + uint idx = gl_GlobalInvocationID.x; + + vec2 wCurr = vec2(win[0], win[1]); + float bCurr = bin[0]; + + vec2 xCurr = vec2(xi[idx], xj[idx]); + float yCurr = y[idx]; + + float yHat = inference(xCurr, wCurr, bCurr); + + float dZ = yHat - yCurr; + vec2 dW = (1. / m) * xCurr * dZ; + float dB = (1. / m) * dZ; + wouti[idx] = dW.x; + woutj[idx] = dW.y; + bout[idx] = dB; + + lout[idx] = calculateLoss(yHat, yCurr); +} +)"; + +} diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_engine/.gdignore b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_engine/.gdignore new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/assets/CustomModuleExampleScene.tscn b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/assets/CustomModuleExampleScene.tscn new file mode 100644 index 0000000000000000000000000000000000000000..ea32697a72e55ca22548056c7aeee455bfda5d31 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/assets/CustomModuleExampleScene.tscn @@ -0,0 +1,8 @@ +[gd_scene load_steps=2 format=2] + +[ext_resource path="res://godot_resources/scripts/CustomModuleExampleScene.gd" type="Script" id=1] + +[node name="Parent" type="Node2D"] +script = ExtResource( 1 ) + +[node name="EditorKomputeModelMLNode" type="KomputeModelMLNode" parent="."] diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/assets/DynamicExampleScene.tscn b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/assets/DynamicExampleScene.tscn new file mode 100644 index 0000000000000000000000000000000000000000..403879b7a355248bfafea1d4f8129a9bca273c75 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/assets/DynamicExampleScene.tscn @@ -0,0 +1,309 @@ +[gd_scene load_steps=10 format=2] + +[ext_resource path="res://godot_resources/scripts/DynamicExampleScript.gd" type="Script" id=1] +[ext_resource path="res://godot_resources/scripts/KomputeNativeClass.gdns" type="Script" id=2] +[ext_resource path="res://godot_resources/assets/icon.png" type="Texture" id=3] +[ext_resource path="res://godot_resources/assets/TextFormat.theme" type="Theme" id=4] + +[sub_resource type="GradientTexture" id=1] + +[sub_resource type="StyleBoxTexture" id=2] +texture = SubResource( 1 ) +region_rect = Rect2( 0, 0, 2048, 1 ) + +[sub_resource type="DynamicFontData" id=3] +font_path = "res://godot_resources/assets/roboto.ttf" + +[sub_resource type="DynamicFont" id=4] +size = 27 +font_data = SubResource( 3 ) + +[sub_resource type="Theme" id=5] +default_font = SubResource( 4 ) + +[node name="Parent" type="Node2D"] +script = ExtResource( 1 ) + +[node name="KomputeNode" type="Node2D" parent="."] +script = ExtResource( 2 ) + +[node name="UI" type="Node" parent="."] + +[node name="UIVBoxContainer" type="VBoxContainer" parent="UI"] +anchor_right = 1.0 +anchor_bottom = 1.0 +theme = ExtResource( 4 ) +__meta__ = { +"_edit_use_anchors_": false +} + +[node name="TitleLabel" type="Label" parent="UI/UIVBoxContainer"] +margin_right = 1024.0 +margin_bottom = 60.0 +text = "Godot ML Kompute " +align = 1 + +[node name="LogoHBoxContainer" type="HBoxContainer" parent="UI/UIVBoxContainer"] +margin_top = 64.0 +margin_right = 1024.0 +margin_bottom = 160.0 +alignment = 1 + +[node name="TextureRect" type="TextureRect" parent="UI/UIVBoxContainer/LogoHBoxContainer"] +margin_left = 464.0 +margin_right = 560.0 +margin_bottom = 96.0 +texture = ExtResource( 3 ) + +[node name="XIHBoxContainer" type="HBoxContainer" parent="UI/UIVBoxContainer"] +margin_top = 164.0 +margin_right = 1024.0 +margin_bottom = 234.0 + +[node name="VSeparator" type="VSeparator" parent="UI/UIVBoxContainer/XIHBoxContainer"] +margin_right = 20.0 +margin_bottom = 70.0 +rect_min_size = Vector2( 20, 0 ) + +[node name="Label" type="Label" parent="UI/UIVBoxContainer/XIHBoxContainer"] +margin_left = 24.0 +margin_top = 5.0 +margin_right = 193.0 +margin_bottom = 65.0 +text = "Xi Input" + +[node name="VSeparator2" type="VSeparator" parent="UI/UIVBoxContainer/XIHBoxContainer"] +margin_left = 197.0 +margin_right = 217.0 +margin_bottom = 70.0 +rect_min_size = Vector2( 20, 0 ) + +[node name="LineEdit" type="LineEdit" parent="UI/UIVBoxContainer/XIHBoxContainer"] +margin_left = 221.0 +margin_right = 1000.0 +margin_bottom = 70.0 +size_flags_horizontal = 3 +text = "[ 0, 0, 1, 1, 1, 1 ]" +align = 1 + +[node name="VSeparator3" type="VSeparator" parent="UI/UIVBoxContainer/XIHBoxContainer"] +margin_left = 1004.0 +margin_right = 1024.0 +margin_bottom = 70.0 +rect_min_size = Vector2( 20, 0 ) + +[node name="XJHBoxContainer" type="HBoxContainer" parent="UI/UIVBoxContainer"] +margin_top = 238.0 +margin_right = 1024.0 +margin_bottom = 308.0 + +[node name="VSeparator" type="VSeparator" parent="UI/UIVBoxContainer/XJHBoxContainer"] +margin_right = 20.0 +margin_bottom = 70.0 +rect_min_size = Vector2( 20, 0 ) + +[node name="Label" type="Label" parent="UI/UIVBoxContainer/XJHBoxContainer"] +margin_left = 24.0 +margin_top = 5.0 +margin_right = 193.0 +margin_bottom = 65.0 +text = "Xj Input" + +[node name="VSeparator2" type="VSeparator" parent="UI/UIVBoxContainer/XJHBoxContainer"] +margin_left = 197.0 +margin_right = 217.0 +margin_bottom = 70.0 +rect_min_size = Vector2( 20, 0 ) + +[node name="LineEdit" type="LineEdit" parent="UI/UIVBoxContainer/XJHBoxContainer"] +margin_left = 221.0 +margin_right = 1000.0 +margin_bottom = 70.0 +size_flags_horizontal = 3 +text = "[ 0, 0, 0, 0, 1, 1 ]" +align = 1 + +[node name="VSeparator3" type="VSeparator" parent="UI/UIVBoxContainer/XJHBoxContainer"] +margin_left = 1004.0 +margin_right = 1024.0 +margin_bottom = 70.0 +rect_min_size = Vector2( 20, 0 ) + +[node name="YHBoxContainer" type="HBoxContainer" parent="UI/UIVBoxContainer"] +margin_top = 312.0 +margin_right = 1024.0 +margin_bottom = 382.0 + +[node name="VSeparator" type="VSeparator" parent="UI/UIVBoxContainer/YHBoxContainer"] +margin_right = 20.0 +margin_bottom = 70.0 +rect_min_size = Vector2( 20, 0 ) + +[node name="Label" type="Label" parent="UI/UIVBoxContainer/YHBoxContainer"] +margin_left = 24.0 +margin_top = 5.0 +margin_right = 192.0 +margin_bottom = 65.0 +text = "Y Input " + +[node name="VSeparator2" type="VSeparator" parent="UI/UIVBoxContainer/YHBoxContainer"] +margin_left = 196.0 +margin_right = 216.0 +margin_bottom = 70.0 +rect_min_size = Vector2( 20, 0 ) + +[node name="LineEdit" type="LineEdit" parent="UI/UIVBoxContainer/YHBoxContainer"] +margin_left = 220.0 +margin_right = 1000.0 +margin_bottom = 70.0 +size_flags_horizontal = 3 +text = "[ 0, 0, 0, 0, 1, 1 ]" +align = 1 + +[node name="VSeparator3" type="VSeparator" parent="UI/UIVBoxContainer/YHBoxContainer"] +margin_left = 1004.0 +margin_right = 1024.0 +margin_bottom = 70.0 +rect_min_size = Vector2( 20, 0 ) + +[node name="Button" type="Button" parent="UI/UIVBoxContainer"] +margin_top = 386.0 +margin_right = 1024.0 +margin_bottom = 452.0 +text = "Kompute Train & Predict ML" + +[node name="Panel" type="PanelContainer" parent="UI/UIVBoxContainer"] +margin_top = 456.0 +margin_right = 1024.0 +margin_bottom = 600.0 +size_flags_vertical = 3 +custom_styles/panel = SubResource( 2 ) + +[node name="VBoxContainer" type="VBoxContainer" parent="UI/UIVBoxContainer/Panel"] +margin_right = 1024.0 +margin_bottom = 144.0 + +[node name="VSplitContainer2" type="VSplitContainer" parent="UI/UIVBoxContainer/Panel/VBoxContainer"] +margin_right = 1024.0 +margin_bottom = 10.0 +rect_min_size = Vector2( 0, 10 ) + +[node name="PredHBoxContainer" type="HBoxContainer" parent="UI/UIVBoxContainer/Panel/VBoxContainer"] +margin_top = 14.0 +margin_right = 1024.0 +margin_bottom = 47.0 +theme = SubResource( 5 ) +__meta__ = { +"_edit_use_anchors_": false +} + +[node name="VSeparator3" type="VSeparator" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"] +margin_right = 20.0 +margin_bottom = 33.0 +rect_min_size = Vector2( 20, 0 ) + +[node name="Label" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"] +margin_left = 24.0 +margin_right = 144.0 +margin_bottom = 33.0 +text = "Weight 1: " + +[node name="Weight1Label" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"] +margin_left = 148.0 +margin_right = 332.0 +margin_bottom = 33.0 +size_flags_horizontal = 3 +text = "n/a" +align = 1 + +[node name="VSeparator4" type="VSeparator" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"] +margin_left = 336.0 +margin_right = 356.0 +margin_bottom = 33.0 +rect_min_size = Vector2( 20, 0 ) + +[node name="VSeparator5" type="VSeparator" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"] +margin_left = 360.0 +margin_right = 380.0 +margin_bottom = 33.0 +rect_min_size = Vector2( 20, 0 ) + +[node name="Label2" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"] +margin_left = 384.0 +margin_right = 504.0 +margin_bottom = 33.0 +text = "Weight 2: " + +[node name="Weight2Label" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"] +margin_left = 508.0 +margin_right = 692.0 +margin_bottom = 33.0 +size_flags_horizontal = 3 +text = "n/a" +align = 1 + +[node name="VSeparator6" type="VSeparator" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"] +margin_left = 696.0 +margin_right = 716.0 +margin_bottom = 33.0 +rect_min_size = Vector2( 20, 0 ) + +[node name="VSeparator7" type="VSeparator" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"] +margin_left = 720.0 +margin_right = 740.0 +margin_bottom = 33.0 +rect_min_size = Vector2( 20, 0 ) + +[node name="Label3" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"] +margin_left = 744.0 +margin_right = 811.0 +margin_bottom = 33.0 +text = "Bias: " + +[node name="BiasLabel" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"] +margin_left = 815.0 +margin_right = 999.0 +margin_bottom = 33.0 +size_flags_horizontal = 3 +text = "n/a" +align = 1 + +[node name="VSeparator8" type="VSeparator" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"] +margin_left = 1003.0 +margin_right = 1023.0 +margin_bottom = 33.0 +rect_min_size = Vector2( 20, 0 ) + +[node name="VSplitContainer" type="VSplitContainer" parent="UI/UIVBoxContainer/Panel/VBoxContainer"] +margin_top = 51.0 +margin_right = 1024.0 +margin_bottom = 71.0 +rect_min_size = Vector2( 0, 20 ) + +[node name="PredHBoxContainer2" type="HBoxContainer" parent="UI/UIVBoxContainer/Panel/VBoxContainer"] +margin_top = 75.0 +margin_right = 1024.0 +margin_bottom = 135.0 +__meta__ = { +"_edit_use_anchors_": false +} + +[node name="VSeparator3" type="VSeparator" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer2"] +margin_right = 20.0 +margin_bottom = 60.0 +rect_min_size = Vector2( 20, 0 ) + +[node name="Label" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer2"] +margin_left = 24.0 +margin_right = 399.0 +margin_bottom = 60.0 +text = "Prediction result:" + +[node name="PredictionsLabel" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer2"] +margin_left = 403.0 +margin_right = 1024.0 +margin_bottom = 60.0 +size_flags_horizontal = 3 +text = "n/a" +align = 1 +[connection signal="pressed" from="UI/UIVBoxContainer/Button" to="." method="compute_ml"] diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/assets/TextFormat.theme b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/assets/TextFormat.theme new file mode 100644 index 0000000000000000000000000000000000000000..1bd17107fbe8ff70678cce9f6f62034cb881a651 Binary files /dev/null and b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/assets/TextFormat.theme differ diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/assets/default_env.tres b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/assets/default_env.tres new file mode 100644 index 0000000000000000000000000000000000000000..20207a4aa2cfbae0c95636531b00d91502940d12 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/assets/default_env.tres @@ -0,0 +1,7 @@ +[gd_resource type="Environment" load_steps=2 format=2] + +[sub_resource type="ProceduralSky" id=1] + +[resource] +background_mode = 2 +background_sky = SubResource( 1 ) diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/assets/icon.png b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/assets/icon.png new file mode 100644 index 0000000000000000000000000000000000000000..1d3cc3b28202d98ebbaf6bc7da0190d7b794ae7e Binary files /dev/null and b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/assets/icon.png differ diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/assets/icon.png.import b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/assets/icon.png.import new file mode 100644 index 0000000000000000000000000000000000000000..edc407e0a3e638cecbe483f3ff146bba335c188f --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/assets/icon.png.import @@ -0,0 +1,34 @@ +[remap] + +importer="texture" +type="StreamTexture" +path="res://.import/icon.png-dceb2295401b72087458cab0725ade4f.stex" +metadata={ +"vram_texture": false +} + +[deps] + +source_file="res://godot_resources/assets/icon.png" +dest_files=[ "res://.import/icon.png-dceb2295401b72087458cab0725ade4f.stex" ] + +[params] + +compress/mode=0 +compress/lossy_quality=0.7 +compress/hdr_mode=0 +compress/bptc_ldr=0 +compress/normal_map=0 +flags/repeat=0 +flags/filter=true +flags/mipmaps=false +flags/anisotropic=false +flags/srgb=2 +process/fix_alpha_border=true +process/premult_alpha=false +process/HDR_as_SRGB=false +process/invert_color=false +stream=false +size_limit=0 +detect_3d=true +svg/scale=1.0 diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/assets/roboto.ttf b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/assets/roboto.ttf new file mode 100644 index 0000000000000000000000000000000000000000..80a0457a9eb127ec3ada59cd0eb3d795a635890c --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/assets/roboto.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:017c0be9aaa6d0359737e1fa762ad304c0e0107927faff5a6c1f415c7f5244ed +size 171272 diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/scripts/CustomModuleExampleScene.gd b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/scripts/CustomModuleExampleScene.gd new file mode 100644 index 0000000000000000000000000000000000000000..30849545f65eae13985b7a6bb60a81c4de38b464 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/scripts/CustomModuleExampleScene.gd @@ -0,0 +1,28 @@ +extends Node2D + +# Called when the node enters the scene tree for the first time. +func _ready(): + var xi = [0, 1, 1, 1, 1, 1] + var xj = [0, 0, 0, 0, 1, 1] + var y = [0, 0, 0, 0, 1, 1] + + print("Running training and predict on existing node") + + $EditorKomputeModelMLNode.train(y, xi, xj) + + var preds = $EditorKomputeModelMLNode.predict(xi, xj) + + print(preds) + + print("Running training and predict on new instance") + + # Create new instance + var s = KomputeModelMLNode.new() + + s.train(y, xi, xj) + print("") + + preds = s.predict(xi, xj) + + print(preds) + diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/scripts/DynamicExampleScript.gd b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/scripts/DynamicExampleScript.gd new file mode 100644 index 0000000000000000000000000000000000000000..8e148f7a3bad0f61bda867c16ada6250995f407a --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/scripts/DynamicExampleScript.gd @@ -0,0 +1,36 @@ +extends Node2D + +onready var xi_node = $UI/UIVBoxContainer/XIHBoxContainer/LineEdit +onready var xj_node = $UI/UIVBoxContainer/XJHBoxContainer/LineEdit +onready var y_node = $UI/UIVBoxContainer/XJHBoxContainer/LineEdit +onready var preds_node = $UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer2/PredictionsLabel +onready var w1_node = $UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer/Weight1Label +onready var w2_node = $UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer/Weight2Label +onready var bias_node = $UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer/BiasLabel + +# Called when the node enters the scene tree for the first time. +func _ready(): + pass + +func compute_ml(): + + var xi = str2var(xi_node.text) + var xj = str2var(xj_node.text) + var y = str2var(y_node.text) + + var s = KomputeModelML.new() + + s.train(y, xi, xj) + + var preds = s.predict(xi, xj) + + preds_node.text = str(preds) + + var params = s.get_params() + + w1_node.set_text(str(params[0])) + w2_node.set_text(str(params[1])) + bias_node.set_text(str(params[2])) + + + diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/scripts/KomputeNativeClass.gdns b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/scripts/KomputeNativeClass.gdns new file mode 100644 index 0000000000000000000000000000000000000000..8d6f9a2cba95a3581f80da069bd9fb583679aea0 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/scripts/KomputeNativeClass.gdns @@ -0,0 +1,8 @@ +[gd_resource type="NativeScript" load_steps=2 format=2] + +[ext_resource path="res://godot_resources/scripts/KomputeNativeLibrary.gdnlib" type="GDNativeLibrary" id=1] + +[resource] +class_name = "KomputeModelML" +library = ExtResource( 1 ) +script_class_name = "KomputeModelML" diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/scripts/KomputeNativeLibrary.gdnlib b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/scripts/KomputeNativeLibrary.gdnlib new file mode 100644 index 0000000000000000000000000000000000000000..6bc8c07007962bad5d05834dbfe5ffef00e47806 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/godot_resources/scripts/KomputeNativeLibrary.gdnlib @@ -0,0 +1,14 @@ +[general] + +singleton=false +load_once=true +symbol_prefix="godot_" +reloadable=false + +[entry] + +Windows.64="res://gdnative_shared/build/Release/kompute_godot.dll" + +[dependencies] + +Windows.64=[ ] diff --git a/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/project.godot b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/project.godot new file mode 100644 index 0000000000000000000000000000000000000000..ffd7325e14bce8b1bc1144f7ae7c13892169e42c --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/godot_logistic_regression/project.godot @@ -0,0 +1,29 @@ +; Engine configuration file. +; It's best edited using the editor UI and not directly, +; since the parameters that go here are not all obvious. +; +; Format: +; [section] ; section goes between [] +; param=value ; assign values to parameters + +config_version=4 + +_global_script_classes=[ { +"base": "Node2D", +"class": "KomputeModelML", +"language": "NativeScript", +"path": "res://godot_resources/scripts/KomputeNativeClass.gdns" +} ] +_global_script_class_icons={ +"KomputeModelML": "" +} + +[application] + +config/name="KomputeGame" +run/main_scene="res://godot_resources/assets/DynamicExampleScene.tscn" +config/icon="res://godot_resources/assets/icon.png" + +[rendering] + +environment/default_environment="res://godot_resources/assets/default_env.tres" diff --git a/ggml/src/ggml-kompute/kompute/examples/logistic_regression/CMakeLists.txt b/ggml/src/ggml-kompute/kompute/examples/logistic_regression/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..246e81338bfcb2dda5dbaaa5deb28fefc8035a6e --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/logistic_regression/CMakeLists.txt @@ -0,0 +1,45 @@ +cmake_minimum_required(VERSION 3.20) +project(kompute_logistic_regression) + +set(CMAKE_CXX_STANDARD 14) + +# Options +option(KOMPUTE_OPT_GIT_TAG "The tag of the repo to use for the example" 1344ece4ac278f9b3be3b4555ffaace7a032b91f) +option(KOMPUTE_OPT_FROM_SOURCE "Whether to build example from source or from git fetch repo" 0) + +# Set a default build type if none was specified +# Based on: https://github.com/openchemistry/tomviz/blob/master/cmake/BuildType.cmake +set(DEFAULT_BUILD_TYPE "Release") + +if(EXISTS "${CMAKE_SOURCE_DIR}/.git") + set(DEFAULT_BUILD_TYPE "Debug") +endif() + +if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) + message(STATUS "Setting build type to '${DEFAULT_BUILD_TYPE}' as none was specified.") + set(CMAKE_BUILD_TYPE "${DEFAULT_BUILD_TYPE}" CACHE STRING "Choose the type of build." FORCE) + + # Set the possible values of build type for cmake-gui + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") +endif() + +if(WIN32) # Install dlls in the same directory as the executable on Windows + set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) + set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) +endif() + +if(KOMPUTE_OPT_FROM_SOURCE) + add_subdirectory(../../ ${CMAKE_CURRENT_BINARY_DIR}/kompute_build) +else() + include(FetchContent) + FetchContent_Declare(kompute GIT_REPOSITORY https://github.com/KomputeProject/kompute.git + GIT_TAG ${KOMPUTE_OPT_GIT_TAG}) + FetchContent_MakeAvailable(kompute) + include_directories(${kompute_SOURCE_DIR}/src/include) +endif() + +# Add to the list, so CMake can later find the code to compile shaders to header files +list(APPEND CMAKE_PREFIX_PATH "${kompute_SOURCE_DIR}/cmake") + +add_subdirectory(shader) +add_subdirectory(src) diff --git a/ggml/src/ggml-kompute/kompute/examples/logistic_regression/README.md b/ggml/src/ggml-kompute/kompute/examples/logistic_regression/README.md new file mode 100644 index 0000000000000000000000000000000000000000..204180f0adabe6deaaae3625b8a74d96f704c8e7 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/logistic_regression/README.md @@ -0,0 +1,45 @@ +# Kompute Logistic Regression Example + +This folder contains an end to end Kompute Example that implements logistic regression. +This example is structured such that you will be able to extend it for your project. +It contains a CMake build configuration that can be used in your production applications. + +## Building the example + +You will notice that it's a standalone project, so you can re-use it for your application. +It uses CMake's [`fetch_content`](https://cmake.org/cmake/help/latest/module/FetchContent.html) to consume Kompute as a dependency. +To build you just need to run the CMake command in this folder as follows: + +```bash +git clone https://github.com/KomputeProject/kompute.git +cd kompute/examples/logistic_regression +mkdir build +cd build +cmake .. +cmake --build . +``` + +## Executing + +Form inside the `build/` directory run: + +### Linux + +```bash +./kompute_logistic_regression +``` + +### Windows + +```bash +.\Debug\kompute_logistic_regression.exe +``` + +## Pre-requisites + +In order to run this example, you will need the following dependencies: + +* REQUIRED + + The Vulkan SDK must be installed + +For the Vulkan SDK, the simplest way to install it is through [their website](https://vulkan.lunarg.com/sdk/home). You just have to follow the instructions for the relevant platform. diff --git a/ggml/src/ggml-kompute/kompute/examples/logistic_regression/shader/CMakeLists.txt b/ggml/src/ggml-kompute/kompute/examples/logistic_regression/shader/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..684a6a82a91b4460d570505f8113f0a4a3559bfa --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/logistic_regression/shader/CMakeLists.txt @@ -0,0 +1,15 @@ +cmake_minimum_required(VERSION 3.20) + +# To add more shaders simply copy the vulkan_compile_shader command and replace it with your new shader +vulkan_compile_shader(INFILE my_shader.comp + OUTFILE my_shader.hpp + NAMESPACE "shader" + RELATIVE_PATH "${kompute_SOURCE_DIR}/cmake") + +# Then add it to the library, so you can access it later in your code +add_library(shader INTERFACE "${CMAKE_CURRENT_BINARY_DIR}/my_shader.hpp" + + # "${CMAKE_CURRENT_BINARY_DIR}/my_shader2.hpp" +) + +target_include_directories(shader INTERFACE $) diff --git a/ggml/src/ggml-kompute/kompute/examples/logistic_regression/shader/my_shader.comp b/ggml/src/ggml-kompute/kompute/examples/logistic_regression/shader/my_shader.comp new file mode 100644 index 0000000000000000000000000000000000000000..6934fd29af348262e8aa1b909fda2246cbc1b276 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/logistic_regression/shader/my_shader.comp @@ -0,0 +1,53 @@ +#version 450 + +layout (constant_id = 0) const float m = 0; + +layout (local_size_x = 1) in; + +layout(set = 0, binding = 0) buffer bxi { float xi[]; }; +layout(set = 0, binding = 1) buffer bxj { float xj[]; }; +layout(set = 0, binding = 2) buffer by { float y[]; }; +layout(set = 0, binding = 3) buffer bwin { float win[]; }; +layout(set = 0, binding = 4) buffer bwouti { float wouti[]; }; +layout(set = 0, binding = 5) buffer bwoutj { float woutj[]; }; +layout(set = 0, binding = 6) buffer bbin { float bin[]; }; +layout(set = 0, binding = 7) buffer bbout { float bout[]; }; +layout(set = 0, binding = 8) buffer blout { float lout[]; }; + +float sigmoid(float z) { + return 1.0 / (1.0 + exp(-z)); +} + +float inference(vec2 x, vec2 w, float b) { + // Compute the linear mapping function + float z = dot(w, x) + b; + // Calculate the y-hat with sigmoid + float yHat = sigmoid(z); + return yHat; +} + +float calculateLoss(float yHat, float y) { + return -(y * log(yHat) + (1.0 - y) * log(1.0 - yHat)); +} + +void main() { + uint idx = gl_GlobalInvocationID.x; + + vec2 wCurr = vec2(win[0], win[1]); + float bCurr = bin[0]; + + vec2 xCurr = vec2(xi[idx], xj[idx]); + float yCurr = y[idx]; + + float yHat = inference(xCurr, wCurr, bCurr); + + float dZ = yHat - yCurr; + vec2 dW = (1. / m) * xCurr * dZ; + float dB = (1. / m) * dZ; + wouti[idx] = dW.x; + woutj[idx] = dW.y; + bout[idx] = dB; + + lout[idx] = calculateLoss(yHat, yCurr); +} + diff --git a/ggml/src/ggml-kompute/kompute/examples/logistic_regression/src/CMakeLists.txt b/ggml/src/ggml-kompute/kompute/examples/logistic_regression/src/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..75273e79cb1706617d4740b3ac77c7191d9fdb4a --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/logistic_regression/src/CMakeLists.txt @@ -0,0 +1,4 @@ +cmake_minimum_required(VERSION 3.20) + +add_executable(kompute_logistic_regression main.cpp) +target_link_libraries(kompute_logistic_regression PRIVATE shader kompute::kompute) diff --git a/ggml/src/ggml-kompute/kompute/examples/logistic_regression/src/main.cpp b/ggml/src/ggml-kompute/kompute/examples/logistic_regression/src/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3fc18f344ad7891c8b044e0ee7d729ffe8baaa86 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/logistic_regression/src/main.cpp @@ -0,0 +1,78 @@ + +#include +#include +#include + +#include "kompute/Tensor.hpp" +#include "my_shader.hpp" +#include + +int +main() +{ + uint32_t ITERATIONS = 100; + float learningRate = 0.1; + + kp::Manager mgr; + + std::shared_ptr> xI = mgr.tensor({ 0, 1, 1, 1, 1 }); + std::shared_ptr> xJ = mgr.tensor({ 0, 0, 0, 1, 1 }); + + std::shared_ptr> y = mgr.tensor({ 0, 0, 0, 1, 1 }); + + std::shared_ptr> wIn = mgr.tensor({ 0.001, 0.001 }); + std::shared_ptr> wOutI = + mgr.tensor({ 0, 0, 0, 0, 0 }); + std::shared_ptr> wOutJ = + mgr.tensor({ 0, 0, 0, 0, 0 }); + + std::shared_ptr> bIn = mgr.tensor({ 0 }); + std::shared_ptr> bOut = + mgr.tensor({ 0, 0, 0, 0, 0 }); + + std::shared_ptr> lOut = + mgr.tensor({ 0, 0, 0, 0, 0 }); + + std::vector> params = { xI, xJ, y, + wIn, wOutI, wOutJ, + bIn, bOut, lOut }; + + mgr.sequence()->eval(params); + + std::vector spirv2{ 0x1, 0x2 }; + + std::vector spirv( + shader::MY_SHADER_COMP_SPV.begin(), + shader::MY_SHADER_COMP_SPV.end()); + + std::shared_ptr algorithm = mgr.algorithm( + params, spirv, kp::Workgroup({ 5 }), std::vector({ 5.0 })); + + std::shared_ptr sq = + mgr.sequence() + ->record({ wIn, bIn }) + ->record(algorithm) + ->record({ wOutI, wOutJ, bOut, lOut }); + + // Iterate across all expected iterations + for (size_t i = 0; i < ITERATIONS; i++) { + sq->eval(); + + for (size_t j = 0; j < bOut->size(); j++) { + wIn->data()[0] -= learningRate * wOutI->data()[j]; + wIn->data()[1] -= learningRate * wOutJ->data()[j]; + bIn->data()[0] -= learningRate * bOut->data()[j]; + } + } + + KP_LOG_WARN("Result wIn i: {}, wIn j: {}, bIn: {}", + wIn->data()[0], + wIn->data()[1], + bIn->data()[0]); + + if (wIn->data()[0] > 0.01 || + wIn->data()[1] < 1.0 || + bIn->data()[0] > 0.0) { + throw std::runtime_error("Result does not match"); + } +} diff --git a/ggml/src/ggml-kompute/kompute/examples/neural_network_vgg7/.gitignore b/ggml/src/ggml-kompute/kompute/examples/neural_network_vgg7/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..999940f78fd5ab1f30f3747e03816ab45304627e --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/neural_network_vgg7/.gitignore @@ -0,0 +1,2 @@ +model-kipper +model.json diff --git a/ggml/src/ggml-kompute/kompute/examples/neural_network_vgg7/README.md b/ggml/src/ggml-kompute/kompute/examples/neural_network_vgg7/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cbf9635c211f9cadcfe5734eb3ce177219d90b96 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/neural_network_vgg7/README.md @@ -0,0 +1,49 @@ +# Convolutional Neural Network (CNN) VGG7 implementation + +This example provides an implementation of a convolutional neural network (CNN) that enables for image resolution upscaling, which means that images can improve their quality through purely the machine learning implementation. + +This example demonstrates performing image upscaling using Kompute on the test image below. + +In this example we will be doing the following: + +* Import pre-trained model +* Create Kompute code that loads model weights +* Create Kompute shader that performs inference on image +* Run model against image to perform upscale + +## Import pre-trained model + +To import the no-noise-compensation VGG7 model (into `model-kipper`): + +``` +curl -o model.json https://raw.githubusercontent.com/nagadomi/waifu2x/master/models/vgg_7/art/scale2.0x_model.json +python3 import_vgg7.py model.json +``` + +Other models from the vgg\_7 set (such as `https://raw.githubusercontent.com/nagadomi/waifu2x/master/models/vgg_7/photo/noise0_model.json`) can be subsituted as desired. + +## Create code that loads model weights + +We implement the kompute logic under run_vgg7 that loads the model weights and coordinates the execution of the inference. + +## Create Kompute shader that performs inference on image + +Similarly, we created a compute shader that performs an inference iteration on an image provided to perfrom upscaling. + +## Run model against image to perfrom upscale + +We now execute model against an image created by us to show how upscaling works. The image used will be the one below: + +![](https://raw.githubusercontent.com/KomputeProject/kompute/master/examples/neural_network_vgg7/w2wbinit.png) + +To execute that model no tiling is performed, so be careful about image sizes. + +We can now run the command below to perform inference against the image blow. + +`python3 run_vgg7.py w2wbinit.png out.png` + +This would successfully upscale the resolution using the machine learning model, and the result is below: + +![](https://raw.githubusercontent.com/KomputeProject/kompute/master/examples/neural_network_vgg7/out.png) + + diff --git a/ggml/src/ggml-kompute/kompute/examples/neural_network_vgg7/import_vgg7.py b/ggml/src/ggml-kompute/kompute/examples/neural_network_vgg7/import_vgg7.py new file mode 100644 index 0000000000000000000000000000000000000000..c86ff36c8d0f7fc0b3f5ffee034b8c19d10d9341 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/neural_network_vgg7/import_vgg7.py @@ -0,0 +1,30 @@ +import numpy +import json +import os +import sys +import time +import sh_common + +if len(sys.argv) != 2: + print("import_vgg7.py JSONPATH") + print(" i.e. import_vgg7.py /home/you/Documents/External/waifu2x/models/vgg_7/art/scale2.0x_model.json") + sys.exit(1) + +try: + os.mkdir("model-kipper") +except: + pass + +data_list = json.load(open(sys.argv[1], "rb")) + +idx = 0 +for i in range(7): + layer = data_list[i] + w = numpy.array(layer["weight"]) + w.reshape((-1, 3, 3)).transpose((0, 2, 1)) + b = numpy.array(layer["bias"]) + sh_common.save_param("kipper", idx, w) + idx += 1 + sh_common.save_param("kipper", idx, b) + idx += 1 + diff --git a/ggml/src/ggml-kompute/kompute/examples/neural_network_vgg7/out.png b/ggml/src/ggml-kompute/kompute/examples/neural_network_vgg7/out.png new file mode 100644 index 0000000000000000000000000000000000000000..ffb718eef13ea0d51ad1e2b2f75daf4db4a09d64 Binary files /dev/null and b/ggml/src/ggml-kompute/kompute/examples/neural_network_vgg7/out.png differ diff --git a/ggml/src/ggml-kompute/kompute/examples/neural_network_vgg7/run_vgg7.py b/ggml/src/ggml-kompute/kompute/examples/neural_network_vgg7/run_vgg7.py new file mode 100644 index 0000000000000000000000000000000000000000..f5d88e8413b0c1616a6268e33a41272ff33a974c --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/neural_network_vgg7/run_vgg7.py @@ -0,0 +1,125 @@ +import kp +import numpy +import os +import sys +import time +import sh_conv +import sh_common + +if len(sys.argv) != 3: + print("run_vgg7.py INPUT OUTPUT") + print(" Tiling is not implemented, but padding is implemented") + sys.exit(1) + +# NOTES: +# + Tiling is not implemented, but padding is implemented +# So don't run anything too big through it + +if False: + kpm = kp.Manager(1) + if kpm.get_device_properties()["device_name"].count("RAVEN") > 0: + raise "Safety cut-out triggered. Sorry!" +else: + kpm = kp.Manager() + +image = sh_common.image_load(sys.argv[1]) +image = image.repeat(2, 0).repeat(2, 1) +image = numpy.pad(image, [[7, 7], [7, 7], [0, 0]], mode = "edge") + +# Ensure image has 4 channels even though they will be unused. +# This is because of vectorization vec4 magic. +while image.shape[2] < sh_common.VSZ: + image = numpy.pad(image, [[0, 0], [0, 0], [0, 1]], mode = "constant") + +# sh_common.image_save("pad.png", image) + +# Prepare the initial tensor. + +tensor_in = kpm.tensor(image) +tensor_in_h = image.shape[0] +tensor_in_w = image.shape[1] +tensor_in_cg = 1 +tensor_in_c = 3 + +# Run things. +channels = [32, 32, 64, 64, 128, 128, 3] + +for i in range(7): + # Prepare tensors. + # 'c' is the total amount of channels, while 'cg' is the amount of vec4s (channel-groups). + # This is important because weights have to be padded for the shader. + tensor_out_h = tensor_in_h - 2 + tensor_out_w = tensor_in_w - 2 + tensor_out_c = channels[i] + tensor_out_cg = (channels[i] + (sh_common.VSZ - 1)) // sh_common.VSZ + # TODO: How to produce a blank tensor we don't care about the contents of? + # This isn't being synced, and from experience so far that should handle most of it, + # but what about memory usage? + # *Most* of these tensors live entirely on-device except when debugging. + # Can that be handled? (Also good question: Does it even need to be handled?) + tensor_out = kpm.tensor(numpy.zeros((tensor_out_h * tensor_out_w * tensor_out_cg * sh_common.VSZ))) + weight = kpm.tensor(sh_common.load_weights_padded("kipper", (i * 2) + 0, tensor_out_c, tensor_in_c, 3)) + bias = kpm.tensor(sh_common.load_biases_padded("kipper", (i * 2) + 1, tensor_out_c)) + # Compute. + # TODO: It'd be nice to wrap this up into a class for optimization purposes. + workgroup = ((tensor_out_w + 7) // 8, (tensor_out_h + 1) // 2, tensor_out_cg) + alg = kpm.algorithm( + # tensors + [tensor_in, bias, weight, tensor_out], + # spirv + sh_conv.conv_shader, + # workgroup + workgroup, + # spec_consts + [tensor_in_w, tensor_in_h, tensor_in_cg, tensor_out_w, tensor_out_h, tensor_out_cg], + # push_consts + [] + ) + + print("Step complexity " + str(workgroup)) + print("Step channel layout " + str(tensor_in_cg) + " " + str(tensor_out_cg)) + + # Do this first. Keep in mind "syncs" are copies. + last_seq = kpm.sequence() + things_to_sync_to_device = [bias, weight] + if i == 0: + # For first layer, the input isn't on-device yet + things_to_sync_to_device.append(tensor_in) + last_seq.eval_async(kp.OpTensorSyncDevice(things_to_sync_to_device)) + last_seq.eval_await() + + # Prepare + seq = (kpm.sequence() + .record(kp.OpAlgoDispatch(alg, [])) + ) + # Run + seq.eval() + + print("Done with step") + + if False: + # DEBUG: + # We want to see the output, copy it to local + last_seq = kpm.sequence() + last_seq.eval_async(kp.OpTensorSyncLocal([tensor_out])) + last_seq.eval_await() + tensor_out.data().astype(" numpy.ndarray: + """ + Loads an image. + Doesn't Tensor it, in case you need to do further work with it. + Shape is (h, w, 3). + """ + # file + na = numpy.array(Image.open(path)) + # change type + na = na.astype("float32") / 255.0 + return na + +def image_save(path, na: numpy.ndarray): + """ + Saves an image. + However, note this expects a numpy array. + Shape is (h, w, 3). + """ + # change type + na = numpy.fmax(numpy.fmin(na * 255.0, 255), 0).astype("uint8") + # file + Image.fromarray(na).save(path) + +def load_param(mdl, idx, expected): + npa = numpy.fromfile("model-" + mdl + "/snoop_bin_" + str(idx) + ".bin", " + # [outputCGroups][kernelH][kernelW][inputCGroups][outputChannels][inputChannels] + weight_na = load_param(mdl, idx, tensor_out_c * tensor_in_c * weight_s * weight_s) + # start by putting in the initial shape + weight_na = weight_na.reshape(tensor_out_c, tensor_in_c, weight_s, weight_s) + # then by padding + # NOTE: It is *critically important* that weight padding is done with the "zero" mode. + # The shader WILL NOT ignore these values, but zeroing them causes them to have no effect. + if (tensor_in_c & 3) != 0: + weight_na = numpy.pad(weight_na, [[0, 0], [0, 4 - (tensor_in_c & 3)], [0, 0], [0, 0]], mode = "constant") + if (tensor_out_c & 3) != 0: + weight_na = numpy.pad(weight_na, [[0, 4 - (tensor_out_c & 3)], [0, 0], [0, 0], [0, 0]], mode = "constant") + # reshape to finish splitting things up + weight_na = weight_na.reshape(tensor_out_cg, 4, tensor_in_cg, 4, weight_s, weight_s) + # result is: + # [outputCGroups][outputChannels][inputCGroups][inputChannels][kernelH][kernelW] + # and move output channels to the right... + weight_na = numpy.moveaxis(weight_na, 1, 5) + # result is: + # [outputCGroups][inputCGroups][inputChannels][kernelH][kernelW][outputChannels] + # and move input channels to the right... + weight_na = numpy.moveaxis(weight_na, 2, 5) + # result is: + # [outputCGroups][inputCGroups][kernelH][kernelW][outputChannels][inputChannels] + # and move input cgroups to the right... + weight_na = numpy.moveaxis(weight_na, 1, 3) + return weight_na + +def load_biases_padded(mdl, idx, tensor_out_c): + tensor_out_cg = (tensor_out_c + 3) // 4 + # [outputCGroups][outputChannels] + # biases merely need padding + # Again, has to be zero + bias_na = load_param(mdl, idx, tensor_out_c) + if (tensor_out_c & 3) != 0: + bias_na = numpy.pad(bias_na, [[0, 4 - (tensor_out_c & 3)]], mode = "constant") + return bias_na + diff --git a/ggml/src/ggml-kompute/kompute/examples/neural_network_vgg7/sh_conv.py b/ggml/src/ggml-kompute/kompute/examples/neural_network_vgg7/sh_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..e284a6162253ce5b21a4dd42d2259b6fb054cfa3 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/neural_network_vgg7/sh_conv.py @@ -0,0 +1,77 @@ +import os + +import kp + +def compile_source(source): + os.system("glslangValidator --stdin -S comp -V -o tmp_kp_shader.comp.spv << END\n" + source + "\nEND") + return open("tmp_kp_shader.comp.spv", "rb").read() + + +# This is the convolution & leakyrelu shader. +global conv_shader +conv_shader = compile_source(""" +#version 450 + +layout (local_size_x = 8, local_size_y = 2) in; + +// [y][x][group] (vec4: channels) +layout (set = 0, binding = 0) buffer buf_in_image { readonly restrict vec4 in_image[]; }; +// [outputCGroups] (vec4: output channels) +layout (set = 0, binding = 1) buffer buf_in_bias { readonly restrict vec4 in_bias[]; }; +// [outputCGroups][kernelH][kernelW][inputCGroups] (mat4: input & output channels) +layout (set = 0, binding = 2) buffer buf_in_weight { readonly restrict mat4 in_weight[]; }; +// [y][x][group] (vec4: channels) +layout (set = 0, binding = 3) buffer buf_out_image { writeonly restrict vec4 out_image[]; }; + +// The 'c' measures in cgroups. +// Some maths changes as a result. +layout (constant_id = 0) const float in_w = 0; +layout (constant_id = 1) const float in_h = 0; +layout (constant_id = 2) const float in_cg = 0; +layout (constant_id = 3) const float out_w = 0; +layout (constant_id = 4) const float out_h = 0; +layout (constant_id = 5) const float out_cg = 0; + +uint index_in_no_ic(uvec2 pos) { + return (pos.x + (pos.y * uint(in_w))) * uint(in_cg); +} + +uint index_out(uvec2 pos) { + return ((pos.x + (pos.y * uint(out_w))) * uint(out_cg)) + gl_GlobalInvocationID.z; +} + +void main() { + // out x/y is gl_GlobalInvocationID.xy + // we need to account for workgroupy padding *here* + // so long as we aren't trying to output to a pixel that doesn't exist, + // we won't read from any pixels that don't exist + if ( + (gl_GlobalInvocationID.x < (uint(in_w) - 2)) && + (gl_GlobalInvocationID.y < (uint(in_h) - 2)) + ) { + vec4 value = in_bias[gl_GlobalInvocationID.z]; + for (uint x = 0; x < 3; x++) { + for (uint y = 0; y < 3; y++) { + uint weight_ptr = ((gl_GlobalInvocationID.z * 9) + (x + (y * 3))) * uint(in_cg); + // specific pixel + // important to note is that since in position has a border around it, + // no further transformation is necessary (the - is implied) + uvec2 in_pos = gl_GlobalInvocationID.xy + uvec2(x, y); + uint in_ptr = index_in_no_ic(in_pos); + for (uint icg = 0; icg < uint(in_cg); icg++) { + // input channel group + vec4 iCG = in_image[in_ptr]; + // handle all 4 input components + value += iCG * in_weight[weight_ptr]; + weight_ptr += 1; + in_ptr += 1; + } + } + } + // leakyrelu slope 0.1 + value = (max(value, 0.0) * 0.9) + (value * 0.1); + out_image[index_out(gl_GlobalInvocationID.xy)] = value; + } +} +""") + diff --git a/ggml/src/ggml-kompute/kompute/examples/neural_network_vgg7/w2wbinit.png b/ggml/src/ggml-kompute/kompute/examples/neural_network_vgg7/w2wbinit.png new file mode 100644 index 0000000000000000000000000000000000000000..fc3a908e52a8103132f33e9109d4969938b1c4f4 Binary files /dev/null and b/ggml/src/ggml-kompute/kompute/examples/neural_network_vgg7/w2wbinit.png differ diff --git a/ggml/src/ggml-kompute/kompute/examples/pi4_mesa_build/README.md b/ggml/src/ggml-kompute/kompute/examples/pi4_mesa_build/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6912d29e4fbfea5f03d4fd30447f3ea50c800e34 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/pi4_mesa_build/README.md @@ -0,0 +1,139 @@ +# Kompute Pi4 Mesa Build Example + +The Raspberry Pi 4 is an impressive little computer when you consider that the Broadcom GPU is able to run 2x 4K displays. This example intends to show how to get Kompute up and running on a Raspberry Pi 4. This has huge potential for edge processing using the power of the Pi 4 GPU. + +Special thanks to [Alejandro Piñeiro](https://blogs.igalia.com/apinheiro/) and others for their work on Broadcom drivers for [Mesa](https://gitlab.freedesktop.org/mesa/mesa) which make this example possible. + +## Raspberry Pi Operating System + +For this experiment we used [RaspiOS Lite 2021-01-12](https://downloads.raspberrypi.org/raspios_lite_armhf/images/raspios_lite_armhf-2021-01-12/2021-01-11-raspios-buster-armhf-lite.zip), though it is likely best to start with the latest available operation system from [Raspberry Pi operating system images](https://www.raspberrypi.org/software/operating-systems/). In other experiments the full Raspberry Pi operating system (with desktop environment) was found to work. However, when attempting to use Ubuntu on the Raspberry Pi we were not able to run the Python Kompute examples. + +## Running the Pi headless + +By far the easiest way to get up and running with a Raspberry Pi is to configure it for headless operation. This removes the requirement to have a monitor, keyboard or mouse. To run headless the Pi needs access to the internet and for SSH enabled. The following guides from the Raspberry Pi foundation should help. + +- [Setting up a Raspberry Pi headless](https://www.raspberrypi.org/documentation/configuration/wireless/headless.md) +- [SSH (Secure Shell)](https://www.raspberrypi.org/documentation/remote-access/ssh/) + +## Ensure all packages are using the latest version + +``` +sudo apt-get update +sudo apt-get upgrade +``` + +## Install dependencies for building mesa and running Kompute + +``` +sudo apt-get install \ + git build-essential cmake \ + python3-dev python3-mako python3-venv \ + flex bison meson ninja-build \ + libxcb-shm0-dev libxcb1-dev libxcb-*-dev \ + libx11-dev libx11-xcb-dev x11proto-dri2-dev x11proto-dri3-dev \ + libdrm-dev libxshmfence-dev libxrandr-dev libxfixes-dev \ + vulkan-tools libvulkan-dev +``` + +## Clone mesa repository + +``` +git clone --depth 1 https://gitlab.freedesktop.org/mesa/mesa.git +``` + +## Build mesa + +Use meson and ninja to build mesa using the Broadcom Vulkan SDK drivers. For information on the Gallium drivers please see [V3D — The Mesa 3D Graphics Library latest documentation](https://docs.mesa3d.org/drivers/v3d.html). + +``` +meson --libdir lib \ + --prefix /mesa-install \ + -D platforms=x11 \ + -D vulkan-drivers=broadcom \ + -D gallium-drivers=v3d \ + -D dri-drivers=[] \ + -D buildtype=debug \ + build + +ninja -C build +sudo ninja -C build install +``` + +## Configure preferred Vulkan SDK driver + +Export the path for the Broadcom drivers, this command will need to be run for every new terminal session. + +``` +export VK_ICD_FILENAMES=/mesa-install/share/vulkan/icd.d/broadcom_icd.armv7l.json +``` + +## Allow access to render + +In order to access the render from remote login there are two options. Both options work. + +**Option 1: provide read write access to everyone.** + +``` +sudo chmod ugo+rw /dev/dri/renderD128 +``` + +**Option 2: Change group from render to video.** + +``` +sudo chown root:video /dev/dri/renderD128 +``` + +## Confirm correct Vulkan SDK operation + +To confirm that mesa was configured and built correctly run the following command. + +``` +vulkaninfo +``` + +## Clone Kompute + +Clone Kompute for access to the latest Python tests. + +``` +git clone https://github.com/KomputeProject/kompute.git +``` + +## Install dependencies to run the tests + +Navigate to the available tests and install required dependencies. + +``` +cd kompute/python/test +python3 -m venv .venv +source .venv/bin/activate +pip install --upgrade pip wheel +pip install -r requirements-dev.txt +pip install git+git://github.com/KomputeProject/kompute.git +``` + +## Run the available tests + +Use the following command to run the python tests for Kompute. + +``` +pytest +``` + +If the tests pass then congratulations! You are now able to make full use of the Pi 4 Broadcom GPU for running parallel computing. If however, there are any issues with the tests they can be run in debug mode to see the logs. + +``` +pytest --log-cli-level debug +``` + +Please share any issues with the maintainers and they will be more than happy to help. + +## Closing remarks + +To avoid the need to export `VK_ICD_FILENAMES` every time you login, it is possible to symlink the json file into the default directory. The Vulkan SDK loader looks in the `/etc/vulkan/icd.d/` directory for `.json` files. + +``` +sudo ln -s /mesa-install/share/vulkan/icd.d/broadcom_icd.armv7l.json /etc/vulkan/icd.d/broadcom_icd.armv7l.json +``` + +As a word of warning, configuring the icd filenames in this way will stop certain tests being skipped. At the time of writing this will mean that some tests fail when running on the Pi. diff --git a/ggml/src/ggml-kompute/kompute/examples/python_naive_matmul/README.md b/ggml/src/ggml-kompute/kompute/examples/python_naive_matmul/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0688bb079341a546dc8a2a7bbc625ef81b8086a8 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/python_naive_matmul/README.md @@ -0,0 +1,9 @@ +# Naive Matmul Implementation + +This demonstrate a basic matmul implementation using Python and vulkan-kompute. Many thanks for the very helpful [SGEMM in WebGL2-compute](https://www.ibiblio.org/e-notes/webgl/gpu/mul/sgemm.htm) article on the public library [ibiblio.org](https://www.ibiblio.org/). + +To test the implementation simply run the `matmul.py` script : + +``` +python matmul.py +``` diff --git a/ggml/src/ggml-kompute/kompute/examples/python_naive_matmul/benchmark.py b/ggml/src/ggml-kompute/kompute/examples/python_naive_matmul/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..768a854d4864e91554f8bb08a189a3150ddad3b5 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/python_naive_matmul/benchmark.py @@ -0,0 +1,56 @@ +import time + +import kp +import numpy as np +from imp1_naive import MatMulOp as MatMulOp1 +from imp2_tiled import MatMulOp as MatMulOp2 +from imp3_better_tiling import MatMulOp as MatMulOp3 + + +def main(): + mgr = kp.Manager() + for tensor_size, experiment_count in [(512, 1000), (4096, 5)]: + tensor_shape = [tensor_size, tensor_size] + tensor_shape = [tensor_size, tensor_size] + mat_1 = np.triu(np.ones(tensor_shape)) + mat_2 = np.triu(np.ones(tensor_shape)) + + tensor_in_1 = mgr.tensor(mat_1) + tensor_in_2 = mgr.tensor(mat_2) + tensor_out = mgr.tensor(np.zeros(tensor_shape)) + if tensor_size <= 512: + mat_result = mat_1 @ mat_2 + else: + MatMulOp1(mgr)(tensor_shape, tensor_in_1, tensor_in_2, tensor_out) + mat_result = tensor_out.data().reshape(tensor_shape) # CPU is too slow for big sizes + + print(f'{tensor_shape} input tensors:\n' + f'{mat_1}\n' + f'{mat_2}\n') + print(f'Output :\n{mat_result}') + + for MatMulOp in [MatMulOp1, MatMulOp2, MatMulOp3]: + tensor_out.data()[:] = 0 + mgr.sequence().record(kp.OpTensorSyncDevice([tensor_out])) + matmul_op = MatMulOp(mgr) + matmul_op(tensor_shape, tensor_in_1, tensor_in_2, tensor_out) + + start_time = time.time() + for _ in range(experiment_count): + matmul_op(tensor_shape, tensor_in_1, tensor_in_2, tensor_out) + end_time = time.time() + experiment_time = end_time - start_time + op_count = tensor_shape[0] * tensor_shape[1] * ((tensor_shape[1] * 2) - 1) + + # print(tensor_out.data().reshape(tensor_shape)) + if (tensor_out.data().reshape(tensor_shape) == mat_result).all(): + print(f'From {MatMulOp.__module__} : {experiment_count} matmul time : ' + f'{experiment_time * 1000:0.2f}ms => ' + f'{experiment_count / experiment_time:0.2f}op/s or ' + f'{experiment_count * op_count / (1e9 * experiment_time):0.2f} GFLOPS') + else: + print(f'Test failed => output tensor is wrong :\n{tensor_out.data().reshape(tensor_shape)}') + + +if __name__ == '__main__': + main() diff --git a/ggml/src/ggml-kompute/kompute/examples/python_naive_matmul/first_example.py b/ggml/src/ggml-kompute/kompute/examples/python_naive_matmul/first_example.py new file mode 100644 index 0000000000000000000000000000000000000000..1e7caa871f3d85ed1a7ceb78f4b4e111ae0ab24e --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/python_naive_matmul/first_example.py @@ -0,0 +1,60 @@ +import kp +import numpy as np + + +def main(): + mgr = kp.Manager() + + tensor_size = 4 + tensor_shape = [tensor_size, tensor_size] + tensor_in_1 = mgr.tensor(np.triu(np.ones(tensor_shape))) + tensor_in_2 = mgr.tensor(np.triu(np.ones(tensor_shape))) + tensor_out = mgr.tensor(np.zeros(tensor_shape)) + + print(f'Input tensors:\n' + f'{tensor_in_1.data().reshape(tensor_shape)}\n' + f'{tensor_in_2.data().reshape(tensor_shape)}\n') + + params = [tensor_in_1, tensor_in_2, tensor_out] + + matmul_shader = kp.Shader.compile_source(''' +#version 450 + +layout (local_size_x = 1, local_size_y = 1) in; + +layout (set = 0, binding = 0) readonly buffer buf_in_tensor_1 { float in_tensor_1[]; }; +layout (set = 0, binding = 1) readonly buffer buf_in_tensor_2 { float in_tensor_2[]; }; +layout (set = 0, binding = 2) writeonly buffer buf_out_tensor { float out_tensor[]; }; + +layout (constant_id = 0) const float tensor_size_f = 0; + + +void main() +{ + uint globalRow = gl_GlobalInvocationID.x; + uint globalCol = gl_GlobalInvocationID.y; + uint tensor_size = uint(tensor_size_f); + float acc = 0.0; + for(uint k = 0u; k < tensor_size; k++) + acc += in_tensor_1[(k * tensor_size) + globalRow] * in_tensor_2[(globalCol * tensor_size) + k]; + out_tensor[(globalCol * tensor_size) + globalRow] = acc; +}''') + + algo = mgr.algorithm( + params, # params + matmul_shader, # spirv + (*tensor_shape, 1), # workgroup + [float(tensor_size)], # spec_consts + []) # push_consts + + (mgr.sequence() + .record(kp.OpTensorSyncDevice(params)) + .record(kp.OpAlgoDispatch(algo)) + .record(kp.OpTensorSyncLocal(params)) + .eval()) + + print(f'Output :\n{tensor_out.data().reshape(tensor_shape)}') + + +if __name__ == '__main__': + main() diff --git a/ggml/src/ggml-kompute/kompute/examples/python_naive_matmul/imp1_naive.py b/ggml/src/ggml-kompute/kompute/examples/python_naive_matmul/imp1_naive.py new file mode 100644 index 0000000000000000000000000000000000000000..a791662d27c11dfd80cd1038ae164e777fac00ee --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/python_naive_matmul/imp1_naive.py @@ -0,0 +1,133 @@ +import time + +import kp +import numpy as np + + +class MatMulOp: + def __init__(self, manager: kp.Manager, local_size_x: int = -1, local_size_y: int = -1): + self.mgr = manager + + props = self.mgr.get_device_properties() + max_workgroup_invocation = props['max_work_group_invocations'] + max_workgroup_size = props['max_work_group_size'] + if local_size_x < 1: + if local_size_y > 0: + local_size_x = 1 + while (2 * local_size_x * local_size_y <= max_workgroup_invocation + and 2 * local_size_x <= max_workgroup_size[0]): + local_size_x *= 2 + else: + local_size_x = 1 + local_size_y = 1 + while 2 * local_size_x * local_size_y <= max_workgroup_invocation: + if 2 * local_size_x <= max_workgroup_size[0]: + local_size_x *= 2 + if 2 * local_size_y <= max_workgroup_size[1]: + local_size_y *= 2 + elif 2 * local_size_x > max_workgroup_size[0]: # stop if neither x nor y can be double + break + elif local_size_y < 0: + local_size_y = 1 + while (2 * local_size_x * local_size_y <= max_workgroup_invocation + and 2 * local_size_x <= max_workgroup_size[0]): + local_size_y *= 2 + + assert local_size_x > 0 + assert local_size_y > 0 + assert local_size_x * local_size_y <= max_workgroup_invocation + assert local_size_x <= max_workgroup_size[0] + assert local_size_y <= max_workgroup_size[1] + self.local_size_x = local_size_x + self.local_size_y = local_size_y + + self.shader = ''' +#version 450 + +layout (local_size_x = {local_size_x}, local_size_y = {local_size_y}) in; + +layout (set = 0, binding = 0) readonly buffer buf_in_tensor_1 {{ float in_tensor_1[]; }}; +layout (set = 0, binding = 1) readonly buffer buf_in_tensor_2 {{ float in_tensor_2[]; }}; +layout (set = 0, binding = 2) writeonly buffer buf_out_tensor {{ float out_tensor[]; }}; + +layout (constant_id = 0) const float tensor_size_f = 0; + + +void main() +{{ + uint globalRow = gl_GlobalInvocationID.x; + uint globalCol = gl_GlobalInvocationID.y; + uint tensor_size = uint(tensor_size_f); + float acc = 0.0; + for(uint k = 0u; k < tensor_size; k++) + acc += in_tensor_1[(k * tensor_size) + globalRow] * in_tensor_2[(globalCol * tensor_size) + k]; + out_tensor[(globalCol * tensor_size) + globalRow] = acc; +}}''' + self.compiled_shader = kp.Shader.compile_source(self.shader.format( + local_size_x=self.local_size_x, local_size_y=self.local_size_y)) + self.tensor_shape: tuple[int, int] = (0, 0) + self.params: list[kp.Tensor] = [] + self.algo = None + + def __call__(self, tensor_shape: tuple[int, int], tensor_in_1: kp.Tensor, tensor_in_2: kp.Tensor, + tensor_out: kp.Tensor): + params = [tensor_in_1, tensor_in_2, tensor_out] + + if self.algo is None or self.tensor_shape != tensor_shape or self.params != params: + self.tensor_shape = tensor_shape + self.params = params + local_size_x = min(self.local_size_x, tensor_shape[0]) + local_size_y = min(self.local_size_y, tensor_shape[1]) + self.compiled_shader = kp.Shader.compile_source(self.shader.format( + local_size_x=local_size_x, local_size_y=local_size_y)) + workgroup = (tensor_shape[0] // local_size_x, tensor_shape[1] // local_size_y, 1) + print(f'{workgroup=} {self.local_size_x=} {self.local_size_y=}') + self.algo = self.mgr.algorithm( + params, # params + self.compiled_shader, # spirv + workgroup, # workgroup + [float(tensor_shape[0])], # spec_consts + []) # push_consts + + (self.mgr.sequence() + .record(kp.OpTensorSyncDevice([tensor_in_1, tensor_in_2])) + .record(kp.OpAlgoDispatch(self.algo)) + .record(kp.OpTensorSyncLocal([tensor_out])) + .eval()) + + +def main(): + mgr = kp.Manager() + + matmul_op = MatMulOp(mgr) + + tensor_size = 4064 + tensor_shape = [tensor_size, tensor_size] + tensor_in_1 = mgr.tensor(np.triu(np.ones(tensor_shape))) + tensor_in_2 = mgr.tensor(np.triu(np.ones(tensor_shape))) + tensor_out = mgr.tensor(np.zeros(tensor_shape)) + + print(f'{tensor_shape} input tensors:\n' + f'{tensor_in_1.data().reshape(tensor_shape)}\n' + f'{tensor_in_2.data().reshape(tensor_shape)}\n') + + matmul_op(tensor_shape, tensor_in_1, tensor_in_2, tensor_out) + + experiment_count = 8 + start_time = time.time() + for _ in range(experiment_count): + matmul_op(tensor_shape, tensor_in_1, tensor_in_2, tensor_out) + end_time = time.time() + experiment_time = end_time - start_time + op_count = tensor_shape[0] * tensor_shape[1] * ((tensor_shape[1] * 2) - 1) + + print(f'Output :\n{tensor_out.data().reshape(tensor_shape)}') + + print(f'{experiment_count} matmul time : ' + f'{experiment_time * 1000:0.2f}ms => ' + f'{experiment_count / experiment_time:0.2f}op/s or ' + f'{experiment_count * op_count / (1e9 * experiment_time):0.2f} GFLOPS') + + +if __name__ == '__main__': + main() diff --git a/ggml/src/ggml-kompute/kompute/examples/python_naive_matmul/imp2_tiled.py b/ggml/src/ggml-kompute/kompute/examples/python_naive_matmul/imp2_tiled.py new file mode 100644 index 0000000000000000000000000000000000000000..1ac13e8580c95e465e33f5e5336fda592ea7f8a9 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/python_naive_matmul/imp2_tiled.py @@ -0,0 +1,132 @@ +import time + +import kp +import numpy as np + + +class MatMulOp: + def __init__(self, manager: kp.Manager, tile_size: int = -1): + self.mgr = manager + + props = self.mgr.get_device_properties() + max_workgroup_invocation = props['max_work_group_invocations'] + max_workgroup_size = props['max_work_group_size'] + if tile_size < 0: + tile_size = 1 + while (4 * tile_size * tile_size <= max_workgroup_invocation + and 2 * tile_size <= max_workgroup_size[0] + and 2 * tile_size <= max_workgroup_size[1]): + tile_size *= 2 + + assert tile_size > 0 + assert tile_size * tile_size <= max_workgroup_invocation + assert tile_size <= max_workgroup_size[0] + assert tile_size <= max_workgroup_size[1] + self.tile_size = tile_size + + self.shader = ''' +#version 450 + +layout (local_size_x = {tile_size}, local_size_y = {tile_size}) in; + +layout (set = 0, binding = 0) readonly buffer buf_in_tensor_1 {{ float in_tensor_1[]; }}; +layout (set = 0, binding = 1) readonly buffer buf_in_tensor_2 {{ float in_tensor_2[]; }}; +layout (set = 0, binding = 2) writeonly buffer buf_out_tensor {{ float out_tensor[]; }}; + +layout (constant_id = 0) const float tensor_size_f = 0; + +shared float sub_tensor_1[{tile_size}][{tile_size}]; +shared float sub_tensor_2[{tile_size}][{tile_size}]; + +void main() +{{ + uint row = gl_LocalInvocationID.x; // 0 .. tile_size + uint col = gl_LocalInvocationID.y; // 0 .. tile_size + // gl_WorkGroupID : 0 .. tensor_size / tile_size + uint globalRow = {tile_size} * gl_WorkGroupID.x + row; + uint globalCol = {tile_size} * gl_WorkGroupID.y + col; + + uint tensor_size = uint(tensor_size_f); + float acc = 0.0; + uint numTiles = tensor_size / {tile_size}; + for(uint t = 0u; t < numTiles; t++) + {{ + uint tiledRow = ({tile_size} * t) + row; + uint tiledCol = ({tile_size} * t) + col; + sub_tensor_1[col][row] = in_tensor_1[(tiledCol * tensor_size) + globalRow]; + sub_tensor_2[col][row] = in_tensor_2[(globalCol * tensor_size) + tiledRow]; + + memoryBarrierShared(); + barrier(); + + for(uint k = 0u; k < {tile_size}; k++) + acc += sub_tensor_1[k][row] * sub_tensor_2[col][k]; + + barrier(); + }} + out_tensor[tensor_size * globalCol + globalRow] = acc; +}}''' + self.compiled_shader = kp.Shader.compile_source(self.shader.format(tile_size=tile_size)) + self.tensor_shape: tuple[int, int] = (0, 0) + self.params: list[kp.Tensor] = [] + self.algo = None + + def __call__(self, tensor_shape: tuple[int, int], tensor_in_1: kp.Tensor, tensor_in_2: kp.Tensor, + tensor_out: kp.Tensor): + params = [tensor_in_1, tensor_in_2, tensor_out] + + if self.algo is None or self.tensor_shape != tensor_shape or self.params != params: + self.tensor_shape = tensor_shape + self.params = params + tile_size = min(tensor_shape[0], tensor_shape[1], self.tile_size) + self.compiled_shader = kp.Shader.compile_source(self.shader.format(tile_size=tile_size)) + workgroup = [tensor_shape[0] // tile_size, tensor_shape[1] // tile_size, 1] + self.algo = self.mgr.algorithm( + params, # params + self.compiled_shader, # spirv + workgroup, # workgroup + [float(tensor_shape[0])], # spec_consts + []) # push_consts + + (self.mgr.sequence() + .record(kp.OpTensorSyncDevice([tensor_in_1, tensor_in_2])) + .record(kp.OpAlgoDispatch(self.algo)) + .record(kp.OpTensorSyncLocal([tensor_out])) + .eval()) + + +def main(): + mgr = kp.Manager() + + matmul_op = MatMulOp(mgr) + + tensor_size = 4096 + tensor_shape = [tensor_size, tensor_size] + tensor_in_1 = mgr.tensor(np.triu(np.ones(tensor_shape))) + tensor_in_2 = mgr.tensor(np.triu(np.ones(tensor_shape))) + tensor_out = mgr.tensor(np.zeros(tensor_shape)) + + print(f'{tensor_shape} input tensors:\n' + f'{tensor_in_1.data().reshape(tensor_shape)}\n' + f'{tensor_in_2.data().reshape(tensor_shape)}\n') + + matmul_op(tensor_shape, tensor_in_1, tensor_in_2, tensor_out) + + experiment_count = 8 + start_time = time.time() + for _ in range(experiment_count): + matmul_op(tensor_shape, tensor_in_1, tensor_in_2, tensor_out) + end_time = time.time() + experiment_time = end_time - start_time + op_count = tensor_shape[0] * tensor_shape[1] * ((tensor_shape[1] * 2) - 1) + + print(f'Output :\n{tensor_out.data().reshape(tensor_shape)}') + + print(f'{experiment_count} matmul time : ' + f'{experiment_time * 1000:0.2f}ms => ' + f'{experiment_count / experiment_time:0.2f}op/s or ' + f'{experiment_count * op_count / (1e9 * experiment_time):0.2f} GFLOPS') + + +if __name__ == '__main__': + main() diff --git a/ggml/src/ggml-kompute/kompute/examples/python_naive_matmul/imp3_better_tiling.py b/ggml/src/ggml-kompute/kompute/examples/python_naive_matmul/imp3_better_tiling.py new file mode 100644 index 0000000000000000000000000000000000000000..8cd44277be6c3a02e1c6c5661206be5efb4fb0a7 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/python_naive_matmul/imp3_better_tiling.py @@ -0,0 +1,153 @@ +import time + +import kp +import numpy as np + + +class MatMulOp: + def __init__(self, manager: kp.Manager, tile_size: int = -1, thread_work_ratio: int = 16): + self.mgr = manager + + props = self.mgr.get_device_properties() + max_workgroup_invocation = props['max_work_group_invocations'] + max_workgroup_size = props['max_work_group_size'] + if tile_size < 0: + tile_size = 1 + local_size_y = tile_size // thread_work_ratio + while (4 * tile_size * tile_size <= max_workgroup_invocation + and 2 * tile_size <= max_workgroup_size[0] + and 2 * tile_size <= max_workgroup_size[1]): + tile_size *= 2 + local_size_y = tile_size // thread_work_ratio + else: + local_size_y = tile_size // thread_work_ratio + + assert tile_size > 0 + assert thread_work_ratio > 0 + assert tile_size * local_size_y <= max_workgroup_invocation + assert tile_size <= max_workgroup_size[0] + assert local_size_y <= max_workgroup_size[1] + self.tile_size = tile_size + self.thread_work_ratio = thread_work_ratio + + self.local_size_x = tile_size + self.local_size_y = tile_size // thread_work_ratio + self.shader = ''' +#version 450 + +layout (local_size_x = {tile_size}, local_size_y = {local_size_y}) in; + +layout (set = 0, binding = 0) readonly buffer buf_in_tensor_1 {{ float in_tensor_1[]; }}; +layout (set = 0, binding = 1) readonly buffer buf_in_tensor_2 {{ float in_tensor_2[]; }}; +layout (set = 0, binding = 2) writeonly buffer buf_out_tensor {{ float out_tensor[]; }}; + +layout (constant_id = 0) const float tensor_size_f = 0; + +shared float sub_tensor_1[{tile_size}][{tile_size}]; +shared float sub_tensor_2[{tile_size}][{tile_size}]; + +void main() +{{ + uint row = gl_LocalInvocationID.x; + uint col = gl_LocalInvocationID.y; + uint globalRow = {tile_size} * gl_WorkGroupID.x + row; + uint globalCol = {tile_size} * gl_WorkGroupID.y + col; + + uint tensor_size = uint(tensor_size_f); + float acc[{thread_work_ratio}]; + for(uint w = 0u; w < {thread_work_ratio}; w++) + acc[w] = 0.0; + + uint numTiles = tensor_size / {tile_size}; + for(uint t = 0u; t < numTiles; t++) + {{ + for(uint w = 0u; w < {thread_work_ratio}; w++) + {{ + uint tiledRow = {tile_size} * t + row; + uint tiledCol = {tile_size} * t + col; + sub_tensor_1[col + w * {local_size_y}][row] = in_tensor_1[ + (tiledCol + w * {local_size_y}) * tensor_size + globalRow]; + sub_tensor_2[col + w * {local_size_y}][row] = in_tensor_2[ + (globalCol + w * {local_size_y})* tensor_size + tiledRow]; + }} + + memoryBarrierShared(); + barrier(); + + for(uint k = 0u; k < {tile_size}; k++) + for(uint w = 0u; w < {thread_work_ratio}; w++) + acc[w] += sub_tensor_1[k][row] * sub_tensor_2[col + w * {local_size_y}][k]; + + barrier(); + }} + for(uint w = 0u; w < {thread_work_ratio}; w++) + out_tensor[(globalCol + w * {local_size_y}) * tensor_size + globalRow] = acc[w]; +}}''' + self.compiled_shader = kp.Shader.compile_source(self.shader.format( + tile_size=tile_size, thread_work_ratio=thread_work_ratio, local_size_y=local_size_y)) + self.tensor_shape: tuple[int, int] = (0, 0) + self.params: list[kp.Tensor] = [] + self.algo = None + + def __call__(self, tensor_shape: tuple[int, int], tensor_in_1: kp.Tensor, tensor_in_2: kp.Tensor, + tensor_out: kp.Tensor): + params = [tensor_in_1, tensor_in_2, tensor_out] + + if self.algo is None or self.tensor_shape != tensor_shape or self.params != params: + self.tensor_shape = tensor_shape + self.params = params + tile_size = min(self.tensor_shape[0], self.tile_size) + thread_work_ratio = min(self.tensor_shape[1] // self.tile_size, self.thread_work_ratio) + local_size_y = tile_size // thread_work_ratio + self.compiled_shader = kp.Shader.compile_source(self.shader.format( + tile_size=tile_size, thread_work_ratio=thread_work_ratio, local_size_y=local_size_y)) + workgroup = (tensor_shape[0] // self.local_size_x, tensor_shape[1] // self.local_size_y, 1) + self.algo = self.mgr.algorithm( + params, # params + self.compiled_shader, # spirv + workgroup, # workgroup + [float(tensor_shape[0])], # spec_consts + []) # push_consts + + (self.mgr.sequence() + .record(kp.OpTensorSyncDevice([tensor_in_1, tensor_in_2])) + .record(kp.OpAlgoDispatch(self.algo)) + .record(kp.OpTensorSyncLocal([tensor_out])) + .eval()) + + +def main(): + mgr = kp.Manager() + + matmul_op = MatMulOp(mgr) + + tensor_size = 4096 + tensor_shape = [tensor_size, tensor_size] + tensor_in_1 = mgr.tensor(np.triu(np.ones(tensor_shape))) + tensor_in_2 = mgr.tensor(np.triu(np.ones(tensor_shape))) + tensor_out = mgr.tensor(np.zeros(tensor_shape)) + + print(f'{tensor_shape} input tensors:\n' + f'{tensor_in_1.data().reshape(tensor_shape)}\n' + f'{tensor_in_2.data().reshape(tensor_shape)}\n') + + matmul_op(tensor_shape, tensor_in_1, tensor_in_2, tensor_out) + + experiment_count = 2 + start_time = time.time() + for _ in range(experiment_count): + matmul_op(tensor_shape, tensor_in_1, tensor_in_2, tensor_out) + end_time = time.time() + experiment_time = end_time - start_time + op_count = tensor_shape[0] * tensor_shape[1] * ((tensor_shape[1] * 2) - 1) + + print(f'Output :\n{tensor_out.data().reshape(tensor_shape)}') + + print(f'{experiment_count} matmul time : ' + f'{experiment_time * 1000:0.2f}ms => ' + f'{experiment_count / experiment_time:0.2f}op/s or ' + f'{experiment_count * op_count / (1e9 * experiment_time):0.2f} GFLOPS') + + +if __name__ == '__main__': + main() diff --git a/ggml/src/ggml-kompute/kompute/examples/python_naive_matmul/matmul_plot.py b/ggml/src/ggml-kompute/kompute/examples/python_naive_matmul/matmul_plot.py new file mode 100644 index 0000000000000000000000000000000000000000..81763a439358ef13bc0cb1e6985c0e9cdaf54aed --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/examples/python_naive_matmul/matmul_plot.py @@ -0,0 +1,97 @@ +from argparse import ArgumentParser + +import cv2 +import numpy as np + + +def plot_tensor(window_name: str, tensor: np.ndarray, coord_highlight: tuple[int, int] = None): + font_size = 48 + image = np.zeros((tensor.shape[1] * font_size, tensor.shape[0] * font_size, 3), dtype=np.uint8) + + for y in range(tensor.shape[1]): + for x in range(tensor.shape[0]): + if coord_highlight and x == coord_highlight[1] and y == coord_highlight[0]: + cv2.putText( + image, str(int(tensor[y, x])), (x * font_size, int((y + 0.8) * font_size)), + cv2.FONT_HERSHEY_TRIPLEX, 1., (127, 127, 255)) + else: + cv2.putText( + image, str(int(tensor[y, x])), (x * font_size, int((y + 0.8) * font_size)), + cv2.FONT_HERSHEY_TRIPLEX, 1., (255, 255, 255)) + + cv2.imshow(window_name, image) + + +def main(): + parser = ArgumentParser() + parser.add_argument('tensor_size', type=int, help='Size of the square tensors') + parser.add_argument('tile_size', type=int) + parser.add_argument('local_size', type=int, nargs=2) + parser.add_argument('workgroup', type=int, nargs=2) + arguments = parser.parse_args() + + tensor_size: int = arguments.tensor_size + tile_size: int = arguments.tile_size + local_size: tuple[int, int, int] = tuple(arguments.local_size) + workgroup: tuple[int, int, int] = tuple(arguments.workgroup) + + tensor_shape = (tensor_size, tensor_size) + tensor_1 = np.triu(np.ones(tensor_shape)) + tensor_2 = np.triu(np.ones(tensor_shape)) + tensor_out = np.zeros(tensor_shape) + tensor_test_1 = np.zeros(tensor_shape) + tensor_test_2 = np.zeros(tensor_shape) + tensor_test_3 = np.zeros(tensor_shape) + tensor_test_4 = np.zeros(tensor_shape) + tensor_test_5 = np.zeros(tensor_shape) + + plot_tensor('tensor_1', tensor_1) + plot_tensor('tensor_2', tensor_2) + plot_tensor('tensor_out', tensor_out) + plot_tensor('tensor_test_1', tensor_test_1) + plot_tensor('tensor_test_2', tensor_test_2) + plot_tensor('tensor_test_3', tensor_test_3) + plot_tensor('tensor_test_4', tensor_test_4) + plot_tensor('tensor_test_5', tensor_test_5) + cv2.waitKey(-1) + + print(f'{workgroup=} {local_size=}') + for workgroup_x in range(workgroup[0]): + for workgroup_y in range(workgroup[1]): + for invocation_x in range(workgroup_x * local_size[0], (workgroup_x + 1) * local_size[0]): + for invocation_y in range(workgroup_y * local_size[1], (workgroup_y + 1) * local_size[1]): + row = invocation_x + col = invocation_y + globalRow = (tile_size * workgroup_x) + row + globalCol = (tile_size * workgroup_y) + col + try: + tensor_out[row, col] = row + tensor_test_1[row, col] = col + tensor_test_2[row, col] = workgroup_x + tensor_test_3[row, col] = workgroup_y + tensor_test_4[row, col] = globalRow + tensor_test_5[row, col] = globalCol + plot_tensor('tensor_out', tensor_out, (row, col)) + plot_tensor('tensor_test_1', tensor_test_1, (row, col)) + plot_tensor('tensor_test_2', tensor_test_2, (row, col)) + plot_tensor('tensor_test_3', tensor_test_3, (row, col)) + plot_tensor('tensor_test_4', tensor_test_4, (row, col)) + plot_tensor('tensor_test_5', tensor_test_5, (row, col)) + cv2.waitKey(-1) + except IndexError as error: + print(f'{workgroup_x=} {workgroup_y=} {row=} {col=}') + raise error + + plot_tensor('tensor_1', tensor_1) + plot_tensor('tensor_2', tensor_2) + plot_tensor('tensor_out', tensor_out) + plot_tensor('tensor_test_1', tensor_test_1) + plot_tensor('tensor_test_2', tensor_test_2) + plot_tensor('tensor_test_3', tensor_test_3) + plot_tensor('tensor_test_4', tensor_test_4) + plot_tensor('tensor_test_5', tensor_test_5) + cv2.waitKey(-1) + + +if __name__ == '__main__': + main() diff --git a/ggml/src/ggml-kompute/kompute/external/bin/xxd.c b/ggml/src/ggml-kompute/kompute/external/bin/xxd.c new file mode 100644 index 0000000000000000000000000000000000000000..4d16f77c4757fa70884457f51acefe0f35c37ffd --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/external/bin/xxd.c @@ -0,0 +1,809 @@ +/* +As indicated at https://lists.debian.org/debian-legal/2015/01/msg00037.html, +the author has permitted redistribution of xxd under the MIT license, as follows: + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be included +in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + * + * xxd: my hexdump facility. jw + * + * 2.10.90 changed to word output + * 3.03.93 new indent style, dumb bug inserted and fixed. + * -c option, mls + * 26.04.94 better option parser, -ps, -l, -s added. + * 1.07.94 -r badly needs - as input file. Per default autoskip over + * consecutive lines of zeroes, as unix od does. + * -a shows them too. + * -i dump as c-style #include "file.h" + * 1.11.95 if "xxd -i" knows the filename, an 'unsigned char filename_bits[]' + * array is written in correct c-syntax. + * -s improved, now defaults to absolute seek, relative requires a '+'. + * -r improved, now -r -s -0x... is supported. + * change/suppress leading '\0' bytes. + * -l n improved: stops exactly after n bytes. + * -r improved, better handling of partial lines with trailing garbage. + * -r improved, now -r -p works again! + * -r improved, less flushing, much faster now! (that was silly) + * 3.04.96 Per repeated request of a single person: autoskip defaults to off. + * 15.05.96 -v added. They want to know the version. + * -a fixed, to show last line inf file ends in all zeros. + * -u added: Print upper case hex-letters, as preferred by unix bc. + * -h added to usage message. Usage message extended. + * Now using outfile if specified even in normal mode, aehem. + * No longer mixing of ints and longs. May help doze people. + * Added binify ioctl for same reason. (Enough Doze stress for 1996!) + * 16.05.96 -p improved, removed occasional superfluous linefeed. + * 20.05.96 -l 0 fixed. tried to read anyway. + * 21.05.96 -i fixed. now honours -u, and prepends __ to numeric filenames. + * compile -DWIN32 for NT or W95. George V. Reilly, * -v improved :-) + * support --gnuish-longhorn-options + * 25.05.96 MAC support added: CodeWarrior already uses ``outline'' in Types.h + * which is included by MacHeaders (Axel Kielhorn). Renamed to + * xxdline(). + * 7.06.96 -i printed 'int' instead of 'char'. *blush* + * added Bram's OS2 ifdefs... + * 18.07.96 gcc -Wall @ SunOS4 is now slient. + * Added osver for MSDOS/DJGPP/WIN32. + * 29.08.96 Added size_t to strncmp() for Amiga. + * 24.03.97 Windows NT support (Phil Hanna). Clean exit for Amiga WB (Bram) + * 02.04.97 Added -E option, to have EBCDIC translation instead of ASCII + * (azc10@yahoo.com) + * 22.05.97 added -g (group octets) option (jcook@namerica.kla.com). + * 23.09.98 nasty -p -r misfeature fixed: slightly wrong output, when -c was + * missing or wrong. + * 26.09.98 Fixed: 'xxd -i infile outfile' did not truncate outfile. + * 27.10.98 Fixed: -g option parser required blank. + * option -b added: 01000101 binary output in normal format. + * 16.05.00 Added VAXC changes by Stephen P. Wall + * 16.05.00 Improved MMS file and merge for VMS by Zoltan Arpadffy + * + * (c) 1990-1998 by Juergen Weigert (jnweiger@informatik.uni-erlangen.de) + * + * Small changes made afterwards by Bram Moolenaar et al. + * + * Distribute freely and credit me, + * make money and share with me, + * lose money and don't ask me. + * + * + */ + +/* Visual Studio 2005 has 'deprecated' many of the standard CRT functions */ +#if _MSC_VER >= 1400 +# define _CRT_SECURE_NO_DEPRECATE +# define _CRT_NONSTDC_NO_DEPRECATE +#endif + +#include +#ifdef VAXC +# include +#else +# include +#endif +#ifdef __TSC__ +# define MSDOS +#endif +#if !defined(OS2) && defined(__EMX__) +# define OS2 +#endif +#if defined(MSDOS) || defined(WIN32) || defined(OS2) || defined(__BORLANDC__) || defined(CYGWIN) +# include /* for setmode() */ +#else +# ifdef UNIX +# include +# endif +#endif +#include +#include /* for strncmp() */ +#include /* for isalnum() */ +#if __MWERKS__ && !defined(BEBOX) +# include /* for fdopen() on MAC */ +#endif + +#if defined(__BORLANDC__) && __BORLANDC__ <= 0x0410 && !defined(fileno) +/* Missing define and prototype grabbed from the BC 4.0 */ +# define fileno(f) ((f)->fd) +FILE _FAR *_Cdecl _FARFUNC fdopen(int __handle, char _FAR *__type); +#endif + + +/* This corrects the problem of missing prototypes for certain functions + * in some GNU installations (e.g. SunOS 4.1.x). + * Darren Hiebert (sparc-sun-sunos4.1.3_U1/2.7.2.2) + */ +#if defined(__GNUC__) && defined(__STDC__) +# ifndef __USE_FIXED_PROTOTYPES__ +# define __USE_FIXED_PROTOTYPES__ +# endif +#endif + +#ifndef __USE_FIXED_PROTOTYPES__ +/* + * This is historic and works only if the compiler really has no prototypes: + * + * Include prototypes for Sun OS 4.x, when using an ANSI compiler. + * FILE is defined on OS 4.x, not on 5.x (Solaris). + * if __SVR4 is defined (some Solaris versions), don't include this. + */ +#if defined(sun) && defined(FILE) && !defined(__SVR4) && defined(__STDC__) +# define __P(a) a +/* excerpt from my sun_stdlib.h */ +extern int fprintf __P((FILE *, char *, ...)); +extern int fputs __P((char *, FILE *)); +extern int _flsbuf __P((unsigned char, FILE *)); +extern int _filbuf __P((FILE *)); +extern int fflush __P((FILE *)); +extern int fclose __P((FILE *)); +extern int fseek __P((FILE *, long, int)); +extern int rewind __P((FILE *)); + +extern void perror __P((char *)); +# endif +#endif + +extern long int strtol(); +extern long int ftell(); + +char version[] = "xxd V1.10 27oct98 by Juergen Weigert"; +#ifdef WIN32 +char osver[] = " (Win32)"; +#else +# ifdef DJGPP +char osver[] = " (dos 32 bit)"; +# else +# ifdef MSDOS +char osver[] = " (dos 16 bit)"; +# else +char osver[] = ""; +# endif +# endif +#endif + +#if !defined(CYGWIN) && (defined(CYGWIN32) || defined(__CYGWIN__) || defined(__CYGWIN32__)) +# define CYGWIN +#endif +#if defined(MSDOS) || defined(WIN32) || defined(OS2) +# define BIN_READ(yes) ((yes) ? "rb" : "rt") +# define BIN_WRITE(yes) ((yes) ? "wb" : "wt") +# define BIN_CREAT(yes) ((yes) ? (O_CREAT|O_BINARY) : O_CREAT) +# define BIN_ASSIGN(fp, yes) setmode(fileno(fp), (yes) ? O_BINARY : O_TEXT) +# define PATH_SEP '\\' +#elif defined(CYGWIN) +# define BIN_READ(yes) ((yes) ? "rb" : "rt") +# define BIN_WRITE(yes) ((yes) ? "wb" : "w") +# define BIN_CREAT(yes) ((yes) ? (O_CREAT|O_BINARY) : O_CREAT) +# define BIN_ASSIGN(fp, yes) ((yes) ? (void) setmode(fileno(fp), O_BINARY) : (void) (fp)) +# define PATH_SEP '/' +#else +# ifdef VMS +# define BIN_READ(dummy) "r" +# define BIN_WRITE(dummy) "w" +# define BIN_CREAT(dummy) O_CREAT +# define BIN_ASSIGN(fp, dummy) fp +# define PATH_SEP ']' +# define FILE_SEP '.' +# else +# define BIN_READ(dummy) "r" +# define BIN_WRITE(dummy) "w" +# define BIN_CREAT(dummy) O_CREAT +# define BIN_ASSIGN(fp, dummy) fp +# define PATH_SEP '/' +# endif +#endif + +/* open has only to arguments on the Mac */ +#if __MWERKS__ +# define OPEN(name, mode, umask) open(name, mode) +#else +# define OPEN(name, mode, umask) open(name, mode, umask) +#endif + +#ifdef AMIGA +# define STRNCMP(s1, s2, l) strncmp(s1, s2, (size_t)l) +#else +# define STRNCMP(s1, s2, l) strncmp(s1, s2, l) +#endif + +#ifndef __P +# if defined(__STDC__) || defined(MSDOS) || defined(WIN32) || defined(OS2) \ + || defined(__BORLANDC__) +# define __P(a) a +# else +# define __P(a) () +# endif +#endif + +/* Let's collect some prototypes */ +/* CodeWarrior is really picky about missing prototypes */ +static void exit_with_usage __P((char *)); +static int huntype __P((FILE *, FILE *, FILE *, char *, int, int, long)); +static void xxdline __P((FILE *, char *, int)); + +#define TRY_SEEK /* attempt to use lseek, or skip forward by reading */ +#define COLS 256 /* change here, if you ever need more columns */ +#define LLEN (11 + (9*COLS-1)/1 + COLS + 2) + +char hexxa[] = "0123456789abcdef0123456789ABCDEF", *hexx = hexxa; + +/* the different hextypes known by this program: */ +#define HEX_NORMAL 0 +#define HEX_POSTSCRIPT 1 +#define HEX_CINCLUDE 2 +#define HEX_BITS 3 /* not hex a dump, but bits: 01111001 */ + +static void +exit_with_usage(pname) +char *pname; +{ + fprintf(stderr, "Usage:\n %s [options] [infile [outfile]]\n", pname); + fprintf(stderr, " or\n %s -r [-s [-]offset] [-c cols] [-ps] [infile [outfile]]\n", pname); + fprintf(stderr, "Options:\n"); + fprintf(stderr, " -a toggle autoskip: A single '*' replaces nul-lines. Default off.\n"); + fprintf(stderr, " -b binary digit dump (incompatible with -p,-i,-r). Default hex.\n"); + fprintf(stderr, " -c cols format octets per line. Default 16 (-i: 12, -ps: 30).\n"); + fprintf(stderr, " -E show characters in EBCDIC. Default ASCII.\n"); + fprintf(stderr, " -g number of octets per group in normal output. Default 2.\n"); + fprintf(stderr, " -h print this summary.\n"); + fprintf(stderr, " -i output in C include file style.\n"); + fprintf(stderr, " -l len stop after octets.\n"); + fprintf(stderr, " -ps output in postscript plain hexdump style.\n"); + fprintf(stderr, " -r reverse operation: convert (or patch) hexdump into binary.\n"); + fprintf(stderr, " -r -s off revert with added to file positions found in hexdump.\n"); + fprintf(stderr, " -s %sseek start at bytes abs. %sinfile offset.\n", +#ifdef TRY_SEEK + "[+][-]", "(or +: rel.) "); +#else + "", ""); +#endif + fprintf(stderr, " -u use upper case hex letters.\n"); + fprintf(stderr, " -v show version: \"%s%s\".\n", version, osver); + exit(1); +} + +/* + * Max. cols binary characters are decoded from the input stream per line. + * Two adjacent garbage characters after evaluated data delimit valid data. + * Everything up to the next newline is discarded. + * + * The name is historic and came from 'undo type opt h'. + */ +static int +huntype(fpi, fpo, fperr, pname, cols, hextype, base_off) +FILE *fpi, *fpo, *fperr; +char *pname; +int cols, hextype; +long base_off; +{ + int c, ign_garb = 1, n1 = -1, n2 = 0, n3, p = cols; + long have_off = 0, want_off = 0; + + rewind(fpi); + + while ((c = getc(fpi)) != EOF) + { + if (c == '\r') /* Doze style input file? */ + continue; + +#if 0 /* this doesn't work when there is normal text after the hex codes in + the last line that looks like hex */ + if (c == ' ' || c == '\n' || c == '\t') /* allow multiple spaces */ + continue; +#endif + + n3 = n2; + n2 = n1; + + if (c >= '0' && c <= '9') + n1 = c - '0'; + else if (c >= 'a' && c <= 'f') + n1 = c - 'a' + 10; + else if (c >= 'A' && c <= 'F') + n1 = c - 'A' + 10; + else + { + n1 = -1; + if (ign_garb) + continue; + } + + ign_garb = 0; + + if (p >= cols) + { + if (!hextype) + { + if (n1 < 0) + { + p = 0; + continue; + } + want_off = (want_off << 4) | n1; + continue; + } + else + p = 0; + } + + if (base_off + want_off != have_off) + { + fflush(fpo); +#ifdef TRY_SEEK + c = fseek(fpo, base_off + want_off - have_off, 1); + if (c >= 0) + have_off = base_off + want_off; +#endif + if (base_off + want_off < have_off) + { + fprintf(fperr, "%s: sorry, cannot seek backwards.\n", pname); + return 5; + } + for (; have_off < base_off + want_off; have_off++) + putc(0, fpo); + } + + if (n2 >= 0 && n1 >= 0) + { + putc((n2 << 4) | n1, fpo); + have_off++; + want_off++; + n1 = -1; + if ((++p >= cols) && !hextype) + { + /* skip rest of line as garbage */ + want_off = 0; + while ((c = getc(fpi)) != '\n' && c != EOF) + ; + ign_garb = 1; + } + } + else if (n1 < 0 && n2 < 0 && n3 < 0) + { + /* already stumbled into garbage, skip line, wait and see */ + if (!hextype) + want_off = 0; + while ((c = getc(fpi)) != '\n' && c != EOF) + ; + ign_garb = 1; + } + } + fflush(fpo); +#ifdef TRY_SEEK + fseek(fpo, 0L, 2); +#endif + fclose(fpo); + fclose(fpi); + return 0; +} + +/* + * Print line l. If nz is false, xxdline regards the line a line of + * zeroes. If there are three or more consecutive lines of zeroes, + * they are replaced by a single '*' character. + * + * If the output ends with more than two lines of zeroes, you + * should call xxdline again with l being the last line and nz + * negative. This ensures that the last line is shown even when + * it is all zeroes. + * + * If nz is always positive, lines are never suppressed. + */ +static void +xxdline(fp, l, nz) +FILE *fp; +char *l; +int nz; +{ + static char z[LLEN+1]; + static int zero_seen = 0; + + if (!nz && zero_seen == 1) + strcpy(z, l); + + if (nz || !zero_seen++) + { + if (nz) + { + if (nz < 0) + zero_seen--; + if (zero_seen == 2) + fputs(z, fp); + if (zero_seen > 2) + fputs("*\n", fp); + } + if (nz >= 0 || zero_seen > 0) + fputs(l, fp); + if (nz) + zero_seen = 0; + } +} + +/* This is an EBCDIC to ASCII conversion table */ +/* from a proposed BTL standard April 16, 1979 */ +static unsigned char etoa64[] = +{ + 0040,0240,0241,0242,0243,0244,0245,0246, + 0247,0250,0325,0056,0074,0050,0053,0174, + 0046,0251,0252,0253,0254,0255,0256,0257, + 0260,0261,0041,0044,0052,0051,0073,0176, + 0055,0057,0262,0263,0264,0265,0266,0267, + 0270,0271,0313,0054,0045,0137,0076,0077, + 0272,0273,0274,0275,0276,0277,0300,0301, + 0302,0140,0072,0043,0100,0047,0075,0042, + 0303,0141,0142,0143,0144,0145,0146,0147, + 0150,0151,0304,0305,0306,0307,0310,0311, + 0312,0152,0153,0154,0155,0156,0157,0160, + 0161,0162,0136,0314,0315,0316,0317,0320, + 0321,0345,0163,0164,0165,0166,0167,0170, + 0171,0172,0322,0323,0324,0133,0326,0327, + 0330,0331,0332,0333,0334,0335,0336,0337, + 0340,0341,0342,0343,0344,0135,0346,0347, + 0173,0101,0102,0103,0104,0105,0106,0107, + 0110,0111,0350,0351,0352,0353,0354,0355, + 0175,0112,0113,0114,0115,0116,0117,0120, + 0121,0122,0356,0357,0360,0361,0362,0363, + 0134,0237,0123,0124,0125,0126,0127,0130, + 0131,0132,0364,0365,0366,0367,0370,0371, + 0060,0061,0062,0063,0064,0065,0066,0067, + 0070,0071,0372,0373,0374,0375,0376,0377 +}; + +int +main(argc, argv) +int argc; +char *argv[]; +{ + FILE *fp, *fpo; + int c, e, p = 0, relseek = 1, negseek = 0, revert = 0; + int cols = 0, nonzero = 0, autoskip = 0, hextype = HEX_NORMAL; + int ebcdic = 0; + int octspergrp = -1; /* number of octets grouped in output */ + int grplen; /* total chars per octet group */ + long length = -1, n = 0, seekoff = 0; + char l[LLEN+1]; + char *pname, *pp; + +#ifdef AMIGA + /* This program doesn't work when started from the Workbench */ + if (argc == 0) + exit(1); +#endif + + pname = argv[0]; + for (pp = pname; *pp; ) + if (*pp++ == PATH_SEP) + pname = pp; +#ifdef FILE_SEP + for (pp = pname; *pp; pp++) + if (*pp == FILE_SEP) + { + *pp = '\0'; + break; + } +#endif + + while (argc >= 2) + { + pp = argv[1] + (!STRNCMP(argv[1], "--", 2) && argv[1][2]); + if (!STRNCMP(pp, "-a", 2)) autoskip = 1 - autoskip; + else if (!STRNCMP(pp, "-b", 2)) hextype = HEX_BITS; + else if (!STRNCMP(pp, "-u", 2)) hexx = hexxa + 16; + else if (!STRNCMP(pp, "-p", 2)) hextype = HEX_POSTSCRIPT; + else if (!STRNCMP(pp, "-i", 2)) hextype = HEX_CINCLUDE; + else if (!STRNCMP(pp, "-r", 2)) revert++; + else if (!STRNCMP(pp, "-E", 2)) ebcdic++; + else if (!STRNCMP(pp, "-v", 2)) + { + fprintf(stderr, "%s%s\n", version, osver); + exit(0); + } + else if (!STRNCMP(pp, "-c", 2)) + { + if (pp[2] && STRNCMP("ols", pp + 2, 3)) + cols = (int)strtol(pp + 2, NULL, 0); + else + { + if (!argv[2]) + exit_with_usage(pname); + cols = (int)strtol(argv[2], NULL, 0); + argv++; + argc--; + } + } + else if (!STRNCMP(pp, "-g", 2)) + { + if (pp[2] && STRNCMP("roupsize", pp + 2, 8)) + octspergrp = (int)strtol(pp + 2, NULL, 0); + else + { + if (!argv[2]) + exit_with_usage(pname); + octspergrp = (int)strtol(argv[2], NULL, 0); + argv++; + argc--; + } + } + else if (!STRNCMP(pp, "-s", 2)) + { + relseek = 0; + negseek = 0; + if (pp[2] && STRNCMP("kip", pp+2, 3) && STRNCMP("eek", pp+2, 3)) + { +#ifdef TRY_SEEK + if (pp[2] == '+') + relseek++; + if (pp[2+relseek] == '-') + negseek++; +#endif + seekoff = strtol(pp + 2+relseek+negseek, (char **)NULL, 0); + } + else + { + if (!argv[2]) + exit_with_usage(pname); +#ifdef TRY_SEEK + if (argv[2][0] == '+') + relseek++; + if (argv[2][relseek] == '-') + negseek++; +#endif + seekoff = strtol(argv[2] + relseek+negseek, (char **)NULL, 0); + argv++; + argc--; + } + } + else if (!STRNCMP(pp, "-l", 2)) + { + if (pp[2] && STRNCMP("en", pp + 2, 2)) + length = strtol(pp + 2, (char **)NULL, 0); + else + { + if (!argv[2]) + exit_with_usage(pname); + length = strtol(argv[2], (char **)NULL, 0); + argv++; + argc--; + } + } + else if (!strcmp(pp, "--")) /* end of options */ + { + argv++; + argc--; + break; + } + else if (pp[0] == '-' && pp[1]) /* unknown option */ + exit_with_usage(pname); + else + break; /* not an option */ + + argv++; /* advance to next argument */ + argc--; + } + + if (!cols) + switch (hextype) + { + case HEX_POSTSCRIPT: cols = 30; break; + case HEX_CINCLUDE: cols = 12; break; + case HEX_BITS: cols = 6; break; + case HEX_NORMAL: + default: cols = 16; break; + } + + if (octspergrp < 0) + switch (hextype) + { + case HEX_BITS: octspergrp = 1; break; + case HEX_NORMAL: octspergrp = 2; break; + case HEX_POSTSCRIPT: + case HEX_CINCLUDE: + default: octspergrp = 0; break; + } + + if (cols < 1 || ((hextype == HEX_NORMAL || hextype == HEX_BITS) + && (cols > COLS))) + { + fprintf(stderr, "%s: invalid number of columns (max. %d).\n", pname, COLS); + exit(1); + } + + if (octspergrp < 1) + octspergrp = cols; + + if (argc > 3) + exit_with_usage(pname); + + if (argc == 1 || (argv[1][0] == '-' && !argv[1][1])) + BIN_ASSIGN(fp = stdin, !revert); + else + { + if ((fp = fopen(argv[1], BIN_READ(!revert))) == NULL) + { + fprintf(stderr,"%s: ", pname); + perror(argv[1]); + return 2; + } + } + + if (argc < 3 || (argv[2][0] == '-' && !argv[2][1])) + BIN_ASSIGN(fpo = stdout, revert); + else + { + int fd; + int mode = revert ? O_WRONLY : (O_TRUNC|O_WRONLY); + + if (((fd = OPEN(argv[2], mode | BIN_CREAT(revert), 0666)) < 0) || + (fpo = fdopen(fd, BIN_WRITE(revert))) == NULL) + { + fprintf(stderr, "%s: ", pname); + perror(argv[2]); + return 3; + } + rewind(fpo); + } + + if (revert) + { + if (hextype && (hextype != HEX_POSTSCRIPT)) + { + fprintf(stderr, "%s: sorry, cannot revert this type of hexdump\n", pname); + return -1; + } + return huntype(fp, fpo, stderr, pname, cols, hextype, + negseek ? -seekoff : seekoff); + } + + if (seekoff || negseek || !relseek) + { +#ifdef TRY_SEEK + if (relseek) + e = fseek(fp, negseek ? -seekoff : seekoff, 1); + else + e = fseek(fp, negseek ? -seekoff : seekoff, negseek ? 2 : 0); + if (e < 0 && negseek) + { + fprintf(stderr, "%s: sorry cannot seek.\n", pname); + return 4; + } + if (e >= 0) + seekoff = ftell(fp); + else +#endif + { + long s = seekoff; + + while (s--) + (void)getc(fp); + } + } + + if (hextype == HEX_CINCLUDE) + { + if (fp != stdin) + { + fprintf(fpo, "unsigned char %s", isdigit((int)argv[1][0]) ? "__" : ""); + for (e = 0; (c = argv[1][e]) != 0; e++) + putc(isalnum(c) ? c : '_', fpo); + fputs("[] = {\n", fpo); + } + + p = 0; + while ((length < 0 || p < length) && (c = getc(fp)) != EOF) + { + fprintf(fpo, (hexx == hexxa) ? "%s0x%02x" : "%s0X%02X", + (p % cols) ? ", " : ",\n "+2*!p, c); + p++; + } + + if (p) + fputs("\n};\n"+3*(fp == stdin), fpo); + + if (fp != stdin) + { + fprintf(fpo, "unsigned int %s", isdigit((int)argv[1][0]) ? "__" : ""); + for (e = 0; (c = argv[1][e]) != 0; e++) + putc(isalnum(c) ? c : '_', fpo); + fprintf(fpo, "_len = %d;\n", p); + } + + fclose(fp); + fclose(fpo); + return 0; + } + + if (hextype == HEX_POSTSCRIPT) + { + p = cols; + while ((length < 0 || n < length) && (e = getc(fp)) != EOF) + { + putchar(hexx[(e >> 4) & 0xf]); + putchar(hexx[(e ) & 0xf]); + n++; + if (!--p) + { + putchar('\n'); + p = cols; + } + } + if (p < cols) + putchar('\n'); + fclose(fp); + fclose(fpo); + return 0; + } + + /* hextype: HEX_NORMAL or HEX_BITS */ + + if (hextype == HEX_NORMAL) + grplen = octspergrp + octspergrp + 1; /* chars per octet group */ + else /* hextype == HEX_BITS */ + grplen = 8 * octspergrp + 1; + + while ((length < 0 || n < length) && (e = getc(fp)) != EOF) + { + if (p == 0) + { + sprintf(l, "%07lx: ", n + seekoff); + for (c = 9; c < LLEN; l[c++] = ' '); + } + if (hextype == HEX_NORMAL) + { + l[c = (9 + (grplen * p) / octspergrp)] = hexx[(e >> 4) & 0xf]; + l[++c] = hexx[ e & 0xf]; + } + else /* hextype == HEX_BITS */ + { + int i; + + c = (9 + (grplen * p) / octspergrp) - 1; + for (i = 7; i >= 0; i--) + l[++c] = (e & (1 << i)) ? '1' : '0'; + } + if (ebcdic) + e = (e < 64) ? '.' : etoa64[e-64]; + /* When changing this update definition of LLEN above. */ + l[11 + (grplen * cols - 1)/octspergrp + p] = +#ifdef __MVS__ + (e >= 64) +#else + (e > 31 && e < 127) +#endif + ? e : '.'; + if (e) + nonzero++; + n++; + if (++p == cols) + { + l[c = (11 + (grplen * cols - 1)/octspergrp + p)] = '\n'; l[++c] = '\0'; + xxdline(fpo, l, autoskip ? nonzero : 1); + nonzero = 0; + p = 0; + } + } + if (p) + { + l[c = (11 + (grplen * cols - 1)/octspergrp + p)] = '\n'; l[++c] = '\0'; + xxdline(fpo, l, 1); + } + else if (autoskip) + xxdline(fpo, l, -1); /* last chance to flush out suppressed lines */ + + fclose(fp); + fclose(fpo); + return 0; +} diff --git a/ggml/src/ggml-kompute/kompute/kompute-config.cmake b/ggml/src/ggml-kompute/kompute/kompute-config.cmake new file mode 100644 index 0000000000000000000000000000000000000000..10425252ce476ff740cf3f0cf78a9d70c75f9e98 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/kompute-config.cmake @@ -0,0 +1,28 @@ +# General purpose GPU compute framework built on Vulkan to +# support 1000s of cross vendor graphics cards +# (AMD, Qualcomm, NVIDIA & friends). Blazing fast, mobile-enabled, +# asynchronous and optimized for advanced GPU data processing use cases. +# Backed by the Linux Foundation. +# +# Finding this module will define the following variables: +# KOMPUTE_FOUND - True if the core library has been found +# KOMPUTE_LIBRARIES - Path to the core library archive +# KOMPUTE_INCLUDE_DIRS - Path to the include directories. Gives access +# to kompute.h, as a single include which must be included in every +# file that uses this interface. Else it also points to the +# directory for individual includes. + +find_path(KOMPUTE_INCLUDE_DIR + NAMES kompute.h) + +find_library(KOMPUTE_LIBRARY + NAMES kompute + HINTS ${KOMPUTE_LIBRARY_ROOT}) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(KOMPUTE REQUIRED_VARS KOMPUTE_LIBRARY KOMPUTE_INCLUDE_DIR) + +if(KOMPUTE_FOUND) + set(KOMPUTE_LIBRARIES ${KOMPUTE_LIBRARY}) + set(KOMPUTE_INCLUDE_DIRS ${KOMPUTE_INCLUDE_DIR}) +endif() diff --git a/ggml/src/ggml-kompute/kompute/pylintrc b/ggml/src/ggml-kompute/kompute/pylintrc new file mode 100644 index 0000000000000000000000000000000000000000..f22b2db5537b3d4afbeef8264ff8aafecf1fc4e0 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/pylintrc @@ -0,0 +1,2 @@ +[TYPECHECK] +generated-members=sh diff --git a/ggml/src/ggml-kompute/kompute/python/CMakeLists.txt b/ggml/src/ggml-kompute/kompute/python/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1b4598fb74b588cbdd2b26a0d46ccf70b7b6cdc1 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/python/CMakeLists.txt @@ -0,0 +1,9 @@ + +pybind11_add_module(kp src/main.cpp) + +include_directories( + ${PROJECT_SOURCE_DIR}/include/) + +target_link_libraries( + kp PRIVATE + kompute::kompute) diff --git a/ggml/src/ggml-kompute/kompute/python/README.md b/ggml/src/ggml-kompute/kompute/python/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fd0a55ce9f374f25d3d0d841d1bd698ccc8c548d --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/python/README.md @@ -0,0 +1,3 @@ +# Python Package for Kompute + +Read the documentation under `/docs/overview` for details on the python package. diff --git a/ggml/src/ggml-kompute/kompute/python/src/docstrings.hpp b/ggml/src/ggml-kompute/kompute/python/src/docstrings.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fb1def637c810095049baf096845b02a14d4dc09 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/python/src/docstrings.hpp @@ -0,0 +1,879 @@ +/* + This file contains docstrings for use in the Python bindings. + Do not edit! They were automatically extracted by pybind11_mkdoc. + */ + +#define __EXPAND(x) x +#define __COUNT(_1, _2, _3, _4, _5, _6, _7, COUNT, ...) COUNT +#define __VA_SIZE(...) __EXPAND(__COUNT(__VA_ARGS__, 7, 6, 5, 4, 3, 2, 1)) +#define __CAT1(a, b) a ## b +#define __CAT2(a, b) __CAT1(a, b) +#define __DOC1(n1) __doc_##n1 +#define __DOC2(n1, n2) __doc_##n1##_##n2 +#define __DOC3(n1, n2, n3) __doc_##n1##_##n2##_##n3 +#define __DOC4(n1, n2, n3, n4) __doc_##n1##_##n2##_##n3##_##n4 +#define __DOC5(n1, n2, n3, n4, n5) __doc_##n1##_##n2##_##n3##_##n4##_##n5 +#define __DOC6(n1, n2, n3, n4, n5, n6) __doc_##n1##_##n2##_##n3##_##n4##_##n5##_##n6 +#define __DOC7(n1, n2, n3, n4, n5, n6, n7) __doc_##n1##_##n2##_##n3##_##n4##_##n5##_##n6##_##n7 +#define DOC(...) __EXPAND(__EXPAND(__CAT2(__DOC, __VA_SIZE(__VA_ARGS__)))(__VA_ARGS__)) + +#if defined(__GNUG__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-variable" +#endif + + +static const char *__doc_kp_Algorithm = +R"doc(Abstraction for compute shaders that are run on top of tensors grouped +via ParameterGroups (which group descriptorsets))doc"; + +static const char *__doc_kp_Algorithm_Algorithm = +R"doc(Main constructor for algorithm with configuration parameters to create +the underlying resources. + +@param device The Vulkan device to use for creating resources @param +tensors (optional) The tensors to use to create the descriptor +resources @param spirv (optional) The spirv code to use to create the +algorithm @param workgroup (optional) The kp::Workgroup to use for the +dispatch which defaults to kp::Workgroup(tensor[0].size(), 1, 1) if +not set. @param specializationConstants (optional) The std::vector +to use to initialize the specialization constants which cannot be +changed once set. @param pushConstants (optional) The std::vector to +use when initializing the pipeline, which set the size of the push +constants - these can be modified but all new values must have the +same vector size as this initial value.)doc"; + +static const char *__doc_kp_Algorithm_createParameters = R"doc()doc"; + +static const char *__doc_kp_Algorithm_createPipeline = R"doc()doc"; + +static const char *__doc_kp_Algorithm_createShaderModule = R"doc()doc"; + +static const char *__doc_kp_Algorithm_destroy = R"doc()doc"; + +static const char *__doc_kp_Algorithm_getPush = +R"doc(Gets the specialization constants of the current algorithm. + +@returns The std::vector currently set for push constants)doc"; + +static const char *__doc_kp_Algorithm_getSpecializationConstants = +R"doc(Gets the specialization constants of the current algorithm. + +@returns The std::vector currently set for specialization constants)doc"; + +static const char *__doc_kp_Algorithm_getTensors = +R"doc(Gets the current tensors that are used in the algorithm. + +@returns The list of tensors used in the algorithm.)doc"; + +static const char *__doc_kp_Algorithm_getWorkgroup = +R"doc(Gets the current workgroup from the algorithm. + +@param The kp::Constant to use to set the push constants to use in the +next bindPush(...) calls. The constants provided must be of the same +size as the ones created during initialization.)doc"; + +static const char *__doc_kp_Algorithm_isInit = +R"doc(function that checks all the gpu resource components to verify if +these have been created and returns true if all are valid. + +@returns returns true if the algorithm is currently initialized.)doc"; + +static const char *__doc_kp_Algorithm_mDescriptorPool = R"doc()doc"; + +static const char *__doc_kp_Algorithm_mDescriptorSet = R"doc()doc"; + +static const char *__doc_kp_Algorithm_mDescriptorSetLayout = R"doc()doc"; + +static const char *__doc_kp_Algorithm_mDevice = R"doc()doc"; + +static const char *__doc_kp_Algorithm_mFreeDescriptorPool = R"doc()doc"; + +static const char *__doc_kp_Algorithm_mFreeDescriptorSet = R"doc()doc"; + +static const char *__doc_kp_Algorithm_mFreeDescriptorSetLayout = R"doc()doc"; + +static const char *__doc_kp_Algorithm_mFreePipeline = R"doc()doc"; + +static const char *__doc_kp_Algorithm_mFreePipelineCache = R"doc()doc"; + +static const char *__doc_kp_Algorithm_mFreePipelineLayout = R"doc()doc"; + +static const char *__doc_kp_Algorithm_mFreeShaderModule = R"doc()doc"; + +static const char *__doc_kp_Algorithm_mPipeline = R"doc()doc"; + +static const char *__doc_kp_Algorithm_mPipelineCache = R"doc()doc"; + +static const char *__doc_kp_Algorithm_mPipelineLayout = R"doc()doc"; + +static const char *__doc_kp_Algorithm_mPushConstants = R"doc()doc"; + +static const char *__doc_kp_Algorithm_mShaderModule = R"doc()doc"; + +static const char *__doc_kp_Algorithm_mSpecializationConstants = R"doc()doc"; + +static const char *__doc_kp_Algorithm_mSpirv = R"doc()doc"; + +static const char *__doc_kp_Algorithm_mTensors = R"doc()doc"; + +static const char *__doc_kp_Algorithm_mWorkgroup = R"doc()doc"; + +static const char *__doc_kp_Algorithm_rebuild = +R"doc(Rebuild function to reconstruct algorithm with configuration +parameters to create the underlying resources. + +@param tensors The tensors to use to create the descriptor resources +@param spirv The spirv code to use to create the algorithm @param +workgroup (optional) The kp::Workgroup to use for the dispatch which +defaults to kp::Workgroup(tensor[0].size(), 1, 1) if not set. @param +specializationConstants (optional) The std::vector to use to +initialize the specialization constants which cannot be changed once +set. @param pushConstants (optional) The std::vector to use when +initializing the pipeline, which set the size of the push constants - +these can be modified but all new values must have the same vector +size as this initial value.)doc"; + +static const char *__doc_kp_Algorithm_recordBindCore = +R"doc(Records command that binds the "core" algorithm components which +consist of binding the pipeline and binding the descriptorsets. + +@param commandBuffer Command buffer to record the algorithm resources +to)doc"; + +static const char *__doc_kp_Algorithm_recordBindPush = +R"doc(Records command that binds the push constants to the command buffer +provided - it is required that the pushConstants provided are of the +same size as the ones provided during initialization. + +@param commandBuffer Command buffer to record the algorithm resources +to)doc"; + +static const char *__doc_kp_Algorithm_recordDispatch = +R"doc(Records the dispatch function with the provided template parameters or +alternatively using the size of the tensor by default. + +@param commandBuffer Command buffer to record the algorithm resources +to)doc"; + +static const char *__doc_kp_Algorithm_setPush = +R"doc(Sets the push constants to the new value provided to use in the next +bindPush() + +@param The kp::Constant to use to set the push constants to use in the +next bindPush(...) calls. The constants provided must be of the same +size as the ones created during initialization.)doc"; + +static const char *__doc_kp_Algorithm_setWorkgroup = +R"doc(Sets the work group to use in the recordDispatch + +@param workgroup The kp::Workgroup value to use to update the +algorithm. It must have a value greater than 1 on the x value (index +1) otherwise it will be initialized on the size of the first tensor +(ie. this->mTensor[0]->size()))doc"; + +static const char *__doc_kp_Manager = +R"doc(Base orchestrator which creates and manages device and child +components)doc"; + +static const char *__doc_kp_Manager_Manager = +R"doc(Base constructor and default used which creates the base resources +including choosing the device 0 by default.)doc"; + +static const char *__doc_kp_Manager_Manager_2 = +R"doc(Similar to base constructor but allows for further configuration to +use when creating the Vulkan resources. + +@param physicalDeviceIndex The index of the physical device to use +@param familyQueueIndices (Optional) List of queue indices to add for +explicit allocation @param desiredExtensions The desired extensions to +load from physicalDevice)doc"; + +static const char *__doc_kp_Manager_Manager_3 = +R"doc(Manager constructor which allows your own vulkan application to +integrate with the kompute use. + +@param instance Vulkan compute instance to base this application +@param physicalDevice Vulkan physical device to use for application +@param device Vulkan logical device to use for all base resources +@param physicalDeviceIndex Index for vulkan physical device used)doc"; + +static const char *__doc_kp_Manager_algorithm = +R"doc(Create a managed algorithm that will be destroyed by this manager if +it hasn't been destroyed by its reference count going to zero. + +@param tensors (optional) The tensors to initialise the algorithm with +@param spirv (optional) The SPIRV bytes for the algorithm to dispatch +@param workgroup (optional) kp::Workgroup for algorithm to use, and +defaults to (tensor[0].size(), 1, 1) @param specializationConstants +(optional) kp::Constant to use for specialization constants, and +defaults to an empty constant @param pushConstants (optional) +kp::Constant to use for push constants, and defaults to an empty +constant @returns Shared pointer with initialised algorithm)doc"; + +static const char *__doc_kp_Manager_clear = +R"doc(Run a pseudo-garbage collection to release all the managed resources +that have been already freed due to these reaching to zero ref count.)doc"; + +static const char *__doc_kp_Manager_createDevice = R"doc()doc"; + +static const char *__doc_kp_Manager_createInstance = R"doc()doc"; + +static const char *__doc_kp_Manager_destroy = R"doc(Destroy the GPU resources and all managed resources by manager.)doc"; + +static const char *__doc_kp_Manager_mComputeQueueFamilyIndices = R"doc()doc"; + +static const char *__doc_kp_Manager_mComputeQueues = R"doc()doc"; + +static const char *__doc_kp_Manager_mDevice = R"doc()doc"; + +static const char *__doc_kp_Manager_mFreeDevice = R"doc()doc"; + +static const char *__doc_kp_Manager_mFreeInstance = R"doc()doc"; + +static const char *__doc_kp_Manager_mInstance = R"doc()doc"; + +static const char *__doc_kp_Manager_mManageResources = R"doc()doc"; + +static const char *__doc_kp_Manager_mManagedAlgorithms = R"doc()doc"; + +static const char *__doc_kp_Manager_mManagedSequences = R"doc()doc"; + +static const char *__doc_kp_Manager_mManagedTensors = R"doc()doc"; + +static const char *__doc_kp_Manager_mPhysicalDevice = R"doc()doc"; + +static const char *__doc_kp_Manager_sequence = +R"doc(Create a managed sequence that will be destroyed by this manager if it +hasn't been destroyed by its reference count going to zero. + +@param queueIndex The queue to use from the available queues @param +nrOfTimestamps The maximum number of timestamps to allocate. If zero +(default), disables latching of timestamps. @returns Shared pointer +with initialised sequence)doc"; + +static const char *__doc_kp_Manager_tensor = R"doc()doc"; + +static const char *__doc_kp_Manager_tensor_2 = R"doc()doc"; + +static const char *__doc_kp_Manager_tensorT = +R"doc(Create a managed tensor that will be destroyed by this manager if it +hasn't been destroyed by its reference count going to zero. + +@param data The data to initialize the tensor with @param tensorType +The type of tensor to initialize @returns Shared pointer with +initialised tensor)doc"; + +static const char *__doc_kp_OpAlgoDispatch = +R"doc(Operation that provides a general abstraction that simplifies the use +of algorithm and parameter components which can be used with shaders. +By default it enables the user to provide a dynamic number of tensors +which are then passed as inputs.)doc"; + +static const char *__doc_kp_OpAlgoDispatch_OpAlgoDispatch = +R"doc(Constructor that stores the algorithm to use as well as the relevant +push constants to override when recording. + +@param algorithm The algorithm object to use for dispatch @param +pushConstants The push constants to use for override)doc"; + +static const char *__doc_kp_OpAlgoDispatch_mAlgorithm = R"doc()doc"; + +static const char *__doc_kp_OpAlgoDispatch_mPushConstants = R"doc()doc"; + +static const char *__doc_kp_OpAlgoDispatch_postEval = +R"doc(Does not perform any postEval commands. + +@param commandBuffer The command buffer to record the command into.)doc"; + +static const char *__doc_kp_OpAlgoDispatch_preEval = +R"doc(Does not perform any preEval commands. + +@param commandBuffer The command buffer to record the command into.)doc"; + +static const char *__doc_kp_OpAlgoDispatch_record = +R"doc(This records the commands that are to be sent to the GPU. This +includes the barriers that ensure the memory has been copied before +going in and out of the shader, as well as the dispatch operation that +sends the shader processing to the gpu. This function also records the +GPU memory copy of the output data for the staging buffer so it can be +read by the host. + +@param commandBuffer The command buffer to record the command into.)doc"; + +static const char *__doc_kp_OpBase = +R"doc(Base Operation which provides the high level interface that Kompute +operations implement in order to perform a set of actions in the GPU. + +Operations can perform actions on tensors, and optionally can also own +an Algorithm with respective parameters. kp::Operations with +kp::Algorithms would inherit from kp::OpBaseAlgo.)doc"; + +static const char *__doc_kp_OpBase_postEval = +R"doc(Post eval is called after the Sequence has called eval and submitted +the commands to the GPU for processing, and can be used to perform any +tear-down steps required as the computation iteration finishes. It's +worth noting that there are situations where eval can be called +multiple times, so the resources that are destroyed should not require +a re-init unless explicitly provided by the user. + +@param commandBuffer The command buffer to record the command into.)doc"; + +static const char *__doc_kp_OpBase_preEval = +R"doc(Pre eval is called before the Sequence has called eval and submitted +the commands to the GPU for processing, and can be used to perform any +per-eval setup steps required as the computation iteration begins. +It's worth noting that there are situations where eval can be called +multiple times, so the resources that are created should be idempotent +in case it's called multiple times in a row. + +@param commandBuffer The command buffer to record the command into.)doc"; + +static const char *__doc_kp_OpBase_record = +R"doc(The record function is intended to only send a record command or run +commands that are expected to record operations that are to be +submitted as a batch into the GPU. + +@param commandBuffer The command buffer to record the command into.)doc"; + +static const char *__doc_kp_OpMult = +R"doc(Operation that performs multiplication on two tensors and outpus on +third tensor.)doc"; + +static const char *__doc_kp_OpMult_OpMult = +R"doc(Default constructor with parameters that provides the bare minimum +requirements for the operations to be able to create and manage their +sub-components. + +@param tensors Tensors that are to be used in this operation @param +algorithm An algorithm that will be overridden with the OpMult shader +data and the tensors provided which are expected to be 3)doc"; + +static const char *__doc_kp_OpTensorCopy = +R"doc(Operation that copies the data from the first tensor to the rest of +the tensors provided, using a record command for all the vectors. This +operation does not own/manage the memory of the tensors passed to it. +The operation must only receive tensors of type)doc"; + +static const char *__doc_kp_OpTensorCopy_OpTensorCopy = +R"doc(Default constructor with parameters that provides the core vulkan +resources and the tensors that will be used in the operation. + +@param tensors Tensors that will be used to create in operation.)doc"; + +static const char *__doc_kp_OpTensorCopy_mTensors = R"doc()doc"; + +static const char *__doc_kp_OpTensorCopy_postEval = +R"doc(Copies the local vectors for all the tensors to sync the data with the +gpu. + +@param commandBuffer The command buffer to record the command into.)doc"; + +static const char *__doc_kp_OpTensorCopy_preEval = +R"doc(Does not perform any preEval commands. + +@param commandBuffer The command buffer to record the command into.)doc"; + +static const char *__doc_kp_OpTensorCopy_record = +R"doc(Records the copy commands from the first tensor into all the other +tensors provided. Also optionally records a barrier. + +@param commandBuffer The command buffer to record the command into.)doc"; + +static const char *__doc_kp_OpTensorSyncDevice = +R"doc(Operation that syncs tensor's device by mapping local data into the +device memory. For TensorTypes::eDevice it will use a record operation +for the memory to be syncd into GPU memory which means that the +operation will be done in sync with GPU commands. For +TensorTypes::eHost it will only map the data into host memory which +will happen during preEval before the recorded commands are +dispatched.)doc"; + +static const char *__doc_kp_OpTensorSyncDevice_OpTensorSyncDevice = +R"doc(Default constructor with parameters that provides the core vulkan +resources and the tensors that will be used in the operation. The +tensos provided cannot be of type TensorTypes::eStorage. + +@param tensors Tensors that will be used to create in operation.)doc"; + +static const char *__doc_kp_OpTensorSyncDevice_mTensors = R"doc()doc"; + +static const char *__doc_kp_OpTensorSyncDevice_postEval = +R"doc(Does not perform any postEval commands. + +@param commandBuffer The command buffer to record the command into.)doc"; + +static const char *__doc_kp_OpTensorSyncDevice_preEval = +R"doc(Does not perform any preEval commands. + +@param commandBuffer The command buffer to record the command into.)doc"; + +static const char *__doc_kp_OpTensorSyncDevice_record = +R"doc(For device tensors, it records the copy command for the tensor to copy +the data from its staging to device memory. + +@param commandBuffer The command buffer to record the command into.)doc"; + +static const char *__doc_kp_OpTensorSyncLocal = +R"doc(Operation that syncs tensor's local memory by mapping device data into +the local CPU memory. For TensorTypes::eDevice it will use a record +operation for the memory to be syncd into GPU memory which means that +the operation will be done in sync with GPU commands. For +TensorTypes::eHost it will only map the data into host memory which +will happen during preEval before the recorded commands are +dispatched.)doc"; + +static const char *__doc_kp_OpTensorSyncLocal_OpTensorSyncLocal = +R"doc(Default constructor with parameters that provides the core vulkan +resources and the tensors that will be used in the operation. The +tensors provided cannot be of type TensorTypes::eStorage. + +@param tensors Tensors that will be used to create in operation.)doc"; + +static const char *__doc_kp_OpTensorSyncLocal_mTensors = R"doc()doc"; + +static const char *__doc_kp_OpTensorSyncLocal_postEval = +R"doc(For host tensors it performs the map command from the host memory into +local memory. + +@param commandBuffer The command buffer to record the command into.)doc"; + +static const char *__doc_kp_OpTensorSyncLocal_preEval = +R"doc(Does not perform any preEval commands. + +@param commandBuffer The command buffer to record the command into.)doc"; + +static const char *__doc_kp_OpTensorSyncLocal_record = +R"doc(For device tensors, it records the copy command for the tensor to copy +the data from its device to staging memory. + +@param commandBuffer The command buffer to record the command into.)doc"; + +static const char *__doc_kp_Sequence = R"doc(Container of operations that can be sent to GPU as batch)doc"; + +static const char *__doc_kp_Sequence_Sequence = +R"doc(Main constructor for sequence which requires core vulkan components to +generate all dependent resources. + +@param physicalDevice Vulkan physical device @param device Vulkan +logical device @param computeQueue Vulkan compute queue @param +queueIndex Vulkan compute queue index in device @param totalTimestamps +Maximum number of timestamps to allocate)doc"; + +static const char *__doc_kp_Sequence_begin = +R"doc(Begins recording commands for commands to be submitted into the +command buffer. + +@return Boolean stating whether execution was successful.)doc"; + +static const char *__doc_kp_Sequence_clear = +R"doc(Clear function clears all operations currently recorded and starts +recording again.)doc"; + +static const char *__doc_kp_Sequence_createCommandBuffer = R"doc()doc"; + +static const char *__doc_kp_Sequence_createCommandPool = R"doc()doc"; + +static const char *__doc_kp_Sequence_createTimestampQueryPool = R"doc()doc"; + +static const char *__doc_kp_Sequence_destroy = +R"doc(Destroys and frees the GPU resources which include the buffer and +memory and sets the sequence as init=False.)doc"; + +static const char *__doc_kp_Sequence_end = +R"doc(Ends the recording and stops recording commands when the record +command is sent. + +@return Boolean stating whether execution was successful.)doc"; + +static const char *__doc_kp_Sequence_eval = +R"doc(Eval sends all the recorded and stored operations in the vector of +operations into the gpu as a submit job synchronously (with a +barrier). + +@return shared_ptr of the Sequence class itself)doc"; + +static const char *__doc_kp_Sequence_eval_2 = +R"doc(Resets all the recorded and stored operations, records the operation +provided and submits into the gpu as a submit job synchronously (with +a barrier). + +@return shared_ptr of the Sequence class itself)doc"; + +static const char *__doc_kp_Sequence_eval_3 = +R"doc(Eval sends all the recorded and stored operations in the vector of +operations into the gpu as a submit job with a barrier. + +@param tensors Vector of tensors to use for the operation @param TArgs +Template parameters that are used to initialise operation which allows +for extensible configurations on initialisation. @return +shared_ptr of the Sequence class itself)doc"; + +static const char *__doc_kp_Sequence_eval_4 = +R"doc(Eval sends all the recorded and stored operations in the vector of +operations into the gpu as a submit job with a barrier. + +@param algorithm Algorithm to use for the record often used for OpAlgo +operations @param TArgs Template parameters that are used to +initialise operation which allows for extensible configurations on +initialisation. @return shared_ptr of the Sequence class +itself)doc"; + +static const char *__doc_kp_Sequence_evalAsync = +R"doc(Eval Async sends all the recorded and stored operations in the vector +of operations into the gpu as a submit job without a barrier. +EvalAwait() must ALWAYS be called after to ensure the sequence is +terminated correctly. + +@return Boolean stating whether execution was successful.)doc"; + +static const char *__doc_kp_Sequence_evalAsync_2 = +R"doc(Clears currnet operations to record provided one in the vector of +operations into the gpu as a submit job without a barrier. EvalAwait() +must ALWAYS be called after to ensure the sequence is terminated +correctly. + +@return Boolean stating whether execution was successful.)doc"; + +static const char *__doc_kp_Sequence_evalAsync_3 = +R"doc(Eval sends all the recorded and stored operations in the vector of +operations into the gpu as a submit job with a barrier. + +@param tensors Vector of tensors to use for the operation @param TArgs +Template parameters that are used to initialise operation which allows +for extensible configurations on initialisation. @return +shared_ptr of the Sequence class itself)doc"; + +static const char *__doc_kp_Sequence_evalAsync_4 = +R"doc(Eval sends all the recorded and stored operations in the vector of +operations into the gpu as a submit job with a barrier. + +@param algorithm Algorithm to use for the record often used for OpAlgo +operations @param TArgs Template parameters that are used to +initialise operation which allows for extensible configurations on +initialisation. @return shared_ptr of the Sequence class +itself)doc"; + +static const char *__doc_kp_Sequence_evalAwait = +R"doc(Eval Await waits for the fence to finish processing and then once it +finishes, it runs the postEval of all operations. + +@param waitFor Number of milliseconds to wait before timing out. +@return shared_ptr of the Sequence class itself)doc"; + +static const char *__doc_kp_Sequence_getTimestamps = +R"doc(Return the timestamps that were latched at the beginning and after +each operation during the last eval() call.)doc"; + +static const char *__doc_kp_Sequence_isInit = +R"doc(Returns true if the sequence has been initialised, and it's based on +the GPU resources being refrenced. + +@return Boolean stating if is initialized)doc"; + +static const char *__doc_kp_Sequence_isRecording = +R"doc(Returns true if the sequence is currently in recording activated. + +@return Boolean stating if recording ongoing.)doc"; + +static const char *__doc_kp_Sequence_isRunning = +R"doc(Returns true if the sequence is currently running - mostly used for +async workloads. + +@return Boolean stating if currently running.)doc"; + +static const char *__doc_kp_Sequence_mCommandBuffer = R"doc()doc"; + +static const char *__doc_kp_Sequence_mCommandPool = R"doc()doc"; + +static const char *__doc_kp_Sequence_mComputeQueue = R"doc()doc"; + +static const char *__doc_kp_Sequence_mDevice = R"doc()doc"; + +static const char *__doc_kp_Sequence_mFence = R"doc()doc"; + +static const char *__doc_kp_Sequence_mFreeCommandBuffer = R"doc()doc"; + +static const char *__doc_kp_Sequence_mFreeCommandPool = R"doc()doc"; + +static const char *__doc_kp_Sequence_mIsRunning = R"doc()doc"; + +static const char *__doc_kp_Sequence_mOperations = R"doc()doc"; + +static const char *__doc_kp_Sequence_mPhysicalDevice = R"doc()doc"; + +static const char *__doc_kp_Sequence_mQueueIndex = R"doc()doc"; + +static const char *__doc_kp_Sequence_mRecording = R"doc()doc"; + +static const char *__doc_kp_Sequence_record = +R"doc(Record function for operation to be added to the GPU queue in batch. +This template requires classes to be derived from the OpBase class. +This function also requires the Sequence to be recording, otherwise it +will not be able to add the operation. + +@param op Object derived from kp::BaseOp that will be recoreded by the +sequence which will be used when the operation is evaluated. @return +shared_ptr of the Sequence class itself)doc"; + +static const char *__doc_kp_Sequence_record_2 = +R"doc(Record function for operation to be added to the GPU queue in batch. +This template requires classes to be derived from the OpBase class. +This function also requires the Sequence to be recording, otherwise it +will not be able to add the operation. + +@param tensors Vector of tensors to use for the operation @param TArgs +Template parameters that are used to initialise operation which allows +for extensible configurations on initialisation. @return +shared_ptr of the Sequence class itself)doc"; + +static const char *__doc_kp_Sequence_record_3 = +R"doc(Record function for operation to be added to the GPU queue in batch. +This template requires classes to be derived from the OpBase class. +This function also requires the Sequence to be recording, otherwise it +will not be able to add the operation. + +@param algorithm Algorithm to use for the record often used for OpAlgo +operations @param TArgs Template parameters that are used to +initialise operation which allows for extensible configurations on +initialisation. @return shared_ptr of the Sequence class +itself)doc"; + +static const char *__doc_kp_Sequence_rerecord = +R"doc(Clears command buffer and triggers re-record of all the current +operations saved, which is useful if the underlying kp::Tensors or +kp::Algorithms are modified and need to be re-recorded.)doc"; + +static const char *__doc_kp_Sequence_timestampQueryPool = R"doc()doc"; + +static const char *__doc_kp_Shader = R"doc(Shader utily class with functions to compile and process glsl files.)doc"; + +static const char *__doc_kp_Shader_compileSource = +R"doc(Compile a single glslang source from string value. Currently this +function uses the glslang C++ interface which is not thread safe so +this funciton should not be called from multiple threads concurrently. +If you have a online shader processing multithreading use-case that +can't use offline compilation please open an issue. + +@param source An individual raw glsl shader in string format @param +entryPoint The function name to use as entry point @param definitions +List of pairs containing key value definitions @param resourcesLimit A +list that contains the resource limits for the GLSL compiler @return +The compiled SPIR-V binary in unsigned int32 format)doc"; + +static const char *__doc_kp_Shader_compileSources = +R"doc(Compile multiple sources with optional filenames. Currently this +function uses the glslang C++ interface which is not thread safe so +this funciton should not be called from multiple threads concurrently. +If you have a online shader processing multithreading use-case that +can't use offline compilation please open an issue. + +@param sources A list of raw glsl shaders in string format @param +files A list of file names respective to each of the sources @param +entryPoint The function name to use as entry point @param definitions +List of pairs containing key value definitions @param resourcesLimit A +list that contains the resource limits for the GLSL compiler @return +The compiled SPIR-V binary in unsigned int32 format)doc"; + +static const char *__doc_kp_Tensor = +R"doc(Structured data used in GPU operations. + +Tensors are the base building block in Kompute to perform operations +across GPUs. Each tensor would have a respective Vulkan memory and +buffer, which would be used to store their respective data. The +tensors can be used for GPU data storage or transfer.)doc"; + +static const char *__doc_kp_TensorT = R"doc()doc"; + +static const char *__doc_kp_TensorT_TensorT = R"doc()doc"; + +static const char *__doc_kp_TensorT_data = R"doc()doc"; + +static const char *__doc_kp_TensorT_dataType = R"doc()doc"; + +static const char *__doc_kp_TensorT_operator_array = R"doc()doc"; + +static const char *__doc_kp_TensorT_setData = R"doc()doc"; + +static const char *__doc_kp_TensorT_vector = R"doc()doc"; + +static const char *__doc_kp_Tensor_Tensor = +R"doc(Constructor with data provided which would be used to create the +respective vulkan buffer and memory. + +@param physicalDevice The physical device to use to fetch properties +@param device The device to use to create the buffer and memory from +@param data Non-zero-sized vector of data that will be used by the +tensor @param tensorTypes Type for the tensor which is of type +TensorTypes)doc"; + +static const char *__doc_kp_Tensor_TensorDataTypes = R"doc()doc"; + +static const char *__doc_kp_Tensor_TensorDataTypes_eBool = R"doc()doc"; + +static const char *__doc_kp_Tensor_TensorDataTypes_eDouble = R"doc()doc"; + +static const char *__doc_kp_Tensor_TensorDataTypes_eFloat = R"doc()doc"; + +static const char *__doc_kp_Tensor_TensorDataTypes_eInt = R"doc()doc"; + +static const char *__doc_kp_Tensor_TensorDataTypes_eUnsignedInt = R"doc()doc"; + +static const char *__doc_kp_Tensor_TensorTypes = +R"doc(Type for tensors created: Device allows memory to be transferred from +staging buffers. Staging are host memory visible. Storage are device +visible but are not set up to transfer or receive data (only for +shader storage).)doc"; + +static const char *__doc_kp_Tensor_TensorTypes_eDevice = R"doc(< Type is device memory, source and destination)doc"; + +static const char *__doc_kp_Tensor_TensorTypes_eHost = R"doc(< Type is host memory, source and destination)doc"; + +static const char *__doc_kp_Tensor_TensorTypes_eStorage = R"doc(< Type is Device memory (only))doc"; + +static const char *__doc_kp_Tensor_allocateBindMemory = R"doc()doc"; + +static const char *__doc_kp_Tensor_allocateMemoryCreateGPUResources = R"doc()doc"; + +static const char *__doc_kp_Tensor_constructDescriptorBufferInfo = +R"doc(Constructs a vulkan descriptor buffer info which can be used to +specify and reference the underlying buffer component of the tensor +without exposing it. + +@return Descriptor buffer info with own buffer)doc"; + +static const char *__doc_kp_Tensor_createBuffer = R"doc()doc"; + +static const char *__doc_kp_Tensor_data = R"doc()doc"; + +static const char *__doc_kp_Tensor_dataType = +R"doc(Retrieve the underlying data type of the Tensor + +@return Data type of tensor of type kp::Tensor::TensorDataTypes)doc"; + +static const char *__doc_kp_Tensor_dataTypeMemorySize = R"doc()doc"; + +static const char *__doc_kp_Tensor_destroy = +R"doc(Destroys and frees the GPU resources which include the buffer and +memory.)doc"; + +static const char *__doc_kp_Tensor_getPrimaryBufferUsageFlags = R"doc()doc"; + +static const char *__doc_kp_Tensor_getPrimaryMemoryPropertyFlags = R"doc()doc"; + +static const char *__doc_kp_Tensor_getStagingBufferUsageFlags = R"doc()doc"; + +static const char *__doc_kp_Tensor_getStagingMemoryPropertyFlags = R"doc()doc"; + +static const char *__doc_kp_Tensor_isInit = +R"doc(Check whether tensor is initialized based on the created gpu +resources. + +@returns Boolean stating whether tensor is initialized)doc"; + +static const char *__doc_kp_Tensor_mDataType = R"doc()doc"; + +static const char *__doc_kp_Tensor_mDataTypeMemorySize = R"doc()doc"; + +static const char *__doc_kp_Tensor_mDevice = R"doc()doc"; + +static const char *__doc_kp_Tensor_mFreePrimaryBuffer = R"doc()doc"; + +static const char *__doc_kp_Tensor_mFreePrimaryMemory = R"doc()doc"; + +static const char *__doc_kp_Tensor_mFreeStagingBuffer = R"doc()doc"; + +static const char *__doc_kp_Tensor_mFreeStagingMemory = R"doc()doc"; + +static const char *__doc_kp_Tensor_mPhysicalDevice = R"doc()doc"; + +static const char *__doc_kp_Tensor_mPrimaryBuffer = R"doc()doc"; + +static const char *__doc_kp_Tensor_mPrimaryMemory = R"doc()doc"; + +static const char *__doc_kp_Tensor_mRawData = R"doc()doc"; + +static const char *__doc_kp_Tensor_mSize = R"doc()doc"; + +static const char *__doc_kp_Tensor_mStagingBuffer = R"doc()doc"; + +static const char *__doc_kp_Tensor_mStagingMemory = R"doc()doc"; + +static const char *__doc_kp_Tensor_mTensorType = R"doc()doc"; + +static const char *__doc_kp_Tensor_mapRawData = R"doc()doc"; + +static const char *__doc_kp_Tensor_memorySize = R"doc()doc"; + +static const char *__doc_kp_Tensor_rawData = R"doc()doc"; + +static const char *__doc_kp_Tensor_rebuild = +R"doc(Function to trigger reinitialisation of the tensor buffer and memory +with new data as well as new potential device type. + +@param data Vector of data to use to initialise vector from @param +tensorType The type to use for the tensor)doc"; + +static const char *__doc_kp_Tensor_recordBufferMemoryBarrier = +R"doc(Records the buffer memory barrier into the command buffer which +ensures that relevant data transfers are carried out correctly. + +@param commandBuffer Vulkan Command Buffer to record the commands into +@param srcAccessMask Access flags for source access mask @param +dstAccessMask Access flags for destination access mask @param +scrStageMask Pipeline stage flags for source stage mask @param +dstStageMask Pipeline stage flags for destination stage mask)doc"; + +static const char *__doc_kp_Tensor_recordCopyBuffer = R"doc()doc"; + +static const char *__doc_kp_Tensor_recordCopyFrom = +R"doc(Records a copy from the memory of the tensor provided to the current +thensor. This is intended to pass memory into a processing, to perform +a staging buffer transfer, or to gather output (between others). + +@param commandBuffer Vulkan Command Buffer to record the commands into +@param copyFromTensor Tensor to copy the data from @param +createBarrier Whether to create a barrier that ensures the data is +copied before further operations. Default is true.)doc"; + +static const char *__doc_kp_Tensor_recordCopyFromDeviceToStaging = +R"doc(Records a copy from the internal device memory to the staging memory +using an optional barrier to wait for the operation. This function +would only be relevant for kp::Tensors of type eDevice. + +@param commandBuffer Vulkan Command Buffer to record the commands into +@param createBarrier Whether to create a barrier that ensures the data +is copied before further operations. Default is true.)doc"; + +static const char *__doc_kp_Tensor_recordCopyFromStagingToDevice = +R"doc(Records a copy from the internal staging memory to the device memory +using an optional barrier to wait for the operation. This function +would only be relevant for kp::Tensors of type eDevice. + +@param commandBuffer Vulkan Command Buffer to record the commands into +@param createBarrier Whether to create a barrier that ensures the data +is copied before further operations. Default is true.)doc"; + +static const char *__doc_kp_Tensor_setRawData = +R"doc(Sets / resets the vector data of the tensor. This function does not +perform any copies into GPU memory and is only performed on the host.)doc"; + +static const char *__doc_kp_Tensor_size = +R"doc(Returns the size/magnitude of the Tensor, which will be the total +number of elements across all dimensions + +@return Unsigned integer representing the total number of elements)doc"; + +static const char *__doc_kp_Tensor_tensorType = +R"doc(Retrieve the tensor type of the Tensor + +@return Tensor type of tensor)doc"; + +static const char *__doc_kp_Tensor_unmapRawData = R"doc()doc"; + +static const char *__doc_kp_Tensor_vector = R"doc()doc"; + +#if defined(__GNUG__) +#pragma GCC diagnostic pop +#endif + diff --git a/ggml/src/ggml-kompute/kompute/python/src/main.cpp b/ggml/src/ggml-kompute/kompute/python/src/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6c0f640c6c901d7a46ff80b121f1e7f5cff2f4b9 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/python/src/main.cpp @@ -0,0 +1,567 @@ +#include +#include +#include + +#include + +#include + +#include "docstrings.hpp" +#include "utils.hpp" + +namespace py = pybind11; + +// used in Core.hpp +py::object kp_trace, kp_debug, kp_info, kp_warning, kp_error; + +std::unique_ptr +opAlgoDispatchPyInit(std::shared_ptr& algorithm, + const py::array& push_consts) +{ + const py::buffer_info info = push_consts.request(); + KP_LOG_DEBUG("Kompute Python Manager creating tensor_T with push_consts " + "size {} dtype {}", + push_consts.size(), + std::string(py::str(push_consts.dtype()))); + + if (push_consts.dtype().is(py::dtype::of())) { + std::vector dataVec((float*)info.ptr, + ((float*)info.ptr) + info.size); + return std::unique_ptr{ new kp::OpAlgoDispatch( + algorithm, dataVec) }; + } else if (push_consts.dtype().is(py::dtype::of())) { + std::vector dataVec((uint32_t*)info.ptr, + ((uint32_t*)info.ptr) + info.size); + return std::unique_ptr{ new kp::OpAlgoDispatch( + algorithm, dataVec) }; + } else if (push_consts.dtype().is(py::dtype::of())) { + std::vector dataVec((int32_t*)info.ptr, + ((int32_t*)info.ptr) + info.size); + return std::unique_ptr{ new kp::OpAlgoDispatch( + algorithm, dataVec) }; + } else if (push_consts.dtype().is(py::dtype::of())) { + std::vector dataVec((double*)info.ptr, + ((double*)info.ptr) + info.size); + return std::unique_ptr{ new kp::OpAlgoDispatch( + algorithm, dataVec) }; + } else { + throw std::runtime_error("Kompute Python no valid dtype supported"); + } +} + +PYBIND11_MODULE(kp, m) +{ + + // The logging modules are used in the Kompute.hpp file + py::module_ logging = py::module_::import("logging"); + py::object kp_logger = logging.attr("getLogger")("kp"); + kp_trace = kp_logger.attr( + "debug"); // Same as for debug since python has no trace logging level + kp_debug = kp_logger.attr("debug"); + kp_info = kp_logger.attr("info"); + kp_warning = kp_logger.attr("warning"); + kp_error = kp_logger.attr("error"); + logging.attr("basicConfig")(); + + py::module_ np = py::module_::import("numpy"); + + py::enum_(m, "TensorTypes") + .value("device", + kp::Tensor::TensorTypes::eDevice, + DOC(kp, Tensor, TensorTypes, eDevice)) + .value("host", + kp::Tensor::TensorTypes::eHost, + DOC(kp, Tensor, TensorTypes, eHost)) + .value("storage", + kp::Tensor::TensorTypes::eStorage, + DOC(kp, Tensor, TensorTypes, eStorage)) + .export_values(); + + py::class_>( + m, "OpBase", DOC(kp, OpBase)); + + py::class_>( + m, "OpTensorSyncDevice", DOC(kp, OpTensorSyncDevice)) + .def(py::init>&>(), + DOC(kp, OpTensorSyncDevice, OpTensorSyncDevice)); + + py::class_>( + m, "OpTensorSyncLocal", DOC(kp, OpTensorSyncLocal)) + .def(py::init>&>(), + DOC(kp, OpTensorSyncLocal, OpTensorSyncLocal)); + + py::class_>( + m, "OpTensorCopy", DOC(kp, OpTensorCopy)) + .def(py::init>&>(), + DOC(kp, OpTensorCopy, OpTensorCopy)); + + py::class_>( + m, "OpAlgoDispatch", DOC(kp, OpAlgoDispatch)) + .def(py::init&, + const std::vector&>(), + DOC(kp, OpAlgoDispatch, OpAlgoDispatch), + py::arg("algorithm"), + py::arg("push_consts") = std::vector()) + .def(py::init(&opAlgoDispatchPyInit), + DOC(kp, OpAlgoDispatch, OpAlgoDispatch), + py::arg("algorithm"), + py::arg("push_consts")); + + py::class_>( + m, "OpMult", DOC(kp, OpMult)) + .def(py::init>&, + const std::shared_ptr&>(), + DOC(kp, OpMult, OpMult)); + + py::class_>( + m, "Algorithm", DOC(kp, Algorithm, Algorithm)) + .def("get_tensors", + &kp::Algorithm::getTensors, + DOC(kp, Algorithm, getTensors)) + .def("destroy", &kp::Algorithm::destroy, DOC(kp, Algorithm, destroy)) + .def("is_init", &kp::Algorithm::isInit, DOC(kp, Algorithm, isInit)); + + py::class_>( + m, "Tensor", DOC(kp, Tensor)) + .def( + "data", + [](kp::Tensor& self) { + // Non-owning container exposing the underlying pointer + switch (self.dataType()) { + case kp::Tensor::TensorDataTypes::eFloat: + return py::array( + self.size(), self.data(), py::cast(&self)); + case kp::Tensor::TensorDataTypes::eUnsignedInt: + return py::array( + self.size(), self.data(), py::cast(&self)); + case kp::Tensor::TensorDataTypes::eInt: + return py::array( + self.size(), self.data(), py::cast(&self)); + case kp::Tensor::TensorDataTypes::eDouble: + return py::array( + self.size(), self.data(), py::cast(&self)); + case kp::Tensor::TensorDataTypes::eBool: + return py::array( + self.size(), self.data(), py::cast(&self)); + default: + throw std::runtime_error( + "Kompute Python data type not supported"); + } + }, + DOC(kp, Tensor, data)) + .def("size", &kp::Tensor::size, DOC(kp, Tensor, size)) + .def("__len__", &kp::Tensor::size, DOC(kp, Tensor, size)) + .def("tensor_type", &kp::Tensor::tensorType, DOC(kp, Tensor, tensorType)) + .def("data_type", &kp::Tensor::dataType, DOC(kp, Tensor, dataType)) + .def("is_init", &kp::Tensor::isInit, DOC(kp, Tensor, isInit)) + .def("destroy", &kp::Tensor::destroy, DOC(kp, Tensor, destroy)); + + py::class_>(m, "Sequence") + .def( + "record", + [](kp::Sequence& self, std::shared_ptr op) { + return self.record(op); + }, + DOC(kp, Sequence, record)) + .def( + "eval", + [](kp::Sequence& self) { return self.eval(); }, + DOC(kp, Sequence, eval)) + .def( + "eval", + [](kp::Sequence& self, std::shared_ptr op) { + return self.eval(op); + }, + DOC(kp, Sequence, eval_2)) + .def( + "eval_async", + [](kp::Sequence& self) { return self.eval(); }, + DOC(kp, Sequence, evalAwait)) + .def( + "eval_async", + [](kp::Sequence& self, std::shared_ptr op) { + return self.evalAsync(op); + }, + DOC(kp, Sequence, evalAsync)) + .def( + "eval_await", + [](kp::Sequence& self) { return self.evalAwait(); }, + DOC(kp, Sequence, evalAwait)) + .def( + "eval_await", + [](kp::Sequence& self, uint32_t wait) { return self.evalAwait(wait); }, + DOC(kp, Sequence, evalAwait)) + .def("is_recording", + &kp::Sequence::isRecording, + DOC(kp, Sequence, isRecording)) + .def("is_running", &kp::Sequence::isRunning, DOC(kp, Sequence, isRunning)) + .def("is_init", &kp::Sequence::isInit, DOC(kp, Sequence, isInit)) + .def("clear", &kp::Sequence::clear, DOC(kp, Sequence, clear)) + .def("rerecord", &kp::Sequence::rerecord, DOC(kp, Sequence, rerecord)) + .def("get_timestamps", + &kp::Sequence::getTimestamps, + DOC(kp, Sequence, getTimestamps)) + .def("destroy", &kp::Sequence::destroy, DOC(kp, Sequence, destroy)); + + py::class_>( + m, "Manager", DOC(kp, Manager)) + .def(py::init(), DOC(kp, Manager, Manager)) + .def(py::init(), DOC(kp, Manager, Manager_2)) + .def(py::init&, + const std::vector&>(), + DOC(kp, Manager, Manager_2), + py::arg("device") = 0, + py::arg("family_queue_indices") = std::vector(), + py::arg("desired_extensions") = std::vector()) + .def("destroy", &kp::Manager::destroy, DOC(kp, Manager, destroy)) + .def("sequence", + &kp::Manager::sequence, + DOC(kp, Manager, sequence), + py::arg("queue_index") = 0, + py::arg("total_timestamps") = 0) + .def( + "tensor", + [np](kp::Manager& self, + const py::array_t& data, + kp::Tensor::TensorTypes tensor_type) { + const py::array_t& flatdata = np.attr("ravel")(data); + const py::buffer_info info = flatdata.request(); + KP_LOG_DEBUG("Kompute Python Manager tensor() creating tensor " + "float with data size {}", + flatdata.size()); + return self.tensor(info.ptr, + flatdata.size(), + sizeof(float), + kp::Tensor::TensorDataTypes::eFloat, + tensor_type); + }, + DOC(kp, Manager, tensor), + py::arg("data"), + py::arg("tensor_type") = kp::Tensor::TensorTypes::eDevice) + .def( + "tensor_t", + [np](kp::Manager& self, + const py::array& data, + kp::Tensor::TensorTypes tensor_type) { + // TODO: Suppport strides in numpy format + const py::array& flatdata = np.attr("ravel")(data); + const py::buffer_info info = flatdata.request(); + KP_LOG_DEBUG("Kompute Python Manager creating tensor_T with data " + "size {} dtype {}", + flatdata.size(), + std::string(py::str(flatdata.dtype()))); + if (flatdata.dtype().is(py::dtype::of())) { + return self.tensor(info.ptr, + flatdata.size(), + sizeof(float), + kp::Tensor::TensorDataTypes::eFloat, + tensor_type); + } else if (flatdata.dtype().is(py::dtype::of())) { + return self.tensor(info.ptr, + flatdata.size(), + sizeof(uint32_t), + kp::Tensor::TensorDataTypes::eUnsignedInt, + tensor_type); + } else if (flatdata.dtype().is(py::dtype::of())) { + return self.tensor(info.ptr, + flatdata.size(), + sizeof(int32_t), + kp::Tensor::TensorDataTypes::eInt, + tensor_type); + } else if (flatdata.dtype().is(py::dtype::of())) { + return self.tensor(info.ptr, + flatdata.size(), + sizeof(double), + kp::Tensor::TensorDataTypes::eDouble, + tensor_type); + } else if (flatdata.dtype().is(py::dtype::of())) { + return self.tensor(info.ptr, + flatdata.size(), + sizeof(bool), + kp::Tensor::TensorDataTypes::eBool, + tensor_type); + } else { + throw std::runtime_error( + "Kompute Python no valid dtype supported"); + } + }, + DOC(kp, Manager, tensorT), + py::arg("data"), + py::arg("tensor_type") = kp::Tensor::TensorTypes::eDevice) + .def( + "algorithm", + [](kp::Manager& self, + const std::vector>& tensors, + const py::bytes& spirv, + const kp::Workgroup& workgroup, + const std::vector& spec_consts, + const std::vector& push_consts) { + py::buffer_info info(py::buffer(spirv).request()); + const char* data = reinterpret_cast(info.ptr); + size_t length = static_cast(info.size); + std::vector spirvVec((uint32_t*)data, + (uint32_t*)(data + length)); + return self.algorithm( + tensors, spirvVec, workgroup, spec_consts, push_consts); + }, + DOC(kp, Manager, algorithm), + py::arg("tensors"), + py::arg("spirv"), + py::arg("workgroup") = kp::Workgroup(), + py::arg("spec_consts") = std::vector(), + py::arg("push_consts") = std::vector()) + .def( + "algorithm", + [np](kp::Manager& self, + const std::vector>& tensors, + const py::bytes& spirv, + const kp::Workgroup& workgroup, + const py::array& spec_consts, + const py::array& push_consts) { + py::buffer_info info(py::buffer(spirv).request()); + const char* data = reinterpret_cast(info.ptr); + size_t length = static_cast(info.size); + std::vector spirvVec((uint32_t*)data, + (uint32_t*)(data + length)); + + const py::buffer_info pushInfo = push_consts.request(); + const py::buffer_info specInfo = spec_consts.request(); + + KP_LOG_DEBUG("Kompute Python Manager creating Algorithm_T with " + "push consts data size {} dtype {} and spec const " + "data size {} dtype {}", + push_consts.size(), + std::string(py::str(push_consts.dtype())), + spec_consts.size(), + std::string(py::str(spec_consts.dtype()))); + + // We have to iterate across a combination of parameters due to the + // lack of support for templating + if (spec_consts.dtype().is(py::dtype::of())) { + std::vector specConstsVec( + (float*)specInfo.ptr, ((float*)specInfo.ptr) + specInfo.size); + if (spec_consts.dtype().is(py::dtype::of())) { + std::vector pushConstsVec((float*)pushInfo.ptr, + ((float*)pushInfo.ptr) + + pushInfo.size); + return self.algorithm(tensors, + spirvVec, + workgroup, + specConstsVec, + pushConstsVec); + } else if (spec_consts.dtype().is( + py::dtype::of())) { + std::vector pushConstsVec( + (int32_t*)pushInfo.ptr, + ((int32_t*)pushInfo.ptr) + pushInfo.size); + return self.algorithm(tensors, + spirvVec, + workgroup, + specConstsVec, + pushConstsVec); + } else if (spec_consts.dtype().is( + py::dtype::of())) { + std::vector pushConstsVec( + (uint32_t*)pushInfo.ptr, + ((uint32_t*)pushInfo.ptr) + pushInfo.size); + return self.algorithm(tensors, + spirvVec, + workgroup, + specConstsVec, + pushConstsVec); + } else if (spec_consts.dtype().is( + py::dtype::of())) { + std::vector pushConstsVec((double*)pushInfo.ptr, + ((double*)pushInfo.ptr) + + pushInfo.size); + return self.algorithm(tensors, + spirvVec, + workgroup, + specConstsVec, + pushConstsVec); + } + } else if (spec_consts.dtype().is(py::dtype::of())) { + std::vector specconstsvec((int32_t*)specInfo.ptr, + ((int32_t*)specInfo.ptr) + + specInfo.size); + if (spec_consts.dtype().is(py::dtype::of())) { + std::vector pushconstsvec((float*)pushInfo.ptr, + ((float*)pushInfo.ptr) + + pushInfo.size); + return self.algorithm(tensors, + spirvVec, + workgroup, + specconstsvec, + pushconstsvec); + } else if (spec_consts.dtype().is( + py::dtype::of())) { + std::vector pushconstsvec( + (int32_t*)pushInfo.ptr, + ((int32_t*)pushInfo.ptr) + pushInfo.size); + return self.algorithm(tensors, + spirvVec, + workgroup, + specconstsvec, + pushconstsvec); + } else if (spec_consts.dtype().is( + py::dtype::of())) { + std::vector pushconstsvec( + (uint32_t*)pushInfo.ptr, + ((uint32_t*)pushInfo.ptr) + pushInfo.size); + return self.algorithm(tensors, + spirvVec, + workgroup, + specconstsvec, + pushconstsvec); + } else if (spec_consts.dtype().is( + py::dtype::of())) { + std::vector pushconstsvec((double*)pushInfo.ptr, + ((double*)pushInfo.ptr) + + pushInfo.size); + return self.algorithm(tensors, + spirvVec, + workgroup, + specconstsvec, + pushconstsvec); + } + } else if (spec_consts.dtype().is(py::dtype::of())) { + std::vector specconstsvec((uint32_t*)specInfo.ptr, + ((uint32_t*)specInfo.ptr) + + specInfo.size); + if (spec_consts.dtype().is(py::dtype::of())) { + std::vector pushconstsvec((float*)pushInfo.ptr, + ((float*)pushInfo.ptr) + + pushInfo.size); + return self.algorithm(tensors, + spirvVec, + workgroup, + specconstsvec, + pushconstsvec); + } else if (spec_consts.dtype().is( + py::dtype::of())) { + std::vector pushconstsvec( + (int32_t*)pushInfo.ptr, + ((int32_t*)pushInfo.ptr) + pushInfo.size); + return self.algorithm(tensors, + spirvVec, + workgroup, + specconstsvec, + pushconstsvec); + } else if (spec_consts.dtype().is( + py::dtype::of())) { + std::vector pushconstsvec( + (uint32_t*)pushInfo.ptr, + ((uint32_t*)pushInfo.ptr) + pushInfo.size); + return self.algorithm(tensors, + spirvVec, + workgroup, + specconstsvec, + pushconstsvec); + } else if (spec_consts.dtype().is( + py::dtype::of())) { + std::vector pushconstsvec((double*)pushInfo.ptr, + ((double*)pushInfo.ptr) + + pushInfo.size); + return self.algorithm(tensors, + spirvVec, + workgroup, + specconstsvec, + pushconstsvec); + } + } else if (spec_consts.dtype().is(py::dtype::of())) { + std::vector specconstsvec((double*)specInfo.ptr, + ((double*)specInfo.ptr) + + specInfo.size); + if (spec_consts.dtype().is(py::dtype::of())) { + std::vector pushconstsvec((float*)pushInfo.ptr, + ((float*)pushInfo.ptr) + + pushInfo.size); + return self.algorithm(tensors, + spirvVec, + workgroup, + specconstsvec, + pushconstsvec); + } else if (spec_consts.dtype().is( + py::dtype::of())) { + std::vector pushconstsvec((int32_t*)pushInfo.ptr, + ((int32_t*)pushInfo.ptr) + + pushInfo.size); + return self.algorithm(tensors, + spirvVec, + workgroup, + specconstsvec, + pushconstsvec); + } else if (spec_consts.dtype().is( + py::dtype::of())) { + std::vector pushconstsvec((uint32_t*)pushInfo.ptr, + ((uint32_t*)pushInfo.ptr) + + pushInfo.size); + return self.algorithm(tensors, + spirvVec, + workgroup, + specconstsvec, + pushconstsvec); + } else if (spec_consts.dtype().is( + py::dtype::of())) { + std::vector pushconstsvec((double*)pushInfo.ptr, + ((double*)pushInfo.ptr) + + pushInfo.size); + return self.algorithm(tensors, + spirvVec, + workgroup, + specconstsvec, + pushconstsvec); + } + } + // If reach then no valid dtype supported + throw std::runtime_error("Kompute Python no valid dtype supported"); + }, + DOC(kp, Manager, algorithm), + py::arg("tensors"), + py::arg("spirv"), + py::arg("workgroup") = kp::Workgroup(), + py::arg("spec_consts") = std::vector(), + py::arg("push_consts") = std::vector()) + .def( + "list_devices", + [](kp::Manager& self) { + const std::vector devices = self.listDevices(); + py::list list; + for (const vk::PhysicalDevice& device : devices) { + list.append(kp::py::vkPropertiesToDict(device.getProperties())); + } + return list; + }, + "Return a dict containing information about the device") + .def( + "get_device_properties", + [](kp::Manager& self) { + const vk::PhysicalDeviceProperties properties = + self.getDeviceProperties(); + + return kp::py::vkPropertiesToDict(properties); + }, + "Return a dict containing information about the device"); + + auto atexit = py::module_::import("atexit"); + atexit.attr("register")(py::cpp_function([]() { + kp_trace = py::none(); + kp_debug = py::none(); + kp_info = py::none(); + kp_warning = py::none(); + kp_error = py::none(); + })); + +#ifdef VERSION_INFO + m.attr("__version__") = VERSION_INFO; +#else + m.attr("__version__") = "dev"; +#endif +} diff --git a/ggml/src/ggml-kompute/kompute/python/src/utils.hpp b/ggml/src/ggml-kompute/kompute/python/src/utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..69cd2a653beb670df712b17d7e2b1dea39898474 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/python/src/utils.hpp @@ -0,0 +1,32 @@ + +#include +#include +#include + +using namespace pybind11::literals; // for the `_a` literal + +namespace kp { +namespace py { +static pybind11::dict +vkPropertiesToDict(const vk::PhysicalDeviceProperties& properties) +{ + std::string deviceName = properties.deviceName; + pybind11::dict pyDict( + "device_name"_a = deviceName, + "max_work_group_count"_a = + pybind11::make_tuple(properties.limits.maxComputeWorkGroupCount[0], + properties.limits.maxComputeWorkGroupCount[1], + properties.limits.maxComputeWorkGroupCount[2]), + "max_work_group_invocations"_a = + properties.limits.maxComputeWorkGroupInvocations, + "max_work_group_size"_a = + pybind11::make_tuple(properties.limits.maxComputeWorkGroupSize[0], + properties.limits.maxComputeWorkGroupSize[1], + properties.limits.maxComputeWorkGroupSize[2]), + "timestamps_supported"_a = + (bool)properties.limits.timestampComputeAndGraphics); + + return pyDict; +} +} +} diff --git a/ggml/src/ggml-kompute/kompute/python/test/__init__.py b/ggml/src/ggml-kompute/kompute/python/test/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..139597f9cb07c5d48bed18984ec4747f4b4f3438 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/python/test/__init__.py @@ -0,0 +1,2 @@ + + diff --git a/ggml/src/ggml-kompute/kompute/python/test/requirements-dev.txt b/ggml/src/ggml-kompute/kompute/python/test/requirements-dev.txt new file mode 100644 index 0000000000000000000000000000000000000000..99d811d39f7c4108b9294bebaf0777d377566cb7 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/python/test/requirements-dev.txt @@ -0,0 +1,4 @@ +pyshader==0.7.0 +numpy==1.22.4 +pytest==7.1.2 + diff --git a/ggml/src/ggml-kompute/kompute/python/test/test_array_multiplication.py b/ggml/src/ggml-kompute/kompute/python/test/test_array_multiplication.py new file mode 100644 index 0000000000000000000000000000000000000000..e8de68328f4a9f0a0844fa55f1ca24a5f6ac59fc --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/python/test/test_array_multiplication.py @@ -0,0 +1,34 @@ +import pyshader as ps +import kp +import numpy as np + + +def test_array_multiplication(): + + # 1. Create Kompute Manager (selects device 0 by default) + mgr = kp.Manager() + + # 2. Create Kompute Tensors to hold data + tensor_in_a = mgr.tensor(np.array([2, 2, 2])) + tensor_in_b = mgr.tensor(np.array([1, 2, 3])) + tensor_out = mgr.tensor(np.array([0, 0, 0])) + + params = [tensor_in_a, tensor_in_b, tensor_out] + + # 4. Define the multiplication shader code to run on the GPU + @ps.python2shader + def compute_mult(index=("input", "GlobalInvocationId", ps.ivec3), + data1=("buffer", 0, ps.Array(ps.f32)), + data2=("buffer", 1, ps.Array(ps.f32)), + data3=("buffer", 2, ps.Array(ps.f32))): + i = index.x + data3[i] = data1[i] * data2[i] + + (mgr.sequence() + .record(kp.OpTensorSyncDevice(params)) + .record(kp.OpAlgoDispatch(mgr.algorithm(params, compute_mult.to_spirv()))) + .record(kp.OpTensorSyncLocal([tensor_out])) + .eval()) + + assert tensor_out.data().tolist() == [2.0, 4.0, 6.0] + assert np.all(tensor_out.data() == [2.0, 4.0, 6.0]) diff --git a/ggml/src/ggml-kompute/kompute/python/test/test_kompute.py b/ggml/src/ggml-kompute/kompute/python/test/test_kompute.py new file mode 100644 index 0000000000000000000000000000000000000000..4b7ec5450ea6c6f55b2003f57949dd3b8a5c58ed --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/python/test/test_kompute.py @@ -0,0 +1,286 @@ +import os + +import kp +import numpy as np +import logging +import pyshader as ps + +from .utils import compile_source + +DIRNAME = os.path.dirname(os.path.abspath(__file__)) + +kp_log = logging.getLogger("kp") + + +def test_end_to_end(): + + mgr = kp.Manager() + + tensor_in_a = mgr.tensor([2, 2, 2]) + tensor_in_b = mgr.tensor([1, 2, 3]) + # Explicit type constructor supports int, in32, double, float and int + tensor_out_a = mgr.tensor_t(np.array([0, 0, 0], dtype=np.uint32)) + tensor_out_b = mgr.tensor_t(np.array([0, 0, 0], dtype=np.uint32)) + + params = [tensor_in_a, tensor_in_b, tensor_out_a, tensor_out_b] + + shader = """ + #version 450 + + layout (local_size_x = 1) in; + + // The input tensors bind index is relative to index in parameter passed + layout(set = 0, binding = 0) buffer buf_in_a { float in_a[]; }; + layout(set = 0, binding = 1) buffer buf_in_b { float in_b[]; }; + layout(set = 0, binding = 2) buffer buf_out_a { uint out_a[]; }; + layout(set = 0, binding = 3) buffer buf_out_b { uint out_b[]; }; + + // Kompute supports push constants updated on dispatch + layout(push_constant) uniform PushConstants { + float val; + } push_const; + + // Kompute also supports spec constants on initalization + layout(constant_id = 0) const float const_one = 0; + + void main() { + uint index = gl_GlobalInvocationID.x; + out_a[index] += uint( in_a[index] * in_b[index] ); + out_b[index] += uint( const_one * push_const.val ); + } + """ + + workgroup = (3, 1, 1) + spec_consts = [2] + push_consts_a = [2] + push_consts_b = [3] + + algo = mgr.algorithm(params, compile_source(shader), workgroup, spec_consts, push_consts_a) + + (mgr.sequence() + .record(kp.OpTensorSyncDevice(params)) + .record(kp.OpAlgoDispatch(algo)) + .record(kp.OpAlgoDispatch(algo, push_consts_b)) + .eval()) + + sq = mgr.sequence() + sq.eval_async(kp.OpTensorSyncLocal(params)) + + sq.eval_await() + + assert tensor_out_a.data().tolist() == [4, 8, 12] + assert tensor_out_b.data().tolist() == [10, 10, 10] + + +def test_shader_str(): + """ + Test basic OpAlgoBase operation + """ + + shader = """ +#version 450 +layout(set = 0, binding = 0) buffer tensorLhs {float valuesLhs[];}; +layout(set = 0, binding = 1) buffer tensorRhs {float valuesRhs[];}; +layout(set = 0, binding = 2) buffer tensorOutput { float valuesOutput[];}; +layout (local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + +void main() +{ + uint index = gl_GlobalInvocationID.x; + valuesOutput[index] = valuesLhs[index] * valuesRhs[index]; +} + """ + + spirv = compile_source(shader) + + mgr = kp.Manager() + + tensor_in_a = mgr.tensor([2, 2, 2]) + tensor_in_b = mgr.tensor([1, 2, 3]) + tensor_out = mgr.tensor([0, 0, 0]) + + params = [tensor_in_a, tensor_in_b, tensor_out] + + algo = mgr.algorithm(params, spirv) + + (mgr.sequence() + .record(kp.OpTensorSyncDevice(params)) + .record(kp.OpAlgoDispatch(algo)) + .record(kp.OpTensorSyncLocal(params)) + .eval()) + + assert tensor_out.data().tolist() == [2.0, 4.0, 6.0] + + +def test_sequence(): + """ + Test basic OpAlgoBase operation + """ + + shader = """ + #version 450 + layout(set = 0, binding = 0) buffer tensorLhs {float valuesLhs[];}; + layout(set = 0, binding = 1) buffer tensorRhs {float valuesRhs[];}; + layout(set = 0, binding = 2) buffer tensorOutput { float valuesOutput[];}; + layout (local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + + void main() + { + uint index = gl_GlobalInvocationID.x; + valuesOutput[index] = valuesLhs[index] * valuesRhs[index]; + } + """ + + spirv = compile_source(shader) + + mgr = kp.Manager(0) + + tensor_in_a = mgr.tensor([2, 2, 2]) + tensor_in_b = mgr.tensor([1, 2, 3]) + tensor_out = mgr.tensor([0, 0, 0]) + + params = [tensor_in_a, tensor_in_b, tensor_out] + + algo = mgr.algorithm(params, spirv) + + sq = mgr.sequence() + + sq.record(kp.OpTensorSyncDevice(params)) + sq.record(kp.OpAlgoDispatch(algo)) + sq.record(kp.OpTensorSyncLocal(params)) + + sq.eval() + + assert sq.is_init() == True + + sq.destroy() + + assert sq.is_init() == False + + assert tensor_out.data().tolist() == [2.0, 4.0, 6.0] + assert np.all(tensor_out.data() == [2.0, 4.0, 6.0]) + + tensor_in_a.destroy() + tensor_in_b.destroy() + tensor_out.destroy() + + assert tensor_in_a.is_init() == False + assert tensor_in_b.is_init() == False + assert tensor_out.is_init() == False + + +def test_pushconsts(): + + spirv = compile_source(""" + #version 450 + layout(push_constant) uniform PushConstants { + float x; + float y; + float z; + } pcs; + layout (local_size_x = 1) in; + layout(set = 0, binding = 0) buffer a { float pa[]; }; + void main() { + pa[0] += pcs.x; + pa[1] += pcs.y; + pa[2] += pcs.z; + } + """) + + mgr = kp.Manager() + + tensor = mgr.tensor([0, 0, 0]) + + algo = mgr.algorithm([tensor], spirv, (1, 1, 1), [], [0.1, 0.2, 0.3]) + + (mgr.sequence() + .record(kp.OpTensorSyncDevice([tensor])) + .record(kp.OpAlgoDispatch(algo)) + .record(kp.OpAlgoDispatch(algo, [0.3, 0.2, 0.1])) + .record(kp.OpAlgoDispatch(algo, [0.3, 0.2, 0.1])) + .record(kp.OpTensorSyncLocal([tensor])) + .eval()) + + assert np.allclose(tensor.data(), np.array([0.7, 0.6, 0.5], dtype=np.float32)) + + +def test_pushconsts_int(): + + spirv = compile_source(""" + #version 450 + layout(push_constant) uniform PushConstants { + int x; + int y; + int z; + } pcs; + layout (local_size_x = 1) in; + layout(set = 0, binding = 0) buffer a { int pa[]; }; + void main() { + pa[0] += pcs.x; + pa[1] += pcs.y; + pa[2] += pcs.z; + } + """) + + mgr = kp.Manager() + + tensor = mgr.tensor_t(np.array([0, 0, 0], dtype=np.int32)) + + spec_consts = np.array([], dtype=np.int32) + push_consts = np.array([-1, -1, -1], dtype=np.int32) + + algo = mgr.algorithm([tensor], spirv, (1, 1, 1), spec_consts, push_consts) + + (mgr.sequence() + .record(kp.OpTensorSyncDevice([tensor])) + .record(kp.OpAlgoDispatch(algo)) + .record(kp.OpAlgoDispatch(algo, np.array([-1, -1, -1], dtype=np.int32))) + .record(kp.OpAlgoDispatch(algo, np.array([-1, -1, -1], dtype=np.int32))) + .record(kp.OpTensorSyncLocal([tensor])) + .eval()) + + assert np.all(tensor.data() == np.array([-3, -3, -3], dtype=np.int32)) + + +def test_workgroup(): + mgr = kp.Manager(0) + + tensor_a = mgr.tensor(np.zeros([16,8])) + tensor_b = mgr.tensor(np.zeros([16,8])) + + @ps.python2shader + def compute_shader_wg(gl_idx=("input", "GlobalInvocationId", ps.ivec3), + gl_wg_id=("input", "WorkgroupId", ps.ivec3), + gl_wg_num=("input", "NumWorkgroups", ps.ivec3), + data1=("buffer", 0, ps.Array(ps.f32)), + data2=("buffer", 1, ps.Array(ps.f32))): + i = gl_wg_id.x * gl_wg_num.y + gl_wg_id.y + data1[i] = f32(gl_idx.x) + data2[i] = f32(gl_idx.y) + + algo = mgr.algorithm([tensor_a, tensor_b], compute_shader_wg.to_spirv(), (16,8,1)) + + (mgr.sequence() + .record(kp.OpTensorSyncDevice([tensor_a, tensor_b])) + .record(kp.OpAlgoDispatch(algo)) + .record(kp.OpTensorSyncLocal([tensor_a, tensor_b])) + .eval()) + + print(tensor_a.data()) + print(tensor_b.data()) + + assert np.all(tensor_a.data() == np.stack([np.arange(16)]*8, axis=1).ravel()) + assert np.all(tensor_b.data() == np.stack([np.arange(8)]*16, axis=0).ravel()) + + +def test_mgr_utils(): + mgr = kp.Manager() + + props = mgr.get_device_properties() + + assert "device_name" in props + + devices = mgr.list_devices() + + assert len(devices) > 0 + assert "device_name" in devices[0] diff --git a/ggml/src/ggml-kompute/kompute/python/test/test_logistic_regression.py b/ggml/src/ggml-kompute/kompute/python/test/test_logistic_regression.py new file mode 100644 index 0000000000000000000000000000000000000000..f7c50aa7ce199b8ddef5e000b49f2e8b302bbcb5 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/python/test/test_logistic_regression.py @@ -0,0 +1,102 @@ +import pyshader as ps +import numpy as np +import kp + +def test_logistic_regression(): + + @ps.python2shader + def compute_shader( + index = ("input", "GlobalInvocationId", ps.ivec3), + x_i = ("buffer", 0, ps.Array(ps.f32)), + x_j = ("buffer", 1, ps.Array(ps.f32)), + y = ("buffer", 2, ps.Array(ps.f32)), + w_in = ("buffer", 3, ps.Array(ps.f32)), + w_out_i = ("buffer", 4, ps.Array(ps.f32)), + w_out_j = ("buffer", 5, ps.Array(ps.f32)), + b_in = ("buffer", 6, ps.Array(ps.f32)), + b_out = ("buffer", 7, ps.Array(ps.f32)), + l_out = ("buffer", 8, ps.Array(ps.f32)), + M = ("buffer", 9, ps.Array(ps.f32))): + + i = index.x + + m = M[0] + + w_curr = vec2(w_in[0], w_in[1]) + b_curr = b_in[0] + + x_curr = vec2(x_i[i], x_j[i]) + y_curr = y[i] + + z_dot = w_curr @ x_curr + z = z_dot + b_curr + y_hat = 1.0 / (1.0 + exp(-z)) + + d_z = y_hat - y_curr + d_w = (1.0 / m) * x_curr * d_z + d_b = (1.0 / m) * d_z + + loss = -((y_curr * log(y_hat)) + ((1.0 + y_curr) * log(1.0 - y_hat))) + + w_out_i[i] = d_w.x + w_out_j[i] = d_w.y + b_out[i] = d_b + l_out[i] = loss + + + mgr = kp.Manager(0) + + # First we create input and ouput tensors for shader + tensor_x_i = mgr.tensor(np.array([0.0, 1.0, 1.0, 1.0, 1.0])) + tensor_x_j = mgr.tensor(np.array([0.0, 0.0, 0.0, 1.0, 1.0])) + + tensor_y = mgr.tensor(np.array([0.0, 0.0, 0.0, 1.0, 1.0])) + + tensor_w_in = mgr.tensor(np.array([0.001, 0.001])) + tensor_w_out_i = mgr.tensor(np.array([0.0, 0.0, 0.0, 0.0, 0.0])) + tensor_w_out_j = mgr.tensor(np.array([0.0, 0.0, 0.0, 0.0, 0.0])) + + tensor_b_in = mgr.tensor(np.array([0.0])) + tensor_b_out = mgr.tensor(np.array([0.0, 0.0, 0.0, 0.0, 0.0])) + + tensor_l_out = mgr.tensor(np.array([0.0, 0.0, 0.0, 0.0, 0.0])) + + tensor_m = mgr.tensor(np.array([ tensor_y.size() ])) + + # We store them in an array for easier interaction + params = [tensor_x_i, tensor_x_j, tensor_y, tensor_w_in, tensor_w_out_i, + tensor_w_out_j, tensor_b_in, tensor_b_out, tensor_l_out, tensor_m] + + mgr.sequence().eval(kp.OpTensorSyncDevice(params)) + + # Create a managed sequence + sq = mgr.sequence() + + # Record operation to sync memory from local to GPU memory + sq.record(kp.OpTensorSyncDevice([tensor_w_in, tensor_b_in])) + + # Record operation to execute GPU shader against all our parameters + sq.record(kp.OpAlgoDispatch(mgr.algorithm(params, compute_shader.to_spirv()))) + + # Record operation to sync memory from GPU to local memory + sq.record(kp.OpTensorSyncLocal([tensor_w_out_i, tensor_w_out_j, tensor_b_out, tensor_l_out])) + + ITERATIONS = 100 + learning_rate = 0.1 + + # Perform machine learning training and inference across all input X and Y + for i_iter in range(ITERATIONS): + + # Execute an iteration of the algorithm + sq.eval() + + # Calculate the parameters based on the respective derivatives calculated + for j_iter in range(tensor_b_out.size()): + tensor_w_in.data()[0] -= learning_rate * tensor_w_out_i.data()[j_iter] + tensor_w_in.data()[1] -= learning_rate * tensor_w_out_j.data()[j_iter] + tensor_b_in.data()[0] -= learning_rate * tensor_b_out.data()[j_iter] + + assert tensor_w_in.data()[0] < 0.01 + assert tensor_w_in.data()[0] > 0.0 + assert tensor_w_in.data()[1] > 1.5 + assert tensor_b_in.data()[0] < 0.7 diff --git a/ggml/src/ggml-kompute/kompute/python/test/test_tensor_types.py b/ggml/src/ggml-kompute/kompute/python/test/test_tensor_types.py new file mode 100644 index 0000000000000000000000000000000000000000..d615a55b3694d0a9d321f7f667c7cbba6b1720e1 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/python/test/test_tensor_types.py @@ -0,0 +1,232 @@ +import os +import pytest +import kp +import numpy as np + +from .utils import compile_source + +VK_ICD_FILENAMES = os.environ.get("VK_ICD_FILENAMES", "") + +def test_type_float(): + + shader = """ + #version 450 + layout(set = 0, binding = 0) buffer tensorLhs {float valuesLhs[];}; + layout(set = 0, binding = 1) buffer tensorRhs {float valuesRhs[];}; + layout(set = 0, binding = 2) buffer tensorOutput { float valuesOutput[];}; + layout (local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + + void main() + { + uint index = gl_GlobalInvocationID.x; + valuesOutput[index] = valuesLhs[index] * valuesRhs[index]; + } + """ + + spirv = compile_source(shader) + + arr_in_a = np.array([123., 153., 231.], dtype=np.float32) + arr_in_b = np.array([9482, 1208, 1238], dtype=np.float32) + arr_out = np.array([0, 0, 0], dtype=np.float32) + + mgr = kp.Manager() + + tensor_in_a = mgr.tensor(arr_in_a) + tensor_in_b = mgr.tensor(arr_in_b) + tensor_out = mgr.tensor(arr_out) + + params = [tensor_in_a, tensor_in_b, tensor_out] + + (mgr.sequence() + .record(kp.OpTensorSyncDevice(params)) + .record(kp.OpAlgoDispatch(mgr.algorithm(params, spirv))) + .record(kp.OpTensorSyncLocal([tensor_out])) + .eval()) + + assert np.all(tensor_out.data() == arr_in_a * arr_in_b) + + +def test_type_float_double_incorrect(): + + shader = """ + #version 450 + layout(set = 0, binding = 0) buffer tensorLhs {float valuesLhs[];}; + layout(set = 0, binding = 1) buffer tensorRhs {float valuesRhs[];}; + layout(set = 0, binding = 2) buffer tensorOutput { float valuesOutput[];}; + layout (local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + + void main() + { + uint index = gl_GlobalInvocationID.x; + valuesOutput[index] = valuesLhs[index] * valuesRhs[index]; + } + """ + + spirv = compile_source(shader) + + arr_in_a = np.array([123., 153., 231.], dtype=np.float32) + arr_in_b = np.array([9482, 1208, 1238], dtype=np.uint32) + arr_out = np.array([0, 0, 0], dtype=np.float32) + + mgr = kp.Manager() + + tensor_in_a = mgr.tensor_t(arr_in_a) + tensor_in_b = mgr.tensor_t(arr_in_b) + tensor_out = mgr.tensor_t(arr_out) + + params = [tensor_in_a, tensor_in_b, tensor_out] + + (mgr.sequence() + .record(kp.OpTensorSyncDevice(params)) + .record(kp.OpAlgoDispatch(mgr.algorithm(params, spirv))) + .record(kp.OpTensorSyncLocal([tensor_out])) + .eval()) + + assert np.all(tensor_out.data() != arr_in_a * arr_in_b) + +@pytest.mark.skipif("broadcom" in VK_ICD_FILENAMES, + reason="Broadcom doesn't support double") +@pytest.mark.skipif("swiftshader" in VK_ICD_FILENAMES, + reason="Swiftshader doesn't support double") +def test_type_double(): + + shader = """ + #version 450 + layout(set = 0, binding = 0) buffer tensorLhs { double valuesLhs[]; }; + layout(set = 0, binding = 1) buffer tensorRhs { double valuesRhs[]; }; + layout(set = 0, binding = 2) buffer tensorOutput { double valuesOutput[]; }; + layout (local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + + void main() + { + uint index = gl_GlobalInvocationID.x; + valuesOutput[index] = valuesLhs[index] * valuesRhs[index]; + } + """ + + spirv = compile_source(shader) + + arr_in_a = np.array([123., 153., 231.], dtype=np.float64) + arr_in_b = np.array([9482, 1208, 1238], dtype=np.float64) + arr_out = np.array([0, 0, 0], dtype=np.float64) + + mgr = kp.Manager() + + tensor_in_a = mgr.tensor_t(arr_in_a) + tensor_in_b = mgr.tensor_t(arr_in_b) + tensor_out = mgr.tensor_t(arr_out) + + params = [tensor_in_a, tensor_in_b, tensor_out] + + (mgr.sequence() + .record(kp.OpTensorSyncDevice(params)) + .record(kp.OpAlgoDispatch(mgr.algorithm(params, spirv))) + .record(kp.OpTensorSyncLocal([tensor_out])) + .eval()) + + print(f"Dtype value {tensor_out.data().dtype}") + + assert np.all(tensor_out.data() == arr_in_a * arr_in_b) + +def test_type_int(): + + shader = """ + #version 450 + layout(set = 0, binding = 0) buffer tensorLhs { int valuesLhs[]; }; + layout(set = 0, binding = 1) buffer tensorRhs { int valuesRhs[]; }; + layout(set = 0, binding = 2) buffer tensorOutput { int valuesOutput[]; }; + layout (local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + + void main() + { + uint index = gl_GlobalInvocationID.x; + valuesOutput[index] = valuesLhs[index] * valuesRhs[index]; + } + """ + + spirv = compile_source(shader) + + arr_in_a = np.array([123, 153, 231], dtype=np.int32) + arr_in_b = np.array([9482, 1208, 1238], dtype=np.int32) + arr_out = np.array([0, 0, 0], dtype=np.int32) + + mgr = kp.Manager() + + tensor_in_a = mgr.tensor_t(arr_in_a) + tensor_in_b = mgr.tensor_t(arr_in_b) + tensor_out = mgr.tensor_t(arr_out) + + params = [tensor_in_a, tensor_in_b, tensor_out] + + (mgr.sequence() + .record(kp.OpTensorSyncDevice(params)) + .record(kp.OpAlgoDispatch(mgr.algorithm(params, spirv))) + .record(kp.OpTensorSyncLocal([tensor_out])) + .eval()) + + print(f"Dtype value {tensor_out.data().dtype}") + + assert np.all(tensor_out.data() == arr_in_a * arr_in_b) + +def test_type_unsigned_int(): + + shader = """ + #version 450 + layout(set = 0, binding = 0) buffer tensorLhs { uint valuesLhs[]; }; + layout(set = 0, binding = 1) buffer tensorRhs { uint valuesRhs[]; }; + layout(set = 0, binding = 2) buffer tensorOutput { uint valuesOutput[]; }; + layout (local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + + void main() + { + uint index = gl_GlobalInvocationID.x; + valuesOutput[index] = valuesLhs[index] * valuesRhs[index]; + } + """ + + spirv = compile_source(shader) + + arr_in_a = np.array([123, 153, 231], dtype=np.uint32) + arr_in_b = np.array([9482, 1208, 1238], dtype=np.uint32) + arr_out = np.array([0, 0, 0], dtype=np.uint32) + + mgr = kp.Manager() + + tensor_in_a = mgr.tensor_t(arr_in_a) + tensor_in_b = mgr.tensor_t(arr_in_b) + tensor_out = mgr.tensor_t(arr_out) + + params = [tensor_in_a, tensor_in_b, tensor_out] + + (mgr.sequence() + .record(kp.OpTensorSyncDevice(params)) + .record(kp.OpAlgoDispatch(mgr.algorithm(params, spirv))) + .record(kp.OpTensorSyncLocal([tensor_out])) + .eval()) + + print(f"Dtype value {tensor_out.data().dtype}") + + assert np.all(tensor_out.data() == arr_in_a * arr_in_b) + +def test_tensor_numpy_ownership(): + + arr_in = np.array([1, 2, 3]) + + m = kp.Manager() + + t = m.tensor(arr_in) + + # This should increment refcount for tensor sharedptr + td = t.data() + + assert td.base.is_init() == True + assert np.all(td == arr_in) + + del t + + assert td.base.is_init() == True + assert np.all(td == arr_in) + + m.destroy() + + assert td.base.is_init() == False diff --git a/ggml/src/ggml-kompute/kompute/python/test/utils.py b/ggml/src/ggml-kompute/kompute/python/test/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d02f6adde711cc9307309c39737d91d18497d032 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/python/test/utils.py @@ -0,0 +1,7 @@ +import os + + +def compile_source(source): + open("tmp_kp_shader.comp", "w").write(source) + os.system("glslangValidator -V tmp_kp_shader.comp -o tmp_kp_shader.comp.spv") + return open("tmp_kp_shader.comp.spv", "rb").read() diff --git a/ggml/src/ggml-kompute/kompute/scripts/convert_shaders.py b/ggml/src/ggml-kompute/kompute/scripts/convert_shaders.py new file mode 100644 index 0000000000000000000000000000000000000000..11a3ab974d6a6d7a3d10e55a79fc75117e432ea6 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/scripts/convert_shaders.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +""" + Script to handle conversion of compute shaders to spirv and to headers +""" +import os +import sys +import logging +import click +import subprocess + +logger = logging.getLogger(__name__) +logger.addHandler(logging.StreamHandler()) + +is_windows = sys.platform.startswith('win') + +CWD=os.path.dirname(os.path.abspath(__file__)) +XXD_LINUX_CMD="xxd" +XXD_WINDOWS_CMD=os.path.abspath(os.path.join(CWD, "..\\external\\bin\\", "xxd.exe")) + +SHADER_GENERATED_NOTICE = """/* + THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT + + --- + + Copyright 2020 The Institute for Ethical AI & Machine Learning + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +""" + +@click.command() +@click.option( + "--shader-path", + "-p", + envvar="KOMPUTE_SHADER_PATH", + required=True, + help="The path for the directory to build and convert shaders", +) +@click.option( + "--shader-binary", + "-s", + envvar="KOMPUTE_SHADER_BINARY", + required=True, + help="The path for the directory to build and convert shaders", +) +@click.option( + "--header-path", + "-c", + envvar="KOMPUTE_HEADER_PATH", + default="", + required=False, + help="The (optional) output file for the cpp header files", +) +@click.option( + "--verbose", + "-v", + envvar="KOMPUTE_HEADER_PATH", + default=False, + is_flag=True, + help="Enable versbosity if flag is provided", +) +def run_cli( + shader_path: str = None, + shader_binary: str = None, + header_path: bool = None, + verbose: bool = None, +): + """ + CLI function for shader generation + """ + + if verbose: + logger.setLevel(logging.DEBUG) + else: + logger.setLevel(logging.WARNING) + + logger.debug(f"Starting script with variables: {locals()}") + + if is_windows: + logger.debug(f"Running on windows, converting input paths") + shader_path = shader_path.replace("/", "\\") + header_path = header_path.replace("/", "\\") + + shader_files = [] + for root, directory, files in os.walk(shader_path): + for file in files: + if file.endswith(".comp"): + shader_files.append(os.path.join(root, file)) + + run_cmd = lambda *args: subprocess.check_output([*args]).decode() + + logger.debug(f"Output spirv path: {shader_path}") + logger.debug(f"Converting files to spirv: {shader_files}") + + spirv_files = [] + for file in shader_files: + logger.debug(f"Converting to spirv: {file}") + spirv_file = f"{file}.spv" + run_cmd(shader_binary, "-V", file, "-o", spirv_file) + spirv_files.append(spirv_file) + + # Create cpp files if header_path provided + if header_path: + logger.debug(f"Header path provided. Converting bin files to hpp.") + logger.debug(f"Output header path: {shader_path}") + + # Check if xxd command options are available + if is_windows: + xxd_cmd = XXD_WINDOWS_CMD + else: + xxd_cmd = XXD_LINUX_CMD + + for file in spirv_files: + print(xxd_cmd) + header_data = str(run_cmd(xxd_cmd, "-i", file)) + # Ensuring the variable is a static const unsigned + header_data = header_data.replace("unsigned", "static const unsigned") + if is_windows: + raw_file_name = file.split("\\")[-1] + else: + raw_file_name = file.split("/")[-1] + file_name = f"shader{raw_file_name}" + header_file = file_name.replace(".comp.spv", ".hpp") + header_file_define = "SHADEROP_" + header_file.replace(".", "_").upper() + logger.debug(f"Converting to hpp: {file_name}") + with open(os.path.join(header_path, header_file), "w+", newline='\n') as fstream: + fstream.write(f"{SHADER_GENERATED_NOTICE}\n") + fstream.write(f"#ifndef {header_file_define}\n") + fstream.write(f"#define {header_file_define}\n\n") + fstream.write("namespace kp {\n") + fstream.write("namespace shader_data {\n") + fstream.write(f"{header_data}") + fstream.write("}\n") + fstream.write("}\n") + fstream.write(f"#endif // define {header_file_define}\n") + + +if __name__ == "__main__": + run_cli() diff --git a/ggml/src/ggml-kompute/kompute/scripts/push_folder_to_branch.sh b/ggml/src/ggml-kompute/kompute/scripts/push_folder_to_branch.sh new file mode 100644 index 0000000000000000000000000000000000000000..51e621be1a2b38f569e8cdc99c6a6fc0c87ffb43 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/scripts/push_folder_to_branch.sh @@ -0,0 +1,215 @@ +#!/usr/bin/env bash +set -o errexit #abort if any command fails +me=$(basename "$0") + +help_message="\ +Usage: $me [-c FILE] [] +Deploy generated files to a git branch. +Options: + -h, --help Show this help information. + -v, --verbose Increase verbosity. Useful for debugging. + -e, --allow-empty Allow deployment of an empty directory. + -m, --message MESSAGE Specify the message used when committing on the + deploy branch. + -n, --no-hash Don't append the source commit's hash to the deploy + commit's message. + -c, --config-file PATH Override default & environment variables' values + with those in set in the file at 'PATH'. Must be the + first option specified. +Variables: + GIT_DEPLOY_DIR Folder path containing the files to deploy. + GIT_DEPLOY_BRANCH Commit deployable files to this branch. + GIT_DEPLOY_REPO Push the deploy branch to this repository. +These variables have default values defined in the script. The defaults can be +overridden by environment variables. Any environment variables are overridden +by values set in a '.env' file (if it exists), and in turn by those set in a +file specified by the '--config-file' option." + +parse_args() { + # Set args from a local environment file. + if [ -e ".env" ]; then + source .env + fi + + # Set args from file specified on the command-line. + if [[ $1 = "-c" || $1 = "--config-file" ]]; then + source "$2" + shift 2 + fi + + # Parse arg flags + # If something is exposed as an environment variable, set/overwrite it + # here. Otherwise, set/overwrite the internal variable instead. + while : ; do + if [[ $1 = "-h" || $1 = "--help" ]]; then + echo "$help_message" + return 0 + elif [[ $1 = "-v" || $1 = "--verbose" ]]; then + verbose=true + shift + elif [[ $1 = "-e" || $1 = "--allow-empty" ]]; then + allow_empty=true + shift + elif [[ ( $1 = "-m" || $1 = "--message" ) && -n $2 ]]; then + commit_message=$2 + shift 2 + elif [[ $1 = "-n" || $1 = "--no-hash" ]]; then + GIT_DEPLOY_APPEND_HASH=false + shift + else + break + fi + done + + # Set internal option vars from the environment and arg flags. All internal + # vars should be declared here, with sane defaults if applicable. + + # Source directory & target branch. + deploy_directory=${GIT_DEPLOY_DIR:-dist} + deploy_branch=${GIT_DEPLOY_BRANCH:-gh-pages} + + #if no user identity is already set in the current git environment, use this: + default_username=${GIT_DEPLOY_USERNAME:-deploy.sh} + default_email=${GIT_DEPLOY_EMAIL:-} + + #repository to deploy to. must be readable and writable. + repo=${GIT_DEPLOY_REPO:-origin} + + #append commit hash to the end of message by default + append_hash=${GIT_DEPLOY_APPEND_HASH:-true} +} + +main() { + parse_args "$@" + + enable_expanded_output + + if ! git diff --exit-code --quiet --cached; then + echo Aborting due to uncommitted changes in the index >&2 + return 1 + fi + + commit_title=`git log -n 1 --format="%s" HEAD` + commit_hash=` git log -n 1 --format="%H" HEAD` + + #default commit message uses last title if a custom one is not supplied + if [[ -z $commit_message ]]; then + commit_message="publish: $commit_title" + fi + + #append hash to commit message unless no hash flag was found + if [ $append_hash = true ]; then + commit_message="$commit_message"$'\n\n'"generated from commit $commit_hash" + fi + + previous_branch=`git rev-parse --abbrev-ref HEAD` + + if [ ! -d "$deploy_directory" ]; then + echo "Deploy directory '$deploy_directory' does not exist. Aborting." >&2 + return 1 + fi + + # must use short form of flag in ls for compatibility with OS X and BSD + if [[ -z `ls -A "$deploy_directory" 2> /dev/null` && -z $allow_empty ]]; then + echo "Deploy directory '$deploy_directory' is empty. Aborting. If you're sure you want to deploy an empty tree, use the --allow-empty / -e flag." >&2 + return 1 + fi + + if git ls-remote --exit-code $repo "refs/heads/$deploy_branch" ; then + # deploy_branch exists in $repo; make sure we have the latest version + + disable_expanded_output + git fetch --force $repo $deploy_branch:$deploy_branch + enable_expanded_output + fi + + # check if deploy_branch exists locally + if git show-ref --verify --quiet "refs/heads/$deploy_branch" + then incremental_deploy + else initial_deploy + fi + + restore_head +} + +initial_deploy() { + git --work-tree "$deploy_directory" checkout --orphan $deploy_branch + git --work-tree "$deploy_directory" add --all + commit+push +} + +incremental_deploy() { + #make deploy_branch the current branch + git symbolic-ref HEAD refs/heads/$deploy_branch + #put the previously committed contents of deploy_branch into the index + git --work-tree "$deploy_directory" reset --mixed --quiet + git --work-tree "$deploy_directory" add --all + + set +o errexit + diff=$(git --work-tree "$deploy_directory" diff --exit-code --quiet HEAD --)$? + set -o errexit + case $diff in + 0) echo No changes to files in $deploy_directory. Skipping commit.;; + 1) commit+push;; + *) + echo git diff exited with code $diff. Aborting. Staying on branch $deploy_branch so you can debug. To switch back to master, use: git symbolic-ref HEAD refs/heads/master && git reset --mixed >&2 + return $diff + ;; + esac +} + +commit+push() { + set_user_id + git --work-tree "$deploy_directory" commit -m "$commit_message" + + disable_expanded_output + #--quiet is important here to avoid outputting the repo URL, which may contain a secret token + git push --quiet $repo $deploy_branch + enable_expanded_output +} + +#echo expanded commands as they are executed (for debugging) +enable_expanded_output() { + if [ $verbose ]; then + set -o xtrace + set +o verbose + fi +} + +#this is used to avoid outputting the repo URL, which may contain a secret token +disable_expanded_output() { + if [ $verbose ]; then + set +o xtrace + set -o verbose + fi +} + +set_user_id() { + if [[ -z `git config user.name` ]]; then + git config user.name "$default_username" + fi + if [[ -z `git config user.email` ]]; then + git config user.email "$default_email" + fi +} + +restore_head() { + if [[ $previous_branch = "HEAD" ]]; then + #we weren't on any branch before, so just set HEAD back to the commit it was on + git update-ref --no-deref HEAD $commit_hash $deploy_branch + else + git symbolic-ref HEAD refs/heads/$previous_branch + fi + + git reset --mixed +} + +filter() { + sed -e "s|$repo|\$repo|g" +} + +sanitize() { + "$@" 2> >(filter 1>&2) | filter +} + +[[ $1 = --source-only ]] || main "$@" diff --git a/ggml/src/ggml-kompute/kompute/scripts/requirements.txt b/ggml/src/ggml-kompute/kompute/scripts/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4da0425044e90fb494eaa8365b966d4c2a94fb16 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/scripts/requirements.txt @@ -0,0 +1,11 @@ +# CLI dependencies +click==7.1.2 + +# Dev dependencies +black==19.10b0 +quom==1.2.0 +Sphinx==3.2.1 +sphinx_material==0.0.30 +breathe==4.20.0 +m2r2==0.2.5 +git+git://github.com/pybind/pybind11_mkdoc.git@master diff --git a/ggml/src/ggml-kompute/kompute/setup.py b/ggml/src/ggml-kompute/kompute/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..09faa8d1a7d322b84249d2ac471ca8d8c71e34f6 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/setup.py @@ -0,0 +1,93 @@ +import os +import re +import platform +import sys +import sysconfig +import subprocess + +from setuptools import setup, Extension +from setuptools.command.build_ext import build_ext +from distutils.version import LooseVersion + +curr_dir = os.path.abspath(os.path.dirname(__file__)) +with open(os.path.join(curr_dir, 'README.md'), encoding='utf-8') as f: + long_description = f.read() + +class CMakeExtension(Extension): + def __init__(self, name, sourcedir=''): + Extension.__init__(self, name, sources=[]) + self.sourcedir = os.path.abspath(sourcedir) + + +class CMakeBuild(build_ext): + def run(self): + try: + out = subprocess.check_output(['cmake', '--version']) + except OSError: + raise RuntimeError("CMake must be installed to build the following extensions: " + + ", ".join(e.name for e in self.extensions)) + + cmake_version = LooseVersion(re.search(r'version\s*([\d.]+)', out.decode()).group(1)) + if cmake_version < '3.15': + raise RuntimeError("CMake >= 3.15 is required") + + for ext in self.extensions: + self.build_extension(ext) + + def build_extension(self, ext): + extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) + # required for auto-detection of auxiliary "native" libs + if not extdir.endswith(os.path.sep): + extdir += os.path.sep + + cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir, + '-DKOMPUTE_OPT_BUILD_PYTHON=ON', + '-DKOMPUTE_OPT_LOG_LEVEL=Off', + '-DKOMPUTE_OPT_USE_SPDLOG=Off', + '-DKOMPUTE_OPT_DISABLE_VULKAN_VERSION_CHECK=ON' + '-DPYTHON_EXECUTABLE=' + sys.executable, + '-DPYTHON_INCLUDE_DIR=' + sysconfig.get_path('include'), + '-DPYTHON_LIBRARY=' + sysconfig.get_path('stdlib'), + ] + + cfg = 'Debug' if self.debug else 'Release' + build_args = ['--config', cfg] + + env = os.environ.copy() + oldCxxFlags = env.get('CXXFLAGS', '') + env['CXXFLAGS'] = f'{oldCxxFlags} -DVERSION_INFO=\\"{self.distribution.get_version()}\\"' + + if platform.system() == "Windows": + cmake_args += [f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}'] + if sys.maxsize > 2**32: + cmake_args += ['-A', 'x64'] + build_args += ['--', '/m'] + else: + env['CXXFLAGS'] += ' -fPIC' + cmake_args += ['-DCMAKE_BUILD_TYPE=' + cfg] + build_args += ['--', '-j'] + # Optional environment variable to limit the number of parallel jobs for GitHub actions to reduce RAM usage + if 'KOMPUTE_PYTHON_NUM_PARALLEL_THREADS' in env: + build_args += env['KOMPUTE_PYTHON_NUM_PARALLEL_THREADS'] + + if not os.path.exists(self.build_temp): + os.makedirs(self.build_temp) + + subprocess.check_call(['cmake', ext.sourcedir] + cmake_args, cwd=self.build_temp, env=env) + subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp) + +setup( + name='kp', + version='0.8.1', + author='Alejandro Saucedo', + description='Kompute: Blazing fast, mobile-enabled, asynchronous, and optimized for advanced GPU processing usecases.', + long_description=long_description, + long_description_content_type='text/markdown', + ext_modules=[CMakeExtension('kp')], + install_requires=[ + "numpy<2.0.0" + ], + cmdclass=dict(build_ext=CMakeBuild), + zip_safe=False, + include_package_data=True, +) diff --git a/ggml/src/ggml-kompute/kompute/src/Algorithm.cpp b/ggml/src/ggml-kompute/kompute/src/Algorithm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c2d8554e1fc5e44c083c59268844313735271e5f --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/Algorithm.cpp @@ -0,0 +1,418 @@ +// SPDX-License-Identifier: Apache-2.0 +#include + +#include "kompute/Algorithm.hpp" + +namespace kp { + +Algorithm::~Algorithm() +{ + KP_LOG_DEBUG("Kompute Algorithm Destructor started"); + + this->destroy(); +} + +bool +Algorithm::isInit() +{ + return this->mPipeline && this->mPipelineCache && this->mPipelineLayout && + this->mDescriptorPool && this->mDescriptorSet && + this->mDescriptorSetLayout && this->mShaderModule; +} + +void +Algorithm::destroy() +{ + // We don't have to free memory on destroy as it's freed by the + // commandBuffer destructor if (this->mPushConstantsData) { + // free(this->mPushConstantsData); + // } + // if (this->mSpecializationConstantsData) { + // free(this->mSpecializationConstantsData); + // } + + if (!this->mDevice) { + KP_LOG_WARN("Kompute Algorithm destroy function reached with null " + "Device pointer"); + return; + } + + if (this->mFreePipeline && this->mPipeline) { + KP_LOG_DEBUG("Kompute Algorithm Destroying pipeline"); + if (!this->mPipeline) { + KP_LOG_WARN("Kompute Algorithm Error requested to destroy " + "pipeline but it is null"); + } + this->mDevice->destroy( + *this->mPipeline, + (vk::Optional)nullptr); + this->mPipeline = nullptr; + } + + if (this->mFreePipelineLayout && this->mPipelineLayout) { + KP_LOG_DEBUG("Kompute Algorithm Destroying pipeline layout"); + if (!this->mPipelineLayout) { + KP_LOG_WARN("Kompute Algorithm Error requested to destroy " + "pipeline layout but it is null"); + } + this->mDevice->destroy( + *this->mPipelineLayout, + (vk::Optional)nullptr); + this->mPipelineLayout = nullptr; + } + + if (this->mFreeShaderModule && this->mShaderModule) { + KP_LOG_DEBUG("Kompute Algorithm Destroying shader module"); + if (!this->mShaderModule) { + KP_LOG_WARN("Kompute Algorithm Error requested to destroy shader " + "module but it is null"); + } + this->mDevice->destroy( + *this->mShaderModule, + (vk::Optional)nullptr); + this->mShaderModule = nullptr; + } + + freeParameters(); +} + +void +Algorithm::freeParameters() +{ + if (this->mFreeDescriptorSetLayout && this->mDescriptorSetLayout) { + KP_LOG_DEBUG("Kompute Algorithm Destroying Descriptor Set Layout"); + if (!this->mDescriptorSetLayout) { + KP_LOG_WARN("Kompute Algorithm Error requested to destroy " + "descriptor set layout but it is null"); + } + this->mDevice->destroy( + *this->mDescriptorSetLayout, + (vk::Optional)nullptr); + this->mDescriptorSetLayout = nullptr; + } +} + +void +Algorithm::createParameters() +{ + KP_LOG_DEBUG("Kompute Algorithm createParameters started"); + if (!*this->mDescriptorPool) { + KP_LOG_ERROR("Kompute Algorithm can not create descriptor pool"); + return; + } + + std::vector descriptorSetBindings; + for (size_t i = 0; i < this->mTensors.size(); i++) { + descriptorSetBindings.push_back( + vk::DescriptorSetLayoutBinding(i, // Binding index + vk::DescriptorType::eStorageBuffer, + 1, // Descriptor count + vk::ShaderStageFlagBits::eCompute)); + } + + // This is the component that is fed into the pipeline + vk::DescriptorSetLayoutCreateInfo descriptorSetLayoutInfo( + vk::DescriptorSetLayoutCreateFlags(), + static_cast(descriptorSetBindings.size()), + descriptorSetBindings.data()); + + KP_LOG_DEBUG("Kompute Algorithm creating descriptor set layout"); + this->mDescriptorSetLayout = std::make_shared(); + vk::Result result = this->mDevice->createDescriptorSetLayout( + &descriptorSetLayoutInfo, nullptr, this->mDescriptorSetLayout.get()); + + if (result != vk::Result::eSuccess) { + KP_LOG_ERROR("Failed to create descriptor set layout. Error code: {}", vk::to_string(result)); + } else { + this->mFreeDescriptorSetLayout = true; + KP_LOG_DEBUG("Successfully allocated descriptor set layout."); + } + + vk::DescriptorSetAllocateInfo descriptorSetAllocateInfo( + *this->mDescriptorPool, + 1, // Descriptor set layout count + this->mDescriptorSetLayout.get()); + + KP_LOG_DEBUG("Kompute Algorithm allocating descriptor sets"); + this->mDescriptorSet = std::make_shared(); + result = this->mDevice->allocateDescriptorSets(&descriptorSetAllocateInfo, + this->mDescriptorSet.get()); + + if (result != vk::Result::eSuccess) { + KP_LOG_ERROR("Failed to allocate descriptor sets. Error code: {}", vk::to_string(result)); + } else { + this->mFreeDescriptorSet = true; + KP_LOG_DEBUG("Successfully allocated descriptor sets."); + } + + this->mFreeDescriptorSet = true; + + KP_LOG_DEBUG("Kompute Algorithm updating descriptor sets"); + for (size_t i = 0; i < this->mTensors.size(); i++) { + std::vector computeWriteDescriptorSets; + + vk::DescriptorBufferInfo descriptorBufferInfo = + this->mTensors[i]->constructDescriptorBufferInfo(); + + computeWriteDescriptorSets.push_back( + vk::WriteDescriptorSet(*this->mDescriptorSet, + i, // Destination binding + 0, // Destination array element + 1, // Descriptor count + vk::DescriptorType::eStorageBuffer, + nullptr, // Descriptor image info + &descriptorBufferInfo)); + + this->mDevice->updateDescriptorSets(computeWriteDescriptorSets, + nullptr); + } + + KP_LOG_DEBUG("Kompute Algorithm successfully run init"); +} + +void +Algorithm::updateParameters() +{ + KP_LOG_DEBUG("Kompute Algorithm updateParameters started"); + if (!*this->mDescriptorPool) { + KP_LOG_ERROR("Kompute Algorithm can not create descriptor pool"); + return; + } + + vk::DescriptorSetAllocateInfo descriptorSetAllocateInfo( + *this->mDescriptorPool, + 1, // Descriptor set layout count + this->mDescriptorSetLayout.get()); + + KP_LOG_DEBUG("Kompute Algorithm allocating descriptor sets"); + this->mDescriptorSet = std::make_shared(); + vk::Result result = this->mDevice->allocateDescriptorSets(&descriptorSetAllocateInfo, + this->mDescriptorSet.get()); + + if (result != vk::Result::eSuccess) { + KP_LOG_ERROR("Failed to allocate descriptor sets. Error code: {}", vk::to_string(result)); + } else { + this->mFreeDescriptorSet = true; + KP_LOG_DEBUG("Successfully allocated descriptor sets."); + } + + this->mFreeDescriptorSet = true; + + KP_LOG_DEBUG("Kompute Algorithm updating descriptor sets"); + for (size_t i = 0; i < this->mTensors.size(); i++) { + std::vector computeWriteDescriptorSets; + + vk::DescriptorBufferInfo descriptorBufferInfo = + this->mTensors[i]->constructDescriptorBufferInfo(); + + computeWriteDescriptorSets.push_back( + vk::WriteDescriptorSet(*this->mDescriptorSet, + i, // Destination binding + 0, // Destination array element + 1, // Descriptor count + vk::DescriptorType::eStorageBuffer, + nullptr, // Descriptor image info + &descriptorBufferInfo)); + + this->mDevice->updateDescriptorSets(computeWriteDescriptorSets, + nullptr); + } + + KP_LOG_DEBUG("Kompute Algorithm successfully run init"); +} + +void +Algorithm::createShaderModule() +{ + KP_LOG_DEBUG("Kompute Algorithm createShaderModule started"); + + vk::ShaderModuleCreateInfo shaderModuleInfo(vk::ShaderModuleCreateFlags(), + sizeof(uint32_t) * + this->mSpirv.size(), + this->mSpirv.data()); + + KP_LOG_DEBUG("Kompute Algorithm Creating shader module. ShaderFileSize: {}", + this->mSpirv.size()); + this->mFreeShaderModule = true; + this->mShaderModule = std::make_shared(); + this->mDevice->createShaderModule( + &shaderModuleInfo, nullptr, this->mShaderModule.get()); + this->mFreeShaderModule = true; + + KP_LOG_DEBUG("Kompute Algorithm create shader module success"); +} + +void +Algorithm::createPipeline() +{ + KP_LOG_DEBUG("Kompute Algorithm calling create Pipeline"); + + vk::PipelineLayoutCreateInfo pipelineLayoutInfo( + vk::PipelineLayoutCreateFlags(), + 1, // Set layout count + this->mDescriptorSetLayout.get()); + + vk::PushConstantRange pushConstantRange; + if (this->mPushConstantsSize) { + pushConstantRange.setStageFlags(vk::ShaderStageFlagBits::eCompute); + pushConstantRange.setOffset(0); + pushConstantRange.setSize(this->mPushConstantsDataTypeMemorySize * + this->mPushConstantsSize); + + pipelineLayoutInfo.setPushConstantRangeCount(1); + pipelineLayoutInfo.setPPushConstantRanges(&pushConstantRange); + } + + this->mPipelineLayout = std::make_shared(); + this->mDevice->createPipelineLayout( + &pipelineLayoutInfo, nullptr, this->mPipelineLayout.get()); + this->mFreePipelineLayout = true; + + std::vector specializationEntries; + + for (uint32_t i = 0; i < this->mSpecializationConstantsSize; i++) { + vk::SpecializationMapEntry specializationEntry( + static_cast(i), + static_cast( + this->mSpecializationConstantsDataTypeMemorySize * i), + this->mSpecializationConstantsDataTypeMemorySize); + + specializationEntries.push_back(specializationEntry); + } + + // This passes ownership of the memory so we remove ownership from + // specialization container by using "transferDataOwnership" + vk::SpecializationInfo specializationInfo( + static_cast(specializationEntries.size()), + specializationEntries.data(), + this->mSpecializationConstantsDataTypeMemorySize * + this->mSpecializationConstantsSize, + this->mSpecializationConstantsData); + + vk::PipelineShaderStageCreateInfo shaderStage( + vk::PipelineShaderStageCreateFlags(), + vk::ShaderStageFlagBits::eCompute, + *this->mShaderModule, + "main", + &specializationInfo); + + vk::ComputePipelineCreateInfo pipelineInfo(vk::PipelineCreateFlags(), + shaderStage, + *this->mPipelineLayout, + vk::Pipeline(), + 0); + +#ifdef KOMPUTE_CREATE_PIPELINE_RESULT_VALUE + vk::ResultValue pipelineResult = + this->mDevice->createComputePipeline(*mPipelineCache, pipelineInfo); + + if (pipelineResult.result != vk::Result::eSuccess) { + throw std::runtime_error("Failed to create pipeline result: " + + vk::to_string(pipelineResult.result)); + } + + vk::Pipeline& pipeline = pipelineResult.value; + this->mPipeline = std::make_shared(pipeline); + this->mFreePipeline = true; +#else + vk::Pipeline pipeline = + this->mDevice->createComputePipeline(*mPipelineCache, pipelineInfo) + .value; + this->mPipeline = std::make_shared(pipeline); + this->mFreePipeline = true; +#endif + + // TODO: Update to consistent + // this->mPipeline = std::make_shared(); + // this->mDevice->createComputePipelines( + // *this->mPipelineCache, 1, &pipelineInfo, nullptr, + // this->mPipeline.get()); + + KP_LOG_DEBUG("Kompute Algorithm Create Pipeline Success"); +} + +void +Algorithm::recordBindCore(const vk::CommandBuffer& commandBuffer) +{ + KP_LOG_DEBUG("Kompute Algorithm binding pipeline"); + + commandBuffer.bindPipeline(vk::PipelineBindPoint::eCompute, + *this->mPipeline); + + KP_LOG_DEBUG("Kompute Algorithm binding descriptor sets"); + + commandBuffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute, + *this->mPipelineLayout, + 0, // First set + *this->mDescriptorSet, + nullptr // Dispatcher + ); +} + +void +Algorithm::recordBindPush(const vk::CommandBuffer& commandBuffer) +{ + if (this->mPushConstantsSize) { + KP_LOG_DEBUG("Kompute Algorithm binding push constants memory size: {}", + this->mPushConstantsSize * + this->mPushConstantsDataTypeMemorySize); + + commandBuffer.pushConstants(*this->mPipelineLayout, + vk::ShaderStageFlagBits::eCompute, + 0, + this->mPushConstantsSize * + this->mPushConstantsDataTypeMemorySize, + this->mPushConstantsData); + } +} + +void +Algorithm::recordDispatch(const vk::CommandBuffer& commandBuffer) +{ + KP_LOG_DEBUG("Kompute Algorithm recording dispatch"); + + commandBuffer.dispatch( + this->mWorkgroup[0], this->mWorkgroup[1], this->mWorkgroup[2]); +} + +void +Algorithm::setWorkgroup(const Workgroup& workgroup, uint32_t minSize) +{ + KP_LOG_INFO("Kompute OpAlgoCreate setting dispatch size"); + + // The dispatch size is set up based on either explicitly provided template + // parameters or by default it would take the shape and size of the tensors + if (workgroup[0] > 0) { + // If at least the x value is provided we use mainly the parameters + // provided + this->mWorkgroup = { workgroup[0], + workgroup[1] > 0 ? workgroup[1] : 1, + workgroup[2] > 0 ? workgroup[2] : 1 }; + } else { + this->mWorkgroup = { minSize, 1, 1 }; + } + + KP_LOG_INFO("Kompute OpAlgoCreate set dispatch size X: {}, Y: {}, Z: {}", + this->mWorkgroup[0], + this->mWorkgroup[1], + this->mWorkgroup[2]); +} + +const Workgroup& +Algorithm::getWorkgroup() +{ + return this->mWorkgroup; +} + +const std::vector>& +Algorithm::getTensors() +{ + return this->mTensors; +} + +void Algorithm::setTensors(const std::vector>& tensors) +{ + this->mTensors = tensors; +} + +} diff --git a/ggml/src/ggml-kompute/kompute/src/CMakeLists.txt b/ggml/src/ggml-kompute/kompute/src/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..42b7d07f5e5b90598cb8a4066b6fa762942a39c7 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/CMakeLists.txt @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 + +cmake_minimum_required(VERSION 3.20) + +if(KOMPUTE_OPT_ANDROID_BUILD) + find_library(android android) +endif() + +cmake_minimum_required(VERSION 3.20) + +add_library(kompute STATIC Algorithm.cpp + Manager.cpp + OpAlgoDispatch.cpp + OpMemoryBarrier.cpp + OpTensorCopy.cpp + OpTensorFill.cpp + OpTensorSyncDevice.cpp + OpTensorSyncLocal.cpp + OpBufferSyncDevice.cpp + OpBufferSyncLocal.cpp + Sequence.cpp + Tensor.cpp + Core.cpp) + +add_library(kompute::kompute ALIAS kompute) + +# Set version for shared libraries. +set_target_properties(kompute + PROPERTIES + VERSION ${${PROJECT_NAME}_VERSION} + SOVERSION ${${PROJECT_NAME}_VERSION_MAJOR} + POSITION_INDEPENDENT_CODE TRUE) + +# Import GNU common install directory variables +include(GNUInstallDirs) + +install(TARGETS kompute + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) + +# Include CMake helpers for package config files +# Follow this installation guideline: https://cmake.org/cmake/help/latest/manual/cmake-packages.7.html +include(CMakePackageConfigHelpers) + +configure_package_config_file(${PROJECT_SOURCE_DIR}/cmake/komputeConfig.cmake.in + "${PROJECT_BINARY_DIR}/kompute/komputeConfig.cmake" + INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/kompute) + +#install(FILES ${PROJECT_BINARY_DIR}/kompute/komputeConfig.cmake +# ${PROJECT_BINARY_DIR}/kompute/komputeConfigVersion.cmake DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/kompute) + +# #################################################### +# Linking +# #################################################### +if(KOMPUTE_OPT_ANDROID_BUILD) + target_link_libraries(kompute PUBLIC vulkanAndroid + android + kp_logger + kp_shader + fmt::fmt-header-only) +else() + target_link_libraries(kompute PUBLIC + kp_logger + kp_shader + fmt::fmt-header-only) +endif() + +if(KOMPUTE_OPT_BUILD_PYTHON) + include_directories(${PYTHON_INCLUDE_DIRS}) + + target_link_libraries(kompute PRIVATE pybind11::headers ${PYTHON_LIBRARIES}) +endif() + +if(KOMPUTE_OPT_USE_BUILT_IN_VULKAN_HEADER) + target_link_libraries(kompute PUBLIC Vulkan-Headers) +else() + target_link_libraries(kompute PUBLIC Vulkan::Headers) +endif() + +# #################################################### +# Misc +# #################################################### +add_subdirectory(logger) +add_subdirectory(shaders) +add_subdirectory(include) diff --git a/ggml/src/ggml-kompute/kompute/src/Core.cpp b/ggml/src/ggml-kompute/kompute/src/Core.cpp new file mode 100644 index 0000000000000000000000000000000000000000..020f44160402286ed3426c03e362f068fe78b6f9 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/Core.cpp @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "kompute/Core.hpp" + +#ifndef KOMPUTE_VK_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE +#define KOMPUTE_VK_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE +/** + * Ensures support for dynamic loading of Vulkan functions on Android. + * Acts as a default store for loaded functions. + * More information: + * https://github.com/KhronosGroup/Vulkan-Hpp#vulkan_hpp_default_dispatcher + **/ +VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE +#endif // !KOMPUTE_VK_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE + +namespace kp { +} // namespace kp diff --git a/ggml/src/ggml-kompute/kompute/src/Manager.cpp b/ggml/src/ggml-kompute/kompute/src/Manager.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0c588e19be7b9726e5ae8dae3569dea0fd990af1 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/Manager.cpp @@ -0,0 +1,512 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "kompute/Manager.hpp" +#include "fmt/format.h" +#include "kompute/logger/Logger.hpp" +#include +#include +#include +#include +#include + +namespace kp { + +#ifndef KOMPUTE_DISABLE_VK_DEBUG_LAYERS +static VKAPI_ATTR VkBool32 VKAPI_CALL +debugMessageCallback(VkDebugReportFlagsEXT /*flags*/, + VkDebugReportObjectTypeEXT /*objectType*/, + uint64_t /*object*/, + size_t /*location*/, + int32_t /*messageCode*/, +#if KOMPUTE_OPT_ACTIVE_LOG_LEVEL <= KOMPUTE_LOG_LEVEL_DEBUG + const char* pLayerPrefix, + const char* pMessage, +#else + const char* /*pLayerPrefix*/, + const char* /*pMessage*/, +#endif + void* /*pUserData*/) +{ + KP_LOG_DEBUG("[VALIDATION]: {} - {}", pLayerPrefix, pMessage); + return VK_FALSE; +} +#endif + +Manager::Manager() +{ + this->mManageResources = true; + +// Make sure the logger is setup +#if !KOMPUTE_OPT_LOG_LEVEL_DISABLED + logger::setupLogger(); +#endif + this->createInstance(); +} + +void Manager::initializeDevice(uint32_t physicalDeviceIndex, + const std::vector& familyQueueIndices, + const std::vector& desiredExtensions) +{ + this->createDevice( + familyQueueIndices, physicalDeviceIndex, desiredExtensions); +} + +Manager::~Manager() +{ + KP_LOG_DEBUG("Kompute Manager Destructor started"); + this->destroy(); +} + +void +Manager::destroy() +{ + + KP_LOG_DEBUG("Kompute Manager destroy() started"); + + if (this->mDevice == nullptr) { + KP_LOG_ERROR( + "Kompute Manager destructor reached with null Device pointer"); + return; + } + + if (this->mManageResources && this->mManagedSequences.size()) { + KP_LOG_DEBUG("Kompute Manager explicitly running destructor for " + "managed sequences"); + for (const std::weak_ptr& weakSq : this->mManagedSequences) { + if (std::shared_ptr sq = weakSq.lock()) { + sq->destroy(); + } + } + this->mManagedSequences.clear(); + } + + if (this->mManageResources && !this->mManagedAlgorithmsMap.empty()) { + KP_LOG_DEBUG("Kompute Manager explicitly freeing algorithms"); + for (const auto& kv : this->mManagedAlgorithmsMap) { + if (std::shared_ptr algorithm = kv.second) { + algorithm->destroy(); + } + } + this->mManagedAlgorithmsMap.clear(); + } + + if (this->mManageResources && this->mManagedTensors.size()) { + KP_LOG_DEBUG("Kompute Manager explicitly freeing tensors"); + for (const std::weak_ptr& weakTensor : this->mManagedTensors) { + if (std::shared_ptr tensor = weakTensor.lock()) { + tensor->destroy(); + } + } + this->mManagedTensors.clear(); + } + + if (this->mPipelineCache) { + KP_LOG_DEBUG("Kompute Manager Destroying pipeline cache"); + if (!this->mPipelineCache) { + KP_LOG_WARN("Kompute Manager Error requested to destroy " + "pipeline cache but it is null"); + } + this->mDevice->destroy( + *this->mPipelineCache, + (vk::Optional)nullptr); + this->mPipelineCache = nullptr; + } + + if (this->mFreeDevice) { + KP_LOG_INFO("Destroying device"); + this->mDevice->destroy( + (vk::Optional)nullptr); + this->mDevice = nullptr; + KP_LOG_DEBUG("Kompute Manager Destroyed Device"); + } + + if (this->mInstance == nullptr) { + KP_LOG_ERROR( + "Kompute Manager destructor reached with null Instance pointer"); + return; + } + +#ifndef KOMPUTE_DISABLE_VK_DEBUG_LAYERS + if (this->mDebugReportCallback) { + this->mInstance->destroyDebugReportCallbackEXT( + this->mDebugReportCallback, nullptr, this->mDebugDispatcher); + KP_LOG_DEBUG("Kompute Manager Destroyed Debug Report Callback"); + } +#endif + + if (this->mFreeInstance) { + this->mInstance->destroy( + (vk::Optional)nullptr); + this->mInstance = nullptr; + KP_LOG_DEBUG("Kompute Manager Destroyed Instance"); + } +} + +void +Manager::createInstance() +{ + + KP_LOG_DEBUG("Kompute Manager creating instance"); + + this->mFreeInstance = true; + + vk::ApplicationInfo applicationInfo; + applicationInfo.pApplicationName = "Kompute"; + applicationInfo.pEngineName = "Kompute"; + applicationInfo.apiVersion = KOMPUTE_VK_API_VERSION; + applicationInfo.engineVersion = KOMPUTE_VK_API_VERSION; + applicationInfo.applicationVersion = KOMPUTE_VK_API_VERSION; + + std::vector applicationExtensions; + +#ifndef KOMPUTE_DISABLE_VK_DEBUG_LAYERS + applicationExtensions.push_back(VK_EXT_DEBUG_REPORT_EXTENSION_NAME); +#endif + + vk::InstanceCreateInfo computeInstanceCreateInfo; + computeInstanceCreateInfo.pApplicationInfo = &applicationInfo; + if (!applicationExtensions.empty()) { + computeInstanceCreateInfo.enabledExtensionCount = + (uint32_t)applicationExtensions.size(); + computeInstanceCreateInfo.ppEnabledExtensionNames = + applicationExtensions.data(); + } + + try { + mDynamicLoader = std::make_shared(); + } catch (const std::exception & err) { + return; + } + + PFN_vkGetInstanceProcAddr vkGetInstanceProcAddr = + mDynamicLoader->getProcAddress("vkGetInstanceProcAddr"); + VULKAN_HPP_DEFAULT_DISPATCHER.init(vkGetInstanceProcAddr); + +#ifndef KOMPUTE_DISABLE_VK_DEBUG_LAYERS + KP_LOG_DEBUG("Kompute Manager adding debug validation layers"); + // We'll identify the layers that are supported + std::vector validLayerNames; + std::vector desiredLayerNames = { + "VK_LAYER_LUNARG_assistant_layer", + "VK_LAYER_LUNARG_standard_validation", + "VK_LAYER_KHRONOS_validation", + }; + std::vector envLayerNames; + const char* envLayerNamesVal = std::getenv("KOMPUTE_ENV_DEBUG_LAYERS"); + if (envLayerNamesVal != nullptr && *envLayerNamesVal != '\0') { + KP_LOG_DEBUG("Kompute Manager adding environment layers: {}", + envLayerNamesVal); + std::istringstream iss(envLayerNamesVal); + std::istream_iterator beg(iss); + std::istream_iterator end; + envLayerNames = std::vector(beg, end); + for (const std::string& layerName : envLayerNames) { + desiredLayerNames.push_back(layerName.c_str()); + } + KP_LOG_DEBUG("Desired layers: {}", fmt::join(desiredLayerNames, ", ")); + } + + // Identify the valid layer names based on the desiredLayerNames + { + std::set uniqueLayerNames; + std::vector availableLayerProperties = + vk::enumerateInstanceLayerProperties(); + for (vk::LayerProperties layerProperties : availableLayerProperties) { + std::string layerName(layerProperties.layerName.data()); + uniqueLayerNames.insert(layerName); + } + KP_LOG_DEBUG("Available layers: {}", fmt::join(uniqueLayerNames, ", ")); + for (const char* desiredLayerName : desiredLayerNames) { + if (uniqueLayerNames.count(desiredLayerName) != 0) { + validLayerNames.push_back(desiredLayerName); + } + } + } + + if (!validLayerNames.empty()) { + KP_LOG_DEBUG( + "Kompute Manager Initializing instance with valid layers: {}", + fmt::join(validLayerNames, ", ")); + computeInstanceCreateInfo.enabledLayerCount = + static_cast(validLayerNames.size()); + computeInstanceCreateInfo.ppEnabledLayerNames = validLayerNames.data(); + } else { + KP_LOG_WARN("Kompute Manager no valid layer names found from desired " + "layer names"); + } +#endif + + this->mInstance = std::make_shared(); + vk::Result r = vk::createInstance( + &computeInstanceCreateInfo, nullptr, this->mInstance.get()); + if (r != vk::Result::eSuccess) { + KP_LOG_ERROR( + "Kompute Manager Error allocating vulkan instance", vk::to_string(r)); + this->mInstance = nullptr; + this->mFreeInstance = false; + return; + } + + VULKAN_HPP_DEFAULT_DISPATCHER.init(*this->mInstance); + + KP_LOG_DEBUG("Kompute Manager Instance Created"); + +#ifndef KOMPUTE_DISABLE_VK_DEBUG_LAYERS + KP_LOG_DEBUG("Kompute Manager adding debug callbacks"); + if (validLayerNames.size() > 0) { + vk::DebugReportFlagsEXT debugFlags = + vk::DebugReportFlagBitsEXT::eError | + vk::DebugReportFlagBitsEXT::eWarning; + vk::DebugReportCallbackCreateInfoEXT debugCreateInfo = {}; + debugCreateInfo.pfnCallback = + (PFN_vkDebugReportCallbackEXT)debugMessageCallback; + debugCreateInfo.flags = debugFlags; + + this->mDebugDispatcher.init(*this->mInstance, vkGetInstanceProcAddr); + this->mDebugReportCallback = + this->mInstance->createDebugReportCallbackEXT( + debugCreateInfo, nullptr, this->mDebugDispatcher); + } +#endif +} + +void +Manager::clear() +{ + if (this->mManageResources) { + this->mManagedTensors.erase( + std::remove_if(begin(this->mManagedTensors), + end(this->mManagedTensors), + [](std::weak_ptr t) { return t.expired(); }), + end(this->mManagedTensors)); + for (auto it = this->mManagedAlgorithmsMap.begin(); + it != this->mManagedAlgorithmsMap.end();) { + if (it->second) { + it = this->mManagedAlgorithmsMap.erase(it); + } else { + ++it; + } + } + this->mManagedSequences.erase( + std::remove_if(begin(this->mManagedSequences), + end(this->mManagedSequences), + [](std::weak_ptr t) { return t.expired(); }), + end(this->mManagedSequences)); + } +} + +void +Manager::createDevice(const std::vector& familyQueueIndices, + uint32_t physicalDeviceIndex, + const std::vector& desiredExtensions) +{ + + KP_LOG_DEBUG("Kompute Manager creating Device"); + + if (this->mInstance == nullptr) { + throw std::runtime_error("Kompute Manager instance is null"); + } + + this->mFreeDevice = true; + + // Getting an integer that says how many vuklan devices we have + std::vector physicalDevices = + this->mInstance->enumeratePhysicalDevices(); + uint32_t deviceCount = physicalDevices.size(); + + // This means there are no devices at all + if (deviceCount == 0) { + throw std::runtime_error("Failed to find GPUs with Vulkan support! " + "Maybe you haven't installed vulkan drivers?"); + } + + // This means that we're exceeding our device limit, for + // example if we have 2 devices, just physicalDeviceIndex + // 0 and 1 are acceptable. Hence, physicalDeviceIndex should + // always be less than deviceCount, else we raise an error + if (!(deviceCount > physicalDeviceIndex)) { + throw std::runtime_error("There is no such physical index or device, " + "please use your existing device"); + } + + vk::PhysicalDevice physicalDevice = physicalDevices[physicalDeviceIndex]; + + this->mPhysicalDevice = + std::make_shared(physicalDevice); + +#if KOMPUTE_OPT_ACTIVE_LOG_LEVEL <= KOMPUTE_LOG_LEVEL_INFO + vk::PhysicalDeviceProperties physicalDeviceProperties = + physicalDevice.getProperties(); +#endif + + KP_LOG_INFO("Using physical device index {} found {}", + physicalDeviceIndex, + physicalDeviceProperties.deviceName.data()); + + if (familyQueueIndices.empty()) { + // Find compute queue + std::vector allQueueFamilyProperties = + physicalDevice.getQueueFamilyProperties(); + + uint32_t computeQueueFamilyIndex = 0; + bool computeQueueSupported = false; + for (uint32_t i = 0; i < allQueueFamilyProperties.size(); i++) { + vk::QueueFamilyProperties queueFamilyProperties = + allQueueFamilyProperties[i]; + + if (queueFamilyProperties.queueFlags & + vk::QueueFlagBits::eCompute) { + computeQueueFamilyIndex = i; + computeQueueSupported = true; + break; + } + } + + if (!computeQueueSupported) { + throw std::runtime_error("Compute queue is not supported"); + } + + this->mComputeQueueFamilyIndices.push_back(computeQueueFamilyIndex); + } else { + this->mComputeQueueFamilyIndices = familyQueueIndices; + } + + std::unordered_map familyQueueCounts; + std::unordered_map> familyQueuePriorities; + for (const auto& value : this->mComputeQueueFamilyIndices) { + familyQueueCounts[value]++; + familyQueuePriorities[value].push_back(1.0f); + } + + std::unordered_map familyQueueIndexCount; + std::vector deviceQueueCreateInfos; + for (const auto& familyQueueInfo : familyQueueCounts) { + // Setting the device count to 0 + familyQueueIndexCount[familyQueueInfo.first] = 0; + + // Creating the respective device queue + vk::DeviceQueueCreateInfo deviceQueueCreateInfo( + vk::DeviceQueueCreateFlags(), + familyQueueInfo.first, + familyQueueInfo.second, + familyQueuePriorities[familyQueueInfo.first].data()); + deviceQueueCreateInfos.push_back(deviceQueueCreateInfo); + } + + KP_LOG_DEBUG("Kompute Manager desired extension layers {}", + fmt::join(desiredExtensions, ", ")); + + std::vector deviceExtensions = + this->mPhysicalDevice->enumerateDeviceExtensionProperties(); + + std::set uniqueExtensionNames; + for (const vk::ExtensionProperties& ext : deviceExtensions) { + uniqueExtensionNames.insert(ext.extensionName); + } + KP_LOG_DEBUG("Kompute Manager available extensions {}", + fmt::join(uniqueExtensionNames, ", ")); + std::vector validExtensions; + for (const std::string& ext : desiredExtensions) { + if (uniqueExtensionNames.count(ext) != 0) { + validExtensions.push_back(ext.c_str()); + } + } + if (desiredExtensions.size() != validExtensions.size()) { + KP_LOG_ERROR("Kompute Manager not all extensions were added: {}", + fmt::join(validExtensions, ", ")); + } + + vk::PhysicalDeviceFeatures features; + features.shaderInt16 = true; + + vk::PhysicalDeviceVulkan11Features features11; + features11.uniformAndStorageBuffer16BitAccess = true; + features11.storageBuffer16BitAccess = true; + features11.pNext = nullptr; + + vk::PhysicalDeviceVulkan12Features features12; + features12.storageBuffer8BitAccess = true; + features12.uniformAndStorageBuffer8BitAccess = true; + features12.shaderFloat16 = true; + features12.shaderInt8 = true; + features12.pNext = &features11; + + vk::DeviceCreateInfo deviceCreateInfo(vk::DeviceCreateFlags(), + deviceQueueCreateInfos.size(), + deviceQueueCreateInfos.data(), + {}, + {}, + validExtensions.size(), + validExtensions.data(), + &features); + + deviceCreateInfo.pNext = &features12; + + this->mDevice = std::make_shared(); + vk::Result r = physicalDevice.createDevice( + &deviceCreateInfo, nullptr, this->mDevice.get()); + if (r != vk::Result::eSuccess) { + KP_LOG_ERROR("Kompute Manager could not create device"); + } + + KP_LOG_DEBUG("Kompute Manager device created"); + + for (const uint32_t& familyQueueIndex : this->mComputeQueueFamilyIndices) { + std::shared_ptr currQueue = std::make_shared(); + + this->mDevice->getQueue(familyQueueIndex, + familyQueueIndexCount[familyQueueIndex], + currQueue.get()); + + familyQueueIndexCount[familyQueueIndex]++; + + this->mComputeQueues.push_back(currQueue); + } + + KP_LOG_DEBUG("Kompute Manager compute queue obtained"); + + mPipelineCache = std::make_shared(); + vk::PipelineCacheCreateInfo pipelineCacheInfo = + vk::PipelineCacheCreateInfo(); + this->mDevice->createPipelineCache( + &pipelineCacheInfo, nullptr, mPipelineCache.get()); +} + +std::shared_ptr +Manager::sequence(uint32_t queueIndex, uint32_t totalTimestamps) +{ + KP_LOG_DEBUG("Kompute Manager sequence() with queueIndex: {}", queueIndex); + + std::shared_ptr sq{ new kp::Sequence( + this->mPhysicalDevice, + this->mDevice, + this->mComputeQueues[queueIndex], + this->mComputeQueueFamilyIndices[queueIndex], + totalTimestamps) }; + + if (this->mManageResources) { + this->mManagedSequences.push_back(sq); + } + + return sq; +} + +vk::PhysicalDeviceProperties +Manager::getDeviceProperties() const +{ + return this->mPhysicalDevice->getProperties(); +} + +std::vector +Manager::listDevices() const +{ + return this->mInstance->enumeratePhysicalDevices(); +} + +std::shared_ptr +Manager::getVkInstance() const +{ + return this->mInstance; +} + +} diff --git a/ggml/src/ggml-kompute/kompute/src/OpAlgoDispatch.cpp b/ggml/src/ggml-kompute/kompute/src/OpAlgoDispatch.cpp new file mode 100644 index 0000000000000000000000000000000000000000..edc0f6eb6344875bba58338c5ccae41427497754 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/OpAlgoDispatch.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "kompute/operations/OpAlgoDispatch.hpp" + +namespace kp { + +OpAlgoDispatch::~OpAlgoDispatch() +{ + KP_LOG_DEBUG("Kompute OpAlgoDispatch destructor started"); + + if (this->mPushConstantsData) { + KP_LOG_DEBUG("Kompute freeing push constants data"); + free(this->mPushConstantsData); + } +} + +void +OpAlgoDispatch::record(const vk::CommandBuffer& commandBuffer) +{ + KP_LOG_DEBUG("Kompute OpAlgoDispatch record called"); + + // Barrier to ensure the data is finished writing to buffer memory + for (const std::shared_ptr& tensor : + this->mAlgorithm->getTensors()) { + tensor->recordPrimaryBufferMemoryBarrier( + commandBuffer, + vk::AccessFlagBits::eShaderWrite, + vk::AccessFlagBits::eShaderRead, + vk::PipelineStageFlagBits::eComputeShader, + vk::PipelineStageFlagBits::eComputeShader); + } + + if (this->mPushConstantsSize) { + this->mAlgorithm->setPushConstants( + this->mPushConstantsData, + this->mPushConstantsSize, + this->mPushConstantsDataTypeMemorySize); + } + + this->mAlgorithm->recordBindCore(commandBuffer); + this->mAlgorithm->recordBindPush(commandBuffer); + this->mAlgorithm->recordDispatch(commandBuffer); +} + +void +OpAlgoDispatch::preEval(const vk::CommandBuffer& /*commandBuffer*/) +{ + KP_LOG_DEBUG("Kompute OpAlgoDispatch preEval called"); +} + +void +OpAlgoDispatch::postEval(const vk::CommandBuffer& /*commandBuffer*/) +{ + KP_LOG_DEBUG("Kompute OpAlgoDispatch postSubmit called"); +} + +} diff --git a/ggml/src/ggml-kompute/kompute/src/OpBufferSyncDevice.cpp b/ggml/src/ggml-kompute/kompute/src/OpBufferSyncDevice.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1812d04b2428ed804949ab5c4885da24f46ce016 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/OpBufferSyncDevice.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "kompute/operations/OpBufferSyncDevice.hpp" + +namespace kp { + +OpBufferSyncDevice::OpBufferSyncDevice( + vk::Buffer *primaryBuffer, + vk::Buffer *stagingBuffer, + vk::DeviceSize size) + : mPrimaryBuffer(primaryBuffer) + , mStagingBuffer(stagingBuffer) + , mSize(size) +{ + KP_LOG_DEBUG("Kompute OpBufferSyncDevice constructor with params"); +} + +OpBufferSyncDevice::~OpBufferSyncDevice() +{ + KP_LOG_DEBUG("Kompute OpBufferSyncDevice destructor started"); +} + +void +OpBufferSyncDevice::record(const vk::CommandBuffer& commandBuffer) +{ + KP_LOG_DEBUG("Kompute OpBufferSyncDevice record called"); + vk::BufferCopy copyRegion(0, 0, mSize); + commandBuffer.copyBuffer(*mStagingBuffer, *mPrimaryBuffer, copyRegion); +} + +void +OpBufferSyncDevice::preEval(const vk::CommandBuffer& /*commandBuffer*/) +{ + KP_LOG_DEBUG("Kompute OpBufferSyncDevice preEval called"); +} + +void +OpBufferSyncDevice::postEval(const vk::CommandBuffer& /*commandBuffer*/) +{ + KP_LOG_DEBUG("Kompute OpBufferSyncDevice postEval called"); +} + +} diff --git a/ggml/src/ggml-kompute/kompute/src/OpBufferSyncLocal.cpp b/ggml/src/ggml-kompute/kompute/src/OpBufferSyncLocal.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a829819fa603a05528b6ab96be5802a5f1b2f8ad --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/OpBufferSyncLocal.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "kompute/operations/OpBufferSyncLocal.hpp" + +namespace kp { + +OpBufferSyncLocal::OpBufferSyncLocal( + vk::Buffer *primaryBuffer, + vk::Buffer *stagingBuffer, + vk::DeviceSize size) + : mPrimaryBuffer(primaryBuffer) + , mStagingBuffer(stagingBuffer) + , mSize(size) +{ + KP_LOG_DEBUG("Kompute OpBufferSyncLocal constructor with params"); +} + +OpBufferSyncLocal::~OpBufferSyncLocal() +{ + KP_LOG_DEBUG("Kompute OpBufferSyncLocal destructor started"); +} + +void +OpBufferSyncLocal::record(const vk::CommandBuffer& commandBuffer) +{ + KP_LOG_DEBUG("Kompute OpBufferSyncLocal record called"); + vk::BufferCopy copyRegion(0, 0, mSize); + commandBuffer.copyBuffer(*mPrimaryBuffer, *mStagingBuffer, copyRegion); +} + +void +OpBufferSyncLocal::preEval(const vk::CommandBuffer& /*commandBuffer*/) +{ + KP_LOG_DEBUG("Kompute OpBufferSyncLocal preEval called"); +} + +void +OpBufferSyncLocal::postEval(const vk::CommandBuffer& /*commandBuffer*/) +{ + KP_LOG_DEBUG("Kompute OpBufferSyncLocal postEval called"); +} + +} diff --git a/ggml/src/ggml-kompute/kompute/src/OpMemoryBarrier.cpp b/ggml/src/ggml-kompute/kompute/src/OpMemoryBarrier.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1f075a3c434e5c9798ce0fdbc25e42c62d37c71d --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/OpMemoryBarrier.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "kompute/operations/OpMemoryBarrier.hpp" + +namespace kp { + +OpMemoryBarrier::OpMemoryBarrier( + const std::vector>& tensors, + const vk::AccessFlagBits& srcAccessMask, + const vk::AccessFlagBits& dstAccessMask, + const vk::PipelineStageFlagBits& srcStageMask, + const vk::PipelineStageFlagBits& dstStageMask, + bool barrierOnPrimary) + : mSrcAccessMask(srcAccessMask) + , mDstAccessMask(dstAccessMask) + , mSrcStageMask(srcStageMask) + , mDstStageMask(dstStageMask) + , mBarrierOnPrimary(barrierOnPrimary) + , mTensors(tensors) +{ + KP_LOG_DEBUG("Kompute OpMemoryBarrier constructor"); +} + +OpMemoryBarrier::~OpMemoryBarrier() +{ + KP_LOG_DEBUG("Kompute OpMemoryBarrier destructor started"); +} + +void +OpMemoryBarrier::record(const vk::CommandBuffer& commandBuffer) +{ + KP_LOG_DEBUG("Kompute OpMemoryBarrier record called"); + + // Barrier to ensure the data is finished writing to buffer memory + if (this->mBarrierOnPrimary) { + for (const std::shared_ptr& tensor : this->mTensors) { + tensor->recordPrimaryBufferMemoryBarrier(commandBuffer, + this->mSrcAccessMask, + this->mDstAccessMask, + this->mSrcStageMask, + this->mDstStageMask); + } + } else { + for (const std::shared_ptr& tensor : this->mTensors) { + tensor->recordStagingBufferMemoryBarrier(commandBuffer, + this->mSrcAccessMask, + this->mDstAccessMask, + this->mSrcStageMask, + this->mDstStageMask); + } + } +} + +void +OpMemoryBarrier::preEval(const vk::CommandBuffer& /*commandBuffer*/) +{ + KP_LOG_DEBUG("Kompute OpMemoryBarrier preEval called"); +} + +void +OpMemoryBarrier::postEval(const vk::CommandBuffer& /*commandBuffer*/) +{ + KP_LOG_DEBUG("Kompute OpMemoryBarrier postSubmit called"); +} + +} diff --git a/ggml/src/ggml-kompute/kompute/src/OpTensorCopy.cpp b/ggml/src/ggml-kompute/kompute/src/OpTensorCopy.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1eaf428b85556fd808a910709f8e1e6f9c8529c5 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/OpTensorCopy.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "kompute/operations/OpTensorCopy.hpp" +#include "kompute/Tensor.hpp" + +namespace kp { + +OpTensorCopy::OpTensorCopy(const std::vector>& tensors) +{ + KP_LOG_DEBUG("Kompute OpTensorCopy constructor with params"); + + this->mTensors = tensors; + + if (this->mTensors.size() < 2) { + throw std::runtime_error( + "Kompute OpTensorCopy called with less than 2 tensor"); + } + + kp::Tensor::TensorDataTypes dataType = this->mTensors[0]->dataType(); + uint32_t size = this->mTensors[0]->size(); + for (const std::shared_ptr& tensor : tensors) { + if (tensor->dataType() != dataType) { + throw std::runtime_error(fmt::format( + "Attempting to copy tensors of different types from {} to {}", + Tensor::toString(dataType), + Tensor::toString(tensor->dataType()))); + } + if (tensor->size() != size) { + throw std::runtime_error(fmt::format( + "Attempting to copy tensors of different sizes from {} to {}", + size, + tensor->size())); + } + } +} + +OpTensorCopy::~OpTensorCopy() +{ + KP_LOG_DEBUG("Kompute OpTensorCopy destructor started"); +} + +void +OpTensorCopy::record(const vk::CommandBuffer& commandBuffer) +{ + KP_LOG_DEBUG("Kompute OpTensorCopy record called"); + + // We iterate from the second tensor onwards and record a copy to all + for (size_t i = 1; i < this->mTensors.size(); i++) { + this->mTensors[i]->recordCopyFrom(commandBuffer, this->mTensors[0]); + } +} + +void +OpTensorCopy::preEval(const vk::CommandBuffer& /*commandBuffer*/) +{ + KP_LOG_DEBUG("Kompute OpTensorCopy preEval called"); +} + +void +OpTensorCopy::postEval(const vk::CommandBuffer& /*commandBuffer*/) +{ + KP_LOG_DEBUG("Kompute OpTensorCopy postEval called"); + + // Do not copy on CPU side if source is storage tensor + if (this->mTensors[0]->tensorType() == kp::Tensor::TensorTypes::eStorage) + { + KP_LOG_DEBUG("Kompute OpTensorCopy not copying tensor source given it's of eStorage type"); + return; + } + void* data = this->mTensors[0]->rawData(); + + // Copy the data from the first tensor into all the tensors + for (size_t i = 1; i < this->mTensors.size(); i++) { + if (this->mTensors[i]->tensorType() == kp::Tensor::TensorTypes::eStorage) { + KP_LOG_DEBUG("Kompute OpTensorCopy not copying to tensor dest given it's of eStorage type"); + continue; + } + this->mTensors[i]->setRawData(data); + } +} + +} diff --git a/ggml/src/ggml-kompute/kompute/src/OpTensorFill.cpp b/ggml/src/ggml-kompute/kompute/src/OpTensorFill.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bda7d6040eb219ca7eb56ddfb5f402c69d22d572 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/OpTensorFill.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "kompute/operations/OpTensorFill.hpp" +#include "kompute/Tensor.hpp" + +namespace kp { + +OpTensorFill::OpTensorFill(const std::vector>& tensors) +{ + KP_LOG_DEBUG("Kompute OpTensorFill constructor with params"); + + if (tensors.size() < 1) { + throw std::runtime_error( + "Kompute OpTensorFill called with less than 1 tensor"); + } + + this->mTensors = tensors; +} + +OpTensorFill::~OpTensorFill() +{ + KP_LOG_DEBUG("Kompute OpTensorFill destructor started"); +} + +void +OpTensorFill::record(const vk::CommandBuffer& commandBuffer) +{ + KP_LOG_DEBUG("Kompute OpTensorFill record called"); + + for (size_t i = 0; i < this->mTensors.size(); i++) { + this->mTensors[i]->recordFill(commandBuffer, 0); + } +} + +void +OpTensorFill::preEval(const vk::CommandBuffer& /*commandBuffer*/) +{ + KP_LOG_DEBUG("Kompute OpTensorFill preEval called"); +} + +void +OpTensorFill::postEval(const vk::CommandBuffer& /*commandBuffer*/) +{ + KP_LOG_DEBUG("Kompute OpTensorFill postEval called"); +} + +} diff --git a/ggml/src/ggml-kompute/kompute/src/OpTensorSyncDevice.cpp b/ggml/src/ggml-kompute/kompute/src/OpTensorSyncDevice.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a2542357fb3619764bf063b183ee63f2b7673908 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/OpTensorSyncDevice.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "kompute/operations/OpTensorSyncDevice.hpp" + +namespace kp { + +OpTensorSyncDevice::OpTensorSyncDevice( + const std::vector>& tensors) +{ + KP_LOG_DEBUG("Kompute OpTensorSyncDevice constructor with params"); + + if (tensors.size() < 1) { + throw std::runtime_error( + "Kompute OpTensorSyncDevice called with less than 1 tensor"); + } + + this->mTensors = tensors; +} + +OpTensorSyncDevice::~OpTensorSyncDevice() +{ + KP_LOG_DEBUG("Kompute OpTensorSyncDevice destructor started"); + + this->mTensors.clear(); +} + +void +OpTensorSyncDevice::record(const vk::CommandBuffer& commandBuffer) +{ + KP_LOG_DEBUG("Kompute OpTensorSyncDevice record called"); + + for (size_t i = 0; i < this->mTensors.size(); i++) { + if (this->mTensors[i]->tensorType() == Tensor::TensorTypes::eDevice) { + this->mTensors[i]->recordCopyFromStagingToDevice(commandBuffer); + } + } +} + +void +OpTensorSyncDevice::preEval(const vk::CommandBuffer& /*commandBuffer*/) +{ + KP_LOG_DEBUG("Kompute OpTensorSyncDevice preEval called"); +} + +void +OpTensorSyncDevice::postEval(const vk::CommandBuffer& /*commandBuffer*/) +{ + KP_LOG_DEBUG("Kompute OpTensorSyncDevice postEval called"); +} + +} diff --git a/ggml/src/ggml-kompute/kompute/src/OpTensorSyncLocal.cpp b/ggml/src/ggml-kompute/kompute/src/OpTensorSyncLocal.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7818db565aaa7bffdab0a8486c483d3e29c4813c --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/OpTensorSyncLocal.cpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "kompute/Tensor.hpp" + +#include "kompute/operations/OpTensorSyncLocal.hpp" + +namespace kp { + +OpTensorSyncLocal::OpTensorSyncLocal( + const std::vector>& tensors) +{ + KP_LOG_DEBUG("Kompute OpTensorSyncLocal constructor with params"); + + if (tensors.size() < 1) { + throw std::runtime_error( + "Kompute OpTensorSyncLocal called with less than 1 tensor"); + } + + this->mTensors = tensors; +} + +OpTensorSyncLocal::~OpTensorSyncLocal() +{ + KP_LOG_DEBUG("Kompute OpTensorSyncLocal destructor started"); +} + +void +OpTensorSyncLocal::record(const vk::CommandBuffer& commandBuffer) +{ + KP_LOG_DEBUG("Kompute OpTensorSyncLocal record called"); + + for (size_t i = 0; i < this->mTensors.size(); i++) { + if (this->mTensors[i]->tensorType() == Tensor::TensorTypes::eDevice) { + + this->mTensors[i]->recordPrimaryBufferMemoryBarrier( + commandBuffer, + vk::AccessFlagBits::eShaderWrite, + vk::AccessFlagBits::eTransferRead, + vk::PipelineStageFlagBits::eComputeShader, + vk::PipelineStageFlagBits::eTransfer); + + this->mTensors[i]->recordCopyFromDeviceToStaging(commandBuffer); + + this->mTensors[i]->recordPrimaryBufferMemoryBarrier( + commandBuffer, + vk::AccessFlagBits::eTransferWrite, + vk::AccessFlagBits::eHostRead, + vk::PipelineStageFlagBits::eTransfer, + vk::PipelineStageFlagBits::eHost); + } + } +} + +void +OpTensorSyncLocal::preEval(const vk::CommandBuffer& /*commandBuffer*/) +{ + KP_LOG_DEBUG("Kompute OpTensorSyncLocal preEval called"); +} + +void +OpTensorSyncLocal::postEval(const vk::CommandBuffer& /*commandBuffer*/) +{ + KP_LOG_DEBUG("Kompute OpTensorSyncLocal postEval called"); + + KP_LOG_DEBUG("Kompute OpTensorSyncLocal mapping data into tensor local"); +} + +} diff --git a/ggml/src/ggml-kompute/kompute/src/Sequence.cpp b/ggml/src/ggml-kompute/kompute/src/Sequence.cpp new file mode 100644 index 0000000000000000000000000000000000000000..da3b379a3104c260c7bacf2d0adaea2884d5892c --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/Sequence.cpp @@ -0,0 +1,388 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "kompute/Sequence.hpp" + +namespace kp { + +Sequence::Sequence(std::shared_ptr physicalDevice, + std::shared_ptr device, + std::shared_ptr computeQueue, + uint32_t queueIndex, + uint32_t totalTimestamps) +{ + KP_LOG_DEBUG("Kompute Sequence Constructor with existing device & queue"); + + this->mPhysicalDevice = physicalDevice; + this->mDevice = device; + this->mComputeQueue = computeQueue; + this->mQueueIndex = queueIndex; + + this->createCommandPool(); + this->createCommandBuffer(); + if (totalTimestamps > 0) + this->createTimestampQueryPool(totalTimestamps + + 1); //+1 for the first one +} + +Sequence::~Sequence() +{ + KP_LOG_DEBUG("Kompute Sequence Destructor started"); + + if (this->mDevice) { + this->destroy(); + } +} + +void +Sequence::begin() +{ + KP_LOG_DEBUG("Kompute sequence called BEGIN"); + + if (this->isRecording()) { + KP_LOG_DEBUG("Kompute Sequence begin called when already recording"); + return; + } + + if (this->isRunning()) { + throw std::runtime_error( + "Kompute Sequence begin called when sequence still running"); + } + + KP_LOG_INFO("Kompute Sequence command now started recording"); + this->mCommandBuffer->begin(vk::CommandBufferBeginInfo()); + this->mRecording = true; + + // latch the first timestamp before any commands are submitted + if (this->timestampQueryPool) + this->mCommandBuffer->writeTimestamp( + vk::PipelineStageFlagBits::eAllCommands, + *this->timestampQueryPool, + 0); +} + +void +Sequence::end() +{ + KP_LOG_DEBUG("Kompute Sequence calling END"); + + if (this->isRunning()) { + throw std::runtime_error( + "Kompute Sequence begin called when sequence still running"); + } + + if (!this->isRecording()) { + KP_LOG_WARN("Kompute Sequence end called when not recording"); + return; + } else { + KP_LOG_INFO("Kompute Sequence command recording END"); + this->mCommandBuffer->end(); + this->mRecording = false; + } +} + +void +Sequence::clear() +{ + KP_LOG_DEBUG("Kompute Sequence calling clear"); + if (this->isRecording()) { + this->end(); + } +} + +std::shared_ptr +Sequence::eval() +{ + KP_LOG_DEBUG("Kompute sequence EVAL BEGIN"); + + return this->evalAsync()->evalAwait(); +} + +std::shared_ptr +Sequence::eval(std::shared_ptr op) +{ + this->clear(); + return this->record(op)->eval(); +} + +std::shared_ptr +Sequence::evalAsync() +{ + if (this->isRecording()) { + this->end(); + } + + if (this->mIsRunning) { + throw std::runtime_error( + "Kompute Sequence evalAsync called when an eval async was " + "called without successful wait"); + } + + this->mIsRunning = true; + + for (size_t i = 0; i < this->mOperations.size(); i++) { + this->mOperations[i]->preEval(*this->mCommandBuffer); + } + + vk::SubmitInfo submitInfo( + 0, nullptr, nullptr, 1, this->mCommandBuffer.get()); + + this->mFence = this->mDevice->createFence(vk::FenceCreateInfo()); + + KP_LOG_DEBUG( + "Kompute sequence submitting command buffer into compute queue"); + + this->mComputeQueue->submit(1, &submitInfo, this->mFence); + + return shared_from_this(); +} + +std::shared_ptr +Sequence::evalAsync(std::shared_ptr op) +{ + this->clear(); + this->record(op); + this->evalAsync(); + return shared_from_this(); +} + +std::shared_ptr +Sequence::evalAwait(uint64_t waitFor) +{ + if (!this->mIsRunning) { + KP_LOG_WARN("Kompute Sequence evalAwait called without existing eval"); + return shared_from_this(); + } + + vk::Result result = + this->mDevice->waitForFences(1, &this->mFence, VK_TRUE, waitFor); + this->mDevice->destroy( + this->mFence, (vk::Optional)nullptr); + + this->mIsRunning = false; + + if (result == vk::Result::eTimeout) { + KP_LOG_WARN("Kompute Sequence evalAwait reached timeout of {}", + waitFor); + return shared_from_this(); + } + + for (size_t i = 0; i < this->mOperations.size(); i++) { + this->mOperations[i]->postEval(*this->mCommandBuffer); + } + + return shared_from_this(); +} + +bool +Sequence::isRunning() const +{ + return this->mIsRunning; +} + +bool +Sequence::isRecording() const +{ + return this->mRecording; +} + +bool +Sequence::isInit() const +{ + return this->mDevice && this->mCommandPool && this->mCommandBuffer && + this->mComputeQueue; +} + +void +Sequence::rerecord() +{ + this->end(); + std::vector> ops = this->mOperations; + this->mOperations.clear(); + for (const std::shared_ptr& op : ops) { + this->record(op); + } +} + +void +Sequence::destroy() +{ + KP_LOG_DEBUG("Kompute Sequence destroy called"); + + if (!this->mDevice) { + KP_LOG_WARN("Kompute Sequence destroy called " + "with null Device pointer"); + return; + } + + if (this->mFreeCommandBuffer) { + KP_LOG_INFO("Freeing CommandBuffer"); + if (!this->mCommandBuffer) { + KP_LOG_WARN("Kompute Sequence destroy called with null " + "CommandPool pointer"); + return; + } + this->mDevice->freeCommandBuffers( + *this->mCommandPool, 1, this->mCommandBuffer.get()); + + this->mCommandBuffer = nullptr; + this->mFreeCommandBuffer = false; + + KP_LOG_DEBUG("Kompute Sequence Freed CommandBuffer"); + } + + if (this->mFreeCommandPool) { + KP_LOG_INFO("Destroying CommandPool"); + if (this->mCommandPool == nullptr) { + KP_LOG_WARN("Kompute Sequence destroy called with null " + "CommandPool pointer"); + return; + } + this->mDevice->destroy( + *this->mCommandPool, + (vk::Optional)nullptr); + + this->mCommandPool = nullptr; + this->mFreeCommandPool = false; + + KP_LOG_DEBUG("Kompute Sequence Destroyed CommandPool"); + } + + if (this->mOperations.size()) { + KP_LOG_INFO("Kompute Sequence clearing operations buffer"); + this->mOperations.clear(); + } + + if (this->timestampQueryPool) { + KP_LOG_INFO("Destroying QueryPool"); + this->mDevice->destroy( + *this->timestampQueryPool, + (vk::Optional)nullptr); + + this->timestampQueryPool = nullptr; + KP_LOG_DEBUG("Kompute Sequence Destroyed QueryPool"); + } + + if (this->mDevice) { + this->mDevice = nullptr; + } + if (this->mPhysicalDevice) { + this->mPhysicalDevice = nullptr; + } + if (this->mComputeQueue) { + this->mComputeQueue = nullptr; + } +} + +std::shared_ptr +Sequence::record(std::shared_ptr op) +{ + KP_LOG_DEBUG("Kompute Sequence record function started"); + + this->begin(); + + KP_LOG_DEBUG( + "Kompute Sequence running record on OpBase derived class instance"); + + op->record(*this->mCommandBuffer); + + this->mOperations.push_back(op); + + if (this->timestampQueryPool) + this->mCommandBuffer->writeTimestamp( + vk::PipelineStageFlagBits::eAllCommands, + *this->timestampQueryPool, + this->mOperations.size()); + + return shared_from_this(); +} + +void +Sequence::createCommandPool() +{ + KP_LOG_DEBUG("Kompute Sequence creating command pool"); + + if (!this->mDevice) { + throw std::runtime_error("Kompute Sequence device is null"); + } + + this->mFreeCommandPool = true; + + vk::CommandPoolCreateInfo commandPoolInfo(vk::CommandPoolCreateFlags(), + this->mQueueIndex); + this->mCommandPool = std::make_shared(); + this->mDevice->createCommandPool( + &commandPoolInfo, nullptr, this->mCommandPool.get()); + KP_LOG_DEBUG("Kompute Sequence Command Pool Created"); +} + +void +Sequence::createCommandBuffer() +{ + KP_LOG_DEBUG("Kompute Sequence creating command buffer"); + if (!this->mDevice) { + throw std::runtime_error("Kompute Sequence device is null"); + } + if (!this->mCommandPool) { + throw std::runtime_error("Kompute Sequence command pool is null"); + } + + this->mFreeCommandBuffer = true; + + vk::CommandBufferAllocateInfo commandBufferAllocateInfo( + *this->mCommandPool, vk::CommandBufferLevel::ePrimary, 1); + + this->mCommandBuffer = std::make_shared(); + this->mDevice->allocateCommandBuffers(&commandBufferAllocateInfo, + this->mCommandBuffer.get()); + KP_LOG_DEBUG("Kompute Sequence Command Buffer Created"); +} + +void +Sequence::createTimestampQueryPool(uint32_t totalTimestamps) +{ + KP_LOG_DEBUG("Kompute Sequence creating query pool"); + if (!this->isInit()) { + throw std::runtime_error( + "createTimestampQueryPool() called on uninitialized Sequence"); + } + if (!this->mPhysicalDevice) { + throw std::runtime_error("Kompute Sequence physical device is null"); + } + + vk::PhysicalDeviceProperties physicalDeviceProperties = + this->mPhysicalDevice->getProperties(); + + if (physicalDeviceProperties.limits.timestampComputeAndGraphics) { + vk::QueryPoolCreateInfo queryPoolInfo; + queryPoolInfo.setQueryCount(totalTimestamps); + queryPoolInfo.setQueryType(vk::QueryType::eTimestamp); + this->timestampQueryPool = std::make_shared( + this->mDevice->createQueryPool(queryPoolInfo)); + + KP_LOG_DEBUG("Query pool for timestamps created"); + } else { + throw std::runtime_error("Device does not support timestamps"); + } +} + +std::vector +Sequence::getTimestamps() +{ + if (!this->timestampQueryPool) + throw std::runtime_error("Timestamp latching not enabled"); + + const auto n = this->mOperations.size() + 1; + std::vector timestamps(n, 0); + this->mDevice->getQueryPoolResults( + *this->timestampQueryPool, + 0, + n, + timestamps.size() * sizeof(std::uint64_t), + timestamps.data(), + sizeof(uint64_t), + vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait); + + return timestamps; +} + +} diff --git a/ggml/src/ggml-kompute/kompute/src/Tensor.cpp b/ggml/src/ggml-kompute/kompute/src/Tensor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..84dce08e02457027b2c4ec918c1a42a4a4a89471 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/Tensor.cpp @@ -0,0 +1,450 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "kompute/Tensor.hpp" + +namespace kp { + +std::string +Tensor::toString(Tensor::TensorDataTypes dt) +{ + switch (dt) { + case TensorDataTypes::eBool: + return "eBool"; + case TensorDataTypes::eInt: + return "eInt"; + case TensorDataTypes::eUnsignedInt: + return "eUnsignedInt"; + case TensorDataTypes::eFloat: + return "eFloat"; + case TensorDataTypes::eDouble: + return "eDouble"; + default: + return "unknown"; + } +} + +std::string +Tensor::toString(Tensor::TensorTypes dt) +{ + switch (dt) { + case TensorTypes::eDevice: + return "eDevice"; + case TensorTypes::eHost: + return "eHost"; + case TensorTypes::eStorage: + return "eStorage"; + default: + return "unknown"; + } +} + +Tensor::Tensor(std::shared_ptr physicalDevice, + std::shared_ptr device, + void* data, + uint32_t elementTotalCount, + uint32_t elementMemorySize, + const TensorDataTypes& dataType, + vk::DeviceMemory *primaryMemory, + vk::Buffer *primaryBuffer, + vk::DeviceMemory *stagingMemory, + vk::Buffer *stagingBuffer, + vk::DeviceSize offset, + const TensorTypes& tensorType) +{ + KP_LOG_DEBUG("Kompute Tensor constructor data length: {}, and type: {}", + elementTotalCount, + Tensor::toString(tensorType)); + + this->mPhysicalDevice = physicalDevice; + this->mDevice = device; + this->mDataType = dataType; + this->mTensorType = tensorType; + + this->rebuild(data, elementTotalCount, elementMemorySize, primaryMemory, primaryBuffer, stagingMemory, stagingBuffer, offset); +} + +Tensor::~Tensor() +{ + KP_LOG_DEBUG("Kompute Tensor destructor started. Type: {}", + Tensor::toString(this->tensorType())); + + if (this->mDevice) { + this->destroy(); + } + + KP_LOG_DEBUG("Kompute Tensor destructor success"); +} + +void +Tensor::rebuild(void* /*data*/, + uint32_t elementTotalCount, + uint64_t memorySize, + vk::DeviceMemory *primaryMemory, + vk::Buffer *primaryBuffer, + vk::DeviceMemory *stagingMemory, + vk::Buffer *stagingBuffer, + vk::DeviceSize offset) +{ + KP_LOG_DEBUG("Kompute Tensor rebuilding with size {}", elementTotalCount); + + this->mSize = elementTotalCount; + this->mMemorySize = memorySize; + this->mOffset = offset; + + if (this->mPrimaryBuffer || this->mPrimaryMemory) { + KP_LOG_DEBUG( + "Kompute Tensor destroying existing resources before rebuild"); + this->destroy(); + } + + this->setGPUResources(primaryMemory, primaryBuffer, stagingMemory, stagingBuffer, offset); +} + +Tensor::TensorTypes +Tensor::tensorType() +{ + return this->mTensorType; +} + +bool +Tensor::isInit() +{ + return this->mDevice && this->mPrimaryBuffer && this->mPrimaryMemory && + this->mRawData; +} + +uint32_t +Tensor::size() +{ + return this->mSize; +} + +uint64_t +Tensor::memorySize() +{ + return this->mMemorySize; +} + +kp::Tensor::TensorDataTypes +Tensor::dataType() +{ + return this->mDataType; +} + +void* +Tensor::rawData() +{ + return this->mRawData; +} + +void +Tensor::setRawData(const void* data) +{ + memcpy(this->mRawData, data, this->memorySize()); +} + +void +Tensor::recordCopyFrom(const vk::CommandBuffer& commandBuffer, + std::shared_ptr copyFromTensor) +{ + + vk::DeviceSize bufferSize(this->memorySize()); + vk::BufferCopy copyRegion(mOffset, mOffset, bufferSize); + + KP_LOG_DEBUG("Kompute Tensor recordCopyFrom data size {}.", bufferSize); + + this->recordCopyBuffer(commandBuffer, + copyFromTensor->mPrimaryBuffer, + this->mPrimaryBuffer, + bufferSize, + copyRegion); +} + +void +Tensor::recordCopyFromStagingToDevice(const vk::CommandBuffer& commandBuffer) +{ + if (!this->mStagingBuffer) + return; + + vk::DeviceSize bufferSize(this->memorySize()); + vk::BufferCopy copyRegion(mOffset, mOffset, bufferSize); + + KP_LOG_DEBUG("Kompute Tensor copying data size {}.", bufferSize); + + this->recordCopyBuffer(commandBuffer, + this->mStagingBuffer, + this->mPrimaryBuffer, + bufferSize, + copyRegion); +} + +void +Tensor::recordCopyFromDeviceToStaging(const vk::CommandBuffer& commandBuffer) +{ + if (!this->mStagingBuffer) + return; + + vk::DeviceSize bufferSize(this->memorySize()); + vk::BufferCopy copyRegion(mOffset, mOffset, bufferSize); + + KP_LOG_DEBUG("Kompute Tensor copying data size {}.", bufferSize); + + this->recordCopyBuffer(commandBuffer, + this->mPrimaryBuffer, + this->mStagingBuffer, + bufferSize, + copyRegion); +} + +void +Tensor::recordCopyBuffer(const vk::CommandBuffer& commandBuffer, + vk::Buffer *bufferFrom, + vk::Buffer *bufferTo, + vk::DeviceSize /*bufferSize*/, + vk::BufferCopy copyRegion) +{ + + commandBuffer.copyBuffer(*bufferFrom, *bufferTo, copyRegion); +} + +void +Tensor::recordFill(const vk::CommandBuffer &commandBuffer, + uint32_t fill) +{ + commandBuffer.fillBuffer(*this->mPrimaryBuffer, mOffset, this->memorySize(), fill); +} + +void +Tensor::recordPrimaryBufferMemoryBarrier(const vk::CommandBuffer& commandBuffer, + vk::AccessFlagBits srcAccessMask, + vk::AccessFlagBits dstAccessMask, + vk::PipelineStageFlagBits srcStageMask, + vk::PipelineStageFlagBits dstStageMask) +{ + KP_LOG_DEBUG("Kompute Tensor recording PRIMARY buffer memory barrier"); + + this->recordBufferMemoryBarrier(commandBuffer, + *this->mPrimaryBuffer, + srcAccessMask, + dstAccessMask, + srcStageMask, + dstStageMask); +} + +void +Tensor::recordStagingBufferMemoryBarrier(const vk::CommandBuffer& commandBuffer, + vk::AccessFlagBits srcAccessMask, + vk::AccessFlagBits dstAccessMask, + vk::PipelineStageFlagBits srcStageMask, + vk::PipelineStageFlagBits dstStageMask) +{ + if (!this->mStagingBuffer) + return; + + KP_LOG_DEBUG("Kompute Tensor recording STAGING buffer memory barrier"); + + this->recordBufferMemoryBarrier(commandBuffer, + *this->mStagingBuffer, + srcAccessMask, + dstAccessMask, + srcStageMask, + dstStageMask); +} + +void +Tensor::recordBufferMemoryBarrier(const vk::CommandBuffer& commandBuffer, + const vk::Buffer& buffer, + vk::AccessFlagBits srcAccessMask, + vk::AccessFlagBits dstAccessMask, + vk::PipelineStageFlagBits srcStageMask, + vk::PipelineStageFlagBits dstStageMask) +{ + KP_LOG_DEBUG("Kompute Tensor recording buffer memory barrier"); + + vk::DeviceSize bufferSize = this->memorySize(); + + vk::BufferMemoryBarrier bufferMemoryBarrier; + bufferMemoryBarrier.buffer = buffer; + bufferMemoryBarrier.size = bufferSize; + bufferMemoryBarrier.srcAccessMask = srcAccessMask; + bufferMemoryBarrier.dstAccessMask = dstAccessMask; + bufferMemoryBarrier.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; + bufferMemoryBarrier.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; + + commandBuffer.pipelineBarrier(srcStageMask, + dstStageMask, + vk::DependencyFlags(), + nullptr, + bufferMemoryBarrier, + nullptr); +} + +vk::DescriptorBufferInfo +Tensor::constructDescriptorBufferInfo() +{ + KP_LOG_DEBUG("Kompute Tensor construct descriptor buffer info size {}", + this->memorySize()); + vk::DeviceSize bufferSize = this->memorySize(); + return vk::DescriptorBufferInfo(*this->mPrimaryBuffer, + mOffset, // offset + bufferSize); +} + +vk::BufferUsageFlags +Tensor::getPrimaryBufferUsageFlags() +{ + switch (this->mTensorType) { + case TensorTypes::eDevice: + return vk::BufferUsageFlagBits::eStorageBuffer | + vk::BufferUsageFlagBits::eTransferSrc | + vk::BufferUsageFlagBits::eTransferDst; + break; + case TensorTypes::eHost: + return vk::BufferUsageFlagBits::eStorageBuffer | + vk::BufferUsageFlagBits::eTransferSrc | + vk::BufferUsageFlagBits::eTransferDst; + break; + case TensorTypes::eStorage: + return vk::BufferUsageFlagBits::eStorageBuffer; + break; + default: + throw std::runtime_error("Kompute Tensor invalid tensor type"); + } +} + +vk::MemoryPropertyFlags +Tensor::getPrimaryMemoryPropertyFlags() +{ + switch (this->mTensorType) { + case TensorTypes::eDevice: + return vk::MemoryPropertyFlagBits::eDeviceLocal; + break; + case TensorTypes::eHost: + return vk::MemoryPropertyFlagBits::eHostVisible | + vk::MemoryPropertyFlagBits::eHostCoherent; + break; + case TensorTypes::eStorage: + return vk::MemoryPropertyFlagBits::eDeviceLocal; + break; + default: + throw std::runtime_error("Kompute Tensor invalid tensor type"); + } +} + +vk::BufferUsageFlags +Tensor::getStagingBufferUsageFlags() +{ + switch (this->mTensorType) { + case TensorTypes::eDevice: + return vk::BufferUsageFlagBits::eTransferSrc | + vk::BufferUsageFlagBits::eTransferDst; + break; + default: + throw std::runtime_error("Kompute Tensor invalid tensor type"); + } +} + +vk::MemoryPropertyFlags +Tensor::getStagingMemoryPropertyFlags() +{ + switch (this->mTensorType) { + case TensorTypes::eDevice: + return vk::MemoryPropertyFlagBits::eHostVisible | + vk::MemoryPropertyFlagBits::eHostCoherent; + break; + default: + throw std::runtime_error("Kompute Tensor invalid tensor type"); + } +} + +void +Tensor::setGPUResources(vk::DeviceMemory *primaryMemory, + vk::Buffer *primaryBuffer, + vk::DeviceMemory *stagingMemory, + vk::Buffer *stagingBuffer, + vk::DeviceSize /*offset*/) +{ + KP_LOG_DEBUG("Kompute Tensor creating buffer"); + + if (!this->mPhysicalDevice) { + throw std::runtime_error("Kompute Tensor phyisical device is null"); + } + if (!this->mDevice) { + throw std::runtime_error("Kompute Tensor device is null"); + } + + KP_LOG_DEBUG("Kompute Tensor creating primary buffer and memory"); + + this->mPrimaryBuffer = primaryBuffer; + this->mPrimaryMemory = primaryMemory; + + if (this->mTensorType == TensorTypes::eDevice) { + KP_LOG_DEBUG("Kompute Tensor creating staging buffer and memory"); + + this->mStagingBuffer = stagingBuffer; + this->mStagingMemory = stagingMemory; + } + + KP_LOG_DEBUG("Kompute Tensor buffer & memory creation successful"); +} + +void +Tensor::destroy() +{ + KP_LOG_DEBUG("Kompute Tensor started destroy()"); + + // Setting raw data to null regardless whether device is available to + // invalidate Tensor + this->mRawData = nullptr; + this->mSize = 0; + this->mMemorySize = 0; + + if (!this->mDevice) { + KP_LOG_WARN( + "Kompute Tensor destructor reached with null Device pointer"); + return; + } + + if (this->mDevice) { + this->mDevice = nullptr; + } + + KP_LOG_DEBUG("Kompute Tensor successful destroy()"); +} + +template<> +Tensor::TensorDataTypes +TensorT::dataType() +{ + return Tensor::TensorDataTypes::eBool; +} + +template<> +Tensor::TensorDataTypes +TensorT::dataType() +{ + return Tensor::TensorDataTypes::eInt; +} + +template<> +Tensor::TensorDataTypes +TensorT::dataType() +{ + return Tensor::TensorDataTypes::eUnsignedInt; +} + +template<> +Tensor::TensorDataTypes +TensorT::dataType() +{ + return Tensor::TensorDataTypes::eFloat; +} + +template<> +Tensor::TensorDataTypes +TensorT::dataType() +{ + return Tensor::TensorDataTypes::eDouble; +} + +} diff --git a/ggml/src/ggml-kompute/kompute/src/include/CMakeLists.txt b/ggml/src/ggml-kompute/kompute/src/include/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..53e9d8ae616be8d95b2381cf3e3da62ced88e1a7 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/include/CMakeLists.txt @@ -0,0 +1,47 @@ +cmake_minimum_required(VERSION 3.20) + +# #################################################### +# Kompute +# #################################################### +target_include_directories(kompute PUBLIC $ + $) + +target_sources(kompute PRIVATE + + # Header files (useful in IDEs) + kompute/Algorithm.hpp + kompute/Core.hpp + kompute/Kompute.hpp + kompute/Manager.hpp + kompute/Sequence.hpp + kompute/Tensor.hpp + + kompute/operations/OpAlgoDispatch.hpp + kompute/operations/OpBase.hpp + kompute/operations/OpMemoryBarrier.hpp + kompute/operations/OpMult.hpp + kompute/operations/OpTensorCopy.hpp + kompute/operations/OpTensorFill.hpp + kompute/operations/OpTensorSyncDevice.hpp + kompute/operations/OpTensorSyncLocal.hpp + kompute/operations/OpBufferSyncDevice.hpp + kompute/operations/OpBufferSyncLocal.hpp + + kompute/logger/Logger.hpp +) + +#install(DIRECTORY kompute DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) + +# #################################################### +# Logger +# #################################################### +target_include_directories(kp_logger PUBLIC $ + $) + +target_sources(kp_logger PRIVATE + + # Header files (useful in IDEs) + kompute/logger/Logger.hpp +) + +#install(DIRECTORY logger DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) \ No newline at end of file diff --git a/ggml/src/ggml-kompute/kompute/src/include/kompute/Algorithm.hpp b/ggml/src/ggml-kompute/kompute/src/include/kompute/Algorithm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e5fef1f56d849b4c9afe54df8265405f915946f8 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/include/kompute/Algorithm.hpp @@ -0,0 +1,330 @@ +// SPDX-License-Identifier: Apache-2.0 +#pragma once + +#include "kompute/Core.hpp" + +#include "fmt/format.h" +#include "kompute/Tensor.hpp" +#include "logger/Logger.hpp" + +namespace kp { + +/** + Abstraction for compute shaders that are run on top of tensors grouped via + ParameterGroups (which group descriptorsets) +*/ +class Algorithm +{ + public: + /** + * Main constructor for algorithm with configuration parameters to create + * the underlying resources. + * + * @param device The Vulkan device to use for creating resources + * @param tensors (optional) The tensors to use to create the descriptor + * resources + * @param spirv (optional) The spirv code to use to create the algorithm + * @param workgroup (optional) The kp::Workgroup to use for the dispatch + * which defaults to kp::Workgroup(tensor[0].size(), 1, 1) if not set. + * @param specializationConstants (optional) The templatable param is to be + * used to initialize the specialization constants which cannot be changed + * once set. + * @param pushConstants (optional) This templatable param is to be used + * when initializing the pipeline, which set the size of the push constants + * - these can be modified but all new values must have the same data type + * and length as otherwise it will result in errors. + */ + template + Algorithm(std::shared_ptr device, + vk::PipelineCache *pipelineCache, + vk::DescriptorPool *pool, + const std::vector>& tensors = {}, + const std::vector& spirv = {}, + const Workgroup& workgroup = {}, + const std::vector& specializationConstants = {}, + const std::vector

& pushConstants = {}) + { + KP_LOG_DEBUG("Kompute Algorithm Constructor with device"); + + this->mDevice = device; + this->mPipelineCache = pipelineCache; + this->mDescriptorPool = pool; + + if (tensors.size() && spirv.size()) { + KP_LOG_INFO( + "Kompute Algorithm initialising with tensor size: {} and " + "spirv size: {}", + tensors.size(), + spirv.size()); + this->rebuild(tensors, + spirv, + workgroup, + specializationConstants, + pushConstants); + } else { + KP_LOG_INFO( + "Kompute Algorithm constructor with empty tensors and or " + "spirv so not rebuilding vulkan components"); + } + } + + /** + * Rebuild function to reconstruct algorithm with configuration parameters + * to create the underlying resources. + * + * @param tensors The tensors to use to create the descriptor resources + * @param spirv The spirv code to use to create the algorithm + * @param workgroup (optional) The kp::Workgroup to use for the dispatch + * which defaults to kp::Workgroup(tensor[0].size(), 1, 1) if not set. + * @param specializationConstants (optional) The std::vector to use + * to initialize the specialization constants which cannot be changed once + * set. + * @param pushConstants (optional) The std::vector to use when + * initializing the pipeline, which set the size of the push constants - + * these can be modified but all new values must have the same vector size + * as this initial value. + */ + template + void rebuild(const std::vector>& tensors, + const std::vector& spirv, + const Workgroup& workgroup = {}, + const std::vector& specializationConstants = {}, + const std::vector

& pushConstants = {}) + { + KP_LOG_DEBUG("Kompute Algorithm rebuild started"); + + this->mTensors = tensors; + this->mSpirv = spirv; + + if (specializationConstants.size()) { + if (this->mSpecializationConstantsData) { + free(this->mSpecializationConstantsData); + } + uint32_t memorySize = + sizeof(decltype(specializationConstants.back())); + uint32_t size = specializationConstants.size(); + uint32_t totalSize = size * memorySize; + this->mSpecializationConstantsData = malloc(totalSize); + memcpy(this->mSpecializationConstantsData, + specializationConstants.data(), + totalSize); + this->mSpecializationConstantsDataTypeMemorySize = memorySize; + this->mSpecializationConstantsSize = size; + } + + if (pushConstants.size()) { + if (this->mPushConstantsData) { + free(this->mPushConstantsData); + } + uint32_t memorySize = sizeof(decltype(pushConstants.back())); + uint32_t size = pushConstants.size(); + uint32_t totalSize = size * memorySize; + this->mPushConstantsData = malloc(totalSize); + memcpy(this->mPushConstantsData, pushConstants.data(), totalSize); + this->mPushConstantsDataTypeMemorySize = memorySize; + this->mPushConstantsSize = size; + } + + this->setWorkgroup( + workgroup, this->mTensors.size() ? this->mTensors[0]->size() : 1); + + // Descriptor pool is created first so if available then destroy all + // before rebuild + if (this->isInit()) { + this->destroy(); + } + + this->createParameters(); + this->createShaderModule(); + this->createPipeline(); + } + + /** + * Destructor for Algorithm which is responsible for freeing and desroying + * respective pipelines and owned parameter groups. + */ + ~Algorithm(); + + /** + * Records the dispatch function with the provided template parameters or + * alternatively using the size of the tensor by default. + * + * @param commandBuffer Command buffer to record the algorithm resources to + */ + void recordDispatch(const vk::CommandBuffer& commandBuffer); + + /** + * Records command that binds the "core" algorithm components which consist + * of binding the pipeline and binding the descriptorsets. + * + * @param commandBuffer Command buffer to record the algorithm resources to + */ + void recordBindCore(const vk::CommandBuffer& commandBuffer); + + /** + * Records command that binds the push constants to the command buffer + * provided + * - it is required that the pushConstants provided are of the same size as + * the ones provided during initialization. + * + * @param commandBuffer Command buffer to record the algorithm resources to + */ + void recordBindPush(const vk::CommandBuffer& commandBuffer); + + /** + * function that checks all the gpu resource components to verify if these + * have been created and returns true if all are valid. + * + * @returns returns true if the algorithm is currently initialized. + */ + bool isInit(); + + /** + * Sets the work group to use in the recordDispatch + * + * @param workgroup The kp::Workgroup value to use to update the algorithm. + * It must have a value greater than 1 on the x value (index 1) otherwise it + * will be initialized on the size of the first tensor (ie. + * this->mTensor[0]->size()) + */ + void setWorkgroup(const Workgroup& workgroup, uint32_t minSize = 1); + /** + * Sets the push constants to the new value provided to use in the next + * bindPush() + * + * @param pushConstants The templatable vector is to be used to set the push + * constants to use in the next bindPush(...) calls. The constants provided + * must be of the same size as the ones created during initialization. + */ + template + void setPushConstants(const std::vector& pushConstants) + { + uint32_t memorySize = sizeof(decltype(pushConstants.back())); + uint32_t size = pushConstants.size(); + this->setPushConstants(pushConstants.data(), size, memorySize); + } + + void updateDescriptors(vk::DescriptorPool *pool) + { + this->mDescriptorPool = pool; + this->setWorkgroup( + this->mWorkgroup, this->mTensors.size() ? this->mTensors[0]->size() : 1); + + this->updateParameters(); // TODO: See if we can reduce this + } + + /** + * Sets the push constants to the new value provided to use in the next + * bindPush() with the raw memory block location and memory size to be used. + * + * @param data The raw data point to copy the data from, without modifying + * the pointer. + * @param size The number of data elements provided in the data + * @param memorySize The memory size of each of the data elements in bytes. + */ + void setPushConstants(const void* data, uint32_t size, uint32_t memorySize) + { + + uint32_t totalSize = memorySize * size; + uint32_t previousTotalSize = + this->mPushConstantsDataTypeMemorySize * this->mPushConstantsSize; + + if (totalSize != previousTotalSize) { + throw std::runtime_error(fmt::format( + "Kompute Algorithm push " + "constant total memory size provided is {} but expected {} bytes", + totalSize, + previousTotalSize)); + } + if (this->mPushConstantsData) { + free(this->mPushConstantsData); + } + + this->mPushConstantsData = malloc(totalSize); + memcpy(this->mPushConstantsData, data, totalSize); + this->mPushConstantsDataTypeMemorySize = memorySize; + this->mPushConstantsSize = size; + } + + /** + * Gets the current workgroup from the algorithm. + * + * @param The kp::Constant to use to set the push constants to use in the + * next bindPush(...) calls. The constants provided must be of the same size + * as the ones created during initialization. + */ + const Workgroup& getWorkgroup(); + /** + * Gets the specialization constants of the current algorithm. + * + * @returns The std::vector currently set for specialization + * constants + */ + template + const std::vector getSpecializationConstants() + { + return { (T*)this->mSpecializationConstantsData, + ((T*)this->mSpecializationConstantsData) + + this->mSpecializationConstantsSize }; + } + /** + * Gets the specialization constants of the current algorithm. + * + * @returns The std::vector currently set for push constants + */ + template + const std::vector getPushConstants() + { + return { (T*)this->mPushConstantsData, + ((T*)this->mPushConstantsData) + this->mPushConstantsSize }; + } + /** + * Gets the current tensors that are used in the algorithm. + * + * @returns The list of tensors used in the algorithm. + */ + const std::vector>& getTensors(); + void setTensors(const std::vector>& tensors); + + void destroy(); + + private: + // -------------- NEVER OWNED RESOURCES + std::shared_ptr mDevice; + std::vector> mTensors; + + // -------------- OPTIONALLY OWNED RESOURCES + std::shared_ptr mDescriptorSetLayout; + bool mFreeDescriptorSetLayout = false; + vk::DescriptorPool *mDescriptorPool = nullptr; + std::shared_ptr mDescriptorSet; + bool mFreeDescriptorSet = false; + std::shared_ptr mShaderModule; + bool mFreeShaderModule = false; + std::shared_ptr mPipelineLayout; + bool mFreePipelineLayout = false; + vk::PipelineCache *mPipelineCache = nullptr; + std::shared_ptr mPipeline; + bool mFreePipeline = false; + + // -------------- ALWAYS OWNED RESOURCES + std::vector mSpirv; + void* mSpecializationConstantsData = nullptr; + uint32_t mSpecializationConstantsDataTypeMemorySize = 0; + uint32_t mSpecializationConstantsSize = 0; + void* mPushConstantsData = nullptr; + uint32_t mPushConstantsDataTypeMemorySize = 0; + uint32_t mPushConstantsSize = 0; + Workgroup mWorkgroup; + + // Create util functions + void createShaderModule(); + void createPipeline(); + + // Parameters + void freeParameters(); + void createParameters(); + void updateParameters(); +}; + +} // End namespace kp diff --git a/ggml/src/ggml-kompute/kompute/src/include/kompute/Core.hpp b/ggml/src/ggml-kompute/kompute/src/include/kompute/Core.hpp new file mode 100644 index 0000000000000000000000000000000000000000..406e6b5d481d5253640321d77c1cbad652587260 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/include/kompute/Core.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 +#pragma once + +#include + +// Typedefs to simplify interaction with core types +namespace kp { +typedef std::array Workgroup; +typedef std::vector Constants; +} + +// Must be after vulkan is included +#ifndef KOMPUTE_VK_API_VERSION +#ifndef KOMPUTE_VK_API_MAJOR_VERSION +#define KOMPUTE_VK_API_MAJOR_VERSION 1 +#endif // KOMPUTE_VK_API_MAJOR_VERSION +#ifndef KOMPUTE_VK_API_MINOR_VERSION +#define KOMPUTE_VK_API_MINOR_VERSION 2 +#endif // KOMPUTE_VK_API_MINOR_VERSION +#define KOMPUTE_VK_API_VERSION \ + VK_MAKE_VERSION( \ + KOMPUTE_VK_API_MAJOR_VERSION, KOMPUTE_VK_API_MINOR_VERSION, 0) +#endif // KOMPUTE_VK_API_VERSION + +#if defined(KOMPUTE_BUILD_PYTHON) +#include +namespace py = pybind11; +// from python/src/main.cpp +extern py::object kp_trace, kp_debug, kp_info, kp_warning, kp_error; +#endif diff --git a/ggml/src/ggml-kompute/kompute/src/include/kompute/Kompute.hpp b/ggml/src/ggml-kompute/kompute/src/include/kompute/Kompute.hpp new file mode 100644 index 0000000000000000000000000000000000000000..70e0dd433c44f567681a1d1361f017ffae5feae7 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/include/kompute/Kompute.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include "Algorithm.hpp" +#include "Core.hpp" +#include "Manager.hpp" +#include "Sequence.hpp" +#include "Tensor.hpp" + +#include "operations/OpAlgoDispatch.hpp" +#include "operations/OpBase.hpp" +#include "operations/OpMemoryBarrier.hpp" +#include "operations/OpMult.hpp" +#include "operations/OpTensorCopy.hpp" +#include "operations/OpTensorSyncDevice.hpp" +#include "operations/OpTensorSyncLocal.hpp" +#include "operations/OpBufferSyncDevice.hpp" +#include "operations/OpBufferSyncLocal.hpp" +#include "operations/OpTensorFill.hpp" + +// Will be build by CMake and placed inside the build directory +#include "ShaderLogisticRegression.hpp" +#include "ShaderOpMult.hpp" diff --git a/ggml/src/ggml-kompute/kompute/src/include/kompute/Manager.hpp b/ggml/src/ggml-kompute/kompute/src/include/kompute/Manager.hpp new file mode 100644 index 0000000000000000000000000000000000000000..780c352ebc43ae4fe9ded852208b73cc81dd288e --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/include/kompute/Manager.hpp @@ -0,0 +1,284 @@ +// SPDX-License-Identifier: Apache-2.0 +#pragma once + +#include +#include + +#include "kompute/Core.hpp" + +#include "kompute/Sequence.hpp" +#include "logger/Logger.hpp" + +#define KP_DEFAULT_SESSION "DEFAULT" + +namespace kp { + +/** + Base orchestrator which creates and manages device and child components +*/ +class Manager +{ + public: + /** + Base constructor. + */ + Manager(); + + /** + * Manager destructor which would ensure all owned resources are destroyed + * unless explicitly stated that resources should not be destroyed or freed. + */ + ~Manager(); + + bool hasInstance() const { + return this->mInstance.get(); + } + + bool hasDevice() const { + return this->mDevice.get(); + } + + bool hasVulkan() const { + return this->mDynamicLoader.get(); + } + + /** + * Initialize a device. + * + * @param physicalDeviceIndex The index of the physical device to use + * @param familyQueueIndices (Optional) List of queue indices to add for + * explicit allocation + * @param desiredExtensions The desired extensions to load from + * physicalDevice + */ + void initializeDevice(uint32_t physicalDeviceIndex, + const std::vector& familyQueueIndices = {}, + const std::vector& desiredExtensions = {}); + + /** + * Create a managed sequence that will be destroyed by this manager + * if it hasn't been destroyed by its reference count going to zero. + * + * @param queueIndex The queue to use from the available queues + * @param nrOfTimestamps The maximum number of timestamps to allocate. + * If zero (default), disables latching of timestamps. + * @returns Shared pointer with initialised sequence + */ + std::shared_ptr sequence(uint32_t queueIndex = 0, + uint32_t totalTimestamps = 0); + + /** + * Create a managed tensor that will be destroyed by this manager + * if it hasn't been destroyed by its reference count going to zero. + * + * @param data The data to initialize the tensor with + * @param tensorType The type of tensor to initialize + * @returns Shared pointer with initialised tensor + */ + template + std::shared_ptr> tensorT( + const std::vector& data, + vk::DeviceMemory *primaryMemory, + vk::Buffer *primaryBuffer, + vk::DeviceMemory *stagingMemory, + vk::Buffer *stagingBuffer, + Tensor::TensorTypes tensorType = Tensor::TensorTypes::eDevice) + { + KP_LOG_DEBUG("Kompute Manager tensor creation triggered"); + + std::shared_ptr> tensor{ new kp::TensorT( + this->mPhysicalDevice, this->mDevice, data, primaryMemory, primaryBuffer, stagingMemory, stagingBuffer, tensorType) }; + + if (this->mManageResources) { + this->mManagedTensors.push_back(tensor); + } + + return tensor; + } + + std::shared_ptr tensor( + void* data, + uint32_t elementTotalCount, + uint64_t memorySize, + const Tensor::TensorDataTypes& dataType, + vk::DeviceMemory *primaryMemory, + vk::Buffer *primaryBuffer, + vk::DeviceMemory *stagingMemory, + vk::Buffer *stagingBuffer, + vk::DeviceSize offset, + Tensor::TensorTypes tensorType = Tensor::TensorTypes::eDevice) + { + std::shared_ptr tensor{ new kp::Tensor(this->mPhysicalDevice, + this->mDevice, + data, + elementTotalCount, + memorySize, + dataType, + primaryMemory, + primaryBuffer, + stagingMemory, + stagingBuffer, + offset, + tensorType) }; + + if (this->mManageResources) { + this->mManagedTensors.push_back(tensor); + } + + return tensor; + } + + /** + * Default non-template function that can be used to create algorithm + * objects which provides default types to the push and spec constants as + * floats. + * + * @param tensors (optional) The tensors to initialise the algorithm with + * @param spirv (optional) The SPIRV bytes for the algorithm to dispatch + * @param workgroup (optional) kp::Workgroup for algorithm to use, and + * defaults to (tensor[0].size(), 1, 1) + * @param specializationConstants (optional) float vector to use for + * specialization constants, and defaults to an empty constant + * @param pushConstants (optional) float vector to use for push constants, + * and defaults to an empty constant + * @returns Shared pointer with initialised algorithm + */ + std::shared_ptr algorithm( + const std::string &name, + vk::DescriptorPool *pool, + const std::vector>& tensors = {}, + const std::vector& spirv = {}, + const Workgroup& workgroup = {}, + const std::vector& specializationConstants = {}, + const std::vector& pushConstants = {}) + { + return this->algorithm<>( + name, pool, tensors, spirv, workgroup, specializationConstants, pushConstants); + } + + /** + * Create a managed algorithm that will be destroyed by this manager + * if it hasn't been destroyed by its reference count going to zero. + * + * @param tensors (optional) The tensors to initialise the algorithm with + * @param spirv (optional) The SPIRV bytes for the algorithm to dispatch + * @param workgroup (optional) kp::Workgroup for algorithm to use, and + * defaults to (tensor[0].size(), 1, 1) + * @param specializationConstants (optional) templatable vector parameter to + * use for specialization constants, and defaults to an empty constant + * @param pushConstants (optional) templatable vector parameter to use for + * push constants, and defaults to an empty constant + * @returns Shared pointer with initialised algorithm + */ + template + std::shared_ptr algorithm( + const std::string &name, + vk::DescriptorPool *pool, + const std::vector>& tensors, + const std::vector& spirv, + const Workgroup& workgroup, + const std::vector& specializationConstants, + const std::vector

& pushConstants) + { + + KP_LOG_DEBUG("Kompute Manager algorithm creation triggered"); + + std::shared_ptr algorithm{ new kp::Algorithm( + this->mDevice, + mPipelineCache.get(), + pool, + tensors, + spirv, + workgroup, + specializationConstants, + pushConstants) }; + + if (this->mManageResources) { + this->mManagedAlgorithmsMap.insert({name, algorithm}); + } + + return algorithm; + } + + bool hasAlgorithm(const std::string &name) const { + return mManagedAlgorithmsMap.find(name) != mManagedAlgorithmsMap.end(); + } + + std::shared_ptr getAlgorithm(const std::string &name) const { + auto it = mManagedAlgorithmsMap.find(name); + if (it != mManagedAlgorithmsMap.end()) { + return it->second; + } + return nullptr; + } + + /** + * Destroy the GPU resources and all managed resources by manager. + **/ + void destroy(); + /** + * Run a pseudo-garbage collection to release all the managed resources + * that have been already freed due to these reaching to zero ref count. + **/ + void clear(); + + /** + * Information about the current device. + * + * @return vk::PhysicalDeviceProperties containing information about the + *device + **/ + vk::PhysicalDeviceProperties getDeviceProperties() const; + + /** + * List the devices available in the current vulkan instance. + * + * @return vector of physical devices containing their respective properties + **/ + std::vector listDevices() const; + + /** + * The current Vulkan instance. + * + * @return a shared pointer to the current Vulkan instance held by this + *object + **/ + std::shared_ptr getVkInstance() const; + + std::shared_ptr device() const { return mDevice; } + std::shared_ptr physicalDevice() const { return mPhysicalDevice; } + std::shared_ptr pipelineCache() const { return mPipelineCache; } + + private: + // -------------- OPTIONALLY OWNED RESOURCES + std::shared_ptr mInstance = nullptr; + bool mFreeInstance = false; + std::shared_ptr mPhysicalDevice = nullptr; + std::shared_ptr mDevice = nullptr; + std::shared_ptr mDynamicLoader = nullptr; + bool mFreeDevice = false; + + // -------------- ALWAYS OWNED RESOURCES + std::vector> mManagedTensors; + std::vector> mManagedSequences; + std::unordered_map> mManagedAlgorithmsMap; + + std::vector mComputeQueueFamilyIndices; + std::vector> mComputeQueues; + std::shared_ptr mPipelineCache; + + bool mManageResources = false; + +#ifndef KOMPUTE_DISABLE_VK_DEBUG_LAYERS + vk::DebugReportCallbackEXT mDebugReportCallback; + vk::DispatchLoaderDynamic mDebugDispatcher; +#endif + + // Create functions + void createInstance(); + void createDevice(const std::vector& familyQueueIndices = {}, + uint32_t physicalDeviceIndex = 0, + const std::vector& desiredExtensions = {}); +}; + +} // End namespace kp diff --git a/ggml/src/ggml-kompute/kompute/src/include/kompute/Sequence.hpp b/ggml/src/ggml-kompute/kompute/src/include/kompute/Sequence.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3b29a6e2e66ae237b2cb64fa096a9ff90f7fa1e8 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/include/kompute/Sequence.hpp @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: Apache-2.0 +#pragma once + +#include "kompute/Core.hpp" + +#include "kompute/operations/OpAlgoDispatch.hpp" +#include "kompute/operations/OpBase.hpp" + +namespace kp { + +/** + * Container of operations that can be sent to GPU as batch + */ +class Sequence : public std::enable_shared_from_this +{ + public: + /** + * Main constructor for sequence which requires core vulkan components to + * generate all dependent resources. + * + * @param physicalDevice Vulkan physical device + * @param device Vulkan logical device + * @param computeQueue Vulkan compute queue + * @param queueIndex Vulkan compute queue index in device + * @param totalTimestamps Maximum number of timestamps to allocate + */ + Sequence(std::shared_ptr physicalDevice, + std::shared_ptr device, + std::shared_ptr computeQueue, + uint32_t queueIndex, + uint32_t totalTimestamps = 0); + /** + * Destructor for sequence which is responsible for cleaning all subsequent + * owned operations. + */ + ~Sequence(); + + /** + * Record function for operation to be added to the GPU queue in batch. This + * template requires classes to be derived from the OpBase class. This + * function also requires the Sequence to be recording, otherwise it will + * not be able to add the operation. + * + * @param op Object derived from kp::BaseOp that will be recoreded by the + * sequence which will be used when the operation is evaluated. + * @return shared_ptr of the Sequence class itself + */ + std::shared_ptr record(std::shared_ptr op); + + /** + * Record function for operation to be added to the GPU queue in batch. This + * template requires classes to be derived from the OpBase class. This + * function also requires the Sequence to be recording, otherwise it will + * not be able to add the operation. + * + * @param tensors Vector of tensors to use for the operation + * @param TArgs Template parameters that are used to initialise operation + * which allows for extensible configurations on initialisation. + * @return shared_ptr of the Sequence class itself + */ + template + std::shared_ptr record( + std::vector> tensors, + TArgs&&... params) + { + std::shared_ptr op{ new T(tensors, std::forward(params)...) }; + return this->record(op); + } + /** + * Record function for operation to be added to the GPU queue in batch. This + * template requires classes to be derived from the OpBase class. This + * function also requires the Sequence to be recording, otherwise it will + * not be able to add the operation. + * + * @param algorithm Algorithm to use for the record often used for OpAlgo + * operations + * @param TArgs Template parameters that are used to initialise operation + * which allows for extensible configurations on initialisation. + * @return shared_ptr of the Sequence class itself + */ + template + std::shared_ptr record(std::shared_ptr algorithm, + TArgs&&... params) + { + std::shared_ptr op{ new T(algorithm, + std::forward(params)...) }; + return this->record(op); + } + + /** + * Eval sends all the recorded and stored operations in the vector of + * operations into the gpu as a submit job synchronously (with a barrier). + * + * @return shared_ptr of the Sequence class itself + */ + std::shared_ptr eval(); + + /** + * Resets all the recorded and stored operations, records the operation + * provided and submits into the gpu as a submit job synchronously (with a + * barrier). + * + * @return shared_ptr of the Sequence class itself + */ + std::shared_ptr eval(std::shared_ptr op); + + /** + * Eval sends all the recorded and stored operations in the vector of + * operations into the gpu as a submit job with a barrier. + * + * @param tensors Vector of tensors to use for the operation + * @param TArgs Template parameters that are used to initialise operation + * which allows for extensible configurations on initialisation. + * @return shared_ptr of the Sequence class itself + */ + template + std::shared_ptr eval(std::vector> tensors, + TArgs&&... params) + { + std::shared_ptr op{ new T(tensors, std::forward(params)...) }; + return this->eval(op); + } + + template + std::shared_ptr eval(vk::Buffer *primaryBuffer, + vk::Buffer *stagingBuffer, + vk::DeviceSize size, + TArgs&&... params) + { + std::shared_ptr op{ new T(primaryBuffer, stagingBuffer, size, std::forward(params)...) }; + return this->eval(op); + } + + /** + * Eval sends all the recorded and stored operations in the vector of + * operations into the gpu as a submit job with a barrier. + * + * @param algorithm Algorithm to use for the record often used for OpAlgo + * operations + * @param TArgs Template parameters that are used to initialise operation + * which allows for extensible configurations on initialisation. + * @return shared_ptr of the Sequence class itself + */ + template + std::shared_ptr eval(std::shared_ptr algorithm, + TArgs&&... params) + { + std::shared_ptr op{ new T(algorithm, + std::forward(params)...) }; + return this->eval(op); + } + + /** + * Eval Async sends all the recorded and stored operations in the vector of + * operations into the gpu as a submit job without a barrier. EvalAwait() + * must ALWAYS be called after to ensure the sequence is terminated + * correctly. + * + * @return Boolean stating whether execution was successful. + */ + std::shared_ptr evalAsync(); + /** + * Clears currnet operations to record provided one in the vector of + * operations into the gpu as a submit job without a barrier. EvalAwait() + * must ALWAYS be called after to ensure the sequence is terminated + * correctly. + * + * @return Boolean stating whether execution was successful. + */ + std::shared_ptr evalAsync(std::shared_ptr op); + /** + * Eval sends all the recorded and stored operations in the vector of + * operations into the gpu as a submit job with a barrier. + * + * @param tensors Vector of tensors to use for the operation + * @param TArgs Template parameters that are used to initialise operation + * which allows for extensible configurations on initialisation. + * @return shared_ptr of the Sequence class itself + */ + template + std::shared_ptr evalAsync( + std::vector> tensors, + TArgs&&... params) + { + std::shared_ptr op{ new T(tensors, std::forward(params)...) }; + return this->evalAsync(op); + } + /** + * Eval sends all the recorded and stored operations in the vector of + * operations into the gpu as a submit job with a barrier. + * + * @param algorithm Algorithm to use for the record often used for OpAlgo + * operations + * @param TArgs Template parameters that are used to initialise operation + * which allows for extensible configurations on initialisation. + * @return shared_ptr of the Sequence class itself + */ + template + std::shared_ptr evalAsync(std::shared_ptr algorithm, + TArgs&&... params) + { + std::shared_ptr op{ new T(algorithm, + std::forward(params)...) }; + return this->evalAsync(op); + } + + /** + * Eval Await waits for the fence to finish processing and then once it + * finishes, it runs the postEval of all operations. + * + * @param waitFor Number of milliseconds to wait before timing out. + * @return shared_ptr of the Sequence class itself + */ + std::shared_ptr evalAwait(uint64_t waitFor = UINT64_MAX); + + /** + * Clear function clears all operations currently recorded and starts + * recording again. + */ + void clear(); + + /** + * Return the timestamps that were latched at the beginning and + * after each operation during the last eval() call. + */ + std::vector getTimestamps(); + + /** + * Begins recording commands for commands to be submitted into the command + * buffer. + */ + void begin(); + + /** + * Ends the recording and stops recording commands when the record command + * is sent. + */ + void end(); + + /** + * Returns true if the sequence is currently in recording activated. + * + * @return Boolean stating if recording ongoing. + */ + bool isRecording() const; + + /** + * Returns true if the sequence has been initialised, and it's based on the + * GPU resources being referenced. + * + * @return Boolean stating if is initialized + */ + bool isInit() const; + + /** + * Clears command buffer and triggers re-record of all the current + * operations saved, which is useful if the underlying kp::Tensors or + * kp::Algorithms are modified and need to be re-recorded. + */ + void rerecord(); + + /** + * Returns true if the sequence is currently running - mostly used for async + * workloads. + * + * @return Boolean stating if currently running. + */ + bool isRunning() const; + + /** + * Destroys and frees the GPU resources which include the buffer and memory + * and sets the sequence as init=False. + */ + void destroy(); + + private: + // -------------- NEVER OWNED RESOURCES + std::shared_ptr mPhysicalDevice = nullptr; + std::shared_ptr mDevice = nullptr; + std::shared_ptr mComputeQueue = nullptr; + uint32_t mQueueIndex = -1; + + // -------------- OPTIONALLY OWNED RESOURCES + std::shared_ptr mCommandPool = nullptr; + bool mFreeCommandPool = false; + std::shared_ptr mCommandBuffer = nullptr; + bool mFreeCommandBuffer = false; + + // -------------- ALWAYS OWNED RESOURCES + vk::Fence mFence; + std::vector> mOperations{}; + std::shared_ptr timestampQueryPool = nullptr; + + // State + bool mRecording = false; + bool mIsRunning = false; + + // Create functions + void createCommandPool(); + void createCommandBuffer(); + void createTimestampQueryPool(uint32_t totalTimestamps); +}; + +} // End namespace kp diff --git a/ggml/src/ggml-kompute/kompute/src/include/kompute/Tensor.hpp b/ggml/src/ggml-kompute/kompute/src/include/kompute/Tensor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..20939093da7af21254ae3a535342fec768093d35 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/include/kompute/Tensor.hpp @@ -0,0 +1,302 @@ +// SPDX-License-Identifier: Apache-2.0 +#pragma once + +#include "kompute/Core.hpp" +#include "logger/Logger.hpp" +#include +#include + +namespace kp { + +/** + * Structured data used in GPU operations. + * + * Tensors are the base building block in Kompute to perform operations across + * GPUs. Each tensor would have a respective Vulkan memory and buffer, which + * would be used to store their respective data. The tensors can be used for GPU + * data storage or transfer. + */ +class Tensor +{ + public: + /** + * Type for tensors created: Device allows memory to be transferred from + * staging buffers. Staging are host memory visible. Storage are device + * visible but are not set up to transfer or receive data (only for shader + * storage). + */ + enum class TensorTypes + { + eDevice = 0, ///< Type is device memory, source and destination + eHost = 1, ///< Type is host memory, source and destination + eStorage = 2, ///< Type is Device memory (only) + }; + enum class TensorDataTypes + { + eBool = 0, + eInt = 1, + eUnsignedInt = 2, + eFloat = 3, + eDouble = 4, + }; + + static std::string toString(TensorDataTypes dt); + static std::string toString(TensorTypes dt); + + /** + * Constructor with data provided which would be used to create the + * respective vulkan buffer and memory. + * + * @param physicalDevice The physical device to use to fetch properties + * @param device The device to use to create the buffer and memory from + * @param data Non-zero-sized vector of data that will be used by the + * tensor + * @param tensorTypes Type for the tensor which is of type TensorTypes + */ + Tensor(std::shared_ptr physicalDevice, + std::shared_ptr device, + void* data, + uint32_t elementTotalCount, + uint32_t memorySize, + const TensorDataTypes& dataType, + vk::DeviceMemory *primaryMemory, + vk::Buffer *primaryBuffer, + vk::DeviceMemory *stagingMemory, + vk::Buffer *stagingBuffer, + vk::DeviceSize offset, + const TensorTypes& tensorType = TensorTypes::eDevice); + + /** + * Destructor which is in charge of freeing vulkan resources unless they + * have been provided externally. + */ + virtual ~Tensor(); + + /** + * Function to trigger reinitialisation of the tensor buffer and memory with + * new data as well as new potential device type. + * + * @param data Vector of data to use to initialise vector from + * @param tensorType The type to use for the tensor + */ + void rebuild(void* data, + uint32_t elementTotalCount, + uint64_t memorySize, + vk::DeviceMemory *primaryMemory, + vk::Buffer *primaryBuffer, + vk::DeviceMemory *stagingMemory, + vk::Buffer *stagingBuffer, + vk::DeviceSize offset); + + /** + * Destroys and frees the GPU resources which include the buffer and memory. + */ + void destroy(); + + /** + * Check whether tensor is initialized based on the created gpu resources. + * + * @returns Boolean stating whether tensor is initialized + */ + bool isInit(); + + /** + * Retrieve the tensor type of the Tensor + * + * @return Tensor type of tensor + */ + TensorTypes tensorType(); + + /** + * Records a copy from the memory of the tensor provided to the current + * thensor. This is intended to pass memory into a processing, to perform + * a staging buffer transfer, or to gather output (between others). + * + * @param commandBuffer Vulkan Command Buffer to record the commands into + * @param copyFromTensor Tensor to copy the data from + */ + void recordCopyFrom(const vk::CommandBuffer& commandBuffer, + std::shared_ptr copyFromTensor); + + void recordFill(const vk::CommandBuffer &commandBuffer, + uint32_t fill); + + /** + * Records a copy from the internal staging memory to the device memory + * using an optional barrier to wait for the operation. This function would + * only be relevant for kp::Tensors of type eDevice. + * + * @param commandBuffer Vulkan Command Buffer to record the commands into + */ + void recordCopyFromStagingToDevice(const vk::CommandBuffer& commandBuffer); + + /** + * Records a copy from the internal device memory to the staging memory + * using an optional barrier to wait for the operation. This function would + * only be relevant for kp::Tensors of type eDevice. + * + * @param commandBuffer Vulkan Command Buffer to record the commands into + */ + void recordCopyFromDeviceToStaging(const vk::CommandBuffer& commandBuffer); + + /** + * Records the buffer memory barrier into the primary buffer and command + * buffer which ensures that relevant data transfers are carried out + * correctly. + * + * @param commandBuffer Vulkan Command Buffer to record the commands into + * @param srcAccessMask Access flags for source access mask + * @param dstAccessMask Access flags for destination access mask + * @param scrStageMask Pipeline stage flags for source stage mask + * @param dstStageMask Pipeline stage flags for destination stage mask + */ + void recordPrimaryBufferMemoryBarrier( + const vk::CommandBuffer& commandBuffer, + vk::AccessFlagBits srcAccessMask, + vk::AccessFlagBits dstAccessMask, + vk::PipelineStageFlagBits srcStageMask, + vk::PipelineStageFlagBits dstStageMask); + /** + * Records the buffer memory barrier into the staging buffer and command + * buffer which ensures that relevant data transfers are carried out + * correctly. + * + * @param commandBuffer Vulkan Command Buffer to record the commands into + * @param srcAccessMask Access flags for source access mask + * @param dstAccessMask Access flags for destination access mask + * @param scrStageMask Pipeline stage flags for source stage mask + * @param dstStageMask Pipeline stage flags for destination stage mask + */ + void recordStagingBufferMemoryBarrier( + const vk::CommandBuffer& commandBuffer, + vk::AccessFlagBits srcAccessMask, + vk::AccessFlagBits dstAccessMask, + vk::PipelineStageFlagBits srcStageMask, + vk::PipelineStageFlagBits dstStageMask); + + /** + * Constructs a vulkan descriptor buffer info which can be used to specify + * and reference the underlying buffer component of the tensor without + * exposing it. + * + * @return Descriptor buffer info with own buffer + */ + vk::DescriptorBufferInfo constructDescriptorBufferInfo(); + + /** + * Returns the size/magnitude of the Tensor, which will be the total number + * of elements across all dimensions + * + * @return Unsigned integer representing the total number of elements + */ + uint32_t size(); + + /** + * Returns the total memory size of the data contained by the Tensor object + * + * @return Unsigned integer representing the memory of the tensor in bytes. + */ + uint64_t memorySize(); + + /** + * Retrieve the data type of the tensor (host, device, storage) + * + * @return Data type of tensor of type kp::Tensor::TensorDataTypes + */ + TensorDataTypes dataType(); + + /** + * Retrieve the raw data via the pointer to the memory that contains the raw + * memory of this current tensor. This tensor gets changed to a nullptr when + * the Tensor is removed. + * + * @return Pointer to raw memory containing raw bytes data of Tensor. + */ + void* rawData(); + + /** + * Sets / resets the data of the tensor which is directly done on the GPU + * host visible memory available by the tensor. + */ + void setRawData(const void* data); + + /** + * Template to return the pointer data converted by specific type, which + * would be any of the supported types including float, double, int32, + * uint32 and bool. + * + * @return Pointer to raw memory containing raw bytes data of Tensor. + */ + template + T* data() + { + return (T*)this->mRawData; + } + + /** + * Template to get the data of the current tensor as a vector of specific + * type, which would be any of the supported types including float, double, + * int32, uint32 and bool. + * + * @return Vector of type provided by template. + */ + template + std::vector vector() + { + return { (T*)this->mRawData, ((T*)this->mRawData) + this->size() }; + } + + protected: + // -------------- ALWAYS OWNED RESOURCES + TensorTypes mTensorType; + TensorDataTypes mDataType; + uint32_t mSize = 0; + uint64_t mMemorySize = 0; + vk::DeviceSize mOffset = 0; + void* mRawData = nullptr; + + private: + // -------------- NEVER OWNED RESOURCES + std::shared_ptr mPhysicalDevice; + std::shared_ptr mDevice; + vk::Buffer *mPrimaryBuffer = nullptr; + vk::Buffer *mStagingBuffer = nullptr; + vk::DeviceMemory *mPrimaryMemory = nullptr; + vk::DeviceMemory *mStagingMemory = nullptr; + + void setGPUResources(vk::DeviceMemory *primaryMemory, + vk::Buffer *primaryBuffer, + vk::DeviceMemory *stagingMemory, + vk::Buffer *stagingBuffer, + vk::DeviceSize offset); + void recordCopyBuffer(const vk::CommandBuffer& commandBuffer, + vk::Buffer *bufferFrom, + vk::Buffer *bufferTo, + vk::DeviceSize bufferSize, + vk::BufferCopy copyRegion); + + void recordBufferMemoryBarrier(const vk::CommandBuffer& commandBuffer, + const vk::Buffer& buffer, + vk::AccessFlagBits srcAccessMask, + vk::AccessFlagBits dstAccessMask, + vk::PipelineStageFlagBits srcStageMask, + vk::PipelineStageFlagBits dstStageMask); + + // Private util functions + vk::BufferUsageFlags getPrimaryBufferUsageFlags(); + vk::MemoryPropertyFlags getPrimaryMemoryPropertyFlags(); + vk::BufferUsageFlags getStagingBufferUsageFlags(); + vk::MemoryPropertyFlags getStagingMemoryPropertyFlags(); +}; + +template +class TensorT : public Tensor +{ + + public: + ~TensorT() { KP_LOG_DEBUG("Kompute TensorT destructor"); } + + TensorDataTypes dataType(); +}; + +} // End namespace kp diff --git a/ggml/src/ggml-kompute/kompute/src/include/kompute/logger/Logger.hpp b/ggml/src/ggml-kompute/kompute/src/include/kompute/logger/Logger.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f97e95cf06c4c5b3960ef82cec57a0f4a82c591b --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/include/kompute/logger/Logger.hpp @@ -0,0 +1,197 @@ +#pragma once + +#define KOMPUTE_LOG_LEVEL_TRACE 0 +#define KOMPUTE_LOG_LEVEL_DEBUG 1 +#define KOMPUTE_LOG_LEVEL_INFO 2 +#define KOMPUTE_LOG_LEVEL_WARN 3 +#define KOMPUTE_LOG_LEVEL_ERROR 4 +#define KOMPUTE_LOG_LEVEL_CRITICAL 5 +#define KOMPUTE_LOG_LEVEL_OFF 6 + +// Logging is disabled entirely. +#if KOMPUTE_OPT_LOG_LEVEL_DISABLED +#define KP_LOG_TRACE(...) +#define KP_LOG_DEBUG(...) +#define KP_LOG_INFO(...) +#define KP_LOG_WARN(...) +#define KP_LOG_ERROR(...) +#else + +#if !KOMPUTE_OPT_USE_SPDLOG +#if VK_USE_PLATFORM_ANDROID_KHR +#include +#include +static const char* KOMPUTE_LOG_TAG = "KomputeLog"; +#else +#if KOMPUTE_BUILD_PYTHON +#include +namespace py = pybind11; +// from python/src/main.cpp +extern py::object kp_trace, kp_debug, kp_info, kp_warning, kp_error; +#else +#include +#endif // KOMPUTE_BUILD_PYTHON +#endif // VK_USE_PLATFORM_ANDROID_KHR +#else +#include +#endif // !KOMPUTE_OPT_USE_SPDLOG +#include +#include +#include +namespace logger { +// Setup the logger, note the loglevel can not be set below the CMake log level +// (To change this use -DKOMPUTE_OPT_LOG_LEVEL=...) +void +setupLogger(); + +// Logging is enabled, but we do not use Spdlog. So we use fmt in case nothing +// else is defined, overriding logging. +#if !KOMPUTE_OPT_USE_SPDLOG + +#ifndef KP_LOG_TRACE +#if KOMPUTE_OPT_ACTIVE_LOG_LEVEL <= KOMPUTE_LOG_LEVEL_TRACE +#if VK_USE_PLATFORM_ANDROID_KHR +#define KP_LOG_TRACE(...) \ + ((void)__android_log_write( \ + ANDROID_LOG_VERBOSE, KOMPUTE_LOG_TAG, fmt::format(__VA_ARGS__).c_str())) +#else +#if KOMPUTE_BUILD_PYTHON +#define KP_LOG_DEBUG(...) kp_trace(fmt::format(__VA_ARGS__)) +#else +#define KP_LOG_TRACE(...) \ + fmt::print("[{} {}] [trace] [{}:{}] {}\n", \ + __DATE__, \ + __TIME__, \ + __FILE__, \ + __LINE__, \ + fmt::format(__VA_ARGS__)) +#endif // KOMPUTE_BUILD_PYTHON +#endif // VK_USE_PLATFORM_ANDROID_KHR +#else +#define KP_LOG_TRACE(...) +#endif +#endif // !KP_LOG_TRACE + +#ifndef KP_LOG_DEBUG +#if KOMPUTE_OPT_ACTIVE_LOG_LEVEL <= KOMPUTE_LOG_LEVEL_DEBUG +#if VK_USE_PLATFORM_ANDROID_KHR +#define KP_LOG_DEBUG(...) \ + ((void)__android_log_write( \ + ANDROID_LOG_DEBUG, KOMPUTE_LOG_TAG, fmt::format(__VA_ARGS__).c_str())) +#else +#if KOMPUTE_BUILD_PYTHON +#define KP_LOG_DEBUG(...) kp_debug(fmt::format(__VA_ARGS__)) +#else +#ifdef __FILE_NAME__ // gcc 12 provides only file name without path +#define KP_LOG_DEBUG(...) \ + fmt::print("[{} {}] [debug] [{}:{}] {}\n", \ + __DATE__, \ + __TIME__, \ + __FILE_NAME__, \ + __LINE__, \ + fmt::format(__VA_ARGS__)) +#else +#define KP_LOG_DEBUG(...) \ + fmt::print("[{} {}] [debug] [{}:{}] {}\n", \ + __DATE__, \ + __TIME__, \ + __FILE__, \ + __LINE__, \ + fmt::format(__VA_ARGS__)) +#endif // __FILE__NAME__ +#endif // KOMPUTE_BUILD_PYTHON +#endif // VK_USE_PLATFORM_ANDROID_KHR +#else +#define KP_LOG_DEBUG(...) +#endif +#endif // !KP_LOG_DEBUG + +#ifndef KP_LOG_INFO +#if KOMPUTE_OPT_ACTIVE_LOG_LEVEL <= KOMPUTE_LOG_LEVEL_INFO +#if VK_USE_PLATFORM_ANDROID_KHR +#define KP_LOG_INFO(...) \ + ((void)__android_log_write( \ + ANDROID_LOG_INFO, KOMPUTE_LOG_TAG, fmt::format(__VA_ARGS__).c_str())) +#else +#if KOMPUTE_BUILD_PYTHON +#define KP_LOG_DEBUG(...) kp_info(fmt::format(__VA_ARGS__)) +#else +#define KP_LOG_INFO(...) \ + fmt::print("[{} {}] [info] [{}:{}] {}\n", \ + __DATE__, \ + __TIME__, \ + __FILE__, \ + __LINE__, \ + fmt::format(__VA_ARGS__)) +#endif // KOMPUTE_BUILD_PYTHON +#endif // VK_USE_PLATFORM_ANDROID_KHR +#else +#define KP_LOG_INFO(...) +#endif +#endif // !KP_LOG_INFO + +#ifndef KP_LOG_WARN +#if KOMPUTE_OPT_ACTIVE_LOG_LEVEL <= KOMPUTE_LOG_LEVEL_WARN +#if VK_USE_PLATFORM_ANDROID_KHR +#define KP_LOG_WARN(...) \ + ((void)__android_log_write( \ + ANDROID_LOG_WARN, KOMPUTE_LOG_TAG, fmt::format(__VA_ARGS__).c_str())) +#else +#if KOMPUTE_BUILD_PYTHON +#define KP_LOG_DEBUG(...) kp_warning(fmt::format(__VA_ARGS__)) +#else +#define KP_LOG_WARN(...) \ + fmt::print("[{} {}] [warn] [{}:{}] {}\n", \ + __DATE__, \ + __TIME__, \ + __FILE__, \ + __LINE__, \ + fmt::format(__VA_ARGS__)) +#endif // KOMPUTE_BUILD_PYTHON +#endif // VK_USE_PLATFORM_ANDROID_KHR +#else +#define KP_LOG_WARN(...) +#endif +#endif // !KP_LOG_WARN + +#ifndef KP_LOG_ERROR +#if KOMPUTE_OPT_ACTIVE_LOG_LEVEL <= KOMPUTE_LOG_LEVEL_ERROR +#if VK_USE_PLATFORM_ANDROID_KHR +#define KP_LOG_ERROR(...) \ + ((void)__android_log_write( \ + ANDROID_LOG_ERROR, KOMPUTE_LOG_TAG, fmt::format(__VA_ARGS__).c_str())) +#else +#if KOMPUTE_BUILD_PYTHON +#define KP_LOG_DEBUG(...) kp_error(fmt::format(__VA_ARGS__)) +#else +#define KP_LOG_ERROR(...) \ + fmt::print("[{} {}] [error] [{}:{}] {}\n", \ + __DATE__, \ + __TIME__, \ + __FILE__, \ + __LINE__, \ + fmt::format(__VA_ARGS__)) +#endif // KOMPUTE_BUILD_PYTHON +#endif // VK_USE_PLATFORM_ANDROID_KHR +#else +#define KP_LOG_ERROR(...) +#endif +#endif // !KP_LOG_ERROR +#else + +#define KP_LOG_TRACE(...) SPDLOG_TRACE(__VA_ARGS__) +#define KP_LOG_DEBUG(...) SPDLOG_DEBUG(__VA_ARGS__) +#define KP_LOG_INFO(...) SPDLOG_INFO(__VA_ARGS__) +#define KP_LOG_WARN(...) SPDLOG_WARN(__VA_ARGS__) +#define KP_LOG_ERROR(...) SPDLOG_ERROR(__VA_ARGS__) + +void +setLogLevel(spdlog::level::level_enum level); + +spdlog::level::level_enum +getLogLevel(); + +#endif // !KOMPUTE_OPT_USE_SPDLOG +} // namespace logger + +#endif // KOMPUTE_OPT_LOG_LEVEL_DISABLED diff --git a/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpAlgoDispatch.hpp b/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpAlgoDispatch.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e91598f0562c2d2605f4f522a28f8384fd014fb7 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpAlgoDispatch.hpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: Apache-2.0 +#pragma once + +#include "kompute/Algorithm.hpp" +#include "kompute/Core.hpp" +#include "kompute/Tensor.hpp" +#include "kompute/operations/OpBase.hpp" + +namespace kp { + +/** + * Operation that provides a general abstraction that simplifies the use of + * algorithm and parameter components which can be used with shaders. + * By default it enables the user to provide a dynamic number of tensors + * which are then passed as inputs. + */ +class OpAlgoDispatch : public OpBase +{ + public: + /** + * Constructor that stores the algorithm to use as well as the relevant + * push constants to override when recording. + * + * @param algorithm The algorithm object to use for dispatch + * @param pushConstants The push constants to use for override + */ + template + OpAlgoDispatch(const std::shared_ptr& algorithm, + const std::vector& pushConstants = {}) + { + KP_LOG_DEBUG("Kompute OpAlgoDispatch constructor"); + + this->mAlgorithm = algorithm; + + if (pushConstants.size()) { + uint32_t memorySize = sizeof(decltype(pushConstants.back())); + uint32_t size = pushConstants.size(); + uint32_t totalSize = size * memorySize; + this->mPushConstantsData = malloc(totalSize); + memcpy(this->mPushConstantsData, pushConstants.data(), totalSize); + this->mPushConstantsDataTypeMemorySize = memorySize; + this->mPushConstantsSize = size; + } + } + + /** + * Default destructor, which is in charge of destroying the algorithm + * components but does not destroy the underlying tensors + */ + virtual ~OpAlgoDispatch() override; + + /** + * This records the commands that are to be sent to the GPU. This includes + * the barriers that ensure the memory has been copied before going in and + * out of the shader, as well as the dispatch operation that sends the + * shader processing to the gpu. This function also records the GPU memory + * copy of the output data for the staging buffer so it can be read by the + * host. + * + * @param commandBuffer The command buffer to record the command into. + */ + virtual void record(const vk::CommandBuffer& commandBuffer) override; + + /** + * Does not perform any preEval commands. + * + * @param commandBuffer The command buffer to record the command into. + */ + virtual void preEval(const vk::CommandBuffer& commandBuffer) override; + + /** + * Does not perform any postEval commands. + * + * @param commandBuffer The command buffer to record the command into. + */ + virtual void postEval(const vk::CommandBuffer& commandBuffer) override; + + private: + // -------------- ALWAYS OWNED RESOURCES + std::shared_ptr mAlgorithm; + void* mPushConstantsData = nullptr; + uint32_t mPushConstantsDataTypeMemorySize = 0; + uint32_t mPushConstantsSize = 0; +}; + +} // End namespace kp diff --git a/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpBase.hpp b/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpBase.hpp new file mode 100644 index 0000000000000000000000000000000000000000..737670846350d091adfa764598497bc4b65eb8ca --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpBase.hpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: Apache-2.0 +#pragma once + +#include "kompute/Algorithm.hpp" +#include "kompute/Core.hpp" +#include "kompute/Tensor.hpp" + +namespace kp { + +/** + * Base Operation which provides the high level interface that Kompute + * operations implement in order to perform a set of actions in the GPU. + * + * Operations can perform actions on tensors, and optionally can also own an + * Algorithm with respective parameters. kp::Operations with kp::Algorithms + * would inherit from kp::OpBaseAlgo. + */ +class OpBase +{ + public: + /** + * Default destructor for OpBase class. This OpBase destructor class should + * always be called to destroy and free owned resources unless it is + * intended to destroy the resources in the parent class. + */ + virtual ~OpBase() { KP_LOG_DEBUG("Kompute OpBase destructor started"); } + + /** + * The record function is intended to only send a record command or run + * commands that are expected to record operations that are to be submitted + * as a batch into the GPU. + * + * @param commandBuffer The command buffer to record the command into. + */ + virtual void record(const vk::CommandBuffer& commandBuffer) = 0; + + /** + * Pre eval is called before the Sequence has called eval and submitted the + * commands to the GPU for processing, and can be used to perform any + * per-eval setup steps required as the computation iteration begins. It's + * worth noting that there are situations where eval can be called multiple + * times, so the resources that are created should be idempotent in case + * it's called multiple times in a row. + * + * @param commandBuffer The command buffer to record the command into. + */ + virtual void preEval(const vk::CommandBuffer& commandBuffer) = 0; + + /** + * Post eval is called after the Sequence has called eval and submitted the + * commands to the GPU for processing, and can be used to perform any + * tear-down steps required as the computation iteration finishes. It's + * worth noting that there are situations where eval can be called multiple + * times, so the resources that are destroyed should not require a re-init + * unless explicitly provided by the user. + * + * @param commandBuffer The command buffer to record the command into. + */ + virtual void postEval(const vk::CommandBuffer& commandBuffer) = 0; +}; + +} // End namespace kp diff --git a/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpBufferSyncDevice.hpp b/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpBufferSyncDevice.hpp new file mode 100644 index 0000000000000000000000000000000000000000..50d8e97072412253db9a8fb8fb48723492928a1a --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpBufferSyncDevice.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: Apache-2.0 +#pragma once + +#include "kompute/operations/OpBase.hpp" + +namespace kp { + +class OpBufferSyncDevice : public OpBase +{ + public: + OpBufferSyncDevice( + vk::Buffer *primaryBuffer, + vk::Buffer *stagingBuffer, + vk::DeviceSize size); + + /** + * Default destructor. This class does not manage memory so it won't be + * expecting the parent to perform a release. + */ + ~OpBufferSyncDevice() override; + + /** + * For device buffers, it records the copy command for the buffer to copy + * the data from its staging to device memory. + * + * @param commandBuffer The command buffer to record the command into. + */ + void record(const vk::CommandBuffer& commandBuffer) override; + + /** + * Does not perform any preEval commands. + * + * @param commandBuffer The command buffer to record the command into. + */ + virtual void preEval(const vk::CommandBuffer& commandBuffer) override; + + /** + * Does not perform any postEval commands. + * + * @param commandBuffer The command buffer to record the command into. + */ + virtual void postEval(const vk::CommandBuffer& commandBuffer) override; + + private: + vk::Buffer *mPrimaryBuffer; + vk::Buffer *mStagingBuffer; + vk::DeviceSize mSize; +}; + +} // End namespace kp diff --git a/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpBufferSyncLocal.hpp b/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpBufferSyncLocal.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7db9971991c59929b07a3d9810b6fb288a0796ea --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpBufferSyncLocal.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: Apache-2.0 +#pragma once + +#include "kompute/operations/OpBase.hpp" + +namespace kp { + +class OpBufferSyncLocal : public OpBase +{ + public: + OpBufferSyncLocal( + vk::Buffer *primaryBuffer, + vk::Buffer *stagingBuffer, + vk::DeviceSize size); + + /** + * Default destructor. This class does not manage memory so it won't be + * expecting the parent to perform a release. + */ + ~OpBufferSyncLocal() override; + + /** + * For device buffers, it records the copy command for the buffer to copy + * the data from its staging to device memory. + * + * @param commandBuffer The command buffer to record the command into. + */ + void record(const vk::CommandBuffer& commandBuffer) override; + + /** + * Does not perform any preEval commands. + * + * @param commandBuffer The command buffer to record the command into. + */ + virtual void preEval(const vk::CommandBuffer& commandBuffer) override; + + /** + * Does not perform any postEval commands. + * + * @param commandBuffer The command buffer to record the command into. + */ + virtual void postEval(const vk::CommandBuffer& commandBuffer) override; + + private: + vk::Buffer *mPrimaryBuffer; + vk::Buffer *mStagingBuffer; + vk::DeviceSize mSize; +}; + +} // End namespace kp diff --git a/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpMemoryBarrier.hpp b/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpMemoryBarrier.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4a232232397cf38945ab15a5a205a8b2083d44bf --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpMemoryBarrier.hpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: Apache-2.0 +#pragma once + +#include "kompute/Algorithm.hpp" +#include "kompute/Core.hpp" +#include "kompute/Tensor.hpp" +#include "kompute/operations/OpBase.hpp" + +namespace kp { + +/** + * Operation that provides a general abstraction that simplifies the use of + * algorithm and parameter components which can be used with shaders. + * It exposes the pipeline barrier functionality specifically for memory + * barriers that can be configured through the respective source and destination + * masks + */ +class OpMemoryBarrier : public OpBase +{ + public: + /** + * Constructor that stores tensors as well as memory barrier parameters to + * be used to create a pipeline barrier on the respective primary or staging + * tensor. + * + * @param tensors The tensors to apply the memory barriers on + * @param srcAccessMask The kp::AccessFlagBits for the source access mask + * @param dstAccessMask The kp::AccessFlagBits for the destination access + * mask + * @param srcStageMask The kp::PipelineStageFlagBits for the source stage + * mask + * @param dstStageMask The kp::PipelineStageFlagBits for the destination + * stage mask + * @param barrierOnPrimary Boolean to select primary or secondary buffers on + * tensors + */ + OpMemoryBarrier(const std::vector>& tensors, + const vk::AccessFlagBits& srcAccessMask, + const vk::AccessFlagBits& dstAccessMask, + const vk::PipelineStageFlagBits& srcStageMask, + const vk::PipelineStageFlagBits& dstStageMask, + bool barrierOnPrimary = true); + + /** + * Default destructor, which is in charge of destroying the reference to the + * tensors and all the relevant access / stage masks created + */ + virtual ~OpMemoryBarrier() override; + + /** + * This records the memory barrier with the access and stage masks provided + * across all relevant tensors. + * + * @param commandBuffer The command buffer to record the command into. + */ + virtual void record(const vk::CommandBuffer& commandBuffer) override; + + /** + * Does not perform any preEval commands. + * + * @param commandBuffer The command buffer to record the command into. + */ + virtual void preEval(const vk::CommandBuffer& commandBuffer) override; + + /** + * Does not perform any postEval commands. + * + * @param commandBuffer The command buffer to record the command into. + */ + virtual void postEval(const vk::CommandBuffer& commandBuffer) override; + + private: + const vk::AccessFlagBits mSrcAccessMask; + const vk::AccessFlagBits mDstAccessMask; + const vk::PipelineStageFlagBits mSrcStageMask; + const vk::PipelineStageFlagBits mDstStageMask; + const bool mBarrierOnPrimary; + const std::vector> mTensors; +}; + +} // End namespace kp diff --git a/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpMult.hpp b/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpMult.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f75ccc4fbb763ec615dce6f09d66ce306d048f91 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpMult.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: Apache-2.0 +#pragma once + +#include + +#include "kompute/Core.hpp" + +#include "ShaderOpMult.hpp" + +#include "kompute/Algorithm.hpp" +#include "kompute/Tensor.hpp" + +#include "kompute/operations/OpAlgoDispatch.hpp" + +namespace kp { + +/** + * Operation that performs multiplication on two tensors and outpus on third + * tensor. + */ +class OpMult : public OpAlgoDispatch +{ + public: + /** + * Default constructor with parameters that provides the bare minimum + * requirements for the operations to be able to create and manage their + * sub-components. + * + * @param tensors Tensors that are to be used in this operation + * @param algorithm An algorithm that will be overridden with the OpMult + * shader data and the tensors provided which are expected to be 3 + */ + OpMult(std::vector> tensors, + std::shared_ptr algorithm) + : OpAlgoDispatch(algorithm) + { + KP_LOG_DEBUG("Kompute OpMult constructor with params"); + + if (tensors.size() != 3) { + throw std::runtime_error( + "Kompute OpMult expected 3 tensors but got " + + std::to_string(tensors.size())); + } + + const std::vector spirv = std::vector( + SHADEROPMULT_COMP_SPV.begin(), SHADEROPMULT_COMP_SPV.end()); + + algorithm->rebuild<>(tensors, spirv); + } + + /** + * Default destructor, which is in charge of destroying the algorithm + * components but does not destroy the underlying tensors + */ + ~OpMult() override { KP_LOG_DEBUG("Kompute OpMult destructor started"); } +}; + +} // End namespace kp diff --git a/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpTensorCopy.hpp b/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpTensorCopy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..968c1065a3388ebe521873d815f9357020493009 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpTensorCopy.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: Apache-2.0 +#pragma once + +#include "kompute/Core.hpp" + +#include "kompute/Tensor.hpp" + +#include "kompute/operations/OpBase.hpp" + +namespace kp { + +/** + * Operation that copies the data from the first tensor to the rest of the + * tensors provided, using a record command for all the vectors. This operation + * does not own/manage the memory of the tensors passed to it. The operation + * must only receive tensors of type + */ +class OpTensorCopy : public OpBase +{ + public: + /** + * Default constructor with parameters that provides the core vulkan + * resources and the tensors that will be used in the operation. + * + * @param tensors Tensors that will be used to create in operation. + */ + OpTensorCopy(const std::vector>& tensors); + + /** + * Default destructor. This class does not manage memory so it won't be + * expecting the parent to perform a release. + */ + ~OpTensorCopy() override; + + /** + * Records the copy commands from the first tensor into all the other + * tensors provided. Also optionally records a barrier. + * + * @param commandBuffer The command buffer to record the command into. + */ + void record(const vk::CommandBuffer& commandBuffer) override; + + /** + * Does not perform any preEval commands. + * + * @param commandBuffer The command buffer to record the command into. + */ + virtual void preEval(const vk::CommandBuffer& commandBuffer) override; + + /** + * Copies the local vectors for all the tensors to sync the data with the + * gpu. + * + * @param commandBuffer The command buffer to record the command into. + */ + virtual void postEval(const vk::CommandBuffer& commandBuffer) override; + + private: + // -------------- ALWAYS OWNED RESOURCES + std::vector> mTensors; +}; + +} // End namespace kp diff --git a/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpTensorFill.hpp b/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpTensorFill.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9a6bf131e88f0608b262bdd6c4342d6e5e16b56f --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpTensorFill.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: Apache-2.0 +#pragma once + +#include "kompute/Core.hpp" + +#include "kompute/Tensor.hpp" + +#include "kompute/operations/OpBase.hpp" + +namespace kp { + +/** + * Operation that fills the tensor + */ +class OpTensorFill : public OpBase +{ + public: + /** + * Default constructor with parameters that provides the core vulkan + * resources and the tensors that will be used in the operation. + * + * @param tensors Tensors that will be used to create in operation. + */ + OpTensorFill(const std::vector>& tensors); + + /** + * Default destructor. This class does not manage memory so it won't be + * expecting the parent to perform a release. + */ + ~OpTensorFill() override; + + /** + * Records the fill command for tensor. + * + * @param commandBuffer The command buffer to record the command into. + */ + void record(const vk::CommandBuffer& commandBuffer) override; + + /** + * Does not perform any preEval commands. + * + * @param commandBuffer The command buffer to record the command into. + */ + virtual void preEval(const vk::CommandBuffer& commandBuffer) override; + + /** + * Does not perform any postEval commands. + * + * @param commandBuffer The command buffer to record the command into. + */ + virtual void postEval(const vk::CommandBuffer& commandBuffer) override; + + private: + // -------------- ALWAYS OWNED RESOURCES + std::vector> mTensors; +}; + +} // End namespace kp diff --git a/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpTensorSyncDevice.hpp b/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpTensorSyncDevice.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3a1792ac6217105ea75f545daa10cf41f35bf426 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpTensorSyncDevice.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: Apache-2.0 +#pragma once + +#include "kompute/Core.hpp" +#include "kompute/Tensor.hpp" +#include "kompute/operations/OpBase.hpp" + +namespace kp { + +/** + * Operation that syncs tensor's device by mapping local data into the device + * memory. For TensorTypes::eDevice it will use a record operation for the + * memory to be syncd into GPU memory which means that the operation will be + * done in sync with GPU commands. For TensorTypes::eHost it will only map the + * data into host memory which will happen during preEval before the recorded + * commands are dispatched. + */ +class OpTensorSyncDevice : public OpBase +{ + public: + /** + * Default constructor with parameters that provides the core vulkan + * resources and the tensors that will be used in the operation. The tensos + * provided cannot be of type TensorTypes::eStorage. + * + * @param tensors Tensors that will be used to create in operation. + */ + OpTensorSyncDevice(const std::vector>& tensors); + + /** + * Default destructor. This class does not manage memory so it won't be + * expecting the parent to perform a release. + */ + ~OpTensorSyncDevice() override; + + /** + * For device tensors, it records the copy command for the tensor to copy + * the data from its staging to device memory. + * + * @param commandBuffer The command buffer to record the command into. + */ + void record(const vk::CommandBuffer& commandBuffer) override; + + /** + * Does not perform any preEval commands. + * + * @param commandBuffer The command buffer to record the command into. + */ + virtual void preEval(const vk::CommandBuffer& commandBuffer) override; + + /** + * Does not perform any postEval commands. + * + * @param commandBuffer The command buffer to record the command into. + */ + virtual void postEval(const vk::CommandBuffer& commandBuffer) override; + + private: + // -------------- ALWAYS OWNED RESOURCES + std::vector> mTensors; +}; + +} // End namespace kp diff --git a/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpTensorSyncLocal.hpp b/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpTensorSyncLocal.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4216003e530c58ad01e78fefefbba43562f29eef --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/include/kompute/operations/OpTensorSyncLocal.hpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: Apache-2.0 +#pragma once + +#include "kompute/Core.hpp" + +#include "kompute/Tensor.hpp" + +#include "kompute/operations/OpBase.hpp" + +namespace kp { + +/** + * Operation that syncs tensor's local memory by mapping device data into the + * local CPU memory. For TensorTypes::eDevice it will use a record operation + * for the memory to be syncd into GPU memory which means that the operation + * will be done in sync with GPU commands. For TensorTypes::eHost it will + * only map the data into host memory which will happen during preEval before + * the recorded commands are dispatched. + */ +class OpTensorSyncLocal : public OpBase +{ + public: + /** + * Default constructor with parameters that provides the core vulkan + * resources and the tensors that will be used in the operation. The tensors + * provided cannot be of type TensorTypes::eStorage. + * + * @param tensors Tensors that will be used to create in operation. + */ + OpTensorSyncLocal(const std::vector>& tensors); + + /** + * Default destructor. This class does not manage memory so it won't be + * expecting the parent to perform a release. + */ + ~OpTensorSyncLocal() override; + + /** + * For device tensors, it records the copy command for the tensor to copy + * the data from its device to staging memory. + * + * @param commandBuffer The command buffer to record the command into. + */ + void record(const vk::CommandBuffer& commandBuffer) override; + + /** + * Does not perform any preEval commands. + * + * @param commandBuffer The command buffer to record the command into. + */ + virtual void preEval(const vk::CommandBuffer& commandBuffer) override; + + /** + * For host tensors it performs the map command from the host memory into + * local memory. + * + * @param commandBuffer The command buffer to record the command into. + */ + virtual void postEval(const vk::CommandBuffer& commandBuffer) override; + + private: + // -------------- ALWAYS OWNED RESOURCES + std::vector> mTensors; +}; + +} // End namespace kp diff --git a/ggml/src/ggml-kompute/kompute/src/logger/CMakeLists.txt b/ggml/src/ggml-kompute/kompute/src/logger/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1f8695acd26736f7c582fd0b09dccd209aca7253 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/logger/CMakeLists.txt @@ -0,0 +1,69 @@ +cmake_minimum_required(VERSION 3.20) + +set(LOGGER_SOURCES Logger.cpp) + +add_library(kp_logger STATIC ${LOGGER_SOURCES}) + +# Define log levels in code +add_compile_definitions(KOMPUTE_LOG_LEVEL_TRACE=0) +add_compile_definitions(KOMPUTE_LOG_LEVEL_DEBUG=1) +add_compile_definitions(KOMPUTE_LOG_LEVEL_INFO=2) +add_compile_definitions(KOMPUTE_LOG_LEVEL_WARN=3) +add_compile_definitions(KOMPUTE_LOG_LEVEL_ERROR=4) +add_compile_definitions(KOMPUTE_LOG_LEVEL_CRITICAL=5) +add_compile_definitions(KOMPUTE_LOG_LEVEL_OFF=6) + +if(KOMPUTE_OPT_BUILD_PYTHON AND KOMPUTE_OPT_USE_SPDLOG) + message(FATAL_ERROR "'KOMPUTE_OPT_BUILD_PYTHON' is incompatible with 'KOMPUTE_OPT_USE_SPDLOG'. To continue set either one option to 'OFF'.") +endif() + +if(KOMPUTE_OPT_ANDROID_BUILD AND KOMPUTE_OPT_USE_SPDLOG) + message(FATAL_ERROR "'KOMPUTE_OPT_ANDROID_BUILD' is incompatible with 'KOMPUTE_OPT_USE_SPDLOG'. To continue set either one option to 'OFF'.") +endif() + +if(${KOMPUTE_OPT_LOG_LEVEL} STREQUAL "Trace") + set(KOMPUTE_OPT_LOG_LEVEL TRACE) + message(STATUS "Using log level Trace") +elseif(${KOMPUTE_OPT_LOG_LEVEL} STREQUAL "Debug") + set(KOMPUTE_OPT_LOG_LEVEL DEBUG) + message(STATUS "Using log level Debug") +elseif(${KOMPUTE_OPT_LOG_LEVEL} STREQUAL "Info") + set(KOMPUTE_OPT_LOG_LEVEL INFO) + message(STATUS "Using log level Info") +elseif(${KOMPUTE_OPT_LOG_LEVEL} STREQUAL "Warn") + set(KOMPUTE_OPT_LOG_LEVEL WARN) + message(STATUS "Using log level Warn") +elseif(${KOMPUTE_OPT_LOG_LEVEL} STREQUAL "Error") + set(KOMPUTE_OPT_LOG_LEVEL ERROR) + message(STATUS "Using log level Error") +elseif(${KOMPUTE_OPT_LOG_LEVEL} STREQUAL "Critical") + set(KOMPUTE_OPT_LOG_LEVEL CRITICAL) + message(STATUS "Using log level Critical") +elseif(${KOMPUTE_OPT_LOG_LEVEL} STREQUAL "Off") + set(KOMPUTE_OPT_LOG_LEVEL OFF) + message(STATUS "Using log level Off") +elseif(${KOMPUTE_OPT_LOG_LEVEL} STREQUAL "Default") + set(KOMPUTE_OPT_LOG_LEVEL $,DEBUG,INFO>) + message(STATUS "Setting KOMPUTE_OPT_LOG_LEVEL to according to the build type") +else() + message(FATAL_ERROR "Log level '${KOMPUTE_OPT_LOG_LEVEL}' unknown, use -DKOMPUTE_OPT_LOG_LEVEL={Trace, Debug, Info, Warn, Error, Critical, Off, Default} to set it to a correct value.") +endif() + +# Always make sure we define the Kompute log level independent of the Spdlog log level +target_compile_definitions(kp_logger INTERFACE KOMPUTE_OPT_ACTIVE_LOG_LEVEL=KOMPUTE_LOG_LEVEL_${KOMPUTE_OPT_LOG_LEVEL}) + +# Link depending on how the logger should be setup +if(NOT KOMPUTE_OPT_LOG_LEVEL_DISABLED) + if(KOMPUTE_OPT_USE_SPDLOG) + target_link_libraries(kp_logger PUBLIC spdlog::spdlog) + target_compile_definitions(spdlog INTERFACE SPDLOG_ACTIVE_LEVEL=SPDLOG_LEVEL_${KOMPUTE_OPT_LOG_LEVEL}) + target_compile_definitions(kp_logger INTERFACE SPDLOG_ACTIVE_LEVEL=SPDLOG_LEVEL_${KOMPUTE_OPT_LOG_LEVEL}) + message(STATUS "setting SPDLOG_ACTIVE_LEVEL to SPDLOG_LEVEL_${KOMPUTE_OPT_LOG_LEVEL}") + + if(KOMPUTE_OPT_SPDLOG_ASYNC_MODE) + target_compile_definitions(kp_logger INTERFACE KOMPUTE_SPDLOG_ASYNC_LOGGING=1) + endif() + else() + target_link_libraries(kp_logger PUBLIC fmt::fmt) + endif() +endif() diff --git a/ggml/src/ggml-kompute/kompute/src/logger/Logger.cpp b/ggml/src/ggml-kompute/kompute/src/logger/Logger.cpp new file mode 100644 index 0000000000000000000000000000000000000000..69df2b609610c0776008bcf9bf0b6a658584a801 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/logger/Logger.cpp @@ -0,0 +1,101 @@ +#include "kompute/logger/Logger.hpp" + +#if !KOMPUTE_OPT_LOG_LEVEL_DISABLED +#if !KOMPUTE_OPT_USE_SPDLOG +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif // !KOMPUTE_OPT_USE_SPDLOG + +namespace logger { +#if !KOMPUTE_OPT_USE_SPDLOG + +void +setupLogger() +{ +} + +#else +constexpr int THREAD_QUEUE_LENGTH = 8192; + +void +setupLogger() +{ + // Ensure we setup the logger only once + static bool setup = false; + static std::mutex setupMutex{}; + setupMutex.lock(); + if (setup) { + setupMutex.unlock(); + return; + } + setup = true; + setupMutex.unlock(); + + spdlog::init_thread_pool(THREAD_QUEUE_LENGTH, 1); + spdlog::sink_ptr console_sink = + std::make_shared(); +#if SPDLOG_ACTIVE_LEVEL < SPDLOG_LEVEL_INFO + console_sink->set_pattern("[%H:%M:%S %z] [%^%=9l%$] [%=21s] %v"); +#else + console_sink->set_pattern("[%H:%M:%S %z] [%^%=9l%$] [%=15s] %v"); +#endif + std::vector sinks{ console_sink }; + // TODO: Add flag in compile flags + std::shared_ptr logger = +#if KOMPUTE_SPDLOG_ASYNC_LOGGING + std::make_shared( + "", + sinks.begin(), + sinks.end(), + spdlog::thread_pool(), + spdlog::async_overflow_policy::block); +#else + std::make_shared( + "", + sinks.begin(), + sinks.end()); +#endif + + logger->set_level(getLogLevel()); + + spdlog::set_default_logger(logger); +} + +spdlog::level::level_enum +getLogLevel() +{ +#if SPDLOG_ACTIVE_LEVEL == SPDLOG_LEVEL_TRACE + return spdlog::level::trace; +#elif SPDLOG_ACTIVE_LEVEL == SPDLOG_LEVEL_DEBUG + return spdlog::level::debug; +#elif SPDLOG_ACTIVE_LEVEL == SPDLOG_LEVEL_INFO + return spdlog::level::info; +#elif SPDLOG_ACTIVE_LEVEL == SPDLOG_LEVEL_WARN + return spdlog::level::warn; +#elif SPDLOG_ACTIVE_LEVEL == SPDLOG_LEVEL_ERROR + return spdlog::level::error; +#elif SPDLOG_ACTIVE_LEVEL == SPDLOG_LEVEL_CRITICAL + return spdlog::level::critical; +#else + return spdlog::level::off; +#endif +} + +void +setLogLevel(const spdlog::level::level_enum level) +{ + spdlog::default_logger()->set_level(level); +} +#endif // !KOMPUTE_OPT_USE_SPDLOG +} // namespace logger + +#endif diff --git a/ggml/src/ggml-kompute/kompute/src/shaders/CMakeLists.txt b/ggml/src/ggml-kompute/kompute/src/shaders/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..901bf3e8a8af296b9e5f7b6cf04fd4b92f723ef9 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/shaders/CMakeLists.txt @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 +# ###################### +cmake_minimum_required(VERSION 3.20) + +add_subdirectory(glsl) \ No newline at end of file diff --git a/ggml/src/ggml-kompute/kompute/src/shaders/glsl/CMakeLists.txt b/ggml/src/ggml-kompute/kompute/src/shaders/glsl/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..3101a2b17b7518586fcd800d9eae67102fd9cb2b --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/shaders/glsl/CMakeLists.txt @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 +# ###################### +cmake_minimum_required(VERSION 3.20) + +# Check if build shaders from source is enabled +if(KOMPUTE_OPT_BUILD_SHADERS) + vulkan_compile_shader(INFILE ShaderOpMult.comp + OUTFILE ShaderOpMult.hpp + NAMESPACE "kp") + + vulkan_compile_shader(INFILE ShaderLogisticRegression.comp + OUTFILE ShaderLogisticRegression.hpp + NAMESPACE "kp") +else() # Else we will use our precompiled versions + add_custom_command(OUTPUT $/ShaderOpMult.hpp COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/ShaderOpMult.hpp.in $/ShaderOpMult.hpp) + add_custom_command(OUTPUT $/ShaderLogisticRegression.hpp COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/ShaderLogisticRegression.hpp.in $/ShaderLogisticRegression.hpp) +endif() + +add_library(kp_shader INTERFACE "${CMAKE_CURRENT_BINARY_DIR}/ShaderOpMult.hpp" + "${CMAKE_CURRENT_BINARY_DIR}/ShaderLogisticRegression.hpp") + +target_include_directories(kp_shader INTERFACE $) + +# Make sure we install shaders: +install(FILES $/ShaderOpMult.hpp DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) +install(FILES $/ShaderLogisticRegression.hpp DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) diff --git a/ggml/src/ggml-kompute/kompute/src/shaders/glsl/ShaderLogisticRegression.comp b/ggml/src/ggml-kompute/kompute/src/shaders/glsl/ShaderLogisticRegression.comp new file mode 100644 index 0000000000000000000000000000000000000000..5a1c5d94867548a59dc35683fffff596066baab9 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/shaders/glsl/ShaderLogisticRegression.comp @@ -0,0 +1,52 @@ +#version 450 + +layout (constant_id = 0) const float m = 0; + +layout (local_size_x = 1) in; + +layout(set = 0, binding = 0) buffer bxi { float xi[]; }; +layout(set = 0, binding = 1) buffer bxj { float xj[]; }; +layout(set = 0, binding = 2) buffer by { float y[]; }; +layout(set = 0, binding = 3) buffer bwin { float win[]; }; +layout(set = 0, binding = 4) buffer bwouti { float wouti[]; }; +layout(set = 0, binding = 5) buffer bwoutj { float woutj[]; }; +layout(set = 0, binding = 6) buffer bbin { float bin[]; }; +layout(set = 0, binding = 7) buffer bbout { float bout[]; }; +layout(set = 0, binding = 8) buffer blout { float lout[]; }; + +float sigmoid(float z) { + return 1.0 / (1.0 + exp(-z)); +} + +float inference(vec2 x, vec2 w, float b) { + // Compute the linear mapping function + float z = dot(w, x) + b; + // Calculate the y-hat with sigmoid + float yHat = sigmoid(z); + return yHat; +} + +float calculateLoss(float yHat, float y) { + return -(y * log(yHat) + (1.0 - y) * log(1.0 - yHat)); +} + +void main() { + uint idx = gl_GlobalInvocationID.x; + + vec2 wCurr = vec2(win[0], win[1]); + float bCurr = bin[0]; + + vec2 xCurr = vec2(xi[idx], xj[idx]); + float yCurr = y[idx]; + + float yHat = inference(xCurr, wCurr, bCurr); + + float dZ = yHat - yCurr; + vec2 dW = (1. / m) * xCurr * dZ; + float dB = (1. / m) * dZ; + wouti[idx] = dW.x; + woutj[idx] = dW.y; + bout[idx] = dB; + + lout[idx] = calculateLoss(yHat, yCurr); +} diff --git a/ggml/src/ggml-kompute/kompute/src/shaders/glsl/ShaderLogisticRegression.hpp.in b/ggml/src/ggml-kompute/kompute/src/shaders/glsl/ShaderLogisticRegression.hpp.in new file mode 100644 index 0000000000000000000000000000000000000000..bfe7792c6c8d906ccac732f6c1744deff0d9372c --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/shaders/glsl/ShaderLogisticRegression.hpp.in @@ -0,0 +1,310 @@ +#pragma once +#include +#include + +namespace kp { +const std::array SHADERLOGISTICREGRESSION_COMP_SPV = { +0x07230203, 0x00010000, 0x0008000a, 0x000000ae, +0x00000000, 0x00020011, 0x00000001, 0x0006000b, +0x00000001, 0x4c534c47, 0x6474732e, 0x3035342e, +0x00000000, 0x0003000e, 0x00000000, 0x00000001, +0x0006000f, 0x00000005, 0x00000004, 0x6e69616d, +0x00000000, 0x00000041, 0x00060010, 0x00000004, +0x00000011, 0x00000001, 0x00000001, 0x00000001, +0x00030003, 0x00000002, 0x000001c2, 0x00040005, +0x00000004, 0x6e69616d, 0x00000000, 0x00050005, +0x0000000a, 0x6d676973, 0x2864696f, 0x003b3166, +0x00030005, 0x00000009, 0x0000007a, 0x00080005, +0x00000012, 0x65666e69, 0x636e6572, 0x66762865, +0x66763b32, 0x31663b32, 0x0000003b, 0x00030005, +0x0000000f, 0x00000078, 0x00030005, 0x00000010, +0x00000077, 0x00030005, 0x00000011, 0x00000062, +0x00080005, 0x00000017, 0x636c6163, 0x74616c75, +0x736f4c65, 0x31662873, 0x3b31663b, 0x00000000, +0x00040005, 0x00000015, 0x74614879, 0x00000000, +0x00030005, 0x00000016, 0x00000079, 0x00030005, +0x00000021, 0x0000007a, 0x00040005, 0x00000027, +0x74614879, 0x00000000, 0x00040005, 0x00000028, +0x61726170, 0x0000006d, 0x00030005, 0x0000003e, +0x00786469, 0x00080005, 0x00000041, 0x475f6c67, +0x61626f6c, 0x766e496c, 0x7461636f, 0x496e6f69, +0x00000044, 0x00040005, 0x00000046, 0x72754377, +0x00000072, 0x00040005, 0x00000048, 0x6e697762, +0x00000000, 0x00040006, 0x00000048, 0x00000000, +0x006e6977, 0x00030005, 0x0000004a, 0x00000000, +0x00040005, 0x00000054, 0x72754362, 0x00000072, +0x00040005, 0x00000056, 0x6e696262, 0x00000000, +0x00040006, 0x00000056, 0x00000000, 0x006e6962, +0x00030005, 0x00000058, 0x00000000, 0x00040005, +0x0000005b, 0x72754378, 0x00000072, 0x00030005, +0x0000005d, 0x00697862, 0x00040006, 0x0000005d, +0x00000000, 0x00006978, 0x00030005, 0x0000005f, +0x00000000, 0x00030005, 0x00000064, 0x006a7862, +0x00040006, 0x00000064, 0x00000000, 0x00006a78, +0x00030005, 0x00000066, 0x00000000, 0x00040005, +0x0000006b, 0x72754379, 0x00000072, 0x00030005, +0x0000006d, 0x00007962, 0x00040006, 0x0000006d, +0x00000000, 0x00000079, 0x00030005, 0x0000006f, +0x00000000, 0x00040005, 0x00000073, 0x74614879, +0x00000000, 0x00040005, 0x00000074, 0x61726170, +0x0000006d, 0x00040005, 0x00000076, 0x61726170, +0x0000006d, 0x00040005, 0x00000078, 0x61726170, +0x0000006d, 0x00030005, 0x0000007b, 0x00005a64, +0x00030005, 0x0000007f, 0x00005764, 0x00030005, +0x00000080, 0x0000006d, 0x00030005, 0x00000086, +0x00004264, 0x00040005, 0x0000008b, 0x756f7762, +0x00006974, 0x00050006, 0x0000008b, 0x00000000, +0x74756f77, 0x00000069, 0x00030005, 0x0000008d, +0x00000000, 0x00040005, 0x00000093, 0x756f7762, +0x00006a74, 0x00050006, 0x00000093, 0x00000000, +0x74756f77, 0x0000006a, 0x00030005, 0x00000095, +0x00000000, 0x00040005, 0x0000009c, 0x756f6262, +0x00000074, 0x00050006, 0x0000009c, 0x00000000, +0x74756f62, 0x00000000, 0x00030005, 0x0000009e, +0x00000000, 0x00040005, 0x000000a3, 0x756f6c62, +0x00000074, 0x00050006, 0x000000a3, 0x00000000, +0x74756f6c, 0x00000000, 0x00030005, 0x000000a5, +0x00000000, 0x00040005, 0x000000a7, 0x61726170, +0x0000006d, 0x00040005, 0x000000a9, 0x61726170, +0x0000006d, 0x00040047, 0x00000041, 0x0000000b, +0x0000001c, 0x00040047, 0x00000047, 0x00000006, +0x00000004, 0x00050048, 0x00000048, 0x00000000, +0x00000023, 0x00000000, 0x00030047, 0x00000048, +0x00000003, 0x00040047, 0x0000004a, 0x00000022, +0x00000000, 0x00040047, 0x0000004a, 0x00000021, +0x00000003, 0x00040047, 0x00000055, 0x00000006, +0x00000004, 0x00050048, 0x00000056, 0x00000000, +0x00000023, 0x00000000, 0x00030047, 0x00000056, +0x00000003, 0x00040047, 0x00000058, 0x00000022, +0x00000000, 0x00040047, 0x00000058, 0x00000021, +0x00000006, 0x00040047, 0x0000005c, 0x00000006, +0x00000004, 0x00050048, 0x0000005d, 0x00000000, +0x00000023, 0x00000000, 0x00030047, 0x0000005d, +0x00000003, 0x00040047, 0x0000005f, 0x00000022, +0x00000000, 0x00040047, 0x0000005f, 0x00000021, +0x00000000, 0x00040047, 0x00000063, 0x00000006, +0x00000004, 0x00050048, 0x00000064, 0x00000000, +0x00000023, 0x00000000, 0x00030047, 0x00000064, +0x00000003, 0x00040047, 0x00000066, 0x00000022, +0x00000000, 0x00040047, 0x00000066, 0x00000021, +0x00000001, 0x00040047, 0x0000006c, 0x00000006, +0x00000004, 0x00050048, 0x0000006d, 0x00000000, +0x00000023, 0x00000000, 0x00030047, 0x0000006d, +0x00000003, 0x00040047, 0x0000006f, 0x00000022, +0x00000000, 0x00040047, 0x0000006f, 0x00000021, +0x00000002, 0x00040047, 0x00000080, 0x00000001, +0x00000000, 0x00040047, 0x0000008a, 0x00000006, +0x00000004, 0x00050048, 0x0000008b, 0x00000000, +0x00000023, 0x00000000, 0x00030047, 0x0000008b, +0x00000003, 0x00040047, 0x0000008d, 0x00000022, +0x00000000, 0x00040047, 0x0000008d, 0x00000021, +0x00000004, 0x00040047, 0x00000092, 0x00000006, +0x00000004, 0x00050048, 0x00000093, 0x00000000, +0x00000023, 0x00000000, 0x00030047, 0x00000093, +0x00000003, 0x00040047, 0x00000095, 0x00000022, +0x00000000, 0x00040047, 0x00000095, 0x00000021, +0x00000005, 0x00040047, 0x0000009b, 0x00000006, +0x00000004, 0x00050048, 0x0000009c, 0x00000000, +0x00000023, 0x00000000, 0x00030047, 0x0000009c, +0x00000003, 0x00040047, 0x0000009e, 0x00000022, +0x00000000, 0x00040047, 0x0000009e, 0x00000021, +0x00000007, 0x00040047, 0x000000a2, 0x00000006, +0x00000004, 0x00050048, 0x000000a3, 0x00000000, +0x00000023, 0x00000000, 0x00030047, 0x000000a3, +0x00000003, 0x00040047, 0x000000a5, 0x00000022, +0x00000000, 0x00040047, 0x000000a5, 0x00000021, +0x00000008, 0x00040047, 0x000000ad, 0x0000000b, +0x00000019, 0x00020013, 0x00000002, 0x00030021, +0x00000003, 0x00000002, 0x00030016, 0x00000006, +0x00000020, 0x00040020, 0x00000007, 0x00000007, +0x00000006, 0x00040021, 0x00000008, 0x00000006, +0x00000007, 0x00040017, 0x0000000c, 0x00000006, +0x00000002, 0x00040020, 0x0000000d, 0x00000007, +0x0000000c, 0x00060021, 0x0000000e, 0x00000006, +0x0000000d, 0x0000000d, 0x00000007, 0x00050021, +0x00000014, 0x00000006, 0x00000007, 0x00000007, +0x0004002b, 0x00000006, 0x00000019, 0x3f800000, +0x00040015, 0x0000003c, 0x00000020, 0x00000000, +0x00040020, 0x0000003d, 0x00000007, 0x0000003c, +0x00040017, 0x0000003f, 0x0000003c, 0x00000003, +0x00040020, 0x00000040, 0x00000001, 0x0000003f, +0x0004003b, 0x00000040, 0x00000041, 0x00000001, +0x0004002b, 0x0000003c, 0x00000042, 0x00000000, +0x00040020, 0x00000043, 0x00000001, 0x0000003c, +0x0003001d, 0x00000047, 0x00000006, 0x0003001e, +0x00000048, 0x00000047, 0x00040020, 0x00000049, +0x00000002, 0x00000048, 0x0004003b, 0x00000049, +0x0000004a, 0x00000002, 0x00040015, 0x0000004b, +0x00000020, 0x00000001, 0x0004002b, 0x0000004b, +0x0000004c, 0x00000000, 0x00040020, 0x0000004d, +0x00000002, 0x00000006, 0x0004002b, 0x0000004b, +0x00000050, 0x00000001, 0x0003001d, 0x00000055, +0x00000006, 0x0003001e, 0x00000056, 0x00000055, +0x00040020, 0x00000057, 0x00000002, 0x00000056, +0x0004003b, 0x00000057, 0x00000058, 0x00000002, +0x0003001d, 0x0000005c, 0x00000006, 0x0003001e, +0x0000005d, 0x0000005c, 0x00040020, 0x0000005e, +0x00000002, 0x0000005d, 0x0004003b, 0x0000005e, +0x0000005f, 0x00000002, 0x0003001d, 0x00000063, +0x00000006, 0x0003001e, 0x00000064, 0x00000063, +0x00040020, 0x00000065, 0x00000002, 0x00000064, +0x0004003b, 0x00000065, 0x00000066, 0x00000002, +0x0003001d, 0x0000006c, 0x00000006, 0x0003001e, +0x0000006d, 0x0000006c, 0x00040020, 0x0000006e, +0x00000002, 0x0000006d, 0x0004003b, 0x0000006e, +0x0000006f, 0x00000002, 0x00040032, 0x00000006, +0x00000080, 0x00000000, 0x0003001d, 0x0000008a, +0x00000006, 0x0003001e, 0x0000008b, 0x0000008a, +0x00040020, 0x0000008c, 0x00000002, 0x0000008b, +0x0004003b, 0x0000008c, 0x0000008d, 0x00000002, +0x0003001d, 0x00000092, 0x00000006, 0x0003001e, +0x00000093, 0x00000092, 0x00040020, 0x00000094, +0x00000002, 0x00000093, 0x0004003b, 0x00000094, +0x00000095, 0x00000002, 0x0004002b, 0x0000003c, +0x00000097, 0x00000001, 0x0003001d, 0x0000009b, +0x00000006, 0x0003001e, 0x0000009c, 0x0000009b, +0x00040020, 0x0000009d, 0x00000002, 0x0000009c, +0x0004003b, 0x0000009d, 0x0000009e, 0x00000002, +0x0003001d, 0x000000a2, 0x00000006, 0x0003001e, +0x000000a3, 0x000000a2, 0x00040020, 0x000000a4, +0x00000002, 0x000000a3, 0x0004003b, 0x000000a4, +0x000000a5, 0x00000002, 0x0006002c, 0x0000003f, +0x000000ad, 0x00000097, 0x00000097, 0x00000097, +0x00050036, 0x00000002, 0x00000004, 0x00000000, +0x00000003, 0x000200f8, 0x00000005, 0x0004003b, +0x0000003d, 0x0000003e, 0x00000007, 0x0004003b, +0x0000000d, 0x00000046, 0x00000007, 0x0004003b, +0x00000007, 0x00000054, 0x00000007, 0x0004003b, +0x0000000d, 0x0000005b, 0x00000007, 0x0004003b, +0x00000007, 0x0000006b, 0x00000007, 0x0004003b, +0x00000007, 0x00000073, 0x00000007, 0x0004003b, +0x0000000d, 0x00000074, 0x00000007, 0x0004003b, +0x0000000d, 0x00000076, 0x00000007, 0x0004003b, +0x00000007, 0x00000078, 0x00000007, 0x0004003b, +0x00000007, 0x0000007b, 0x00000007, 0x0004003b, +0x0000000d, 0x0000007f, 0x00000007, 0x0004003b, +0x00000007, 0x00000086, 0x00000007, 0x0004003b, +0x00000007, 0x000000a7, 0x00000007, 0x0004003b, +0x00000007, 0x000000a9, 0x00000007, 0x00050041, +0x00000043, 0x00000044, 0x00000041, 0x00000042, +0x0004003d, 0x0000003c, 0x00000045, 0x00000044, +0x0003003e, 0x0000003e, 0x00000045, 0x00060041, +0x0000004d, 0x0000004e, 0x0000004a, 0x0000004c, +0x0000004c, 0x0004003d, 0x00000006, 0x0000004f, +0x0000004e, 0x00060041, 0x0000004d, 0x00000051, +0x0000004a, 0x0000004c, 0x00000050, 0x0004003d, +0x00000006, 0x00000052, 0x00000051, 0x00050050, +0x0000000c, 0x00000053, 0x0000004f, 0x00000052, +0x0003003e, 0x00000046, 0x00000053, 0x00060041, +0x0000004d, 0x00000059, 0x00000058, 0x0000004c, +0x0000004c, 0x0004003d, 0x00000006, 0x0000005a, +0x00000059, 0x0003003e, 0x00000054, 0x0000005a, +0x0004003d, 0x0000003c, 0x00000060, 0x0000003e, +0x00060041, 0x0000004d, 0x00000061, 0x0000005f, +0x0000004c, 0x00000060, 0x0004003d, 0x00000006, +0x00000062, 0x00000061, 0x0004003d, 0x0000003c, +0x00000067, 0x0000003e, 0x00060041, 0x0000004d, +0x00000068, 0x00000066, 0x0000004c, 0x00000067, +0x0004003d, 0x00000006, 0x00000069, 0x00000068, +0x00050050, 0x0000000c, 0x0000006a, 0x00000062, +0x00000069, 0x0003003e, 0x0000005b, 0x0000006a, +0x0004003d, 0x0000003c, 0x00000070, 0x0000003e, +0x00060041, 0x0000004d, 0x00000071, 0x0000006f, +0x0000004c, 0x00000070, 0x0004003d, 0x00000006, +0x00000072, 0x00000071, 0x0003003e, 0x0000006b, +0x00000072, 0x0004003d, 0x0000000c, 0x00000075, +0x0000005b, 0x0003003e, 0x00000074, 0x00000075, +0x0004003d, 0x0000000c, 0x00000077, 0x00000046, +0x0003003e, 0x00000076, 0x00000077, 0x0004003d, +0x00000006, 0x00000079, 0x00000054, 0x0003003e, +0x00000078, 0x00000079, 0x00070039, 0x00000006, +0x0000007a, 0x00000012, 0x00000074, 0x00000076, +0x00000078, 0x0003003e, 0x00000073, 0x0000007a, +0x0004003d, 0x00000006, 0x0000007c, 0x00000073, +0x0004003d, 0x00000006, 0x0000007d, 0x0000006b, +0x00050083, 0x00000006, 0x0000007e, 0x0000007c, +0x0000007d, 0x0003003e, 0x0000007b, 0x0000007e, +0x00050088, 0x00000006, 0x00000081, 0x00000019, +0x00000080, 0x0004003d, 0x0000000c, 0x00000082, +0x0000005b, 0x0005008e, 0x0000000c, 0x00000083, +0x00000082, 0x00000081, 0x0004003d, 0x00000006, +0x00000084, 0x0000007b, 0x0005008e, 0x0000000c, +0x00000085, 0x00000083, 0x00000084, 0x0003003e, +0x0000007f, 0x00000085, 0x00050088, 0x00000006, +0x00000087, 0x00000019, 0x00000080, 0x0004003d, +0x00000006, 0x00000088, 0x0000007b, 0x00050085, +0x00000006, 0x00000089, 0x00000087, 0x00000088, +0x0003003e, 0x00000086, 0x00000089, 0x0004003d, +0x0000003c, 0x0000008e, 0x0000003e, 0x00050041, +0x00000007, 0x0000008f, 0x0000007f, 0x00000042, +0x0004003d, 0x00000006, 0x00000090, 0x0000008f, +0x00060041, 0x0000004d, 0x00000091, 0x0000008d, +0x0000004c, 0x0000008e, 0x0003003e, 0x00000091, +0x00000090, 0x0004003d, 0x0000003c, 0x00000096, +0x0000003e, 0x00050041, 0x00000007, 0x00000098, +0x0000007f, 0x00000097, 0x0004003d, 0x00000006, +0x00000099, 0x00000098, 0x00060041, 0x0000004d, +0x0000009a, 0x00000095, 0x0000004c, 0x00000096, +0x0003003e, 0x0000009a, 0x00000099, 0x0004003d, +0x0000003c, 0x0000009f, 0x0000003e, 0x0004003d, +0x00000006, 0x000000a0, 0x00000086, 0x00060041, +0x0000004d, 0x000000a1, 0x0000009e, 0x0000004c, +0x0000009f, 0x0003003e, 0x000000a1, 0x000000a0, +0x0004003d, 0x0000003c, 0x000000a6, 0x0000003e, +0x0004003d, 0x00000006, 0x000000a8, 0x00000073, +0x0003003e, 0x000000a7, 0x000000a8, 0x0004003d, +0x00000006, 0x000000aa, 0x0000006b, 0x0003003e, +0x000000a9, 0x000000aa, 0x00060039, 0x00000006, +0x000000ab, 0x00000017, 0x000000a7, 0x000000a9, +0x00060041, 0x0000004d, 0x000000ac, 0x000000a5, +0x0000004c, 0x000000a6, 0x0003003e, 0x000000ac, +0x000000ab, 0x000100fd, 0x00010038, 0x00050036, +0x00000006, 0x0000000a, 0x00000000, 0x00000008, +0x00030037, 0x00000007, 0x00000009, 0x000200f8, +0x0000000b, 0x0004003d, 0x00000006, 0x0000001a, +0x00000009, 0x0004007f, 0x00000006, 0x0000001b, +0x0000001a, 0x0006000c, 0x00000006, 0x0000001c, +0x00000001, 0x0000001b, 0x0000001b, 0x00050081, +0x00000006, 0x0000001d, 0x00000019, 0x0000001c, +0x00050088, 0x00000006, 0x0000001e, 0x00000019, +0x0000001d, 0x000200fe, 0x0000001e, 0x00010038, +0x00050036, 0x00000006, 0x00000012, 0x00000000, +0x0000000e, 0x00030037, 0x0000000d, 0x0000000f, +0x00030037, 0x0000000d, 0x00000010, 0x00030037, +0x00000007, 0x00000011, 0x000200f8, 0x00000013, +0x0004003b, 0x00000007, 0x00000021, 0x00000007, +0x0004003b, 0x00000007, 0x00000027, 0x00000007, +0x0004003b, 0x00000007, 0x00000028, 0x00000007, +0x0004003d, 0x0000000c, 0x00000022, 0x00000010, +0x0004003d, 0x0000000c, 0x00000023, 0x0000000f, +0x00050094, 0x00000006, 0x00000024, 0x00000022, +0x00000023, 0x0004003d, 0x00000006, 0x00000025, +0x00000011, 0x00050081, 0x00000006, 0x00000026, +0x00000024, 0x00000025, 0x0003003e, 0x00000021, +0x00000026, 0x0004003d, 0x00000006, 0x00000029, +0x00000021, 0x0003003e, 0x00000028, 0x00000029, +0x00050039, 0x00000006, 0x0000002a, 0x0000000a, +0x00000028, 0x0003003e, 0x00000027, 0x0000002a, +0x0004003d, 0x00000006, 0x0000002b, 0x00000027, +0x000200fe, 0x0000002b, 0x00010038, 0x00050036, +0x00000006, 0x00000017, 0x00000000, 0x00000014, +0x00030037, 0x00000007, 0x00000015, 0x00030037, +0x00000007, 0x00000016, 0x000200f8, 0x00000018, +0x0004003d, 0x00000006, 0x0000002e, 0x00000016, +0x0004003d, 0x00000006, 0x0000002f, 0x00000015, +0x0006000c, 0x00000006, 0x00000030, 0x00000001, +0x0000001c, 0x0000002f, 0x00050085, 0x00000006, +0x00000031, 0x0000002e, 0x00000030, 0x0004003d, +0x00000006, 0x00000032, 0x00000016, 0x00050083, +0x00000006, 0x00000033, 0x00000019, 0x00000032, +0x0004003d, 0x00000006, 0x00000034, 0x00000015, +0x00050083, 0x00000006, 0x00000035, 0x00000019, +0x00000034, 0x0006000c, 0x00000006, 0x00000036, +0x00000001, 0x0000001c, 0x00000035, 0x00050085, +0x00000006, 0x00000037, 0x00000033, 0x00000036, +0x00050081, 0x00000006, 0x00000038, 0x00000031, +0x00000037, 0x0004007f, 0x00000006, 0x00000039, +0x00000038, 0x000200fe, 0x00000039, 0x00010038 }; +} // namespace kp + + diff --git a/ggml/src/ggml-kompute/kompute/src/shaders/glsl/ShaderOpMult.comp b/ggml/src/ggml-kompute/kompute/src/shaders/glsl/ShaderOpMult.comp new file mode 100644 index 0000000000000000000000000000000000000000..d5486503760c1e4b868b70de277c2b8ee5878935 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/shaders/glsl/ShaderOpMult.comp @@ -0,0 +1,28 @@ +#version 450 + +layout(set = 0, binding = 0) buffer tensorLhs { + float valuesLhs[ ]; +}; + +layout(set = 0, binding = 1) buffer tensorRhs { + float valuesRhs[ ]; +}; + +layout(set = 0, binding = 2) buffer tensorOutput { + float valuesOutput[ ]; +}; + +layout (constant_id = 0) const uint LEN_LHS = 0; +layout (constant_id = 1) const uint LEN_RHS = 0; +layout (constant_id = 2) const uint LEN_OUT = 0; + +layout (local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + +void main() +{ + uint index = gl_GlobalInvocationID.x; + + valuesOutput[index] = valuesLhs[index] * valuesRhs[index]; +} + + diff --git a/ggml/src/ggml-kompute/kompute/src/shaders/glsl/ShaderOpMult.hpp.in b/ggml/src/ggml-kompute/kompute/src/shaders/glsl/ShaderOpMult.hpp.in new file mode 100644 index 0000000000000000000000000000000000000000..5af29c66d12143c6e368974efbab160f4b886660 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/shaders/glsl/ShaderOpMult.hpp.in @@ -0,0 +1,101 @@ +#pragma once +#include +#include + +namespace kp { +const std::array SHADEROPMULT_COMP_SPV = { +0x07230203, 0x00010000, 0x0008000a, 0x0000002e, +0x00000000, 0x00020011, 0x00000001, 0x0006000b, +0x00000001, 0x4c534c47, 0x6474732e, 0x3035342e, +0x00000000, 0x0003000e, 0x00000000, 0x00000001, +0x0006000f, 0x00000005, 0x00000004, 0x6e69616d, +0x00000000, 0x0000000b, 0x00060010, 0x00000004, +0x00000011, 0x00000001, 0x00000001, 0x00000001, +0x00030003, 0x00000002, 0x000001c2, 0x00040005, +0x00000004, 0x6e69616d, 0x00000000, 0x00040005, +0x00000008, 0x65646e69, 0x00000078, 0x00080005, +0x0000000b, 0x475f6c67, 0x61626f6c, 0x766e496c, +0x7461636f, 0x496e6f69, 0x00000044, 0x00060005, +0x00000012, 0x736e6574, 0x754f726f, 0x74757074, +0x00000000, 0x00070006, 0x00000012, 0x00000000, +0x756c6176, 0x754f7365, 0x74757074, 0x00000000, +0x00030005, 0x00000014, 0x00000000, 0x00050005, +0x00000019, 0x736e6574, 0x684c726f, 0x00000073, +0x00060006, 0x00000019, 0x00000000, 0x756c6176, +0x684c7365, 0x00000073, 0x00030005, 0x0000001b, +0x00000000, 0x00050005, 0x00000021, 0x736e6574, +0x6852726f, 0x00000073, 0x00060006, 0x00000021, +0x00000000, 0x756c6176, 0x68527365, 0x00000073, +0x00030005, 0x00000023, 0x00000000, 0x00040005, +0x00000029, 0x5f4e454c, 0x0053484c, 0x00040005, +0x0000002a, 0x5f4e454c, 0x00534852, 0x00040005, +0x0000002b, 0x5f4e454c, 0x0054554f, 0x00040047, +0x0000000b, 0x0000000b, 0x0000001c, 0x00040047, +0x00000011, 0x00000006, 0x00000004, 0x00050048, +0x00000012, 0x00000000, 0x00000023, 0x00000000, +0x00030047, 0x00000012, 0x00000003, 0x00040047, +0x00000014, 0x00000022, 0x00000000, 0x00040047, +0x00000014, 0x00000021, 0x00000002, 0x00040047, +0x00000018, 0x00000006, 0x00000004, 0x00050048, +0x00000019, 0x00000000, 0x00000023, 0x00000000, +0x00030047, 0x00000019, 0x00000003, 0x00040047, +0x0000001b, 0x00000022, 0x00000000, 0x00040047, +0x0000001b, 0x00000021, 0x00000000, 0x00040047, +0x00000020, 0x00000006, 0x00000004, 0x00050048, +0x00000021, 0x00000000, 0x00000023, 0x00000000, +0x00030047, 0x00000021, 0x00000003, 0x00040047, +0x00000023, 0x00000022, 0x00000000, 0x00040047, +0x00000023, 0x00000021, 0x00000001, 0x00040047, +0x00000029, 0x00000001, 0x00000000, 0x00040047, +0x0000002a, 0x00000001, 0x00000001, 0x00040047, +0x0000002b, 0x00000001, 0x00000002, 0x00040047, +0x0000002d, 0x0000000b, 0x00000019, 0x00020013, +0x00000002, 0x00030021, 0x00000003, 0x00000002, +0x00040015, 0x00000006, 0x00000020, 0x00000000, +0x00040020, 0x00000007, 0x00000007, 0x00000006, +0x00040017, 0x00000009, 0x00000006, 0x00000003, +0x00040020, 0x0000000a, 0x00000001, 0x00000009, +0x0004003b, 0x0000000a, 0x0000000b, 0x00000001, +0x0004002b, 0x00000006, 0x0000000c, 0x00000000, +0x00040020, 0x0000000d, 0x00000001, 0x00000006, +0x00030016, 0x00000010, 0x00000020, 0x0003001d, +0x00000011, 0x00000010, 0x0003001e, 0x00000012, +0x00000011, 0x00040020, 0x00000013, 0x00000002, +0x00000012, 0x0004003b, 0x00000013, 0x00000014, +0x00000002, 0x00040015, 0x00000015, 0x00000020, +0x00000001, 0x0004002b, 0x00000015, 0x00000016, +0x00000000, 0x0003001d, 0x00000018, 0x00000010, +0x0003001e, 0x00000019, 0x00000018, 0x00040020, +0x0000001a, 0x00000002, 0x00000019, 0x0004003b, +0x0000001a, 0x0000001b, 0x00000002, 0x00040020, +0x0000001d, 0x00000002, 0x00000010, 0x0003001d, +0x00000020, 0x00000010, 0x0003001e, 0x00000021, +0x00000020, 0x00040020, 0x00000022, 0x00000002, +0x00000021, 0x0004003b, 0x00000022, 0x00000023, +0x00000002, 0x00040032, 0x00000006, 0x00000029, +0x00000000, 0x00040032, 0x00000006, 0x0000002a, +0x00000000, 0x00040032, 0x00000006, 0x0000002b, +0x00000000, 0x0004002b, 0x00000006, 0x0000002c, +0x00000001, 0x0006002c, 0x00000009, 0x0000002d, +0x0000002c, 0x0000002c, 0x0000002c, 0x00050036, +0x00000002, 0x00000004, 0x00000000, 0x00000003, +0x000200f8, 0x00000005, 0x0004003b, 0x00000007, +0x00000008, 0x00000007, 0x00050041, 0x0000000d, +0x0000000e, 0x0000000b, 0x0000000c, 0x0004003d, +0x00000006, 0x0000000f, 0x0000000e, 0x0003003e, +0x00000008, 0x0000000f, 0x0004003d, 0x00000006, +0x00000017, 0x00000008, 0x0004003d, 0x00000006, +0x0000001c, 0x00000008, 0x00060041, 0x0000001d, +0x0000001e, 0x0000001b, 0x00000016, 0x0000001c, +0x0004003d, 0x00000010, 0x0000001f, 0x0000001e, +0x0004003d, 0x00000006, 0x00000024, 0x00000008, +0x00060041, 0x0000001d, 0x00000025, 0x00000023, +0x00000016, 0x00000024, 0x0004003d, 0x00000010, +0x00000026, 0x00000025, 0x00050085, 0x00000010, +0x00000027, 0x0000001f, 0x00000026, 0x00060041, +0x0000001d, 0x00000028, 0x00000014, 0x00000016, +0x00000017, 0x0003003e, 0x00000028, 0x00000027, +0x000100fd, 0x00010038 }; +} // namespace kp + + diff --git a/ggml/src/ggml-kompute/kompute/src/shaders/hlsl/computeheadless.comp b/ggml/src/ggml-kompute/kompute/src/shaders/hlsl/computeheadless.comp new file mode 100644 index 0000000000000000000000000000000000000000..ee3cd024f0466cbd80c1967955febbe974596393 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/src/shaders/hlsl/computeheadless.comp @@ -0,0 +1,29 @@ +// Copyright 2020 Google LLC + +RWStructuredBuffer values : register(u0); +[[vk::constant_id(0)]] const uint BUFFER_ELEMENTS = 32; + +uint fibonacci(uint n) { + if(n <= 1){ + return n; + } + uint curr = 1; + uint prev = 1; + for(uint i = 2; i < n; ++i) { + uint temp = curr; + curr += prev; + prev = temp; + } + return curr; +} + +[numthreads(1, 1, 1)] +void main(uint3 GlobalInvocationID : SV_DispatchThreadID) +{ + uint index = GlobalInvocationID.x; + if (index >= BUFFER_ELEMENTS) + return; + values[index] = fibonacci(values[index]); +} + + diff --git a/ggml/src/ggml-kompute/kompute/test/CMakeLists.txt b/ggml/src/ggml-kompute/kompute/test/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..0564458c661ef2561058e2622aef9501e7be0fb3 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/CMakeLists.txt @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +# ###################### +cmake_minimum_required(VERSION 3.20) + +# #################################################### +# Shaders +# #################################################### +add_subdirectory(shaders) + +# #################################################### +# Tests +# #################################################### +add_executable(kompute_tests TestAsyncOperations.cpp + TestDestroy.cpp + TestLogisticRegression.cpp + TestManager.cpp + TestMultipleAlgoExecutions.cpp + TestOpShadersFromStringAndFile.cpp + TestOpTensorCopy.cpp + TestOpTensorCreate.cpp + TestPushConstant.cpp + TestSequence.cpp + TestSpecializationConstant.cpp + TestWorkgroup.cpp) + +target_link_libraries(kompute_tests PRIVATE GTest::gtest_main + kompute::kompute + kp_logger + test_shaders + test_shaders_glsl) +add_test(NAME kompute_tests COMMAND kompute_tests) + +# Group under the "tests" project folder in IDEs such as Visual Studio. +set_property(TARGET kompute_tests PROPERTY FOLDER "tests") + +if(WIN32 AND BUILD_SHARED_LIBS) # Install dlls in the same directory as the executable on Windows so one can simply double click them + add_custom_command(TARGET kompute_tests POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy $ $) + add_custom_command(TARGET kompute_tests POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy $ $) + add_custom_command(TARGET kompute_tests POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy $ $) + add_custom_command(TARGET kompute_tests POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy $ $) +endif() diff --git a/ggml/src/ggml-kompute/kompute/test/TestAsyncOperations.cpp b/ggml/src/ggml-kompute/kompute/test/TestAsyncOperations.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d549dda41e5efe576e520254061cd2460f3baa56 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/TestAsyncOperations.cpp @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "gtest/gtest.h" + +#include + +#include "kompute/Kompute.hpp" +#include "kompute/logger/Logger.hpp" +#include "shaders/Utils.hpp" + +TEST(TestAsyncOperations, TestManagerParallelExecution) +{ + // This test is built for NVIDIA 1650. It assumes: + // * Queue family 0 and 2 have compute capabilities + // * GPU is able to process parallel shader code across different families + uint32_t size = 10; + + uint32_t numParallel = 2; + + std::string shader(R"( + #version 450 + + layout (local_size_x = 1) in; + + layout(set = 0, binding = 0) buffer b { float pb[]; }; + + shared uint sharedTotal[1]; + + void main() { + uint index = gl_GlobalInvocationID.x; + + sharedTotal[0] = 0; + + for (int i = 0; i < 100000000; i++) + { + atomicAdd(sharedTotal[0], 1); + } + + pb[index] = sharedTotal[0]; + } + )"); + + std::vector spirv = compileSource(shader); + + std::vector data(size, 0.0); + std::vector resultSync(size, 100000000); + std::vector resultAsync(size, 100000000); + + kp::Manager mgr; + + std::shared_ptr sq = mgr.sequence(); + + std::vector> inputsSyncB; + std::vector> algorithms; + + for (uint32_t i = 0; i < numParallel; i++) { + inputsSyncB.push_back(mgr.tensor(data)); + algorithms.push_back(mgr.algorithm({ inputsSyncB[i] }, spirv)); + } + + sq->eval(inputsSyncB); + + mgr.sequence()->eval(inputsSyncB); + + auto startSync = std::chrono::high_resolution_clock::now(); + + for (uint32_t i = 0; i < numParallel; i++) { + sq->eval(algorithms[i]); + } + + auto endSync = std::chrono::high_resolution_clock::now(); + auto durationSync = + std::chrono::duration_cast(endSync - startSync) + .count(); + + sq->eval(inputsSyncB); + + for (uint32_t i = 0; i < numParallel; i++) { + EXPECT_EQ(inputsSyncB[i]->vector(), resultSync); + } + + kp::Manager mgrAsync(0, { 0, 2 }); + + std::vector> inputsAsyncB; + + std::vector> algosAsync; + + for (uint32_t i = 0; i < numParallel; i++) { + inputsAsyncB.push_back(mgr.tensor(data)); + algosAsync.push_back(mgr.algorithm({ inputsAsyncB[i] }, spirv)); + } + + std::vector> sqs; + + for (uint32_t i = 0; i < numParallel; i++) { + sqs.push_back(mgrAsync.sequence(i)); + } + + auto startAsync = std::chrono::high_resolution_clock::now(); + + for (uint32_t i = 0; i < numParallel; i++) { + sqs[i]->evalAsync(algosAsync[i]); + } + + for (uint32_t i = 0; i < numParallel; i++) { + sqs[i]->evalAwait(); + } + + auto endAsync = std::chrono::high_resolution_clock::now(); + auto durationAsync = std::chrono::duration_cast( + endAsync - startAsync) + .count(); + + sq->eval({ inputsAsyncB }); + + for (uint32_t i = 0; i < numParallel; i++) { + EXPECT_EQ((inputsAsyncB[i]->vector()), resultAsync); + } + + // The speedup should be at least 40% + EXPECT_LT(durationAsync, durationSync * 0.6); +} + +TEST(TestAsyncOperations, TestManagerAsyncExecution) +{ + uint32_t size = 10; + + std::string shader(R"( + #version 450 + + layout (local_size_x = 1) in; + + layout(set = 0, binding = 0) buffer b { float pb[]; }; + + shared uint sharedTotal[1]; + + void main() { + uint index = gl_GlobalInvocationID.x; + + sharedTotal[0] = 0; + + for (int i = 0; i < 100000000; i++) + { + atomicAdd(sharedTotal[0], 1); + } + + pb[index] = sharedTotal[0]; + } + )"); + + std::vector spirv = compileSource(shader); + + std::vector data(size, 0.0); + std::vector resultAsync(size, 100000000); + + kp::Manager mgr; + + std::shared_ptr> tensorA = mgr.tensor(data); + std::shared_ptr> tensorB = mgr.tensor(data); + + std::shared_ptr sq1 = mgr.sequence(); + std::shared_ptr sq2 = mgr.sequence(); + + sq1->eval({ tensorA, tensorB }); + + std::shared_ptr algo1 = mgr.algorithm({ tensorA }, spirv); + std::shared_ptr algo2 = mgr.algorithm({ tensorB }, spirv); + + // AMD Drivers in Windows may see an error in this line due to timeout. + // In order to fix this, it requires a change on Windows registries. + // More details on this can be found here: + // https://docs.substance3d.com/spdoc/gpu-drivers-crash-with-long-computations-128745489.html + // Context on solution discussed in github: + // https://github.com/KomputeProject/kompute/issues/196#issuecomment-808866505 + sq1->evalAsync(algo1); + sq2->evalAsync(algo2); + + sq1->evalAwait(); + sq2->evalAwait(); + + sq1->evalAsync({ tensorA, tensorB }); + sq1->evalAwait(); + + EXPECT_EQ(tensorA->vector(), resultAsync); + EXPECT_EQ(tensorB->vector(), resultAsync); +} + +TEST(TestAsyncOperations, TestManagerAsyncExecutionTimeout) +{ + uint32_t size = 10; + + std::string shader(R"( + #version 450 + + layout (local_size_x = 1) in; + + layout(set = 0, binding = 0) buffer b { float pb[]; }; + + shared uint sharedTotal[1]; + + void main() { + uint index = gl_GlobalInvocationID.x; + + sharedTotal[0] = 0; + + for (int i = 0; i < 100000000; i++) + { + atomicAdd(sharedTotal[0], 1); + } + + pb[index] = sharedTotal[0]; + } + )"); + + std::vector spirv = compileSource(shader); + + std::vector data(size, 0.0); + std::vector resultAsync(size, 100000000); + + kp::Manager mgr; + + std::shared_ptr> tensorA = mgr.tensor(data); + std::shared_ptr> tensorB = mgr.tensor(data); + + std::shared_ptr sq1 = mgr.sequence(); + std::shared_ptr sq2 = mgr.sequence(); + + sq1->eval({ tensorA, tensorB }); + + std::shared_ptr algo1 = mgr.algorithm({ tensorA }, spirv); + std::shared_ptr algo2 = mgr.algorithm({ tensorB }, spirv); + + auto startSync = std::chrono::high_resolution_clock::now(); + + // AMD Drivers in Windows may see an error in this line due to timeout. + // In order to fix this, it requires a change on Windows registries. + // More details on this can be found here: + // https://docs.substance3d.com/spdoc/gpu-drivers-crash-with-long-computations-128745489.html + // Context on solution discussed in github: + // https://github.com/KomputeProject/kompute/issues/196#issuecomment-808866505 + sq1->evalAsync(algo1); + sq2->evalAsync(algo2); + + sq1->evalAwait(1); + sq2->evalAwait(1); + + auto endSync = std::chrono::high_resolution_clock::now(); + auto duration = + std::chrono::duration_cast(endSync - startSync) + .count(); + + // The time should several orders of magnitude smaller (in this 100k instead + // of 1m ns) + EXPECT_LT(duration, 100000); + + sq1->evalAsync({ tensorA, tensorB }); + sq1->evalAwait(); + + EXPECT_EQ(tensorA->vector(), resultAsync); + EXPECT_EQ(tensorB->vector(), resultAsync); +} diff --git a/ggml/src/ggml-kompute/kompute/test/TestDestroy.cpp b/ggml/src/ggml-kompute/kompute/test/TestDestroy.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f6b92cd902599b233d07a34cd28aa47bbfe7897d --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/TestDestroy.cpp @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "gtest/gtest.h" + +#include "kompute/Kompute.hpp" +#include "kompute/logger/Logger.hpp" + +#include "shaders/Utils.hpp" + +TEST(TestDestroy, TestDestroyTensorSingle) +{ + std::shared_ptr> tensorA = nullptr; + + std::string shader(R"( + #version 450 + layout (local_size_x = 1) in; + layout(set = 0, binding = 0) buffer a { float pa[]; }; + void main() { + uint index = gl_GlobalInvocationID.x; + pa[index] = pa[index] + 1; + })"); + + std::vector spirv = compileSource(shader); + + { + std::shared_ptr sq = nullptr; + + { + kp::Manager mgr; + + const std::vector initialValues = { 0.0f, 0.0f, 0.0f }; + + tensorA = mgr.tensor(initialValues); + + std::shared_ptr algo = + mgr.algorithm({ tensorA }, spirv); + + // Sync values to and from device + mgr.sequence()->eval(algo->getTensors()); + + EXPECT_EQ(tensorA->vector(), initialValues); + + mgr.sequence() + ->record(algo) + ->eval() + ->eval(algo->getTensors()); + + const std::vector expectedFinalValues = { 1.0f, 1.0f, 1.0f }; + EXPECT_EQ(tensorA->vector(), expectedFinalValues); + + tensorA->destroy(); + EXPECT_FALSE(tensorA->isInit()); + } + EXPECT_FALSE(tensorA->isInit()); + } +} + +TEST(TestDestroy, TestDestroyTensorVector) +{ + std::shared_ptr> tensorA = nullptr; + std::shared_ptr> tensorB = nullptr; + + std::string shader(R"( + #version 450 + layout (local_size_x = 1) in; + layout(set = 0, binding = 0) buffer a { float pa[]; }; + layout(set = 0, binding = 1) buffer b { float pb[]; }; + void main() { + uint index = gl_GlobalInvocationID.x; + pa[index] = pa[index] + 1; + pb[index] = pb[index] + 2; + })"); + std::vector spirv = compileSource(shader); + + { + std::shared_ptr sq = nullptr; + + { + kp::Manager mgr; + + tensorA = mgr.tensor({ 1, 1, 1 }); + tensorB = mgr.tensor({ 1, 1, 1 }); + + std::shared_ptr algo = + mgr.algorithm({ tensorA, tensorB }, spirv); + + mgr.sequence() + ->record(algo->getTensors()) + ->record(algo) + ->record(algo->getTensors()) + ->eval(); + + EXPECT_EQ(tensorA->vector(), std::vector({ 2, 2, 2 })); + EXPECT_EQ(tensorB->vector(), std::vector({ 3, 3, 3 })); + + tensorA->destroy(); + tensorB->destroy(); + + EXPECT_FALSE(tensorA->isInit()); + EXPECT_FALSE(tensorB->isInit()); + } + } +} + +TEST(TestDestroy, TestDestroySequenceSingle) +{ + std::shared_ptr> tensorA = nullptr; + + std::string shader(R"( + #version 450 + layout (local_size_x = 1) in; + layout(set = 0, binding = 0) buffer a { float pa[]; }; + void main() { + uint index = gl_GlobalInvocationID.x; + pa[index] = pa[index] + 1; + })"); + + std::vector spirv = compileSource(shader); + + { + std::shared_ptr sq = nullptr; + + { + kp::Manager mgr; + + tensorA = mgr.tensor({ 0, 0, 0 }); + + sq = + mgr.sequence() + ->record({ tensorA }) + ->record(mgr.algorithm({ tensorA }, spirv)) + ->record({ tensorA }) + ->eval(); + + sq->destroy(); + + EXPECT_FALSE(sq->isInit()); + + EXPECT_EQ(tensorA->vector(), std::vector({ 1, 1, 1 })); + } + } +} diff --git a/ggml/src/ggml-kompute/kompute/test/TestLogisticRegression.cpp b/ggml/src/ggml-kompute/kompute/test/TestLogisticRegression.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6b417a94c9d35371fb0beaf27a5cb84e8fd8b630 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/TestLogisticRegression.cpp @@ -0,0 +1,159 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "gtest/gtest.h" + +#include "kompute/Kompute.hpp" +#include "kompute/logger/Logger.hpp" + +#include "test_logistic_regression_shader.hpp" + +TEST(TestLogisticRegression, TestMainLogisticRegression) +{ + + uint32_t ITERATIONS = 100; + float learningRate = 0.1; + + { + kp::Manager mgr; + + std::shared_ptr> xI = mgr.tensor({ 0, 1, 1, 1, 1 }); + std::shared_ptr> xJ = mgr.tensor({ 0, 0, 0, 1, 1 }); + + std::shared_ptr> y = mgr.tensor({ 0, 0, 0, 1, 1 }); + + std::shared_ptr> wIn = mgr.tensor({ 0.001, 0.001 }); + std::shared_ptr> wOutI = + mgr.tensor({ 0, 0, 0, 0, 0 }); + std::shared_ptr> wOutJ = + mgr.tensor({ 0, 0, 0, 0, 0 }); + + std::shared_ptr> bIn = mgr.tensor({ 0 }); + std::shared_ptr> bOut = + mgr.tensor({ 0, 0, 0, 0, 0 }); + + std::shared_ptr> lOut = + mgr.tensor({ 0, 0, 0, 0, 0 }); + + std::vector> params = { xI, xJ, y, + wIn, wOutI, wOutJ, + bIn, bOut, lOut }; + + mgr.sequence()->eval(params); + + std::vector spirv2{ 0x1, 0x2 }; + + std::vector spirv( + kp::TEST_LOGISTIC_REGRESSION_SHADER_COMP_SPV.begin(), + kp::TEST_LOGISTIC_REGRESSION_SHADER_COMP_SPV.end()); + + std::shared_ptr algorithm = mgr.algorithm( + params, spirv, kp::Workgroup({ 5 }), std::vector({ 5.0 })); + + std::shared_ptr sq = + mgr.sequence() + ->record({ wIn, bIn }) + ->record(algorithm) + ->record({ wOutI, wOutJ, bOut, lOut }); + + // Iterate across all expected iterations + for (size_t i = 0; i < ITERATIONS; i++) { + sq->eval(); + + for (size_t j = 0; j < bOut->size(); j++) { + wIn->data()[0] -= learningRate * wOutI->data()[j]; + wIn->data()[1] -= learningRate * wOutJ->data()[j]; + bIn->data()[0] -= learningRate * bOut->data()[j]; + } + } + + // Based on the inputs the outputs should be at least: + // * wi < 0.01 + // * wj > 1.0 + // * b < 0 + // TODO: Add EXPECT_DOUBLE_EQ instead + EXPECT_LT(wIn->data()[0], 0.01); + EXPECT_GT(wIn->data()[1], 1.0); + EXPECT_LT(bIn->data()[0], 0.0); + + KP_LOG_WARN("Result wIn i: {}, wIn j: {}, bIn: {}", + wIn->data()[0], + wIn->data()[1], + bIn->data()[0]); + } +} + +TEST(TestLogisticRegression, TestMainLogisticRegressionManualCopy) +{ + + uint32_t ITERATIONS = 100; + float learningRate = 0.1; + + { + kp::Manager mgr; + + std::shared_ptr> xI = mgr.tensor({ 0, 1, 1, 1, 1 }); + std::shared_ptr> xJ = mgr.tensor({ 0, 0, 0, 1, 1 }); + + std::shared_ptr> y = mgr.tensor({ 0, 0, 0, 1, 1 }); + + std::shared_ptr> wIn = + mgr.tensor({ 0.001, 0.001 }, kp::Tensor::TensorTypes::eHost); + std::shared_ptr> wOutI = + mgr.tensor({ 0, 0, 0, 0, 0 }); + std::shared_ptr> wOutJ = + mgr.tensor({ 0, 0, 0, 0, 0 }); + + std::shared_ptr> bIn = + mgr.tensor({ 0 }, kp::Tensor::TensorTypes::eHost); + std::shared_ptr> bOut = + mgr.tensor({ 0, 0, 0, 0, 0 }); + + std::shared_ptr> lOut = + mgr.tensor({ 0, 0, 0, 0, 0 }); + + std::vector> params = { xI, xJ, y, + wIn, wOutI, wOutJ, + bIn, bOut, lOut }; + + mgr.sequence()->record(params)->eval(); + + std::vector spirv( + kp::TEST_LOGISTIC_REGRESSION_SHADER_COMP_SPV.begin(), + kp::TEST_LOGISTIC_REGRESSION_SHADER_COMP_SPV.end()); + + std::shared_ptr algorithm = mgr.algorithm( + params, spirv, kp::Workgroup(), std::vector({ 5.0 })); + + std::shared_ptr sq = + mgr.sequence() + ->record({ wIn, bIn }) + ->record(algorithm) + ->record({ wOutI, wOutJ, bOut, lOut }); + + // Iterate across all expected iterations + for (size_t i = 0; i < ITERATIONS; i++) { + + sq->eval(); + + for (size_t j = 0; j < bOut->size(); j++) { + wIn->data()[0] -= learningRate * wOutI->data()[j]; + wIn->data()[1] -= learningRate * wOutJ->data()[j]; + bIn->data()[0] -= learningRate * bOut->data()[j]; + } + } + + // Based on the inputs the outputs should be at least: + // * wi < 0.01 + // * wj > 1.0 + // * b < 0 + // TODO: Add EXPECT_DOUBLE_EQ instead + EXPECT_LT(wIn->data()[0], 0.01); + EXPECT_GT(wIn->data()[1], 1.0); + EXPECT_LT(bIn->data()[0], 0.0); + + KP_LOG_WARN("Result wIn i: {}, wIn j: {}, bIn: {}", + wIn->data()[0], + wIn->data()[1], + bIn->data()[0]); + } +} diff --git a/ggml/src/ggml-kompute/kompute/test/TestManager.cpp b/ggml/src/ggml-kompute/kompute/test/TestManager.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a7e488d11b16df08df6a253c9363f4b74282fda8 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/TestManager.cpp @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "gtest/gtest.h" + +#include "kompute/Kompute.hpp" +#include "kompute/logger/Logger.hpp" + +TEST(TestManager, EndToEndOpMultEvalFlow) +{ + kp::Manager mgr; + + std::shared_ptr> tensorLHS = mgr.tensor({ 0, 1, 2 }); + std::shared_ptr> tensorRHS = mgr.tensor({ 2, 4, 6 }); + std::shared_ptr> tensorOutput = mgr.tensor({ 0, 0, 0 }); + + std::vector> params = { tensorLHS, + tensorRHS, + tensorOutput }; + + mgr.sequence() + ->eval(params) + ->eval(params, mgr.algorithm()) + ->eval(params); + + EXPECT_EQ(tensorOutput->vector(), std::vector({ 0, 4, 12 })); +} + +TEST(TestManager, EndToEndOpMultSeqFlow) +{ + kp::Manager mgr; + + std::shared_ptr> tensorLHS = mgr.tensor({ 0, 1, 2 }); + std::shared_ptr> tensorRHS = mgr.tensor({ 2, 4, 6 }); + std::shared_ptr> tensorOutput = mgr.tensor({ 0, 0, 0 }); + + std::vector> params = { tensorLHS, + tensorRHS, + tensorOutput }; + + mgr.sequence() + ->record(params) + ->record(params, mgr.algorithm()) + ->record(params) + ->eval(); + + EXPECT_EQ(tensorOutput->vector(), std::vector({ 0, 4, 12 })); +} + +TEST(TestManager, TestMultipleSequences) +{ + kp::Manager mgr; + + std::shared_ptr> tensorLHS = mgr.tensor({ 0, 1, 2 }); + std::shared_ptr> tensorRHS = mgr.tensor({ 2, 4, 6 }); + std::shared_ptr> tensorOutput = mgr.tensor({ 0, 0, 0 }); + + std::vector> params = { tensorLHS, + tensorRHS, + tensorOutput }; + + mgr.sequence()->eval(params); + mgr.sequence()->eval(params, mgr.algorithm()); + mgr.sequence()->eval(params); + + EXPECT_EQ(tensorOutput->vector(), std::vector({ 0, 4, 12 })); +} + +TEST(TestManager, TestDeviceProperties) +{ + kp::Manager mgr; + const vk::PhysicalDeviceProperties properties = mgr.getDeviceProperties(); + EXPECT_GT(properties.deviceName.size(), 0); +} + +TEST(TestManager, TestListDevices) +{ + kp::Manager mgr; + const std::vector devices = mgr.listDevices(); + EXPECT_GT(devices.size(), 0); + EXPECT_GT(devices[0].getProperties().deviceName.size(), 0); +} + +TEST(TestManager, TestClearDestroy) +{ + kp::Manager mgr; + + // Running within scope to run clear + { + std::shared_ptr> tensorLHS = mgr.tensor({ 0, 1, 2 }); + std::shared_ptr> tensorRHS = mgr.tensor({ 2, 4, 6 }); + std::shared_ptr> tensorOutput = + mgr.tensor({ 0, 0, 0 }); + + std::vector> params = { tensorLHS, + tensorRHS, + tensorOutput }; + + mgr.sequence()->eval(params); + mgr.sequence()->eval(params, mgr.algorithm()); + mgr.sequence()->eval(params); + + EXPECT_EQ(tensorOutput->vector(), std::vector({ 0, 4, 12 })); + } + + mgr.clear(); + + mgr.destroy(); +} diff --git a/ggml/src/ggml-kompute/kompute/test/TestMultipleAlgoExecutions.cpp b/ggml/src/ggml-kompute/kompute/test/TestMultipleAlgoExecutions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f3d7315bafa82631c26503f16de6700c9cd94522 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/TestMultipleAlgoExecutions.cpp @@ -0,0 +1,273 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "gtest/gtest.h" + +#include "kompute/Kompute.hpp" +#include "kompute/logger/Logger.hpp" + +#include "shaders/Utils.hpp" + +TEST(TestMultipleAlgoExecutions, TestEndToEndFunctionality) +{ + + kp::Manager mgr; + + // Default tensor constructor simplifies creation of float values + auto tensorInA = mgr.tensor({ 2., 2., 2. }); + auto tensorInB = mgr.tensor({ 1., 2., 3. }); + // Explicit type constructor supports int, in32, double, float and int + auto tensorOutA = mgr.tensorT({ 0, 0, 0 }); + auto tensorOutB = mgr.tensorT({ 0, 0, 0 }); + + std::string shader = (R"( + #version 450 + + layout (local_size_x = 1) in; + + // The input tensors bind index is relative to index in parameter passed + layout(set = 0, binding = 0) buffer buf_in_a { float in_a[]; }; + layout(set = 0, binding = 1) buffer buf_in_b { float in_b[]; }; + layout(set = 0, binding = 2) buffer buf_out_a { uint out_a[]; }; + layout(set = 0, binding = 3) buffer buf_out_b { uint out_b[]; }; + + // Kompute supports push constants updated on dispatch + layout(push_constant) uniform PushConstants { + float val; + } push_const; + + // Kompute also supports spec constants on initalization + layout(constant_id = 0) const float const_one = 0; + + void main() { + uint index = gl_GlobalInvocationID.x; + out_a[index] += uint( in_a[index] * in_b[index] ); + out_b[index] += uint( const_one * push_const.val ); + } + )"); + + std::vector> params = { + tensorInA, tensorInB, tensorOutA, tensorOutB + }; + + kp::Workgroup workgroup({ 3, 1, 1 }); + std::vector specConsts({ 2 }); + std::vector pushConstsA({ 2.0 }); + std::vector pushConstsB({ 3.0 }); + + auto algorithm = mgr.algorithm( + params, compileSource(shader), workgroup, specConsts, pushConstsA); + + // 3. Run operation with string shader synchronously + mgr.sequence() + ->record(params) + ->record(algorithm) + ->eval() + ->record(algorithm, pushConstsB) + ->eval(); + + auto sq = mgr.sequence(); + sq->evalAsync(params); + + sq->evalAwait(); + + EXPECT_EQ(tensorOutA->vector(), std::vector({ 4, 8, 12 })); + EXPECT_EQ(tensorOutB->vector(), std::vector({ 10, 10, 10 })); +} + +TEST(TestMultipleAlgoExecutions, SingleSequenceRecord) +{ + + kp::Manager mgr; + + std::shared_ptr> tensorA = mgr.tensor({ 0, 0, 0 }); + + std::string shader(R"( + #version 450 + layout (local_size_x = 1) in; + layout(set = 0, binding = 0) buffer a { float pa[]; }; + void main() { + uint index = gl_GlobalInvocationID.x; + pa[index] = pa[index] + 1; + })"); + + std::vector spirv = compileSource(shader); + + { + // A sharedMemoryBarrier is required as the shader is not thread-safe:w + std::shared_ptr shaderBarrier{ + new kp::OpMemoryBarrier({ tensorA }, + vk::AccessFlagBits::eTransferRead, + vk::AccessFlagBits::eShaderWrite, + vk::PipelineStageFlagBits::eComputeShader, + vk::PipelineStageFlagBits::eComputeShader) + }; + + mgr.sequence() + ->record({ tensorA }) + ->record(mgr.algorithm({ tensorA }, spirv)) + ->record(shaderBarrier) + ->record(mgr.algorithm({ tensorA }, spirv)) + ->record(shaderBarrier) + ->record(mgr.algorithm({ tensorA }, spirv)) + ->record({ tensorA }) + ->eval(); + } + + EXPECT_EQ(tensorA->vector(), std::vector({ 3, 3, 3 })); +} + +TEST(TestMultipleAlgoExecutions, MultipleCmdBufRecords) +{ + kp::Manager mgr; + + std::shared_ptr> tensorA = mgr.tensor({ 0, 0, 0 }); + + std::string shader(R"( + #version 450 + layout (local_size_x = 1) in; + layout(set = 0, binding = 0) buffer a { float pa[]; }; + void main() { + uint index = gl_GlobalInvocationID.x; + pa[index] = pa[index] + 1; + })"); + + std::vector spirv = compileSource(shader); + + std::shared_ptr algorithm = + mgr.algorithm({ tensorA }, spirv); + + std::shared_ptr sq = mgr.sequence(); + + mgr.sequence()->record({ tensorA })->eval(); + + mgr.sequence()->record(algorithm)->eval(); + + mgr.sequence()->record(algorithm)->eval(); + + mgr.sequence()->record(algorithm)->eval(); + + mgr.sequence()->record({ tensorA })->eval(); + + EXPECT_EQ(tensorA->vector(), std::vector({ 3, 3, 3 })); +} + +TEST(TestMultipleAlgoExecutions, MultipleSequences) +{ + + kp::Manager mgr; + + std::shared_ptr> tensorA = mgr.tensor({ 0, 0, 0 }); + + std::string shader(R"( + #version 450 + layout (local_size_x = 1) in; + layout(set = 0, binding = 0) buffer a { float pa[]; }; + void main() { + uint index = gl_GlobalInvocationID.x; + pa[index] = pa[index] + 1; + })"); + + std::vector spirv = compileSource(shader); + + std::shared_ptr algorithm = + mgr.algorithm({ tensorA }, spirv); + + std::shared_ptr sq = mgr.sequence(); + + sq->record({ tensorA })->eval(); + + sq->record(algorithm)->eval(); + + sq->record(algorithm)->eval(); + + sq->record(algorithm)->eval(); + + sq->record({ tensorA })->eval(); + + EXPECT_EQ(tensorA->vector(), std::vector({ 3, 3, 3 })); +} + +TEST(TestMultipleAlgoExecutions, SingleRecordMultipleEval) +{ + kp::Manager mgr; + + std::shared_ptr> tensorA = mgr.tensor({ 0, 0, 0 }); + + std::string shader(R"( + #version 450 + layout (local_size_x = 1) in; + layout(set = 0, binding = 0) buffer a { float pa[]; }; + void main() { + uint index = gl_GlobalInvocationID.x; + pa[index] = pa[index] + 1; + })"); + + std::vector spirv = compileSource(shader); + + std::shared_ptr algorithm = + mgr.algorithm({ tensorA }, spirv); + + std::shared_ptr sq = mgr.sequence(); + + sq->record({ tensorA })->eval(); + + sq->record(algorithm)->eval()->eval()->eval(); + + sq->record({ tensorA })->eval(); + + EXPECT_EQ(tensorA->vector(), std::vector({ 3, 3, 3 })); +} + +TEST(TestMultipleAlgoExecutions, TestAlgorithmUtilFunctions) +{ + + kp::Manager mgr; + + // Default tensor constructor simplifies creation of float values + auto tensorInA = mgr.tensor({ 2., 2., 2. }); + auto tensorInB = mgr.tensor({ 1., 2., 3. }); + // Explicit type constructor supports int, in32, double, float and int + auto tensorOutA = mgr.tensorT({ 0, 0, 0 }); + auto tensorOutB = mgr.tensorT({ 0, 0, 0 }); + + std::string shader = (R"( + #version 450 + + layout (local_size_x = 1) in; + + // The input tensors bind index is relative to index in parameter passed + layout(set = 0, binding = 0) buffer buf_in_a { float in_a[]; }; + layout(set = 0, binding = 1) buffer buf_in_b { float in_b[]; }; + layout(set = 0, binding = 2) buffer buf_out_a { uint out_a[]; }; + layout(set = 0, binding = 3) buffer buf_out_b { uint out_b[]; }; + + // Kompute supports push constants updated on dispatch + layout(push_constant) uniform PushConstants { + float val; + } push_const; + + // Kompute also supports spec constants on initalization + layout(constant_id = 0) const float const_one = 0; + + void main() { + uint index = gl_GlobalInvocationID.x; + out_a[index] += uint( in_a[index] * in_b[index] ); + out_b[index] += uint( const_one * push_const.val ); + } + )"); + + std::vector> params = { + tensorInA, tensorInB, tensorOutA, tensorOutB + }; + + kp::Workgroup workgroup({ 3, 1, 1 }); + std::vector specConsts({ 2 }); + std::vector pushConsts({ 2.0 }); + + auto algorithm = mgr.algorithm( + params, compileSource(shader), workgroup, specConsts, pushConsts); + + EXPECT_EQ(algorithm->getWorkgroup(), workgroup); + EXPECT_EQ(algorithm->getPushConstants(), pushConsts); + EXPECT_EQ(algorithm->getSpecializationConstants(), specConsts); +} diff --git a/ggml/src/ggml-kompute/kompute/test/TestOpShadersFromStringAndFile.cpp b/ggml/src/ggml-kompute/kompute/test/TestOpShadersFromStringAndFile.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c95f2f5161b015b1738b1cb141c3fbcfc5881ae2 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/TestOpShadersFromStringAndFile.cpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "gtest/gtest.h" + +#include "kompute/Kompute.hpp" +#include "kompute/logger/Logger.hpp" + +#include "shaders/Utils.hpp" +#include "test_op_custom_shader.hpp" +#include "test_shader.hpp" + +TEST(TestShaderEndianness, ShaderRawDataFromConstructor) +{ + std::string shader(R"( + #version 450 + + layout (local_size_x = 1) in; + + layout(set = 0, binding = 0) buffer a { float pa[]; }; + layout(set = 0, binding = 1) buffer b { float pb[]; }; + + void main() { + uint index = gl_GlobalInvocationID.x; + pb[index] = pa[index]; + pa[index] = index; + } + )"); + + std::vector spirv = compileSource(shader); + std::vector spirv2(kp::TEST_SHADER_COMP_SPV.begin(), + kp::TEST_SHADER_COMP_SPV.end()); + EXPECT_EQ(spirv.size(), spirv2.size()); + for (size_t i = 0; i < spirv.size(); i++) { + EXPECT_EQ(spirv[i], spirv2[i]); + } +} + +TEST(TestOpAlgoCreate, ShaderRawDataFromConstructor) +{ + kp::Manager mgr; + + std::shared_ptr> tensorA = mgr.tensor({ 3, 4, 5 }); + std::shared_ptr> tensorB = mgr.tensor({ 0, 0, 0 }); + + std::string shader(R"( + #version 450 + + layout (local_size_x = 1) in; + + layout(set = 0, binding = 0) buffer a { float pa[]; }; + layout(set = 0, binding = 1) buffer b { float pb[]; }; + + void main() { + uint index = gl_GlobalInvocationID.x; + pb[index] = pa[index]; + pa[index] = index; + } + )"); + + std::vector spirv = compileSource(shader); + + std::vector> params = { tensorA, tensorB }; + + mgr.sequence() + ->eval(params) + ->eval(mgr.algorithm(params, spirv)) + ->eval(params); + + EXPECT_EQ(tensorA->vector(), std::vector({ 0, 1, 2 })); + EXPECT_EQ(tensorB->vector(), std::vector({ 3, 4, 5 })); +} + +TEST(TestOpAlgoCreate, ShaderCompiledDataFromConstructor) +{ + kp::Manager mgr; + + std::shared_ptr> tensorA = mgr.tensor({ 3, 4, 5 }); + std::shared_ptr> tensorB = mgr.tensor({ 0, 0, 0 }); + + std::vector spirv(kp::TEST_OP_CUSTOM_SHADER_COMP_SPV.begin(), + kp::TEST_OP_CUSTOM_SHADER_COMP_SPV.end()); + std::vector> params = { tensorA, tensorB }; + + mgr.sequence() + ->eval(params) + ->eval(mgr.algorithm(params, spirv)) + ->eval(params); + + EXPECT_EQ(tensorA->vector(), std::vector({ 0, 1, 2 })); + EXPECT_EQ(tensorB->vector(), std::vector({ 3, 4, 5 })); +} + +// TODO: Add support to read from file for shader +// TEST(TestOpAlgoCreate, ShaderCompiledDataFromFile) +//{ +// kp::Manager mgr; +// +// std::shared_ptr> tensorA{ new kp::Tensor({ 3, 4, 5 }) +// }; std::shared_ptr> tensorB{ new kp::Tensor({ 0, 0, 0 +// }) }; mgr.rebuild({ tensorA, tensorB }); +// +// mgr.evalOpDefault( +// { tensorA, tensorB }, +// "test/shaders/glsl/test_op_custom_shader.comp.spv"); +// +// mgr.evalOpDefault({ tensorA, tensorB }); +// +// EXPECT_EQ(tensorA->vector(), std::vector({ 0, 1, 2 })); +// EXPECT_EQ(tensorB->vector(), std::vector({ 3, 4, 5 })); +//} diff --git a/ggml/src/ggml-kompute/kompute/test/TestOpTensorCopy.cpp b/ggml/src/ggml-kompute/kompute/test/TestOpTensorCopy.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9f8de608e93c67ed6e4bcbf2cb0791f483f9b7cc --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/TestOpTensorCopy.cpp @@ -0,0 +1,248 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "gtest/gtest.h" + +#include "kompute/Kompute.hpp" +#include "kompute/logger/Logger.hpp" + +#include "shaders/Utils.hpp" + +TEST(TestOpTensorCopy, CopyDeviceToDeviceTensor) +{ + + kp::Manager mgr; + + std::vector testVecA{ 1, 2, 3 }; + std::vector testVecB{ 0, 0, 0 }; + + std::shared_ptr> tensorA = mgr.tensor(testVecA); + std::shared_ptr> tensorB = mgr.tensor(testVecB); + + EXPECT_TRUE(tensorA->isInit()); + EXPECT_TRUE(tensorB->isInit()); + + mgr.sequence() + ->eval({ tensorA, tensorB }) + ->eval({ tensorA, tensorB }) + ->eval({ tensorA, tensorB }); + + // Making sure the GPU holds the same vector + EXPECT_EQ(tensorA->vector(), tensorB->vector()); +} + +TEST(TestOpTensorCopy, CopyDeviceToDeviceTensorMulti) +{ + + kp::Manager mgr; + + std::vector testVecA{ 2, 3, 4 }; + std::vector testVecB{ 0, 0, 0 }; + std::vector testVecC{ 0, 0, 0 }; + + std::shared_ptr> tensorA = mgr.tensor(testVecA); + std::shared_ptr> tensorB = mgr.tensor(testVecB); + std::shared_ptr> tensorC = mgr.tensor(testVecC); + + EXPECT_TRUE(tensorA->isInit()); + EXPECT_TRUE(tensorB->isInit()); + EXPECT_TRUE(tensorC->isInit()); + + mgr.sequence() + ->eval({ tensorA, tensorB, tensorC }) + ->eval({ tensorA, tensorB, tensorC }); + + EXPECT_EQ(tensorA->vector(), tensorB->vector()); + EXPECT_EQ(tensorA->vector(), tensorC->vector()); + + // Making sure the GPU holds the same vector + mgr.sequence()->eval({ tensorB, tensorC }); + + EXPECT_EQ(tensorA->vector(), tensorB->vector()); + EXPECT_EQ(tensorA->vector(), tensorC->vector()); +} + +TEST(TestOpTensorCopy, CopyDeviceToHostTensor) +{ + + kp::Manager mgr; + + std::vector testVecA{ 3, 4, 5 }; + std::vector testVecB{ 0, 0, 0 }; + + std::shared_ptr> tensorA = mgr.tensor(testVecA); + std::shared_ptr> tensorB = + mgr.tensor(testVecB, kp::Tensor::TensorTypes::eHost); + + // Only calling sync on device type tensor + mgr.sequence()->eval({ tensorA }); + + EXPECT_TRUE(tensorA->isInit()); + EXPECT_TRUE(tensorB->isInit()); + + mgr.sequence()->eval({ tensorA, tensorB }); + + EXPECT_EQ(tensorA->vector(), tensorB->vector()); + + // Making sure the GPU holds the same vector + mgr.sequence()->eval({ tensorB }); + EXPECT_EQ(tensorA->vector(), tensorB->vector()); +} + +TEST(TestOpTensorCopy, CopyHostToDeviceTensor) +{ + + kp::Manager mgr; + + std::vector testVecA{ 4, 5, 6 }; + std::vector testVecB{ 0, 0, 0 }; + + std::shared_ptr> tensorA = + mgr.tensor(testVecA, kp::Tensor::TensorTypes::eHost); + std::shared_ptr> tensorB = mgr.tensor(testVecB); + + // Only calling sync on device type tensor + mgr.sequence()->eval({ tensorA, tensorB }); + + EXPECT_TRUE(tensorA->isInit()); + EXPECT_TRUE(tensorB->isInit()); + + mgr.sequence()->eval({ tensorA, tensorB }); + + EXPECT_EQ(tensorA->vector(), tensorB->vector()); + + // Making sure the GPU holds the same vector + mgr.sequence()->eval({ tensorB }); + EXPECT_EQ(tensorA->vector(), tensorB->vector()); +} + +TEST(TestOpTensorCopy, CopyHostToHostTensor) +{ + + kp::Manager mgr; + + std::vector testVecA{ 5, 6, 7 }; + std::vector testVecB{ 0, 0, 0 }; + + std::shared_ptr> tensorA = + mgr.tensor(testVecA, kp::Tensor::TensorTypes::eHost); + std::shared_ptr> tensorB = + mgr.tensor(testVecB, kp::Tensor::TensorTypes::eHost); + + EXPECT_TRUE(tensorA->isInit()); + EXPECT_TRUE(tensorB->isInit()); + + mgr.sequence() + ->eval({ tensorA }) + ->eval({ tensorA, tensorB }); + + EXPECT_EQ(tensorA->vector(), tensorB->vector()); + + // Making sure the GPU holds the same vector + mgr.sequence()->eval({ tensorB }); + EXPECT_EQ(tensorA->vector(), tensorB->vector()); +} + +TEST(TestOpTensorCopy, SingleTensorShouldFail) +{ + + kp::Manager mgr; + + std::vector testVecA{ 6, 7, 8 }; + + std::shared_ptr> tensorA = + mgr.tensor(testVecA, kp::Tensor::TensorTypes::eHost); + + EXPECT_TRUE(tensorA->isInit()); + + EXPECT_THROW(mgr.sequence()->eval({ tensorA }), + std::runtime_error); +} + +TEST(TestOpTensorCopy, CopyThroughStorageTensor) +{ + kp::Manager mgr; + + std::vector testVecIn{ 9, 1, 3 }; + std::vector testVecOut{ 0, 0, 0 }; + + std::shared_ptr> tensorIn = mgr.tensor(testVecIn); + std::shared_ptr> tensorOut = mgr.tensor(testVecOut); + // Tensor storage requires a vector to be passed only to reflect size + std::shared_ptr> tensorStorage = + mgr.tensor({ 0, 0, 0 }, kp::Tensor::TensorTypes::eStorage); + + mgr.sequence() + ->eval({ tensorIn, tensorOut }) + ->eval({ tensorIn, tensorStorage }) + ->eval({ tensorStorage, tensorOut }) + ->eval({ tensorIn, tensorOut }); + + // Making sure the GPU holds the same vector + EXPECT_EQ(tensorIn->vector(), tensorOut->vector()); +} + +TEST(TestOpTensorCopy, CopyTensorThroughStorageViaAlgorithms) +{ + kp::Manager mgr; + + std::vector testVecIn{ 9, 1, 3 }; + std::vector testVecOut{ 0, 0, 0 }; + + std::shared_ptr> tensorIn = mgr.tensor(testVecIn); + std::shared_ptr> tensorOut = mgr.tensor(testVecOut); + // Tensor storage requires a vector to be passed only to reflect size + std::shared_ptr> tensorStorage = + mgr.tensor({ 0, 0, 0 }, kp::Tensor::TensorTypes::eStorage); + + EXPECT_TRUE(tensorIn->isInit()); + EXPECT_TRUE(tensorOut->isInit()); + + // Copy to storage tensor through algorithm + std::string shaderA = (R"( + #version 450 + + layout (local_size_x = 1) in; + + // The input tensors bind index is relative to index in parameter passed + layout(set = 0, binding = 0) buffer buf_in { float t_in[]; }; + layout(set = 0, binding = 1) buffer buf_st { float t_st[]; }; + + void main() { + uint index = gl_GlobalInvocationID.x; + t_st[index] = t_in[index]; + } + )"); + + auto algoA = mgr.algorithm( + { tensorIn, tensorStorage }, + compileSource(shaderA)); + + // Copy from storage tensor to output tensor + std::string shaderB = (R"( + #version 450 + + layout (local_size_x = 1) in; + + // The input tensors bind index is relative to index in parameter passed + layout(set = 0, binding = 0) buffer buf_st { float t_st[]; }; + layout(set = 0, binding = 1) buffer buf_out { float t_out[]; }; + + void main() { + uint index = gl_GlobalInvocationID.x; + t_out[index] = t_st[index]; + } + )"); + + auto algoB = mgr.algorithm( + { tensorStorage, tensorOut }, + compileSource(shaderB)); + + mgr.sequence() + ->eval({ tensorIn }) + ->eval(algoA) + ->eval(algoB) + ->eval({ tensorOut }); + + // Making sure the GPU holds the same vector + EXPECT_EQ(tensorIn->vector(), tensorOut->vector()); +} diff --git a/ggml/src/ggml-kompute/kompute/test/TestOpTensorCreate.cpp b/ggml/src/ggml-kompute/kompute/test/TestOpTensorCreate.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c3508ff89c7dbb4ab5cd84f2e6830379208beaf0 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/TestOpTensorCreate.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "gtest/gtest.h" + +#include "kompute/Kompute.hpp" +#include "kompute/logger/Logger.hpp" + +TEST(TestOpTensorCreate, CreateSingleTensorSingleOp) +{ + std::vector testVecA{ 9, 8, 7 }; + std::shared_ptr> tensorA = nullptr; + + { + kp::Manager mgr; + + tensorA = mgr.tensor(testVecA); + + EXPECT_TRUE(tensorA->isInit()); + + EXPECT_EQ(tensorA->vector(), testVecA); + } + + EXPECT_FALSE(tensorA->isInit()); +} + +TEST(TestOpTensorCreate, NoErrorIfTensorFreedBefore) +{ + + std::vector testVecA{ 9, 8, 7 }; + std::vector testVecB{ 6, 5, 4 }; + + kp::Manager mgr; + + std::shared_ptr> tensorA = mgr.tensor(testVecA); + std::shared_ptr> tensorB = mgr.tensor(testVecB); + + EXPECT_EQ(tensorA->vector(), testVecA); + EXPECT_EQ(tensorB->vector(), testVecB); + + tensorA->destroy(); + tensorB->destroy(); + + EXPECT_FALSE(tensorA->isInit()); + EXPECT_FALSE(tensorB->isInit()); +} + +TEST(TestOpTensorCreate, ExceptionOnZeroSizeTensor) +{ + std::vector testVecA; + + kp::Manager mgr; + + try { + std::shared_ptr> tensorA = mgr.tensor(testVecA); + } catch (const std::runtime_error& err) { + // check exception + ASSERT_TRUE(std::string(err.what()).find("zero-sized") != + std::string::npos); + } +} diff --git a/ggml/src/ggml-kompute/kompute/test/TestOpTensorSync.cpp b/ggml/src/ggml-kompute/kompute/test/TestOpTensorSync.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b59ee14813d2e922b29cf11bd1ffaaeefe83e475 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/TestOpTensorSync.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "gtest/gtest.h" + +#include "kompute/Kompute.hpp" +#include "kompute/logger/Logger.hpp" + +TEST(TestOpTensorSync, SyncToDeviceMemorySingleTensor) +{ + + kp::Manager mgr; + + std::vector testVecPreA{ 0, 0, 0 }; + std::vector testVecPostA{ 9, 8, 7 }; + + std::shared_ptr> tensorA = mgr.tensor(testVecPreA); + + EXPECT_TRUE(tensorA->isInit()); + + tensorA->setData(testVecPostA); + + mgr.sequence()->eval({ tensorA }); + + mgr.sequence()->eval({ tensorA }); + + EXPECT_EQ(tensorA->vector(), testVecPostA); +} + +TEST(TestOpTensorSync, SyncToDeviceMemoryMultiTensor) +{ + + kp::Manager mgr; + + std::vector testVec{ 9, 8, 7 }; + + std::shared_ptr> tensorA = mgr.tensor({ 0, 0, 0 }); + std::shared_ptr> tensorB = mgr.tensor({ 0, 0, 0 }); + std::shared_ptr> tensorC = mgr.tensor({ 0, 0, 0 }); + + EXPECT_TRUE(tensorA->isInit()); + EXPECT_TRUE(tensorB->isInit()); + EXPECT_TRUE(tensorC->isInit()); + + tensorA->setData(testVec); + + mgr.sequence()->eval({ tensorA }); + + mgr.sequence()->eval({ tensorA, tensorB, tensorC }); + + mgr.sequence()->eval({ tensorA, tensorB, tensorC }); + + EXPECT_EQ(tensorA->vector(), testVec); + EXPECT_EQ(tensorB->vector(), testVec); + EXPECT_EQ(tensorC->vector(), testVec); +} diff --git a/ggml/src/ggml-kompute/kompute/test/TestPushConstant.cpp b/ggml/src/ggml-kompute/kompute/test/TestPushConstant.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1356425cfc67e6f948698bc58367f42a1e442492 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/TestPushConstant.cpp @@ -0,0 +1,392 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "gtest/gtest.h" + +#include "kompute/Kompute.hpp" +#include "kompute/logger/Logger.hpp" + +#include "shaders/Utils.hpp" + +TEST(TestPushConstants, TestConstantsAlgoDispatchOverride) +{ + { + std::string shader(R"( + #version 450 + layout(push_constant) uniform PushConstants { + float x; + float y; + float z; + } pcs; + layout (local_size_x = 1) in; + layout(set = 0, binding = 0) buffer a { float pa[]; }; + void main() { + pa[0] += pcs.x; + pa[1] += pcs.y; + pa[2] += pcs.z; + })"); + + std::vector spirv = compileSource(shader); + + std::shared_ptr sq = nullptr; + + { + kp::Manager mgr; + + std::shared_ptr> tensor = + mgr.tensor({ 0, 0, 0 }); + + std::shared_ptr algo = mgr.algorithm( + { tensor }, spirv, kp::Workgroup({ 1 }), {}, { 0.0, 0.0, 0.0 }); + + sq = mgr.sequence()->eval({ tensor }); + + // We need to run this in sequence to avoid race condition + // We can't use atomicAdd as swiftshader doesn't support it for + // float + sq->eval(algo, + std::vector{ 0.1, 0.2, 0.3 }); + sq->eval(algo, + std::vector{ 0.3, 0.2, 0.1 }); + sq->eval({ tensor }); + + EXPECT_EQ(tensor->vector(), std::vector({ 0.4, 0.4, 0.4 })); + } + } +} + +TEST(TestPushConstants, TestConstantsAlgoDispatchNoOverride) +{ + { + std::string shader(R"( + #version 450 + layout(push_constant) uniform PushConstants { + float x; + float y; + float z; + } pcs; + layout (local_size_x = 1) in; + layout(set = 0, binding = 0) buffer a { float pa[]; }; + void main() { + pa[0] += pcs.x; + pa[1] += pcs.y; + pa[2] += pcs.z; + })"); + + std::vector spirv = compileSource(shader); + + std::shared_ptr sq = nullptr; + + { + kp::Manager mgr; + + std::shared_ptr> tensor = + mgr.tensor({ 0, 0, 0 }); + + std::shared_ptr algo = mgr.algorithm( + { tensor }, spirv, kp::Workgroup({ 1 }), {}, { 0.1, 0.2, 0.3 }); + + sq = mgr.sequence()->eval({ tensor }); + + // We need to run this in sequence to avoid race condition + // We can't use atomicAdd as swiftshader doesn't support it for + // float + sq->eval(algo); + sq->eval(algo, + std::vector{ 0.3, 0.2, 0.1 }); + sq->eval({ tensor }); + + EXPECT_EQ(tensor->vector(), std::vector({ 0.4, 0.4, 0.4 })); + } + } +} + +TEST(TestPushConstants, TestConstantsWrongSize) +{ + { + std::string shader(R"( + #version 450 + layout(push_constant) uniform PushConstants { + float x; + float y; + float z; + } pcs; + layout (local_size_x = 1) in; + layout(set = 0, binding = 0) buffer a { float pa[]; }; + void main() { + pa[0] += pcs.x; + pa[1] += pcs.y; + pa[2] += pcs.z; + })"); + + std::vector spirv = compileSource(shader); + + std::shared_ptr sq = nullptr; + + { + kp::Manager mgr; + + std::shared_ptr> tensor = + mgr.tensor({ 0, 0, 0 }); + + std::shared_ptr algo = mgr.algorithm( + { tensor }, spirv, kp::Workgroup({ 1 }), {}, { 0.0 }); + + sq = mgr.sequence()->record({ tensor }); + + EXPECT_THROW(sq->record( + algo, std::vector{ 0.1, 0.2, 0.3 }), + std::runtime_error); + } + } +} + +// TODO: Ensure different types are considered for push constants +// TEST(TestPushConstants, TestConstantsWrongType) +// { +// { +// std::string shader(R"( +// #version 450 +// layout(push_constant) uniform PushConstants { +// float x; +// float y; +// float z; +// } pcs; +// layout (local_size_x = 1) in; +// layout(set = 0, binding = 0) buffer a { float pa[]; }; +// void main() { +// pa[0] += pcs.x; +// pa[1] += pcs.y; +// pa[2] += pcs.z; +// })"); +// +// std::vector spirv = compileSource(shader); +// +// std::shared_ptr sq = nullptr; +// +// { +// kp::Manager mgr; +// +// std::shared_ptr> tensor = +// mgr.tensor({ 0, 0, 0 }); +// +// std::shared_ptr algo = mgr.algorithm( +// { tensor }, spirv, kp::Workgroup({ 1 }), {}, { 0.0 }); +// +// sq = mgr.sequence()->record({ tensor }); +// +// EXPECT_THROW(sq->record( +// algo, std::vector{ 1, 2, 3 }), +// std::runtime_error); +// } +// } +// } + +TEST(TestPushConstants, TestConstantsMixedTypes) +{ + { + std::string shader(R"( + #version 450 + layout(push_constant) uniform PushConstants { + float x; + uint y; + int z; + } pcs; + layout (local_size_x = 1) in; + layout(set = 0, binding = 0) buffer a { float pa[]; }; + void main() { + pa[0] += pcs.x; + pa[1] += pcs.y - 2147483000; + pa[2] += pcs.z; + })"); + + struct TestConsts + { + float x; + uint32_t y; + int32_t z; + }; + + std::vector spirv = compileSource(shader); + + std::shared_ptr sq = nullptr; + + { + kp::Manager mgr; + + std::shared_ptr> tensor = + mgr.tensorT({ 0, 0, 0 }); + + std::shared_ptr algo = + mgr.algorithm( + { tensor }, spirv, kp::Workgroup({ 1 }), {}, { { 0, 0, 0 } }); + + sq = mgr.sequence()->eval({ tensor }); + + // We need to run this in sequence to avoid race condition + // We can't use atomicAdd as swiftshader doesn't support it for + // float + sq->eval( + algo, std::vector{ { 15.32, 2147483650, 10 } }); + sq->eval( + algo, std::vector{ { 30.32, 2147483650, -3 } }); + sq->eval({ tensor }); + + EXPECT_EQ(tensor->vector(), std::vector({ 45.64, 1300, 7 })); + } + } +} + +TEST(TestPushConstants, TestConstantsInt) +{ + { + std::string shader(R"( + #version 450 + layout(push_constant) uniform PushConstants { + int x; + int y; + int z; + } pcs; + layout (local_size_x = 1) in; + layout(set = 0, binding = 0) buffer a { int pa[]; }; + void main() { + pa[0] += pcs.x; + pa[1] += pcs.y; + pa[2] += pcs.z; + })"); + + std::vector spirv = compileSource(shader); + + std::shared_ptr sq = nullptr; + + { + kp::Manager mgr; + + std::shared_ptr> tensor = + mgr.tensorT({ -1, -1, -1 }); + + std::shared_ptr algo = + mgr.algorithm( + { tensor }, spirv, kp::Workgroup({ 1 }), {}, { { 0, 0, 0 } }); + + sq = mgr.sequence()->eval({ tensor }); + + // We need to run this in sequence to avoid race condition + // We can't use atomicAdd as swiftshader doesn't support it for + // float + sq->eval( + algo, std::vector{ { -1, -1, -1 } }); + sq->eval( + algo, std::vector{ { -1, -1, -1 } }); + sq->eval({ tensor }); + + EXPECT_EQ(tensor->vector(), std::vector({ -3, -3, -3 })); + } + } +} + +TEST(TestPushConstants, TestConstantsUnsignedInt) +{ + { + std::string shader(R"( + #version 450 + layout(push_constant) uniform PushConstants { + uint x; + uint y; + uint z; + } pcs; + layout (local_size_x = 1) in; + layout(set = 0, binding = 0) buffer a { uint pa[]; }; + void main() { + pa[0] += pcs.x; + pa[1] += pcs.y; + pa[2] += pcs.z; + })"); + + std::vector spirv = compileSource(shader); + + std::shared_ptr sq = nullptr; + + { + kp::Manager mgr; + + std::shared_ptr> tensor = + mgr.tensorT({ 0, 0, 0 }); + + std::shared_ptr algo = + mgr.algorithm( + { tensor }, spirv, kp::Workgroup({ 1 }), {}, { { 0, 0, 0 } }); + + sq = mgr.sequence()->eval({ tensor }); + + // We need to run this in sequence to avoid race condition + // We can't use atomicAdd as swiftshader doesn't support it for + // float + sq->eval( + algo, + std::vector{ { 2147483650, 2147483650, 2147483650 } }); + sq->eval(algo, + std::vector{ { 5, 5, 5 } }); + sq->eval({ tensor }); + + EXPECT_EQ( + tensor->vector(), + std::vector({ 2147483655, 2147483655, 2147483655 })); + } + } +} + +TEST(TestPushConstants, TestConstantsDouble) +{ + { + std::string shader(R"( + #version 450 + layout(push_constant) uniform PushConstants { + double x; + double y; + double z; + } pcs; + layout (local_size_x = 1) in; + layout(set = 0, binding = 0) buffer a { double pa[]; }; + void main() { + pa[0] += pcs.x; + pa[1] += pcs.y; + pa[2] += pcs.z; + })"); + + std::vector spirv = compileSource(shader); + + std::shared_ptr sq = nullptr; + + { + kp::Manager mgr; + + std::shared_ptr> tensor = + mgr.tensorT({ 0, 0, 0 }); + + std::shared_ptr algo = mgr.algorithm( + { tensor }, spirv, kp::Workgroup({ 1 }), {}, { { 0, 0, 0 } }); + + sq = mgr.sequence()->eval({ tensor }); + + // We need to run this in sequence to avoid race condition + // We can't use atomicAdd as swiftshader doesn't support it for + // float + sq->eval( + algo, + std::vector{ { 1.1111222233334444, + 2.1111222233334444, + 3.1111222233334444 } }); + sq->eval( + algo, + std::vector{ { 1.1111222233334444, + 2.1111222233334444, + 3.1111222233334444 } }); + sq->eval({ tensor }); + + EXPECT_EQ(tensor->vector(), + std::vector({ 2.2222444466668888, + 4.2222444466668888, + 6.2222444466668888 })); + } + } +} diff --git a/ggml/src/ggml-kompute/kompute/test/TestSequence.cpp b/ggml/src/ggml-kompute/kompute/test/TestSequence.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d99107a96a53d05a3405506ed38a24a2dbe0fac6 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/TestSequence.cpp @@ -0,0 +1,245 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "gtest/gtest.h" + +#include "kompute/Kompute.hpp" +#include "kompute/logger/Logger.hpp" + +#include "shaders/Utils.hpp" + +TEST(TestSequence, SequenceDestructorViaManager) +{ + std::shared_ptr sq = nullptr; + + { + kp::Manager mgr; + + sq = mgr.sequence(); + + EXPECT_TRUE(sq->isInit()); + } + + EXPECT_FALSE(sq->isInit()); +} + +TEST(TestSequence, SequenceDestructorOutsideManagerExplicit) +{ + std::shared_ptr sq = nullptr; + + { + kp::Manager mgr; + + sq = mgr.sequence(); + + EXPECT_TRUE(sq->isInit()); + + sq->destroy(); + + EXPECT_FALSE(sq->isInit()); + } + + EXPECT_FALSE(sq->isInit()); +} + +TEST(TestSequence, SequenceDestructorOutsideManagerImplicit) +{ + kp::Manager mgr; + + std::weak_ptr sqWeak; + + { + std::shared_ptr sq = mgr.sequence(); + + sqWeak = sq; + + EXPECT_TRUE(sq->isInit()); + } + + EXPECT_FALSE(sqWeak.lock()); +} + +TEST(TestSequence, RerecordSequence) +{ + kp::Manager mgr; + + std::shared_ptr sq = mgr.sequence(); + + std::shared_ptr> tensorA = mgr.tensor({ 1, 2, 3 }); + std::shared_ptr> tensorB = mgr.tensor({ 2, 2, 2 }); + std::shared_ptr> tensorOut = mgr.tensor({ 0, 0, 0 }); + + sq->eval({ tensorA, tensorB, tensorOut }); + + std::vector spirv = compileSource(R"( + #version 450 + + layout (local_size_x = 1) in; + + // The input tensors bind index is relative to index in parameter passed + layout(set = 0, binding = 0) buffer bina { float tina[]; }; + layout(set = 0, binding = 1) buffer binb { float tinb[]; }; + layout(set = 0, binding = 2) buffer bout { float tout[]; }; + + void main() { + uint index = gl_GlobalInvocationID.x; + tout[index] = tina[index] * tinb[index]; + } + )"); + + std::shared_ptr algo = + mgr.algorithm({ tensorA, tensorB, tensorOut }, spirv); + + sq->record(algo)->record( + { tensorA, tensorB, tensorOut }); + + sq->eval(); + + EXPECT_EQ(tensorOut->vector(), std::vector({ 2, 4, 6 })); + + algo->rebuild({ tensorOut, tensorA, tensorB }, spirv); + + // Refresh and trigger a rerecord + sq->rerecord(); + sq->eval(); + + EXPECT_EQ(tensorB->vector(), std::vector({ 2, 8, 18 })); +} + +TEST(TestSequence, SequenceTimestamps) +{ + kp::Manager mgr; + + std::shared_ptr tensorA = mgr.tensor({ 0, 0, 0 }); + + std::string shader(R"( + #version 450 + layout (local_size_x = 1) in; + layout(set = 0, binding = 0) buffer a { float pa[]; }; + void main() { + uint index = gl_GlobalInvocationID.x; + pa[index] = pa[index] + 1; + })"); + + std::vector spirv = compileSource(shader); + + auto seq = mgr.sequence(0, 100); // 100 timestamps + seq->record({ tensorA }) + ->record(mgr.algorithm({ tensorA }, spirv)) + ->record(mgr.algorithm({ tensorA }, spirv)) + ->record(mgr.algorithm({ tensorA }, spirv)) + ->record({ tensorA }) + ->eval(); + const std::vector timestamps = seq->getTimestamps(); + + EXPECT_EQ(timestamps.size(), + 6); // 1 timestamp at start + 1 after each operation +} + +TEST(TestSequence, UtilsClearRecordingRunning) +{ + kp::Manager mgr; + + std::shared_ptr sq = mgr.sequence(); + + std::shared_ptr> tensorA = mgr.tensor({ 1, 2, 3 }); + std::shared_ptr> tensorB = mgr.tensor({ 2, 2, 2 }); + std::shared_ptr> tensorOut = mgr.tensor({ 0, 0, 0 }); + + sq->eval({ tensorA, tensorB, tensorOut }); + + std::vector spirv = compileSource(R"( + #version 450 + + layout (local_size_x = 1) in; + + // The input tensors bind index is relative to index in parameter passed + layout(set = 0, binding = 0) buffer bina { float tina[]; }; + layout(set = 0, binding = 1) buffer binb { float tinb[]; }; + layout(set = 0, binding = 2) buffer bout { float tout[]; }; + + void main() { + uint index = gl_GlobalInvocationID.x; + tout[index] = tina[index] * tinb[index]; + } + )"); + + std::shared_ptr algo = + mgr.algorithm({ tensorA, tensorB, tensorOut }, spirv); + + sq->record(algo)->record( + { tensorA, tensorB, tensorOut }); + + EXPECT_TRUE(sq->isRecording()); + + // Running clear to confirm it clears + sq->clear(); + + EXPECT_FALSE(sq->isRecording()); + + sq->evalAsync(); + + EXPECT_TRUE(sq->isRunning()); + + sq->evalAwait(); + + EXPECT_FALSE(sq->isRunning()); + + EXPECT_EQ(tensorOut->vector(), std::vector({ 2, 4, 6 })); +} + +TEST(TestSequence, CorrectSequenceRunningError) +{ + kp::Manager mgr; + + std::shared_ptr sq = mgr.sequence(); + + std::shared_ptr> tensorA = mgr.tensor({ 1, 2, 3 }); + std::shared_ptr> tensorB = mgr.tensor({ 2, 2, 2 }); + std::shared_ptr> tensorOut = mgr.tensor({ 0, 0, 0 }); + + sq->eval({ tensorA, tensorB, tensorOut }); + + std::vector spirv = compileSource(R"( + #version 450 + + layout (local_size_x = 1) in; + + // The input tensors bind index is relative to index in parameter passed + layout(set = 0, binding = 0) buffer bina { float tina[]; }; + layout(set = 0, binding = 1) buffer binb { float tinb[]; }; + layout(set = 0, binding = 2) buffer bout { float tout[]; }; + + void main() { + uint index = gl_GlobalInvocationID.x; + tout[index] = tina[index] * tinb[index]; + } + )"); + + std::shared_ptr algo = + mgr.algorithm({ tensorA, tensorB, tensorOut }, spirv); + + sq->record(algo)->record( + { tensorA, tensorB, tensorOut }); + + EXPECT_TRUE(sq->isRecording()); + + sq->evalAsync(); + + EXPECT_TRUE(sq->isRunning()); + + // Sequence should throw when running + EXPECT_ANY_THROW(sq->begin()); + EXPECT_ANY_THROW(sq->end()); + EXPECT_ANY_THROW(sq->evalAsync()); + + // Errors should still not get into inconsystent state + sq->evalAwait(); + + // Sequence should not throw when finished + EXPECT_NO_THROW(sq->evalAwait()); + EXPECT_NO_THROW(sq->evalAwait(10)); + + EXPECT_FALSE(sq->isRunning()); + + EXPECT_EQ(tensorOut->vector(), std::vector({ 2, 4, 6 })); +} diff --git a/ggml/src/ggml-kompute/kompute/test/TestSpecializationConstant.cpp b/ggml/src/ggml-kompute/kompute/test/TestSpecializationConstant.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7f15f84435cc5871ebe235c17f2774695254509f --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/TestSpecializationConstant.cpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "gtest/gtest.h" + +#include "kompute/Kompute.hpp" +#include "kompute/logger/Logger.hpp" + +#include "shaders/Utils.hpp" + +TEST(TestSpecializationConstants, TestTwoConstants) +{ + { + std::string shader(R"( + #version 450 + layout (constant_id = 0) const float cOne = 1; + layout (constant_id = 1) const float cTwo = 1; + layout (local_size_x = 1) in; + layout(set = 0, binding = 0) buffer a { float pa[]; }; + layout(set = 0, binding = 1) buffer b { float pb[]; }; + void main() { + uint index = gl_GlobalInvocationID.x; + pa[index] = cOne; + pb[index] = cTwo; + })"); + + std::vector spirv = compileSource(shader); + + std::shared_ptr sq = nullptr; + + { + kp::Manager mgr; + + std::shared_ptr> tensorA = + mgr.tensor({ 0, 0, 0 }); + std::shared_ptr> tensorB = + mgr.tensor({ 0, 0, 0 }); + + std::vector> params = { tensorA, + tensorB }; + + std::vector spec = std::vector({ 5.0, 0.3 }); + + std::shared_ptr algo = + mgr.algorithm(params, spirv, {}, spec); + + sq = mgr.sequence() + ->record(params) + ->record(algo) + ->record(params) + ->eval(); + + EXPECT_EQ(tensorA->vector(), std::vector({ 5, 5, 5 })); + EXPECT_EQ(tensorB->vector(), std::vector({ 0.3, 0.3, 0.3 })); + } + } +} + +TEST(TestSpecializationConstants, TestConstantsInt) +{ + { + std::string shader(R"( + #version 450 + layout (constant_id = 0) const int cOne = 1; + layout (constant_id = 1) const int cTwo = 1; + layout (local_size_x = 1) in; + layout(set = 0, binding = 0) buffer a { int pa[]; }; + layout(set = 0, binding = 1) buffer b { int pb[]; }; + void main() { + uint index = gl_GlobalInvocationID.x; + pa[index] = cOne; + pb[index] = cTwo; + })"); + + std::vector spirv = compileSource(shader); + + std::shared_ptr sq = nullptr; + + { + kp::Manager mgr; + + std::shared_ptr> tensorA = + mgr.tensorT({ 0, 0, 0 }); + std::shared_ptr> tensorB = + mgr.tensorT({ 0, 0, 0 }); + + std::vector> params = { tensorA, + tensorB }; + + std::vector spec({ -1, -2 }); + + std::shared_ptr algo = + mgr.algorithm(params, spirv, {}, spec, {}); + + sq = mgr.sequence() + ->record(params) + ->record(algo) + ->record(params) + ->eval(); + + EXPECT_EQ(tensorA->vector(), std::vector({ -1, -1, -1 })); + EXPECT_EQ(tensorB->vector(), std::vector({ -2, -2, -2 })); + } + } +} diff --git a/ggml/src/ggml-kompute/kompute/test/TestTensor.cpp b/ggml/src/ggml-kompute/kompute/test/TestTensor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7eeff4af3e31354ee8f123b10ac2698999f2b2fa --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/TestTensor.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "gtest/gtest.h" + +#include "kompute/Kompute.hpp" +#include "kompute/logger/Logger.hpp" + +TEST(TestTensor, ConstructorData) +{ + kp::Manager mgr; + std::vector vec{ 0, 1, 2 }; + std::shared_ptr> tensor = mgr.tensor(vec); + EXPECT_EQ(tensor->size(), vec.size()); + EXPECT_EQ(tensor->dataTypeMemorySize(), sizeof(float)); + EXPECT_EQ(tensor->vector(), vec); +} + +TEST(TestTensor, DataTypes) +{ + kp::Manager mgr; + + { + std::vector vec{ 0, 1, 2 }; + std::shared_ptr> tensor = mgr.tensor(vec); + EXPECT_EQ(tensor->dataType(), kp::Tensor::TensorDataTypes::eFloat); + } + + { + std::vector vec{ 0, 1, 2 }; + std::shared_ptr> tensor = mgr.tensorT(vec); + EXPECT_EQ(tensor->dataType(), kp::Tensor::TensorDataTypes::eInt); + } + + { + std::vector vec{ 0, 1, 2 }; + std::shared_ptr> tensor = mgr.tensorT(vec); + EXPECT_EQ(tensor->dataType(), + kp::Tensor::TensorDataTypes::eUnsignedInt); + } + + { + std::vector vec{ 0, 1, 2 }; + std::shared_ptr> tensor = mgr.tensorT(vec); + EXPECT_EQ(tensor->dataType(), kp::Tensor::TensorDataTypes::eDouble); + } +} diff --git a/ggml/src/ggml-kompute/kompute/test/TestWorkgroup.cpp b/ggml/src/ggml-kompute/kompute/test/TestWorkgroup.cpp new file mode 100644 index 0000000000000000000000000000000000000000..914ee721ea3b734e1ce827f41ae114f3b83ce999 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/TestWorkgroup.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "gtest/gtest.h" + +#include "kompute/Kompute.hpp" +#include "kompute/logger/Logger.hpp" + +#include "test_workgroup_shader.hpp" + +TEST(TestWorkgroup, TestSimpleWorkgroup) +{ + std::shared_ptr> tensorA = nullptr; + std::shared_ptr> tensorB = nullptr; + { + std::shared_ptr sq = nullptr; + + { + kp::Manager mgr; + + tensorA = mgr.tensor(std::vector(16 * 8)); + tensorB = mgr.tensor(std::vector(16 * 8)); + + std::vector> params = { tensorA, + tensorB }; + std::vector spirv( + kp::TEST_WORKGROUP_SHADER_COMP_SPV.begin(), + kp::TEST_WORKGROUP_SHADER_COMP_SPV.end()); + + kp::Workgroup workgroup = { 16, 8, 1 }; + + std::shared_ptr algorithm = + mgr.algorithm(params, spirv, workgroup); + + sq = mgr.sequence(); + sq->record(params); + sq->record(algorithm); + sq->record(params); + sq->eval(); + + std::vector expectedA = { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, + 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, + 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, + 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13, + 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15 + }; + + std::vector expectedB = { + 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, + 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, + 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, + 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, + 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, + 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, + 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7 + }; + + EXPECT_EQ(tensorA->vector(), expectedA); + EXPECT_EQ(tensorB->vector(), expectedB); + } + } +} diff --git a/ggml/src/ggml-kompute/kompute/test/compiled_shaders_include/kompute_test/shaders/shadertest_logistic_regression.hpp b/ggml/src/ggml-kompute/kompute/test/compiled_shaders_include/kompute_test/shaders/shadertest_logistic_regression.hpp new file mode 100644 index 0000000000000000000000000000000000000000..342861429d520178ea87a8cccd090f07c98ebe32 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/compiled_shaders_include/kompute_test/shaders/shadertest_logistic_regression.hpp @@ -0,0 +1,433 @@ +/* + THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT + + --- + + Copyright 2020 The Institute for Ethical AI & Machine Learning + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +#ifndef SHADEROP_SHADERTEST_LOGISTIC_REGRESSION_HPP +#define SHADEROP_SHADERTEST_LOGISTIC_REGRESSION_HPP + +namespace kp { +namespace shader_data { +static const unsigned char test_shaders_glsl_test_logistic_regression_comp_spv[] = { + 0x03, 0x02, 0x23, 0x07, 0x00, 0x00, 0x01, 0x00, 0x0a, 0x00, 0x08, 0x00, + 0xae, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x11, 0x00, 0x02, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x47, 0x4c, 0x53, 0x4c, 0x2e, 0x73, 0x74, 0x64, 0x2e, 0x34, 0x35, 0x30, + 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x06, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, + 0x41, 0x00, 0x00, 0x00, 0x10, 0x00, 0x06, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x11, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x03, 0x00, 0x02, 0x00, 0x00, 0x00, + 0xc2, 0x01, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x05, 0x00, + 0x0a, 0x00, 0x00, 0x00, 0x73, 0x69, 0x67, 0x6d, 0x6f, 0x69, 0x64, 0x28, + 0x66, 0x31, 0x3b, 0x00, 0x05, 0x00, 0x03, 0x00, 0x09, 0x00, 0x00, 0x00, + 0x7a, 0x00, 0x00, 0x00, 0x05, 0x00, 0x08, 0x00, 0x12, 0x00, 0x00, 0x00, + 0x69, 0x6e, 0x66, 0x65, 0x72, 0x65, 0x6e, 0x63, 0x65, 0x28, 0x76, 0x66, + 0x32, 0x3b, 0x76, 0x66, 0x32, 0x3b, 0x66, 0x31, 0x3b, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x03, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x03, 0x00, 0x10, 0x00, 0x00, 0x00, 0x77, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x03, 0x00, 0x11, 0x00, 0x00, 0x00, 0x62, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x08, 0x00, 0x17, 0x00, 0x00, 0x00, 0x63, 0x61, 0x6c, 0x63, + 0x75, 0x6c, 0x61, 0x74, 0x65, 0x4c, 0x6f, 0x73, 0x73, 0x28, 0x66, 0x31, + 0x3b, 0x66, 0x31, 0x3b, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, + 0x15, 0x00, 0x00, 0x00, 0x79, 0x48, 0x61, 0x74, 0x00, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x03, 0x00, 0x16, 0x00, 0x00, 0x00, 0x79, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x03, 0x00, 0x21, 0x00, 0x00, 0x00, 0x7a, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x04, 0x00, 0x27, 0x00, 0x00, 0x00, 0x79, 0x48, 0x61, 0x74, + 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, 0x28, 0x00, 0x00, 0x00, + 0x70, 0x61, 0x72, 0x61, 0x6d, 0x00, 0x00, 0x00, 0x05, 0x00, 0x03, 0x00, + 0x3e, 0x00, 0x00, 0x00, 0x69, 0x64, 0x78, 0x00, 0x05, 0x00, 0x08, 0x00, + 0x41, 0x00, 0x00, 0x00, 0x67, 0x6c, 0x5f, 0x47, 0x6c, 0x6f, 0x62, 0x61, + 0x6c, 0x49, 0x6e, 0x76, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x49, + 0x44, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, 0x46, 0x00, 0x00, 0x00, + 0x77, 0x43, 0x75, 0x72, 0x72, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, + 0x48, 0x00, 0x00, 0x00, 0x62, 0x77, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x04, 0x00, 0x48, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x77, 0x69, 0x6e, 0x00, 0x05, 0x00, 0x03, 0x00, 0x4a, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, 0x54, 0x00, 0x00, 0x00, + 0x62, 0x43, 0x75, 0x72, 0x72, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, + 0x56, 0x00, 0x00, 0x00, 0x62, 0x62, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x04, 0x00, 0x56, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x62, 0x69, 0x6e, 0x00, 0x05, 0x00, 0x03, 0x00, 0x58, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, 0x5b, 0x00, 0x00, 0x00, + 0x78, 0x43, 0x75, 0x72, 0x72, 0x00, 0x00, 0x00, 0x05, 0x00, 0x03, 0x00, + 0x5d, 0x00, 0x00, 0x00, 0x62, 0x78, 0x69, 0x00, 0x06, 0x00, 0x04, 0x00, + 0x5d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x78, 0x69, 0x00, 0x00, + 0x05, 0x00, 0x03, 0x00, 0x5f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x03, 0x00, 0x64, 0x00, 0x00, 0x00, 0x62, 0x78, 0x6a, 0x00, + 0x06, 0x00, 0x04, 0x00, 0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x78, 0x6a, 0x00, 0x00, 0x05, 0x00, 0x03, 0x00, 0x66, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, 0x6b, 0x00, 0x00, 0x00, + 0x79, 0x43, 0x75, 0x72, 0x72, 0x00, 0x00, 0x00, 0x05, 0x00, 0x03, 0x00, + 0x6d, 0x00, 0x00, 0x00, 0x62, 0x79, 0x00, 0x00, 0x06, 0x00, 0x04, 0x00, + 0x6d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x79, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x03, 0x00, 0x6f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x04, 0x00, 0x73, 0x00, 0x00, 0x00, 0x79, 0x48, 0x61, 0x74, + 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, 0x74, 0x00, 0x00, 0x00, + 0x70, 0x61, 0x72, 0x61, 0x6d, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, + 0x76, 0x00, 0x00, 0x00, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x04, 0x00, 0x78, 0x00, 0x00, 0x00, 0x70, 0x61, 0x72, 0x61, + 0x6d, 0x00, 0x00, 0x00, 0x05, 0x00, 0x03, 0x00, 0x7b, 0x00, 0x00, 0x00, + 0x64, 0x5a, 0x00, 0x00, 0x05, 0x00, 0x03, 0x00, 0x7f, 0x00, 0x00, 0x00, + 0x64, 0x57, 0x00, 0x00, 0x05, 0x00, 0x03, 0x00, 0x80, 0x00, 0x00, 0x00, + 0x6d, 0x00, 0x00, 0x00, 0x05, 0x00, 0x03, 0x00, 0x86, 0x00, 0x00, 0x00, + 0x64, 0x42, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, 0x8b, 0x00, 0x00, 0x00, + 0x62, 0x77, 0x6f, 0x75, 0x74, 0x69, 0x00, 0x00, 0x06, 0x00, 0x05, 0x00, + 0x8b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x77, 0x6f, 0x75, 0x74, + 0x69, 0x00, 0x00, 0x00, 0x05, 0x00, 0x03, 0x00, 0x8d, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, 0x93, 0x00, 0x00, 0x00, + 0x62, 0x77, 0x6f, 0x75, 0x74, 0x6a, 0x00, 0x00, 0x06, 0x00, 0x05, 0x00, + 0x93, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x77, 0x6f, 0x75, 0x74, + 0x6a, 0x00, 0x00, 0x00, 0x05, 0x00, 0x03, 0x00, 0x95, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, 0x9c, 0x00, 0x00, 0x00, + 0x62, 0x62, 0x6f, 0x75, 0x74, 0x00, 0x00, 0x00, 0x06, 0x00, 0x05, 0x00, + 0x9c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x62, 0x6f, 0x75, 0x74, + 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x03, 0x00, 0x9e, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, 0xa3, 0x00, 0x00, 0x00, + 0x62, 0x6c, 0x6f, 0x75, 0x74, 0x00, 0x00, 0x00, 0x06, 0x00, 0x05, 0x00, + 0xa3, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x6c, 0x6f, 0x75, 0x74, + 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x03, 0x00, 0xa5, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, 0xa7, 0x00, 0x00, 0x00, + 0x70, 0x61, 0x72, 0x61, 0x6d, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, + 0xa9, 0x00, 0x00, 0x00, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0x41, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x47, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, + 0x48, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x03, 0x00, 0x48, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x4a, 0x00, 0x00, 0x00, + 0x22, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0x4a, 0x00, 0x00, 0x00, 0x21, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0x55, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, 0x56, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x03, 0x00, 0x56, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0x58, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x58, 0x00, 0x00, 0x00, + 0x21, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0x5c, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x48, 0x00, 0x05, 0x00, 0x5d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x23, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x03, 0x00, + 0x5d, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0x5f, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0x5f, 0x00, 0x00, 0x00, 0x21, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x63, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, + 0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x03, 0x00, 0x64, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x66, 0x00, 0x00, 0x00, + 0x22, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0x66, 0x00, 0x00, 0x00, 0x21, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0x6c, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, 0x6d, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x03, 0x00, 0x6d, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0x6f, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x6f, 0x00, 0x00, 0x00, + 0x21, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0x80, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0x8a, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, 0x8b, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x03, 0x00, 0x8b, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0x8d, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x8d, 0x00, 0x00, 0x00, + 0x21, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0x92, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x48, 0x00, 0x05, 0x00, 0x93, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x23, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x03, 0x00, + 0x93, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0x95, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0x95, 0x00, 0x00, 0x00, 0x21, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x9b, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, + 0x9c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x03, 0x00, 0x9c, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x9e, 0x00, 0x00, 0x00, + 0x22, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0x9e, 0x00, 0x00, 0x00, 0x21, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0xa2, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, 0xa3, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x03, 0x00, 0xa3, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0xa5, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0xa5, 0x00, 0x00, 0x00, + 0x21, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0xad, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, + 0x13, 0x00, 0x02, 0x00, 0x02, 0x00, 0x00, 0x00, 0x21, 0x00, 0x03, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x16, 0x00, 0x03, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x21, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x17, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, + 0x0d, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x21, 0x00, 0x06, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x0d, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x21, 0x00, 0x05, 0x00, 0x14, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, + 0x15, 0x00, 0x04, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, 0x3d, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x17, 0x00, 0x04, 0x00, + 0x3f, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x04, 0x00, 0x40, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x3f, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x40, 0x00, 0x00, 0x00, + 0x41, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, + 0x3c, 0x00, 0x00, 0x00, 0x42, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x04, 0x00, 0x43, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, 0x1d, 0x00, 0x03, 0x00, 0x47, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x03, 0x00, 0x48, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, 0x49, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, + 0x49, 0x00, 0x00, 0x00, 0x4a, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x15, 0x00, 0x04, 0x00, 0x4b, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, 0x4b, 0x00, 0x00, 0x00, + 0x4c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, + 0x4d, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x2b, 0x00, 0x04, 0x00, 0x4b, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x1d, 0x00, 0x03, 0x00, 0x55, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x03, 0x00, 0x56, 0x00, 0x00, 0x00, + 0x55, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, 0x57, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x56, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, + 0x57, 0x00, 0x00, 0x00, 0x58, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x1d, 0x00, 0x03, 0x00, 0x5c, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x1e, 0x00, 0x03, 0x00, 0x5d, 0x00, 0x00, 0x00, 0x5c, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x04, 0x00, 0x5e, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x5d, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x5e, 0x00, 0x00, 0x00, + 0x5f, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x1d, 0x00, 0x03, 0x00, + 0x63, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x03, 0x00, + 0x64, 0x00, 0x00, 0x00, 0x63, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, + 0x65, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x65, 0x00, 0x00, 0x00, 0x66, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x1d, 0x00, 0x03, 0x00, 0x6c, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x03, 0x00, 0x6d, 0x00, 0x00, 0x00, + 0x6c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, 0x6e, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x6d, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, + 0x6e, 0x00, 0x00, 0x00, 0x6f, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x32, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x1d, 0x00, 0x03, 0x00, 0x8a, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x03, 0x00, 0x8b, 0x00, 0x00, 0x00, + 0x8a, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, 0x8c, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x8b, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, + 0x8c, 0x00, 0x00, 0x00, 0x8d, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x1d, 0x00, 0x03, 0x00, 0x92, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x1e, 0x00, 0x03, 0x00, 0x93, 0x00, 0x00, 0x00, 0x92, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x04, 0x00, 0x94, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x93, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x94, 0x00, 0x00, 0x00, + 0x95, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, + 0x3c, 0x00, 0x00, 0x00, 0x97, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x1d, 0x00, 0x03, 0x00, 0x9b, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x1e, 0x00, 0x03, 0x00, 0x9c, 0x00, 0x00, 0x00, 0x9b, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x04, 0x00, 0x9d, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x9c, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x9d, 0x00, 0x00, 0x00, + 0x9e, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x1d, 0x00, 0x03, 0x00, + 0xa2, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x03, 0x00, + 0xa3, 0x00, 0x00, 0x00, 0xa2, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, + 0xa4, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0xa3, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0xa4, 0x00, 0x00, 0x00, 0xa5, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x06, 0x00, 0x3f, 0x00, 0x00, 0x00, + 0xad, 0x00, 0x00, 0x00, 0x97, 0x00, 0x00, 0x00, 0x97, 0x00, 0x00, 0x00, + 0x97, 0x00, 0x00, 0x00, 0x36, 0x00, 0x05, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0xf8, 0x00, 0x02, 0x00, 0x05, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, + 0x3d, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x46, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x54, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, + 0x0d, 0x00, 0x00, 0x00, 0x5b, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, 0x6b, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x73, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, + 0x0d, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x76, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x78, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x7b, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x7f, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x86, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, + 0x07, 0x00, 0x00, 0x00, 0xa7, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, 0xa9, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x41, 0x00, 0x05, 0x00, 0x43, 0x00, 0x00, 0x00, + 0x44, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x00, 0x42, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x45, 0x00, 0x00, 0x00, + 0x44, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, 0x3e, 0x00, 0x00, 0x00, + 0x45, 0x00, 0x00, 0x00, 0x41, 0x00, 0x06, 0x00, 0x4d, 0x00, 0x00, 0x00, + 0x4e, 0x00, 0x00, 0x00, 0x4a, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, + 0x4c, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x4f, 0x00, 0x00, 0x00, 0x4e, 0x00, 0x00, 0x00, 0x41, 0x00, 0x06, 0x00, + 0x4d, 0x00, 0x00, 0x00, 0x51, 0x00, 0x00, 0x00, 0x4a, 0x00, 0x00, 0x00, + 0x4c, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x52, 0x00, 0x00, 0x00, 0x51, 0x00, 0x00, 0x00, + 0x50, 0x00, 0x05, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x53, 0x00, 0x00, 0x00, + 0x4f, 0x00, 0x00, 0x00, 0x52, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, + 0x46, 0x00, 0x00, 0x00, 0x53, 0x00, 0x00, 0x00, 0x41, 0x00, 0x06, 0x00, + 0x4d, 0x00, 0x00, 0x00, 0x59, 0x00, 0x00, 0x00, 0x58, 0x00, 0x00, 0x00, + 0x4c, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x5a, 0x00, 0x00, 0x00, 0x59, 0x00, 0x00, 0x00, + 0x3e, 0x00, 0x03, 0x00, 0x54, 0x00, 0x00, 0x00, 0x5a, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x60, 0x00, 0x00, 0x00, + 0x3e, 0x00, 0x00, 0x00, 0x41, 0x00, 0x06, 0x00, 0x4d, 0x00, 0x00, 0x00, + 0x61, 0x00, 0x00, 0x00, 0x5f, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, + 0x60, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x62, 0x00, 0x00, 0x00, 0x61, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x3c, 0x00, 0x00, 0x00, 0x67, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x00, 0x00, + 0x41, 0x00, 0x06, 0x00, 0x4d, 0x00, 0x00, 0x00, 0x68, 0x00, 0x00, 0x00, + 0x66, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x67, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x69, 0x00, 0x00, 0x00, + 0x68, 0x00, 0x00, 0x00, 0x50, 0x00, 0x05, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x6a, 0x00, 0x00, 0x00, 0x62, 0x00, 0x00, 0x00, 0x69, 0x00, 0x00, 0x00, + 0x3e, 0x00, 0x03, 0x00, 0x5b, 0x00, 0x00, 0x00, 0x6a, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x70, 0x00, 0x00, 0x00, + 0x3e, 0x00, 0x00, 0x00, 0x41, 0x00, 0x06, 0x00, 0x4d, 0x00, 0x00, 0x00, + 0x71, 0x00, 0x00, 0x00, 0x6f, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, + 0x70, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x72, 0x00, 0x00, 0x00, 0x71, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, + 0x6b, 0x00, 0x00, 0x00, 0x72, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x75, 0x00, 0x00, 0x00, 0x5b, 0x00, 0x00, 0x00, + 0x3e, 0x00, 0x03, 0x00, 0x74, 0x00, 0x00, 0x00, 0x75, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x77, 0x00, 0x00, 0x00, + 0x46, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, 0x76, 0x00, 0x00, 0x00, + 0x77, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x79, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, + 0x78, 0x00, 0x00, 0x00, 0x79, 0x00, 0x00, 0x00, 0x39, 0x00, 0x07, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x7a, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, + 0x74, 0x00, 0x00, 0x00, 0x76, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, + 0x3e, 0x00, 0x03, 0x00, 0x73, 0x00, 0x00, 0x00, 0x7a, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x7c, 0x00, 0x00, 0x00, + 0x73, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x7d, 0x00, 0x00, 0x00, 0x6b, 0x00, 0x00, 0x00, 0x83, 0x00, 0x05, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x7e, 0x00, 0x00, 0x00, 0x7c, 0x00, 0x00, 0x00, + 0x7d, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, 0x7b, 0x00, 0x00, 0x00, + 0x7e, 0x00, 0x00, 0x00, 0x88, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x81, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x82, 0x00, 0x00, 0x00, + 0x5b, 0x00, 0x00, 0x00, 0x8e, 0x00, 0x05, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x83, 0x00, 0x00, 0x00, 0x82, 0x00, 0x00, 0x00, 0x81, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x84, 0x00, 0x00, 0x00, + 0x7b, 0x00, 0x00, 0x00, 0x8e, 0x00, 0x05, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x85, 0x00, 0x00, 0x00, 0x83, 0x00, 0x00, 0x00, 0x84, 0x00, 0x00, 0x00, + 0x3e, 0x00, 0x03, 0x00, 0x7f, 0x00, 0x00, 0x00, 0x85, 0x00, 0x00, 0x00, + 0x88, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x87, 0x00, 0x00, 0x00, + 0x19, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x88, 0x00, 0x00, 0x00, 0x7b, 0x00, 0x00, 0x00, + 0x85, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x89, 0x00, 0x00, 0x00, + 0x87, 0x00, 0x00, 0x00, 0x88, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, + 0x86, 0x00, 0x00, 0x00, 0x89, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x3c, 0x00, 0x00, 0x00, 0x8e, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x00, 0x00, + 0x41, 0x00, 0x05, 0x00, 0x07, 0x00, 0x00, 0x00, 0x8f, 0x00, 0x00, 0x00, + 0x7f, 0x00, 0x00, 0x00, 0x42, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x90, 0x00, 0x00, 0x00, 0x8f, 0x00, 0x00, 0x00, + 0x41, 0x00, 0x06, 0x00, 0x4d, 0x00, 0x00, 0x00, 0x91, 0x00, 0x00, 0x00, + 0x8d, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x8e, 0x00, 0x00, 0x00, + 0x3e, 0x00, 0x03, 0x00, 0x91, 0x00, 0x00, 0x00, 0x90, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x96, 0x00, 0x00, 0x00, + 0x3e, 0x00, 0x00, 0x00, 0x41, 0x00, 0x05, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x98, 0x00, 0x00, 0x00, 0x7f, 0x00, 0x00, 0x00, 0x97, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x99, 0x00, 0x00, 0x00, + 0x98, 0x00, 0x00, 0x00, 0x41, 0x00, 0x06, 0x00, 0x4d, 0x00, 0x00, 0x00, + 0x9a, 0x00, 0x00, 0x00, 0x95, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, + 0x96, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, 0x9a, 0x00, 0x00, 0x00, + 0x99, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x3c, 0x00, 0x00, 0x00, + 0x9f, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0xa0, 0x00, 0x00, 0x00, 0x86, 0x00, 0x00, 0x00, + 0x41, 0x00, 0x06, 0x00, 0x4d, 0x00, 0x00, 0x00, 0xa1, 0x00, 0x00, 0x00, + 0x9e, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x9f, 0x00, 0x00, 0x00, + 0x3e, 0x00, 0x03, 0x00, 0xa1, 0x00, 0x00, 0x00, 0xa0, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x3c, 0x00, 0x00, 0x00, 0xa6, 0x00, 0x00, 0x00, + 0x3e, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0xa8, 0x00, 0x00, 0x00, 0x73, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, + 0xa7, 0x00, 0x00, 0x00, 0xa8, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0xaa, 0x00, 0x00, 0x00, 0x6b, 0x00, 0x00, 0x00, + 0x3e, 0x00, 0x03, 0x00, 0xa9, 0x00, 0x00, 0x00, 0xaa, 0x00, 0x00, 0x00, + 0x39, 0x00, 0x06, 0x00, 0x06, 0x00, 0x00, 0x00, 0xab, 0x00, 0x00, 0x00, + 0x17, 0x00, 0x00, 0x00, 0xa7, 0x00, 0x00, 0x00, 0xa9, 0x00, 0x00, 0x00, + 0x41, 0x00, 0x06, 0x00, 0x4d, 0x00, 0x00, 0x00, 0xac, 0x00, 0x00, 0x00, + 0xa5, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0xa6, 0x00, 0x00, 0x00, + 0x3e, 0x00, 0x03, 0x00, 0xac, 0x00, 0x00, 0x00, 0xab, 0x00, 0x00, 0x00, + 0xfd, 0x00, 0x01, 0x00, 0x38, 0x00, 0x01, 0x00, 0x36, 0x00, 0x05, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x37, 0x00, 0x03, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, 0xf8, 0x00, 0x02, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x1a, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, 0x7f, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x1b, 0x00, 0x00, 0x00, 0x1a, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x06, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x1b, 0x00, 0x00, 0x00, 0x1b, 0x00, 0x00, 0x00, 0x81, 0x00, 0x05, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x1d, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x88, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x1e, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, 0x1d, 0x00, 0x00, 0x00, + 0xfe, 0x00, 0x02, 0x00, 0x1e, 0x00, 0x00, 0x00, 0x38, 0x00, 0x01, 0x00, + 0x36, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x37, 0x00, 0x03, 0x00, + 0x0d, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x37, 0x00, 0x03, 0x00, + 0x0d, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x37, 0x00, 0x03, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, 0xf8, 0x00, 0x02, 0x00, + 0x13, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x21, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x27, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x22, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, + 0x94, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, + 0x22, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x25, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, + 0x81, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x26, 0x00, 0x00, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x25, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, + 0x21, 0x00, 0x00, 0x00, 0x26, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x29, 0x00, 0x00, 0x00, 0x21, 0x00, 0x00, 0x00, + 0x3e, 0x00, 0x03, 0x00, 0x28, 0x00, 0x00, 0x00, 0x29, 0x00, 0x00, 0x00, + 0x39, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x2a, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, + 0x27, 0x00, 0x00, 0x00, 0x2a, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x00, 0x00, 0x27, 0x00, 0x00, 0x00, + 0xfe, 0x00, 0x02, 0x00, 0x2b, 0x00, 0x00, 0x00, 0x38, 0x00, 0x01, 0x00, + 0x36, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x37, 0x00, 0x03, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00, 0x37, 0x00, 0x03, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0xf8, 0x00, 0x02, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x2e, 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x2f, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x06, 0x00, 0x06, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x2f, 0x00, 0x00, 0x00, + 0x85, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x31, 0x00, 0x00, 0x00, + 0x2e, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x32, 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, + 0x83, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x33, 0x00, 0x00, 0x00, + 0x19, 0x00, 0x00, 0x00, 0x32, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00, + 0x83, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x35, 0x00, 0x00, 0x00, + 0x19, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x06, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x36, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x35, 0x00, 0x00, 0x00, 0x85, 0x00, 0x05, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x37, 0x00, 0x00, 0x00, 0x33, 0x00, 0x00, 0x00, + 0x36, 0x00, 0x00, 0x00, 0x81, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x38, 0x00, 0x00, 0x00, 0x31, 0x00, 0x00, 0x00, 0x37, 0x00, 0x00, 0x00, + 0x7f, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x39, 0x00, 0x00, 0x00, + 0x38, 0x00, 0x00, 0x00, 0xfe, 0x00, 0x02, 0x00, 0x39, 0x00, 0x00, 0x00, + 0x38, 0x00, 0x01, 0x00 +}; +static const unsigned int test_shaders_glsl_test_logistic_regression_comp_spv_len = 4816; +} +} +#endif // define SHADEROP_SHADERTEST_LOGISTIC_REGRESSION_HPP diff --git a/ggml/src/ggml-kompute/kompute/test/compiled_shaders_include/kompute_test/shaders/shadertest_op_custom_shader.hpp b/ggml/src/ggml-kompute/kompute/test/compiled_shaders_include/kompute_test/shaders/shadertest_op_custom_shader.hpp new file mode 100644 index 0000000000000000000000000000000000000000..edbf2eed77d7b0717d527794b84861032d0eaa86 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/compiled_shaders_include/kompute_test/shaders/shadertest_op_custom_shader.hpp @@ -0,0 +1,123 @@ +/* + THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT + + --- + + Copyright 2020 The Institute for Ethical AI & Machine Learning + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +#ifndef SHADEROP_SHADERTEST_OP_CUSTOM_SHADER_HPP +#define SHADEROP_SHADERTEST_OP_CUSTOM_SHADER_HPP + +namespace kp { +namespace shader_data { +static const unsigned char test_shaders_glsl_test_op_custom_shader_comp_spv[] = { + 0x03, 0x02, 0x23, 0x07, 0x00, 0x00, 0x01, 0x00, 0x0a, 0x00, 0x08, 0x00, + 0x27, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x11, 0x00, 0x02, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x47, 0x4c, 0x53, 0x4c, 0x2e, 0x73, 0x74, 0x64, 0x2e, 0x34, 0x35, 0x30, + 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x06, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x10, 0x00, 0x06, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x11, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x03, 0x00, 0x02, 0x00, 0x00, 0x00, + 0xc2, 0x01, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x69, 0x6e, 0x64, 0x65, 0x78, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x08, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x67, 0x6c, 0x5f, 0x47, + 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x49, 0x6e, 0x76, 0x6f, 0x63, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x49, 0x44, 0x00, 0x00, 0x00, 0x05, 0x00, 0x03, 0x00, + 0x12, 0x00, 0x00, 0x00, 0x62, 0x00, 0x00, 0x00, 0x06, 0x00, 0x04, 0x00, + 0x12, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x70, 0x62, 0x00, 0x00, + 0x05, 0x00, 0x03, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x03, 0x00, 0x19, 0x00, 0x00, 0x00, 0x61, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x04, 0x00, 0x19, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x70, 0x61, 0x00, 0x00, 0x05, 0x00, 0x03, 0x00, 0x1b, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0x11, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x48, 0x00, 0x05, 0x00, 0x12, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x23, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x03, 0x00, + 0x12, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0x14, 0x00, 0x00, 0x00, 0x21, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, + 0x19, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x03, 0x00, 0x19, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x1b, 0x00, 0x00, 0x00, + 0x22, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0x1b, 0x00, 0x00, 0x00, 0x21, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0x26, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x19, 0x00, 0x00, 0x00, 0x13, 0x00, 0x02, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x21, 0x00, 0x03, 0x00, 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x15, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x17, 0x00, 0x04, 0x00, + 0x09, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x04, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x0a, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x04, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x16, 0x00, 0x03, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x1d, 0x00, 0x03, 0x00, 0x11, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x03, 0x00, 0x12, 0x00, 0x00, 0x00, + 0x11, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, 0x13, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, + 0x13, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x15, 0x00, 0x04, 0x00, 0x15, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, 0x15, 0x00, 0x00, 0x00, + 0x16, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x1d, 0x00, 0x03, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x03, 0x00, + 0x19, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, + 0x1a, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x1a, 0x00, 0x00, 0x00, 0x1b, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, 0x1d, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x25, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x2c, 0x00, 0x06, 0x00, 0x09, 0x00, 0x00, 0x00, 0x26, 0x00, 0x00, 0x00, + 0x25, 0x00, 0x00, 0x00, 0x25, 0x00, 0x00, 0x00, 0x25, 0x00, 0x00, 0x00, + 0x36, 0x00, 0x05, 0x00, 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0xf8, 0x00, 0x02, 0x00, + 0x05, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x41, 0x00, 0x05, 0x00, + 0x0d, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x0f, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x41, 0x00, 0x06, 0x00, 0x1d, 0x00, 0x00, 0x00, + 0x1e, 0x00, 0x00, 0x00, 0x1b, 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x1f, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x00, 0x00, 0x41, 0x00, 0x06, 0x00, + 0x1d, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x16, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x1f, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x21, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x70, 0x00, 0x04, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x23, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, 0x41, 0x00, 0x06, 0x00, + 0x1d, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x1b, 0x00, 0x00, 0x00, + 0x16, 0x00, 0x00, 0x00, 0x21, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0xfd, 0x00, 0x01, 0x00, + 0x38, 0x00, 0x01, 0x00 +}; +static const unsigned int test_shaders_glsl_test_op_custom_shader_comp_spv_len = 1096; +} +} +#endif // define SHADEROP_SHADERTEST_OP_CUSTOM_SHADER_HPP diff --git a/ggml/src/ggml-kompute/kompute/test/compiled_shaders_include/kompute_test/shaders/shadertest_workgroup.hpp b/ggml/src/ggml-kompute/kompute/test/compiled_shaders_include/kompute_test/shaders/shadertest_workgroup.hpp new file mode 100644 index 0000000000000000000000000000000000000000..15d62347822ba4dad1d3bd080d459ac0d77fe193 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/compiled_shaders_include/kompute_test/shaders/shadertest_workgroup.hpp @@ -0,0 +1,148 @@ +/* + THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT + + --- + + Copyright 2020 The Institute for Ethical AI & Machine Learning + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +#ifndef SHADEROP_SHADERTEST_WORKGROUP_HPP +#define SHADEROP_SHADERTEST_WORKGROUP_HPP + +namespace kp { +namespace shader_data { +static const unsigned char test_shaders_glsl_test_workgroup_comp_spv[] = { + 0x03, 0x02, 0x23, 0x07, 0x00, 0x00, 0x01, 0x00, 0x0a, 0x00, 0x08, 0x00, + 0x30, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x11, 0x00, 0x02, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x47, 0x4c, 0x53, 0x4c, 0x2e, 0x73, 0x74, 0x64, 0x2e, 0x34, 0x35, 0x30, + 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x08, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x06, 0x00, 0x04, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x03, 0x00, 0x02, 0x00, 0x00, 0x00, 0xc2, 0x01, 0x00, 0x00, + 0x05, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e, + 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x69, 0x6e, 0x64, 0x65, 0x78, 0x00, 0x00, 0x00, 0x05, 0x00, 0x06, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x67, 0x6c, 0x5f, 0x57, 0x6f, 0x72, 0x6b, 0x47, + 0x72, 0x6f, 0x75, 0x70, 0x49, 0x44, 0x00, 0x00, 0x05, 0x00, 0x07, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x67, 0x6c, 0x5f, 0x4e, 0x75, 0x6d, 0x57, 0x6f, + 0x72, 0x6b, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x00, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x04, 0x00, 0x1a, 0x00, 0x00, 0x00, 0x62, 0x6f, 0x75, 0x74, + 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x05, 0x00, 0x1a, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x74, 0x6f, 0x75, 0x74, 0x78, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x03, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x08, 0x00, 0x20, 0x00, 0x00, 0x00, 0x67, 0x6c, 0x5f, 0x47, + 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x49, 0x6e, 0x76, 0x6f, 0x63, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x49, 0x44, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, + 0x27, 0x00, 0x00, 0x00, 0x62, 0x6f, 0x75, 0x74, 0x32, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x05, 0x00, 0x27, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x74, 0x6f, 0x75, 0x74, 0x79, 0x00, 0x00, 0x00, 0x05, 0x00, 0x03, 0x00, + 0x29, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x1a, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0x10, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x19, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x48, 0x00, 0x04, 0x00, + 0x1a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, + 0x48, 0x00, 0x05, 0x00, 0x1a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x23, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x03, 0x00, + 0x1a, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x21, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0x26, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x48, 0x00, 0x04, 0x00, 0x27, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x19, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, 0x27, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x03, 0x00, 0x27, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0x29, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x29, 0x00, 0x00, 0x00, + 0x21, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0x2f, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, + 0x13, 0x00, 0x02, 0x00, 0x02, 0x00, 0x00, 0x00, 0x21, 0x00, 0x03, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x15, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x17, 0x00, 0x04, 0x00, 0x09, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, + 0x0a, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, + 0x0d, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x11, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x16, 0x00, 0x03, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x1d, 0x00, 0x03, 0x00, + 0x19, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x03, 0x00, + 0x1a, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, + 0x1b, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x1a, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x1b, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x15, 0x00, 0x04, 0x00, 0x1d, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, + 0x1d, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, 0x24, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x1d, 0x00, 0x03, 0x00, + 0x26, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x03, 0x00, + 0x27, 0x00, 0x00, 0x00, 0x26, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, + 0x28, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x27, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x28, 0x00, 0x00, 0x00, 0x29, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x06, 0x00, 0x09, 0x00, 0x00, 0x00, + 0x2f, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, + 0x11, 0x00, 0x00, 0x00, 0x36, 0x00, 0x05, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0xf8, 0x00, 0x02, 0x00, 0x05, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x41, 0x00, 0x05, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x00, 0x00, + 0x41, 0x00, 0x05, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x13, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, + 0x84, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0f, 0x00, 0x00, 0x00, 0x13, 0x00, 0x00, 0x00, 0x41, 0x00, 0x05, 0x00, + 0x0d, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x11, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x16, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00, 0x80, 0x00, 0x05, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x16, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x17, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x1f, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x41, 0x00, 0x05, 0x00, + 0x0d, 0x00, 0x00, 0x00, 0x21, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x22, 0x00, 0x00, 0x00, 0x21, 0x00, 0x00, 0x00, 0x70, 0x00, 0x04, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, + 0x41, 0x00, 0x06, 0x00, 0x24, 0x00, 0x00, 0x00, 0x25, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x00, 0x00, 0x1f, 0x00, 0x00, 0x00, + 0x3e, 0x00, 0x03, 0x00, 0x25, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x2a, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x41, 0x00, 0x05, 0x00, 0x0d, 0x00, 0x00, 0x00, + 0x2b, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00, + 0x2b, 0x00, 0x00, 0x00, 0x70, 0x00, 0x04, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x2d, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00, 0x41, 0x00, 0x06, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x2e, 0x00, 0x00, 0x00, 0x29, 0x00, 0x00, 0x00, + 0x1e, 0x00, 0x00, 0x00, 0x2a, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, + 0x2e, 0x00, 0x00, 0x00, 0x2d, 0x00, 0x00, 0x00, 0xfd, 0x00, 0x01, 0x00, + 0x38, 0x00, 0x01, 0x00 +}; +static const unsigned int test_shaders_glsl_test_workgroup_comp_spv_len = 1396; +} +} +#endif // define SHADEROP_SHADERTEST_WORKGROUP_HPP diff --git a/ggml/src/ggml-kompute/kompute/test/shaders/CMakeLists.txt b/ggml/src/ggml-kompute/kompute/test/shaders/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9b134bac9f49ed381c7a52f1f058cc7664bbc4c8 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/shaders/CMakeLists.txt @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +# ###################### +cmake_minimum_required(VERSION 3.20) + +add_library(test_shaders "Utils.cpp" + "Utils.hpp") + +add_subdirectory(glsl) \ No newline at end of file diff --git a/ggml/src/ggml-kompute/kompute/test/shaders/Utils.cpp b/ggml/src/ggml-kompute/kompute/test/shaders/Utils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..85a38b56800569812624c4582db1f0c830f6e9c4 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/shaders/Utils.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "Utils.hpp" +#include +#include +#include +#include +#include + +std::vector +compileSource(const std::string& source) +{ + std::ofstream fileOut("tmp_kp_shader.comp"); + fileOut << source; + fileOut.close(); + if (system( + std::string( + "glslangValidator -V tmp_kp_shader.comp -o tmp_kp_shader.comp.spv") + .c_str())) { + throw std::runtime_error("Error running glslangValidator command"); + } + std::ifstream fileStream("tmp_kp_shader.comp.spv", std::ios::binary); + std::vector buffer; + buffer.insert( + buffer.begin(), std::istreambuf_iterator(fileStream), {}); + return { reinterpret_cast(buffer.data()), + reinterpret_cast(buffer.data() + buffer.size()) }; +} diff --git a/ggml/src/ggml-kompute/kompute/test/shaders/Utils.hpp b/ggml/src/ggml-kompute/kompute/test/shaders/Utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fa8aa9729e8a31963c3d845b1d8dff4824dc826e --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/shaders/Utils.hpp @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +/** + * Compile a single glslang source from string value. This is only meant + * to be used for testing as it's non threadsafe, and it had to be removed + * from the glslang dependency and now can only run the CLI directly due to + * license issues: see https://github.com/KomputeProject/kompute/pull/235 + * + * @param source An individual raw glsl shader in string format + * @return The compiled SPIR-V binary in unsigned int32 format + */ +std::vector +compileSource(const std::string& source); diff --git a/ggml/src/ggml-kompute/kompute/test/shaders/glsl/CMakeLists.txt b/ggml/src/ggml-kompute/kompute/test/shaders/glsl/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..c8c3cba9a29b04d4ae546c65ba7caa037a1ff6fb --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/shaders/glsl/CMakeLists.txt @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 +# ###################### +cmake_minimum_required(VERSION 3.20) + +vulkan_compile_shader(INFILE test_logistic_regression_shader.comp + OUTFILE test_logistic_regression_shader.hpp + NAMESPACE "kp") + +vulkan_compile_shader(INFILE test_op_custom_shader.comp + OUTFILE test_op_custom_shader.hpp + NAMESPACE "kp") + +vulkan_compile_shader(INFILE test_workgroup_shader.comp + OUTFILE test_workgroup_shader.hpp + NAMESPACE "kp") + +vulkan_compile_shader(INFILE test_shader.comp + OUTFILE test_shader.hpp + NAMESPACE "kp") + +add_library(test_shaders_glsl INTERFACE "${CMAKE_CURRENT_BINARY_DIR}/test_logistic_regression_shader.hpp" + "${CMAKE_CURRENT_BINARY_DIR}/test_op_custom_shader.hpp" + "${CMAKE_CURRENT_BINARY_DIR}/test_workgroup_shader.hpp" + "${CMAKE_CURRENT_BINARY_DIR}/test_shader.hpp") + +target_include_directories(test_shaders_glsl INTERFACE $) diff --git a/ggml/src/ggml-kompute/kompute/test/shaders/glsl/test_logistic_regression_shader.comp b/ggml/src/ggml-kompute/kompute/test/shaders/glsl/test_logistic_regression_shader.comp new file mode 100644 index 0000000000000000000000000000000000000000..6934fd29af348262e8aa1b909fda2246cbc1b276 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/shaders/glsl/test_logistic_regression_shader.comp @@ -0,0 +1,53 @@ +#version 450 + +layout (constant_id = 0) const float m = 0; + +layout (local_size_x = 1) in; + +layout(set = 0, binding = 0) buffer bxi { float xi[]; }; +layout(set = 0, binding = 1) buffer bxj { float xj[]; }; +layout(set = 0, binding = 2) buffer by { float y[]; }; +layout(set = 0, binding = 3) buffer bwin { float win[]; }; +layout(set = 0, binding = 4) buffer bwouti { float wouti[]; }; +layout(set = 0, binding = 5) buffer bwoutj { float woutj[]; }; +layout(set = 0, binding = 6) buffer bbin { float bin[]; }; +layout(set = 0, binding = 7) buffer bbout { float bout[]; }; +layout(set = 0, binding = 8) buffer blout { float lout[]; }; + +float sigmoid(float z) { + return 1.0 / (1.0 + exp(-z)); +} + +float inference(vec2 x, vec2 w, float b) { + // Compute the linear mapping function + float z = dot(w, x) + b; + // Calculate the y-hat with sigmoid + float yHat = sigmoid(z); + return yHat; +} + +float calculateLoss(float yHat, float y) { + return -(y * log(yHat) + (1.0 - y) * log(1.0 - yHat)); +} + +void main() { + uint idx = gl_GlobalInvocationID.x; + + vec2 wCurr = vec2(win[0], win[1]); + float bCurr = bin[0]; + + vec2 xCurr = vec2(xi[idx], xj[idx]); + float yCurr = y[idx]; + + float yHat = inference(xCurr, wCurr, bCurr); + + float dZ = yHat - yCurr; + vec2 dW = (1. / m) * xCurr * dZ; + float dB = (1. / m) * dZ; + wouti[idx] = dW.x; + woutj[idx] = dW.y; + bout[idx] = dB; + + lout[idx] = calculateLoss(yHat, yCurr); +} + diff --git a/ggml/src/ggml-kompute/kompute/test/shaders/glsl/test_op_custom_shader.comp b/ggml/src/ggml-kompute/kompute/test/shaders/glsl/test_op_custom_shader.comp new file mode 100644 index 0000000000000000000000000000000000000000..1e0e0d199edf17e3cee0509dae4a66f20664a1b6 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/shaders/glsl/test_op_custom_shader.comp @@ -0,0 +1,13 @@ +#version 450 + +layout (local_size_x = 1) in; + +layout(set = 0, binding = 0) buffer a { float pa[]; }; +layout(set = 0, binding = 1) buffer b { float pb[]; }; + +void main() { + uint index = gl_GlobalInvocationID.x; + pb[index] = pa[index]; + pa[index] = index; +} + diff --git a/ggml/src/ggml-kompute/kompute/test/shaders/glsl/test_shader.comp b/ggml/src/ggml-kompute/kompute/test/shaders/glsl/test_shader.comp new file mode 100644 index 0000000000000000000000000000000000000000..fb4a7ab62985d6702171cef0f1d4d94047839d3d --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/shaders/glsl/test_shader.comp @@ -0,0 +1,12 @@ +#version 450 + +layout (local_size_x = 1) in; + +layout(set = 0, binding = 0) buffer a { float pa[]; }; +layout(set = 0, binding = 1) buffer b { float pb[]; }; + +void main() { + uint index = gl_GlobalInvocationID.x; + pb[index] = pa[index]; + pa[index] = index; +} \ No newline at end of file diff --git a/ggml/src/ggml-kompute/kompute/test/shaders/glsl/test_workgroup_shader.comp b/ggml/src/ggml-kompute/kompute/test/shaders/glsl/test_workgroup_shader.comp new file mode 100644 index 0000000000000000000000000000000000000000..962804b27df5fcf1a29d32a6c406934b2e1ac553 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/test/shaders/glsl/test_workgroup_shader.comp @@ -0,0 +1,15 @@ +#version 450 + +layout (local_size_x = 1) in; + +// The input rebuild bind index is relative to index in parameter passed +layout(set = 0, binding = 0) writeonly buffer bout { float toutx[]; }; +layout(set = 0, binding = 1) writeonly buffer bout2 { float touty[]; }; + +void main() { + uint index = gl_WorkGroupID.x*gl_NumWorkGroups.y + gl_WorkGroupID.y; + + toutx[index] = gl_GlobalInvocationID.x; + touty[index] = gl_GlobalInvocationID.y; +} + diff --git a/ggml/src/ggml-kompute/kompute/vcpkg.json.opt b/ggml/src/ggml-kompute/kompute/vcpkg.json.opt new file mode 100644 index 0000000000000000000000000000000000000000..9ce3935a17e33f1d8ae948b321250614b780858a --- /dev/null +++ b/ggml/src/ggml-kompute/kompute/vcpkg.json.opt @@ -0,0 +1,10 @@ +{ + "name": "example", + "version-string": "0.8.1", + "dependencies": [ + "fmt", + "spdlog", + "vulkan-headers", + "gtest" + ] +} diff --git a/ggml/src/ggml-metal/CMakeLists.txt b/ggml/src/ggml-metal/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..77187efc1756dc98c52efa3b43cca89c140e81a8 --- /dev/null +++ b/ggml/src/ggml-metal/CMakeLists.txt @@ -0,0 +1,121 @@ +find_library(FOUNDATION_LIBRARY Foundation REQUIRED) +find_library(METAL_FRAMEWORK Metal REQUIRED) +find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) + +message(STATUS "Metal framework found") + +ggml_add_backend_library(ggml-metal + ggml-metal.m + ) + +target_link_libraries(ggml-metal PRIVATE + ${FOUNDATION_LIBRARY} + ${METAL_FRAMEWORK} + ${METALKIT_FRAMEWORK} + ) + +if (GGML_METAL_NDEBUG) + add_compile_definitions(GGML_METAL_NDEBUG) +endif() + +if (GGML_METAL_USE_BF16) + add_compile_definitions(GGML_METAL_USE_BF16) +endif() + +# copy metal files to bin directory +configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) +configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) +configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY) + +set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h") +if (GGML_METAL_EMBED_LIBRARY) + enable_language(ASM) + + add_compile_definitions(GGML_METAL_EMBED_LIBRARY) + + set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal") + set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h") + + file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated") + + # merge ggml-common.h and ggml-metal.metal into a single file + set(METALLIB_EMBED_ASM "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s") + set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal") + set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp") + + add_custom_command( + OUTPUT "${METALLIB_EMBED_ASM}" + COMMAND echo "Embedding Metal library" + COMMAND sed -e "/__embed_ggml-common.h__/r ${METALLIB_COMMON}" -e "/__embed_ggml-common.h__/d" < "${METALLIB_SOURCE}" > "${METALLIB_SOURCE_EMBED_TMP}" + COMMAND sed -e "/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}" -e "/\#include \"ggml-metal-impl.h\"/d" < "${METALLIB_SOURCE_EMBED_TMP}" > "${METALLIB_SOURCE_EMBED}" + COMMAND echo ".section __DATA,__ggml_metallib" > "${METALLIB_EMBED_ASM}" + COMMAND echo ".globl _ggml_metallib_start" >> "${METALLIB_EMBED_ASM}" + COMMAND echo "_ggml_metallib_start:" >> "${METALLIB_EMBED_ASM}" + COMMAND echo .incbin "\"${METALLIB_SOURCE_EMBED}\"" >> "${METALLIB_EMBED_ASM}" + COMMAND echo ".globl _ggml_metallib_end" >> "${METALLIB_EMBED_ASM}" + COMMAND echo "_ggml_metallib_end:" >> "${METALLIB_EMBED_ASM}" + DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h + COMMENT "Generate assembly for embedded Metal library" + VERBATIM + ) + + target_sources(ggml-metal PRIVATE "${METALLIB_EMBED_ASM}") +else() + if (GGML_METAL_SHADER_DEBUG) + # custom command to do the following: + # xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air + # xcrun -sdk macosx metallib ggml-metal.air -o default.metallib + # + # note: this is the only way I found to disable fast-math in Metal. it's ugly, but at least it works + # disabling fast math is needed in order to pass tests/test-backend-ops + # note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1 + # note: unfortunately, we have to call it default.metallib instead of ggml.metallib + # ref: https://github.com/ggerganov/whisper.cpp/issues/1720 + set(XC_FLAGS -fno-fast-math -fno-inline -g) + else() + set(XC_FLAGS -O3) + endif() + + # Append macOS metal versioning flags + if (GGML_METAL_MACOSX_VERSION_MIN) + message(STATUS "Adding -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN} flag to metal compilation") + list (APPEND XC_FLAGS -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN}) + endif() + + if (GGML_METAL_STD) + message(STATUS "Adding -std=${GGML_METAL_STD} flag to metal compilation") + list (APPEND XC_FLAGS -std=${GGML_METAL_STD}) + endif() + + add_custom_command( + OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib + COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o - | + xcrun -sdk macosx metallib - -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib + COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h + COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal + DEPENDS ggml-metal.metal ${METALLIB_COMMON} + COMMENT "Compiling Metal kernels" + ) + + # FIXME: only add to the ggml-metal target? + add_custom_target( + ggml-metal-lib ALL + DEPENDS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib + ) +endif() # GGML_METAL_EMBED_LIBRARY + +if (NOT GGML_METAL_EMBED_LIBRARY) + install( + FILES src/ggml-metal/ggml-metal.metal + PERMISSIONS + OWNER_READ + OWNER_WRITE + GROUP_READ + WORLD_READ + DESTINATION ${CMAKE_INSTALL_BINDIR}) + + install( + FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib + DESTINATION ${CMAKE_INSTALL_BINDIR} + ) +endif() diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h new file mode 100644 index 0000000000000000000000000000000000000000..17eab976f3ad114eaeb9ce108ecd635ce12fddc9 --- /dev/null +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -0,0 +1,622 @@ +#ifndef GGML_METAL_IMPL +#define GGML_METAL_IMPL + +// kernel parameters for mat-vec threadgroups +// +// N_R0: number of src0 rows to process per simdgroup +// N_SG: number of simdgroups per threadgroup +// +// TODO: for optimal performance, become function of the device and work size + +#define N_R0_Q4_0 4 +#define N_SG_Q4_0 2 + +#define N_R0_Q4_1 4 +#define N_SG_Q4_1 2 + +#define N_R0_Q5_0 4 +#define N_SG_Q5_0 2 + +#define N_R0_Q5_1 4 +#define N_SG_Q5_1 2 + +#define N_R0_Q8_0 4 +#define N_SG_Q8_0 2 + +#define N_R0_Q2_K 4 +#define N_SG_Q2_K 2 + +#define N_R0_Q3_K 2 +#define N_SG_Q3_K 2 + +#define N_R0_Q4_K 4 +#define N_SG_Q4_K 2 + +#define N_R0_Q5_K 2 +#define N_SG_Q5_K 2 + +#define N_R0_Q6_K 1 +#define N_SG_Q6_K 2 + +#define N_R0_IQ1_S 4 +#define N_SG_IQ1_S 2 + +#define N_R0_IQ1_M 4 +#define N_SG_IQ1_M 2 + +#define N_R0_IQ2_XXS 4 +#define N_SG_IQ2_XXS 2 + +#define N_R0_IQ2_XS 4 +#define N_SG_IQ2_XS 2 + +#define N_R0_IQ2_S 4 +#define N_SG_IQ2_S 2 + +#define N_R0_IQ3_XXS 4 +#define N_SG_IQ3_XXS 2 + +#define N_R0_IQ3_S 4 +#define N_SG_IQ3_S 2 + +#define N_R0_IQ4_NL 2 +#define N_SG_IQ4_NL 2 + +#define N_R0_IQ4_XS 2 +#define N_SG_IQ4_XS 2 + +// kernel argument structs +// +// - element counters (e.g. ne00) typically use int32_t to reduce register usage +// however, be careful from int overflows when using those in the kernel implementation +// +// - strides (e.g. nb00) use uint64_t + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t dim; +} ggml_metal_kargs_concat; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + uint64_t offs; +} ggml_metal_kargs_bin; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_repeat; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_cpy; + +typedef struct { + int64_t ne10; + int64_t ne11; + int64_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + uint64_t offs; + bool inplace; +} ggml_metal_kargs_set; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t n_past; + int32_t n_dims; + int32_t n_ctx_orig; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + int32_t sect_0; + int32_t sect_1; + int32_t sect_2; + int32_t sect_3; +} ggml_metal_kargs_rope; + +typedef struct { + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + uint64_t nb31; + int32_t ne1; + int32_t ne2; + float scale; + float max_bias; + float m0; + float m1; + uint16_t n_head_log2; + float logit_softcap; +} ggml_metal_kargs_flash_attn_ext; + +typedef struct { + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; +} ggml_metal_kargs_mul_mm; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; +} ggml_metal_kargs_mul_mv; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; + int16_t nsg; + int16_t nxpsg; + int16_t r1ptg; +} ggml_metal_kargs_mul_mv_ext; + +typedef struct { + int32_t ne10; + int32_t ne11; // n_expert_used (bcast) + uint64_t nb11; + uint64_t nb12; + int32_t neh11; // n_tokens + uint64_t nbh11; + int32_t ne20; // n_expert_used + uint64_t nb21; +} ggml_metal_kargs_mul_mm_id_map0; + +typedef struct { + int32_t ne20; // n_expert_used + int32_t neh0; + int32_t neh1; + uint64_t nbh1; + uint64_t nbh2; + int32_t ne0; + uint64_t nb1; + uint64_t nb2; +} ggml_metal_kargs_mul_mm_id_map1; + +typedef struct { + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t neh12; + uint64_t nbh10; + uint64_t nbh11; + uint64_t nbh12; + uint64_t nbh13; + int32_t neh0; + int32_t neh1; + int16_t r2; + int16_t r3; +} ggml_metal_kargs_mul_mm_id; + +typedef struct { + int32_t nei0; + int32_t nei1; + uint64_t nbi1; + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + int32_t ne0; + int32_t ne1; + uint64_t nb1; +} ggml_metal_kargs_mul_mv_id; + +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} ggml_metal_kargs_norm; + +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} ggml_metal_kargs_rms_norm; + +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} ggml_metal_kargs_l2_norm; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + int32_t n_groups; + float eps; +} ggml_metal_kargs_group_norm; + +typedef struct { + int32_t IC; + int32_t IL; + int32_t K; + int32_t s0; + uint64_t nb0; + uint64_t nb1; +} ggml_metal_kargs_conv_transpose_1d; + +typedef struct { + uint64_t ofs0; + uint64_t ofs1; + int32_t IW; + int32_t IH; + int32_t CHW; + int32_t s0; + int32_t s1; + int32_t p0; + int32_t p1; + int32_t d0; + int32_t d1; + int32_t N; + int32_t KH; + int32_t KW; + int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources +} ggml_metal_kargs_im2col; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne10; + int64_t ne11; + int64_t ne12; + int64_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_sum_rows; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + float scale; + float max_bias; + float m0; + float m1; + uint32_t n_head_log2; +} ggml_metal_kargs_soft_max; + +typedef struct { + int64_t ne00; + int64_t ne01; + int n_past; +} ggml_metal_kargs_diag_mask_inf; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + int64_t ne10; + int64_t ne11; + uint64_t nb10; + uint64_t nb11; + int64_t ne0; + int64_t ne1; + int64_t ne2; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; +} ggml_metal_kargs_ssm_conv; + +typedef struct { + int64_t d_state; + int64_t d_inner; + int64_t n_seq_tokens; + int64_t n_seqs; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb20; + uint64_t nb21; + uint64_t nb22; + uint64_t nb30; + uint64_t nb31; + uint64_t nb40; + uint64_t nb41; + uint64_t nb42; + uint64_t nb50; + uint64_t nb51; + uint64_t nb52; +} ggml_metal_kargs_ssm_scan; + +typedef struct { + int64_t ne00; + uint64_t nb01; + uint64_t nb02; + int64_t ne10; + uint64_t nb10; + uint64_t nb11; + uint64_t nb1; + uint64_t nb2; +} ggml_metal_kargs_get_rows; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + float sf0; + float sf1; + float sf2; + float sf3; +} ggml_metal_kargs_upscale; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_pad; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t p0; + int32_t p1; +} ggml_metal_kargs_pad_reflect_1d; + +typedef struct { + uint64_t nb1; + int dim; + int max_period; +} ggml_metal_kargs_timestep_embedding; + +typedef struct { + float slope; +} ggml_metal_kargs_leaky_relu; + +typedef struct { + int64_t ncols; + int64_t ncols_pad; +} ggml_metal_kargs_argsort; + +typedef struct { + int64_t ne0; + float start; + float step; +} ggml_metal_kargs_arange; + +typedef struct { + int32_t k0; + int32_t k1; + int32_t s0; + int32_t s1; + int32_t p0; + int32_t p1; + int64_t IH; + int64_t IW; + int64_t OH; + int64_t OW; + int64_t parallel_elements; +} ggml_metal_kargs_pool_2d; + +#endif // GGML_METAL_IMPL diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m new file mode 100644 index 0000000000000000000000000000000000000000..bc93bc633a49b224d38d3485053c205ba2bbdec3 --- /dev/null +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -0,0 +1,6000 @@ +#import "ggml-metal.h" + +#import "ggml-impl.h" +#import "ggml-backend-impl.h" +#import "ggml-metal-impl.h" + +#import + +#import + +#undef MIN +#undef MAX +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +// max memory buffers that can be mapped to the device +#define GGML_METAL_MAX_BUFFERS 64 + +// max number of MTLCommandBuffer used to submit a graph for processing +#define GGML_METAL_MAX_COMMAND_BUFFERS 8 + +#ifndef TARGET_OS_VISION +#define TARGET_OS_VISION 0 +#endif + +// create residency sets only on macOS >= 15.0 +#if !TARGET_CPU_X86_64 && TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \ + TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \ + TARGET_OS_TV && __TV_OS_VERSION_MAX_ALLOWED >= 180000 || \ + TARGET_OS_VISION && __VISION_OS_VERSION_MAX_ALLOWED >= 200000 +#define GGML_METAL_HAS_RESIDENCY_SETS 1 +#endif + +// globals + +// overload of MTLGPUFamilyMetal3 (not available in some environments) +static const NSInteger MTLGPUFamilyMetal3_GGML = 5001; + +// initialized in ggml_backend_metal_reg +static struct ggml_backend_reg g_ggml_backend_metal_reg; +static struct ggml_backend_device g_ggml_backend_metal_device; + +// information about a Metal device +// note: assumes single GPU device - the default one +// TODO: support multiple GPU devices +static struct ggml_backend_metal_device_context { + id mtl_device; + int mtl_device_ref_count; + id mtl_library; + + bool has_simdgroup_reduction; + bool has_simdgroup_mm; + bool has_residency_sets; + bool has_bfloat; + bool use_bfloat; + + char name[128]; +} g_ggml_ctx_dev_main = { + /*.mtl_device =*/ nil, + /*.mtl_device_ref_count =*/ 0, + /*.mtl_library =*/ nil, + /*.has_simdgroup_reduction =*/ false, + /*.has_simdgroup_mm =*/ false, + /*.has_residency_sets =*/ false, + /*.has_bfloat =*/ false, + /*.use_bfloat =*/ false, + /*.name =*/ "", +}; + +// acquire +static id ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) { + assert(ctx != NULL); + + if (ctx->mtl_device == nil) { + ctx->mtl_device = MTLCreateSystemDefaultDevice(); + } + + if (ctx->mtl_device) { + ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; + ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; + + ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; + +#if defined(GGML_METAL_HAS_RESIDENCY_SETS) + ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == NULL; +#endif + + ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; + ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6]; + +#if defined(GGML_METAL_USE_BF16) + ctx->use_bfloat = ctx->has_bfloat; +#else + ctx->use_bfloat = false; +#endif + + strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1); + } + + ctx->mtl_device_ref_count++; + + return ctx->mtl_device; +} + +// release +static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_context * ctx) { + assert(ctx != NULL); + assert(ctx->mtl_device_ref_count > 0); + + ctx->mtl_device_ref_count--; + + if (ctx->mtl_device_ref_count == 0) { + if (ctx->mtl_library) { + [ctx->mtl_library release]; + ctx->mtl_library = nil; + } + + if (ctx->mtl_device) { + [ctx->mtl_device release]; + ctx->mtl_device = nil; + } + } +} + +// kernels + +struct ggml_metal_kernel { + id pipeline; +}; + +enum ggml_metal_kernel_type { + GGML_METAL_KERNEL_TYPE_ADD, + GGML_METAL_KERNEL_TYPE_ADD_ROW, + GGML_METAL_KERNEL_TYPE_SUB, + GGML_METAL_KERNEL_TYPE_SUB_ROW, + GGML_METAL_KERNEL_TYPE_MUL, + GGML_METAL_KERNEL_TYPE_MUL_ROW, + GGML_METAL_KERNEL_TYPE_DIV, + GGML_METAL_KERNEL_TYPE_DIV_ROW, + GGML_METAL_KERNEL_TYPE_REPEAT_F32, + GGML_METAL_KERNEL_TYPE_REPEAT_F16, + GGML_METAL_KERNEL_TYPE_REPEAT_I32, + GGML_METAL_KERNEL_TYPE_REPEAT_I16, + GGML_METAL_KERNEL_TYPE_SCALE, + GGML_METAL_KERNEL_TYPE_SCALE_4, + GGML_METAL_KERNEL_TYPE_CLAMP, + GGML_METAL_KERNEL_TYPE_TANH, + GGML_METAL_KERNEL_TYPE_RELU, + GGML_METAL_KERNEL_TYPE_SIGMOID, + GGML_METAL_KERNEL_TYPE_GELU, + GGML_METAL_KERNEL_TYPE_GELU_4, + GGML_METAL_KERNEL_TYPE_GELU_ERF, + GGML_METAL_KERNEL_TYPE_GELU_ERF_4, + GGML_METAL_KERNEL_TYPE_GELU_QUICK, + GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, + GGML_METAL_KERNEL_TYPE_SILU, + GGML_METAL_KERNEL_TYPE_SILU_4, + GGML_METAL_KERNEL_TYPE_ELU, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, + GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, + GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, + GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, + GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, + GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, + GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, + GGML_METAL_KERNEL_TYPE_RMS_NORM, + GGML_METAL_KERNEL_TYPE_L2_NORM, + GGML_METAL_KERNEL_TYPE_GROUP_NORM, + GGML_METAL_KERNEL_TYPE_NORM, + GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, + GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, + GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, + GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, + GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, + GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, + GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, + GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, + GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, + //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, + //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, + //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, + GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, + GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, + GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, + GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, + GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, + GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, + GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, + GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, + GGML_METAL_KERNEL_TYPE_IM2COL_F16, + GGML_METAL_KERNEL_TYPE_IM2COL_F32, + GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, + GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, + GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, + GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, + GGML_METAL_KERNEL_TYPE_UPSCALE_F32, + GGML_METAL_KERNEL_TYPE_PAD_F32, + GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, + GGML_METAL_KERNEL_TYPE_ARANGE_F32, + GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, + GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, + GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, + GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, + GGML_METAL_KERNEL_TYPE_SET_I32, + GGML_METAL_KERNEL_TYPE_SET_F32, + GGML_METAL_KERNEL_TYPE_CPY_F32_F32, + GGML_METAL_KERNEL_TYPE_CPY_F32_F16, + GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, + GGML_METAL_KERNEL_TYPE_CPY_F16_F16, + GGML_METAL_KERNEL_TYPE_CPY_F16_F32, + GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, + GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, + GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, + GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, + GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, + GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, + GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, + GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, + GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, + GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, + GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, + GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, + GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, + GGML_METAL_KERNEL_TYPE_CONCAT, + GGML_METAL_KERNEL_TYPE_SQR, + GGML_METAL_KERNEL_TYPE_SQRT, + GGML_METAL_KERNEL_TYPE_SIN, + GGML_METAL_KERNEL_TYPE_COS, + GGML_METAL_KERNEL_TYPE_NEG, + GGML_METAL_KERNEL_TYPE_SUM_ROWS, + GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, + GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, + GGML_METAL_KERNEL_TYPE_ARGMAX, + + GGML_METAL_KERNEL_TYPE_COUNT +}; + +// +// ggml_metal_heap +// + +struct ggml_metal_heap { + // number of times the heap was unused + int n_unused; + + // total number of buffer allocations in this heap across all computes + int64_t n_alloc; + + // current offset in the heap - we reset this after each node in order to reuse the memory + size_t offs; + + // the currently allocated MTLBuffer objects in this heap + id obj; + + NSMutableArray * bufs; +}; + +static struct ggml_metal_heap * ggml_metal_heap_init(id device, size_t size) { + struct ggml_metal_heap * heap = calloc(1, sizeof(struct ggml_metal_heap)); + + MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init]; + desc.storageMode = MTLStorageModePrivate; + desc.cpuCacheMode = MTLCPUCacheModeDefaultCache; + desc.type = MTLHeapTypePlacement; + desc.size = size; + + heap->n_unused = 0; + heap->n_alloc = 0; + + heap->obj = [device newHeapWithDescriptor:desc]; + if (!heap->obj) { + GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size); + + free(heap); + + return false; + } + + [desc release]; + + heap->bufs = [[NSMutableArray alloc] init]; + + return heap; +} + +static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) { + heap->offs = 0; + + // count how many graph computes the heap ended up being unused + if ([heap->bufs count] > 0) { + heap->n_unused = 0; + } else { + heap->n_unused++; + } + + for (id buf in heap->bufs) { + [buf release]; + } + [heap->bufs removeAllObjects]; + + // tell the OS that it can reuse this memory if needed + // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc + [heap->obj setPurgeableState:MTLPurgeableStateVolatile]; +} + +static void ggml_metal_heap_free(struct ggml_metal_heap * heap) { + if (heap == nil) { + return; + } + + ggml_metal_heap_reset(heap); + + [heap->obj release]; + [heap->bufs release]; + + free(heap); +} + +@interface ggml_metal_heap_ptr : NSObject + +@property (nonatomic, assign) struct ggml_metal_heap * data; + +@end + +@implementation ggml_metal_heap_ptr +@end + +// +// ggml_metal_mem_pool +// + +struct ggml_metal_mem_pool { + id device; + + int n_heaps; // total number of heaps ever created (including those that were removed) + + NSMutableArray * heaps; + NSMutableArray * heaps_to_remove; +}; + +static struct ggml_metal_mem_pool * ggml_metal_mem_pool_init(void) { + struct ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct ggml_metal_mem_pool)); + + mem_pool->n_heaps = 0; + + mem_pool->heaps = [[NSMutableArray alloc] init]; + mem_pool->heaps_to_remove = [[NSMutableArray alloc] init]; + + return mem_pool; +} + +static void ggml_metal_mem_pool_free(struct ggml_metal_mem_pool * mem_pool) { + GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps); + + size_t size_all = 0; + size_t size_cur = 0; + + for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) { + GGML_LOG_DEBUG("%s: heap: %p\n", __func__, (void *) ptr.data); + GGML_LOG_DEBUG("%s: n_alloc: %" PRId64 "\n", __func__, ptr.data->n_alloc); + GGML_LOG_DEBUG("%s: n_unused: %d\n", __func__, ptr.data->n_unused); + GGML_LOG_DEBUG("%s: size: %.2f MiB\n", __func__, [ptr.data->obj size] / 1024.0 / 1024.0); + GGML_LOG_DEBUG("%s: bufs: %zu\n", __func__, [ptr.data->bufs count]); + + if ([ptr.data->bufs count] > 0) { + size_cur += [ptr.data->obj size]; + } + size_all += [ptr.data->obj size]; + + ggml_metal_heap_free(ptr.data); + [ptr release]; + } + [mem_pool->heaps release]; + [mem_pool->heaps_to_remove release]; + + if (size_all > 0) { + GGML_LOG_DEBUG("%s: size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0); + GGML_LOG_DEBUG("%s: size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0); + } + + free(mem_pool); +} + +static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) { + for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) { + ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i]; + + struct ggml_metal_heap * heap = ptr.data; + ggml_metal_heap_reset(heap); + + // if the heap hasn't been used for a while, remove it + if (heap->n_unused >= 128) { + [mem_pool->heaps_to_remove addObject:@(i)]; + } + } + + if (mem_pool->heaps_to_remove.count > 0) { + // remove in reverse order + for (NSUInteger i = [mem_pool->heaps_to_remove count] - 1; ; --i) { + NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue]; + ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index]; + + struct ggml_metal_heap * heap = ptr.data; + ggml_metal_heap_free(heap); + + [mem_pool->heaps removeObjectAtIndex:index]; + [ptr release]; + + if (i == 0) { + break; + } + } + + [mem_pool->heaps_to_remove removeAllObjects]; + } +} + +static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) { + for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) { + ptr.data->offs = 0; + } +} + +static id ggml_metal_mem_pool_alloc(struct ggml_metal_mem_pool * mem_pool, size_t size) { + const size_t alignment = 256; + + const size_t size_aligned = GGML_PAD(size, alignment); + + // try one of the existing heaps + for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) { + struct ggml_metal_heap * heap = ptr.data; + if (heap->offs + size_aligned <= [heap->obj size]) { + // if this is the first buffer in the heap for the current command buffer, tell the OS that + // it cannot free the memory used by the heap + // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc + if ([heap->bufs count] == 0) { + [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile]; + } + + id buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs]; + if (buf == nil) { + GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned); + return nil; + } + + heap->n_alloc++; + heap->offs += size_aligned; + + [heap->bufs addObject:buf]; + + return buf; + } + } + + // create a new heap that can fit this buffer + ggml_metal_heap_ptr * heap_ptr = [ggml_metal_heap_ptr new]; + + struct ggml_metal_heap * heap = ggml_metal_heap_init(mem_pool->device, size_aligned); + if (heap == NULL) { + GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned); + return NULL; + } + + //GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]); + + heap_ptr.data = heap; + ggml_metal_heap_reset(heap); + + [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile]; + id buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs]; + if (buf == nil) { + GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned); + return NULL; + } + + heap->n_alloc++; + heap->offs += size_aligned; + + [heap->bufs addObject:buf]; + + [mem_pool->heaps addObject:heap_ptr]; + mem_pool->n_heaps++; + + return buf; +} + +struct ggml_metal_command_buffer { + id obj; + + // each command buffer has a memory pool from which it can allocate temporary buffers during the compute + struct ggml_metal_mem_pool * mem_pool; +}; + +struct ggml_backend_metal_context { + id device; + id queue; + + dispatch_queue_t d_queue; + + struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT]; + + // capture state + bool capture_next_compute; + bool capture_started; + + id capture_scope; + + // command buffer state + int n_cb; // number of extra threads used to submit the command buffers + int n_nodes_0; // number of nodes submitted by the main thread + int n_nodes_1; // remaining number of nodes submitted by the n_cb threads + int n_nodes_per_cb; + + struct ggml_cgraph * gf; + + // the callback given to the thread pool + void (^encode_async)(size_t ith); + + // n_cb command buffers + 1 used by the main thread + struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; + + // abort ggml_metal_graph_compute if callback returns true + ggml_abort_callback abort_callback; + void * abort_callback_data; +}; + +// MSL code +// TODO: move the contents here when ready +// for now it is easier to work in a separate file +// static NSString * const msl_library_source = @"see metal.metal"; + +#if !GGML_METAL_EMBED_LIBRARY +// Here to assist with NSBundle Path Hack +@interface GGMLMetalClass : NSObject +@end +@implementation GGMLMetalClass +@end +#endif + +static void * ggml_metal_host_malloc(size_t n) { + void * data = NULL; + +#if TARGET_OS_OSX + kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE); + if (err != KERN_SUCCESS) { + GGML_LOG_ERROR("%s: error: vm_allocate failed\n", __func__); + return NULL; + } +#else + const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n); + if (result != 0) { + GGML_LOG_ERROR("%s: error: posix_memalign failed\n", __func__); + return NULL; + } +#endif + + return data; +} + +// load library +// +// - first check if the library is embedded +// - then check if the library is in the bundle +// - if not found, load the source and compile it +// - if that fails, return NULL +static id ggml_metal_load_library(id device, bool use_bfloat) { + id metal_library = nil; + NSError * error = nil; + NSString * src = nil; + +#if GGML_METAL_EMBED_LIBRARY + GGML_LOG_INFO("%s: using embedded metal library\n", __func__); + + extern const char ggml_metallib_start[]; + extern const char ggml_metallib_end[]; + + src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding]; + +#else + +#ifdef SWIFT_PACKAGE + NSBundle * bundle = SWIFTPM_MODULE_BUNDLE; +#else + NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; +#endif + + NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"]; + if (path_lib == nil) { + // Try to find the resource in the directory where the current binary located. + NSString * current_binary = [[NSProcessInfo processInfo] arguments][0]; + NSString * bin_dir = [current_binary stringByDeletingLastPathComponent]; + NSString * default_metallib_path = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]]; + if ([[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) { + GGML_LOG_INFO("%s: found '%s'\n", __func__, [default_metallib_path UTF8String]); + NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:default_metallib_path error:&error]; + if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) { + // Optionally, if this is a symlink, try to resolve it. + default_metallib_path = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:default_metallib_path error:&error]; + if (default_metallib_path && [default_metallib_path length] > 0 && ![[default_metallib_path substringToIndex:1] isEqualToString:@"/"]) { + // It is a relative path, adding the binary directory as directory prefix. + default_metallib_path = [NSString pathWithComponents:@[bin_dir, default_metallib_path]]; + } + if (!default_metallib_path || ![[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) { + // Link to the resource could not be resolved. + default_metallib_path = nil; + } else { + GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [default_metallib_path UTF8String]); + } + } + } else { + // The resource couldn't be found in the binary's directory. + default_metallib_path = nil; + } + path_lib = default_metallib_path; + } + + if (path_lib != nil) { + // pre-compiled library found + NSURL * libURL = [NSURL fileURLWithPath:path_lib]; + GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]); + + metal_library = [device newLibraryWithURL:libURL error:&error]; + if (error) { + GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } + } else { + GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); + + NSString * path_source; + NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"]; + + GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil"); + + if (path_resource) { + path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"]; + } else { + path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; + } + + if (path_source == nil) { + GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__); + path_source = @"ggml-metal.metal"; + } + + GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]); + + src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error]; + if (error) { + GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } + } +#endif + + if (!metal_library) { + @autoreleasepool { + // dictionary of preprocessor macros + NSMutableDictionary * prep = [NSMutableDictionary dictionary]; + + if (use_bfloat) { + [prep setObject:@"1" forKey:@"GGML_METAL_USE_BF16"]; + } + +#if GGML_METAL_EMBED_LIBRARY + [prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"]; +#endif + + MTLCompileOptions * options = [MTLCompileOptions new]; + options.preprocessorMacros = prep; + + //[options setFastMathEnabled:false]; + + metal_library = [device newLibraryWithSource:src options:options error:&error]; + if (error) { + GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } + +#if !__has_feature(objc_arc) + [options release]; +#endif + } + } + +#if GGML_METAL_EMBED_LIBRARY + [src release]; +#endif // GGML_METAL_EMBED_LIBRARY + + return metal_library; +} + +static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) { + GGML_LOG_INFO("%s: allocating\n", __func__); + +#if TARGET_OS_OSX && !GGML_METAL_NDEBUG + // Show all the Metal device instances in the system + NSArray * devices = MTLCopyAllDevices(); + for (id device in devices) { + GGML_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]); + } + [devices release]; // since it was created by a *Copy* C method +#endif + + // init context + struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context)); + struct ggml_backend_metal_device_context * ctx_dev = dev->context; + + id device = ggml_backend_metal_device_acq(ctx_dev); + + GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]); + + ctx->device = device; + ctx->queue = [device newCommandQueue]; + if (ctx->queue == nil) { + GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); + return NULL; + } + + ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); + + // load library + if (ctx_dev->mtl_library == nil) { + ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat); + } + id metal_library = ctx_dev->mtl_library; + if (metal_library == nil) { + GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__); + return NULL; + } + + // print MTL GPU family: + GGML_LOG_INFO("%s: GPU name: %s\n", __func__, [[device name] UTF8String]); + + // determine max supported GPU family + // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf + // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf + { + for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) { + if ([device supportsFamily:i]) { + GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i); + break; + } + } + + for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) { + if ([device supportsFamily:i]) { + GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i); + break; + } + } + + for (int i = MTLGPUFamilyMetal3_GGML + 5; i >= MTLGPUFamilyMetal3_GGML; --i) { + if ([device supportsFamily:i]) { + GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3_GGML + 3, i); + break; + } + } + } + + GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false"); + GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false"); + GGML_LOG_INFO("%s: has residency sets = %s\n", __func__, ctx_dev->has_residency_sets ? "true" : "false"); + GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false"); + GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false"); + GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false"); + + ctx->capture_next_compute = false; + ctx->capture_started = false; + ctx->capture_scope = nil; + + ctx->gf = nil; + ctx->encode_async = nil; + for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { + ctx->cmd_bufs[i].obj = nil; + + ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init(); + ctx->cmd_bufs[i].mem_pool->device = device; + } + +#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) + if (@available(macOS 10.12, iOS 16.0, *)) { + GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6); + } +#endif + + // load kernels + { + NSError * error = nil; + + for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) { + ctx->kernels[i].pipeline = nil; + } + +#define GGML_METAL_ADD_KERNEL(e, name, supported) \ + if (supported) { \ + struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \ + id metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \ + kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \ + GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ + (int) kernel->pipeline.threadExecutionWidth); \ + [metal_function release]; \ + if (error) { \ + GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ + return NULL; \ + } \ + } else { \ + GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \ + } + + const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm; + const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction; + const bool use_bfloat = ctx_dev->use_bfloat; + + // simd_sum and simd_max requires MTLGPUFamilyApple7 + + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF, gelu_erf, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF_4, gelu_erf_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, has_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, mul_mm_id_q4_0_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, mul_mm_id_q4_1_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16, mul_mm_id_q5_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16, mul_mm_id_q6_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16, mul_mm_id_iq2_xxs_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16, mul_mm_id_iq2_xs_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16, mul_mm_id_iq3_xxs_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16, mul_mm_id_iq3_s_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16, mul_mm_id_iq2_s_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16, mul_mm_id_iq1_s_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, flash_attn_ext_vec_q4_1_h64, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, flash_attn_ext_vec_q5_0_h64, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, flash_attn_ext_vec_q5_1_h64, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, flash_attn_ext_vec_q4_1_h96, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, flash_attn_ext_vec_q5_0_h96, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, flash_attn_ext_vec_q5_1_h96, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192, flash_attn_ext_vec_f16_h192, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192, flash_attn_ext_vec_bf16_h192, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192, flash_attn_ext_vec_q4_0_h192, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192, flash_attn_ext_vec_q4_1_h192, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192, flash_attn_ext_vec_q5_0_h192, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192, flash_attn_ext_vec_q5_1_h192, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192, flash_attn_ext_vec_q8_0_h192, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, flash_attn_ext_vec_f16_hk192_hv128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128, flash_attn_ext_vec_bf16_hk192_hv128, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128, flash_attn_ext_vec_q4_0_hk192_hv128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128, flash_attn_ext_vec_q4_1_hk192_hv128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128, flash_attn_ext_vec_q5_0_hk192_hv128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128, flash_attn_ext_vec_q5_1_hk192_hv128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, flash_attn_ext_vec_q8_0_hk192_hv128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, flash_attn_ext_vec_f16_hk576_hv512, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, flash_attn_ext_vec_bf16_hk576_hv512, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, flash_attn_ext_vec_q4_0_hk576_hv512, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, flash_attn_ext_vec_q4_1_hk576_hv512, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true); + } + + return ctx; +} + +static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { + GGML_LOG_INFO("%s: deallocating\n", __func__); + + for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) { + [ctx->kernels[i].pipeline release]; + } + + Block_release(ctx->encode_async); + + [ctx->queue release]; + + for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { + // ctx->cmd_bufs[i].obj is auto released + + ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool); + } + + dispatch_release(ctx->d_queue); + + free(ctx); +} + +// temporarily defined here for compatibility between ggml-backend and the old API + +struct ggml_backend_metal_buffer { + void * data; + size_t size; + + id metal; +}; + +struct ggml_backend_metal_buffer_context { + void * all_data; + size_t all_size; + bool owned; + + // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap + int n_buffers; + struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS]; + + // optional MTLResidencySet + id rset; +}; + +// rset init +static bool ggml_backend_metal_buffer_rset_init( + struct ggml_backend_metal_buffer_context * ctx, + struct ggml_backend_metal_device_context * ctx_dev, + id device) { + ctx->rset = nil; + + if (!ctx_dev->has_residency_sets) { + return true; + } + +#if defined(GGML_METAL_HAS_RESIDENCY_SETS) + if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) { + MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init]; + desc.label = @"ggml_backend_metal"; + desc.initialCapacity = ctx->n_buffers; + + NSError * error; + ctx->rset = [device newResidencySetWithDescriptor:desc error:&error]; + if (error) { + GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + [desc release]; + return false; + } + + [desc release]; + + for (int i = 0; i < ctx->n_buffers; i++) { + [ctx->rset addAllocation:ctx->buffers[i].metal]; + } + + [ctx->rset commit]; + [ctx->rset requestResidency]; + + return true; + } +#else + GGML_UNUSED(ctx_dev); + GGML_UNUSED(device); +#endif + + return true; +} + +// rset free +static void ggml_backend_metal_buffer_rset_free(struct ggml_backend_metal_buffer_context * ctx) { +#if defined(GGML_METAL_HAS_RESIDENCY_SETS) + if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) { + if (ctx->rset) { + [ctx->rset endResidency]; + [ctx->rset removeAllAllocations]; + [ctx->rset release]; + } + } +#else + GGML_UNUSED(ctx); +#endif +} + +// finds the Metal buffer that contains the tensor data on the GPU device +// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the +// Metal buffer based on the host memory pointer +// +static id ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs) { + //GGML_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach); + + const int64_t tsize = ggml_nbytes(t); + + ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer; + + struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context; + + // find the view that contains the tensor fully + for (int i = 0; i < buf_ctx->n_buffers; ++i) { + const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data; + + //GGML_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size); + if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) { + *offs = (size_t) ioffs; + + //GGML_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs); + + return buf_ctx->buffers[i].metal; + } + } + + GGML_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name); + + return nil; +} + +static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) { + const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm; + const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction; + const bool use_bfloat = ctx_dev->use_bfloat; + + if (!use_bfloat) { + for (size_t i = 0, n = 3; i < n; ++i) { + if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) { + return false; + } + } + } + + switch (op->op) { + case GGML_OP_UNARY: + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_ERF: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_NEG: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + default: + return false; + } + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_TRANSPOSE: + case GGML_OP_PERMUTE: + case GGML_OP_CONCAT: + return true; + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_ACC: + case GGML_OP_REPEAT: + case GGML_OP_SCALE: + case GGML_OP_CONV_TRANSPOSE_1D: + return true; + case GGML_OP_CLAMP: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_LOG: + return false; // TODO: implement + case GGML_OP_SUM_ROWS: + case GGML_OP_SOFT_MAX: + case GGML_OP_GROUP_NORM: + return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); + case GGML_OP_RMS_NORM: + case GGML_OP_L2_NORM: + return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); + case GGML_OP_ARGMAX: + return true; + case GGML_OP_NORM: + return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); + case GGML_OP_ROPE: + return true; + case GGML_OP_IM2COL: + return op->src[0]->type == GGML_TYPE_F16; + case GGML_OP_POOL_1D: + return false; + case GGML_OP_UPSCALE: + return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; + case GGML_OP_POOL_2D: + case GGML_OP_PAD: + case GGML_OP_PAD_REFLECT_1D: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_ARGSORT: + case GGML_OP_LEAKY_RELU: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_ARANGE: + return true; + case GGML_OP_FLASH_ATTN_EXT: + if (op->src[0]->ne[0] == 32) { + // head size == 32 (e.g. bert-bge-small) + // TODO: not sure if it is worth adding kernels for this size + return false; + } + if (op->src[0]->ne[0] == 576) { + // DeepSeek sizes + // TODO: disabled for now, until optmized + return false; + } + if (op->src[1]->type != op->src[2]->type) { + return false; + } + return has_simdgroup_mm; // TODO: over-restricted for vec-kernels + case GGML_OP_SSM_CONV: + case GGML_OP_SSM_SCAN: + case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: + return true; + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + return has_simdgroup_reduction && + (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32); + case GGML_OP_CPY: + case GGML_OP_DUP: + case GGML_OP_CONT: + { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_IQ4_NL: + return true; + default: + return false; + } + case GGML_TYPE_F16: + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + return true; + default: + return false; + } + case GGML_TYPE_BF16: + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_BF16: + return true; + default: + return false; + } + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + return true; + default: + return false; + } + default: + return false; + }; + } + case GGML_OP_SET: + { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_I32: + return true; + default: + return false; + }; + } + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_GET_ROWS: + { + return op->ne[3] == 1; + } + default: + return false; + } +} + +static bool ggml_metal_encode_node( + ggml_backend_t backend, + int idx, + id encoder, + struct ggml_metal_mem_pool * mem_pool) { + struct ggml_backend_metal_context * ctx = backend->context; + struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + + struct ggml_cgraph * gf = ctx->gf; + + struct ggml_tensor * node = ggml_graph_node(gf, idx); + + //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op)); + + struct ggml_tensor * src0 = node->src[0]; + struct ggml_tensor * src1 = node->src[1]; + struct ggml_tensor * src2 = node->src[2]; + struct ggml_tensor * dst = node; + + if (ggml_is_empty(dst)) { + return true; + } + + switch (dst->op) { + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_TRANSPOSE: + case GGML_OP_PERMUTE: + { + // noop -> next node + } return true; + default: + { + } break; + } + + if (!ggml_metal_supports_op(ctx_dev, dst)) { + GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst)); + GGML_ABORT("unsupported op"); + } + + ggml_metal_mem_pool_clear(mem_pool); + + const int64_t ne00 = src0 ? src0->ne[0] : 0; + const int64_t ne01 = src0 ? src0->ne[1] : 0; + const int64_t ne02 = src0 ? src0->ne[2] : 0; + const int64_t ne03 = src0 ? src0->ne[3] : 0; + + const uint64_t nb00 = src0 ? src0->nb[0] : 0; + const uint64_t nb01 = src0 ? src0->nb[1] : 0; + const uint64_t nb02 = src0 ? src0->nb[2] : 0; + const uint64_t nb03 = src0 ? src0->nb[3] : 0; + + const int64_t ne10 = src1 ? src1->ne[0] : 0; + const int64_t ne11 = src1 ? src1->ne[1] : 0; + const int64_t ne12 = src1 ? src1->ne[2] : 0; + const int64_t ne13 = src1 ? src1->ne[3] : 0; + + const uint64_t nb10 = src1 ? src1->nb[0] : 0; + const uint64_t nb11 = src1 ? src1->nb[1] : 0; + const uint64_t nb12 = src1 ? src1->nb[2] : 0; + const uint64_t nb13 = src1 ? src1->nb[3] : 0; + + const int64_t ne20 = src2 ? src2->ne[0] : 0; + const int64_t ne21 = src2 ? src2->ne[1] : 0; + const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22); + const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23); + + const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); + const uint64_t nb21 = src2 ? src2->nb[1] : 0; + const uint64_t nb22 = src2 ? src2->nb[2] : 0; + const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23); + + const int64_t ne0 = dst ? dst->ne[0] : 0; + const int64_t ne1 = dst ? dst->ne[1] : 0; + const int64_t ne2 = dst ? dst->ne[2] : 0; + const int64_t ne3 = dst ? dst->ne[3] : 0; + + const uint64_t nb0 = dst ? dst->nb[0] : 0; + const uint64_t nb1 = dst ? dst->nb[1] : 0; + const uint64_t nb2 = dst ? dst->nb[2] : 0; + const uint64_t nb3 = dst ? dst->nb[3] : 0; + + const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; + const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; + + size_t offs_src0 = 0; + size_t offs_src1 = 0; + size_t offs_src2 = 0; + size_t offs_dst = 0; + + id id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil; + id id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil; + id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; + id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; + +#if 0 + GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); + if (src0) { + GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, + ggml_is_contiguous(src0), src0->name); + } + if (src1) { + GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, + ggml_is_contiguous(src1), src1->name); + } + if (dst) { + GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, + dst->name); + } +#endif + + id device = ctx_dev->mtl_device; + + switch (dst->op) { + case GGML_OP_CONCAT: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; + + const int32_t dim = ((const int32_t *) dst->op_params)[0]; + + ggml_metal_kargs_concat args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.dim =*/ dim, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + + const size_t offs = 0; + + bool bcast_row = false; + + id pipeline = nil; + + if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(src0)); + + // src1 is a row + GGML_ASSERT(ne11 == 1); + + switch (dst->op) { + case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; + case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; + default: GGML_ABORT("fatal error"); + } + + bcast_row = true; + } else { + switch (dst->op) { + case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; + case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; + default: GGML_ABORT("fatal error"); + } + } + + ggml_metal_kargs_bin args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.offs =*/ offs, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + if (bcast_row) { + const int64_t n = ggml_nelements(dst)/4; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } else { + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } + } break; + case GGML_OP_REPEAT: + { + id pipeline; + + switch (src0t) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break; + case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break; + case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break; + default: GGML_ABORT("fatal error"); + } + + ggml_metal_kargs_repeat args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ACC: + { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + GGML_ASSERT(dstt == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + const size_t pnb1 = ((const int32_t *) dst->op_params)[0]; + const size_t pnb2 = ((const int32_t *) dst->op_params)[1]; + const size_t pnb3 = ((const int32_t *) dst->op_params)[2]; + const size_t offs = ((const int32_t *) dst->op_params)[3]; + + const bool inplace = (bool) ((const int32_t *) dst->op_params)[4]; + + if (!inplace) { + // run a separete kernel to cpy src->dst + // not sure how to avoid this + // TODO: make a simpler cpy_bytes kernel + + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; + + ggml_metal_kargs_cpy args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } + + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; + + ggml_metal_kargs_bin args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ pnb1, + /*.nb02 =*/ pnb2, + /*.nb03 =*/ pnb3, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ pnb1, + /*.nb2 =*/ pnb2, + /*.nb3 =*/ pnb3, + /*.offs =*/ offs, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); + + [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_SCALE: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + float scale; + memcpy(&scale, dst->op_params, sizeof(scale)); + + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + n /= 4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_CLAMP: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; + + float min; + float max; + memcpy(&min, ((const int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float)); + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&min length:sizeof(min) atIndex:2]; + [encoder setBytes:&max length:sizeof(max) atIndex:3]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(node)) { + // we are not taking into account the strides, so for now require contiguous tensors + GGML_ASSERT(ggml_is_contiguous(src0)); + + case GGML_UNARY_OP_TANH: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_RELU: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_SIGMOID: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_GELU: + { + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_GELU_ERF: + { + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_GELU_QUICK: + { + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_SILU: + { + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_ELU: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ELU].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_NEG: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + default: + { + GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); + GGML_ABORT("fatal error"); + } + } break; + case GGML_OP_SQR: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SQRT: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SIN: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIN].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_COS: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_COS].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SUM_ROWS: + { + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; + + + ggml_metal_kargs_sum_rows args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SOFT_MAX: + { + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); + + int nth = 32; // SIMD width + + id pipeline = nil; + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + + if (ne00%4 == 0) { + while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; + } + } else { + while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; + } + } + + float scale; + float max_bias; + + memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); + + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src0->ne[1]; + + const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + +// use this branch to test the ggml_metal_mem_pool functionality +#if 0 + // cpy to tmp buffer in MTLHeap + + id h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0)); + if (!h_src0) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0)); + return false; + } + + offs_src0 = 0; + + ggml_metal_kargs_cpy args_cpy = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne00, + /*.ne1 =*/ ne01, + /*.ne2 =*/ ne02, + /*.ne3 =*/ ne03, + /*.nb0 =*/ nb00, + /*.nb1 =*/ nb01, + /*.nb2 =*/ nb02, + /*.nb3 =*/ nb03, + }; + + if (src0->type == GGML_TYPE_F16) { + [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline]; + } else { + [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline]; + } + [encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:h_src0 offset:0 atIndex:2]; + + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type)); + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)]; + +#else + id h_src0 = id_src0; +#endif + // softmax + + ggml_metal_kargs_soft_max args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:h_src0 offset:offs_src0 atIndex:0]; + if (id_src1) { + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + } else { + [encoder setBuffer:h_src0 offset:offs_src0 atIndex:1]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&args length:sizeof(args) atIndex:3]; + + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_DIAG_MASK_INF: + { + const int n_past = ((const int32_t *)(dst->op_params))[0]; + + id pipeline = nil; + + if (ne00%8 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; + } + + ggml_metal_kargs_diag_mask_inf args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.n_past =*/ n_past, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; + + if (ne00%8 == 0) { + [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } + else { + [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } + } break; + case GGML_OP_SSM_CONV: + { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; + + ggml_metal_kargs_ssm_conv args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&args length:sizeof(args) atIndex:3]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SSM_SCAN: + { + struct ggml_tensor * src3 = node->src[3]; + struct ggml_tensor * src4 = node->src[4]; + struct ggml_tensor * src5 = node->src[5]; + + GGML_ASSERT(src3); + GGML_ASSERT(src4); + GGML_ASSERT(src5); + + size_t offs_src3 = 0; + size_t offs_src4 = 0; + size_t offs_src5 = 0; + + id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; + id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; + id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; + + const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30); + const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); + + const uint64_t nb30 = src3->nb[0]; + const uint64_t nb31 = src3->nb[1]; + + const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); + const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41); + const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); + + const uint64_t nb40 = src4->nb[0]; + const uint64_t nb41 = src4->nb[1]; + const uint64_t nb42 = src4->nb[2]; + + const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50); + const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51); + const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); + + const uint64_t nb50 = src5->nb[0]; + const uint64_t nb51 = src5->nb[1]; + const uint64_t nb52 = src5->nb[2]; + + const int64_t d_state = ne00; + const int64_t d_inner = ne01; + const int64_t n_seq_tokens = ne11; + const int64_t n_seqs = ne02; + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + + ggml_metal_kargs_ssm_scan args = { + /*.d_state =*/ d_state, + /*.d_inner =*/ d_inner, + /*.n_seq_tokens =*/ n_seq_tokens, + /*.n_seqs =*/ n_seqs, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb20 =*/ nb20, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb30 =*/ nb30, + /*.nb31 =*/ nb31, + /*.nb40 =*/ nb40, + /*.nb41 =*/ nb41, + /*.nb42 =*/ nb42, + /*.nb50 =*/ nb50, + /*.nb51 =*/ nb51, + /*.nb52 =*/ nb52, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; + [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; + [encoder setBytes:&args length:sizeof(args) atIndex:7]; + + [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_RWKV_WKV6: + { + const int64_t B = dst->src[5]->ne[1]; + const int64_t T = dst->src[0]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[1]; + + GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == 64); + + size_t offs_src3 = 0; + size_t offs_src4 = 0; + size_t offs_src5 = 0; + + id id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil; + id id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil; + id id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil; + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; + [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; + + [encoder setBytes:&B length:sizeof(B) atIndex:7]; + [encoder setBytes:&T length:sizeof(T) atIndex:8]; + [encoder setBytes:&C length:sizeof(C) atIndex:9]; + [encoder setBytes:&H length:sizeof(H) atIndex:10]; + + [encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)]; + } break; + case GGML_OP_RWKV_WKV7: + { + const int64_t B = dst->src[6]->ne[1]; + const int64_t T = dst->src[0]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[1]; + + GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == 64); + + size_t offs_src3 = 0; + size_t offs_src4 = 0; + size_t offs_src5 = 0; + size_t offs_src6 = 0; + + id id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil; + id id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil; + id id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil; + id id_src6 = dst->src[6] ? ggml_metal_get_buffer(dst->src[6], &offs_src6) : nil; + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; + [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; + [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:7]; + + [encoder setBytes:&B length:sizeof(B) atIndex:8]; + [encoder setBytes:&T length:sizeof(T) atIndex:9]; + [encoder setBytes:&C length:sizeof(C) atIndex:10]; + [encoder setBytes:&H length:sizeof(H) atIndex:11]; + + [encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)]; + } break; + case GGML_OP_MUL_MAT: + { + GGML_ASSERT(ne00 == ne10); + + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); + + const uint32_t r2 = ne12/ne02; + const uint32_t r3 = ne13/ne03; + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + const int ne11_mm_min = 4; + + // first try to use small-batch mat-mv kernels + // these should be efficient for BS [2, ~8] + if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) && + ( + ( + ( + src0t == GGML_TYPE_F16 || // TODO: helper function + src0t == GGML_TYPE_Q4_0 || + src0t == GGML_TYPE_Q4_1 || + src0t == GGML_TYPE_Q5_0 || + src0t == GGML_TYPE_Q5_1 || + src0t == GGML_TYPE_Q8_0 || + src0t == GGML_TYPE_IQ4_NL || + false) && (ne11 >= 2 && ne11 <= 8) + ) || + ( + ( + src0t == GGML_TYPE_Q4_K || + src0t == GGML_TYPE_Q5_K || + src0t == GGML_TYPE_Q6_K || + false) && (ne11 >= 4 && ne11 <= 8) + ) + ) + ) { + // TODO: determine the optimal parameters based on grid utilization + // I still don't know why we should not always use the maximum available threads: + // + // nsg = pipeline.maxTotalThreadsPerThreadgroup / 32 + // + // my current hypothesis is that the work grid is not evenly divisible for different nsg + // values and there can be some tail effects when nsg is high. need to confirm this + // + const int nsg = 2; // num simdgroups per threadgroup + const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup + const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time) + const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup + int r1ptg = 4; // num src1 rows per threadgroup + + // note: not sure how optimal are those across all different hardware. there might be someting cleverer + switch (ne11) { + case 2: + r1ptg = 2; break; + case 3: + case 6: + r1ptg = 3; break; + case 4: + case 7: + case 8: + r1ptg = 4; break; + case 5: + r1ptg = 5; break; + }; + + id pipeline = nil; + + switch (src0->type) { + case GGML_TYPE_F16: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q4_0: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q4_1: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q5_0: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q5_1: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q8_0: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q4_K: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q5_K: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q6_K: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_IQ4_NL: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + default: GGML_ABORT("not implemented"); + } + + ggml_metal_kargs_mul_mv_ext args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + /*.nsg =*/ nsg, + /*.nxpsg =*/ nxpsg, + /*.r1ptg =*/ r1ptg, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + //printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg); + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + r0ptg - 1)/r0ptg, (ne11 + r1ptg - 1)/r1ptg, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs + // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel + if ([device supportsFamily:MTLGPUFamilyApple7] && + !ggml_is_transposed(src0) && + !ggml_is_transposed(src1) && + src1t == GGML_TYPE_F32 && + ne00 % 32 == 0 && ne00 >= 64 && + (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) { + //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + + // some Metal matrix data types require aligned pointers + // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) + switch (src0->type) { + case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; + case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; + case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break; + default: break; + } + + id pipeline = nil; + + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break; + case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break; + case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break; + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; + default: GGML_ABORT("MUL MAT-MAT not implemented"); + } + + ggml_metal_kargs_mul_mm args = { + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + [encoder setThreadgroupMemoryLength:8192 atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + } else { + id pipeline = nil; + + int nsg = 0; // number of simdgroups + int nr0 = 0; // number of src0 rows per simdgroup + int nr1 = 1; // number of src1 rows per threadgroup + + size_t smem = 0; // shared memory + + // use custom matrix x vector kernel + switch (src0t) { + case GGML_TYPE_F32: + { + GGML_ASSERT(src1t == GGML_TYPE_F32); + nsg = 1; + nr0 = 1; + nr1 = 4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline; + } break; + case GGML_TYPE_F16: + { + nsg = 1; + nr0 = 1; + if (src1t == GGML_TYPE_F32) { + if (ne11 * ne12 < 4) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline; + } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; + nr1 = ne11; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline; + nr1 = 4; + } + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline; + nr1 = 4; + } + } break; + case GGML_TYPE_BF16: + { + nsg = 1; + nr0 = 1; + if (src1t == GGML_TYPE_F32) { + if (ne11 * ne12 < 4) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline; + } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline; + nr1 = ne11; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline; + nr1 = 4; + } + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline; + nr1 = 4; + } + } break; + case GGML_TYPE_Q4_0: + { + nsg = N_SG_Q4_0; + nr0 = N_R0_Q4_0; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline; + } break; + case GGML_TYPE_Q4_1: + { + nsg = N_SG_Q4_1; + nr0 = N_R0_Q4_1; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline; + } break; + case GGML_TYPE_Q5_0: + { + nsg = N_SG_Q5_0; + nr0 = N_R0_Q5_0; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline; + } break; + case GGML_TYPE_Q5_1: + { + nsg = N_SG_Q5_1; + nr0 = N_R0_Q5_1; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline; + } break; + case GGML_TYPE_Q8_0: + { + nsg = N_SG_Q8_0; + nr0 = N_R0_Q8_0; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; + } break; + case GGML_TYPE_Q2_K: + { + nsg = N_SG_Q2_K; + nr0 = N_R0_Q2_K; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline; + } break; + case GGML_TYPE_Q3_K: + { + nsg = N_SG_Q3_K; + nr0 = N_R0_Q3_K; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline; + } break; + case GGML_TYPE_Q4_K: + { + nsg = N_SG_Q4_K; + nr0 = N_R0_Q4_K; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline; + } break; + case GGML_TYPE_Q5_K: + { + nsg = N_SG_Q5_K; + nr0 = N_R0_Q5_K; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline; + } break; + case GGML_TYPE_Q6_K: + { + nsg = N_SG_Q6_K; + nr0 = N_R0_Q6_K; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline; + } break; + case GGML_TYPE_IQ2_XXS: + { + nsg = N_SG_IQ2_XXS; + nr0 = N_R0_IQ2_XXS; + smem = 256*8+128; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline; + } break; + case GGML_TYPE_IQ2_XS: + { + nsg = N_SG_IQ2_XS; + nr0 = N_R0_IQ2_XS; + smem = 512*8+128; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline; + } break; + case GGML_TYPE_IQ3_XXS: + { + nsg = N_SG_IQ3_XXS; + nr0 = N_R0_IQ3_XXS; + smem = 256*4+128; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline; + } break; + case GGML_TYPE_IQ3_S: + { + nsg = N_SG_IQ3_S; + nr0 = N_R0_IQ3_S; + smem = 512*4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline; + } break; + case GGML_TYPE_IQ2_S: + { + nsg = N_SG_IQ2_S; + nr0 = N_R0_IQ2_S; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline; + } break; + case GGML_TYPE_IQ1_S: + { + nsg = N_SG_IQ1_S; + nr0 = N_R0_IQ1_S; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline; + } break; + case GGML_TYPE_IQ1_M: + { + nsg = N_SG_IQ1_M; + nr0 = N_R0_IQ1_M; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline; + } break; + case GGML_TYPE_IQ4_NL: + { + nsg = N_SG_IQ4_NL; + nr0 = N_R0_IQ4_NL; + smem = 32*sizeof(float); + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline; + } break; + case GGML_TYPE_IQ4_XS: + { + nsg = N_SG_IQ4_XS; + nr0 = N_R0_IQ4_XS; + smem = 32*sizeof(float); + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline; + } break; + default: + { + GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t); + GGML_ABORT("not implemented"); + } + }; + + ggml_metal_kargs_mul_mv args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + if (smem > 0) { + [encoder setThreadgroupMemoryLength:smem atIndex:0]; + } + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } + } break; + case GGML_OP_MUL_MAT_ID: + { + // src2 = ids + const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); + + GGML_ASSERT(src2t == GGML_TYPE_I32); + + GGML_ASSERT(!ggml_is_transposed(src0)); + GGML_ASSERT(!ggml_is_transposed(src1)); + + GGML_ASSERT(src1t == GGML_TYPE_F32); + + GGML_ASSERT(ne03 == 1); + GGML_ASSERT(ne13 == 1); + + const uint32_t r2 = 1; + const uint32_t r3 = 1; + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + // ne20 = n_used_experts + // ne21 = n_rows (batch size) + const int ne21_mm_id_min = 32; + + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs + // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel + if ([device supportsFamily:MTLGPUFamilyApple7] && + ne00 % 32 == 0 && ne00 >= 64 && + (ne21 >= ne21_mm_id_min)) { + GGML_ASSERT(ne00 % 4 == 0); + + // some Metal matrix data types require aligned pointers + // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) + switch (src0->type) { + case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; + case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; + case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break; + default: break; + } + + const int64_t neh10 = ne10; // n_embd + const int64_t neh11 = ne21; // n_tokens + const int64_t neh12 = ne02; // n_expert + + const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16); + const uint64_t nbh11 = nbh10*neh10; + const uint64_t nbh12 = nbh11*neh11; + const uint64_t nbh13 = nbh12*neh12; + + const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12; + id h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1); + if (!h_src1) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1); + return false; + } + + const int64_t neh0 = ne0; + const int64_t neh1 = ne21; + const int64_t neh2 = ne02; + + const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32); + const uint64_t nbh1 = nbh0*neh0; + const uint64_t nbh2 = nbh1*neh1; + //const uint64_t nbh3 = nbh2*neh2; + + const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2; + id h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst); + if (!h_dst) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst); + return false; + } + + // tokens per expert + const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02; + id h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe); + if (!h_tpe) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe); + return false; + } + + // id map + // [n_expert_used, n_tokens] + const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21; + id h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids); + if (!h_ids) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids); + return false; + } + + { + const int nth = MIN(1024, ne10/4); + + ggml_metal_kargs_mul_mm_id_map0 args = { + ne10, + ne11, // n_expert_used (bcast) + nb11, + nb12, + neh11, // n_tokens + nbh11, + ne20, // n_expert_used + nb21, + }; + + id pipeline = nil; + + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer: h_src1 offset:0 atIndex:3]; + [encoder setBuffer: h_tpe offset:0 atIndex:4]; + [encoder setBuffer: h_ids offset:0 atIndex:5]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } + + { + id pipeline = nil; + + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16 ].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16 ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16 ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16 ].pipeline; break; + case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16].pipeline; break; + case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16 ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16 ].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16 ].pipeline; break; + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break; + default: GGML_ABORT("MUL_MAT_ID not implemented"); + } + + ggml_metal_kargs_mul_mm_id args = { + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.neh12 =*/ neh12, + /*.nbh10 =*/ nbh10, + /*.nbh11 =*/ nbh11, + /*.nbh12 =*/ nbh12, + /*.nbh13 =*/ nbh13, + /*.neh0 =*/ neh0, + /*.neh1 =*/ neh1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer: h_src1 offset:0 atIndex:2]; + [encoder setBuffer: h_tpe offset:0 atIndex:3]; + [encoder setBuffer: h_dst offset:0 atIndex:4]; + + [encoder setThreadgroupMemoryLength:8192 atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + } + + { + GGML_ASSERT(ne0 % 4 == 0); + + const int nth = MIN(1024, ne0/4); + + ggml_metal_kargs_mul_mm_id_map1 args = { + ne20, // n_expert_used + neh0, + neh1, + nbh1, + nbh2, + ne0, + nb1, + nb2, + }; + + id pipeline = nil; + + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer: h_dst offset:0 atIndex:1]; + [encoder setBuffer: h_ids offset:0 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } + } else { + id pipeline = nil; + + int nsg = 0; // number of simdgroups + int nr0 = 0; // number of src0 rows per simdgroup + int nr1 = 1; // number of src1 rows per threadgroup + + size_t smem = 0; // shared memory + + // use custom matrix x vector kernel + switch (src0t) { + case GGML_TYPE_F32: + { + GGML_ASSERT(src1t == GGML_TYPE_F32); + nsg = 1; + nr0 = 1; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(src1t == GGML_TYPE_F32); + nsg = 1; + nr0 = 1; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; + } break; + case GGML_TYPE_BF16: + { + GGML_ASSERT(src1t == GGML_TYPE_F32); + nsg = 1; + nr0 = 1; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline; + } break; + case GGML_TYPE_Q4_0: + { + nsg = N_SG_Q4_0; + nr0 = N_R0_Q4_0; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; + } break; + case GGML_TYPE_Q4_1: + { + nsg = N_SG_Q4_1; + nr0 = N_R0_Q4_1; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; + } break; + case GGML_TYPE_Q5_0: + { + nsg = N_SG_Q5_0; + nr0 = N_R0_Q5_0; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; + } break; + case GGML_TYPE_Q5_1: + { + nsg = N_SG_Q5_1; + nr0 = N_R0_Q5_1; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; + } break; + case GGML_TYPE_Q8_0: + { + nsg = N_SG_Q8_0; + nr0 = N_R0_Q8_0; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; + } break; + case GGML_TYPE_Q2_K: + { + nsg = N_SG_Q2_K; + nr0 = N_R0_Q2_K; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; + } break; + case GGML_TYPE_Q3_K: + { + nsg = N_SG_Q3_K; + nr0 = N_R0_Q3_K; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; + } break; + case GGML_TYPE_Q4_K: + { + nsg = N_SG_Q4_K; + nr0 = N_R0_Q4_K; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; + } break; + case GGML_TYPE_Q5_K: + { + nsg = N_SG_Q5_K; + nr0 = N_R0_Q5_K; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; + } break; + case GGML_TYPE_Q6_K: + { + nsg = N_SG_Q6_K; + nr0 = N_R0_Q6_K; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; + } break; + case GGML_TYPE_IQ2_XXS: + { + nsg = N_SG_IQ2_XXS; + nr0 = N_R0_IQ2_XXS; + smem = 256*8+128; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; + } break; + case GGML_TYPE_IQ2_XS: + { + nsg = N_SG_IQ2_XS; + nr0 = N_R0_IQ2_XS; + smem = 512*8+128; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; + } break; + case GGML_TYPE_IQ3_XXS: + { + nsg = N_SG_IQ3_XXS; + nr0 = N_R0_IQ3_XXS; + smem = 256*4+128; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline; + } break; + case GGML_TYPE_IQ3_S: + { + nsg = N_SG_IQ3_S; + nr0 = N_R0_IQ3_S; + smem = 512*4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline; + } break; + case GGML_TYPE_IQ2_S: + { + nsg = N_SG_IQ2_S; + nr0 = N_R0_IQ2_S; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline; + } break; + case GGML_TYPE_IQ1_S: + { + nsg = N_SG_IQ1_S; + nr0 = N_R0_IQ1_S; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline; + } break; + case GGML_TYPE_IQ1_M: + { + nsg = N_SG_IQ1_M; + nr0 = N_R0_IQ1_M; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline; + } break; + case GGML_TYPE_IQ4_NL: + { + nsg = N_SG_IQ4_NL; + nr0 = N_R0_IQ4_NL; + smem = 32*sizeof(float); + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; + } break; + case GGML_TYPE_IQ4_XS: + { + nsg = N_SG_IQ4_XS; + nr0 = N_R0_IQ4_XS; + smem = 32*sizeof(float); + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; + } break; + default: + { + GGML_LOG_ERROR("Asserting on type %d\n", (int)src2t); + GGML_ABORT("not implemented"); + } + }; + + if (ggml_is_quantized(src0t)) { + GGML_ASSERT(ne00 >= nsg*nr0); + } + + ggml_metal_kargs_mul_mv_id args = { + /*.nei0 =*/ ne20, + /*.nei1 =*/ ne21, + /*.nbi1 =*/ nb21, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.nb1 =*/ nb1, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; + + const int64_t _ne1 = 1; + const int64_t ne123 = ne20*ne21; + + if (smem > 0) { + [encoder setThreadgroupMemoryLength:smem atIndex:0]; + } + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } + } break; + case GGML_OP_GET_ROWS: + { + id pipeline = nil; + + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; + case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break; + case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break; + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; + case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; + default: GGML_ABORT("not implemented"); + } + + ggml_metal_kargs_get_rows args = { + /*.ne00 =*/ ne00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne10 =*/ ne10, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&args length:sizeof(args) atIndex:3]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + } break; + case GGML_OP_RMS_NORM: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; + + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { + nth *= 2; + } + + nth = MIN(nth, ne00/4); + + ggml_metal_kargs_rms_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb01 =*/ nb01, + /*.eps =*/ eps, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_L2_NORM: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_L2_NORM].pipeline; + + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { + nth *= 2; + } + + nth = MIN(nth, ne00/4); + + ggml_metal_kargs_l2_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb01 =*/ nb01, + /*.eps =*/ eps, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_GROUP_NORM: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + float eps; + memcpy(&eps, dst->op_params + 1, sizeof(float)); + + const int32_t n_groups = ((const int32_t *) dst->op_params)[0]; + + int nth = 32; // SIMD width + + //while (nth < ne00/4 && nth < 1024) { + // nth *= 2; + //} + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; + + ggml_metal_kargs_group_norm args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.n_groups =*/ n_groups, + /*.eps =*/ eps, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_NORM: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline; + + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { + nth *= 2; + } + + nth = MIN(nth, ne00/4); + + ggml_metal_kargs_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb01 =*/ nb01, + /*.eps =*/ eps, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ROPE: + { + + // make sure we have one or more position id(ne10) per token(ne02) + GGML_ASSERT(ne10 % ne02 == 0); + GGML_ASSERT(ne10 >= ne02); + + const int nth = MIN(1024, ne00); + + const int n_past = ((const int32_t *) dst->op_params)[0]; + const int n_dims = ((const int32_t *) dst->op_params)[1]; + const int mode = ((const int32_t *) dst->op_params)[2]; + // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal + const int n_ctx_orig = ((const int32_t *) dst->op_params)[4]; + + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + + memcpy(&freq_base, (const int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (const int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (const int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (const int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float)); + + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + // mrope + const int sect_0 = ((const int32_t *) dst->op_params)[11]; + const int sect_1 = ((const int32_t *) dst->op_params)[12]; + const int sect_2 = ((const int32_t *) dst->op_params)[13]; + const int sect_3 = ((const int32_t *) dst->op_params)[14]; + + id pipeline = nil; + + if (is_neox) { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } else if (is_mrope && !is_vision) { + GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } else if (is_vision) { + GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } else { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } + + ggml_metal_kargs_rope args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.n_past =*/ n_past, + /*.n_dims =*/ n_dims, + /*.n_ctx_orig =*/ n_ctx_orig, + /*.freq_base =*/ freq_base, + /*.freq_scale =*/ freq_scale, + /*.ext_factor =*/ ext_factor, + /*.attn_factor =*/ attn_factor, + /*.beta_fast =*/ beta_fast, + /*.beta_slow =*/ beta_slow, + /* sect_0 =*/ sect_0, + /* sect_1 =*/ sect_1, + /* sect_2 =*/ sect_2, + /* sect_3 =*/ sect_3, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + if (id_src2 != nil) { + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_IM2COL: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int32_t N = src1->ne[is_2D ? 3 : 2]; + const int32_t IC = src1->ne[is_2D ? 2 : 1]; + const int32_t IH = is_2D ? src1->ne[1] : 1; + const int32_t IW = src1->ne[0]; + + const int32_t KH = is_2D ? src0->ne[1] : 1; + const int32_t KW = src0->ne[0]; + + const int32_t OH = is_2D ? dst->ne[2] : 1; + const int32_t OW = dst->ne[1]; + + const int32_t CHW = IC * KH * KW; + + const uint64_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; + const uint64_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; + + const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup; + + switch (dst->type) { + case GGML_TYPE_F32: { + pipeline = (is_gt_mttpt ? + ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline + : + ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline); + } break; + case GGML_TYPE_F16: { + pipeline = (is_gt_mttpt ? + ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline + : + ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline); + } break; + default: GGML_ABORT("fatal error"); + }; + + ggml_metal_kargs_im2col args = { + /*.ofs0 =*/ ofs0, + /*.ofs1 =*/ ofs1, + /*.IW =*/ IW, + /*.IH =*/ IH, + /*.CHW =*/ CHW, + /*.s0 =*/ s0, + /*.s1 =*/ s1, + /*.p0 =*/ p0, + /*.p1 =*/ p1, + /*.d0 =*/ d0, + /*.d1 =*/ d1, + /*.N =*/ N, + /*.KH =*/ KH, + /*.KW =*/ KW, + /*.KHW =*/ KH * KW, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; + + if (is_gt_mttpt) { + const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N); + + const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0); + + [encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)]; + } else { + [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; + } + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + + const int32_t IC = src1->ne[1]; + const int32_t IL = src1->ne[0]; + + const int32_t K = src0->ne[0]; + + const int32_t OL = dst->ne[0]; + const int32_t OC = dst->ne[1]; + + id pipeline; + + switch (src0->type) { + case GGML_TYPE_F32: { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32].pipeline; + } break; + case GGML_TYPE_F16: { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32].pipeline; + } break; + default: GGML_ABORT("fatal error"); + }; + + ggml_metal_kargs_conv_transpose_1d args = { + /*.IC =*/ IC, + /*.IL =*/ IL, + /*.K =*/ K, + /*.s0 =*/ s0, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&args length:sizeof(args) atIndex:3]; + + [encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_UPSCALE: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + const float sf0 = (float)ne0/src0->ne[0]; + const float sf1 = (float)ne1/src0->ne[1]; + const float sf2 = (float)ne2/src0->ne[2]; + const float sf3 = (float)ne3/src0->ne[3]; + + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; + + ggml_metal_kargs_upscale args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.sf0 =*/ sf0, + /*.sf1 =*/ sf1, + /*.sf2 =*/ sf2, + /*.sf3 =*/ sf3 + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_PAD: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; + + ggml_metal_kargs_pad args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3 + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_PAD_REFLECT_1D: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + const int32_t p0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[1]; + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline; + + ggml_metal_kargs_pad_reflect_1d args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.p0 =*/ p0, + /*.p1 =*/ p1 + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ARANGE: + { + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + float start; + float step; + + memcpy(&start, ((const int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&step, ((const int32_t *) dst->op_params) + 2, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; + + ggml_metal_kargs_arange args = { + /*.ne0 =*/ ne0, + /*.start =*/ start, + /*.step =*/ step + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; + [encoder setBytes:&args length:sizeof(args) atIndex:1]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_TIMESTEP_EMBEDDING: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + const int dim = dst->op_params[0]; + const int max_period = dst->op_params[1]; + + const int half = dim / 2; + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; + + ggml_metal_kargs_timestep_embedding args = { + /*.nb1 =*/ nb1, + /*.dim =*/ dim, + /*.max_period =*/ max_period + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; + + const int nth = MIN(1024, half); + + [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ARGSORT: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + + const int nrows = ggml_nrows(src0); + + enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; + + // bitonic sort requires the number of elements to be power of 2 + int64_t ne00_padded = 1; + while (ne00_padded < ne00) { + ne00_padded *= 2; + } + + // Metal kernels require the buffer size to be multiple of 16 bytes + // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength + const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16); + + id pipeline = nil; + + switch (order) { + case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; + case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + + ggml_metal_kargs_argsort args = { + /*.ncols =*/ ne00, + /*.ncols_pad =*/ ne00_padded + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)]; + } break; + case GGML_OP_LEAKY_RELU: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + float slope; + memcpy(&slope, dst->op_params, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; + + ggml_metal_kargs_leaky_relu args = { + /*.slope =*/ slope + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_FLASH_ATTN_EXT: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ne11 % 32 == 0); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == src2->type); + + //GGML_ASSERT(ggml_are_same_shape (src1, src2)); + GGML_ASSERT(ne11 == ne21); + GGML_ASSERT(ne12 == ne22); + + struct ggml_tensor * src3 = node->src[3]; + + size_t offs_src3 = 0; + + id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; + + GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); + GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && + "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + + const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); + //const int64_t ne31 = src3 ? src3->ne[1] : 0; + const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); + const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); + + const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); + const uint64_t nb31 = src3 ? src3->nb[1] : 0; + const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); + const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); + + const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); + + float scale; + float max_bias; + float logit_softcap; + memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); + memcpy(&logit_softcap, ((const int32_t *) dst->op_params) + 2, sizeof(logit_softcap)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + + const uint32_t n_head = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + id pipeline = nil; + + bool use_vec_kernel = false; + + // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0) + // for now avoiding mainly to keep the number of templates/kernels a bit lower + // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612 + if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576)) { + switch (src1->type) { + case GGML_TYPE_F16: + { + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline; + } else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } + } break; + case GGML_TYPE_BF16: + { + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline; + } else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break; + case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } + } break; + case GGML_TYPE_Q4_0: + { + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline; + } else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break; + case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } + } break; + case GGML_TYPE_Q4_1: + { + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline; + } else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break; + case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } + } break; + case GGML_TYPE_Q5_0: + { + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline; + } else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break; + case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } + } break; + case GGML_TYPE_Q5_1: + { + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline; + } else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break; + case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } + } break; + case GGML_TYPE_Q8_0: + { + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline; + } else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break; + case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } + } break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } else { + use_vec_kernel = true; + + switch (ne00) { + case 64: + { + switch (src1->type) { + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } break; + case 96: + { + switch (src1->type) { + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } break; + case 128: + { + switch (src1->type) { + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } break; + case 192: + { + if (ne20 == 128) { + switch (src1->type) { + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } else { + switch (src1->type) { + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } + } break; + case 256: + { + switch (src1->type) { + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } break; + case 576: + { + if (ne20 == 512) { + switch (src1->type) { + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } else { + GGML_LOG_ERROR("unsupported size: %lld\n", ne20); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } + + ggml_metal_kargs_flash_attn_ext args = { + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne_12_2 =*/ ne12, + /*.ne_12_3 =*/ ne13, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb23 =*/ nb23, + /*.nb31 =*/ nb31, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + /*.logit_softcap =*/ logit_softcap, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + if (id_src3) { + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:5]; + + if (!use_vec_kernel) { + // half8x8 kernel + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + const int is_q = ggml_is_quantized(src1->type) ? 1 : 0; + + // 2*(2*ncpsg + nqptg)*(nsg) + // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float) + // + // 16*32*(nsg) + // the shared memory needed for the simdgroups to load the KV cache + // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG + // +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16)) + + int64_t nsgmax = 2; + + while (true) { + const size_t smem = FATTN_SMEM(nsgmax); + if (smem > device.maxThreadgroupMemoryLength) { + break; + } + nsgmax *= 2; + } + nsgmax /= 2; + + // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; + + const size_t smem = FATTN_SMEM(nsg); + + //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg); + GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:smem atIndex:0]; +#undef FATTN_SMEM + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else { + // half4x4 kernel + const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 1 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + // ne00 + 2*ncpsg*(nsg) + // for each query, we load it as f16 in shared memory (ne00) + // and store the soft_max values and the mask + // + // ne00*(nsg) + // each simdgroup has a full f32 head vector in shared mem to accumulate results + // +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16)) + + int64_t nsgmax = 2; + while (true) { + const size_t smem = FATTN_SMEM(nsgmax); + if (smem > device.maxThreadgroupMemoryLength) { + break; + } + nsgmax *= 2; + } + nsgmax /= 2; + + // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))); + + int64_t nsg = 1; + while (nsg <= nsgt) { + nsg *= 2; + } + nsg /= 2; + + const size_t smem = FATTN_SMEM(nsg); + + //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg); + GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:smem atIndex:0]; +#undef FATTN_SMEM + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } + } break; + case GGML_OP_DUP: + case GGML_OP_CPY: + case GGML_OP_CONT: + { + id pipeline = nil; + + switch (src0t) { + case GGML_TYPE_F32: + { + GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); + + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_F16: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_BF16: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_Q4_0: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_Q4_1: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_Q5_0: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_Q5_1: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_Q8_0: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + default: GGML_ABORT("not implemented"); + } + + ggml_metal_kargs_cpy args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + + } break; + case GGML_OP_SET: + { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); + + // src0 and dst as viewed during set + const size_t dst_nb0 = ggml_element_size(src0); + + const size_t dst_nb1 = ((int32_t *) dst->op_params)[0]; + const size_t dst_nb2 = ((int32_t *) dst->op_params)[1]; + const size_t dst_nb3 = ((int32_t *) dst->op_params)[2]; + const size_t offset = ((int32_t *) dst->op_params)[3]; + const bool inplace = (bool) ((int32_t *) dst->op_params)[4]; + + if (!inplace) { + memcpy(((char *) dst->data), ((char *) src0->data), ggml_nbytes(dst)); + } + + const int im0 = (ne10 == 0 ? 0 : ne10-1); + const int im1 = (ne11 == 0 ? 0 : ne11-1); + const int im2 = (ne12 == 0 ? 0 : ne12-1); + const int im3 = (ne13 == 0 ? 0 : ne13-1); + + GGML_ASSERT(offset + im0*dst_nb0 + im1*dst_nb1 + im2*dst_nb2 + im3*dst_nb3 <= ggml_nbytes(dst)); + + id pipeline = nil; + + switch (src0t) { + case GGML_TYPE_F32: + GGML_ASSERT(nb10 == sizeof(float)); + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_F32].pipeline; break; + case GGML_TYPE_I32: + GGML_ASSERT(nb10 == sizeof(int32_t)); + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_I32].pipeline; break; + default: GGML_ABORT("fatal error"); + } + + ggml_metal_kargs_set args = { + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb1 =*/ dst_nb1, + /*.nb2 =*/ dst_nb2, + /*.nb3 =*/ dst_nb3, + /*.offs =*/ offset, + /*.inplace =*/ inplace, + }; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne10); + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_POOL_2D: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt); + + const int32_t * opts = dst->op_params; + enum ggml_op_pool op = opts[0]; + + id pipeline = nil; + switch (src0t) { + case GGML_TYPE_F32: { + switch(op) { + case GGML_OP_POOL_AVG: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break; + case GGML_OP_POOL_MAX: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break; + default: GGML_ASSERT(false && "not implemented"); + } + } break; + default: GGML_ASSERT(false && "not implemented"); + } + + const int32_t k0 = opts[1]; + const int32_t k1 = opts[2]; + const int32_t s0 = opts[3]; + const int32_t s1 = opts[4]; + const int32_t p0 = opts[5]; + const int32_t p1 = opts[6]; + + const int64_t IH = src0->ne[1]; + const int64_t IW = src0->ne[0]; + + const int64_t N = dst->ne[3]; + const int64_t OC = dst->ne[2]; + const int64_t OH = dst->ne[1]; + const int64_t OW = dst->ne[0]; + + const int64_t parallel_elements = N * OC * OH * OW; + const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements); + const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads; + + ggml_metal_kargs_pool_2d args_pool_2d = { + /* .k0 = */ k0, + /* .k1 = */ k1, + /* .s0 = */ s0, + /* .s1 = */ s1, + /* .p0 = */ p0, + /* .p1 = */ p1, + /* .IH = */ IH, + /* .IW = */ IW, + /* .OH = */ OH, + /* .OW = */ OW, + /* .parallel_elements = */ parallel_elements + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&args_pool_2d length:sizeof(args_pool_2d) atIndex:2]; + + [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)]; + } break; + case GGML_OP_ARGMAX: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(nb00 == ggml_type_size(src0->type)); + + const int64_t nrows = ggml_nrows(src0); + + int nth = 32; // SIMD width + while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGMAX].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + [encoder setThreadgroupMemoryLength:32*sizeof(int32_t) atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + default: + { + GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); + GGML_ABORT("fatal error"); + } + } + + return true; +} + +static enum ggml_status ggml_metal_graph_compute( + ggml_backend_t backend, + struct ggml_cgraph * gf) { + struct ggml_backend_metal_context * ctx = backend->context; + struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + + // number of nodes encoded by the main thread (empirically determined) + const int n_main = 128; + + // number of threads in addition to the main thread + const int n_cb = ctx->n_cb; + + // submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them + // the first n_nodes_0 are encoded and submitted for processing directly by the calling thread + // while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes + // each thread creates it's own command buffer and enqueues the ops in parallel + // + // tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2 + + @autoreleasepool { + ctx->gf = gf; + + ctx->n_nodes_0 = MIN(n_main, gf->n_nodes); + ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0; + + ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb; + + const bool should_capture = ctx->capture_next_compute; + if (should_capture) { + ctx->capture_next_compute = false; + + if (!ctx->capture_started) { + // create capture scope + ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx_dev->mtl_device]; + + MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; + descriptor.captureObject = ctx->capture_scope; + descriptor.destination = MTLCaptureDestinationGPUTraceDocument; + descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]]; + + NSError * error = nil; + if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { + GGML_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]); + } else { + [ctx->capture_scope beginScope]; + ctx->capture_started = true; + } + } + } + + // the main thread commits the first few commands immediately + // cmd_buf[n_cb] + { + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + ctx->cmd_bufs[n_cb].obj = cmd_buf; + + [cmd_buf enqueue]; + ctx->encode_async(n_cb); + } + + // prepare the rest of the command buffers asynchronously + // cmd_buf[0.. n_cb) + for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + ctx->cmd_bufs[cb_idx].obj = cmd_buf; + + // always enqueue the first two command buffers + // enqueue all of the command buffers if we don't need to abort + if (cb_idx < 2 || ctx->abort_callback == NULL) { + [cmd_buf enqueue]; + } + } + + dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async); + + // wait for completion and check status of each command buffer + // needed to detect if the device ran out-of-memory for example (#1881) + { + id cmd_buf = ctx->cmd_bufs[n_cb].obj; + [cmd_buf waitUntilCompleted]; + + MTLCommandBufferStatus status = [cmd_buf status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; + } + } + + for (int i = 0; i < n_cb; ++i) { + id cmd_buf = ctx->cmd_bufs[i].obj; + [cmd_buf waitUntilCompleted]; + + MTLCommandBufferStatus status = [cmd_buf status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; + } + + id next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil); + if (!next_buffer) { + continue; + } + + const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); + if (next_queued) { + continue; + } + + if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { + GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i); + return GGML_STATUS_ABORTED; + } + + [next_buffer commit]; + } + + if (!should_capture && ctx->capture_started) { + [ctx->capture_scope endScope]; + [[MTLCaptureManager sharedCaptureManager] stopCapture]; + } + } + + return GGML_STATUS_SUCCESS; +} + +//////////////////////////////////////////////////////////////////////////////// + +// backend interface + +static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + + for (int i = 0; i < ctx->n_buffers; i++) { + [ctx->buffers[i].metal release]; + } + + ggml_backend_metal_buffer_rset_free(ctx); + ggml_backend_metal_device_rel(buffer->buft->device->context); + + if (ctx->owned) { +#if TARGET_OS_OSX + vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ctx->all_data, ctx->all_size); +#else + free(ctx->all_data); +#endif + } + + free(ctx); +} + +static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + + return ctx->all_data; +} + +static void ggml_backend_metal_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + memset((char *)tensor->data + offset, value, size); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + memcpy((char *)tensor->data + offset, data, size); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + memcpy(data, (const char *)tensor->data + offset, size); + + GGML_UNUSED(buffer); +} + +static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { + if (ggml_backend_buffer_is_host(src->buffer)) { + memcpy(dst->data, src->data, ggml_nbytes(src)); + return true; + } + return false; + + GGML_UNUSED(buffer); +} + +static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + + memset(ctx->all_data, value, ctx->all_size); +} + +static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = { + /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer, + /* .get_base = */ ggml_backend_metal_buffer_get_base, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_metal_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor, + /* .clear = */ ggml_backend_metal_buffer_clear, + /* .reset = */ NULL, +}; + +// default buffer type + +static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "Metal"; + + GGML_UNUSED(buft); +} + +static void ggml_backend_metal_log_allocated_size(id device, size_t size_aligned) { +#ifndef GGML_METAL_NDEBUG +#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) + if (@available(macOS 10.12, iOS 16.0, *)) { + GGML_LOG_DEBUG("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)\n", + __func__, + size_aligned / 1024.0 / 1024.0, + device.currentAllocatedSize / 1024.0 / 1024.0, + device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); + + if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) { + GGML_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__); + } + } else { + GGML_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n", + __func__, + size_aligned / 1024.0 / 1024.0, + device.currentAllocatedSize / 1024.0 / 1024.0); + } +#endif +#endif + GGML_UNUSED(device); + GGML_UNUSED(size_aligned); +} + +static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context)); + + const size_t size_page = sysconf(_SC_PAGESIZE); + + size_t size_aligned = size; + if ((size_aligned % size_page) != 0) { + size_aligned += (size_page - (size_aligned % size_page)); + } + + struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context; + id device = ggml_backend_metal_device_acq(ctx_dev); + + ctx->all_data = ggml_metal_host_malloc(size_aligned); + ctx->all_size = size_aligned; + ctx->owned = true; + ctx->n_buffers = 1; + + if (ctx->all_data != NULL) { + ctx->buffers[0].data = ctx->all_data; + ctx->buffers[0].size = size; + ctx->buffers[0].metal = nil; + + if (size_aligned > 0) { + ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data + length:size_aligned + options:MTLResourceStorageModeShared + deallocator:nil]; + } + } + + if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) { + GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); + free(ctx); + ggml_backend_metal_device_rel(ctx_dev); + return NULL; + } + + if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { + GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); + free(ctx); + ggml_backend_metal_device_rel(ctx_dev); + return NULL; + } + + //ggml_backend_metal_log_allocated_size(device, size_aligned); + + return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size); +} + +static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return 32; + GGML_UNUSED(buft); +} + +static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + id device = ggml_backend_metal_device_acq(buft->device->context); + const size_t max_size = device.maxBufferLength; + ggml_backend_metal_device_rel(buft->device->context); + + return max_size; + + GGML_UNUSED(buft); +} + +static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + return true; + + GGML_UNUSED(buft); +} + +ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) { + static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size, + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .is_host = */ ggml_backend_metal_buffer_type_is_host, + }, + /* .device = */ &g_ggml_backend_metal_device, + /* .context = */ NULL, + }; + + return &ggml_backend_buffer_type_metal; +} + +static const char * ggml_backend_metal_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) { + return "Metal_Mapped"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_from_ptr_type(void) { + static struct ggml_backend_buffer_type ggml_backend_buffer_from_ptr_type_metal = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_from_ptr_type_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size, + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .is_host = */ ggml_backend_metal_buffer_type_is_host, + }, + /* .device = */ &g_ggml_backend_metal_device, + /* .context = */ NULL, + }; + + return &ggml_backend_buffer_from_ptr_type_metal; +} + +// TODO: obsoleted by ggml_backend_metal_device_buffer_from_ptr +ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) { + struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context)); + + ctx->all_data = data; + ctx->all_size = size; + ctx->owned = false; + ctx->n_buffers = 0; + + const size_t size_page = sysconf(_SC_PAGESIZE); + + // page-align the data ptr + { + const uintptr_t offs = (uintptr_t) data % size_page; + data = (void *) ((char *) data - offs); + size += offs; + } + + size_t size_aligned = size; + if ((size_aligned % size_page) != 0) { + size_aligned += (size_page - (size_aligned % size_page)); + } + + struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main; + id device = ggml_backend_metal_device_acq(ctx_dev); + + // the buffer fits into the max buffer size allowed by the device + if (size_aligned <= device.maxBufferLength) { + ctx->buffers[ctx->n_buffers].data = data; + ctx->buffers[ctx->n_buffers].size = size; + ctx->buffers[ctx->n_buffers].metal = nil; + + if (size_aligned > 0) { + ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; + + if (ctx->buffers[ctx->n_buffers].metal == nil) { + GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); + return false; + } + } + + ggml_backend_metal_log_allocated_size(device, size_aligned); + + ++ctx->n_buffers; + } else { + // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into + // one of the views + const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case + const size_t size_step = device.maxBufferLength - size_ovlp; + const size_t size_view = device.maxBufferLength; + + for (size_t i = 0; i < size; i += size_step) { + const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i); + + ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i); + ctx->buffers[ctx->n_buffers].size = size_step_aligned; + ctx->buffers[ctx->n_buffers].metal = nil; + + if (size_step_aligned > 0) { + ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; + + if (ctx->buffers[ctx->n_buffers].metal == nil) { + GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0); + return false; + } + } + + ggml_backend_metal_log_allocated_size(device, size_step_aligned); + + if (i + size_step < size) { + GGML_LOG_INFO("\n"); + } + + ++ctx->n_buffers; + } + } + + if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { + GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); + free(ctx); + ggml_backend_metal_device_rel(ctx_dev); + return NULL; + } + + return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size); +} + +// backend + +static const char * ggml_backend_metal_name(ggml_backend_t backend) { + return "Metal"; + + GGML_UNUSED(backend); +} + +static void ggml_backend_metal_free(ggml_backend_t backend) { + struct ggml_backend_metal_context * ctx = backend->context; + struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + + ggml_backend_metal_device_rel(ctx_dev); + ggml_metal_free(ctx); + + free(backend); +} + +static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + return ggml_metal_graph_compute(backend, cgraph); +} + +static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { + GGML_ASSERT(ggml_backend_is_metal(backend)); + + struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; + + if (ctx->n_cb != n_cb) { + ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS); + + if (ctx->n_cb > 2) { + GGML_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb); + } + } + + if (ctx->encode_async) { + Block_release(ctx->encode_async); + } + + ctx->encode_async = Block_copy(^(size_t iter) { + const int cb_idx = iter; + const int n_cb_l = ctx->n_cb; + + const int n_nodes_0 = ctx->n_nodes_0; + const int n_nodes_1 = ctx->n_nodes_1; + + const int n_nodes_per_cb = ctx->n_nodes_per_cb; + + id cmd_buf = ctx->cmd_bufs[cb_idx].obj; + + id encoder = [cmd_buf computeCommandEncoder]; + + int node_start = 0; + int node_end = n_nodes_0; + + if (cb_idx < n_cb_l) { + node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb); + node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1)); + } + + const bool should_capture = ctx->capture_next_compute; + + struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool; + ggml_metal_mem_pool_reset(mem_pool); + + for (int idx = node_start; idx < node_end; ++idx) { + if (should_capture) { + [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]]; + } + + const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool); + + if (should_capture) { + [encoder popDebugGroup]; + } + + if (!res) { + break; + } + } + + [encoder endEncoding]; + + if (cb_idx < 2 || ctx->abort_callback == NULL) { + [cmd_buf commit]; + } + }); +} + +static struct ggml_backend_i ggml_backend_metal_i = { + /* .get_name = */ ggml_backend_metal_name, + /* .free = */ ggml_backend_metal_free, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ NULL, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_metal_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, +}; + +static ggml_guid_t ggml_backend_metal_guid(void) { + static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 }; + return &guid; +} + +// TODO: remove in the future +ggml_backend_t ggml_backend_metal_init(void) { + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0); + + struct ggml_backend_metal_context * ctx = ggml_metal_init(dev); + if (ctx == NULL) { + GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); + return NULL; + } + + ggml_backend_t backend = malloc(sizeof(struct ggml_backend)); + + *backend = (struct ggml_backend) { + /* .guid = */ ggml_backend_metal_guid(), + /* .interface = */ ggml_backend_metal_i, + /* .device = */ dev, + /* .context = */ ctx, + }; + + ggml_backend_metal_set_n_cb(backend, 1); + + return backend; +} + +bool ggml_backend_is_metal(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid()); +} + +void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) { + GGML_ASSERT(ggml_backend_is_metal(backend)); + + struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; + + ctx->abort_callback = abort_callback; + ctx->abort_callback_data = user_data; +} + +bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) { + GGML_ASSERT(ggml_backend_is_metal(backend)); + + struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + + return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)]; +} + +void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) { + GGML_ASSERT(ggml_backend_is_metal(backend)); + + struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; + ctx->capture_next_compute = true; +} + +// backend device + +static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) { + return "Metal"; + + GGML_UNUSED(dev); +} + +static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) { + // acq/rel just to populate ctx->name in case it hasn't been done yet + struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context; + ggml_backend_metal_device_acq(ctx_dev); + ggml_backend_metal_device_rel(ctx_dev); + + return ctx_dev->name; +} + +static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + if (@available(macOS 10.12, iOS 16.0, *)) { + struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context; + id device = ggml_backend_metal_device_acq(ctx_dev); + + *total = device.recommendedMaxWorkingSetSize; + *free = *total - device.currentAllocatedSize; + + ggml_backend_metal_device_rel(ctx_dev); + } else { + *free = 1; + *total = 1; + } +} + +static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) { + return GGML_BACKEND_DEVICE_TYPE_GPU; + + GGML_UNUSED(dev); +} + +static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_metal_device_get_name(dev); + props->description = ggml_backend_metal_device_get_description(dev); + props->type = ggml_backend_metal_device_get_type(dev); + ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = (struct ggml_backend_dev_caps) { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ true, + /* .events = */ false, + }; +} + +static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) { + struct ggml_backend_metal_context * ctx = ggml_metal_init(dev); + if (ctx == NULL) { + GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); + return NULL; + } + + ggml_backend_t backend = malloc(sizeof(struct ggml_backend)); + + *backend = (struct ggml_backend) { + /* .guid = */ ggml_backend_metal_guid(), + /* .interface = */ ggml_backend_metal_i, + /* .device = */ dev, + /* .context = */ ctx, + }; + + ggml_backend_metal_set_n_cb(backend, 1); + + return backend; + + GGML_UNUSED(params); +} + +static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) { + return ggml_backend_metal_buffer_type(); + + GGML_UNUSED(dev); +} + +static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context)); + + ctx->all_data = ptr; + ctx->all_size = size; + ctx->owned = false; + ctx->n_buffers = 0; + + const size_t size_page = sysconf(_SC_PAGESIZE); + + // page-align the data ptr + { + const uintptr_t offs = (uintptr_t) ptr % size_page; + ptr = (void *) ((char *) ptr - offs); + size += offs; + } + + size_t size_aligned = size; + if ((size_aligned % size_page) != 0) { + size_aligned += (size_page - (size_aligned % size_page)); + } + + struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context; + id device = ggml_backend_metal_device_acq(ctx_dev); + + // the buffer fits into the max buffer size allowed by the device + if (size_aligned <= device.maxBufferLength) { + ctx->buffers[ctx->n_buffers].data = ptr; + ctx->buffers[ctx->n_buffers].size = size; + ctx->buffers[ctx->n_buffers].metal = nil; + + if (size_aligned > 0) { + ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; + + if (ctx->buffers[ctx->n_buffers].metal == nil) { + GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); + return false; + } + } + + ggml_backend_metal_log_allocated_size(device, size_aligned); + + ++ctx->n_buffers; + } else { + // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into + // one of the views + const size_t size_ovlp = ((max_tensor_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case + const size_t size_step = device.maxBufferLength - size_ovlp; + const size_t size_view = device.maxBufferLength; + + for (size_t i = 0; i < size; i += size_step) { + const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i); + + ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) ptr + i); + ctx->buffers[ctx->n_buffers].size = size_step_aligned; + ctx->buffers[ctx->n_buffers].metal = nil; + + if (size_step_aligned > 0) { + ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; + + if (ctx->buffers[ctx->n_buffers].metal == nil) { + GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0); + return false; + } + } + + ggml_backend_metal_log_allocated_size(device, size_step_aligned); + + if (i + size_step < size) { + GGML_LOG_INFO("\n"); + } + + ++ctx->n_buffers; + } + } + + if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { + GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); + free(ctx); + ggml_backend_metal_device_rel(ctx_dev); + return NULL; + } + + return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size); +} + +static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + struct ggml_backend_metal_device_context * ctx_dev = dev->context; + + return ggml_metal_supports_op(ctx_dev, op); +} + +static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name || + buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name; + + GGML_UNUSED(dev); +} + +static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + return false; + + GGML_UNUSED(dev); + GGML_UNUSED(op); +} + +static struct ggml_backend_device_i ggml_backend_metal_device_i = { + /* .get_name = */ ggml_backend_metal_device_get_name, + /* .get_description = */ ggml_backend_metal_device_get_description, + /* .get_memory = */ ggml_backend_metal_device_get_memory, + /* .get_type = */ ggml_backend_metal_device_get_type, + /* .get_props = */ ggml_backend_metal_device_get_props, + /* .init_backend = */ ggml_backend_metal_device_init, + /* .get_buffer_type = */ ggml_backend_metal_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_from_ptr, + /* .supports_op = */ ggml_backend_metal_device_supports_op, + /* .supports_buft = */ ggml_backend_metal_device_supports_buft, + /* .offload_op = */ ggml_backend_metal_device_offload_op, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +// backend registry + +static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) { + return "Metal"; + + GGML_UNUSED(reg); +} + +static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) { + return 1; + + GGML_UNUSED(reg); +} + +static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) { + GGML_ASSERT(index == 0); + + return &g_ggml_backend_metal_device; + + GGML_UNUSED(reg); + GGML_UNUSED(index); +} + +static struct ggml_backend_feature g_ggml_backend_metal_features[] = { +#if defined(GGML_METAL_EMBED_LIBRARY) + { "EMBED_LIBRARY", "1" }, +#endif +#if defined(GGML_METAL_USE_BF16) + { "BF16", "1" }, +#endif + { nil, nil }, +}; + +static struct ggml_backend_feature * ggml_backend_metal_get_features(ggml_backend_reg_t reg) { + return g_ggml_backend_metal_features; + + GGML_UNUSED(reg); +} + +static void * ggml_backend_metal_get_proc_address(ggml_backend_reg_t reg, const char * name) { + if (strcmp(name, "ggml_backend_get_features") == 0) { + return (void *)ggml_backend_metal_get_features; + } + + return NULL; + + GGML_UNUSED(reg); +} +static struct ggml_backend_reg_i ggml_backend_metal_reg_i = { + /* .get_name = */ ggml_backend_metal_reg_get_name, + /* .device_count = */ ggml_backend_metal_reg_device_count, + /* .device_get = */ ggml_backend_metal_reg_device_get, + /* .get_proc_address = */ ggml_backend_metal_get_proc_address, +}; + +ggml_backend_reg_t ggml_backend_metal_reg(void) { + // TODO: make this thread-safe somehow? + { + g_ggml_backend_metal_reg = (struct ggml_backend_reg) { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_metal_reg_i, + /* .context = */ NULL, + }; + + g_ggml_backend_metal_device = (struct ggml_backend_device) { + /* .iface = */ ggml_backend_metal_device_i, + /* .reg = */ &g_ggml_backend_metal_reg, + /* .context = */ &g_ggml_ctx_dev_main, + }; + } + + return &g_ggml_backend_metal_reg; +} + +GGML_BACKEND_DL_IMPL(ggml_backend_metal_reg) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal new file mode 100644 index 0000000000000000000000000000000000000000..5d7760217f82644602fdb635f52943a8b9f89d56 --- /dev/null +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -0,0 +1,7094 @@ +#define GGML_COMMON_DECL_METAL +#define GGML_COMMON_IMPL_METAL +#if defined(GGML_METAL_EMBED_LIBRARY) +__embed_ggml-common.h__ +#else +#include "ggml-common.h" +#endif +#include "ggml-metal-impl.h" + +#include + +using namespace metal; + +#define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) +#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } + +#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 + +// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf +// +// cmd: +// .../usr/bin/metal -dM -E -c ggml/src/ggml-metal/ggml-metal.metal +// .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal/ggml-metal.metal +// +#if __METAL_VERSION__ < 310 && defined(GGML_METAL_USE_BF16) +#undef GGML_METAL_USE_BF16 +#endif + +#if defined(GGML_METAL_USE_BF16) +typedef matrix bfloat4x4; +#endif + +constexpr constant static float kvalues_iq4nl_f[16] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f +}; + +// NOTE: this is not dequantizing - we are simply fitting the template +template +void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { + reg = (type4x4)(*src); +} + +template +void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { + reg = (type4x4)(*src); +} + +template +void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) { + reg = (type4)(*(src)); +} + +#if defined(GGML_METAL_USE_BF16) +template +void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) { + reg = (type4x4)(*src); +} + +template +void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) { + reg = (type4)(*(src)); +} +#endif + +template +void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 1); + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float md = -8.h * xb->d; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + float4x4 reg_f; + + for (int i = 0; i < 8; i++) { + reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md; + reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md; + } + + reg = (type4x4) reg_f; +} + +template +void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 1); + const float d1 = (il/4) ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float md = -8.h * xb->d; + const ushort mask0 = (il/4) ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + for (int i = 0; i < 2; i++) { + reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md; + reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md; + } +} + +template +void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 2); + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float m = xb->m; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + float4x4 reg_f; + + for (int i = 0; i < 8; i++) { + reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m; + reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m; + } + + reg = (type4x4) reg_f; +} + +template +void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 2); + const float d1 = (il/4) ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float m = xb->m; + const ushort mask0 = (il/4) ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + for (int i = 0; i < 2; i++) { + reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m; + reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m; + } +} + +template +void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 3); + const float d = xb->d; + const float md = -16.h * xb->d; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + float4x4 reg_f; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg_f[i/2][2*(i%2) + 0] = d * x0 + md; + reg_f[i/2][2*(i%2) + 1] = d * x1 + md; + } + + reg = (type4x4) reg_f; +} + +template +void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 3); + const float d = xb->d; + const float md = -16.h * xb->d; + const ushort mask = (il/4) ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = (il/4) ? 4 : 0; + + const int gh_mv = (il/4) ? 12 : 0; + const int gh_bk = (il/4) ? 0 : 4; + + for (int ii = 0; ii < 2; ii++) { + int i = 2*(il%4) + ii; + + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[2*ii + 0] = d * x0 + md; + reg[2*ii + 1] = d * x1 + md; + } +} + +template +void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 4); + const float d = xb->d; + const float m = xb->m; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + float4x4 reg_f; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg_f[i/2][2*(i%2) + 0] = d * x0 + m; + reg_f[i/2][2*(i%2) + 1] = d * x1 + m; + } + + reg = (type4x4) reg_f; +} + +template +void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 4); + const float d = xb->d; + const float m = xb->m; + const ushort mask = (il/4) ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = (il/4) ? 4 : 0; + + const int gh_mv = (il/4) ? 12 : 0; + const int gh_bk = (il/4) ? 0 : 4; + + for (int ii = 0; ii < 2; ii++) { + int i = 2*(il%4) + ii; + + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[2*ii + 0] = d * x0 + m; + reg[2*ii + 1] = d * x1 + m; + } +} + +template +void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { + device const int8_t * qs = ((device const int8_t *)xb->qs); + const float d = xb->d; + + float4x4 reg_f; + + for (int i = 0; i < 16; i++) { + reg_f[i/4][i%4] = (qs[i + 16*il] * d); + } + + reg = (type4x4) reg_f; +} + +template +void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) { + device const int8_t * qs = ((device const int8_t *)xb->qs); + const float d = xb->d; + + for (int i = 0; i < 4; i++) { + reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d); + } +} + +template +void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { + const float d = xb->d; + const float min = xb->dmin; + device const uint8_t * q = (device const uint8_t *)xb->qs; + float dl, ml; + uint8_t sc = xb->scales[il]; + + q = q + 32*(il/8) + 16*(il&1); + il = (il/2)%4; + + half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4); + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - ml; + } +} + +template +void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { + const half d_all = xb->d; + device const uint8_t * q = (device const uint8_t *)xb->qs; + device const uint8_t * h = (device const uint8_t *)xb->hmask; + device const int8_t * scales = (device const int8_t *)xb->scales; + + q = q + 32 * (il/8) + 16 * (il&1); + h = h + 16 * (il&1); + uint8_t m = 1 << (il/2); + uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \ + ((il/4)>0 ? 12 : 3); + uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; + uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; + int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) + : (scale_2&kmask2) | ((scale_1&kmask1) << 4); + float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f); + const float ml = 4.f * dl; + + il = (il/2) & 3; + const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + dl *= coef; + + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); + } +} + +static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) { + return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)} + : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))}; +} + +template +void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) { + device const uchar * q = xb->qs; + + short is = (il/4) * 2; + q = q + (il/4) * 32 + 16 * (il&1); + il = il & 3; + const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); + const float d = il < 2 ? xb->d : xb->d / 16.h; + const float min = xb->dmin; + const float dl = d * sc[0]; + const float ml = min * sc[1]; + + const ushort mask = il < 2 ? 0x0F : 0xF0; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - ml; + } +} + +template +void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) { + device const uint8_t * q = xb->qs; + device const uint8_t * qh = xb->qh; + + short is = (il/4) * 2; + q = q + 32 * (il/4) + 16 * (il&1); + qh = qh + 16 * (il&1); + uint8_t ul = 1 << (il/2); + il = il & 3; + const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); + const float d = il < 2 ? xb->d : xb->d / 16.f; + const float min = xb->dmin; + const float dl = d * sc[0]; + const float ml = min * sc[1]; + + const ushort mask = il<2 ? 0x0F : 0xF0; + const float qh_val = il<2 ? 16.f : 256.f; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; + } +} + +template +void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { + const half d_all = xb->d; + device const uint16_t * ql = (device const uint16_t *)xb->ql; + device const uint16_t * qh = (device const uint16_t *)xb->qh; + device const int8_t * scales = (device const int8_t *)xb->scales; + + ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1); + qh = qh + 16*(il/8) + 8*(il&1); + float sc = scales[(il%2) + 2 * ((il/2))]; + il = (il/2) & 3; + + const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303); + const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F; + const float ml = d_all * sc * 32.f; + const float dl0 = d_all * sc; + const float dl1 = dl0 / 256.f; + const float dl2 = dl0 / (256.f * 256.f); + const float dl3 = dl0 / (256.f * 256.f * 256.f); + const uint8_t shr_h = il>2 ? 2 : 0; + const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4); + const uint8_t shr_l = il>1 ? 4 : 0; + for (int i = 0; i < 4; ++i) { + const uint32_t low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2; + const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1; + const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l); + reg[i][0] = dl0 * ((half)(q & 0xFF)) - ml; + reg[i][1] = dl1 * ((float)(q & 0xFF00)) - ml; + reg[i][2] = dl2 * ((float)(q & 0xFF0000)) - ml; + reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml; + } +} + +template +void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's. + device const uint16_t * q2 = xb->qs + 4*ib32; + const uint32_t aux32_g = q2[0] | (q2[1] << 16); + const uint32_t aux32_s = q2[2] | (q2[3] << 16); + thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g; + const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f; + constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]); + uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127]; + for (int i = 0; i < 8; ++i) { + reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } + grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]); + signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127]; + for (int i = 0; i < 8; ++i) { + reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } +} + +template +void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint16_t * q2 = xb->qs + 4*ib32; + const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; + constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511)); + uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9]; + for (int i = 0; i < 8; ++i) { + reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } + grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511)); + signs = ksigns_iq2xs[q2[2*il+1] >> 9]; + for (int i = 0; i < 8; ++i) { + reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } +} + +template +void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * q3 = xb->qs + 8*ib32; + device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32; + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f; + constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]); + constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]); + uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127]; + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); + reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); + } + grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]); + grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]); + signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127]; + for (int i = 0; i < 4; ++i) { + reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); + reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); + } +} + +template +void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * qs = xb->qs + 8*ib32; + device const uint8_t * signs = xb->signs + 4*ib32 + 2*il; + const uint8_t qh = xb->qh[ib32] >> 4*il; + const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)); + constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256))); + constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); + reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); + } + grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256))); + grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256))); + for (int i = 0; i < 4; ++i) { + reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); + reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); + } +} + +template +void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint8_t * signs = qs + QK_K/8; + const uint8_t qh = xb->qh[ib32] >> 4*il; + const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; + constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300))); + constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300))); + for (int i = 0; i < 8; ++i) { + reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]); + reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]); + } +} + +template +void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + const float d = xb->d; + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint16_t * qh = xb->qh; + const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1); + const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA); + const uint16_t h = qh[ib32] >> 6*il; + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * (grid1[i] & 0xf) + ml; + reg[1][i] = dl * (grid1[i] >> 4) + ml; + reg[2][i] = dl * (grid2[i] & 0xf) + ml; + reg[3][i] = dl * (grid2[i] >> 4) + ml; + } +} + +template +void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + device const uint16_t * sc = (device const uint16_t *)xb->scales; + + iq1m_scale_t scale; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + const float d = scale.f16; + + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint8_t * qh = xb->qh + 2*ib32 + il; + + const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1); + const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * (grid1[i] & 0xf) + ml1; + reg[1][i] = dl * (grid1[i] >> 4) + ml1; + reg[2][i] = dl * (grid2[i] & 0xf) + ml2; + reg[3][i] = dl * (grid2[i] >> 4) + ml2; + } +} + +template +void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) { + device const uint16_t * q4 = (device const uint16_t *)xb->qs; + const float d = xb->d; + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + for (int i = 0; i < 4; ++i) { + aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f; + reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; + reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; + reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; + reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; + } +} + +template +void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) { + device const uint16_t * q4 = (device const uint16_t *)xb->qs; + const float d = xb->d; + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f; + reg[0] = d * kvalues_iq4nl_f[q8[0]]; + reg[1] = d * kvalues_iq4nl_f[q8[1]]; + reg[2] = d * kvalues_iq4nl_f[q8[2]]; + reg[3] = d * kvalues_iq4nl_f[q8[3]]; +} + +template +void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32; + const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4); + const float d = (float)xb->d * (ls - 32); + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + for (int i = 0; i < 4; ++i) { + aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f; + reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; + reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; + reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; + reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; + } +} + +enum ggml_sort_order { + GGML_SORT_ORDER_ASC, + GGML_SORT_ORDER_DESC, +}; + +// general-purpose kernel for addition, subtraction, multiplication and division of two tensors +// pros: works for non-contiguous tensors, supports broadcast across all dims +// cons: not very efficient +kernel void kernel_add( + constant ggml_metal_kargs_bin & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; + + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; + + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) + *((device float *)(src1_ptr + i10*args.nb10)); + } +} + +kernel void kernel_sub( + constant ggml_metal_kargs_bin & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; + + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; + + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10)); + } +} + +kernel void kernel_mul( + constant ggml_metal_kargs_bin & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; + + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; + + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10)); + } +} + +kernel void kernel_div( + constant ggml_metal_kargs_bin & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; + + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; + + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); + } +} + +template +kernel void kernel_repeat( + constant ggml_metal_kargs_repeat & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; + + const int i03 = i3%args.ne03; + const int i02 = i2%args.ne02; + const int i01 = i1%args.ne01; + + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i00 = i0%args.ne00; + *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00)); + } +} + +typedef decltype(kernel_repeat) kernel_repeat_t; + +template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat; + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_add_row( + constant ggml_metal_kargs_bin & args, + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; + dst[tpig] = src0[tpig] + src1[tpig % nb]; +} + +kernel void kernel_sub_row( + constant ggml_metal_kargs_bin & args, + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; + dst[tpig] = src0[tpig] - src1[tpig % nb]; +} + +kernel void kernel_mul_row( + constant ggml_metal_kargs_bin & args, + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; + dst[tpig] = src0[tpig] * src1[tpig % nb]; +} + +kernel void kernel_div_row( + constant ggml_metal_kargs_bin & args, + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; + dst[tpig] = src0[tpig] / src1[tpig % nb]; +} + +kernel void kernel_scale( + device const float * src0, + device float * dst, + constant float & scale, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * scale; +} + +kernel void kernel_scale_4( + device const float4 * src0, + device float4 * dst, + constant float & scale, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * scale; +} + +kernel void kernel_clamp( + device const float * src0, + device float * dst, + constant float & min, + constant float & max, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]); +} + +kernel void kernel_relu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = max(0.0f, src0[tpig]); +} + +kernel void kernel_sigmoid( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); +} + +kernel void kernel_tanh( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = precise::tanh(x); +} + +constant float GELU_COEF_A = 0.044715f; +constant float GELU_QUICK_COEF = -1.702f; +constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; +constant float SQRT_2_INV = 0.70710678118654752440084436210484f; + +kernel void kernel_gelu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + // BEWARE !!! + // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! + // This was observed with Falcon 7B and 40B models + // + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_quick( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_gelu_quick_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation +// ref: https://www.johndcook.com/blog/python_erf/ +constant float p_erf = 0.3275911f; +constant float a1_erf = 0.254829592f; +constant float a2_erf = -0.284496736f; +constant float a3_erf = 1.421413741f; +constant float a4_erf = -1.453152027f; +constant float a5_erf = 1.061405429f; + +template +T erf_approx(T x) { + T sign_x = sign(x); + x = fabs(x); + T t = 1.0f / (1.0f + p_erf * x); + T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); + return sign_x * y; +} + +kernel void kernel_gelu_erf( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV)); +} + +kernel void kernel_gelu_erf_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV)); +} + +kernel void kernel_silu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = x / (1.0f + exp(-x)); +} + +kernel void kernel_silu_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + dst[tpig] = x / (1.0f + exp(-x)); +} + +kernel void kernel_elu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f); +} + +kernel void kernel_sqr( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src0[tpig]; +} + +kernel void kernel_sqrt( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sqrt(src0[tpig]); +} + +kernel void kernel_sin( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sin(src0[tpig]); +} + +kernel void kernel_cos( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = cos(src0[tpig]); +} + +kernel void kernel_neg( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = -src0[tpig]; +} + +kernel void kernel_sum_rows( + device const float * src0, + device float * dst, + constant ggml_metal_kargs_sum_rows & args, + uint3 tpig[[thread_position_in_grid]]) { + int64_t i3 = tpig.z; + int64_t i2 = tpig.y; + int64_t i1 = tpig.x; + + if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { + return; + } + + device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); + device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); + + float row_sum = 0; + + for (int64_t i0 = 0; i0 < args.ne00; i0++) { + row_sum += src_row[i0]; + } + + dst_row[0] = row_sum; +} + +template +kernel void kernel_soft_max( + device const char * src0, + device const char * src1, + device char * dst, + constant ggml_metal_kargs_soft_max & args, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (args.ne02*args.ne01); + const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01; + const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); + + device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr; + device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); + + float slope = 1.0f; + + // ALiBi + if (args.max_bias > 0.0f) { + const int64_t h = i02; + + const float base = h < args.n_head_log2 ? args.m0 : args.m1; + const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float lmax = -INFINITY; + + for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { + lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + + // find the max value in the block + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float lsum = 0.0f; + for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { + const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); + lsum += exp_psrc0; + pdst[i00] = exp_psrc0; + } + + // This barrier fixes a failing test + // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { + pdst[i00] *= inv_sum; + } +} + +template +kernel void kernel_soft_max_4( + device const char * src0, + device const char * src1, + device char * dst, + constant ggml_metal_kargs_soft_max & args, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (args.ne02*args.ne01); + const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01; + const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); + + device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr; + device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; + + float slope = 1.0f; + + if (args.max_bias > 0.0f) { + const int64_t h = i02; + + const float base = h < args.n_head_log2 ? args.m0 : args.m1; + const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float4 lmax4 = -INFINITY; + + for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { + lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); + } + + const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); + + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { + const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + + const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; + + // This barrier fixes a failing test + // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { + pdst4[i00] *= inv_sum; + } +} + +typedef decltype(kernel_soft_max) kernel_soft_max_t; +typedef decltype(kernel_soft_max_4) kernel_soft_max_4_t; + +template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; +template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; + +kernel void kernel_diag_mask_inf( + device const float * src0, + device float * dst, + constant ggml_metal_kargs_diag_mask_inf & args, + uint3 tpig[[thread_position_in_grid]]) { + const int64_t i02 = tpig[2]; + const int64_t i01 = tpig[1]; + const int64_t i00 = tpig[0]; + + if (i00 > args.n_past + i01) { + dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = -INFINITY; + } else { + dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = src0[i02*args.ne01*args.ne00 + i01*args.ne00 + i00]; + } +} + +kernel void kernel_diag_mask_inf_8( + device const float4 * src0, + device float4 * dst, + constant ggml_metal_kargs_diag_mask_inf & args, + uint3 tpig[[thread_position_in_grid]]) { + + const int64_t i = 2*tpig[0]; + + dst[i+0] = src0[i+0]; + dst[i+1] = src0[i+1]; + int64_t i4 = 4*i; + const int64_t i02 = i4/(args.ne00*args.ne01); i4 -= i02*args.ne00*args.ne01; + const int64_t i01 = i4/(args.ne00); i4 -= i01*args.ne00; + const int64_t i00 = i4; + for (int k = 3; k >= 0; --k) { + if (i00 + 4 + k <= args.n_past + i01) { + break; + } + dst[i+1][k] = -INFINITY; + if (i00 + k > args.n_past + i01) { + dst[i][k] = -INFINITY; + } + } +} + +// ref: ggml.c:ggml_compute_forward_ssm_conv_f32 +kernel void kernel_ssm_conv_f32( + device const void * src0, + device const void * src1, + device float * dst, + constant ggml_metal_kargs_ssm_conv & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t ir = tgpig.x; + const int64_t i2 = tgpig.y; + const int64_t i3 = tgpig.z; + + const int64_t nc = args.ne10; + //const int64_t ncs = args.ne00; + //const int64_t nr = args.ne01; + //const int64_t n_t = args.ne1; + //const int64_t n_s = args.ne2; + + device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11); + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); + + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc; ++i0) { + sumf += s[i0] * c[i0]; + } + + x[0] = sumf; +} + +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32 +kernel void kernel_ssm_scan_f32( + device const void * src0, + device const void * src1, + device const void * src2, + device const void * src3, + device const void * src4, + device const void * src5, + device float * dst, + constant ggml_metal_kargs_ssm_scan & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t ir = tgpig.x; + const int64_t i3 = tgpig.y; + + const int64_t nc = args.d_state; + // const int64_t nr = args.d_inner; + const int64_t n_t = args.n_seq_tokens; + // const int64_t n_s = args.n_seqs; + + for (int64_t i2 = 0; i2 < n_t; ++i2) { + device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02); + device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); + device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22); + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); + device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42); + device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52); + device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides + device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13); + + if (i2 > 0) { + s0 = s; + } + + // i1 == 0 + float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; + float x_dt = x[0] * dt_soft_plus; + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc; ++i0) { + int64_t i = i0; + float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt); + sumf += state * C[i0]; + s[i] = state; + } + + y[0] = sumf; + } +} + +kernel void kernel_rwkv_wkv6_f32( + device const float * k, + device const float * v, + device const float * r, + device const float * tf, + device const float * td, + device const float * state_in, + device float * dst, + constant uint & B, + constant uint & T, + constant uint & C, + constant uint & H, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const uint head_size = 64; // TODO: support head_size = 128 + const uint batch_id = tgpig.x / H; + const uint head_id = tgpig.x % H; + const uint tid = tpitg.x; + + if (batch_id >= B || head_id >= H) { + return; + } + + const uint state_size = C * head_size; + const uint n_seq_tokens = T / B; + + threadgroup float _k[head_size]; + threadgroup float _r[head_size]; + threadgroup float _tf[head_size]; + threadgroup float _td[head_size]; + + float state[head_size]; + + for (uint i = 0; i < head_size; i++) { + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + _tf[tid] = tf[head_id * head_size + tid]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; + const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; + + for (uint t = start_t; t < end_t; t += C) { + threadgroup_barrier(mem_flags::mem_threadgroup); + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float v_val = v[t]; + float y = 0.0; + + for (uint j = 0; j < head_size; j += 4) { + float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); + float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]); + float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); + + float4 kv = k_vec * v_val; + + float4 temp = tf_vec * kv + s_vec; + y += dot(r_vec, temp); + + s_vec = s_vec * td_vec + kv; + state[j] = s_vec[0]; + state[j+1] = s_vec[1]; + state[j+2] = s_vec[2]; + state[j+3] = s_vec[3]; + } + + dst[t] = y; + } + + for (uint i = 0; i < head_size; i++) { + dst[T * C + batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid] = state[i]; + } +} + +kernel void kernel_rwkv_wkv7_f32( + device const float * r, + device const float * w, + device const float * k, + device const float * v, + device const float * a, + device const float * b, + device const float * state_in, + device float * dst, + constant uint & B, + constant uint & T, + constant uint & C, + constant uint & H, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const uint head_size = 64; // TODO: support head_size = 128 + const uint batch_id = tgpig.x / H; + const uint head_id = tgpig.x % H; + const uint tid = tpitg.x; + + if (batch_id >= B || head_id >= H) { + return; + } + + const uint state_size = C * head_size; + const uint n_seq_tokens = T / B; + + threadgroup float _r[head_size]; + threadgroup float _w[head_size]; + threadgroup float _k[head_size]; + threadgroup float _a[head_size]; + threadgroup float _b[head_size]; + + float state[head_size]; + + for (uint i = 0; i < head_size; i++) { + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + + tid * head_size + i]; + } + + const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; + const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; + + for (uint t = start_t; t < end_t; t += C) { + threadgroup_barrier(mem_flags::mem_threadgroup); + _r[tid] = r[t]; + _w[tid] = w[t]; + _k[tid] = k[t]; + _a[tid] = a[t]; + _b[tid] = b[t]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float v_val = v[t]; + float y = 0.0, sa = 0.0; + + float4 sa_vec(0.0); + + for (uint j = 0; j < head_size; j += 4) { + float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]); + float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); + sa_vec += a_vec * s_vec; + } + sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3]; + + for (uint j = 0; j < head_size; j += 4) { + float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]); + float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]); + float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); + + float4 kv = k_vec * v_val; + + s_vec = s_vec * w_vec + kv + sa * b_vec; + y += dot(s_vec, r_vec); + + state[j] = s_vec[0]; + state[j+1] = s_vec[1]; + state[j+2] = s_vec[2]; + state[j+3] = s_vec[3]; + } + + dst[t] = y; + } + + for (uint i = 0; i < head_size; i++) { + dst[T * C + batch_id * state_size + head_id * head_size * head_size + + tid * head_size + i] = state[i]; + } +} + +kernel void kernel_argmax( + device const void * x, + device int32_t * dst, + constant int64_t & ncols, + constant uint64_t & nb01, + threadgroup float * shared_maxval [[threadgroup(0)]], + threadgroup int32_t * shared_argmax [[threadgroup(1)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * x_row = (device const float *) ((device const char *) x + tgpig * nb01); + + float lmax = -INFINITY; + int32_t larg = -1; + + for (int i00 = tpitg; i00 < ncols; i00 += ntg) { + if (x_row[i00] > lmax) { + lmax = x_row[i00]; + larg = i00; + } + } + + // find the argmax value in the block + float max_val = simd_max(lmax); + int32_t arg_val = simd_max(select(-1, larg, lmax == max_val)); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + shared_maxval[tiisg] = -INFINITY; + shared_argmax[tiisg] = -1; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shared_maxval[sgitg] = max_val; + shared_argmax[sgitg] = arg_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = shared_maxval[tiisg]; + arg_val = shared_argmax[tiisg]; + + float max_val_reduced = simd_max(max_val); + int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced)); + + dst[tgpig] = arg_val_reduced; + + return; + } + + dst[tgpig] = arg_val; +} + +kernel void kernel_norm( + constant ggml_metal_kargs_norm & args, + device const char * src0, + device char * dst, + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort ntg[[threads_per_threadgroup]]) { + if (sgitg == 0) { + shmem_f32[tiisg] = 0.0f; + } + + device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + + float4 sumf4(0.0f); + + float sumf = 0.0f; + + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + sumf4 += x[i00]; + } + sumf = sumf4[0] + sumf4[1] + sumf4[2] + sumf4[3]; + sumf = simd_sum(sumf); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float mean = sumf/args.ne00; + + device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; + + sumf = 0.0f; + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + y[i00] = x[i00] - mean; + sumf += dot(y[i00], y[i00]); + } + sumf = simd_sum(sumf); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float variance = sumf/args.ne00; + + const float scale = 1.0f/sqrt(variance + args.eps); + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + y[i00] = y[i00] * scale; + } +} + +kernel void kernel_rms_norm( + constant ggml_metal_kargs_rms_norm & args, + device const char * src0, + device char * dst, + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort ntg[[threads_per_threadgroup]]) { + if (sgitg == 0) { + shmem_f32[tiisg] = 0.0f; + } + + device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + + float sumf = 0.0f; + + // parallel sum + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + sumf += dot(x[i00], x[i00]); + } + sumf = simd_sum(sumf); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float mean = sumf/args.ne00; + const float scale = 1.0f/sqrt(mean + args.eps); + + device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + y[i00] = x[i00] * scale; + } +} + +kernel void kernel_l2_norm( + constant ggml_metal_kargs_l2_norm & args, + device const char * src0, + device char * dst, + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort ntg[[threads_per_threadgroup]]) { + if (sgitg == 0) { + shmem_f32[tiisg] = 0.0f; + } + + device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + + float sumf = 0.0f; + + // parallel sum + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + sumf += dot(x[i00], x[i00]); + } + sumf = simd_sum(sumf); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float scale = 1.0f/sqrt(max(sumf, args.eps)); + + device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + y[i00] = x[i00] * scale; + } +} + +kernel void kernel_group_norm( + device const float * src0, + device float * dst, + constant ggml_metal_kargs_group_norm & args, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t ne = args.ne00*args.ne01*args.ne02; + const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.n_groups - 1) / args.n_groups); + + int start = tgpig * gs; + int end = start + gs; + + start += tpitg; + + if (end >= ne) { + end = ne; + } + + float tmp = 0.0f; // partial sum for thread in warp + + for (int j = start; j < end; j += ntg) { + tmp += src0[j]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + tmp = simd_sum(tmp); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = tmp; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + tmp = buf[tiisg]; + tmp = simd_sum(tmp); + } + + const float mean = tmp / gs; + tmp = 0.0f; + + for (int j = start; j < end; j += ntg) { + float xi = src0[j] - mean; + dst[j] = xi; + tmp += xi * xi; + } + + tmp = simd_sum(tmp); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = tmp; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + tmp = buf[tiisg]; + tmp = simd_sum(tmp); + } + + const float variance = tmp / gs; + const float scale = 1.0f/sqrt(variance + args.eps); + for (int j = start; j < end; j += ntg) { + dst[j] *= scale; + } +} + +// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; + + device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2); + + for (int i = 0; i < 8; i += 2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F); + acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0); + acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000); + } + + return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]); +} + +// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + float m = qb_curr->m; + + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; + + device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F); + acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0); + acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000); + } + + return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; +} + +// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)); + acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)); + acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + + return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]); +} + +// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_1/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + float m = qb_curr->m; + + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)); + acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)); + acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + + return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; +} + +template +void mul_vec_q_n_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const int nb = args.ne00/QK4_0; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + // pointers to src0 rows + device const block_q_type * ax[nr0]; + for (int row = 0; row < nr0; ++row) { + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); + } + + float yl[16]; // src1 vector cache + float sumf[nr0] = {0.f}; + + const short ix = (tiisg/2); + const short il = (tiisg%2)*8; + + device const float * yb = y + ix*QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += nw/2) { + float sumy[2] = { 0.f, 0.f }; + +#pragma unroll + for (short i = 0; i < 8; i += 2) { + sumy[0] += yb[i + 0] + yb[i + 1]; + yl[i + 0] = yb[i + 0]; + yl[i + 1] = yb[i + 1]/256.f; + + sumy[1] += yb[i + 16] + yb[i + 17]; + yl[i + 8] = yb[i + 16]/16.f; + yl[i + 9] = yb[i + 17]/4096.f; + } + +#pragma unroll + for (short row = 0; row < nr0; row++) { + sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); + } + + yb += QK4_0 * 16; + } + + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + + for (int row = 0; row < nr0; ++row) { + const float tot = simd_sum(sumf[row]); + + if (tiisg == 0 && first_row + row < args.ne01) { + dst_f32[first_row + row] = tot; + } + } +} + +kernel void kernel_mul_mv_q4_0_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +kernel void kernel_mul_mv_q4_1_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +kernel void kernel_mul_mv_q5_0_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +kernel void kernel_mul_mv_q5_1_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +#define NB_Q8_0 8 + +template +void kernel_mul_mv_q8_0_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const int nb = args.ne00/QK8_0; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + // pointers to src0 rows + device const block_q8_0 * ax[nr0]; + for (int row = 0; row < nr0; ++row) { + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); + } + + float yl[NB_Q8_0]; + float sumf[nr0] = { 0.f }; + + const short ix = tiisg/4; + const short il = tiisg%4; + + device const float * yb = y + ix*QK8_0 + il*NB_Q8_0; + + // each thread in a SIMD group deals with NB_Q8_0 quants at a time + for (int ib = ix; ib < nb; ib += nw/4) { + for (short i = 0; i < NB_Q8_0; ++i) { + yl[i] = yb[i]; + } + + for (short row = 0; row < nr0; row++) { + device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; + float sumq = 0.f; + for (short iq = 0; iq < NB_Q8_0; ++iq) { + sumq += qs[iq] * yl[iq]; + } + sumf[row] += sumq*ax[row][ib].d; + } + + yb += nw*NB_Q8_0; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < nr0; ++row) { + const float tot = simd_sum(sumf[row]); + + if (tiisg == 0 && first_row + row < args.ne01) { + dst_f32[first_row + row] = tot; + } + } +} + +[[host_name("kernel_mul_mv_q8_0_f32")]] +kernel void kernel_mul_mv_q8_0_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +// mat-vec kernel processing in chunks of float4 +// chpb - chunks per quantization block +template +void kernel_mul_mv_ext_q4_f32_impl( + constant ggml_metal_kargs_mul_mv_ext & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short chpt = 4; // chunks per thread + + //const short nxpsg = (32); + const short nypsg = (32/nxpsg); + + const short tx = tiisg%nxpsg; + const short ty = tiisg/nxpsg; + + const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty; + const int i11 = tgpig.y*r1ptg; + const int i1m = tgpig.z; + + const int i12 = i1m%args.ne12; + const int i13 = i1m/args.ne12; + + const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; + + device const float4 * y4[r1ptg]; + + for (int ir1 = 0; ir1 < r1ptg; ++ir1) { + y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1; + } + + float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f }; + + short cch = tx%chpb; // current chunk index + + for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) { + float4 lx[chpt]; + +#pragma unroll(chpt) + for (short ch = 0; ch < chpt; ++ch) { + deq_t4(xq, cch, lx[ch]); + + cch += nxpsg; + if (cch >= chpb) { + xq += cch/chpb; + cch %= chpb; + } + } + +#pragma unroll(chpt) + for (short ch = 0; ch < chpt; ++ch) { +#pragma unroll(r1ptg) + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]); + + } + } + +#pragma unroll(r1ptg) + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + y4[ir1] += chpt*nxpsg; + } + } + + // reduce only the threads in each row + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + if (nxpsg >= 32) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 16); + } + if (nxpsg >= 16) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 8); + } + if (nxpsg >= 8) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 4); + } + if (nxpsg >= 4) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 2); + } + if (nxpsg >= 2) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 1); + } + + //sumf[ir1] = simd_sum(sumf[ir1]); + } + + if (tx == 0) { + for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) { + device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0; + + if (i01 < args.ne01) { + dst_f32[i01] = sumf[ir1]; + } + } + } +} + +// mat-vec kernel processing in chunks of float4x4 +template +void kernel_mul_mv_ext_q4x4_f32_impl( + constant ggml_metal_kargs_mul_mv_ext & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short chpt = 1; + + //const short nxpsg = (32); + const short nypsg = (32/nxpsg); + + const short tx = tiisg%nxpsg; + const short ty = tiisg/nxpsg; + + const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty; + const int i11 = tgpig.y*r1ptg; + const int i1m = tgpig.z; + + const int i12 = i1m%args.ne12; + const int i13 = i1m/args.ne12; + + const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; + + device const float4x4 * y4x4[r1ptg]; + + for (int ir1 = 0; ir1 < r1ptg; ++ir1) { + y4x4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4x4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4x4 *) src1; + } + + float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f }; + + short cch = tx%chpb; + + for (int ich = tx; 16*ich < args.ne00; ich += chpt*nxpsg) { + float4x4 lx[chpt]; + +#pragma unroll(chpt) + for (short ch = 0; ch < chpt; ++ch) { + deq_t4x4(xq, cch, lx[ch]); + + cch += nxpsg; + if (cch >= chpb) { + xq += cch/chpb; + cch %= chpb; + } + } + +#pragma unroll(chpt) + for (short ch = 0; ch < chpt; ++ch) { +#pragma unroll(r1ptg) + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + sumf[ir1] += + dot(lx[ch][0], y4x4[ir1][ch*nxpsg][0]) + + dot(lx[ch][1], y4x4[ir1][ch*nxpsg][1]) + + dot(lx[ch][2], y4x4[ir1][ch*nxpsg][2]) + + dot(lx[ch][3], y4x4[ir1][ch*nxpsg][3]); + + } + } + +#pragma unroll(r1ptg) + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + y4x4[ir1] += chpt*nxpsg; + } + } + + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + if (nxpsg >= 32) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 16); + } + if (nxpsg >= 16) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 8); + } + if (nxpsg >= 8) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 4); + } + if (nxpsg >= 4) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 2); + } + if (nxpsg >= 2) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 1); + } + + //sumf[ir1] = simd_sum(sumf[ir1]); + } + + if (tx == 0) { + for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) { + device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0; + + if (i01 < args.ne01) { + dst_f32[i01] = sumf[ir1]; + } + } + } +} + +// dispatchers needed for compile-time nxpsg +// epb - elements per quantization block +template +kernel void kernel_mul_mv_ext_q4_f32_disp( + constant ggml_metal_kargs_mul_mv_ext & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + switch (args.nxpsg) { + case 4: kernel_mul_mv_ext_q4_f32_impl<4, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 8: kernel_mul_mv_ext_q4_f32_impl<8, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 16: kernel_mul_mv_ext_q4_f32_impl<16, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 32: kernel_mul_mv_ext_q4_f32_impl<32, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + } +} + +template +kernel void kernel_mul_mv_ext_q4x4_f32_disp( + constant ggml_metal_kargs_mul_mv_ext & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + switch (args.nxpsg) { + case 4: kernel_mul_mv_ext_q4x4_f32_impl<4, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 8: kernel_mul_mv_ext_q4x4_f32_impl<8, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 16: kernel_mul_mv_ext_q4x4_f32_impl<16, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 32: kernel_mul_mv_ext_q4x4_f32_impl<32, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + } +} + +typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t; +typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t; + +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>; +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>; +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>; +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>; + +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>; +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>; +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>; +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_0, 32, dequantize_q4_0_t4>; + +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_1, 32, dequantize_q4_1_t4>; +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_1, 32, dequantize_q4_1_t4>; +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_1, 32, dequantize_q4_1_t4>; +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_1, 32, dequantize_q4_1_t4>; + +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_0, 32, dequantize_q5_0_t4>; +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_0, 32, dequantize_q5_0_t4>; +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_0, 32, dequantize_q5_0_t4>; +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_0, 32, dequantize_q5_0_t4>; + +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_1, 32, dequantize_q5_1_t4>; +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_1, 32, dequantize_q5_1_t4>; +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_1, 32, dequantize_q5_1_t4>; +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_1, 32, dequantize_q5_1_t4>; + +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q8_0, 32, dequantize_q8_0_t4>; +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q8_0, 32, dequantize_q8_0_t4>; +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>; +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>; + +template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>; +template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>; +template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>; +template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_iq4_nl, 32, dequantize_iq4_nl_t4>; + +template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>; +template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q4_K, 256, dequantize_q4_K>; +template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q4_K, 256, dequantize_q4_K>; +template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q4_K, 256, dequantize_q4_K>; + +template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q5_K, 256, dequantize_q5_K>; +template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q5_K, 256, dequantize_q5_K>; +template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q5_K, 256, dequantize_q5_K>; +template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q5_K, 256, dequantize_q5_K>; + +template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q6_K, 256, dequantize_q6_K>; +template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q6_K, 256, dequantize_q6_K>; +template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>; +template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>; + +#define N_MV_T_T 4 + +template +void kernel_mul_mv_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig, + ushort tiisg) { + const int r0 = tgpig.x; + const int rb = tgpig.y*N_MV_T_T; + const int im = tgpig.z; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + device const T0 * x = (device const T0 *) (src0 + offset0); + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; + + if (args.ne00 < 128) { + for (int row = 0; row < N_MV_T_T; ++row) { + int r1 = rb + row; + if (r1 >= args.ne11) { + break; + } + + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const T1 * y = (device const T1 *) (src1 + offset1); + + float sumf = 0; + for (int i = tiisg; i < args.ne00; i += 32) { + sumf += (T0) x[i] * (T1) y[i]; + } + + float sum_all = simd_sum(sumf); + if (tiisg == 0) { + dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; + } + } + } else { + device const T04 * x4 = (device const T04 *) x; + for (int row = 0; row < N_MV_T_T; ++row) { + int r1 = rb + row; + if (r1 >= args.ne11) { + break; + } + + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const T1 * y = (device const T1 *) (src1 + offset1); + device const T14 * y4 = (device const T14 *) y; + + float sumf = 0; + for (int i = tiisg; i < args.ne00/4; i += 32) { + sumf += dot((float4) x4[i], (float4) y4[i]); + } + + float sum_all = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]); + dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; + } + } + } +} + +template +kernel void kernel_mul_mv( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_impl( + args, + src0, + src1, + dst, + tgpig, + tiisg); +} + +typedef decltype(kernel_mul_mv) mul_mv_t; + +template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv; +template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv; +template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv; +template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv; +#endif + +template +kernel void kernel_mul_mv_1row( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const T * x = (device const T *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + float sumf = 0; + if (args.ne00 < 128) { + for (int i = tiisg; i < args.ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + float sum_all = simd_sum(sumf); + if (tiisg == 0) { + dst_f32[r0] = sum_all; + } + } else { + device const T4 * x4 = (device const T4 *) x; + device const float4 * y4 = (device const float4 *) y; + + for (int i = tiisg; i < args.ne00/4; i += 32) { + sumf += dot((float4) x4[i], y4[i]); + } + + float sum_all = simd_sum(sumf); + + if (tiisg == 0) { + for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]); + dst_f32[r0] = sum_all; + } + } +} + +typedef decltype(kernel_mul_mv_1row) mul_mv_1row_t; + +template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row; +#endif + +// Assumes row size (ne00) is a multiple of 4 +template +kernel void kernel_mul_mv_l4( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { + + const int nrows = args.ne11; + const int r0 = tgpig.x; + const int im = tgpig.z; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + device const T4 * x4 = (device const T4 *) (src0 + offset0); + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; + + for (int r1 = 0; r1 < nrows; ++r1) { + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const float4 * y4 = (device const float4 *) (src1 + offset1); + + float sumf = 0; + for (int i = tiisg; i < args.ne00/4; i += 32) { + sumf += dot((float4) x4[i], y4[i]); + } + + float sum_all = simd_sum(sumf); + if (tiisg == 0) { + dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; + } + } +} + +typedef decltype(kernel_mul_mv_l4) mul_mv_l4_t; + +template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4; +#endif + +static float rope_yarn_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +static void rope_yarn( + float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale, + thread float * cos_theta, thread float * sin_theta) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); + } + *cos_theta = cos(theta) * mscale; + *sin_theta = sin(theta) * mscale; +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base)); +} + +static void rope_yarn_corr_dims( + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2] +) { + // start and end correction dims + dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))); + dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))); +} + +template +kernel void kernel_rope_norm( + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); + + device const int32_t * pos = (device const int32_t *) src1; + + const float theta_base = (float) pos[i2]; + const float inv_ndims = -1.f/args.n_dims; + + float cos_theta; + float sin_theta; + + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < args.n_dims) { + const int ic = i0/2; + + const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + const float x0 = src[0]; + const float x1 = src[1]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; + } else { + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +template +kernel void kernel_rope_neox( + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); + + device const int32_t * pos = (device const int32_t *) src1; + + const float theta_base = (float) pos[i2]; + const float inv_ndims = -1.f/args.n_dims; + + float cos_theta; + float sin_theta; + + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < args.n_dims) { + const int ic = i0/2; + + const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); + + const float x0 = src[0]; + const float x1 = src[args.n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +template +kernel void kernel_rope_multi( + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); + + device const int32_t * pos = (device const int32_t *) src1; + + const float inv_ndims = -1.f/args.n_dims; + + float cos_theta; + float sin_theta; + + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < args.n_dims) { + const int ic = i0/2; + + // mrope theta calculations + // note: the rest is the same as kernel_rope_neox + const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3; + const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1 + const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2 + const int sector = ic % sect_dims; + + float theta_base; + if (sector < args.sect_0) { + theta_base = (float) pos[i2]; + } else if (sector < sec_w01) { + theta_base = (float) pos[i2 + args.ne02]; + } else if (sector < sec_w012) { + theta_base = (float) pos[i2 + args.ne02 * 2]; + } else { + theta_base = (float) pos[i2 + args.ne02 * 3]; + } + // end of mrope + + const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); + + const float x0 = src[0]; + const float x1 = src[args.n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +template +kernel void kernel_rope_vision( + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); + + device const int32_t * pos = (device const int32_t *) src1; + + const float inv_ndims = -1.f/args.n_dims; + + float cos_theta; + float sin_theta; + + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < 2*args.n_dims) { // different from kernel_rope_multi + const int ic = i0/2; + + // mrope theta calculations (only support 2 dimensions) + const int sect_dims = args.sect_0 + args.sect_1; + const int sector = ic % sect_dims; + + float p; + float theta_base; + if (sector < args.sect_1) { + p = (float) sector; + theta_base = (float) pos[i2]; + } else { + p = (float) sector - args.sect_0; + theta_base = (float) pos[i2 + args.ne02]; + } + + const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p); + // end of mrope + + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); + + const float x0 = src[0]; + const float x1 = src[args.n_dims]; // different from kernel_rope_multi + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi + } else { + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +typedef decltype(kernel_rope_norm) kernel_rope_norm_t; +typedef decltype(kernel_rope_neox) kernel_rope_neox_t; +typedef decltype(kernel_rope_multi) kernel_rope_multi_t; +typedef decltype(kernel_rope_vision) kernel_rope_vision_t; + +template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm; +template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm; + +template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox; +template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox; + +template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi; +template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi; + +template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision; +template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision; + +typedef void (im2col_t)( + device const float * x, + device char * dst, + constant ggml_metal_kargs_im2col & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template +kernel void kernel_im2col( + device const float * x, + device char * dst, + constant ggml_metal_kargs_im2col & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { +// const int64_t IC = tgpg[0]; + const int64_t OH = tgpg[1]; + const int64_t OW = tgpg[2]; + +// const int64_t N = ntg[0]; + const int64_t KH = ntg[1]; + const int64_t KW = ntg[2]; + + const int64_t in = tpitg[0]; + const int64_t ikh = tpitg[1]; + const int64_t ikw = tpitg[2]; + + const int64_t iic = tgpig[0]; + const int64_t ioh = tgpig[1]; + const int64_t iow = tgpig[2]; + + const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0; + const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1; + + const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw); + + device T * pdst = (device T *) (dst); + + if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { + pdst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw; + pdst[offset_dst] = x[offset_src]; + } +} + +template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; +template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; + +typedef void (im2col_ext_t)( + device const float * x, + device char * dst, + constant ggml_metal_kargs_im2col & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template +kernel void kernel_im2col_ext( + device const float * x, + device char * dst, + constant ggml_metal_kargs_im2col & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] + const int64_t KHW = (int64_t)args.KHW; + + const int64_t d = tgpig[0] / args.CHW; + const int64_t chw = tgpig[0] % args.CHW; + const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) + const int64_t HW = tgpig[0] % KHW; + + const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; + if (tpitg_0 >= args.N) { + return; + } + + const int64_t tpitg_1 = HW / args.KW; + const int64_t tpitg_2 = HW % args.KW; + + const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; + const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; + + const int64_t offset_dst = + (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + + (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); + + device T * pdst = (device T *) (dst); + + if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { + pdst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; + pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; + } +} + +template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext; +template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext; + +typedef void (conv_transpose_1d_t)( + device const float * src0, + device const float * src1, + device char * dst, + constant ggml_metal_kargs_conv_transpose_1d & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]]); + +template +kernel void kernel_conv_transpose_1d( + device const T * src0, + device const float * src1, + device char * dst, + constant ggml_metal_kargs_conv_transpose_1d & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]]) { + + float v = 0.0f; + + for (int64_t c = 0; c < args.IC; c++) { + const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1]; + const int32_t input_offset = c * args.IL; + + for (int64_t i = 0; i < args.IL; i++) { + if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) { + v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i]; + } + } + } + + device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1); + + dst_ptr[0] = v; +} + +template [[host_name("kernel_conv_transpose_1d_f32_f32")]] +kernel void kernel_conv_transpose_1d( + device const float * src0, + device const float * src1, + device char * dst, + constant ggml_metal_kargs_conv_transpose_1d & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]]); + +template [[host_name("kernel_conv_transpose_1d_f16_f32")]] +kernel void kernel_conv_transpose_1d( + device const half * src0, + device const float * src1, + device char * dst, + constant ggml_metal_kargs_conv_transpose_1d & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]]); + +kernel void kernel_upscale_f32( + device const char * src0, + device char * dst, + constant ggml_metal_kargs_upscale & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3/args.sf3; + const int64_t i02 = i2/args.sf2; + const int64_t i01 = i1/args.sf1; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int64_t i00 = i0/args.sf0; + + device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + dst_ptr[0] = src0_ptr[0]; + } +} + +kernel void kernel_pad_f32( + device const char * src0, + device char * dst, + constant ggml_metal_kargs_pad & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3; + const int64_t i02 = i2; + const int64_t i01 = i1; + + device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); + + if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + if (i0 < args.ne00) { + dst_ptr[i0] = src0_ptr[i0]; + } else { + dst_ptr[i0] = 0.0f; + } + } + + return; + } + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + dst_ptr[i0] = 0.0f; + } +} + +kernel void kernel_pad_reflect_1d_f32( + device const char * src0, + device char * dst, + constant ggml_metal_kargs_pad_reflect_1d & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3; + const int64_t i02 = i2; + const int64_t i01 = i1; + + device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); + + if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + if (i0 < args.p0) { + dst_ptr[i0] = src0_ptr[args.p0 - i0]; + } else if (i0 < args.ne0 - args.p1) { + dst_ptr[i0] = src0_ptr[i0 - args.p0]; + } else { + dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1]; + } + } + } +} + +kernel void kernel_arange_f32( + device char * dst, + constant ggml_metal_kargs_arange & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + device float * dst_ptr = (device float *) dst; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + dst_ptr[i0] = args.start + args.step * i0; + } +} + +kernel void kernel_timestep_embedding_f32( + device const char * src0, + device char * dst, + constant ggml_metal_kargs_timestep_embedding & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + int i = tgpig.x; + device float * embed_data = (device float *)(dst + i*args.nb1); + + int half_ = args.dim / 2; + for (int j = tpitg.x; j < half_; j += ntg.x) { + float timestep = ((device float *)src0)[i]; + float freq = (float)exp(-log((float)args.max_period) * j / half_); + float arg = timestep * freq; + embed_data[j ] = cos(arg); + embed_data[j + half_] = sin(arg); + } + + if (args.dim % 2 != 0 && tpitg.x == 0) { + embed_data[args.dim] = 0.f; + } +} + +// bitonic sort implementation following the CUDA kernels as reference +typedef void (argsort_t)( + device const float * x, + device int32_t * dst, + constant ggml_metal_kargs_argsort & args, + threadgroup int32_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]); + +template +kernel void kernel_argsort_f32_i32( + device const float * x, + device int32_t * dst, + constant ggml_metal_kargs_argsort & args, + threadgroup int32_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]) { + // bitonic sort + int col = tpitg[0]; + int row = tgpig[1]; + + if (col >= args.ncols_pad) return; + + device const float * x_row = x + row * args.ncols; + threadgroup int32_t * dst_row = shared_values; + + // initialize indices + dst_row[col] = col; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int k = 2; k <= args.ncols_pad; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= args.ncols || + (dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj]); + } + } else { + if (dst_row[ixj] >= args.ncols || + (dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj]); + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } + + // copy the result to dst without the padding + if (col < args.ncols) { + dst[row * args.ncols + col] = dst_row[col]; + } +} + +template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; +template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; + +kernel void kernel_leaky_relu_f32( + device const float * src0, + device float * dst, + constant ggml_metal_kargs_leaky_relu & args, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * args.slope; +} + +// ref: https://arxiv.org/pdf/2307.08691.pdf +template< + typename q_t, // query types in shared memory + typename q4_t, + typename q8x8_t, + typename k_t, // key types in shared memory + typename k4x4_t, + typename k8x8_t, + typename v_t, // value types in shared memory + typename v4x4_t, + typename v8x8_t, + typename qk_t, // Q*K types + typename qk8x8_t, + typename s_t, // soft-max types + typename s8x8_t, + typename o_t, // attention accumulation types + typename o4_t, + typename o8x8_t, + typename kd4x4_t, // key type in device memory + short nl_k, + void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &), + typename vd4x4_t, // value type in device memory + short nl_v, + void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), + short DK, // K head size + short DV, // V head size + short Q = 8, // queries per threadgroup + short KV = 8, // key/value processed per each simdgroup + short C = 32> // cache items per threadgroup +kernel void kernel_flash_attn_ext( + constant ggml_metal_kargs_flash_attn_ext & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const int iq3 = tgpig[2]; + const int iq2 = tgpig[1]; + const int iq1 = tgpig[0]*Q; + + constexpr short DK4 = DK/4; + constexpr short DK8 = DK/8; + constexpr short DK16 = DK/16; + constexpr short DV4 = DV/4; + constexpr short DV8 = DV/8; + constexpr short DV16 = DV/16; + + constexpr short NW = N_SIMDWIDTH; + constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) + + const short TS = nsg*SH; // shared memory size per query in (s_t == float) + const short T = 2*DK + 2*TS; // shared memory size per query in (half) + + threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix + + threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory + threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t + + threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory + threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + o8x8_t lo[DV8]; + + // load heads from Q to shared memory + for (short j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); + + for (short i = tiisg; i < DK4; i += NW) { + if (iq1 + j < args.ne01) { + sq4[j*DK4 + i] = (q4_t) q4[i]; + } else { + sq4[j*DK4 + i] = 0; + } + } + } + + // zero out lo + for (short i = 0; i < DV8; ++i) { + lo[i] = make_filled_simdgroup_matrix((o_t) 0.0f); + } + + // zero out shared memory SH + for (short j = 0; j < Q; ++j) { + for (short i = tiisg; i < SH; i += NW) { + ss[j*TS + i] = 0.0f; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + float S[Q] = { [0 ... Q-1] = 0.0f }; + float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 }; + + // thread indices inside the simdgroup + // TODO: see if we can utilize quad-group functions for better performance + // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3) + const short tx = tiisg%4; + const short ty = tiisg/4; + + // broadcast kv + //const short rk2 = args.ne02/args.ne12; + //const short rk3 = args.ne03/args.ne13; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + const bool has_mask = mask != q; + + float slope = 1.0f; + + // ALiBi + if (args.max_bias > 0.0f) { + const short h = iq2; + + const float base = h < args.n_head_log2 ? args.m0 : args.m1; + const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; + + slope = pow(base, exph); + } + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= args.ne11) { + break; + } + + if (has_mask) { + // used to detect blocks full of -INF + float smax = -INFINITY; + + // load the mask in shared memory + #pragma unroll(Q) + for (short j = 0; j < Q; ++j) { + device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31); + + const float m = pm[ic + tiisg]; + + ss[j*TS + C + tiisg] = m; + smax = max(smax, m); + } + + smax = simd_max(smax); + + if (smax == -INFINITY) { + continue; + } + } + + // Q*K^T + { + for (short cc = 0; cc < C/8; ++cc) { + qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); + + // this is compile-time check, so it does not have runtime overhead + if (is_same::value) { + // we can read directly from global memory + device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + + #pragma unroll(DK8) + for (short i = 0; i < DK8; ++i) { + k8x8_t mk; + simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10 + + q8x8_t mq; + simdgroup_load(mq, sq + i*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + } + } else { + for (short ii = 0; ii < DK16; ii += 4) { + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + + if (DK16%4 == 0) { + // the head is evenly divisible by 4*16 = 64, so no need for bound checks + { + k4x4_t tmp; + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + sk4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll(4) + for (short k = 0; k < 4; ++k) { + k8x8_t mk; + q8x8_t mq; + + simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + } + } else { + if (ii + tx < DK16) { + k4x4_t tmp; + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + sk4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + for (short k = 0; k < 4 && ii + k < DK16; ++k) { + k8x8_t mk; + q8x8_t mq; + + simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + } + } + } + } + + // cast qk_t -> s_t + //s8x8_t mqks(1.0f); + //simdgroup_multiply(mqks, mqk, mqks); + //simdgroup_store(mqks, ss + 8*cc, TS, 0, false); + + simdgroup_store(mqk, ss + 8*cc, TS, 0, false); + } + } + + // online softmax + { + for (ushort j = 0; j < Q; ++j) { + const float m = M[j]; + + // scale and apply the logitcap / mask + float s = ss[j*TS + tiisg]*args.scale; + + if (args.logit_softcap != 0.0f) { + s = args.logit_softcap*precise::tanh(s); + } + + // mqk = mqk + mask*slope + s += slope*ss[j*TS + C + tiisg]; + + M[j] = simd_max(max(M[j], s)); + + const float ms = exp(m - M[j]); + const float vs = exp(s - M[j]); + + S[j] = S[j]*ms + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*TS + tiisg] = vs; + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg == j) { + ss[j*TS + 2*C + j] = ms; + } + } + } + + // O = diag(ms)*O + { + s8x8_t ms; + simdgroup_load(ms, ss + 2*C, TS, 0, false); + + #pragma unroll(DV8) + for (short i = 0; i < DV8; ++i) { + simdgroup_multiply(lo[i], ms, lo[i]); + } + } + + // O = O + (Q*K^T)*V + { + for (short cc = 0; cc < C/8; ++cc) { + s8x8_t vs; + simdgroup_load(vs, ss + 8*cc, TS, 0, false); + + if (is_same::value) { + // we can read directly from global memory + device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + + #pragma unroll(DV8) + for (short i = 0; i < DV8; ++i) { + v8x8_t mv; + simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20 + + simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]); + } + } else { + for (short ii = 0; ii < DV16; ii += 4) { + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + + if (DV16%4 == 0) { + // no need for bound checks + { + v4x4_t tmp; + deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp); + sv4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll(4) + for (short k = 0; k < 4; ++k) { + v8x8_t mv; + + simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]); + + simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]); + } + } else { + if (ii + tx < DV16) { + v4x4_t tmp; + deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp); + sv4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + for (short k = 0; k < 4 && ii + k < DV16; ++k) { + v8x8_t mv; + + simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]); + + simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]); + } + } + } + } + } + } + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + for (short j = tiisg; j < Q; j += NW) { + ss[j*TS + 0] = S[j]; + ss[j*TS + 1] = M[j]; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + threadgroup float * so = (threadgroup float *) (shmem_f16 + 0*DK); // reuse query data for accumulation + threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK); + + // store result to shared memory in F32 + if (sgitg == 0) { + for (short i = 0; i < DV8; ++i) { + //simdgroup_store(lo[i], so + i*8, DV, 0, false); + simdgroup_float8x8 t(1.0f); + simdgroup_multiply(t, lo[i], t); + simdgroup_store(t, so + i*8, DV, 0, false); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // reduce the warps sequentially + for (ushort sg = 1; sg < nsg; ++sg) { + if (sgitg == sg) { + for (short j = tiisg; j < Q; j += NW) { + const float S0 = ss[j*TS - 1*SH + 0]; + const float S1 = ss[j*TS + 0]; + + const float M0 = ss[j*TS - 1*SH + 1]; + const float M1 = ss[j*TS + 1]; + + const float M = max(M0, M1); + + float ms0 = exp(M0 - M); + float ms1 = exp(M1 - M); + + const float S = S0*ms0 + S1*ms1; + + ss[j*TS + 0] = S; + ss[j*TS + 1] = M; + + ss[j*TS + 2*C + j - 1*SH] = ms0; + ss[j*TS + 2*C + j ] = ms1; + } + + //simdgroup_barrier(mem_flags::mem_threadgroup); + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + { + s8x8_t ms0; + s8x8_t ms1; + + simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false); + simdgroup_load(ms1, ss + 2*C, TS, 0, false); + + #pragma unroll(DV8) + for (short i = 0; i < DV8; ++i) { + simdgroup_float8x8 t; + + simdgroup_load (t, so + i*8, DV, 0, false); + simdgroup_multiply(t, ms0, t); + + simdgroup_multiply_accumulate(t, ms1, lo[i], t); + simdgroup_store(t, so + i*8, DV, 0, false); + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK); + + // final rescale with 1/S and store to global memory + for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) { + const float S = 1.0f/sf[j*TS + 0]; + + device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4; + + for (short i = tiisg; i < DV4; i += NW) { + dst4[i] = (float4) so4[j*DV4 + i]*S; + } + } +} + +// TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as +// template to be able to explore different combinations +// +#define FA_TYPES \ + float, float4, simdgroup_float8x8, \ + half, half4x4, simdgroup_half8x8, \ + half, half4x4, simdgroup_half8x8, \ + float, simdgroup_float8x8, \ + float, simdgroup_float8x8, \ + half, half4, simdgroup_half8x8 + //float, float4, simdgroup_float8x8 + +#define FA_TYPES_BF \ + bfloat, bfloat4, simdgroup_bfloat8x8, \ + bfloat, bfloat4x4, simdgroup_bfloat8x8, \ + bfloat, bfloat4x4, simdgroup_bfloat8x8, \ + float, simdgroup_float8x8, \ + float, simdgroup_float8x8, \ + half, half4, simdgroup_half8x8 + //float, float4, simdgroup_float8x8 + +typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; + +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +#endif + +template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +#undef FA_TYPES +#undef FA_TYPES_BF + +template< + typename q4_t, // query types in shared memory + typename k4_t, // key types in shared memory + typename v4_t, // value types in shared memory + typename qk_t, // Q*K types + typename s_t, // soft-max types + typename s4_t, + typename o4_t, // attention accumulation types + typename kd4_t, // key type in device memory + short nl_k, + void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &), + typename vd4_t, // value type in device memory + short nl_v, + void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), + short DK, // K head size + short DV, // V head size + short NE = 4, // head elements per thread + short Q = 1, // queries per threadgroup + short C = 32> // cache items per threadgroup +kernel void kernel_flash_attn_ext_vec( + constant ggml_metal_kargs_flash_attn_ext & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const int iq3 = tgpig[2]; + const int iq2 = tgpig[1]; + const int iq1 = tgpig[0]; + + constexpr short DK4 = DK/4; + constexpr short DV4 = DV/4; + constexpr short NW = N_SIMDWIDTH; + constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads + constexpr short SH = 4*C; // shared memory per simdgroup + + const short T = DK + nsg*SH; // shared memory size per query in (half) + + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t + threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask + threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results + + // store the result for all queries in local memory (the O matrix from the paper) + o4_t lo[DV4/NL]; + + // load heads from Q to shared memory + device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); + + for (short i = tiisg; i < DK4; i += NW) { + if (iq1 < args.ne01) { + sq4[i] = (q4_t) q4[i]; + } else { + sq4[i] = (q4_t) 0.0f; + } + } + + // zero out lo + for (short i = 0; i < DV4/NL; ++i) { + lo[i] = (o4_t) 0.0f; + } + + // zero out shared memory SH + for (short i = tiisg; i < SH/4; i += NW) { + ss4[i] = (s4_t) 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + float S = 0.0f; + float M = -__FLT_MAX__/2; + + // thread indices inside the simdgroup + const short tx = tiisg%NL; + const short ty = tiisg/NL; + + // broadcast kv + //const short rk2 = args.ne02/args.ne12; + //const short rk3 = args.ne03/args.ne13; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + const bool has_mask = mask != q; + + // pointer to the mask + device const half * pm = (device const half *) (mask + iq1*args.nb31); + + float slope = 1.0f; + + // ALiBi + if (args.max_bias > 0.0f) { + const short h = iq2; + + const float base = h < args.n_head_log2 ? args.m0 : args.m1; + const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; + + slope = pow(base, exph); + } + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= args.ne11) { + break; + } + + if (has_mask) { + sm[tiisg] = pm[ic + tiisg]; + } + + // skip -INF blocks + if (simd_max(sm[tiisg]) == -INFINITY) { + continue; + } + + // Q*K^T + { + // each simdgroup processes 1 query and NE (NW/NL) head elements + for (short cc = 0; cc < C/NE; ++cc) { + qk_t mqk = 0.0f; + + device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + + #pragma unroll(DK4/NL) + for (short ii = 0; ii < DK4; ii += NL) { + const short i = ii + tx; + + k4_t mk; + deq_k_t4(pk + i/nl_k, i%nl_k, mk); + + // note: this is less precise than the version below + //mqka[0] += dot(mq[0], mk[0]); + //mqka[1] += dot(mq[1], mk[1]); + //mqka[2] += dot(mq[2], mk[2]); + //mqka[3] += dot(mq[3], mk[3]); + + //q4x4_t mq = sq4x4[i]; + //mqka[0] += dot((float4) mq[0], (float4) mk[0]); + //mqka[1] += dot((float4) mq[1], (float4) mk[1]); + //mqka[2] += dot((float4) mq[2], (float4) mk[2]); + //mqka[3] += dot((float4) mq[3], (float4) mk[3]); + + mqk += dot((float4) mk, (float4) sq4[i]); + } + + static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails + + // simdgroup reduce (NE = 4) + // [ 0 .. 7] -> [ 0] + // [ 8 .. 15] -> [ 8] + // [16 .. 23] -> [16] + // [24 .. 31] -> [24] + if (NE <= 1) { + mqk += simd_shuffle_down(mqk, 16); + } + if (NE <= 2) { + mqk += simd_shuffle_down(mqk, 8); + } + if (NE <= 4) { + mqk += simd_shuffle_down(mqk, 4); + } + if (NE <= 8) { + mqk += simd_shuffle_down(mqk, 2); + } + if (NE <= 16) { + mqk += simd_shuffle_down(mqk, 1); + } + + // mqk = mqk*scale + mask*slope + if (tx == 0) { + mqk *= args.scale; + + if (args.logit_softcap != 0.0f) { + mqk = args.logit_softcap*precise::tanh(mqk); + } + + mqk += sm[NE*cc + ty]*slope; + + ss[NE*cc + ty] = mqk; + } + } + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + // online softmax + { + const float m = M; + const float s = ss[tiisg]; + + M = simd_max(max(M, s)); + + const float ms = exp(m - M); + const float vs = exp(s - M); + + S = S*ms + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[tiisg] = vs; + + // O = diag(ms)*O + #pragma unroll(DV4/NL) + for (short ii = 0; ii < DV4; ii += NL) { + lo[ii/NL] *= ms; + } + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + // O = O + (Q*K^T)*V + { + //#pragma unroll(C/NE) + for (short cc = 0; cc < C/NE; ++cc) { + device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + + const s4_t ms(ss[NE*cc + ty]); + + #pragma unroll(DV4/NL) + for (short ii = 0; ii < DV4; ii += NL) { + const short i = ii + tx; + + v4_t mv; + deq_v_t4(pv4 + i/nl_v, i%nl_v, mv); + + lo[ii/NL] += o4_t(float4(mv)*float4(ms)); + } + } + } + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + if (tiisg == 0) { + ss[0] = (s_t) S; + ss[1] = (s_t) M; + } + } + + // simdgroup reduce (NE = 4) + // [ 0, 8, 16, 24] -> [ 0] + // [ 1, 9, 17, 25] -> [ 1] + // [ 2, 10, 18, 26] -> [ 2] + // [ 3, 11, 19, 27] -> [ 3] + // [ 4, 12, 20, 28] -> [ 4] + // [ 5, 13, 21, 29] -> [ 5] + // [ 6, 14, 22, 30] -> [ 6] + // [ 7, 15, 23, 31] -> [ 7] + for (short ii = 0; ii < DV4; ii += NL) { + if (NE > 1) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16); + } + + if (NE > 2) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8); + } + + if (NE > 4) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4); + } + + if (NE > 8) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2); + } + + if (NE > 16) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // store results to shared memory + for (short i = tiisg; i < DV4; i += NL) { + sr4[i] = lo[i/NL]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // parallel reduce + for (short r = nsg/2; r > 0; r >>= 1) { + if (sgitg < r) { + const float S0 = ss[ 0]; + const float S1 = ss[r*(SH/2) + 0]; + + const float M0 = ss[ 1]; + const float M1 = ss[r*(SH/2) + 1]; + + const float M = max(M0, M1); + + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); + + const float S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (short i = tiisg; i < DV4; i += NW) { + sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + const float S = ss[0]; + + for (short i = tiisg; i < DV4; i += NW) { + dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S; + } + } +} + +// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem +// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max +// +#define FA_TYPES \ + half4, \ + half4, \ + half4, \ + float, \ + float, float4, \ + float4 + +typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +#undef FA_TYPES + +template +kernel void kernel_set( + constant ggml_metal_kargs_set & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i13 = tgpig[2]; + const int i12 = tgpig[1]; + const int i11 = tgpig[0]; + + const int64_t n = i13*args.ne12*args.ne11*args.ne10 + i12*args.ne11*args.ne10 + i11*args.ne10; + + const int64_t i3 = n / (args.ne12*args.ne11*args.ne10); + const int64_t i2 = (n - i3*args.ne12*args.ne11*args.ne10) / (args.ne11*args.ne10); + const int64_t i1 = (n - i3*args.ne12*args.ne11*args.ne10 - i2*args.ne11*args.ne10) / args.ne10; + + device T * dst_data = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + args.offs); + + for (int64_t i10 = tpitg.x; i10 < args.ne10; i10 += ntg.x) { + device const T * src = (device T *) (src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10); + dst_data[i10] = (T) src[0]; + } +} + +typedef decltype(kernel_set) kernel_set_t; + +template [[host_name("kernel_set_f32")]] kernel kernel_set_t kernel_set; +template [[host_name("kernel_set_i32")]] kernel kernel_set_t kernel_set; + +template +kernel void kernel_cpy( + constant ggml_metal_kargs_cpy & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig[2]; + const int i02 = tgpig[1]; + const int i01 = tgpig[0]; + + const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; + + const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); + const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); + const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); + + device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + for (int64_t i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) { + device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + dst_data[i00] = (T1) src[0]; + } +} + +typedef decltype(kernel_cpy) kernel_cpy_t; + +template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy; +#endif +template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy; +#endif + +kernel void kernel_cpy_f32_q8_0( + constant ggml_metal_kargs_cpy & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig[2]; + const int i02 = tgpig[1]; + const int i01 = tgpig[0]; + + const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; + + const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); + const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); + const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK8_0; + + device block_q8_0 * dst_data = (device block_q8_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) { + device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = src[j]; + amax = MAX(amax, fabs(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK8_0].d = d; + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = src[j]*id; + + dst_data[i00/QK8_0].qs[j] = round(x0); + } + } +} + +kernel void kernel_cpy_f32_q4_0( + constant ggml_metal_kargs_cpy & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig[2]; + const int i02 = tgpig[1]; + const int i01 = tgpig[0]; + + const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; + + const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); + const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); + const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_0; + + device block_q4_0 * dst_data = (device block_q4_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) { + device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK4_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / -8; + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK4_0].d = d; + + for (int j = 0; j < QK4_0/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK4_0/2 + j]*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); + + dst_data[i00/QK4_0].qs[j] = xi0; + dst_data[i00/QK4_0].qs[j] |= xi1 << 4; + } + } +} + +kernel void kernel_cpy_f32_q4_1( + constant ggml_metal_kargs_cpy & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig[2]; + const int i02 = tgpig[1]; + const int i01 = tgpig[0]; + + const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; + + const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); + const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); + const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_1; + + device block_q4_1 * dst_data = (device block_q4_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) { + device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int j = 0; j < QK4_1; j++) { + const float v = src[j]; + if (min > v) min = v; + if (max < v) max = v; + } + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK4_1].d = d; + dst_data[i00/QK4_1].m = min; + + for (int j = 0; j < QK4_1/2; ++j) { + const float x0 = (src[0 + j] - min)*id; + const float x1 = (src[QK4_1/2 + j] - min)*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); + + dst_data[i00/QK4_1].qs[j] = xi0; + dst_data[i00/QK4_1].qs[j] |= xi1 << 4; + } + } +} + +kernel void kernel_cpy_f32_q5_0( + constant ggml_metal_kargs_cpy & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig[2]; + const int i02 = tgpig[1]; + const int i01 = tgpig[0]; + + const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; + + const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); + const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); + const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_0; + + device block_q5_0 * dst_data = (device block_q5_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) { + device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK5_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / -16; + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK5_0].d = d; + + uint32_t qh = 0; + for (int j = 0; j < QK5_0/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK5_0/2 + j]*id; + + const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); + const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); + + dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); + } + thread const uint8_t * qh8 = (thread const uint8_t *)&qh; + for (int j = 0; j < 4; ++j) { + dst_data[i00/QK5_0].qh[j] = qh8[j]; + } + } +} + +kernel void kernel_cpy_f32_q5_1( + constant ggml_metal_kargs_cpy & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig[2]; + const int i02 = tgpig[1]; + const int i01 = tgpig[0]; + + const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; + + const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); + const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); + const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_1; + + device block_q5_1 * dst_data = (device block_q5_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) { + device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + + float max = src[0]; + float min = src[0]; + + for (int j = 1; j < QK5_1; j++) { + const float v = src[j]; + min = v < min ? v : min; + max = v > max ? v : max; + } + + const float d = (max - min) / 31; + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK5_1].d = d; + dst_data[i00/QK5_1].m = min; + + uint32_t qh = 0; + for (int j = 0; j < QK5_1/2; ++j) { + const float x0 = (src[0 + j] - min)*id; + const float x1 = (src[QK5_1/2 + j] - min)*id; + + const uint8_t xi0 = (uint8_t)(x0 + 0.5f); + const uint8_t xi1 = (uint8_t)(x1 + 0.5f); + + dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2); + } + thread const uint8_t * qh8 = (thread const uint8_t *)&qh; + for (int j = 0; j < 4; ++j) { + dst_data[i00/QK5_1].qh[j] = qh8[j]; + } + } +} + +static inline int best_index_int8(int n, constant float * val, float x) { + if (x <= val[0]) return 0; + if (x >= val[n-1]) return n-1; + int ml = 0, mu = n-1; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < val[mav]) mu = mav; else ml = mav; + } + return x - val[mu-1] < val[mu] - x ? mu-1 : mu; +} + +kernel void kernel_cpy_f32_iq4_nl( + constant ggml_metal_kargs_cpy & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig[2]; + const int i02 = tgpig[1]; + const int i01 = tgpig[0]; + + const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; + + const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); + const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); + const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_NL; + + device block_iq4_nl * dst_data = (device block_iq4_nl *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) { + device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK4_NL; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / kvalues_iq4nl_f[0]; + const float id = d ? 1.0f/d : 0.0f; + + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK4_NL/2 + j]*id; + + const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0); + const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1); + + dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4); + + const float v0 = kvalues_iq4nl_f[xi0]; + const float v1 = kvalues_iq4nl_f[xi1]; + const float w0 = src[0 + j]*src[0 + j]; + const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j]; + sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j]; + sumq2 += w0*v0*v0 + w1*v1*v1; + + } + + dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d; + } +} + +template +kernel void kernel_cpy_q_f32( + constant ggml_metal_kargs_cpy & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig[2]; + const int i02 = tgpig[1]; + const int i01 = tgpig[0]; + + const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; + + const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); + const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); + const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); + + device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) { + T4x4 temp; + dequantize_func(src_data + i00/nl, i00%nl, temp); + dst_data[i00] = temp; + } +} + +typedef decltype(kernel_cpy_q_f32) cpy_q_f_t; + +template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; + +template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; + +kernel void kernel_concat( + constant ggml_metal_kargs_concat & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; + + int o[4] = {0, 0, 0, 0}; + o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03)); + + device const float * x; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { + x = (device const float *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00); + } else { + x = (device const float *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10); + } + + device float * y = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + *y = *x; + } +} + +template +void kernel_mul_mv_q2_K_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + float yl[32]; + float sumf[nr0]={0.f}; + + const short ix = tiisg/8; // 0...3 + const short it = tiisg%8; // 0...7 + const short iq = it/4; // 0 or 1 + const short ir = it%4; // 0...3 + const short is = (8*ir)/16;// 0 or 1 + + device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; + + for (int ib = ix; ib < nb; ib += 4) { + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (short i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+96]; sumy[3] += yl[i+24]; + } + + device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; + device const half * dh = &x[ib].d; + + for (short row = 0; row < nr0; row++) { + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); + acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); + acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); + acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); + acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); + acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); + acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); + acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); + } + float dall = dh[0]; + float dmin = dh[1] * 1.f/16.f; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + + (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f + + (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f + + (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - + dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); + + qs += args.nb01/2; + sc += args.nb01; + dh += args.nb01/2; + } + + y4 += 4 * QK_K; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = sum_all; + } + } +} + +[[host_name("kernel_mul_mv_q2_K_f32")]] +kernel void kernel_mul_mv_q2_K_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +template +void kernel_mul_mv_q3_K_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0); + device const float * yy = (device const float *) (src1 + offset1); + + float yl[32]; + + //const uint16_t kmask1 = 0x3030; + //const uint16_t kmask2 = 0x0f0f; + + const short tid = tiisg/4; + const short ix = tiisg%4; + const short ip = tid/4; // 0 or 1 + const short il = 2*((tid%4)/2); // 0 or 2 + const short ir = tid%2; + const short l0 = 8*ir; + + // One would think that the Metal compiler would figure out that ip and il can only have + // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it + // with these two tales. + // + // Possible masks for the high bit + const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0 + {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2 + {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0 + {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2 + + // Possible masks for the low 2 bits + const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}}; + + const ushort4 hm = mm[2*ip + il/2]; + + const short shift = 2*il; + + const float v1 = il == 0 ? 4.f : 64.f; + const float v2 = 4.f * v1; + + const uint16_t s_shift1 = 4*ip; + const uint16_t s_shift2 = s_shift1 + il; + + const short q_offset = 32*ip + l0; + const short y_offset = 128*ip + 32*il + l0; + + device const float * y1 = yy + ix*QK_K + y_offset; + + uint32_t scales32, aux32; + thread uint16_t * scales16 = (thread uint16_t *)&scales32; + thread const int8_t * scales = (thread const int8_t *)&scales32; + + float sumf1[nr0] = {0.f}; + float sumf2[nr0] = {0.f}; + + for (int i = ix; i < nb; i += 4) { + for (short l = 0; l < 8; ++l) { + yl[l+ 0] = y1[l+ 0]; + yl[l+ 8] = y1[l+16]; + yl[l+16] = y1[l+32]; + yl[l+24] = y1[l+48]; + } + + device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); + device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0); + device const uint16_t * a = (device const uint16_t *)(x[i].scales); + device const half * dh = &x[i].d; + + for (short row = 0; row < nr0; ++row) { + const float d_all = (float)dh[0]; + + scales16[0] = a[4]; + scales16[1] = a[5]; + aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030; + scales16[0] = a[il+0]; + scales16[1] = a[il+1]; + scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; + + float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; + for (short l = 0; l < 8; l += 2) { + const int32_t qs = q[l/2]; + s1 += yl[l+0] * (qs & qm[il/2][0]); + s2 += yl[l+1] * (qs & qm[il/2][1]); + s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]); + s4 += yl[l+16] * (qs & qm[il/2][2]); + s5 += yl[l+17] * (qs & qm[il/2][3]); + s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]); + } + float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); + float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); + sumf1[row] += d1 * (scales[0] - 32); + sumf2[row] += d2 * (scales[2] - 32); + + s1 = s2 = s3 = s4 = s5 = s6 = 0; + for (short l = 0; l < 8; l += 2) { + const int32_t qs = q[l/2+8]; + s1 += yl[l+8] * (qs & qm[il/2][0]); + s2 += yl[l+9] * (qs & qm[il/2][1]); + s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]); + s4 += yl[l+24] * (qs & qm[il/2][2]); + s5 += yl[l+25] * (qs & qm[il/2][3]); + s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]); + } + d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); + d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); + sumf1[row] += d1 * (scales[1] - 32); + sumf2[row] += d2 * (scales[3] - 32); + + q += args.nb01/2; + h += args.nb01/2; + a += args.nb01/2; + dh += args.nb01/2; + } + + y1 += 4 * QK_K; + } + + for (int row = 0; row < nr0; ++row) { + const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); + sumf1[row] = simd_sum(sumf); + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + if (tiisg == 0) { + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + dst_f32[first_row + row] = sumf1[row]; + } + } +} + +[[host_name("kernel_mul_mv_q3_K_f32")]] +kernel void kernel_mul_mv_q3_K_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +template +void kernel_mul_mv_q4_K_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const short ix = tiisg/8; // 0...3 + const short it = tiisg%8; // 0...7 + const short iq = it/4; // 0 or 1 + const short ir = it%4; // 0...3 + + const int nb = args.ne00/QK_K; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + float yl[16]; + float yh[16]; + + float sumf[nr0]={0.f}; + + device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + uint16_t sc16[4]; + thread const uint8_t * sc8 = (thread const uint8_t *)sc16; + + for (int ib = ix; ib < nb; ib += 4) { + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + + for (short i = 0; i < 8; ++i) { + yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; + yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; + yh[i+0] = y4[i+128]; sumy[2] += yh[i+0]; + yh[i+8] = y4[i+160]; sumy[3] += yh[i+8]; + } + + device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq; + device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; + device const half * dh = &x[ib].d; + + for (short row = 0; row < nr0; row++) { + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + device const uint16_t * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + + for (short i = 0; i < 4; ++i) { + acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F); + acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00); + acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0); + acc1[3] += yl[2*i + 9] * (q1[i] & 0xF000); + acc2[0] += yh[2*i + 0] * (q2[i] & 0x000F); + acc2[1] += yh[2*i + 1] * (q2[i] & 0x0F00); + acc2[2] += yh[2*i + 8] * (q2[i] & 0x00F0); + acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000); + } + + float dall = dh[0]; + float dmin = dh[1]; + + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + + (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + + (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + + (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - + dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + + q1 += args.nb01/2; + sc += args.nb01/2; + dh += args.nb01/2; + } + + y4 += 4 * QK_K; + } + + device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; + + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = sum_all; + } + } +} + +[[host_name("kernel_mul_mv_q4_K_f32")]] +kernel void kernel_mul_mv_q4_K_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +template +void kernel_mul_mv_q5_K_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0); + device const float * yy = (device const float *) (src1 + offset1); + + float sumf[nr0]={0.f}; + + float yl[16], yh[16]; + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const short tid = tiisg/4; + const short ix = tiisg%4; + const short iq = tid/4; + const short ir = tid%4; + + const short l0 = 8*ir; + const short q_offset = 32*iq + l0; + const short y_offset = 64*iq + l0; + + const uint8_t hm1 = 1u << (2*iq); + const uint8_t hm2 = hm1 << 1; + const uint8_t hm3 = hm1 << 4; + const uint8_t hm4 = hm2 << 4; + + uint16_t sc16[4]; + thread const uint8_t * sc8 = (thread const uint8_t *)sc16; + + device const float * y1 = yy + ix*QK_K + y_offset; + + for (int i = ix; i < nb; i += 4) { + device const uint8_t * q1 = x[i].qs + q_offset; + device const uint8_t * qh = x[i].qh + l0; + device const half * dh = &x[i].d; + device const uint16_t * a = (device const uint16_t *)x[i].scales + iq; + + device const float * y2 = y1 + 128; + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (short l = 0; l < 8; ++l) { + yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; + yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; + yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; + yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; + } + + for (short row = 0; row < nr0; ++row) { + device const uint8_t * q2 = q1 + 64; + + sc16[0] = a[0] & kmask1; + sc16[1] = a[2] & kmask1; + sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); + sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); + + float4 acc1 = {0.f}; + float4 acc2 = {0.f}; + for (short l = 0; l < 8; ++l) { + uint8_t h = qh[l]; + acc1[0] += yl[l+0] * (q1[l] & 0x0F); + acc1[1] += yl[l+8] * (q1[l] & 0xF0); + acc1[2] += yh[l+0] * (q2[l] & 0x0F); + acc1[3] += yh[l+8] * (q2[l] & 0xF0); + acc2[0] += h & hm1 ? yl[l+0] : 0.f; + acc2[1] += h & hm2 ? yl[l+8] : 0.f; + acc2[2] += h & hm3 ? yh[l+0] : 0.f; + acc2[3] += h & hm4 ? yh[l+8] : 0.f; + } + const float dall = dh[0]; + const float dmin = dh[1]; + sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + + sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + + sc8[4] * (acc1[2] + 16.f*acc2[2]) + + sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - + dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + + q1 += args.nb01; + qh += args.nb01; + dh += args.nb01/2; + a += args.nb01/2; + } + + y1 += 4 * QK_K; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = tot; + } + } +} + +[[host_name("kernel_mul_mv_q5_K_f32")]] +kernel void kernel_mul_mv_q5_K_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +template +void kernel_mul_mv_q6_K_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const uint8_t kmask1 = 0x03; + const uint8_t kmask2 = 0x0C; + const uint8_t kmask3 = 0x30; + const uint8_t kmask4 = 0xC0; + + const int nb = args.ne00/QK_K; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0); + device const float * yy = (device const float *) (src1 + offset1); + + float sumf[nr0] = { 0.f }; + + float yl[16]; + + const short tid = tiisg/2; + const short ix = tiisg%2; + const short ip = tid/8; // 0 or 1 + const short il = tid%8; + const short l0 = 4*il; + const short is = 8*ip + l0/16; + + const short y_offset = 128*ip + l0; + const short q_offset_l = 64*ip + l0; + const short q_offset_h = 32*ip + l0; + + for (int i = ix; i < nb; i += 2) { + device const uint8_t * q1 = x[i].ql + q_offset_l; + device const uint8_t * q2 = q1 + 32; + device const uint8_t * qh = x[i].qh + q_offset_h; + device const int8_t * sc = x[i].scales + is; + device const half * dh = &x[i].d; + + device const float * y = yy + i * QK_K + y_offset; + + for (short l = 0; l < 4; ++l) { + yl[4*l + 0] = y[l + 0]; + yl[4*l + 1] = y[l + 32]; + yl[4*l + 2] = y[l + 64]; + yl[4*l + 3] = y[l + 96]; + } + + for (short row = 0; row < nr0; ++row) { + const float dall = dh[0]; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + + for (short l = 0; l < 4; ++l) { + sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); + sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + } + + sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); + + q1 += args.nb01; + q2 += args.nb01; + qh += args.nb01; + sc += args.nb01; + dh += args.nb01/2; + } + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = sum_all; + } + } +} + +[[host_name("kernel_mul_mv_q6_K_f32")]] +kernel void kernel_mul_mv_q6_K_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +// ======================= "True" 2-bit + +template +void kernel_mul_mv_iq2_xxs_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + float yl[32]; + float sumf[nr0]={0.f}; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem); + threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256); + { + int nval = 4; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xxs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + for (short i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq2_xxs * xr = x + ibl; + device const uint16_t * q2 = xr->qs + 4 * ib; + device const half * dh = &xr->d; + + for (short row = 0; row < nr0; row++) { + const float db = dh[0]; + device const uint8_t * aux8 = (device const uint8_t *)q2; + const uint32_t aux32 = q2[2] | (q2[3] << 16); + const float d = db * (0.5f + (aux32 >> 28)); + + float sum = 0; + for (short l = 0; l < 4; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]); + const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; + for (short j = 0; j < 8; ++j) { + sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + sumf[row] += d * sum; + + dh += args.nb01/2; + q2 += args.nb01/2; + } + + y4 += 32 * 32; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = sum_all * 0.25f; + } + } +} + +[[host_name("kernel_mul_mv_iq2_xxs_f32")]] +kernel void kernel_mul_mv_iq2_xxs_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +template +void kernel_mul_mv_iq2_xs_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + float yl[32]; + float sumf[nr0]={0.f}; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem); + threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 512); + { + int nval = 8; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + for (short i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq2_xs * xr = x + ibl; + device const uint16_t * q2 = xr->qs + 4 * ib; + device const uint8_t * sc = xr->scales + ib; + device const half * dh = &xr->d; + + for (short row = 0; row < nr0; row++) { + const float db = dh[0]; + const uint8_t ls1 = sc[0] & 0xf; + const uint8_t ls2 = sc[0] >> 4; + const float d1 = db * (0.5f + ls1); + const float d2 = db * (0.5f + ls2); + + float sum1 = 0, sum2 = 0; + for (short l = 0; l < 2; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); + const uint8_t signs = ssigns[(q2[l] >> 9)]; + for (short j = 0; j < 8; ++j) { + sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + for (short l = 2; l < 4; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); + const uint8_t signs = ssigns[(q2[l] >> 9)]; + for (short j = 0; j < 8; ++j) { + sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + sumf[row] += d1 * sum1 + d2 * sum2; + + dh += args.nb01/2; + q2 += args.nb01/2; + sc += args.nb01; + } + + y4 += 32 * 32; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = sum_all * 0.25f; + } + } +} + +[[host_name("kernel_mul_mv_iq2_xs_f32")]] +kernel void kernel_mul_mv_iq2_xs_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +template +void kernel_mul_mv_iq3_xxs_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + float yl[32]; + float sumf[nr0]={0.f}; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint32_t * svalues = (threadgroup uint32_t *)(shmem); + threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256); + { + int nval = 4; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3xxs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + for (short i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq3_xxs * xr = x + ibl; + device const uint8_t * q3 = xr->qs + 8 * ib; + device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib; + device const half * dh = &xr->d; + + for (short row = 0; row < nr0; row++) { + const float db = dh[0]; + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float d = db * (0.5f + (aux32 >> 28)); + + float2 sum = {0}; + for (short l = 0; l < 4; ++l) { + const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]); + const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]); + const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; + for (short j = 0; j < 4; ++j) { + sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); + sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); + } + } + sumf[row] += d * (sum[0] + sum[1]); + + dh += args.nb01/2; + q3 += args.nb01; + gas += args.nb01/2; + } + + y4 += 32 * 32; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = sum_all * 0.5f; + } + } +} + +[[host_name("kernel_mul_mv_iq3_xxs_f32")]] +kernel void kernel_mul_mv_iq3_xxs_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +template +void kernel_mul_mv_iq3_s_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + float yl[32]; + float sumf[nr0]={0.f}; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint32_t * svalues = (threadgroup uint32_t *) shmem; + { + int nval = 8; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3s_grid[pos + i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + for (short i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq3_s * xr = x + ibl; + device const uint8_t * qs = xr->qs + 8 * ib; + device const uint8_t * qh = xr->qh + ib; + device const uint8_t * sc = xr->scales + (ib/2); + device const uint8_t * signs = xr->signs + 4 * ib; + device const half * dh = &xr->d; + + for (short row = 0; row < nr0; row++) { + const float db = dh[0]; + const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf)); + + float2 sum = {0}; + for (short l = 0; l < 4; ++l) { + const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues; + const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues; + const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); + const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); + for (short j = 0; j < 4; ++j) { + sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); + sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); + } + } + sumf[row] += d * (sum[0] + sum[1]); + + dh += args.nb01/2; + qs += args.nb01; + qh += args.nb01; + sc += args.nb01; + signs += args.nb01; + } + + y4 += 32 * 32; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = sum_all; + } + } +} + +[[host_name("kernel_mul_mv_iq3_s_f32")]] +kernel void kernel_mul_mv_iq3_s_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +template +void kernel_mul_mv_iq2_s_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + float yl[32]; + float sumf[nr0]={0.f}; + + const int nb32 = nb * (QK_K / 32); + + //threadgroup uint64_t * svalues = (threadgroup uint64_t *) shmem; + //{ + // int nval = 32; + // int pos = (32*sgitg + tiisg)*nval; + // for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2s_grid[pos + i]; + // threadgroup_barrier(mem_flags::mem_threadgroup); + //} + + const short ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + for (short i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq2_s * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint8_t * qh = xr->qh + ib; + device const uint8_t * sc = xr->scales + ib; + device const uint8_t * signs = qs + QK_K/8; + device const half * dh = &xr->d; + + for (short row = 0; row < nr0; row++) { + const float db = dh[0]; + const float d1 = db * (0.5f + (sc[0] & 0xf)); + const float d2 = db * (0.5f + (sc[0] >> 4)); + + float2 sum = {0}; + for (short l = 0; l < 2; ++l) { + //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); + //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); + constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); + constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); + for (short j = 0; j < 8; ++j) { + sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]); + sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]); + } + } + sumf[row] += d1 * sum[0] + d2 * sum[1]; + + dh += args.nb01/2; + qs += args.nb01; + qh += args.nb01; + sc += args.nb01; + signs += args.nb01; + } + + y4 += 32 * 32; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = sum_all * 0.25f; + } + } +} + +[[host_name("kernel_mul_mv_iq2_s_f32")]] +kernel void kernel_mul_mv_iq2_s_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +template +void kernel_mul_mv_iq1_s_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + float yl[32]; + float sumf[nr0]={0.f}; + + const int nb32 = nb * (QK_K / 32); + + const short ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + float sumy = 0; + for (short i = 0; i < 32; ++i) { + yl[i] = y4[i]; + sumy += yl[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq1_s * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint16_t * qh = xr->qh + ib; + device const half * dh = &xr->d; + + for (short row = 0; row < nr0; row++) { + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700))); + constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700))); + constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700))); + + float sum = 0; + for (short j = 0; j < 4; ++j) { + sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4) + + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) + + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); + } + sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1); + + dh += args.nb01/2; + qs += args.nb01; + qh += args.nb01/2; + } + + y4 += 32 * 32; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = sum_all; + } + } +} + +[[host_name("kernel_mul_mv_iq1_s_f32")]] +kernel void kernel_mul_mv_iq1_s_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +template +void kernel_mul_mv_iq1_m_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + float yl[32]; + float sumf[nr0]={0.f}; + + const int nb32 = nb * (QK_K / 32); + + const short ix = tiisg; + + device const float * y4 = y + 32 * ix; + + iq1m_scale_t scale; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + float4 sumy = {0.f}; + for (short i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+16]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+24]; sumy[3] += yl[i+24]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq1_m * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint8_t * qh = xr->qh + 2 * ib; + device const uint16_t * sc = (device const uint16_t *)xr->scales; + + for (short row = 0; row < nr0; row++) { + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); + constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700))); + constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700))); + + float2 sum = {0.f}; + for (short j = 0; j < 4; ++j) { + sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4); + sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) + + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); + } + const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + + sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) + + (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1)); + + sc += args.nb01/2; + qs += args.nb01; + qh += args.nb01; + } + + y4 += 32 * 32; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = sum_all; + } + } +} + +[[host_name("kernel_mul_mv_iq1_m_f32")]] +kernel void kernel_mul_mv_iq1_m_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +template +void kernel_mul_mv_iq4_nl_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + threadgroup float * shmem_f32 = (threadgroup float *) shmem; + const int nb = args.ne00/QK4_NL; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + const short ix = tiisg/2; // 0...15 + const short it = tiisg%2; // 0 or 1 + + shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[nr0]={0.f}; + + device const float * yb = y + ix * QK4_NL + it * 8; + + uint32_t aux32[2]; + thread const uint8_t * q8 = (thread const uint8_t *)aux32; + + float4 qf1, qf2; + + for (int ib = ix; ib < nb; ib += 16) { + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; + yl[1] = y4[4]; + yl[2] = y4[1]; + yl[3] = y4[5]; + + for (short row = 0; row < nr0; row++) { + device const block_iq4_nl & xb = x[row*nb + ib]; + device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); + + float4 acc1 = {0.f}, acc2 = {0.f}; + + aux32[0] = q4[0] | (q4[1] << 16); + aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; + aux32[0] &= 0x0f0f0f0f; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; + acc1 += yl[0] * qf1; + acc2 += yl[1] * qf2; + + aux32[0] = q4[2] | (q4[3] << 16); + aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; + aux32[0] &= 0x0f0f0f0f; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; + acc1 += yl[2] * qf1; + acc2 += yl[3] * qf2; + + acc1 += acc2; + + sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + } + + yb += 16 * QK4_NL; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = sum_all; + } + } +} + +[[host_name("kernel_mul_mv_iq4_nl_f32")]] +kernel void kernel_mul_mv_iq4_nl_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +template +void kernel_mul_mv_iq4_xs_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + threadgroup float * shmem_f32 = (threadgroup float *) shmem; + const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * nsg + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + const short ix = tiisg/16; // 0 or 1 + const short it = tiisg%16; // 0...15 + const short ib = it/2; + const short il = it%2; + + shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[nr0]={0.f}; + + device const float * yb = y + ix * QK_K + ib * 32 + il * 8; + + uint32_t aux32[2]; + thread const uint8_t * q8 = (thread const uint8_t *)aux32; + + float4 qf1, qf2; + + for (int ibl = ix; ibl < nb; ibl += 2) { + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; + yl[1] = y4[4]; + yl[2] = y4[1]; + yl[3] = y4[5]; + + for (short row = 0; row < nr0; ++row) { + device const block_iq4_xs & xb = x[row*nb + ibl]; + device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); + + float4 acc1 = {0.f}, acc2 = {0.f}; + + aux32[0] = (q4[0] ) & 0x0f0f0f0f; + aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; + acc1 += yl[0] * qf1; + acc2 += yl[1] * qf2; + + aux32[0] = (q4[1] ) & 0x0f0f0f0f; + aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; + acc1 += yl[2] * qf1; + acc2 += yl[3] * qf2; + + acc1 += acc2; + + const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32; + sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + } + + yb += 2 * QK_K; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = sum_all; + } + } +} + +[[host_name("kernel_mul_mv_iq4_xs_f32")]] +kernel void kernel_mul_mv_iq4_xs_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +template +kernel void kernel_get_rows_q( + device const void * src0, + device const void * src1, + device float * dst, + constant ggml_metal_kargs_get_rows & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; + + const int64_t i02 = i11; + + for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) { + float4x4 temp; + dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp); + *(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp; + } +} + +template +kernel void kernel_get_rows_f( + device const void * src0, + device const void * src1, + device float * dst, + constant ggml_metal_kargs_get_rows & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; + + const int64_t i02 = i11; + + for (int ind = tiitg; ind < args.ne00; ind += tptg.x) { + (( device float *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] = + ((const device T *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind]; + } +} + +kernel void kernel_get_rows_i32( + device const void * src0, + device const void * src1, + device int32_t * dst, + constant ggml_metal_kargs_get_rows & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; + + const int64_t i02 = i11; + + for (int ind = tiitg; ind < args.ne00; ind += tptg.x) { + (( device int32_t *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] = + ((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind]; + } +} + + +#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A +#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B +#define BLOCK_SIZE_K 32 +#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A +#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B +#define THREAD_PER_BLOCK 128 +#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers +#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers +#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8 +#define SG_MAT_ROW 8 + +// each block_q contains 16*nl weights +template +kernel void kernel_mul_mm( + constant ggml_metal_kargs_mul_mm & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup T * sa = (threadgroup T *)(shmem); + threadgroup float * sb = (threadgroup float *)(shmem + 4096); + + const int r0 = tgpig.y; + const int r1 = tgpig.x; + const int im = tgpig.z; + + // if this block is of 64x32 shape or smaller + const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; + const short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; + + // a thread shouldn't load data outside of the matrix + const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; + const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + + simdgroup_T8x8 ma[4]; + simdgroup_float8x8 mb[2]; + simdgroup_float8x8 mc[8]; + + for (short i = 0; i < 8; i++){ + mc[i] = make_filled_simdgroup_matrix(0.f); + } + + short il = (tiitg % THREAD_PER_ROW); + + const int i12 = im%args.ne12; + const int i13 = im/args.ne12; + + const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const short offset1 = il/nl; + + device const block_q * x = (device const block_q *)(src0 + + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1; + + device const float * y = (device const float *)(src1 + + args.nb13*i13 + + args.nb12*i12 + + args.nb11*(r1*BLOCK_SIZE_N + thread_col) + + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) { + // load data and store to threadgroup memory + T4x4 temp_a; + dequantize_func(x, il, temp_a); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll(16) + for (short i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ + + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ + + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; + } + + *(threadgroup float2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y); + + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2 + nl - 1)/nl : x; + y += BLOCK_SIZE_K; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // load matrices from threadgroup memory and conduct outer products + threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); + threadgroup const float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); + + #pragma unroll(4) + for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) { + #pragma unroll(4) + for (short i = 0; i < 4; i++) { + simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); + } + + simdgroup_barrier(mem_flags::mem_none); + + #pragma unroll(2) + for (short i = 0; i < 2; i++) { + simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); + } + + #pragma unroll(8) + for (short i = 0; i < 8; i++){ + simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); + } + + lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE; + lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE; + } + } + + if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) { + device float * C = (device float *) dst + + (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \ + (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0; + + for (short i = 0; i < 8; i++) { + simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0); + } + } else { + // block is smaller than 64x32, we should avoid writing data outside of the matrix + threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup float * temp_str = ((threadgroup float *) shmem) \ + + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M; + for (short i = 0; i < 8; i++) { + simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sgitg == 0) { + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { + device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.ne0 + im*args.ne1*args.ne0; + device float4 * D4 = (device float4 *) D; + + threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = 0; + for (; i < n_rows/4; i++) { + *(D4 + i) = *(C4 + i); + } + + i *= 4; + for (; i < n_rows; i++) { + *(D + i) = *(C + i); + } + } + } + } +} + +template +kernel void kernel_mul_mm_id_map0( + constant ggml_metal_kargs_mul_mm_id_map0 & args, + device const char * src1, + device const char * src2, + device char * hsrc1, + device char * htpe, + device char * hids, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int ide = tgpig[0]; // expert id + + int n_all = 0; + + device int32_t * ids_i32 = (device int32_t *) (hids); + + for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens + device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21); + + for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used + if (src2_i32[i20] != ide) { + continue; + } + + device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11); + device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11); + + for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) { + hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]); + } + + if (tpitg.x == 0) { + ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all; + } + + ++n_all; + } + } + + if (tpitg.x == 0) { + device int32_t * tpe_i32 = (device int32_t *) (htpe); + tpe_i32[ide] = n_all; + } +} + +typedef decltype(kernel_mul_mm_id_map0) kernel_mul_mm_id_map0_t; + +template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0; + +template +kernel void kernel_mul_mm_id_map1( + constant ggml_metal_kargs_mul_mm_id_map1 & args, + device const char * hdst, + device const char * hids, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i20 = tgpig[0]; // used expert + const int i21 = tgpig[1]; // token + + device const int32_t * ids_i32 = (device const int32_t *) (hids); + device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2); + + const int id = ids_i32[i21*args.ne20 + i20]; + + const int ide = id / args.neh1; + const int idt = id % args.neh1; + + device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2); + + for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) { + dst_f32x4[i0] = hdst_f32x4[i0]; + } +} + +typedef decltype(kernel_mul_mm_id_map1) kernel_mul_mm_id_map1_t; + +template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1; + +template +kernel void kernel_mul_mm_id( + constant ggml_metal_kargs_mul_mm_id & args, + device const char * src0, + device const char * src1, + device const char * tpe, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup T * sa = (threadgroup T *)(shmem); + threadgroup half * sb = (threadgroup half *)(shmem + 4096); + + const int r0 = tgpig.y; + const int r1 = tgpig.x; + const int im = tgpig.z; + + device const int32_t * tpe_i32 = (device const int32_t *) (tpe); + + const int neh1 = tpe_i32[im]; + + if (r1*BLOCK_SIZE_N >= neh1) { + return; + } + + // if this block is of 64x32 shape or smaller + const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; + const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; + + // a thread shouldn't load data outside of the matrix + const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; + const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + + simdgroup_T8x8 ma[4]; + simdgroup_half8x8 mb[2]; + simdgroup_float8x8 mc[8]; + + for (short i = 0; i < 8; i++){ + mc[i] = make_filled_simdgroup_matrix(0.f); + } + + short il = (tiitg % THREAD_PER_ROW); + + const int i12 = im%args.neh12; + const int i13 = im/args.neh12; + + const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const short offset1 = il/nl; + + device const block_q * x = (device const block_q *)(src0 + + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1; + + device const half * y = (device const half *)(src1 + + args.nbh13*i13 + + args.nbh12*i12 + + args.nbh11*(r1*BLOCK_SIZE_N + thread_col) + + args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) { + // load data and store to threadgroup memory + T4x4 temp_a; + dequantize_func(x, il, temp_a); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll(16) + for (short i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ + + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ + + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; + } + + *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y); + + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2 + nl - 1)/nl : x; + y += BLOCK_SIZE_K; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // load matrices from threadgroup memory and conduct outer products + threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); + threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); + + #pragma unroll(4) + for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) { + #pragma unroll(4) + for (short i = 0; i < 4; i++) { + simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); + } + + simdgroup_barrier(mem_flags::mem_none); + + #pragma unroll(2) + for (short i = 0; i < 2; i++) { + simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); + } + + #pragma unroll(8) + for (short i = 0; i < 8; i++){ + simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); + } + + lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE; + lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE; + } + } + + if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) { + device float * C = (device float *) dst + + (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \ + (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0; + + for (short i = 0; i < 8; i++) { + simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0); + } + } else { + // block is smaller than 64x32, we should avoid writing data outside of the matrix + threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup float * temp_str = ((threadgroup float *) shmem) \ + + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M; + for (short i = 0; i < 8; i++) { + simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sgitg == 0) { + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { + device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0; + device float4 * D4 = (device float4 *) D; + + threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = 0; + for (; i < n_rows/4; i++) { + *(D4 + i) = *(C4 + i); + } + + i *= 4; + for (; i < n_rows; i++) { + *(D + i) = *(C + i); + } + } + } + } +} + +#define QK_NL 16 + +// +// get rows +// + +typedef decltype(kernel_get_rows_f) get_rows_f_t; + +template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f; +#endif + +typedef decltype(kernel_get_rows_q) get_rows_q_t; + +template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q; + +// +// matrix-matrix multiplication +// + +typedef decltype(kernel_mul_mm) mul_mm_t; + +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm; +#endif +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm; + +// +// indirect matrix-matrix multiplication +// + +typedef decltype(kernel_mul_mm_id) mul_mm_id; + +template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id; +#endif +template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; + + +// +// matrix-vector multiplication +// + +typedef void (kernel_mul_mv_impl_t)( + ggml_metal_kargs_mul_mv args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig, + ushort tiisg); + +typedef void (kernel_mul_mv2_impl_t)( + ggml_metal_kargs_mul_mv args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg); + +template +void mmv_fn( + ggml_metal_kargs_mul_mv args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiitg, + ushort tiisg, + ushort sgitg) { + impl_fn(args, src0, src1, dst, tgpig, tiisg); +} + +template +void mmv_fn( + ggml_metal_kargs_mul_mv args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiitg, + ushort tiisg, + ushort sgitg) { + impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +typedef decltype(mmv_fn>) mul_mv_impl_fn_t; + +template +kernel void kernel_mul_mv_id( + constant ggml_metal_kargs_mul_mv_id & args, + device const char * src0s, + device const char * src1, + device char * dst, + device const char * ids, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const int iid1 = tgpig.z/args.nei0; + const int idx = tgpig.z%args.nei0; + + tgpig.z = 0; + + const int32_t i02 = ((device const int32_t *) (ids + iid1*args.nbi1))[idx]; + + const int64_t i11 = idx % args.ne11; + const int64_t i12 = iid1; + + const int64_t i1 = idx; + const int64_t i2 = i12; + + device const char * src0_cur = src0s + i02*args.nb02; + device const char * src1_cur = src1 + i11*args.nb11 + i12*args.nb12; + + device char * dst_cur = dst + (i1*args.ne0 + i2*args.ne1*args.ne0)*sizeof(float); + + ggml_metal_kargs_mul_mv args0 = { + /*.ne00 =*/ args.ne00, + /*.ne01 =*/ args.ne01, + /*.ne02 =*/ 1, // args.ne02, + /*.nb00 =*/ args.nb00, + /*.nb01 =*/ args.nb01, + /*.nb02 =*/ args.nb02, + /*.nb03 =*/ args.nb02, // args.ne02 == 1 + /*.ne10 =*/ args.ne10, + /*.ne11 =*/ 1, // args.ne11, + /*.ne12 =*/ 1, // args.ne12, + /*.nb10 =*/ args.nb10, + /*.nb11 =*/ args.nb11, + /*.nb12 =*/ args.nb12, + /*.nb13 =*/ args.nb12, // ne12 == 1 + /*.ne0 =*/ args.ne0, + /*.ne1 =*/ 1, // args.ne1, + /*.r2 =*/ 1, + /*.r3 =*/ 1, + }; + + impl_fn( + args0, + /* src0 */ src0_cur, + /* src1 */ src1_cur, + /* dst */ dst_cur, + shmem, + tgpig, + tiitg, + tiisg, + sgitg); +} + +typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; + +template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +#endif +template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; + +template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; + +template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; + +kernel void kernel_pool_2d_max_f32( + device const float * src0, + device float * dst, + constant ggml_metal_kargs_pool_2d & args, + uint gid[[thread_position_in_grid]]) { + + if (gid >= args.parallel_elements) { + return; + } + + const int idx = gid; + const int I_HW = args.IH * args.IW; + const int O_HW = args.OH * args.OW; + const int nc = idx / O_HW; + const int cur_oh = idx % O_HW / args.OW; + const int cur_ow = idx % O_HW % args.OW; + + device const float * i_ptr = src0 + nc * I_HW; + device float * o_ptr = dst + nc * O_HW; + + const int start_h = cur_oh * args.s1 - args.p1; + const int bh = MAX(0, start_h); + const int eh = MIN(args.IH, start_h + args.k1); + const int start_w = cur_ow * args.s0 - args.p0; + const int bw = MAX(0, start_w); + const int ew = MIN(args.IW, start_w + args.k0); + + float res = -INFINITY; + + for (int i = bh; i < eh; i += 1) { + for (int j = bw; j < ew; j += 1) { + res = MAX(res, i_ptr[i * args.IW + j]); + } + } + + o_ptr[cur_oh * args.OW + cur_ow] = res; +} + +kernel void kernel_pool_2d_avg_f32( + device const float * src0, + device float * dst, + constant ggml_metal_kargs_pool_2d & args, + uint gid[[thread_position_in_grid]]) { + + if (gid >= args.parallel_elements) { + return; + } + + const int idx = gid; + const int I_HW = args.IH * args.IW; + const int O_HW = args.OH * args.OW; + const int nc = idx / O_HW; + const int cur_oh = idx % O_HW / args.OW; + const int cur_ow = idx % O_HW % args.OW; + + device const float * i_ptr = src0 + nc * I_HW; + device float * o_ptr = dst + nc * O_HW; + + const int start_h = cur_oh * args.s1 - args.p1; + const int bh = MAX(0, start_h); + const int eh = MIN(args.IH, start_h + args.k1); + const int start_w = cur_ow * args.s0 - args.p0; + const int bw = MAX(0, start_w); + const int ew = MIN(args.IW, start_w + args.k0); + // const float scale = 1. / ((eh - bh) * (ew - bw)); + const float scale = 1. / (args.k0 * args.k1); + + float res = 0; + + for (int i = bh; i < eh; i += 1) { + for (int j = bw; j < ew; j += 1) { + float cur = i_ptr[i * args.IW + j]; + res += cur * scale; + } + } + + o_ptr[cur_oh * args.OW + cur_ow] = res; +} diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..971314debc714ffb805ded0d9d4ac1e0101f0024 --- /dev/null +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -0,0 +1,113 @@ +if (NOT EXISTS $ENV{MUSA_PATH}) + if (NOT EXISTS /opt/musa) + set(MUSA_PATH /usr/local/musa) + else() + set(MUSA_PATH /opt/musa) + endif() +else() + set(MUSA_PATH $ENV{MUSA_PATH}) +endif() + +set(CMAKE_C_COMPILER "${MUSA_PATH}/bin/clang") +set(CMAKE_C_EXTENSIONS OFF) +set(CMAKE_CXX_COMPILER "${MUSA_PATH}/bin/clang++") +set(CMAKE_CXX_EXTENSIONS OFF) + +list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake") + +find_package(MUSAToolkit) + +if (MUSAToolkit_FOUND) + message(STATUS "MUSA Toolkit found") + + if (NOT DEFINED MUSA_ARCHITECTURES) + set(MUSA_ARCHITECTURES "21;22;31") + endif() + message(STATUS "Using MUSA architectures: ${MUSA_ARCHITECTURES}") + + file(GLOB GGML_HEADERS_MUSA "../ggml-cuda/*.cuh") + list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h") + list(APPEND GGML_HEADERS_MUSA "../ggml-musa/mudnn.cuh") + + file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu") + file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") + list(APPEND GGML_SOURCES_MUSA ${SRCS}) + file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") + list(APPEND GGML_SOURCES_MUSA ${SRCS}) + file(GLOB SRCS "../ggml-musa/*.cu") + list(APPEND GGML_SOURCES_MUSA ${SRCS}) + + if (GGML_CUDA_FA_ALL_QUANTS) + file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu") + list(APPEND GGML_SOURCES_MUSA ${SRCS}) + add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) + else() + file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") + list(APPEND GGML_SOURCES_MUSA ${SRCS}) + file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") + list(APPEND GGML_SOURCES_MUSA ${SRCS}) + file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu") + list(APPEND GGML_SOURCES_MUSA ${SRCS}) + endif() + + set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX) + foreach(SOURCE ${GGML_SOURCES_MUSA}) + set(COMPILE_FLAGS "-fsigned-char -x musa -mtgpu") + foreach(ARCH ${MUSA_ARCHITECTURES}) + set(COMPILE_FLAGS "${COMPILE_FLAGS} --cuda-gpu-arch=mp_${ARCH}") + endforeach() + set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS ${COMPILE_FLAGS}) + endforeach() + + ggml_add_backend_library(ggml-musa + ${GGML_HEADERS_MUSA} + ${GGML_SOURCES_MUSA} + ) + + # TODO: do not use CUDA definitions for MUSA + if (NOT GGML_BACKEND_DL) + target_compile_definitions(ggml PUBLIC GGML_USE_CUDA) + endif() + + add_compile_definitions(GGML_USE_MUSA) + add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE}) + + if (GGML_CUDA_FORCE_MMQ) + add_compile_definitions(GGML_CUDA_FORCE_MMQ) + endif() + + if (GGML_CUDA_FORCE_CUBLAS) + add_compile_definitions(GGML_CUDA_FORCE_CUBLAS) + endif() + + if (GGML_CUDA_NO_VMM) + add_compile_definitions(GGML_CUDA_NO_VMM) + endif() + + if (NOT GGML_CUDA_FA) + add_compile_definitions(GGML_CUDA_NO_FA) + endif() + + if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) + add_compile_definitions(GGML_CUDA_F16) + endif() + + if (GGML_CUDA_NO_PEER_COPY) + add_compile_definitions(GGML_CUDA_NO_PEER_COPY) + endif() + + if (GGML_STATIC) + # TODO: mudnn has not provided static libraries yet + target_link_libraries(ggml-musa PRIVATE MUSA::musart_static MUSA::mublas_static) + else() + target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas mudnn) + endif() + + if (GGML_CUDA_NO_VMM) + # No VMM requested, no need to link directly with the musa driver lib (libmusa.so) + else() + target_link_libraries(ggml-musa PRIVATE MUSA::musa_driver) + endif() +else() + message(FATAL_ERROR "MUSA Toolkit not found") +endif() diff --git a/ggml/src/ggml-musa/mudnn.cu b/ggml/src/ggml-musa/mudnn.cu new file mode 100644 index 0000000000000000000000000000000000000000..020c1702c45c0ec4bae3caccb49d1d793b5fc460 --- /dev/null +++ b/ggml/src/ggml-musa/mudnn.cu @@ -0,0 +1,112 @@ +#include +#include + +#include "mudnn.cuh" + +namespace mudnn = musa::dnn; + +// Returns a human-readable error string for mudnn::Status +const char* mudnnGetErrorString(mudnn::Status err) { + switch (err) { + case mudnn::Status::SUCCESS: + return "Success"; + case mudnn::Status::INVALID_PARAMETER: + return "Invalid parameter"; + case mudnn::Status::NOT_INITIALIZED: + return "Not initialized"; + case mudnn::Status::ALLOC_FAILED: + return "Allocation failed"; + case mudnn::Status::NOT_SUPPORTED: + return "Not supported"; + case mudnn::Status::INTERNAL_ERROR: + return "Internal error"; + case mudnn::Status::ARCH_MISMATCH: + return "Architecture mismatch"; + case mudnn::Status::EXECUTION_FAILED: + return "Execution failed"; + default: + return "Unknown mudnn status"; + } +} + +// Error checking macro for MUDNN calls +#define MUDNN_CHECK(err) CUDA_CHECK_GEN(err, mudnn::Status::SUCCESS, mudnnGetErrorString) + +namespace { + // Thread-safe cache for mudnn::Handle objects per device + std::unordered_map> handle_cache; + std::mutex handle_cache_mutex; + + mudnn::Handle* get_cached_handle(int device_id) { + std::lock_guard lock(handle_cache_mutex); + auto it = handle_cache.find(device_id); + if (it != handle_cache.end()) { + return it->second.get(); + } + auto handle = std::make_unique(device_id); + mudnn::Handle* handle_ptr = handle.get(); + handle_cache[device_id] = std::move(handle); + return handle_ptr; + } +} + +// Extracts dimensions and strides from a ggml_tensor +int get_ggml_dims_and_strides(const ggml_tensor* tensor, + std::vector& dims, + std::vector& strides) { + const int ndims = ggml_n_dims(tensor); + const size_t element_size = ggml_element_size(tensor); + + dims.resize(ndims); + strides.resize(ndims); + + for (int i = 0; i < ndims; ++i) { + dims[i] = tensor->ne[i]; + strides[i] = tensor->nb[i] / static_cast(element_size); + } + return ndims; +} + +// Converts ggml_type to mudnn::Tensor::Type +mudnn::Tensor::Type ggml_type_to_mudnn_type(ggml_type type) { + switch (type) { + case GGML_TYPE_F32: + return mudnn::Tensor::Type::FLOAT; + case GGML_TYPE_F16: + return mudnn::Tensor::Type::HALF; + + // TODO: Add support for other types + + default: + MUDNN_CHECK(mudnn::Status::NOT_SUPPORTED); + } + + return mudnn::Tensor::Type::FLOAT; // Default fallback +} + +// Asynchronous memory copy using mudnn::Unary::IDENTITY +musaError_t mudnnMemcpyAsync(ggml_backend_cuda_context& ctx, const ggml_tensor* dst, const ggml_tensor* src) { + mudnn::Tensor tensor_dst, tensor_src; + + MUDNN_CHECK(tensor_dst.SetType(ggml_type_to_mudnn_type(dst->type))); + MUDNN_CHECK(tensor_src.SetType(ggml_type_to_mudnn_type(src->type))); + + std::vector dims, strides; + const int ndims = get_ggml_dims_and_strides(src, dims, strides); + + MUDNN_CHECK(tensor_dst.SetNdInfo(ndims, dims.data(), strides.data())); + MUDNN_CHECK(tensor_src.SetNdInfo(ndims, dims.data(), strides.data())); + MUDNN_CHECK(tensor_dst.SetAddr(dst->data)); + MUDNN_CHECK(tensor_src.SetAddr(src->data)); + + mudnn::Unary op; + MUDNN_CHECK(op.SetMode(mudnn::Unary::Mode::IDENTITY)); + MUDNN_CHECK(op.SetAlpha(0.0f)); + MUDNN_CHECK(op.SetBeta(0.0f)); + + mudnn::Handle* handle = get_cached_handle(ctx.device); + MUDNN_CHECK(handle->SetStream(ctx.stream())); + MUDNN_CHECK(op.Run(*handle, tensor_dst, tensor_src)); + + return musaSuccess; +} diff --git a/ggml/src/ggml-musa/mudnn.cuh b/ggml/src/ggml-musa/mudnn.cuh new file mode 100644 index 0000000000000000000000000000000000000000..a63be5755c79ca00445f36f1921604995921adee --- /dev/null +++ b/ggml/src/ggml-musa/mudnn.cuh @@ -0,0 +1,12 @@ +#pragma once + +#include "../include/ggml.h" +#include "../ggml-cuda/common.cuh" + +// Asynchronously copies data from src tensor to dst tensor using the provided context. +// Returns a musaError_t indicating success or failure. +musaError_t mudnnMemcpyAsync( + ggml_backend_cuda_context &ctx, + const ggml_tensor *dst, + const ggml_tensor *src +); diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..0e2a419649cea21b4373438acd151e897e0b6f92 --- /dev/null +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -0,0 +1,109 @@ +find_package(OpenCL REQUIRED) +find_package(Python3 REQUIRED) + +set(TARGET_NAME ggml-opencl) + +ggml_add_backend_library(${TARGET_NAME} + ggml-opencl.cpp + ../../include/ggml-opencl.h) +target_link_libraries(${TARGET_NAME} PRIVATE ${OpenCL_LIBRARIES}) +target_include_directories(${TARGET_NAME} PRIVATE ${OpenCL_INCLUDE_DIRS}) + +if (GGML_OPENCL_PROFILING) + message(STATUS "OpenCL profiling enabled (increases CPU overhead)") + add_compile_definitions(GGML_OPENCL_PROFILING) +endif () + +add_compile_definitions(GGML_OPENCL_SOA_Q) +add_compile_definitions(GGML_OPENCL_TARGET_VERSION=${GGML_OPENCL_TARGET_VERSION}) + +if (GGML_OPENCL_USE_ADRENO_KERNELS) + message(STATUS "OpenCL will use matmul kernels optimized for Adreno") + add_compile_definitions(GGML_OPENCL_USE_ADRENO_KERNELS) +endif () + +if (GGML_OPENCL_EMBED_KERNELS) + add_compile_definitions(GGML_OPENCL_EMBED_KERNELS) + + set(EMBED_KERNEL_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/kernels/embed_kernel.py") + file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/autogenerated") + + target_include_directories(${TARGET_NAME} PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/autogenerated") +endif () + +function(ggml_opencl_add_kernel KNAME) + set(KERN_HDR ${CMAKE_CURRENT_BINARY_DIR}/autogenerated/${KNAME}.cl.h) + set(KERN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/kernels/${KNAME}.cl) + + if (GGML_OPENCL_EMBED_KERNELS) + message(STATUS "opencl: embedding kernel ${KNAME}") + + # Python must be accessible from command line + add_custom_command( + OUTPUT ${KERN_HDR} + COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT} ${KERN_SRC} ${KERN_HDR} + DEPENDS ${KERN_SRC} ${EMBED_KERNEL_SCRIPT} + COMMENT "Generate ${KERN_HDR}" + ) + + target_sources(${TARGET_NAME} PRIVATE ${KERN_HDR}) + else () + message(STATUS "opencl: adding kernel ${KNAME}") + configure_file(${KERN_SRC} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${KNAME}.cl COPYONLY) + endif () +endfunction() + +set(GGML_OPENCL_KERNELS + add + argsort + clamp + cpy + cvt + diag_mask_inf + div + gelu + gemv_noshuffle_general + gemv_noshuffle + get_rows + group_norm + im2col_f32 + im2col_f16 + mul_mat_Ab_Bi_8x4 + mul_mv_f16_f16 + mul_mv_f16_f32_1row + mul_mv_f16_f32_l4 + mul_mv_f16_f32 + mul_mv_f32_f32 + mul_mv_q4_0_f32 + mul_mv_q4_0_f32_v + mul_mv_q4_0_f32_8x_flat + mul_mv_q4_0_f32_1d_8x_flat + mul_mv_q4_0_f32_1d_16x_flat + mul_mv_q6_k + mul_mv_id_q4_0_f32_8x_flat + mul + norm + relu + rms_norm + rope + scale + sigmoid + silu + softmax_4_f32 + softmax_4_f16 + softmax_f32 + softmax_f16 + sub + sum_rows + transpose + concat + tsembd + upscale + tanh + pad + repeat +) + +foreach (K ${GGML_OPENCL_KERNELS}) + ggml_opencl_add_kernel(${K}) +endforeach() diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..628e574f0f71e6a7b0a023a5f00876d36463bf1a --- /dev/null +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -0,0 +1,6665 @@ +#define CL_TARGET_OPENCL_VERSION GGML_OPENCL_TARGET_VERSION +#define CL_USE_DEPRECATED_OPENCL_1_2_APIS + +// suppress warnings in CL headers for GCC and Clang +#pragma GCC diagnostic ignored "-Woverlength-strings" +#ifdef __clang__ +#pragma GCC diagnostic ignored "-Wgnu-anonymous-struct" +#endif + +#include "ggml-opencl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" +#include "ggml-backend-impl.h" +#include "ggml.h" + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#undef MIN +#undef MAX +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +#define UNUSED(x) (void)(x) + +#define CL_CHECK(err) \ + do { \ + cl_int err_ = (err); \ + if (err_ != CL_SUCCESS) { \ + GGML_LOG_ERROR("ggml_opencl: %s error %d at %s:%d\n", \ + #err, err_, __FILE__, __LINE__); \ + GGML_ASSERT(0); \ + } \ + } while (0) + +//------------------------------------------------------------------------------ +// OpenCL +//------------------------------------------------------------------------------ + +bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor); + +enum GPU_FAMILY { + ADRENO, + INTEL, + UNKNOWN, +}; + +enum ADRENO_GPU_GEN { + ADRENO_UNKNOWN, + A7X, + A8X, + X1E, +}; + +enum ADRENO_CL_COMPILER_TYPE { + E031, + DX, +}; + +struct ggml_cl_version { + cl_uint major = 0; + cl_uint minor = 0; +}; + + +struct ggml_cl_compiler_version { + ADRENO_CL_COMPILER_TYPE type; + int major = -1; + int minor = -1; + int patch = -1; + + bool same(ADRENO_CL_COMPILER_TYPE t, int x, int y, int z) const { + return major == x && minor == y && patch == z && type == t; + } + bool newer_than(ADRENO_CL_COMPILER_TYPE t, int x, int y, int z) const { + return major*10000 + minor*100 + patch > x*10000 + y*100 + z && type == t; + } + bool newer_than_or_same(ADRENO_CL_COMPILER_TYPE t, int x, int y, int z) const { + return same(t, x, y, z) || newer_than(t, x, y, z); + } +}; + +static size_t align_to(size_t value, size_t to_alignment) { + GGML_ASSERT(to_alignment && "Invalid alignment (must be non-zero)"); + GGML_ASSERT((to_alignment & (to_alignment - 1)) == 0 && "to_alignment must be power-of-two"); + + return ((value + to_alignment - 1) / to_alignment) * to_alignment; +} + + +// Parses a version string of form "XX.YY ". On an error returns ggml_cl_version with all zeroes. +static ggml_cl_version parse_cl_version(std::string_view str) { + size_t major_str_begin = 0; + size_t major_str_end = str.find(".", major_str_begin); + if (major_str_end == std::string::npos) { + return {}; + } + + size_t minor_str_begin = major_str_end + 1; + size_t minor_str_end = str.find(" ", minor_str_begin); + if (minor_str_end == std::string::npos) { + return {}; + } + + cl_uint version_major; + if (std::from_chars(str.data() + major_str_begin, str.data() + major_str_end, version_major).ec != std::errc{}) { + return {}; + } + + cl_uint version_minor; + if (std::from_chars(str.data() + minor_str_begin, str.data() + minor_str_end, version_minor).ec != std::errc{}) { + return {}; + } + return { version_major, version_minor }; +} + +// Returns OpenCL platform's version. On an error returns ggml_cl_version with all zeroes. +static ggml_cl_version get_opencl_platform_version(cl_platform_id platform) { + size_t param_size; + CL_CHECK(clGetPlatformInfo(platform, CL_PLATFORM_VERSION, 0, nullptr, ¶m_size)); + std::unique_ptr param_storage(new char[param_size]); + CL_CHECK(clGetPlatformInfo(platform, CL_PLATFORM_VERSION, param_size, param_storage.get(), nullptr)); + + auto param_value = std::string_view(param_storage.get(), param_size); + const std::string version_prefix = "OpenCL "; // Suffix: "XX.YY " + if (param_value.find(version_prefix) != 0) { + return {}; + } + param_value.remove_prefix(version_prefix.length()); + return parse_cl_version(param_value); +} + +// Return a version to use in OpenCL C compilation. On an error returns ggml_cl_version with all zeroes. +static ggml_cl_version get_opencl_c_version(ggml_cl_version platform_version, cl_device_id device) { + size_t param_size; + +#if CL_TARGET_OPENCL_VERSION >= 300 + if (platform_version.major >= 3) { + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_OPENCL_C_ALL_VERSIONS, 0, nullptr, ¶m_size)); + if (!param_size) { + return {}; + } + + std::unique_ptr versions(new cl_name_version[param_size]); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_OPENCL_C_ALL_VERSIONS, param_size, versions.get(), nullptr)); + unsigned versions_count = param_size / sizeof(cl_name_version); + + cl_version version_max = 0; + for (unsigned i = 0; i < versions_count; i++) { + version_max = std::max(versions[i].version, version_max); + } + + return { CL_VERSION_MAJOR(version_max), CL_VERSION_MINOR(version_max) }; + } +#else + GGML_UNUSED(platform_version); +#endif // CL_TARGET_OPENCL_VERSION >= 300 + + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_OPENCL_C_VERSION, 0, nullptr, ¶m_size)); + if (!param_size) { + return {}; + } + + std::unique_ptr param_storage(new char[param_size]); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_OPENCL_C_VERSION, param_size, param_storage.get(), nullptr)); + auto param_value = std::string_view(param_storage.get(), param_size); + + const std::string version_prefix = "OpenCL C "; // Suffix: "XX.YY " + if (param_value.find(version_prefix) != 0) { + return {}; + } + param_value.remove_prefix(version_prefix.length()); + + return parse_cl_version(param_value); +} + +static ADRENO_GPU_GEN get_adreno_gpu_gen(const char *device_name) { + if (strstr(device_name, "730") || + strstr(device_name, "740") || + strstr(device_name, "750")) { + return ADRENO_GPU_GEN::A7X; + } + + if (strstr(device_name, "830")) { + return ADRENO_GPU_GEN::A8X; + } + + if (strstr(device_name, "X1")) { + return ADRENO_GPU_GEN::X1E; + } + + return ADRENO_GPU_GEN::ADRENO_UNKNOWN; +} + +static ggml_cl_compiler_version get_adreno_cl_compiler_version(const char *driver_version) { + std::string driver_ver_str(driver_version); + ADRENO_CL_COMPILER_TYPE type = ADRENO_CL_COMPILER_TYPE::E031; + size_t compiler_ver_pos = driver_ver_str.find("E031"); + size_t compiler_ver_len = 13; + size_t compiler_major_offset = 5; + size_t compiler_minor_offset = 8; + size_t compiler_patch_offset = 11; + + if (compiler_ver_pos == std::string::npos) { + compiler_ver_pos = driver_ver_str.find("DX"); + if (compiler_ver_pos == std::string::npos) { + return {}; + } + type = ADRENO_CL_COMPILER_TYPE::DX; + compiler_ver_len = 11; + compiler_major_offset = 3; + } + + std::string compiler_ver_str = driver_ver_str.substr(compiler_ver_pos, compiler_ver_len); + int major = std::atoi(compiler_ver_str.substr(compiler_major_offset, 2).c_str()); + int minor = std::atoi(compiler_ver_str.substr(compiler_minor_offset, 2).c_str()); + int patch = std::atoi(compiler_ver_str.substr(compiler_patch_offset, 2).c_str()); + return { type, major, minor, patch }; +} + +struct ggml_backend_opencl_context; + +// backend device context +struct ggml_backend_opencl_device_context { + cl_platform_id platform; + std::string platform_name; + + cl_device_id device; + std::string device_name; + cl_device_type device_type; + std::string device_version; + + // Initialized by ggml_cl2_init(). + ggml_backend_opencl_context * backend_ctx = nullptr; + + // Initialized by ggml_backend_opencl_device_get_buffer_type() + ggml_backend_buffer_type buffer_type; + + cl_context context = nullptr; +}; + +// backend context +struct ggml_backend_opencl_context { + cl_device_id device; + std::string device_name; + + std::string driver_version; + + GPU_FAMILY gpu_family; + ADRENO_GPU_GEN adreno_gen; + + cl_int alignment; + size_t max_alloc_size; + bool fp16_support; + bool has_vector_subgroup_broadcast; + ggml_cl_compiler_version adreno_cl_compiler_version; + + int adreno_wave_size; + + cl_bool non_uniform_workgroups; + + cl_context context; + cl_command_queue queue; + + cl_program program_add; + cl_program program_clamp; + cl_program program_cpy; + cl_program program_cvt; + cl_program program_diag_mask_inf; + cl_program program_gelu; + cl_program program_gemv_noshuffle_general; + cl_program program_gemv_noshuffle; + cl_program program_get_rows; + cl_program program_im2col_f16; + cl_program program_im2col_f32; + cl_program program_mul_mat_Ab_Bi_8x4; + cl_program program_mul_mv_q4_0_f32; + cl_program program_mul_mv_q4_0_f32_v; + cl_program program_mul_mv_q4_0_f32_8x_flat; + cl_program program_mul_mv_q4_0_f32_1d_8x_flat; + cl_program program_mul_mv_q4_0_f32_1d_16x_flat; + cl_program program_mul_mv_q6_K; + cl_program program_mul_mv_f16_f16; + cl_program program_mul_mv_f16_f32_1row; + cl_program program_mul_mv_f16_f32_l4; + cl_program program_mul_mv_f16_f32; + cl_program program_mul_mv_f32_f32; + cl_program program_mul; + cl_program program_div; + cl_program program_sub; + cl_program program_norm; + cl_program program_relu; + cl_program program_rms_norm; + cl_program program_group_norm; + cl_program program_rope; + cl_program program_scale; + cl_program program_silu; + cl_program program_sigmoid; + cl_program program_softmax_f32; + cl_program program_softmax_f16; + cl_program program_softmax_4_f32; + cl_program program_softmax_4_f16; + cl_program program_argsort_f32_i32; + cl_program program_sum_rows_f32; + cl_program program_repeat; + cl_program program_pad; + cl_program program_tanh; + cl_program program_upscale; + cl_program program_concat; + cl_program program_tsembd; + cl_program program_mul_mv_id_q4_0_f32_8x_flat; + + cl_kernel kernel_add, kernel_add_row; + cl_kernel kernel_mul, kernel_mul_row; + cl_kernel kernel_div, kernel_div_row; + cl_kernel kernel_sub, kernel_sub_row; + cl_kernel kernel_scale; + cl_kernel kernel_silu, kernel_silu_4; + cl_kernel kernel_gelu, kernel_gelu_4; + cl_kernel kernel_gelu_quick, kernel_gelu_quick_4; + cl_kernel kernel_relu; + cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16; + cl_kernel kernel_clamp; + cl_kernel kernel_norm; + cl_kernel kernel_rms_norm; + cl_kernel kernel_group_norm; + cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8; + cl_kernel kernel_soft_max, kernel_soft_max_4; + cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16; + cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0; + cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16; + cl_kernel kernel_rope_multi_f32, kernel_rope_multi_f16, kernel_rope_vision_f32, kernel_rope_vision_f16; + cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32; + cl_kernel kernel_mul_mat_f32_f32; + cl_kernel kernel_mul_mat_f16_f16; + cl_kernel kernel_mul_mat_f16_f32_1row; + cl_kernel kernel_mul_mat_f16_f32; + cl_kernel kernel_mul_mat_f16_f32_l4; + cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v; + cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; + cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; + cl_kernel kernel_convert_block_q4_0_noshuffle; + cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; + cl_kernel kernel_mul_mv_q6_K_f32; + cl_kernel kernel_im2col_f32, kernel_im2col_f16; + cl_kernel kernel_argsort_f32_i32; + cl_kernel kernel_sum_rows_f32; + cl_kernel kernel_repeat; + cl_kernel kernel_pad; + cl_kernel kernel_tanh_f32_nd; + cl_kernel kernel_tanh_f16_nd; + cl_kernel kernel_upscale; + cl_kernel kernel_upscale_bilinear; + cl_kernel kernel_concat_f32_contiguous; + cl_kernel kernel_concat_f32_non_contiguous; + cl_kernel kernel_timestep_embedding; + cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Transpose kernels + cl_program program_transpose; + + cl_kernel kernel_transpose_32; + cl_kernel kernel_transpose_32_16; + cl_kernel kernel_transpose_16; + + cl_mem A_s_d_max; // max scale buffer size for transpose + cl_mem A_q_d_max; // max weight buffer size for transpose + cl_mem B_d_max; // max activation buffer size for transpose + + // Gemm and Gemv related programs, kernels, etc + cl_program program_CL_gemm; + cl_program program_CL_gemv_general; + cl_program program_CL_gemv_4096_1_11008; + cl_program program_CL_gemv_4096_1_4096; + cl_program program_CL_gemv_11008_1_4096; + cl_program program_CL_gemv_32000_1_4096; + cl_kernel CL_mul_mat_Ab_Bi_8x4; + cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general; + cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008; + cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096; + cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096; + cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096; +#endif // GGML_OPENCL_USE_ADRENO_KERNELS +}; + +// All registered devices with a default device in the front. +static std::vector g_ggml_backend_opencl_devices; + +// Profiling +#ifdef GGML_OPENCL_PROFILING +struct ProfilingInfo { + std::string op_name; + std::string kernel_name; + + cl_kernel kernel; + cl_event evt; + + cl_ulong cmd_queued; + cl_ulong cmd_submit; + cl_ulong cmd_start; + cl_ulong cmd_end; + cl_ulong overhead_start; + cl_ulong overhead_end; + // For the times below, see spec for clGetEventProfilingInfo + // The time kernel spent in cmd queue - SUBMIT - QUEUED + cl_ulong cmd_queued_duration_ns; + // The time kernel spent for submission - START - SUBMIT + cl_ulong cmd_submit_duration_ns; + // Kernel execution time in nanoseconds - END - START + cl_ulong cmd_duration_ns; + // The time for the kernel to complete - COMPLETE - END + cl_ulong cmd_complete_duration_ns; + // Total time to finish the kernel - COMPELTE - QUEUED + cl_ulong cmd_total_duration_ns; + // Global and local work sizes. + size_t global_size[3]; + size_t local_size[3]; + // Op output size. + size_t output_size[4]; +}; + +std::vector g_profiling_info; +#endif + +inline std::string read_file(const std::string &path) { + std::ifstream ifs(path); + if (!ifs) { + return ""; + } + std::string text; + ifs.seekg(0, std::ios::end); + text.resize(ifs.tellg()); + ifs.seekg(0, std::ios::beg); + ifs.read(&text[0], text.size()); + return text; +} + +static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer, const std::string &compile_opts) { + cl_program p; + char *program_log; + size_t program_size; + size_t log_size; + int err; + + program_size = strlen(program_buffer); + + p = clCreateProgramWithSource(ctx, 1, (const char**)&program_buffer, &program_size, &err); + if(err < 0) { + GGML_LOG_ERROR("OpenCL error creating program"); + exit(1); + } + + err = clBuildProgram(p, 0, NULL, compile_opts.c_str(), NULL, NULL); + if(err < 0) { + clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size); + program_log = (char*) malloc(log_size + 1); + program_log[log_size] = '\0'; + clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, log_size + 1, program_log, NULL); + GGML_LOG_ERROR("ggml_opencl: kernel compile error:\n\n%s\n", program_log); + free(program_log); + exit(1); + } + + return p; +} + +static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_version opencl_c_version) { + cl_int err; + + // compiler options for general kernels + auto opencl_c_std = + std::string("CL") + std::to_string(opencl_c_version.major) + "." + std::to_string(opencl_c_version.minor); + std::string compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable -cl-unsafe-math-optimizations" + " -cl-finite-math-only -cl-fast-relaxed-math"; + + GGML_LOG_INFO("ggml_opencl: loading OpenCL kernels"); + + // add + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "add.cl.h" + }; +#else + const std::string kernel_src = read_file("add.cl"); +#endif + backend_ctx->program_add = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_add = clCreateKernel(backend_ctx->program_add, "kernel_add", &err), err)); + CL_CHECK((backend_ctx->kernel_add_row = clCreateKernel(backend_ctx->program_add, "kernel_add_row", &err), err)); + GGML_LOG_CONT("."); + } + + // clamp + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "clamp.cl.h" + }; +#else + const std::string kernel_src = read_file("clamp.cl"); +#endif + backend_ctx->program_clamp = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_clamp = clCreateKernel(backend_ctx->program_clamp, "kernel_clamp", &err), err)); + GGML_LOG_CONT("."); + } + + // cpy + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "cpy.cl.h" + }; +#else + const std::string kernel_src = read_file("cpy.cl"); +#endif + backend_ctx->program_cpy = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f16_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f16_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f32_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f32_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // cvt + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "cvt.cl.h" + }; +#else + const std::string kernel_src = read_file("cvt.cl"); +#endif + backend_ctx->program_cvt = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err)); + GGML_LOG_CONT("."); + } + + // diag_mask_inf + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "diag_mask_inf.cl.h" + }; +#else + const std::string kernel_src = read_file("diag_mask_inf.cl"); +#endif + backend_ctx->program_diag_mask_inf = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_diag_mask_inf_8 = clCreateKernel(backend_ctx->program_diag_mask_inf, "kernel_diag_mask_inf_8", &err), err)); + CL_CHECK((backend_ctx->kernel_diag_mask_inf = clCreateKernel(backend_ctx->program_diag_mask_inf, "kernel_diag_mask_inf", &err), err)); + GGML_LOG_CONT("."); + } + + // gelu + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gelu.cl.h" + }; +#else + const std::string kernel_src = read_file("gelu.cl"); +#endif + backend_ctx->program_gelu = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_4", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_quick = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_quick_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick_4", &err), err)); + GGML_LOG_CONT("."); + } + + // get_rows + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "get_rows.cl.h" + }; +#else + const std::string kernel_src = read_file("get_rows.cl"); +#endif + backend_ctx->program_get_rows = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_get_rows_f32 = clCreateKernel(backend_ctx->program_get_rows, "kernel_get_rows_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_get_rows_f16 = clCreateKernel(backend_ctx->program_get_rows, "kernel_get_rows_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_get_rows_q4_0 = clCreateKernel(backend_ctx->program_get_rows, "kernel_get_rows_q4_0", &err), err)); + GGML_LOG_CONT("."); + } + + // im2col_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "im2col_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("im2col_f32.cl"); +#endif + backend_ctx->program_im2col_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_im2col_f32 = clCreateKernel(backend_ctx->program_im2col_f32, "kernel_im2col_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // im2col_f16 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "im2col_f16.cl.h" + }; +#else + const std::string kernel_src = read_file("im2col_f16.cl"); +#endif + backend_ctx->program_im2col_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_im2col_f16 = clCreateKernel(backend_ctx->program_im2col_f16, "kernel_im2col_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32 = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32, "kernel_mul_mat_q4_0_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32_v + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32_v.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32_v.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32_v = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_v = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_v, "kernel_mul_mat_q4_0_f32_v", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32_8x_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32_8x_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32_8x_flat.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32_8x_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_8x_flat, "kernel_mul_mat_q4_0_f32_8x_flat", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32_1d_8x_flat + // This kernel does not compiler on Adreno cl compiler 38.01. Skip it for + // those compiler versions since it is anyway not used for Adreno. + if (backend_ctx->gpu_family != ADRENO || + backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) || + backend_ctx->adreno_cl_compiler_version.type == DX) { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32_1d_8x_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32_1d_8x_flat.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32_1d_8x_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_8x_flat = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_1d_8x_flat, "kernel_mul_mat_q4_0_f32_1d_8x_flat", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32_1d_16x_flat + // This kernel does not compiler on Adreno cl compiler 38.01. Skip it for + // those compiler versions since it is anyway not used for Adreno. + if (backend_ctx->gpu_family != ADRENO || + backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) || + backend_ctx->adreno_cl_compiler_version.type == DX) { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32_1d_16x_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32_1d_16x_flat.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32_1d_16x_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_16x_flat = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_1d_16x_flat, "kernel_mul_mat_q4_0_f32_1d_16x_flat", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q6_k + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q6_k.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q6_k.cl"); +#endif + backend_ctx->program_mul_mv_q6_K = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32 = clCreateKernel(backend_ctx->program_mul_mv_q6_K, "kernel_mul_mv_q6_K_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f16_f16 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f16_f16.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f16_f16.cl"); +#endif + backend_ctx->program_mul_mv_f16_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f16 = clCreateKernel(backend_ctx->program_mul_mv_f16_f16, "kernel_mul_mat_f16_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f16_f32_1row + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f16_f32_1row.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f16_f32_1row.cl"); +#endif + backend_ctx->program_mul_mv_f16_f32_1row = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_1row = clCreateKernel(backend_ctx->program_mul_mv_f16_f32_1row, "kernel_mul_mat_f16_f32_1row", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f16_f32_l4 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f16_f32_l4.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f16_f32_l4.cl"); +#endif + backend_ctx->program_mul_mv_f16_f32_l4 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_l4 = clCreateKernel(backend_ctx->program_mul_mv_f16_f32_l4, "kernel_mul_mat_f16_f32_l4", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f16_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f16_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f16_f32.cl"); +#endif + backend_ctx->program_mul_mv_f16_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32 = clCreateKernel(backend_ctx->program_mul_mv_f16_f32, "kernel_mul_mat_f16_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f32_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f32_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f32_f32.cl"); +#endif + backend_ctx->program_mul_mv_f32_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f32_f32 = clCreateKernel(backend_ctx->program_mul_mv_f32_f32, "kernel_mul_mat_f32_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // mul + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul.cl.h" + }; +#else + const std::string kernel_src = read_file("mul.cl"); +#endif + backend_ctx->program_mul = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul = clCreateKernel(backend_ctx->program_mul, "kernel_mul", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_row = clCreateKernel(backend_ctx->program_mul, "kernel_mul_row", &err), err)); + GGML_LOG_CONT("."); + } + + // norm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "norm.cl.h" + }; +#else + const std::string kernel_src = read_file("norm.cl"); +#endif + backend_ctx->program_norm = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err)); + GGML_LOG_CONT("."); + } + + // relu + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "relu.cl.h" + }; +#else + const std::string kernel_src = read_file("relu.cl"); +#endif + backend_ctx->program_relu = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_relu = clCreateKernel(backend_ctx->program_relu, "kernel_relu", &err), err)); + GGML_LOG_CONT("."); + } + + // rms_norm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "rms_norm.cl.h" + }; +#else + const std::string kernel_src = read_file("rms_norm.cl"); +#endif + backend_ctx->program_rms_norm = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err)); + GGML_LOG_CONT("."); + } + + // rope + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "rope.cl.h" + }; +#else + const std::string kernel_src = read_file("rope.cl"); +#endif + backend_ctx->program_rope = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_rope_norm_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_norm_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_norm_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_norm_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_neox_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_neox_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_neox_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_neox_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_multi_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_multi_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_multi_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_multi_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_vision_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_vision_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_vision_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_vision_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // scale + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "scale.cl.h" + }; +#else + const std::string kernel_src = read_file("scale.cl"); +#endif + backend_ctx->program_scale = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_scale = clCreateKernel(backend_ctx->program_scale, "kernel_scale", &err), err)); + GGML_LOG_CONT("."); + } + + // silu + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "silu.cl.h" + }; +#else + const std::string kernel_src = read_file("silu.cl"); +#endif + backend_ctx->program_silu = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_silu = clCreateKernel(backend_ctx->program_silu, "kernel_silu", &err), err)); + CL_CHECK((backend_ctx->kernel_silu_4 = clCreateKernel(backend_ctx->program_silu, "kernel_silu_4", &err), err)); + GGML_LOG_CONT("."); + } + + // softmax_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "softmax_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("softmax_f32.cl"); +#endif + backend_ctx->program_softmax_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_soft_max = clCreateKernel(backend_ctx->program_softmax_f32, "kernel_soft_max", &err), err)); + GGML_LOG_CONT("."); + } + + // softmax_f16 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "softmax_f16.cl.h" + }; +#else + const std::string kernel_src = read_file("softmax_f16.cl"); +#endif + backend_ctx->program_softmax_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_soft_max_f16 = clCreateKernel(backend_ctx->program_softmax_f16, "kernel_soft_max_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // softmax_4_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "softmax_4_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("softmax_4_f32.cl"); +#endif + backend_ctx->program_softmax_4_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_soft_max_4 = clCreateKernel(backend_ctx->program_softmax_4_f32, "kernel_soft_max_4", &err), err)); + GGML_LOG_CONT("."); + } + + // softmax_4_f16 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "softmax_4_f16.cl.h" + }; +#else + const std::string kernel_src = read_file("softmax_4_f16.cl"); +#endif + backend_ctx->program_softmax_4_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_soft_max_4_f16 = clCreateKernel(backend_ctx->program_softmax_4_f16, "kernel_soft_max_4_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // argsort + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "argsort.cl.h" + }; +#else + const std::string kernel_src = read_file("argsort.cl"); +#endif + backend_ctx->program_argsort_f32_i32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_argsort_f32_i32 = clCreateKernel(backend_ctx->program_argsort_f32_i32, "kernel_argsort_f32_i32", &err), err)); + GGML_LOG_CONT("."); + } + + // div + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "div.cl.h" + }; +#else + const std::string kernel_src = read_file("div.cl"); +#endif + backend_ctx->program_div = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_div = clCreateKernel(backend_ctx->program_div, "kernel_div", &err), err)); + CL_CHECK((backend_ctx->kernel_div_row = clCreateKernel(backend_ctx->program_div, "kernel_div_row", &err), err)); + GGML_LOG_CONT("."); + } + + // sub + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "sub.cl.h" + }; +#else + const std::string kernel_src = read_file("sub.cl"); +#endif + backend_ctx->program_sub = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_sub = clCreateKernel(backend_ctx->program_sub, "kernel_sub", &err), err)); + CL_CHECK((backend_ctx->kernel_sub_row = clCreateKernel(backend_ctx->program_sub, "kernel_sub_row", &err), err)); + GGML_LOG_CONT("."); + } + + // sum_rows + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "sum_rows.cl.h" + }; +#else + const std::string kernel_src = read_file("sum_rows.cl"); +#endif + backend_ctx->program_sum_rows_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_sum_rows_f32 = clCreateKernel(backend_ctx->program_sum_rows_f32, "kernel_sum_rows_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // sigmoid + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "sigmoid.cl.h" + }; +#else + const std::string kernel_src = read_file("sigmoid.cl"); +#endif + backend_ctx->program_sigmoid = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_sigmoid_f32 = clCreateKernel(backend_ctx->program_sigmoid, "kernel_sigmoid_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_sigmoid_f16 = clCreateKernel(backend_ctx->program_sigmoid, "kernel_sigmoid_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // group_norm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "group_norm.cl.h" + }; +#else + const std::string kernel_src = read_file("group_norm.cl"); +#endif + backend_ctx->program_group_norm = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_group_norm = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm", &err), err)); + GGML_LOG_CONT("."); + } + + // repeat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "repeat.cl.h" + }; +#else + const std::string kernel_src = read_file("repeat.cl"); +#endif + if (!kernel_src.empty()) { + backend_ctx->program_repeat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_repeat = clCreateKernel(backend_ctx->program_repeat, "kernel_repeat", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: repeat kernel source not found or empty. Repeat operations will not be available.\n"); + backend_ctx->program_repeat = nullptr; + backend_ctx->kernel_repeat = nullptr; + } + } + + // pad + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "pad.cl.h" + }; +#else + const std::string kernel_src = read_file("pad.cl"); +#endif + if (!kernel_src.empty()) { + backend_ctx->program_pad = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_pad = clCreateKernel(backend_ctx->program_pad, "kernel_pad", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: pad kernel source not found or empty. Pad operations will not be available.\n"); + backend_ctx->program_pad = nullptr; + backend_ctx->kernel_pad = nullptr; + } + } + + // tanh + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "tanh.cl.h" + }; +#else + const std::string kernel_src = read_file("tanh.cl"); +#endif + if (!kernel_src.empty()) { + backend_ctx->program_tanh = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_tanh_f32_nd = clCreateKernel(backend_ctx->program_tanh, "kernel_tanh_f32_nd", &err), err)); + CL_CHECK((backend_ctx->kernel_tanh_f16_nd = clCreateKernel(backend_ctx->program_tanh, "kernel_tanh_f16_nd", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: tanh kernel source not found or empty. Tanh operation will not be available.\n"); + backend_ctx->program_tanh = nullptr; + backend_ctx->kernel_tanh_f32_nd = nullptr; + backend_ctx->kernel_tanh_f16_nd = nullptr; + } + } + + // upscale + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "upscale.cl.h" + }; +#else + const std::string kernel_src = read_file("upscale.cl"); +#endif + if (!kernel_src.empty()) { + backend_ctx->program_upscale = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_upscale = clCreateKernel(backend_ctx->program_upscale, "kernel_upscale", &err), err)); + if (backend_ctx->program_upscale) { + cl_int err_bilinear; + backend_ctx->kernel_upscale_bilinear = clCreateKernel(backend_ctx->program_upscale, "kernel_upscale_bilinear", &err_bilinear); + if (err_bilinear != CL_SUCCESS) { + GGML_LOG_WARN("ggml_opencl: kernel_upscale_bilinear not found in upscale.cl. Bilinear upscale will not be available. Error: %d\n", err_bilinear); + backend_ctx->kernel_upscale_bilinear = nullptr; + } + } else { + backend_ctx->kernel_upscale_bilinear = nullptr; + } + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: upscale kernel source not found or empty. Upscale operations will not be available.\n"); + backend_ctx->program_upscale = nullptr; + backend_ctx->kernel_upscale = nullptr; + backend_ctx->kernel_upscale_bilinear = nullptr; + } + } + + // concat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "concat.cl.h" + }; +#else + + const std::string kernel_src = read_file("concat.cl"); +#endif + if (!kernel_src.empty()) { + backend_ctx->program_concat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_concat_f32_contiguous = clCreateKernel(backend_ctx->program_concat, "kernel_concat_f32_contiguous", &err), err)); + CL_CHECK((backend_ctx->kernel_concat_f32_non_contiguous = clCreateKernel(backend_ctx->program_concat, "kernel_concat_f32_non_contiguous", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: concat kernel source not found or empty. Concat operations will not be available.\n"); + backend_ctx->program_concat = nullptr; + backend_ctx->kernel_concat_f32_contiguous = nullptr; + backend_ctx->kernel_concat_f32_non_contiguous = nullptr; + } + } + + // timestep_embedding + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "tsembd.cl.h" + }; +#else + + const std::string kernel_src = read_file("tsembd.cl"); +#endif + if (!kernel_src.empty()) { + backend_ctx->program_tsembd = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_timestep_embedding = clCreateKernel(backend_ctx->program_tsembd, "kernel_timestep_embedding", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: timestep_embedding kernel source not found or empty. This op will not be available.\n"); + backend_ctx->program_tsembd = nullptr; + backend_ctx->kernel_timestep_embedding = nullptr; + } + } + + // mul_mv_id_q4_0_f32_8x_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_id_q4_0_f32_8x_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_id_q4_0_f32_8x_flat.cl"); +#endif + backend_ctx->program_mul_mv_id_q4_0_f32_8x_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_id_q4_0_f32_8x_flat = clCreateKernel(backend_ctx->program_mul_mv_id_q4_0_f32_8x_flat, "kernel_mul_mv_id_q4_0_f32_8x_flat", &err), err)); + GGML_LOG_CONT("."); + } + + // Adreno kernels +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // transpose + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "transpose.cl.h" + }; +#else + const std::string kernel_src = read_file("transpose.cl"); +#endif + backend_ctx->program_transpose = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32_16", &err), err)); + CL_CHECK((backend_ctx->kernel_transpose_32 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32", &err), err)); + CL_CHECK((backend_ctx->kernel_transpose_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16", &err), err)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_general + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemv_general { + #include "gemv_noshuffle_general.cl.h" + }; +#else + const std::string kernel_src_CL_gemv_general = read_file("gemv_noshuffle_general.cl"); +#endif + + backend_ctx->program_CL_gemv_general = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general = clCreateKernel(backend_ctx->program_CL_gemv_general, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle + { + // Gemv 2048, 16384 + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=2048 " + " -DBLOCK_STRIDE_A=16384 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemv { + #include "gemv_noshuffle.cl.h" + }; +#else + const std::string kernel_src_CL_gemv = read_file("gemv_noshuffle.cl"); +#endif + + backend_ctx->program_CL_gemv_4096_1_4096 = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_4096, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + + // Gemv 2048, 16384 + CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=2048 " + " -DBLOCK_STRIDE_A=16384 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + backend_ctx->program_CL_gemv_4096_1_11008 = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_11008, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + + // Gemv 5504, 44032 + CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=5504 " + " -DBLOCK_STRIDE_A=44032 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + backend_ctx->program_CL_gemv_11008_1_4096 = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_11008_1_4096, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + + // Gemv 16000, 128000 + CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=16000 " + " -DBLOCK_STRIDE_A=128000 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + backend_ctx->program_CL_gemv_32000_1_4096 = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_32000_1_4096, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mat_Ab_Bi_8x4 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemm { + #include "mul_mat_Ab_Bi_8x4.cl.h" + }; +#else + const std::string kernel_src_CL_gemm = read_file("mul_mat_Ab_Bi_8x4.cl"); +#endif + backend_ctx->program_CL_gemm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_CL_gemm.c_str(), compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err)); + GGML_LOG_CONT("."); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + GGML_LOG_CONT("\n"); +} + +// XXX static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { +// XXX static bool initialized = false; +// XXX static ggml_backend_opencl_context *backend_ctx = nullptr; + +static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev); + +namespace /* anonymous */ { +extern struct ggml_backend_device_i ggml_backend_opencl_device_i; +} + +// Look for available and suitable devices. +static std::vector ggml_opencl_probe_devices(ggml_backend_reg * reg) { + std::vector found_devices; + +#ifdef GGML_OPENCL_PROFILING + GGML_LOG_INFO("ggml_opencl: OpenCL profiling enabled\n"); +#endif + + struct cl_device; + struct cl_platform { + cl_platform_id id; + unsigned number; + char name[128]; + char vendor[128]; + struct cl_device * devices; + unsigned n_devices; + struct cl_device * default_device; + }; + + struct cl_device { + struct cl_platform * platform; + cl_device_id id; + unsigned number; + cl_device_type type; + char name[128]; + char version[128]; + }; + + enum { NPLAT = 16, NDEV = 16 }; + + struct cl_platform platforms[NPLAT]; + unsigned n_platforms = 0; + struct cl_device devices[NDEV]; + unsigned n_devices = 0; + struct cl_device * default_device = NULL; + unsigned default_platform_number = 0; + + cl_platform_id platform_ids[NPLAT]; + if (clGetPlatformIDs(NPLAT, platform_ids, &n_platforms) != CL_SUCCESS) { + GGML_LOG_ERROR("ggml_opencl: plaform IDs not available.\n"); + return found_devices; + } + + for (unsigned i = 0; i < n_platforms; i++) { + struct cl_platform * p = &platforms[i]; + p->number = i; + p->id = platform_ids[i]; + CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_NAME, sizeof(p->name), &p->name, NULL)); + CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_VENDOR, sizeof(p->vendor), &p->vendor, NULL)); + + cl_device_id device_ids[NDEV]; + cl_int clGetDeviceIDsError = clGetDeviceIDs(p->id, CL_DEVICE_TYPE_ALL, NDEV, device_ids, &p->n_devices); + if (clGetDeviceIDsError == CL_DEVICE_NOT_FOUND) { + p->n_devices = 0; + } else { + CL_CHECK(clGetDeviceIDsError); + } + p->devices = p->n_devices > 0 ? &devices[n_devices] : NULL; + p->default_device = NULL; + + for (unsigned j = 0; j < p->n_devices; j++) { + struct cl_device * d = &devices[n_devices]; + d->number = n_devices++; + d->id = device_ids[j]; + d->platform = p; + CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_NAME, sizeof(d->name), &d->name, NULL)); + CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_TYPE, sizeof(d->type), &d->type, NULL)); + CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_VERSION, sizeof(d->version), &d->version, NULL)); + + if (p->default_device == NULL && d->type == CL_DEVICE_TYPE_GPU) { + p->default_device = d; + } + } + + if (default_device == NULL && p->default_device != NULL) { + default_device = p->default_device; + default_platform_number = i; + } + } + + if (n_devices == 0) { + GGML_LOG_ERROR("ggml_opencl: could find any OpenCL devices.\n"); + return found_devices; + } + + char * user_platform_string = getenv("GGML_OPENCL_PLATFORM"); + char * user_device_string = getenv("GGML_OPENCL_DEVICE"); + int user_platform_number = -1; + int user_device_number = -1; + cl_device * candidate_devices = nullptr; + unsigned n_candidate_devices = 0; + + unsigned n; + if (user_platform_string != NULL && sscanf(user_platform_string, " %u", &n) == 1 && n < n_platforms) { + user_platform_number = (int)n; + } + if (user_device_string != NULL && sscanf(user_device_string, " %u", &n) == 1 && n < n_devices) { + user_device_number = (int)n; + } + if (user_platform_number != -1 && user_device_number != -1) { + cl_platform* platform = &platforms[user_platform_number]; + if ((unsigned)user_device_number >= platform->n_devices) { + GGML_LOG_ERROR("ggml_opencl: invalid device number %d\n", user_device_number); + exit(1); + } + default_device = &platform->devices[user_device_number]; + candidate_devices = platform->devices; + n_candidate_devices = platform->n_devices; + } else { + // Choose a platform by matching a substring. + if (user_platform_number == -1 && user_platform_string != NULL && user_platform_string[0] != 0) { + for (unsigned i = 0; i < n_platforms; i++) { + struct cl_platform * p = &platforms[i]; + if (strstr(p->name, user_platform_string) != NULL || + strstr(p->vendor, user_platform_string) != NULL) { + user_platform_number = (int)i; + break; + } + } + if (user_platform_number == -1) { + GGML_LOG_ERROR("ggml_opencl: no platform matching '%s' was found.\n", user_platform_string); + exit(1); + } + } + + int platform_idx = user_platform_number != -1 ? user_platform_number : default_platform_number; + struct cl_platform * p = &platforms[platform_idx]; + candidate_devices = p->devices; + n_candidate_devices = p->n_devices; + default_device = p->default_device; + if (n_candidate_devices == 0) { + GGML_LOG_ERROR("ggml_opencl: selected platform '%s' does not have any devices.\n", p->name); + exit(1); + } + + if (user_device_number == -1 && user_device_string != NULL && user_device_string[0] != 0) { + for (unsigned i = 0; i < n_candidate_devices; i++) { + struct cl_device * d = &candidate_devices[i]; + if (strstr(d->name, user_device_string) != NULL) { + user_device_number = d->number; + break; + } + } + if (user_device_number == -1) { + GGML_LOG_ERROR("ggml_opencl: no device matching '%s' was found.\n", user_device_string); + exit(1); + } + } + if (user_device_number != -1) { + candidate_devices = &devices[user_device_number]; + n_candidate_devices = 1; + default_device = &candidate_devices[0]; + } + + GGML_ASSERT(n_candidate_devices > 0); + + if (default_device == NULL) { + default_device = &candidate_devices[0]; + } + } + + GGML_ASSERT(n_candidate_devices != 0 && candidate_devices); + + // Put the default device in front. + for (unsigned i = 1; i < n_candidate_devices; i++) { + if (&candidate_devices[i] == default_device) { + std::swap(candidate_devices[0], candidate_devices[i]); + default_device = &candidate_devices[0]; + break; + } + } + + GGML_LOG_INFO("ggml_opencl: selected platform: '%s'\n", default_device->platform->name); + + std::vector device_ids; + for (auto dev = candidate_devices, dev_end = candidate_devices + n_candidate_devices; dev != dev_end; dev++) { + device_ids.push_back(dev->id); + } + + cl_int err; + cl_context shared_context; + cl_context_properties properties[] = { (intptr_t) CL_CONTEXT_PLATFORM, (intptr_t) default_device->platform->id, 0 }; + + CL_CHECK( + (shared_context = clCreateContext(properties, device_ids.size(), device_ids.data(), NULL, NULL, &err), err)); + + for (auto dev = candidate_devices, dev_end = candidate_devices + n_candidate_devices; dev != dev_end; dev++) { + GGML_LOG_INFO("\nggml_opencl: device: '%s (%s)'\n", dev->name, dev->version); + + auto dev_ctx = std::unique_ptr(new ggml_backend_opencl_device_context{ + /*.platform =*/dev->platform->id, + /*.platform_nane =*/dev->platform->name, + /*.device =*/dev->id, + /*.device_name =*/dev->name, + /*.device_type =*/dev->type, + /*.device_version =*/dev->version, + /*.backend_ctx =*/nullptr, + /*.buffer_type =*/{}, + /*.context =*/shared_context, + }); + + found_devices.push_back(ggml_backend_device{ + /* .iface = */ ggml_backend_opencl_device_i, + /* .reg = */ reg, + /* .context = */ dev_ctx.get(), + }); + + if (!ggml_cl2_init(&found_devices.back())) { + found_devices.pop_back(); + GGML_LOG_INFO("ggml_opencl: drop unsupported device.\n"); + continue; + } + + dev_ctx.release(); + } + + if (found_devices.size()) { + auto * dev_ctx = static_cast(found_devices.front().context); + GGML_LOG_INFO("ggml_opencl: default device: '%s (%s)'\n", dev_ctx->device_name.c_str(), + dev_ctx->device_version.c_str()); + + if (dev_ctx->device_type != CL_DEVICE_TYPE_GPU) { + GGML_LOG_WARN("ggml_opencl: warning, the default device is not a GPU: '%s'.\n", + dev_ctx->device_name.c_str()); + } + } + + return found_devices; +} + +// Initialize device if it is supported (returns nullptr if it is not). +static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { + GGML_ASSERT(dev); + GGML_ASSERT(dev->context); + + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context; + GGML_ASSERT(dev_ctx->platform); + GGML_ASSERT(dev_ctx->device); + + if (dev_ctx->backend_ctx) { + return dev_ctx->backend_ctx; + } + + auto backend_ctx = std::make_unique(); + backend_ctx->device = dev_ctx->device; + backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN; + + if (strstr(dev_ctx->device_name.c_str(), "Adreno") || + strstr(dev_ctx->device_name.c_str(), "Qualcomm") || + strstr(dev_ctx->device_version.c_str(), "Adreno")) { + backend_ctx->gpu_family = GPU_FAMILY::ADRENO; + // Usually device version contains the detailed device name + backend_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_version.c_str()); + if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::ADRENO_UNKNOWN) { + backend_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_name.c_str()); + } + + // Use wave size of 64 for all Adreno GPUs. + backend_ctx->adreno_wave_size = 64; + } else if (strstr(dev_ctx->device_name.c_str(), "Intel")) { + backend_ctx->gpu_family = GPU_FAMILY::INTEL; + } else { + GGML_LOG_ERROR("Unsupported GPU: %s\n", dev_ctx->device_name.c_str()); + backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN; + return nullptr; + } + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) { + GGML_LOG_ERROR("ggml_opencl: Adreno-specific kernels should not be enabled for non-Adreno GPUs; " + "run on an Adreno GPU or recompile with CMake option `-DGGML_OPENCL_USE_ADRENO_KERNELS=OFF`\n"); + return nullptr; + } +#endif + + // Populate backend device name + backend_ctx->device_name = dev_ctx->device_name; + + // A local ref of cl_device_id for convenience + cl_device_id device = backend_ctx->device; + + ggml_cl_version platform_version = get_opencl_platform_version(dev_ctx->platform); + + // Check device OpenCL version, OpenCL 2.0 or above is required + ggml_cl_version opencl_c_version = get_opencl_c_version(platform_version, device); + if (opencl_c_version.major < 2) { + GGML_LOG_ERROR("ggml_opencl: OpenCL 2.0 or above is required\n"); + return nullptr; + } + + // Check driver version + size_t driver_version_str_size; + clGetDeviceInfo(device, CL_DRIVER_VERSION, 0, NULL, &driver_version_str_size); + char *driver_version = (char *)alloca(driver_version_str_size + 1); + clGetDeviceInfo(device, CL_DRIVER_VERSION, driver_version_str_size, driver_version, NULL); + driver_version[driver_version_str_size] = '\0'; + GGML_LOG_INFO("ggml_opencl: OpenCL driver: %s\n", driver_version); + backend_ctx->driver_version = driver_version; + + backend_ctx->adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version); + backend_ctx->has_vector_subgroup_broadcast = + backend_ctx->adreno_cl_compiler_version.major >= 47 || + backend_ctx->adreno_cl_compiler_version.major == 17; + GGML_LOG_INFO("ggml_opencl: vector subgroup broadcast support: %s\n", + backend_ctx->has_vector_subgroup_broadcast ? "true" : "false"); + + size_t ext_str_size; + clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, 0, NULL, &ext_str_size); + char *ext_buffer = (char *)alloca(ext_str_size + 1); + clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, ext_str_size, ext_buffer, NULL); + ext_buffer[ext_str_size] = '\0'; // ensure it is null terminated + // Check if ext_buffer contains cl_khr_fp16 + backend_ctx->fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL; + GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", backend_ctx->fp16_support ? "true" : "false"); + + // fp16 is required + if (!backend_ctx->fp16_support) { + GGML_LOG_ERROR("ggml_opencl: device does not support FP16\n"); + return nullptr; + } + + // If OpenCL 3.0 is supported, then check for cl_khr_subgroups, which becomes + // optional in OpenCL 3.0 (cl_khr_subgroup is mandatory in OpenCL 2.x) + if (opencl_c_version.major == 3 && strstr(ext_buffer, "cl_khr_subgroups") == NULL && + strstr(ext_buffer, "cl_intel_subgroups") == NULL) { + GGML_LOG_ERROR("ggml_opencl: device does not support subgroups (cl_khr_subgroups or cl_intel_subgroups) " + "(note that subgroups is an optional feature in OpenCL 3.0)\n"); + return nullptr; + } + + cl_uint base_align_in_bits; + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &base_align_in_bits, NULL)); + GGML_ASSERT(base_align_in_bits % 8u == 0); + backend_ctx->alignment = base_align_in_bits / 8u; + GGML_LOG_INFO("ggml_opencl: mem base addr align: %u\n", backend_ctx->alignment); + + clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL); + GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024); + + // Check SVM. + cl_device_svm_capabilities svm_caps; + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_SVM_CAPABILITIES, sizeof(cl_device_svm_capabilities), &svm_caps, 0)); + GGML_LOG_INFO("ggml_opencl: SVM coarse grain buffer support: %s\n", + svm_caps & CL_DEVICE_SVM_COARSE_GRAIN_BUFFER ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: SVM fine grain buffer support: %s\n", + svm_caps & CL_DEVICE_SVM_FINE_GRAIN_BUFFER ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: SVM fine grain system support: %s\n", + svm_caps & CL_DEVICE_SVM_FINE_GRAIN_SYSTEM ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: SVM atomics support: %s\n", + svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false"); + + if (opencl_c_version.major >= 3) { + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_NON_UNIFORM_WORK_GROUP_SUPPORT, sizeof(cl_bool), + &backend_ctx->non_uniform_workgroups, 0)); + } else { + GGML_ASSERT(opencl_c_version.major == 2); + // Non-uniform workgroup sizes is mandatory feature in v2.x. + backend_ctx->non_uniform_workgroups = true; + } + + // Print out configurations +#ifdef GGML_OPENCL_SOA_Q + GGML_LOG_INFO("ggml_opencl: flattening quantized weights representation as struct of arrays (GGML_OPENCL_SOA_Q)\n"); +#endif // GGML_OPENCL_SOA_Q + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_LOG_INFO("ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n"); +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + cl_int err; + + // A local ref of cl_context for convenience + cl_context context = backend_ctx->context = dev_ctx->context; + + //CL_CHECK((queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err), + // (err != CL_INVALID_QUEUE_PROPERTIES && err != CL_INVALID_VALUE ? err : + // (queue = clCreateCommandQueue(context, device, 0, &err), err) + //))); + cl_command_queue_properties command_queue_props = 0; +#ifdef GGML_OPENCL_PROFILING + command_queue_props |= CL_QUEUE_PROFILING_ENABLE; +#endif + CL_CHECK((backend_ctx->queue = clCreateCommandQueue(context, device, command_queue_props, &err), err)); + + // Load kernels + load_cl_kernels(backend_ctx.get(), opencl_c_version); + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Allocate intermediate buffers and images + size_t required_A_q_d_bytes = 311164928; + size_t required_A_s_d_bytes = 38895616; + size_t required_B_d_bytes = 45088768; + + // Ensure buffer sizes do not exceed the maximum allocation size + size_t max_A_q_d_bytes = MIN(required_A_q_d_bytes, backend_ctx->max_alloc_size); + size_t max_A_s_d_bytes = MIN(required_A_s_d_bytes, backend_ctx->max_alloc_size); + size_t max_B_d_bytes = MIN(required_B_d_bytes, backend_ctx->max_alloc_size); + if (required_A_q_d_bytes > backend_ctx->max_alloc_size) { + GGML_LOG_WARN("ggml_opencl: A_q_d buffer size reduced from %zu to %zu due to device limitations.\n", + required_A_q_d_bytes, max_A_q_d_bytes); + } + if (required_A_s_d_bytes > backend_ctx->max_alloc_size) { + GGML_LOG_WARN("ggml_opencl: A_s_d buffer size reduced from %zu to %zu due to device limitations.\n", + required_A_s_d_bytes, max_A_s_d_bytes); + } + if (required_B_d_bytes > backend_ctx->max_alloc_size) { + GGML_LOG_WARN("ggml_opencl: B_d buffer size reduced from %zu to %zu due to device limitations.\n", + required_B_d_bytes, max_B_d_bytes); + } + + CL_CHECK((backend_ctx->A_q_d_max = clCreateBuffer(context, 0, max_A_q_d_bytes, NULL, &err), err)); + CL_CHECK((backend_ctx->A_s_d_max = clCreateBuffer(context, 0, max_A_s_d_bytes, NULL, &err), err)); + CL_CHECK((backend_ctx->B_d_max = clCreateBuffer(context, 0, max_B_d_bytes, NULL, &err), err)); +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + dev_ctx->backend_ctx = backend_ctx.release(); + return dev_ctx->backend_ctx; +} + +static void ggml_cl2_free(void) { +#ifdef GGML_OPENCL_PROFILING + FILE * fperf = fopen("cl_profiling.csv", "w"); + if (!fperf) { + GGML_LOG_ERROR("Failed to open cl_profiling.csv\n"); + return; + } + + // Populate profiling info + for (ProfilingInfo & info : g_profiling_info) { + cl_ulong cmd_queued; + cl_ulong cmd_submit; + cl_ulong cmd_start; + cl_ulong cmd_end; + cl_ulong cmd_complete; + + CL_CHECK(clWaitForEvents(1, &info.evt)); + CL_CHECK(clGetEventProfilingInfo( + info.evt, CL_PROFILING_COMMAND_QUEUED, sizeof(cl_ulong), &cmd_queued, NULL)); + CL_CHECK(clGetEventProfilingInfo( + info.evt, CL_PROFILING_COMMAND_SUBMIT, sizeof(cl_ulong), &cmd_submit, NULL)); + CL_CHECK(clGetEventProfilingInfo( + info.evt, CL_PROFILING_COMMAND_START, sizeof(cl_ulong), &cmd_start, NULL)); + CL_CHECK(clGetEventProfilingInfo( + info.evt, CL_PROFILING_COMMAND_END, sizeof(cl_ulong), &cmd_end, NULL)); + CL_CHECK(clGetEventProfilingInfo( + info.evt, CL_PROFILING_COMMAND_COMPLETE, sizeof(cl_ulong), &cmd_complete, NULL)); + CL_CHECK(clReleaseEvent(info.evt)); + + char kernel_name[512]; + CL_CHECK(clGetKernelInfo(info.kernel, CL_KERNEL_FUNCTION_NAME, + sizeof(kernel_name), kernel_name, NULL)); + info.kernel_name = kernel_name; + + info.cmd_queued = cmd_queued; + info.cmd_submit = cmd_submit; + info.cmd_start = cmd_start; + info.cmd_end = cmd_end; + + info.cmd_queued_duration_ns = cmd_submit - cmd_queued; + info.cmd_submit_duration_ns = cmd_start - cmd_submit; + info.cmd_duration_ns = cmd_end - cmd_start; + info.cmd_complete_duration_ns = cmd_complete - cmd_end; + info.cmd_total_duration_ns = cmd_complete - cmd_queued; + } + + // Dump a csv + float total_kernel_time = 0; + fprintf(fperf, "op name, kernel name, queued duration (ms), submit duration(ms), exec duration (ms), complete duration (ms), total duration (ms), global size, local size, output size\n"); + for (const ProfilingInfo & info : g_profiling_info) { + total_kernel_time += info.cmd_duration_ns/1.e6f; + fprintf(fperf, "%s,%s,%f,%f,%f,%f,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n", + info.op_name.c_str(), info.kernel_name.c_str(), + info.cmd_queued_duration_ns/1.e6f, + info.cmd_submit_duration_ns/1.e6f, + info.cmd_duration_ns/1.e6f, + info.cmd_complete_duration_ns/1.e6f, + info.cmd_total_duration_ns/1.e6f, + info.global_size[0], info.global_size[1], info.global_size[2], + info.local_size[0], info.local_size[1], info.local_size[2], + info.output_size[0], info.output_size[1], info.output_size[2], info.output_size[3]); + } + fclose(fperf); + + GGML_LOG_INFO("ggml_opencl: total kernel time: %f\n", total_kernel_time); + + // Dump a simple chrome trace + FILE* ftrace = fopen("cl_trace.json", "w"); + if (!ftrace) { + GGML_LOG_ERROR("Failed to open cl_trace.json\n"); + return; + } + + fprintf(ftrace, "[\n"); + for (const ProfilingInfo & info : g_profiling_info) { + fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n", + info.kernel_name.c_str(), info.cmd_queued/1000); + fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n", + info.kernel_name.c_str(), info.cmd_submit/1000); + + fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n", + info.kernel_name.c_str(), info.cmd_start/1000); + fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n", + info.kernel_name.c_str(), info.cmd_end/1000); + } + fclose(ftrace); +#endif +} + +//------------------------------------------------------------------------------ +// Tensor extra management +//------------------------------------------------------------------------------ +struct ggml_tensor_extra_cl { + // The buffer object that holds the data. + cl_mem data_device; + // The offset into the buffer object. This is primarily for scratch buffer + // and view operation. + // NB: this offset no longer includes view offset (view_offs). Whenever this + // offset is used, view_offs should be considered. + cl_ulong offset; + // The actual size of the cl_mem object. This is needed when returning the + // block to the pool. + size_t actual_size; + + void reset() { + data_device = nullptr; + offset = 0; + actual_size = 0; + } +}; + +// Additional tensor extra structs for quantized tensors. +// These tensors are loaded from files and should not be allocated in scratch -- +// they should always be allocated from the pool. Hence, they do not have an +// `offset`, which indicate their locations in the scratch buffer. +struct ggml_tensor_extra_cl_q4_0 { + // Quantized values. + cl_mem q = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem q_img = nullptr; + // Scales. + cl_mem d = nullptr; + // Scales in image1d_buffer_t. + cl_mem d_img = nullptr; + // Size of quantized values. + size_t size_q = 0; + // Size of scales. + size_t size_d = 0; + + ~ggml_tensor_extra_cl_q4_0() { + reset(); + } + + void reset() { + // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer. + // They must be properly released so that the original buffer can be + // properly released to avoid memory leak. + if (q != nullptr) { + CL_CHECK(clReleaseMemObject(q)); + q = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + // Currently, q_img and d_img are only initialized when SMALL_ALLOC is + // enabled. They point to the images in ggml_backend_opencl_buffer_context. + // So, there is no need to release them here. + // TODO: initialize them for non SMALL_PATH path, or remove them. + q_img = nullptr; + d_img = nullptr; + size_q = 0; + size_d = 0; + } +}; + +//------------------------------------------------------------------------------ +// Backend API +//------------------------------------------------------------------------------ + +// +// backend +// +static const char * ggml_backend_opencl_name(ggml_backend_t backend) { + return "OpenCL"; + + UNUSED(backend); +} + +static void ggml_backend_opencl_free(ggml_backend_t backend) { + ggml_cl2_free(); + + GGML_UNUSED(backend); +} + +static void ggml_backend_opencl_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_UNUSED(backend); + GGML_UNUSED(tensor); + GGML_UNUSED(data); + GGML_UNUSED(offset); + GGML_UNUSED(size); +} + +static void ggml_backend_opencl_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_UNUSED(backend); + GGML_UNUSED(tensor); + GGML_UNUSED(data); + GGML_UNUSED(offset); + GGML_UNUSED(size); +} + +static bool ggml_backend_opencl_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) { + GGML_UNUSED(backend); + GGML_UNUSED(src); + GGML_UNUSED(dst); + return false; +} + +static void ggml_backend_opencl_synchronize(ggml_backend_t backend) { + auto * backend_ctx = static_cast(backend->context); + + cl_event evt; + CL_CHECK(clEnqueueBarrierWithWaitList(backend_ctx->queue, 0, nullptr, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseEvent(evt)); +} + +// Syncronizes the 'backend_ctx's device with others so that commands +// enqueued to it won't start until commands in the other devices have +// completed. +static void sync_with_other_backends(ggml_backend_opencl_context * backend_ctx) { + if (g_ggml_backend_opencl_devices.size() < 2) + return; // No other devices to synchronize with. + + std::vector events; + events.reserve(g_ggml_backend_opencl_devices.size()); + + for (ggml_backend_device & backend_dev : g_ggml_backend_opencl_devices) { + auto * other_backend_ctx = ggml_cl2_init(&backend_dev); + if (backend_ctx != other_backend_ctx) { + cl_event ev; + CL_CHECK(clEnqueueMarkerWithWaitList(other_backend_ctx->queue, 0, nullptr, &ev)); + CL_CHECK(clFlush(other_backend_ctx->queue)); + events.push_back(ev); + } + } + + CL_CHECK(clEnqueueBarrierWithWaitList(backend_ctx->queue, events.size(), events.data(), nullptr)); + for (auto ev : events) { + CL_CHECK(clReleaseEvent(ev)); + } +} + +static void sync_with_other_backends(ggml_backend_t backend) { + auto * backend_ctx = static_cast(backend->context); + sync_with_other_backends(backend_ctx); +} + +static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + + // NOTE: this may oversynchronize by synchronizing with + // backends/devices which don't compute 'cgraph's + // dependencies. + sync_with_other_backends(backend); + + if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + continue; + } + + bool ok = ggml_cl_compute_forward(backend, node); + if (!ok) { + GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + } + GGML_ASSERT(ok); + } + + return GGML_STATUS_SUCCESS; +} + +static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + GGML_UNUSED(dev); + + switch (op->op) { + case GGML_OP_NONE: + return true; + case GGML_OP_GET_ROWS: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + return true; + case GGML_TYPE_Q4_0: +#ifdef GGML_OPENCL_SOA_Q + // We do not support flattened Q4_0 (and possibly other Q's) + return false; +#else // GGML_OPENCL_SOA_Q + return true; +#endif // GGML_OPENCL_SOA_Q + default: + return false; + } + case GGML_OP_CPY: + case GGML_OP_DUP: + case GGML_OP_CONT: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + switch (op->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + return true; + default: + return false; + } + case GGML_TYPE_F16: + switch (op->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + return true; + default: + return false; + } + default: + return false; + } + case GGML_OP_ADD: + case GGML_OP_SCALE: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_SUB: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_GELU_QUICK: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + case GGML_UNARY_OP_SIGMOID: + return ggml_is_contiguous(op->src[0]); + case GGML_UNARY_OP_TANH: + return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || + (op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16); + default: + return false; + } + case GGML_OP_CLAMP: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_SOFT_MAX: + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + return true; + case GGML_OP_REPEAT: + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded + case GGML_OP_PAD: + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && + op->src[0]->ne[3] == 1 && op->ne[3] == 1; + case GGML_OP_UPSCALE: + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_CONCAT: + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_TIMESTEP_EMBEDDING: + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_GROUP_NORM: + return ggml_is_contiguous(op->src[0]); + case GGML_OP_MUL_MAT: + if (op->src[0]->type == GGML_TYPE_F16) { + return true; + } else if (op->src[0]->type == GGML_TYPE_F32) { + return op->src[1]->type == GGML_TYPE_F32; + } else if (op->src[0]->type == GGML_TYPE_Q4_0 || + op->src[0]->type == GGML_TYPE_Q6_K) { + return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); + } + return false; + case GGML_OP_MUL_MAT_ID: + if (op->src[0]->type == GGML_TYPE_Q4_0) { + if (op->src[1]->type == GGML_TYPE_F32) { + return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); + } + } + return false; + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + return true; + case GGML_OP_DIAG_MASK_INF: + return op->ne[3] == 1; + case GGML_OP_ROPE: { + const int mode = ((const int32_t *) op->op_params)[2]; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + if (is_mrope && !is_vision) { + if (op->src[0]->type == GGML_TYPE_F32 || + op->src[0]->type == GGML_TYPE_F16) { + return true; + } + return false; + } + if (is_vision) { + if (op->src[0]->type == GGML_TYPE_F32 || + op->src[0]->type == GGML_TYPE_F16) { + return true; + } + return false; + } + return true; + } + case GGML_OP_IM2COL: + return true; + case GGML_OP_ARGSORT: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_SUM_ROWS: + return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); + default: + return false; + } +} + +// Forward declaration - implementation appears later in the file. +static const char * ggml_backend_opencl_buffer_type_get_name(ggml_backend_buffer_type_t buffer_type); + +static ggml_guid_t ggml_backend_opencl_guid() { + static ggml_guid guid = { 0xde, 0xe0, 0x70, 0xa2, 0x73, 0x4e, 0x4d, 0xbc, 0xb0, 0xc7, 0x4f, 0xd4, 0x6d, 0x4e, 0x90, 0xfe }; + return &guid; +} + +static ggml_backend_i ggml_backend_opencl_i = { + /* .get_name = */ ggml_backend_opencl_name, + /* .free = */ ggml_backend_opencl_free, + /* .set_tensor_async = */ NULL, /* ggml_backend_opencl_set_tensor_async */ + /* .get_tensor_async = */ NULL, /* ggml_backend_opencl_get_tensor_async */ + /* .cpy_tensor_async = */ NULL, /* ggml_backend_opencl_cpy_tensor_async */ + /* .synchronize = */ ggml_backend_opencl_synchronize, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_opencl_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, +}; + +ggml_backend_t ggml_backend_opencl_init(void) { + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_opencl_reg(), 0); + ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev); + + ggml_backend_t backend = new ggml_backend { + /* .guid = */ ggml_backend_opencl_guid(), + /* .interface = */ ggml_backend_opencl_i, + /* .device = */ dev, + /* .context = */ backend_ctx + }; + + return backend; +} + +bool ggml_backend_is_opencl(ggml_backend_t backend) { + return backend && backend->iface.get_name == ggml_backend_opencl_name; +} + +// +// buffer +// +struct ggml_backend_opencl_buffer_context { + // A buffer context can hold multiple cl_mem objects. This is for flattening + // quantized weights and should be used with GGML_OPENCL_SMALL_ALLOC where + // each tensor is allocated a separate buffer. When flattening is enabled + // with small allocation, each tensor is backed by two cl_mem objects (for + // quants and scales) packed into a backend_opencl_buffer. + ggml_backend_opencl_buffer_context(cl_mem buf) + : name("OpenCL") { + buffer.push_back(buf); + } + + ~ggml_backend_opencl_buffer_context() { + for (cl_mem buf : buffer) { + CL_CHECK(clReleaseMemObject(buf)); + } + for (cl_mem im : img) { + CL_CHECK(clReleaseMemObject(im)); + } + + // Delete all extras to trigger their destructors + for (ggml_tensor_extra_cl * e : temp_tensor_extras) { + delete e; + } + for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) { + delete e; + } + for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0) { + delete e; + } + for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) { + delete e; + } + } + + ggml_tensor_extra_cl * ggml_opencl_alloc_temp_tensor_extra() { + ggml_tensor_extra_cl * extra; + if (temp_tensor_extras.empty()) { + extra = new ggml_tensor_extra_cl(); + } else { + extra = temp_tensor_extras.back(); + temp_tensor_extras.pop_back(); + } + + temp_tensor_extras_in_use.push_back(extra); + + extra->reset(); + return extra; + } + + ggml_tensor_extra_cl_q4_0 * ggml_opencl_alloc_temp_tensor_extra_q4_0() { + ggml_tensor_extra_cl_q4_0 * extra; + if (temp_tensor_extras_q4_0.empty()) { + extra = new ggml_tensor_extra_cl_q4_0(); + } else { + extra = temp_tensor_extras_q4_0.back(); + temp_tensor_extras_q4_0.pop_back(); + } + + temp_tensor_extras_q4_0_in_use.push_back(extra); + + extra->reset(); + return extra; + } + + void reset() { + for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) { + temp_tensor_extras.push_back(e); + } + temp_tensor_extras_in_use.clear(); + + for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) { + temp_tensor_extras_q4_0.push_back(e); + } + temp_tensor_extras_q4_0_in_use.clear(); + } + + // Pools for extras. Available extras are in `temp_tensor_extras`. Extras + // being used are in `temp_tensor_extras_in_use`. At the first run, new + // extras get created and put in `in_use`. When the buffer is reset via + // the `reset` callback, all extras in `in_use` get moved to available extras + // for reuse. + std::vector temp_tensor_extras; + std::vector temp_tensor_extras_in_use; + std::vector temp_tensor_extras_q4_0; + std::vector temp_tensor_extras_q4_0_in_use; + + // The buffer_context is initially created by ggml_backend_buft_alloc_buffer + // before any tensor is initialized (at the beginning of alloc_tensor_range). + // Hence, there is alway a buffer object in this vector. When each tensor is + // being initialized, this original buffer object will be released if both + // flattening and small allocation are enabled, and additional buffer + // objects will be created in init_tensor to represent flattened quantized + // weights. + std::vector buffer; + // These are image1d_buffer_t objects that wrap around the quants and scales. + // For Q4_0 quantization, there should be two of them - one for quants and + // one for scales. They should be populated only when flattening and small + // allocation are enabled. + std::vector img; + std::string name; +}; + +static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + delete ctx; +} + +static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) { + ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer->buft->device); + return (void *) (uintptr_t) backend_ctx->alignment; +} + +static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + + ggml_cl2_init(buffer->buft->device); + + if (tensor->view_src != nullptr) { + GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); + + ggml_tensor_extra_cl * view_extra = (ggml_tensor_extra_cl *) tensor->view_src->extra; + GGML_ASSERT(view_extra && "view_extra is nullptr?"); + + // Reuse extra of the parent tensor. The offset of this view tensor + // becomes `extra->offset + view_offs` and needs to be calculated when + // it is used. This changes is needed because of the change to + // ggml_alloc.c in https://github.com/ggerganov/llama.cpp/pull/7640. + // `buffer` passed in here will always be `tensor->buffer`. It is OK + // to allocate extras from the same buffer context for ordinary + // intermediate tensors. But for views into kv cache tensors, doing so + // would mess up the extras used by kv cache. + // Before #7640, `buffer` is for intermediate tensors, which is always + // different from that of kv cache tensors. + // + // NB: now extra->offset no longer accounts for view_offs. + // NB: this should not apply to weight tensors (for end-to-end runs, but + // may apply for test-backend-ops). + // FIXME: if any unexpected results are seen, double check the offset - + // there could be other places that need fix. + tensor->extra = view_extra; + } else { + { + size_t offset = (char *) tensor->data - (char *) ggml_backend_opencl_buffer_get_base(buffer); + + ggml_tensor_extra_cl * extra = ctx->ggml_opencl_alloc_temp_tensor_extra(); + extra->offset = offset; + extra->data_device = ctx->buffer[0]; + extra->actual_size = ggml_nbytes(tensor); + + tensor->extra = extra; + } + } + return GGML_STATUS_SUCCESS; +} + +// The optimized gemm and gemv kernels are used for large matrices without batch. +// tensor is the quantized weights matrix. +inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { + int64_t threshold_ne0 = 512; + int64_t threshold_ne1 = 512; + if (!backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) && + backend_ctx->adreno_cl_compiler_version.type != DX) { + threshold_ne0 = 128; + threshold_ne1 = 128; + } + return tensor->ne[0] >= threshold_ne0 && tensor->ne[1] >= threshold_ne1 && + tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device); + + cl_context context = backend_ctx->context; + cl_command_queue queue = backend_ctx->queue; + +#ifdef GGML_OPENCL_SOA_Q + // We separate the quantized bits and scale from block_q4_0 by using an + // additional kernel, where each thread handles a block. We first read the + // original weights into a temporary buffer, then create two separate + // buffers for quantized bits and scales, which are then populated by the + // conversion kernel. + if (tensor->type == GGML_TYPE_Q4_0) { + // Tensors should have been preallocated, therefore they should + // already have ggml_tensor_extra_cl as extra. + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q4_0 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q4_0(); + + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + // We consider the specified offset arg as always, although For weights + // the offset arg should be 0 (we do not assert this). + //GGML_ASSERT(offset == 0); + + // We create subbuffers from the original tensor buffer for scales and + // quants - i.e., scales and quants are aliases into the buffer obejct + // that backs the original tensor. This is a cleaner way to adapt to the + // new memory management. + // In the old code, we allocate new buffers for scales and quants + // respectively, which could still be done but would result in double + // allocation; properly deallocating the preallocated buffer that backs + // the tensors is tricky and would leak the backend specific information + // into the general backend code. + // Does this create misaligned subbuffers (alignment is 1024) in certain + // cases ? + cl_buffer_region region; + + // The original tensor memory is divided into scales and quants, i.e., + // we first store scales, then quants. + // Create subbuffer for scales. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for quants. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + //cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; + #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; + + // The optimized kernels need weights in natural order, so unshuffle. + if (use_adreno_kernels(backend_ctx, tensor)) { + kernel = backend_ctx->kernel_convert_block_q4_0_noshuffle; + } + #else + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; + #endif // GGML_OPENCL_USE_ADRENO_KERNELS + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + + // transpose the weights and scales + #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Only do transpose for large, non batched matrix + // TODO: use preallocated images instead of sub-buffer then image + if (use_adreno_kernels(backend_ctx, tensor)) { + // <----------------------------------------------------------------------------------> // + // start transpose + // <----------------------------------------------------------------------------------> // + int M = tensor->ne[1]; // ne01 + int K = tensor->ne[0]; // ne00 + + //For matrix-vector multiplication kernel, we assume K is a multiple of 32 + GGML_ASSERT(K % 32 == 0); + //For transpose kernels, we assume K is a multiple of 4 (satisfied by prior assert), and M is a multiple of 4 + GGML_ASSERT(M % 4 == 0); + + // transpose is out of place, so we need to allocate transposed buffers + // <----------------------------------------------------------------------------------> // + // use sub_buffer of max buffer size instead + + size_t q_size_bytes = K * M / 8 * sizeof(float); + cl_buffer_region region; + region.origin = 0; + region.size = q_size_bytes; + cl_mem qT_d = clCreateSubBuffer( + backend_ctx->A_q_d_max, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &err); + // cl_mem qT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, q_size_bytes, NULL, &err); + CL_CHECK(err); + + // size_t d_size_bytes = M * (K / 32) / 2 * sizeof(float); + size_t d_size_bytes = M * (K / 32) * 2; + region.origin = 0; + region.size = d_size_bytes; + cl_mem dT_d = clCreateSubBuffer( + backend_ctx->A_s_d_max, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &err); + // cl_mem dT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, d_size_bytes, NULL, &err); + CL_CHECK(err); + + // <----------------------------------------------------------------------------------> // + + + // create images from the buffers + // <----------------------------------------------------------------------------------> // + cl_mem q_d_image1D; + cl_mem d_d_image1D; + cl_mem qT_d_image1D; + cl_mem dT_d_image1D; + + cl_image_format img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; + cl_image_desc img_desc_1d; + + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 4 / 4; + img_desc_1d.buffer = extra->q; + q_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); + CL_CHECK(err); + + img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 4 / 4; + img_desc_1d.buffer = qT_d; + qT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); + CL_CHECK(err); + + img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 32 / 4; + img_desc_1d.buffer = extra->d; + d_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); + CL_CHECK(err); + + img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 32 / 4; + img_desc_1d.buffer = dT_d; + dT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); + CL_CHECK(err); + // <----------------------------------------------------------------------------------> // + + // set up and call the transpose kernels + // <----------------------------------------------------------------------------------> // + // weights + int height_q = M / 4; + int width_q = K / 4 / 4; + kernel = backend_ctx->kernel_transpose_16; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_d_image1D)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qT_d_image1D)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_q)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_q)); + + size_t local_size_q[3] = {4, 16, 1}; + size_t global_size_q[3] = {static_cast(width_q), static_cast(height_q), 1}; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_q, local_size_q, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + + // scales + int height_s = M / 4; + int width_s = K / 32 / 4; + + kernel = backend_ctx->kernel_transpose_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &d_d_image1D)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &dT_d_image1D)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_s)); + + size_t local_size_s[3] = {4, 16, 1}; + size_t global_size_s[3] = {static_cast(width_s), static_cast(height_s), 1}; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_s, local_size_s, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + // <----------------------------------------------------------------------------------> // + + // copy transposed buffer contents to original buffers + // <----------------------------------------------------------------------------------> // + // weights + CL_CHECK(clEnqueueCopyBuffer(queue, qT_d, extra->q, 0, 0, q_size_bytes, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + + // scales + CL_CHECK(clEnqueueCopyBuffer(queue, dT_d, extra->d, 0, 0, d_size_bytes, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + // <----------------------------------------------------------------------------------> // + + // deallocate transpose buffers + // <----------------------------------------------------------------------------------> // + CL_CHECK(clReleaseMemObject(qT_d)); + CL_CHECK(clReleaseMemObject(dT_d)); + + // deallocate temporary images + CL_CHECK(clReleaseMemObject(q_d_image1D)); + CL_CHECK(clReleaseMemObject(d_d_image1D)); + CL_CHECK(clReleaseMemObject(qT_d_image1D)); + CL_CHECK(clReleaseMemObject(dT_d_image1D)); + // <----------------------------------------------------------------------------------> // + // end transpose + // <----------------------------------------------------------------------------------> // + } + #endif // GGML_OPENCL_USE_ADRENO_KERNELS + + return; + } +#endif // GGML_OPENCL_SOA_Q + + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + GGML_ASSERT(extra); + + CL_CHECK(clEnqueueWriteBuffer( + queue, extra->data_device, CL_TRUE, extra->offset + offset, + size, data, 0, NULL, NULL)); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(tensor->extra); + + ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device); + + cl_context context = backend_ctx->context; + cl_command_queue queue = backend_ctx->queue; + + // Make sure all previously submitted commands in other devices are finished. + sync_with_other_backends(backend_ctx); + +#ifdef GGML_OPENCL_SOA_Q + // In end-to-end runs, get_tensor is usually used to get back the logits, + // where we can simply do clEnqueueReadBuffer since they are f32. + // However, in test-backend-ops, the GPU graph is copied to the CPU backend, + // which requires reading back quantized weight tensors. + // To properly support this, we need to restore block_q4_0 struct arrays + // from the flattened buffers. + if (tensor->type == GGML_TYPE_Q4_0) { + ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *)tensor->extra; + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } +#endif // GGML_OPENCL_SOA_Q + + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + + CL_CHECK(clEnqueueReadBuffer( + queue, extra->data_device, CL_TRUE, extra->offset + tensor->view_offs + offset, + size, data, 0, NULL, NULL)); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_opencl_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_dev_t dev = buffer->buft->device; + ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev); + cl_command_queue queue = backend_ctx->queue; + + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + for (cl_mem buf : ctx->buffer) { + CL_CHECK(clEnqueueFillBuffer(queue, buf, &value, sizeof(value), 0, buffer->size, 0, NULL, NULL)); + } + CL_CHECK(clFinish(queue)); +} + +static void ggml_backend_opencl_buffer_reset(ggml_backend_buffer_t buffer) { + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ctx->reset(); +} + +static ggml_backend_buffer_i ggml_backend_opencl_buffer_interface = { + /* .free_buffer = */ ggml_backend_opencl_buffer_free_buffer, + /* .get_base = */ ggml_backend_opencl_buffer_get_base, + /* .init_tensor = */ ggml_backend_opencl_buffer_init_tensor, + /* .memset_tensor = */ NULL, + /* .set_tensor = */ ggml_backend_opencl_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_opencl_buffer_get_tensor, + /* .cpy_tensor = */ NULL, + /* .clear = */ ggml_backend_opencl_buffer_clear, + /* .reset = */ ggml_backend_opencl_buffer_reset, +}; + +// +// buffer type +// + +static const char * ggml_backend_opencl_buffer_type_get_name(ggml_backend_buffer_type_t buffer_type) { + return "OpenCL"; + + GGML_UNUSED(buffer_type); +} + +static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buffer_type, size_t size) { + ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer_type->device); + + // clCreateBuffer returns -61 for size 0 + size = std::max(size, (size_t)1); + + cl_int err; + cl_mem mem = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, size, NULL, &err); + if (err != CL_SUCCESS) { + GGML_LOG_INFO("%s: failed to allocate %.2f MiB\n", __func__, size / 1024.0 / 1024.0); + return nullptr; + } + + ggml_backend_opencl_buffer_context * ctx = new ggml_backend_opencl_buffer_context(mem); + + return ggml_backend_buffer_init(buffer_type, ggml_backend_opencl_buffer_interface, ctx, size); +} + +static size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) { + ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device); + return backend_ctx->alignment; +} + +static size_t ggml_backend_opencl_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) { + static size_t max_size = -1; + if (max_size == (size_t)-1) { + ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device); + max_size = backend_ctx->max_alloc_size; + } + return max_size; +} + +static bool ggml_backend_opencl_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { + return ggml_backend_is_opencl(backend); + + UNUSED(buft); +} + +static ggml_backend_buffer_type_i ggml_backend_opencl_buffer_type_interface = { + /* .get_name = */ ggml_backend_opencl_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_opencl_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_opencl_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_opencl_buffer_type_get_max_size, + /* .get_alloc_size = */ NULL, + /* .is_host = */ NULL, +}; + +// +// backend device +// + +static const char * ggml_backend_opencl_device_get_name(ggml_backend_dev_t dev) { + return "GPUOpenCL"; + + GGML_UNUSED(dev); +} + +static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_opencl_device_context *dev_ctx = (ggml_backend_opencl_device_context *) dev->context; + return dev_ctx->device_name.c_str(); +} + +static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + *free = 1; + *total = 1; + + GGML_UNUSED(dev); +} + +static enum ggml_backend_dev_type ggml_backend_opencl_device_get_type(ggml_backend_dev_t dev) { + return GGML_BACKEND_DEVICE_TYPE_GPU; + + GGML_UNUSED(dev); +} + +static void ggml_backend_opencl_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_opencl_device_get_name(dev); + props->description = ggml_backend_opencl_device_get_description(dev); + props->type = ggml_backend_opencl_device_get_type(dev); + ggml_backend_opencl_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = ggml_backend_dev_caps { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ false, + /* .events = */ false, + }; +} + +static ggml_backend_t ggml_backend_opencl_device_init(ggml_backend_dev_t dev, const char * params) { + ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(dev); + + ggml_backend_t backend = new ggml_backend { + /* .guid = */ ggml_backend_opencl_guid(), + /* .interface = */ ggml_backend_opencl_i, + /* .device = */ dev, + /* .context = */ backend_ctx, + }; + + return backend; + + GGML_UNUSED(params); +} + +static ggml_backend_buffer_type_t ggml_backend_opencl_device_get_buffer_type(ggml_backend_dev_t dev) { + auto * dev_ctx = static_cast(dev->context); + + dev_ctx->buffer_type = ggml_backend_buffer_type{ + /* .iface = */ ggml_backend_opencl_buffer_type_interface, + /* .device = */ dev, + /* .context = */ nullptr, + }; + + return &dev_ctx->buffer_type; +} + +static ggml_backend_buffer_t ggml_backend_opencl_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + GGML_UNUSED(dev); + GGML_UNUSED(ptr); + GGML_UNUSED(size); + GGML_UNUSED(max_tensor_size); + return nullptr; +} + +static bool ggml_backend_opencl_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + return ggml_opencl_supports_op(dev, op); +} + +static bool ggml_backend_opencl_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + // Check 'dev' and 'buffer_type' are not objects belonging to this backend. + if (dev->iface.get_name != ggml_backend_opencl_device_get_name || + buft->iface.get_name != ggml_backend_opencl_buffer_type_get_name) { + return false; + } + + // Check cl_context is the same. clEnqueue* commands may not use + // buffers from another cl_context. + ggml_backend_opencl_context * backend_ctx0 = ggml_cl2_init(dev); + ggml_backend_opencl_context * backend_ctx1 = ggml_cl2_init(buft->device); + return backend_ctx0->context == backend_ctx1->context; +} + +namespace /* anonymous */ { +struct ggml_backend_device_i ggml_backend_opencl_device_i = { + /* .get_name = */ ggml_backend_opencl_device_get_name, + /* .get_description = */ ggml_backend_opencl_device_get_description, + /* .get_memory = */ ggml_backend_opencl_device_get_memory, + /* .get_type = */ ggml_backend_opencl_device_get_type, + /* .get_props = */ ggml_backend_opencl_device_get_props, + /* .init_backend = */ ggml_backend_opencl_device_init, + /* .get_buffer_type = */ ggml_backend_opencl_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ ggml_backend_opencl_device_buffer_from_ptr, + /* .supports_op = */ ggml_backend_opencl_device_supports_op, + /* .supports_buft = */ ggml_backend_opencl_device_supports_buft, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; +} + +// Backend registry + +static const char * ggml_backend_opencl_reg_get_name(ggml_backend_reg_t reg) { + return "OpenCL"; + + GGML_UNUSED(reg); +} + +static size_t ggml_backend_opencl_reg_device_count(ggml_backend_reg_t reg) { + return g_ggml_backend_opencl_devices.size(); + + GGML_UNUSED(reg); +} + +static ggml_backend_dev_t ggml_backend_opencl_reg_device_get(ggml_backend_reg_t reg, size_t index) { + GGML_ASSERT(index < ggml_backend_opencl_reg_device_count(reg)); + + return &g_ggml_backend_opencl_devices[index]; + + GGML_UNUSED(reg); + GGML_UNUSED(index); +} + +static struct ggml_backend_reg_i ggml_backend_opencl_reg_i = { + /* .get_name = */ ggml_backend_opencl_reg_get_name, + /* .device_count = */ ggml_backend_opencl_reg_device_count, + /* .device_get = */ ggml_backend_opencl_reg_device_get, + /* .get_proc_address = */ NULL, +}; + +ggml_backend_reg_t ggml_backend_opencl_reg(void) { + static std::mutex mutex; + static ggml_backend_reg reg; + static bool initialized = false; + std::lock_guard lock(mutex); + + if (initialized) { + return ® + } + initialized = true; + + g_ggml_backend_opencl_devices = ggml_opencl_probe_devices(®); + + reg = ggml_backend_reg{ + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_opencl_reg_i, + /* .context = */ NULL, + }; + + return ® +} + +GGML_BACKEND_DL_IMPL(ggml_backend_opencl_reg) + +//------------------------------------------------------------------------------ +// Debugging utils +//------------------------------------------------------------------------------ +#if 0 +#define QK4_0 32 +typedef struct { + ggml_fp16_t d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, + "wrong q4_0 block size/padding"); + +#include +#ifdef __cplusplus +#include "half.hpp" +#endif + +static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tensor) { + void * buf = malloc(ggml_nbytes(tensor)); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; +#ifdef GGML_OPENCL_SOA_Q + void * buf_q; + void * buf_d; +#endif + + // Make sure everything is done. + CL_CHECK(clFinish(queue)); + +#ifdef GGML_OPENCL_SOA_Q + if (tensor->type == GGML_TYPE_Q4_0) { + ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *) tensor->extra; + GGML_ASSERT(extra); + + size_t size_q = ggml_nelements(tensor)/QK4_0 * QK4_0/2; + size_t size_d = ggml_nelements(tensor)/QK4_0 * sizeof(ggml_fp16_t); + GGML_ASSERT(size_q + size_d == ggml_nbytes(tensor)); + buf_q = malloc(size_q); + buf_d = malloc(size_d); + + CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_d, buf_d, 0, NULL, NULL)); + CL_CHECK(clFinish(queue)); + } else { + // Read out the tensor from GPU memory. + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + GGML_ASSERT(extra); + + CL_CHECK(clEnqueueReadBuffer(queue, extra->data_device, CL_TRUE, + extra->offset, ggml_nbytes(tensor), buf, 0, NULL, NULL)); + CL_CHECK(clFinish(queue)); + } +#else + // Read out the tensor from GPU memory. + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + GGML_ASSERT(extra); + + CL_CHECK(clEnqueueReadBuffer(queue, extra->data_device, CL_TRUE, + extra->offset, ggml_nbytes(tensor), buf, 0, NULL, NULL)); + CL_CHECK(clFinish(queue)); +#endif // GGML_OPENCL_SOA_Q + + // Open file and dump. + char fname[512]; + sprintf(fname, "./tensor-dumps/%s.txt", tensor->name); + FILE * f = fopen(fname, "w"); + if (!f) { + printf("Failed to open %s\n", fname); + return; + } + + if (tensor->type == GGML_TYPE_F32) { + float * data = (float *) buf; + for (int i = 0; i < ggml_nelements(tensor); ++i) { + if (isnan(data[i])) { + printf("NaN found: %s\n", tensor->name); + break; + } + fprintf(f, "%f\n", data[i]); + } + } else if (tensor->type == GGML_TYPE_I32) { + int * data = (int *) buf; + for (int i = 0; i < ggml_nelements(tensor); ++i) { + if (isnan(data[i])) { + printf("NaN found: %s\n", tensor->name); + break; + } + fprintf(f, "%d\n", data[i]); + } + } else if (tensor->type == GGML_TYPE_F16) { +#ifdef __cplusplus + half_float::half * data = (half_float::half *) buf; + for (int i = 0; i < ggml_nelements(tensor); ++i) { + if (std::isnan(data[i])) { + printf("NaN found: %s\n", tensor->name); + break; + } + fprintf(f, "%f\n", float(data[i])); + } +#endif + } else if (tensor->type == GGML_TYPE_Q4_0) { +#ifdef GGML_OPENCL_SOA_Q + ggml_fp16_t * data_d = (ggml_fp16_t *)buf_d; + unsigned char * data_q = (unsigned char *)buf_q; + + for (int i = 0; i < ggml_nelements(tensor)/QK4_0; ++i) { + fprintf(f, "%04x, ", data_d[i]); + for (int k = 0; k < QK4_0/2; ++k) { + fprintf(f, "%02x, ", data_q[k]); + } + fprintf(f, "\n"); + data_q += QK4_0/2; + } + free(buf_d); + free(buf_q); +#else + block_q4_0 * data = (block_q4_0 *) buf; + for (int i = 0; i < ggml_nelements(tensor)/QK4_0; ++i) { + fprintf(f, "%04x, ", data[i].d); + for (int k = 0; k < QK4_0/2; ++k) { + fprintf(f, "%02x, ", data[i].qs[k]); + } + fprintf(f, "\n"); + } +#endif // GGML_OPENCL_SOA_Q + } + free(buf); + fflush(f); + fclose(f); +} +#else +#define dump_tensor(tensor) +#endif + +//------------------------------------------------------------------------------ +// Profiling utility +//------------------------------------------------------------------------------ +#ifdef GGML_OPENCL_PROFILING +static void populateProfilingInfo( + ProfilingInfo& info, cl_event evt, cl_kernel kernel, + size_t global_size[3], size_t local_size[3], + const ggml_tensor * tensor) { + info.op_name = tensor->name; + info.kernel = kernel; + info.evt = evt; + + info.local_size[0] = local_size[0]; + info.local_size[1] = local_size[1]; + info.local_size[2] = local_size[2]; + info.global_size[0] = global_size[0]; + info.global_size[1] = global_size[1]; + info.global_size[2] = global_size[2]; + info.output_size[0] = tensor->ne[0]; + info.output_size[1] = tensor->ne[1]; + info.output_size[2] = tensor->ne[2]; + info.output_size[3] = tensor->ne[3]; +} +#endif + +//------------------------------------------------------------------------------ +// Ops +//------------------------------------------------------------------------------ + +static bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { + const int64_t ne10 = src1->ne[0]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + + // TODO: find the optimal values for these + return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && + src1->type == GGML_TYPE_F32 && + dst->type == GGML_TYPE_F32 && + (ne0 >= 32 && ne1 >= 32 && ne10 >= 32); +} + +static void ggml_cl_nop(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + UNUSED(backend); + UNUSED(src0); + UNUSED(src1); + UNUSED(dst); +} + +static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const int ne00 = src0 ? src0->ne[0] : 0; + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const int ne10 = src1 ? src1->ne[0] : 0; + const cl_ulong nb10 = src1 ? src1->nb[0] : 0; + const int ne11 = src1 ? src1->ne[1] : 0; + const cl_ulong nb11 = src1 ? src1->nb[1] : 0; + const cl_ulong nb1 = dst ? dst->nb[1] : 0; + const cl_ulong nb2 = dst ? dst->nb[2] : 0; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + switch (src0->type) { + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_get_rows_f32; + break; + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_get_rows_f16; + break; + case GGML_TYPE_Q4_0: + kernel = backend_ctx->kernel_get_rows_q4_0; + break; + default: + GGML_ASSERT(false && "not implemented"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2)); + + size_t global_work_size[] = {(size_t)ne10, (size_t)ne11, 1}; + size_t local_work_size[] = {1, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + + const cl_ulong nb00 = src0 ? src0->nb[0] : 0; + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const cl_ulong nb03 = src0 ? src0->nb[3] : 0; + + const int ne10 = src1 ? src1->ne[0] : 0; + const int ne11 = src1 ? src1->ne[1] : 0; + const int ne12 = src1 ? src1->ne[2] : 0; + const int ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13); + + const cl_ulong nb10 = src1 ? src1->nb[0] : 0; + const cl_ulong nb11 = src1 ? src1->nb[1] : 0; + const cl_ulong nb12 = src1 ? src1->nb[2] : 0; + const cl_ulong nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13); + + const int ne0 = dst ? dst->ne[0] : 0; + const int ne1 = dst ? dst->ne[1] : 0; + const int ne2 = dst ? dst->ne[2] : 0; + const int ne3 = dst ? dst->ne[3] : 0; + + const cl_ulong nb0 = dst ? dst->nb[0] : 0; + const cl_ulong nb1 = dst ? dst->nb[1] : 0; + const cl_ulong nb2 = dst ? dst->nb[2] : 0; + const cl_ulong nb3 = dst ? dst->nb[3] : 0; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + bool bcast_row = false; + cl_kernel kernel; + + if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(src0)); + + // src1 is a row + GGML_ASSERT(ne11 == 1); + + bcast_row = true; + int ne = ne00 / 4; + kernel = backend_ctx->kernel_add_row; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne)); + } else { + kernel = backend_ctx->kernel_add; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3)); + } + + if (bcast_row) { + int n = ggml_nelements(dst)/4; + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); +#endif + } else { + unsigned int nth = MIN(64, ne0); + size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } +} + +static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + + const cl_ulong nb00 = src0 ? src0->nb[0] : 0; + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const cl_ulong nb03 = src0 ? src0->nb[3] : 0; + + const int ne10 = src1 ? src1->ne[0] : 0; + const int ne11 = src1 ? src1->ne[1] : 0; + const int ne12 = src1 ? src1->ne[2] : 0; + const int ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13); + + const cl_ulong nb10 = src1 ? src1->nb[0] : 0; + const cl_ulong nb11 = src1 ? src1->nb[1] : 0; + const cl_ulong nb12 = src1 ? src1->nb[2] : 0; + const cl_ulong nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13); + + const int ne0 = dst ? dst->ne[0] : 0; + const int ne1 = dst ? dst->ne[1] : 0; + const int ne2 = dst ? dst->ne[2] : 0; + const int ne3 = dst ? dst->ne[3] : 0; + + const cl_ulong nb0 = dst ? dst->nb[0] : 0; + const cl_ulong nb1 = dst ? dst->nb[1] : 0; + const cl_ulong nb2 = dst ? dst->nb[2] : 0; + const cl_ulong nb3 = dst ? dst->nb[3] : 0; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + bool bcast_row = false; + cl_kernel kernel; + + if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(src0)); + + // src1 is a row + GGML_ASSERT(ne11 == 1); + + bcast_row = true; + int ne = ne00 / 4; + kernel = backend_ctx->kernel_mul_row; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne)); + } else { + kernel = backend_ctx->kernel_mul; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3)); + } + + if (bcast_row) { + int n = ggml_nelements(dst)/4; + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); +#endif + } else { + unsigned int nth = MIN(64, ne0); + size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } +} + +static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const cl_ulong nb10 = src1->nb[0]; + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb13 = src1->nb[3]; + + const int ne0 = dst->ne[0]; + + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + bool bcast_row = false; + cl_kernel kernel; + + if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(src0)); + + // src1 is a row + GGML_ASSERT(ne11 == 1); + + bcast_row = true; + int ne = ne00 / 4; + kernel = backend_ctx->kernel_div_row; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne)); + } else { + kernel = backend_ctx->kernel_div; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3)); + } + + if (bcast_row) { + int n = ggml_nelements(dst)/4; + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } else { + unsigned int nth = MIN(64, ne0); + size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } +} + +static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const cl_ulong nb10 = src1->nb[0]; + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb13 = src1->nb[3]; + + const int ne0 = dst->ne[0]; + + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + bool bcast_row = false; + cl_kernel kernel; + + if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(src0)); + + // src1 is a row + GGML_ASSERT(ne11 == 1); + + bcast_row = true; + int ne = ne00 / 4; + kernel = backend_ctx->kernel_sub_row; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne)); + } else { + kernel = backend_ctx->kernel_sub; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3)); + } + + if (bcast_row) { + int n = ggml_nelements(dst)/4; + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } else { + unsigned int nth = MIN(64, ne0); + size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } +} + +static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + int n = ggml_nelements(dst); + + if (n % 4 == 0) { + kernel = backend_ctx->kernel_gelu_4; + n /= 4; + } else { + kernel = backend_ctx->kernel_gelu; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL); +#endif +} + +static void ggml_cl_gelu_quick(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + int n = ggml_nelements(dst); + + if (n % 4 == 0) { + kernel = backend_ctx->kernel_gelu_quick_4; + n /= 4; + } else { + kernel = backend_ctx->kernel_gelu_quick; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL); +#endif +} + +static void ggml_cl_silu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + int n = ggml_nelements(dst); + + if (n % 4 == 0) { + kernel = backend_ctx->kernel_silu_4; + n /= 4; + } else { + kernel = backend_ctx->kernel_silu; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_relu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel = backend_ctx->kernel_relu; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + const int64_t n = ggml_nelements(dst); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_sigmoid(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_sigmoid_f32; + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + kernel = backend_ctx->kernel_sigmoid_f16; + } else { + GGML_ASSERT(false && "Unsupported data types for sigmoid (input and output must be both f32 or f16)"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + const int64_t n = ggml_nelements(dst); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_clamp(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + float min; + float max; + memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float)); + + cl_kernel kernel = backend_ctx->kernel_clamp; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &min)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float), &max)); + + const int64_t n = ggml_nelements(dst); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const cl_ulong nb03 = src0 ? src0->nb[3] : 0; + + const int nth = MIN(64, ne00); + + cl_kernel kernel = backend_ctx->kernel_norm; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth, NULL)); + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + //ggml_backend_opencl_device_context * dev_ctx = + // (ggml_backend_opencl_device_context *)backend->device->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const cl_ulong nb03 = src0 ? src0->nb[3] : 0; + + GGML_ASSERT(ne00 % 4 == 0); + + const int nth = MIN(64, ne00); + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + cl_kernel kernel = backend_ctx->kernel_rms_norm; + + // Note, this kernel declares local memory in kernel args and the size + // depends on subgroup size. + // Note, this requires OpenCL 2.1 and above + // For now we use fixed subgroup size to simplify support for OpenCL 2.0. + size_t sgs; + //CL_CHECK(clGetKernelSubGroupInfo(kernel, dev_ctx->device, + // CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE, + // sizeof(local_work_size), local_work_size, + // sizeof(size_t), &sgs, NULL)); + if (backend_ctx->gpu_family == ADRENO) { + sgs = 64; + } else if (backend_ctx->gpu_family == INTEL) { + sgs = 32; + } else { + GGML_ASSERT(false && "Unsupported GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps)); + // This is local memory - the size depends on subgroup size. + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth/sgs, NULL)); + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + int32_t n_groups = ((const int32_t *) dst->op_params)[0]; + int32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + n_groups - 1) / n_groups); + float eps = ((const float *) dst->op_params)[1]; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne = ne00*ne01*ne02; + + cl_kernel kernel = backend_ctx->kernel_group_norm; + + size_t sgs = 64; + if (backend_ctx->gpu_family == ADRENO) { + sgs = 64; + } else if (backend_ctx->gpu_family == INTEL) { + sgs = 32; + } else { + GGML_ASSERT(false && "Unsupported GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &group_size)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float), &eps)); + + size_t global_work_size[] = {(size_t)n_groups*sgs, 1, 1}; + size_t local_work_size[] = {(size_t)sgs, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_tanh(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0_abs = extra0->offset + src0->view_offs; + cl_ulong offsetd_abs = extrad->offset + dst->view_offs; + + cl_kernel kernel; + if (dst->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_tanh_f32_nd; + } else if (dst->type == GGML_TYPE_F16) { + kernel = backend_ctx->kernel_tanh_f16_nd; + } else { + GGML_ASSERT(false && "Unsupported type for ggml_cl_tanh"); + } + GGML_ASSERT(kernel != nullptr); + + const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; const int ne02 = src0->ne[2]; const int ne03 = src0->ne[3]; + const cl_ulong nb00 = src0->nb[0]; const cl_ulong nb01 = src0->nb[1]; const cl_ulong nb02 = src0->nb[2]; const cl_ulong nb03 = src0->nb[3]; + + const int ne10 = dst->ne[0]; const int ne11 = dst->ne[1]; const int ne12 = dst->ne[2]; const int ne13 = dst->ne[3]; + const cl_ulong nb10 = dst->nb[0]; const cl_ulong nb11 = dst->nb[1]; const cl_ulong nb12 = dst->nb[2]; const cl_ulong nb13 = dst->nb[3]; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs)); + + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03)); + + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13)); + + size_t global_work_size[3]; + if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements + return; + } + global_work_size[0] = (size_t)ne10; + global_work_size[1] = (size_t)ne11; + global_work_size[2] = (size_t)ne12; + + size_t lws0 = 16, lws1 = 4, lws2 = 1; + if (ne10 < 16) lws0 = ne10; + if (ne11 < 4) lws1 = ne11; + if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1; + + while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2; + while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2; + while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2; + + + size_t local_work_size[] = {lws0, lws1, lws2}; + + size_t* local_work_size_ptr = local_work_size; + if (!backend_ctx->non_uniform_workgroups) { + if (global_work_size[0] % local_work_size[0] != 0 || + global_work_size[1] % local_work_size[1] != 0 || + global_work_size[2] % local_work_size[2] != 0) { + local_work_size_ptr = NULL; + } + } + if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return; + + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr ? local_work_size : (size_t[3]){0,0,0}, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1_shape_def, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_ASSERT(dst->type == src0->type); + + UNUSED(src1_shape_def); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + if (backend_ctx->kernel_repeat == nullptr) { + GGML_LOG_WARN("%s: repeat kernel not available, skipping OpenCL execution.\n", __func__); + return; + } + + ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong off_src0 = extra_src0->offset + src0->view_offs; + cl_ulong off_dst = extra_dst->offset + dst->view_offs; + + const int src0_ne0 = src0->ne[0]; const int src0_ne1 = src0->ne[1]; const int src0_ne2 = src0->ne[2]; const int src0_ne3 = src0->ne[3]; + const cl_ulong src0_nb0 = src0->nb[0]; const cl_ulong src0_nb1 = src0->nb[1]; const cl_ulong src0_nb2 = src0->nb[2]; const cl_ulong src0_nb3 = src0->nb[3]; + + const int dst_ne0 = dst->ne[0]; const int dst_ne1 = dst->ne[1]; const int dst_ne2 = dst->ne[2]; const int dst_ne3 = dst->ne[3]; + const cl_ulong dst_nb0 = dst->nb[0]; const cl_ulong dst_nb1 = dst->nb[1]; const cl_ulong dst_nb2 = dst->nb[2]; const cl_ulong dst_nb3 = dst->nb[3]; + + cl_kernel kernel = backend_ctx->kernel_repeat; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_ulong), &off_src0)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &src0_ne0)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &src0_ne1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &src0_ne2)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &src0_ne3)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &src0_nb0)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &src0_nb1)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &src0_nb2)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &src0_nb3)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &dst_ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &dst_ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &dst_ne2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &dst_ne3)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &dst_nb0)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &dst_nb1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &dst_nb2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &dst_nb3)); + + size_t gws0 = dst_ne1 > 0 ? (size_t)dst_ne1 : 1; + size_t gws1 = dst_ne2 > 0 ? (size_t)dst_ne2 : 1; + size_t gws2 = dst_ne3 > 0 ? (size_t)dst_ne3 : 1; + + size_t global_work_size[] = { gws0, gws1, gws2 }; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, NULL, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, (size_t[3]){0,0,0}, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, NULL, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + if (backend_ctx->kernel_pad == nullptr) { + GGML_LOG_WARN("%s: pad kernel not available, skipping OpenCL execution.\n", __func__); + return; + } + + ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong off_src0 = extra_src0->offset + src0->view_offs; + cl_ulong off_dst = extra_dst->offset + dst->view_offs; + + const int s_ne0 = src0->ne[0]; + const int s_ne1 = src0->ne[1]; + const int s_ne2 = src0->ne[2]; + + const int d_ne0 = dst->ne[0]; + const int d_ne1 = dst->ne[1]; + const int d_ne2 = dst->ne[2]; + + cl_kernel kernel = backend_ctx->kernel_pad; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne0)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne1)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne2)); + + size_t lws0 = 64; + size_t gws0 = (( (size_t)d_ne0 + lws0 - 1 ) / lws0) * lws0; + + size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2 }; + size_t local_work_size[] = { lws0, 1, 1 }; + + size_t * local_work_size_ptr = local_work_size; + if (d_ne0 % lws0 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr ? local_work_size : (size_t[3]){0,0,0}, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + const ggml_scale_mode mode = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0); + cl_kernel kernel = nullptr; + + if (mode == GGML_SCALE_MODE_NEAREST) { + kernel = backend_ctx->kernel_upscale; + if (kernel == nullptr) { + GGML_LOG_WARN("%s: nearest upscale kernel not available, skipping OpenCL execution.\n", __func__); + return; + } + } else if (mode == GGML_SCALE_MODE_BILINEAR) { + kernel = backend_ctx->kernel_upscale_bilinear; + if (kernel == nullptr) { + GGML_LOG_WARN("%s: bilinear upscale kernel not available, skipping OpenCL execution.\n", __func__); + return; + } + } else { + GGML_LOG_WARN("%s: unsupported upscale mode %d, skipping OpenCL execution.\n", __func__, mode); + return; + } + + ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong off_src0 = extra_src0->offset + src0->view_offs; + cl_ulong off_dst = extra_dst->offset + dst->view_offs; + + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const int ne00_src = src0->ne[0]; + const int ne01_src = src0->ne[1]; + + const int ne10_dst = dst->ne[0]; + const int ne11_dst = dst->ne[1]; + const int ne12_dst = dst->ne[2]; + const int ne13_dst = dst->ne[3]; + + const float sf0 = (float)dst->ne[0] / src0->ne[0]; + const float sf1 = (float)dst->ne[1] / src0->ne[1]; + const float sf2 = (float)dst->ne[2] / src0->ne[2]; + const float sf3 = (float)dst->ne[3] / src0->ne[3]; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb03)); + + if (mode == GGML_SCALE_MODE_NEAREST) { + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne10_dst)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11_dst)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12_dst)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13_dst)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &sf0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float), &sf1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf3)); + } else if (mode == GGML_SCALE_MODE_BILINEAR) { + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00_src)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01_src)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10_dst)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11_dst)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12_dst)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13_dst)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf0)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf1)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(float), &sf2)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(float), &sf3)); + } + + + size_t dst_total_elements = (size_t)ne10_dst * ne11_dst * ne12_dst * ne13_dst; + if (dst_total_elements == 0) { + return; + } + size_t global_work_size[] = { dst_total_elements, 1, 1 }; + size_t local_work_size_pref = 256; + size_t local_work_size[] = { MIN(local_work_size_pref, dst_total_elements), 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (dst_total_elements % local_work_size[0] != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 1, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + size_t profiling_gws[3] = {global_work_size[0], 1, 1}; + size_t profiling_lws[3] = {local_work_size_ptr ? local_work_size[0] : 0, 1, 1}; + populateProfilingInfo(g_profiling_info.back(), evt, kernel, profiling_gws, profiling_lws, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 1, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_concat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + if (backend_ctx->kernel_concat_f32_contiguous == nullptr || backend_ctx->kernel_concat_f32_non_contiguous == nullptr) { + GGML_LOG_WARN("%s: concat kernels not available, skipping OpenCL execution.\n", __func__); + return; + } + + ggml_tensor_extra_cl * extra0_cl = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1_cl = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad_cl = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong off_src0 = extra0_cl->offset + src0->view_offs; + cl_ulong off_src1 = extra1_cl->offset + src1->view_offs; + cl_ulong off_dst = extrad_cl->offset + dst->view_offs; + + const int32_t dim = ((const int32_t *) dst->op_params)[0]; + GGML_ASSERT(dim >= 0 && dim <= 3); + + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { + if (dim == 3) { + + size_t nbytes_src0 = ggml_nbytes(src0); + size_t nbytes_src1 = ggml_nbytes(src1); + + CL_CHECK(clEnqueueCopyBuffer(queue, extra0_cl->data_device, extrad_cl->data_device, + off_src0, off_dst, nbytes_src0, 0, NULL, NULL)); + CL_CHECK(clEnqueueCopyBuffer(queue, extra1_cl->data_device, extrad_cl->data_device, + off_src1, off_dst + nbytes_src0, nbytes_src1, 0, NULL, NULL)); + } else { + + cl_kernel kernel = backend_ctx->kernel_concat_f32_contiguous; + size_t global_work_size[3]; + + for (int i3 = 0; i3 < dst->ne[3]; ++i3) { + cl_ulong current_off_src0 = off_src0 + (i3 * src0->nb[3]); + cl_ulong current_off_src1 = off_src1 + (i3 * src1->nb[3]); + cl_ulong current_off_dst = off_dst + (i3 * dst->nb[3]); + + int d_ne00 = src0->ne[0]; int d_ne01 = src0->ne[1]; int d_ne02 = src0->ne[2]; + int d_ne10 = src1->ne[0]; int d_ne11 = src1->ne[1]; int d_ne12 = src1->ne[2]; + int d_ne0 = dst->ne[0]; int d_ne1 = dst->ne[1]; int d_ne2 = dst->ne[2]; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_cl->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), ¤t_off_src0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1_cl->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), ¤t_off_src1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad_cl->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), ¤t_off_dst)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &d_ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &d_ne11)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &d_ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &d_ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &d_ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &d_ne2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &dim)); + + global_work_size[0] = d_ne0; + global_work_size[1] = d_ne1; + global_work_size[2] = d_ne2; + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, NULL, 0, NULL, NULL)); + } + } + } else { + cl_kernel kernel = backend_ctx->kernel_concat_f32_non_contiguous; + + long ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3]; + cl_ulong nb00 = src0->nb[0], nb01 = src0->nb[1], nb02 = src0->nb[2], nb03 = src0->nb[3]; + + cl_ulong nb10 = src1->nb[0], nb11 = src1->nb[1], nb12 = src1->nb[2], nb13 = src1->nb[3]; + + long d_ne0 = dst->ne[0], d_ne1 = dst->ne[1], d_ne2 = dst->ne[2], d_ne3 = dst->ne[3]; + cl_ulong d_nb0 = dst->nb[0], d_nb1 = dst->nb[1], d_nb2 = dst->nb[2], d_nb3 = dst->nb[3]; + + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_cl->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1_cl->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_src1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad_cl->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &off_dst)); + + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(long), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(long), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(long), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(long), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); + + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13)); + + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(long), &d_ne0)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(long), &d_ne1)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(long), &d_ne2)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(long), &d_ne3)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &d_nb0)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &d_nb1)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &d_nb2)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(cl_ulong), &d_nb3)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &dim)); + + size_t global_work_size_nc[] = { d_ne1 > 0 ? (size_t)d_ne1 : 1, + d_ne2 > 0 ? (size_t)d_ne2 : 1, + d_ne3 > 0 ? (size_t)d_ne3 : 1 }; + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size_nc, NULL, 0, NULL, NULL)); + } +} + +static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + if (backend_ctx->kernel_timestep_embedding == nullptr) { + GGML_LOG_WARN("%s: timestep_embedding kernel not available, skipping OpenCL execution.\n", __func__); + return; + } + + ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong off_src0 = extra_src0->offset + src0->view_offs; + cl_ulong off_dst = extra_dst->offset + dst->view_offs; + + const int logical_dim = dst->op_params[0]; + const int max_period = dst->op_params[1]; + const int dst_nb1_bytes = dst->nb[1]; + + cl_kernel kernel = backend_ctx->kernel_timestep_embedding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &dst_nb1_bytes)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &logical_dim)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &max_period)); + + size_t gws0 = (size_t)(((logical_dim + 1) / 2) + 1); + + size_t gws1 = (size_t)src0->ne[0]; + + size_t global_work_size[] = {gws0, gws1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_work_size, NULL, 0, NULL, &evt)); // Pass 2 for 2D problem + + g_profiling_info.emplace_back(); + size_t profiling_gws[3] = {global_work_size[0], global_work_size[1], 1}; + size_t profiling_lws[3] = {0,0,0}; // Reflects NULL LWS + populateProfilingInfo(g_profiling_info.back(), evt, kernel, profiling_gws, profiling_lws, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_work_size, NULL, 0, NULL, NULL)); // Pass 2 for 2D problem +#endif +} + +static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; + const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + +#ifdef GGML_OPENCL_SOA_Q + ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; +#endif + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + + const cl_ulong nb00 = src0 ? src0->nb[0] : 0; + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const cl_ulong nb03 = src0 ? src0->nb[3] : 0; + + const int ne10 = src1 ? src1->ne[0] : 0; + const int ne11 = src1 ? src1->ne[1] : 0; + const int ne12 = src1 ? src1->ne[2] : 0; + const int ne13 = src1 ? src1->ne[3] : 0; + + const cl_ulong nb10 = src1 ? src1->nb[0] : 0; + const cl_ulong nb11 = src1 ? src1->nb[1] : 0; + const cl_ulong nb12 = src1 ? src1->nb[2] : 0; + const cl_ulong nb13 = src1 ? src1->nb[3] : 0; + + const int ne0 = dst ? dst->ne[0] : 0; + const int ne1 = dst ? dst->ne[1] : 0; + + int r2 = ne12/ne02; + int r3 = ne13/ne03; + + GGML_ASSERT(ne00 == ne10); + + int nth0 = 32; + int nth1 = 1; + int nrows = 1; + // The number of values produced by each subgroup + int ndst = 4; + + cl_kernel kernel; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + cl_context context = backend_ctx->context; + + if (ne01 && ne1 && use_adreno_kernels(backend_ctx, src0)) { + + // init CL objects + // <--------------------------------------------> // + cl_int status; + cl_image_format img_fmt_1d; + cl_image_desc img_desc_1d; + cl_buffer_region region; + cl_mem A_image1d = nullptr; + cl_mem B_image1d = nullptr; + cl_mem B_sub_buffer = nullptr; + cl_mem C_d = nullptr; + // for B transpose + cl_mem B_d = nullptr; + cl_mem B_d_input_image = nullptr; + // <--------------------------------------------> // + + // define matrix dimensions + // <--------------------------------------------> // + int M = ne01; + int N = ne1; + int K = ne00; + int padding; + // <--------------------------------------------> // + + // q4_0 x fp32 + if(src0t == GGML_TYPE_Q4_0 && src1t == GGML_TYPE_F32) { + // TODO: remove duplicate definitions of image description + format -- move to top + + // create an image for A + // <--------------------------------------------> // + if (N == 1) { + img_fmt_1d = { CL_R, CL_UNSIGNED_INT32}; + } else { + img_fmt_1d = { CL_R, CL_FLOAT}; + } + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 2 / 4; // Divide by 4 for char -> float + img_desc_1d.buffer = extra0_q4_0->q; + A_image1d = clCreateImage( + context, + CL_MEM_READ_ONLY, + &img_fmt_1d, + &img_desc_1d, + NULL, + &status); + CL_CHECK(status); + // <--------------------------------------------> // + + + // create a sub_buffer for B + // <--------------------------------------------> // + region.origin = (extra1->offset); + region.size = K * N * sizeof(float); + B_sub_buffer = clCreateSubBuffer( + extra1->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // <--------------------------------------------> // + + // transpose activation for Skyler's gemm + if (N != 1) { + //how many extra elements beyond multiple of 8 + int extra_elements = N % 8; + + //how much padding to add + padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // Specify the starting offset (in bytes) + region.origin = 0; + // Specify the size of the sub-buffer (divide by 2 for FP16) + region.size = K * (N + padding) * sizeof(float)/2; + B_d = clCreateSubBuffer( + backend_ctx->B_d_max, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + + cl_image_format image_format_B_d_input = { CL_RGBA, CL_FLOAT }; + cl_image_desc image_desc_B_d_input = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(K * N / 4), + 0, 0, 0, 0, 0, 0, 0, { B_sub_buffer } + }; + B_d_input_image = clCreateImage( + context, + 0, + &image_format_B_d_input, + &image_desc_B_d_input, + NULL, + &status); + CL_CHECK(status); + + cl_image_format image_format_B_d_output = { CL_RGBA, CL_HALF_FLOAT }; //(CL_HALF_FLOAT for FP16) + cl_image_desc image_desc_B_d_output = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(K * (N + padding)/4), + 0, 0, 0, 0, 0, 0, 0, { B_d } + }; + B_image1d = clCreateImage( + context, + 0, + &image_format_B_d_output, + &image_desc_B_d_output, + NULL, + &status); + CL_CHECK(status); + + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &B_d_input_image)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &B_image1d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_size_t[2] = { 1, 16 }; + //WGS tuning + if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) { + local_size_t[0]=4; + local_size_t[1]=8; + } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) { + local_size_t[0]=2; + local_size_t[1]=8; + } else if(ne0 == 4096 && ne1 == 128 && ne10 == 11008) { + local_size_t[0]=1; + local_size_t[1]=8; + } else if(ne0 == 32000 && ne1 == 128 && ne10 == 4096) { + local_size_t[0]=2; + local_size_t[1]=8; + } + + size_t global_size_t[2] = { + static_cast(width_B), + static_cast(padded_height_B) + }; + + #ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_size_t, local_size_t, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_size_t, local_size_t, dst); + #else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_size_t, local_size_t, 0, NULL, NULL)); + #endif + } else { + // no need to transpose B in other cases + // create an image for B from sub_buffer + // <--------------------------------------------> // + img_fmt_1d = {CL_RGBA, CL_FLOAT}; + + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_width = K * N / 4; + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.buffer = B_sub_buffer; + B_image1d = clCreateImage( + context, + CL_MEM_READ_ONLY, + &img_fmt_1d, + &img_desc_1d, + NULL, + &status); + CL_CHECK(status); + // <--------------------------------------------> // + } + + // choose gemm or gemv kernel + // <--------------------------------------------> // + if (N == 1) { + kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general; + if (M == 4096 && K == 4096) { + kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096; + } else if (M == 4096 && K == 11008) { + kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008; + } else if (M == 11008 && K == 4096) { + kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096; + } else if (M == 32000 && K == 4096) { + kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096; + } + } else { + kernel = backend_ctx->CL_mul_mat_Ab_Bi_8x4; + } + // <--------------------------------------------> // + + // set kernel args + // <--------------------------------------------> // + cl_uint k_arg = 0; + + if (N == 1) { + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &A_image1d)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &B_image1d)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extra1->offset)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extrad->offset)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r3)); + } else { + region.origin = extrad->offset; // Specify the starting offset (in bytes) + region.size = M * N * sizeof(float); // Specify the size of the sub-buffer + C_d = clCreateSubBuffer(extrad->data_device, CL_MEM_WRITE_ONLY, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + int padded_N = ne1 + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); //A_q_dextra0_q4_0->q + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); //A_s_d + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &B_image1d)); //B_d + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &C_d)); //C_d + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01)); //M + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &padded_N)); //N with padding + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); //K + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne1)); //N without padding + } + // <--------------------------------------------> // + + // choose workgroup size + // <--------------------------------------------> // + size_t global_work_size[3] = { + 64, static_cast((M+63)/64), static_cast((N+31)/32)}; + size_t local_work_size[3] = {64, 2, 4}; + + global_work_size[0] = (size_t)(ceil((float)ne1/8)); + global_work_size[1] = (size_t)(ne01/4); + global_work_size[2] = (size_t)(1); + + local_work_size[0] = (size_t)(1); //4x32 for FP32 + local_work_size[1] = (size_t)(128); + local_work_size[2] = (size_t)(1); + + //WGS tuning + if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) { + local_work_size[0] = 1; + local_work_size[1] = 128; + } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) { + local_work_size[0] = 2; + local_work_size[1] = 64; + } else if (ne0 == 4096 && ne1 == 128 && ne10 == 11008) { + local_work_size[0] = 2; + local_work_size[1] = 64; + } else if (ne0 == 32000 && ne1 == 128 && ne10 == 4096) { + local_work_size[0] = 2; + local_work_size[1] = 64; + } + + if (N == 1) { + size_t wavesize = backend_ctx->adreno_wave_size; + local_work_size[0] = wavesize; // localsize + local_work_size[1] = 4; // reduce factor + local_work_size[2] = 1; + + global_work_size[0] = (((M / 2) + wavesize - 1) / wavesize) * wavesize; + global_work_size[1] = 4; // reduce factor + global_work_size[2] = 1; + } + // <--------------------------------------------> // + + // enqueue kernel with profiling + // <--------------------------------------------> // + #ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); + // enqueue kernel without profiling + #else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + #endif + // <--------------------------------------------> // + + // deallocate sub buffers and images + // <--------------------------------------------> // + CL_CHECK(clReleaseMemObject(A_image1d)); + CL_CHECK(clReleaseMemObject(B_sub_buffer)); + CL_CHECK(clReleaseMemObject(B_image1d)); + + if (N != 1) { + CL_CHECK(clReleaseMemObject(B_d)); + CL_CHECK(clReleaseMemObject(B_d_input_image)); + CL_CHECK(clReleaseMemObject(C_d)); + } + // <--------------------------------------------> // + + return; + } + } // if (ne01 && ne1) +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + if (!ggml_is_transposed(src0) && + !ggml_is_transposed(src1) && + src1t == GGML_TYPE_F32 && + ne00%32 == 0 && + ne11 > 2) { +#ifdef GGML_OPENCL_SOA_Q + // Set up kernel. + switch(src0t) { + case GGML_TYPE_Q4_0: + // This should have been satisfied. + GGML_ASSERT(ne11 == ne1); + GGML_ASSERT(ne01 == ne0); + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + + kernel = backend_ctx->kernel_mul_mat_q4_0_f32_1d_16x_flat; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + + kernel = backend_ctx->kernel_mul_mat_q4_0_f32_1d_8x_flat; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); + break; + default: + break; + } + + // Launch kernel. + if (src0t == GGML_TYPE_Q4_0) { + size_t global_work_size[] = {(size_t)(ne01 + 7)/8*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + + if (backend_ctx->gpu_family == INTEL) { + // Set global size for Intel. It uses 16x output values. + global_work_size[0] = (size_t)(ne01 + 15)/16*nth0; + global_work_size[1] = (size_t)ne11*nth1; + global_work_size[2] = (size_t)ne12*ne13; + } + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + return; + } +#else // GGML_OPENCL_SOA_Q + // TODO: add block_q4_0 variant. +#endif // GGML_OPENCL_SOA_Q + } + + // use custom matrix x vector kernel + switch (src0t) { + case GGML_TYPE_F32: + //GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(src1t == GGML_TYPE_F32); + kernel = backend_ctx->kernel_mul_mat_f32_f32; + nrows = 4; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 32; + nth1 = 1; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r3)); + break; + case GGML_TYPE_F16: + //GGML_ASSERT(ne02 == ne12); + if (backend_ctx->gpu_family == INTEL) { + nth0 = 32; + nth1 = 1; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + if (src1t == GGML_TYPE_F32) { + if (ne11 * ne12 < 4) { + kernel = backend_ctx->kernel_mul_mat_f16_f32_1row; + } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { + kernel = backend_ctx->kernel_mul_mat_f16_f32_l4; + nrows = ne11; + } else { + kernel = backend_ctx->kernel_mul_mat_f16_f32; + nrows = 4; + } + } else { + kernel = backend_ctx->kernel_mul_mat_f16_f16; + nrows = 4; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r3)); + break; + case GGML_TYPE_Q4_0: + // This should have been satisfied. + GGML_ASSERT(ne11 == ne1); + GGML_ASSERT(ne01 == ne0); + +#ifdef GGML_OPENCL_SOA_Q + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + + kernel = backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat; + ndst = 8; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + + kernel = backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat; + ndst =8; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#else // GGML_OPENCL_SOA_Q + if (backend_ctx->gpu_family == INTEL) { + // Use 1D local size. Each workgroup is a SIMD group. Each SIMD + // group produces N_DST (4 for Q4_0 kernel) values in the result. + // The number of workgroups on dim 0 (the leading dimension) is + // the nearest multiple of 4 that covers ne0 (equals ne01). + nth0 = 16; + nth1 = 1; + + kernel = backend_ctx->kernel_mul_mat_q4_0_f32; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + + kernel = backend_ctx->kernel_mul_mat_q4_0_f32_v; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q + break; + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + kernel = backend_ctx->kernel_mul_mv_q6_K_f32; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 2; + nth1 = 16; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 2; + nth1 = 64; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); + break; + default: + GGML_ASSERT(false && "not implemented"); + } + + if (src0t == GGML_TYPE_Q4_0 || + src0t == GGML_TYPE_Q4_1 || + src0t == GGML_TYPE_Q8_0 || + src0t == GGML_TYPE_Q2_K) { + // Each SIMD group produces N_DST values in the result. Assuming each + // workgroup has N_SIMDGROUP SIMD groups, then each workgroup will + // produce N_DST*N_SIMDGROUP values in the result. Hence, the grid size + // (number of workgroups) will be a nearest multiple of + // N_DST*N_SIMDGROUP to cover the size of the dimension. Below, 4 is + // N_DST*N_SIMDGROUP (see the kernel for Q4_0 matmul). + size_t global_work_size[] = {(size_t)(ne01 + ndst-1)/ndst*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } else if (src0t == GGML_TYPE_Q4_K) { + GGML_ASSERT(false && "not implemented"); + } else if (src0t == GGML_TYPE_Q3_K) { + GGML_ASSERT(false && "not implemented"); + } else if (src0t == GGML_TYPE_Q5_K) { + GGML_ASSERT(false && "not implemented"); + } else if (src0t == GGML_TYPE_Q6_K) { + size_t global_work_size[] = {(size_t)(ne01+1)/2*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } else { + int64_t ny = (ne11 + nrows - 1)/nrows; + + size_t global_work_size[] = {(size_t)ne01*nth0, (size_t)ny*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } +} + +static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const ggml_tensor * src2 = dst->src[2]; + GGML_ASSERT(src2); + GGML_ASSERT(src2->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offset2 = extra2->offset + src2->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + +#ifdef GGML_OPENCL_SOA_Q + ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; +#endif + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb02 = src0->nb[2]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + + const int ne20 = src2->ne[0]; + const int ne21 = src2->ne[1]; + + const cl_ulong nb21 = src2->nb[1]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + + const int r2 = ne12/ne02; + const int r3 = ne13/ne03; + const int dst_rows = ne20*ne21; // ne20 = n_used_experts, ne21 = n_rows + + GGML_ASSERT(ne00 == ne10); + + int sgs = 32; // subgroup size + int nsg = 1; // number of subgroups + int nrows = 1; // number of row in src1 + int ndst = 4; // number of values produced by each subgroup + + cl_kernel kernel; + + // subgroup mat vec + switch (src0->type) { + case GGML_TYPE_Q4_0: { + kernel = backend_ctx->kernel_mul_mv_id_q4_0_f32_8x_flat; + + if (backend_ctx->gpu_family == INTEL) { + sgs = 16; + nsg = 1; + ndst = 8; + } else if (backend_ctx->gpu_family == ADRENO) { + sgs = 64; + nsg = 1; + ndst = 8; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne20)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &ne21)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb21)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &r3)); + + break; + } + default: + GGML_ASSERT(false && "not implemented");; + } + + int _ne1 = 1; + int ne123 = dst_rows; + + size_t global_work_size[] = {(size_t)(ne01+ndst*nsg-1)/(ndst*nsg)*sgs, (size_t)(_ne1+nrows-1)/nrows*nsg, (size_t)ne123}; + size_t local_work_size[] = {(size_t)sgs, (size_t)nsg, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_UNUSED(src1); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + float scale; + memcpy(&scale, dst->op_params, sizeof(scale)); + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel = backend_ctx->kernel_scale; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &scale)); + + int n = ggml_nelements(dst)/4; + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + + // GGML_OP_CPY happens between src0 and src1. + // GGML_OP_DUP and GGML_OP_CONT happen between src0 and dst. + UNUSED(dst); + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + + const cl_ulong nb00 = src0 ? src0->nb[0] : 0; + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const cl_ulong nb03 = src0 ? src0->nb[3] : 0; + + const int ne10 = src1 ? src1->ne[0] : 0; + const int ne11 = src1 ? src1->ne[1] : 0; + const int ne12 = src1 ? src1->ne[2] : 0; + const int ne13 = src1 ? src1->ne[3] : 0; + + const cl_ulong nb10 = src1 ? src1->nb[0] : 0; + const cl_ulong nb11 = src1 ? src1->nb[1] : 0; + const cl_ulong nb12 = src1 ? src1->nb[2] : 0; + const cl_ulong nb13 = src1 ? src1->nb[3] : 0; + + const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; + const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + + cl_kernel kernel; + + switch (src0t) { + case GGML_TYPE_F32: + switch (src1t) { + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_cpy_f32_f16; + break; + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_cpy_f32_f32; + break; + default: + GGML_ASSERT(false && "not implemented"); + } + break; + case GGML_TYPE_F16: + switch (src1t) { + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_cpy_f16_f16; + break; + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_cpy_f16_f32; + break; + default: + GGML_ASSERT(false && "not implemented"); + } + break; + default: + GGML_ASSERT(false && "not implemented"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13)); + + const int nth = MIN(64, ne00); + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, src1); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_dup(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cl_cpy(backend, src0, dst, nullptr); + UNUSED(src1); +} + +static void ggml_cl_diag_mask_inf(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + int n_past = ((int32_t *)(dst->op_params))[0]; + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + if (ne00%8 == 0) { + kernel = backend_ctx->kernel_diag_mask_inf_8; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &n_past)); + + size_t global_work_size[] = {(size_t)ne00*ne01*ne02/8, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } else { + kernel = backend_ctx->kernel_diag_mask_inf; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &n_past)); + + size_t global_work_size[] = {(size_t)ne00, (size_t)ne01, (size_t)ne02}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (ne00 % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); +#endif + } +} + +static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + // Softmax can now fuse KQ mask and KQ scale, which used to be two additional + // ops before softmax. It now also fuses alibi if `max_bias > 0`. For llama, + // alibi is not used; however, for some other models, it is used. + // KQ_mask + if (src1) { + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + } + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0; + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + + float scale, max_bias; + memcpy(&scale, dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, dst->op_params + 1, sizeof(float)); + + const int nrows_x = ggml_nrows(src0); + const int nrows_y = src0->ne[1]; + + const int n_head = nrows_x/nrows_y; + const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + + // Local size must be wave size. Each workgroup is a wave, working on a row, + // where a row corresponds to leading dimension. + int nth = MIN(32, ne00); + + if (backend_ctx->gpu_family == INTEL) { + // This is the same as the initial value. + nth = MIN(32, ne00); + } + else if (backend_ctx->gpu_family == ADRENO) { + nth = 64; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + cl_kernel kernel; + + if (ne00%4 == 0) { + if (use_f16) { + kernel = backend_ctx->kernel_soft_max_4_f16; + } else { + kernel = backend_ctx->kernel_soft_max_4; + } + } else { + if (use_f16) { + kernel = backend_ctx->kernel_soft_max_f16; + } else { + kernel = backend_ctx->kernel_soft_max; + } + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), extra1 ? &extra1->data_device : &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(float), &scale)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &max_bias)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &m0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &m1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &n_head_log2)); + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + ggml_tensor * src2 = dst->src[2]; + ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr; + + cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0; + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + + const cl_ulong nb00 = src0 ? src0->nb[0] : 0; + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const cl_ulong nb03 = src0 ? src0->nb[3] : 0; + + const int ne10 = src1 ? src1->ne[0] : 0; + const int ne11 = src1 ? src1->ne[1] : 0; UNUSED(ne11); + const int ne12 = src1 ? src1->ne[2] : 0; UNUSED(ne12); + const int ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13); + + const int ne0 = dst ? dst->ne[0] : 0; + const int ne1 = dst ? dst->ne[1] : 0; + const int ne2 = dst ? dst->ne[2] : 0; + const int ne3 = dst ? dst->ne[3] : 0; + + const cl_ulong nb0 = dst ? dst->nb[0] : 0; + const cl_ulong nb1 = dst ? dst->nb[1] : 0; + const cl_ulong nb2 = dst ? dst->nb[2] : 0; + const cl_ulong nb3 = dst ? dst->nb[3] : 0; + + GGML_ASSERT(ne10 % ne02 == 0); + GGML_ASSERT(ne10 >= ne02); + + int nth = MIN(64, ne00); + + const int n_past = ((int *) dst->op_params)[0]; + const int n_dims = ((int *) dst->op_params)[1]; + const int mode = ((int *) dst->op_params)[2]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + int32_t sections[4]; + + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int32_t)*4); + + const bool is_neox = mode & 2; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (is_mrope) { + GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0); + } + + if (is_vision) { + GGML_ASSERT(n_dims == ne00/2); + } + + cl_kernel kernel; + + if (is_neox) { + switch (src0->type) { + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_rope_neox_f32; + break; + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_rope_neox_f16; + break; + default: + GGML_ASSERT(false); + }; + } else if (is_mrope && !is_vision) { + switch (src0->type) { + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_rope_multi_f32; + break; + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_rope_multi_f16; + break; + default: + GGML_ASSERT(false); + }; + } else if (is_vision) { + switch (src0->type) { + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_rope_vision_f32; + break; + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_rope_vision_f16; + break; + default: + GGML_ASSERT(false); + } + } else { + switch (src0->type) { + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_rope_norm_f32; + break; + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_rope_norm_f16; + break; + default: + GGML_ASSERT(false); + }; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), extra2 ? &extra2->data_device : &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &ne3)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &nb3)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &n_past)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &n_dims)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &n_ctx_orig)); + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(float), &freq_base)); + CL_CHECK(clSetKernelArg(kernel, 28, sizeof(float), &freq_scale)); + CL_CHECK(clSetKernelArg(kernel, 29, sizeof(float), &ext_factor)); + CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float), &attn_factor)); + CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float), &beta_fast)); + CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &beta_slow)); + if (is_mrope || is_vision) { + CL_CHECK(clSetKernelArg(kernel, 33, sizeof(int32_t)*4, §ions)); + } + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_im2col(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + // src0 - filter, src1 - input + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; + + const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1; + + const cl_long IC = src1->ne[is_2D ? 2 : 1]; + const cl_long IH = is_2D ? src1->ne[1] : 1; + const cl_long IW = src1->ne[0]; + + const cl_long KH = is_2D ? src0->ne[1] : 1; + const cl_long KW = src0->ne[0]; + + const cl_long OH = is_2D ? dst->ne[2] : 1; + const cl_long OW = dst->ne[1]; + + // nb is byte offset, src is type float32 + const cl_ulong delta_offset = src1->nb[is_2D ? 2 : 1]/4; + const cl_long batch = src1->ne[is_2D ? 3 : 2]; + const cl_ulong batch_offset = src1->nb[is_2D ? 3 : 2]/4; + + const cl_long pelements = OW*KW*KH; + const cl_long CHW = IC*KH*KW; + + cl_kernel kernel; + + if(dst->type == GGML_TYPE_F16) { + kernel = backend_ctx->kernel_im2col_f16; + } else { + kernel = backend_ctx->kernel_im2col_f32; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &batch_offset)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &delta_offset)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_long), &IW)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_long), &IH)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_long), &IC)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_long), &OW)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_long), &OH)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_long), &KW)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_long), &KH)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_long), &pelements)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_long), &CHW)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &s0)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &s1)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &p0)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &p1)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &d0)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &d1)); + + const int num_blocks = (pelements + 256 - 1) / 256; + size_t global_work_size[] = {(size_t)num_blocks*256, (size_t)OH, (size_t)batch*IC}; + size_t local_work_size[] = {256, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_argsort(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_UNUSED(src1); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int nrows = ggml_nrows(src0); + + int ne00_padded = 1; + while (ne00_padded < ne00) { + ne00_padded *= 2; + } + + int order = (enum ggml_sort_order) dst->op_params[0]; + + cl_kernel kernel = backend_ctx->kernel_argsort_f32_i32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne00_padded)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &order)); + CL_CHECK(clSetKernelArg(kernel, 7, ne00_padded*sizeof(int), NULL)); + + size_t global_work_size[] = {(size_t)ne00_padded, (size_t)nrows, (size_t)1}; + size_t local_work_size[] = {(size_t)ne00_padded, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_UNUSED(src1); + + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + GGML_ASSERT(ggml_is_contiguous(src0)); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; + + cl_kernel kernel = backend_ctx->kernel_sum_rows_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3)); + + size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +//------------------------------------------------------------------------------ +// Op offloading +//------------------------------------------------------------------------------ + +typedef void (*ggml_cl_func_t)(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor) { + ggml_cl_func_t func = nullptr; + + ggml_tensor * src0 = tensor->src[0]; + ggml_tensor * src1 = tensor->src[1]; + + const bool any_on_device = tensor->extra + || (src0 != nullptr && src0->extra) + || (src1 != nullptr && src1->extra); + + switch (tensor->op) { + case GGML_OP_GET_ROWS: + if (!any_on_device) { + return false; + } + func = ggml_cl_get_rows; + break; + case GGML_OP_CPY: + if (!any_on_device) { + return false; + } + func = ggml_cl_cpy; + break; + case GGML_OP_DUP: + case GGML_OP_CONT: + if (!any_on_device) { + return false; + } + func = ggml_cl_dup; + break; + case GGML_OP_ADD: + if (!any_on_device) { + return false; + } + func = ggml_cl_add; + break; + case GGML_OP_MUL: + if (!any_on_device) { + return false; + } + func = ggml_cl_mul; + break; + case GGML_OP_DIV: + if (!any_on_device) { + return false; + } + func = ggml_cl_div; + break; + case GGML_OP_SUB: + if (!any_on_device) { + return false; + } + func = ggml_cl_sub; + break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_GELU: + if (!any_on_device) { + return false; + } + func = ggml_cl_gelu; + break; + case GGML_UNARY_OP_GELU_QUICK: + if (!any_on_device) { + return false; + } + func = ggml_cl_gelu_quick; + break; + case GGML_UNARY_OP_SILU: + if (!any_on_device) { + return false; + } + func = ggml_cl_silu; + break; + case GGML_UNARY_OP_RELU: + if (!any_on_device) { + return false; + } + func = ggml_cl_relu; + break; + case GGML_UNARY_OP_SIGMOID: + if (!any_on_device) { + return false; + } + func = ggml_cl_sigmoid; + break; + case GGML_UNARY_OP_TANH: + if (!any_on_device) { + return false; + } + func = ggml_cl_tanh; + break; + default: + return false; + } break; + case GGML_OP_CLAMP: + if (!any_on_device) { + return false; + } + func = ggml_cl_clamp; + break; + case GGML_OP_NORM: + if (!any_on_device) { + return false; + } + func = ggml_cl_norm; + break; + case GGML_OP_RMS_NORM: + if (!any_on_device) { + return false; + } + func = ggml_cl_rms_norm; + break; + case GGML_OP_GROUP_NORM: + if (!any_on_device) { + return false; + } + func = ggml_cl_group_norm; + break; + case GGML_OP_REPEAT: + if (!any_on_device) { + return false; + } + func = ggml_cl_repeat; + break; + case GGML_OP_PAD: + if (!any_on_device) { + return false; + } + ggml_cl_pad(backend, tensor->src[0], tensor); + return true; + case GGML_OP_UPSCALE: + if (!any_on_device) { + return false; + } + ggml_cl_upscale(backend, tensor->src[0], tensor); + return true; + case GGML_OP_CONCAT: + if (!any_on_device) { + return false; + } + func = ggml_cl_concat; + break; + case GGML_OP_TIMESTEP_EMBEDDING: + if (!any_on_device) { + return false; + } + ggml_cl_timestep_embedding(backend, tensor->src[0], tensor); + return true; + case GGML_OP_MUL_MAT: + if (!any_on_device && !ggml_cl_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) { + return false; + } + func = ggml_cl_mul_mat; + break; + case GGML_OP_MUL_MAT_ID: + if (!any_on_device) { + return false; + } + func = ggml_cl_mul_mat_id; + break; + case GGML_OP_SCALE: + if (!any_on_device) { + return false; + } + func = ggml_cl_scale; + break; + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + if (!any_on_device) { + return false; + } + func = ggml_cl_nop; + break; + case GGML_OP_DIAG_MASK_INF: + if (!any_on_device) { + return false; + } + func = ggml_cl_diag_mask_inf; + break; + case GGML_OP_SOFT_MAX: + if (!any_on_device) { + return false; + } + func = ggml_cl_soft_max; + break; + case GGML_OP_ROPE: + if (!any_on_device) { + return false; + } + func = ggml_cl_rope; + break; + case GGML_OP_IM2COL: + if (!any_on_device) { + return false; + } + func = ggml_cl_im2col; + break; + case GGML_OP_ARGSORT: + if (!any_on_device) { + return false; + } + func = ggml_cl_argsort; + break; + case GGML_OP_SUM_ROWS: + if (!any_on_device) { + return false; + } + func = ggml_cl_sum_rows; + break; + default: + return false; + } + + func(backend, tensor->src[0], tensor->src[1], tensor); + return true; +} diff --git a/ggml/src/ggml-opencl/kernels/add.cl b/ggml/src/ggml-opencl/kernels/add.cl new file mode 100644 index 0000000000000000000000000000000000000000..f73f3c01343887530a8e9372354b24c6ba911399 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/add.cl @@ -0,0 +1,83 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// add +//------------------------------------------------------------------------------ + +// general-purpose kernel for addition of two tensors +// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 +// cons: not very efficient +kernel void kernel_add( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i10 = i0 % ne10; + *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) + *((global float *)(src1_ptr + i10*nb10)); + } +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_add_row( + global float4 * src0, + ulong offset0, + global float4 * src1, + ulong offset1, + global float4 * dst, + ulong offsetd, + int ne +) { + src0 = (global float4*)((global char*)src0 + offset0); + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float4*)((global char*)dst + offsetd); + + // This performs better than using %. + uint gid = get_global_id(0); + uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne + dst[gid] = src0[gid] + src1[idx1]; +} diff --git a/ggml/src/ggml-opencl/kernels/argsort.cl b/ggml/src/ggml-opencl/kernels/argsort.cl new file mode 100644 index 0000000000000000000000000000000000000000..af4adc7b83f0a5787605349921e266e9fcd07e61 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/argsort.cl @@ -0,0 +1,86 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define SWAP(x, y, T) { T tmp = (x); (x) = (y); (y) = tmp; } + +enum ggml_sort_order { + GGML_SORT_ORDER_ASC, + GGML_SORT_ORDER_DESC, +}; + +kernel void kernel_argsort_f32_i32( + global float * src0, + ulong offset0, + global int * dst, + ulong offsetd, + const int ne00, + const int ne00_pad, + const int order, + local int * dst_row +) { + // bitonic sort + int col = get_local_id(0); + int row = get_group_id(1); + + if (col >= ne00_pad) { + return; + } + + src0 = (global char *)((global char *)src0 + offset0); + dst = (global float *)((global char *)dst + offsetd); + + global float * x_row = src0 + row * ne00; + + // initialize indices + dst_row[col] = col; + + barrier(CLK_LOCAL_MEM_FENCE); + + for (int k = 2; k <= ne00_pad; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= ne00 || + (dst_row[ixj] < ne00 && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj], int); + } + } else { + if (dst_row[ixj] >= ne00 || + (dst_row[col] < ne00 && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj], int); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + } + + // copy the result to dst without the padding + if (col < ne00) { + dst[row * ne00 + col] = dst_row[col]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/clamp.cl b/ggml/src/ggml-opencl/kernels/clamp.cl new file mode 100644 index 0000000000000000000000000000000000000000..ae6032444e8233a1ec35b3dce22d12b167cdbcc5 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/clamp.cl @@ -0,0 +1,20 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// clamp +//------------------------------------------------------------------------------ +kernel void kernel_clamp( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + float min, + float max +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = src0[get_global_id(0)] < min ? + min : + (src0[get_global_id(0)] > max ? max : src0[get_global_id(0)]); +} diff --git a/ggml/src/ggml-opencl/kernels/concat.cl b/ggml/src/ggml-opencl/kernels/concat.cl new file mode 100644 index 0000000000000000000000000000000000000000..132758469c6fa3d1e6da778f080e5e05a5613aee --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/concat.cl @@ -0,0 +1,109 @@ +kernel void kernel_concat_f32_contiguous( + global const char * p_src0, ulong off_src0, + global const char * p_src1, ulong off_src1, + global char * p_dst, ulong off_dst, + int d_ne00, int d_ne01, int d_ne02, // src0->ne[0..2] for the slice + int d_ne10, int d_ne11, int d_ne12, // src1->ne[0..2] for the slice (d_ne1X must match d_ne0X on non-concat axes) + int d_ne0, int d_ne1, int d_ne2, // dst->ne[0..2] for the slice + int dim +) { + global const float * src0 = (global const float*)((global char*)p_src0 + off_src0); + global const float * src1 = (global const float*)((global char*)p_src1 + off_src1); + global float * dst = (global float*)((global char*)p_dst + off_dst); + + int i0 = get_global_id(0); // Index along dst's 0th dimension + int i1 = get_global_id(1); // Index along dst's 1st dimension + int i2 = get_global_id(2); // Index along dst's 2nd dimension + + if (i0 >= d_ne0 || i1 >= d_ne1 || i2 >= d_ne2) { + return; + } + + ulong dst_idx = (ulong)i2 * d_ne0 * d_ne1 + (ulong)i1 * d_ne0 + i0; + ulong src_idx; + + if (dim == 0) { + if (i0 < d_ne00) { // Data from src0 + src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0; + dst[dst_idx] = src0[src_idx]; + } else { // Data from src1 + src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + (i0 - d_ne00); + dst[dst_idx] = src1[src_idx]; + } + } else if (dim == 1) { + if (i1 < d_ne01) { // Data from src0 + src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0; + dst[dst_idx] = src0[src_idx]; + } else { // Data from src1 + src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)(i1 - d_ne01) * d_ne10 + i0; + dst[dst_idx] = src1[src_idx]; + } + } else if (dim == 2) { + if (i2 < d_ne02) { // Data from src0 + src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0; + dst[dst_idx] = src0[src_idx]; + } else { // Data from src1 + + src_idx = (ulong)(i2 - d_ne02) * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + i0; + dst[dst_idx] = src1[src_idx]; + } + } +} + +kernel void kernel_concat_f32_non_contiguous( + global const char * p_src0, ulong off_src0, + global const char * p_src1, ulong off_src1, + global char * p_dst, ulong off_dst, + + long ne00, long ne01, long ne02, long ne03, + ulong nb00, ulong nb01, ulong nb02, ulong nb03, + + ulong nb10, ulong nb11, ulong nb12, ulong nb13, // Strides for src1 + + long d_ne0, long d_ne1, long d_ne2, long d_ne3, + ulong d_nb0, ulong d_nb1, ulong d_nb2, ulong d_nb3, + int dim +) { + global const char * src0_base = p_src0 + off_src0; + global const char * src1_base = p_src1 + off_src1; + global char * dst_base = p_dst + off_dst; + + long current_i1 = get_global_id(0); // Index for dst_dim_1 + long current_i2 = get_global_id(1); // Index for dst_dim_2 + long current_i3 = get_global_id(2); // Index for dst_dim_3 + + if (current_i1 >= d_ne1 || current_i2 >= d_ne2 || current_i3 >= d_ne3) { + return; + } + + global const float * x_val_ptr; + global float * y_val_ptr; + + for (long current_i0 = 0; current_i0 < d_ne0; ++current_i0) { + bool use_src0; + long s_i0 = current_i0, s_i1 = current_i1, s_i2 = current_i2, s_i3 = current_i3; + + if (dim == 0) { + use_src0 = (current_i0 < ne00); + if (!use_src0) { s_i0 = current_i0 - ne00; } + } else if (dim == 1) { + use_src0 = (current_i1 < ne01); + if (!use_src0) { s_i1 = current_i1 - ne01; } + } else if (dim == 2) { + use_src0 = (current_i2 < ne02); + if (!use_src0) { s_i2 = current_i2 - ne02; } + } else { // dim == 3 + use_src0 = (current_i3 < ne03); + if (!use_src0) { s_i3 = current_i3 - ne03; } + } + + if (use_src0) { + x_val_ptr = (global const float *)(src0_base + (ulong)s_i3*nb03 + (ulong)s_i2*nb02 + (ulong)s_i1*nb01 + (ulong)s_i0*nb00); + } else { + x_val_ptr = (global const float *)(src1_base + (ulong)s_i3*nb13 + (ulong)s_i2*nb12 + (ulong)s_i1*nb11 + (ulong)s_i0*nb10); + } + + y_val_ptr = (global float *)(dst_base + (ulong)current_i3*d_nb3 + (ulong)current_i2*d_nb2 + (ulong)current_i1*d_nb1 + (ulong)current_i0*d_nb0); + *y_val_ptr = *x_val_ptr; + } +} diff --git a/ggml/src/ggml-opencl/kernels/cpy.cl b/ggml/src/ggml-opencl/kernels/cpy.cl new file mode 100644 index 0000000000000000000000000000000000000000..9369351a60c45e09aea98121a94c63262e6e3ab5 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/cpy.cl @@ -0,0 +1,184 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// cpy +//------------------------------------------------------------------------------ + +kernel void kernel_cpy_f16_f16( + global half * src0, + ulong offset0, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f16_f32( + global half * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + + src0 = (global half*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_f16( + global float * src0, + ulong offset0, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl new file mode 100644 index 0000000000000000000000000000000000000000..fe7975e3dbfc3d4dd402c9f198c85dd1cd004b48 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -0,0 +1,118 @@ +//------------------------------------------------------------------------------ +// This file is contains kernels for data conversion. +// These kernels are used when loading the model, so its performance is less +// important. +//------------------------------------------------------------------------------ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +//------------------------------------------------------------------------------ +// kernel_convert_block_q4_0 +// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q4_0( + global struct block_q4_0 * src0, + global uchar * dst_q, + global half * dst_d +) { + global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + + for (int i = 0; i < QK4_0/2; ++i) { + q[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_q4_0( + global uchar * src_q, + global half * src_d, + global struct block_q4_0 * dst +) { + global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + for (int i = 0; i < QK4_0/2; ++i) { + b->qs[i] = q[i]; + } +} + +//------------------------------------------------------------------------------ +// kernel_convert_block_q4_0_noshuffle +// Flatten q4_0 weights and unshuffle the bits +//------------------------------------------------------------------------------ + +kernel void kernel_convert_block_q4_0_noshuffle( + global struct block_q4_0 * src0, + global uchar * dst_q, + global half * dst_d +) { + global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + for (int i = 0; i < QK4_0/4; ++i) { + uchar x0 = b->qs[2*i + 0]; + uchar x1 = b->qs[2*i + 1]; + + q[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + q[i + QK4_0/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + +#ifdef ADRENO_GPU + // Workaround for adreno - must have the following printf statement for + // the kernel to work properly. Otherwise it produces incorrect result. + // convert_uchar above also seems necessary. + // Compare against a large number so that it does not print anything. + // get_sub_group_local_id() also works. + if (get_global_id(0) == 65536*4096) { + printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0)); + } +#endif + } +} diff --git a/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl b/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl new file mode 100644 index 0000000000000000000000000000000000000000..36eff0439fa735f11e1a3404dda190c28be9f45a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl @@ -0,0 +1,58 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// diag_mask_inf kernels +//------------------------------------------------------------------------------ +kernel void kernel_diag_mask_inf( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int n_past +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i02 = get_global_id(2); + int i01 = get_global_id(1); + int i00 = get_global_id(0); + + if (i00 > n_past + i01) { + dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; + } else { + dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; + } +} + +kernel void kernel_diag_mask_inf_8( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd, + int ne00, + int ne01, + int n_past +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + int i = 2*get_global_id(0); + + dst[i+0] = src0[i+0]; + dst[i+1] = src0[i+1]; + int i4 = 4*i; + int i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; + int i01 = i4/(ne00); i4 -= i01*ne00; + int i00 = i4; + for (int k = 3; k >= 0; --k) { + if (i00 + 4 + k <= n_past + i01) { + break; + } + (&dst[i+1])[k] = -INFINITY; + if (i00 + k > n_past + i01) { + (&dst[i])[k] = -INFINITY; + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/div.cl b/ggml/src/ggml-opencl/kernels/div.cl new file mode 100644 index 0000000000000000000000000000000000000000..d453ad99be47d08b32381a502d0338317b6edb07 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/div.cl @@ -0,0 +1,72 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// div +//------------------------------------------------------------------------------ +kernel void kernel_div( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i10 = i0 % ne10; + *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) / *((global float *)(src1_ptr + i10*nb10)); + } +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_div_row( + global float4 * src0, + ulong offset0, + global float4 * src1, + ulong offset1, + global float4 * dst, + ulong offsetd, + int ne +) { + src0 = (global float4*)((global char*)src0 + offset0); + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float4*)((global char*)dst + offsetd); + + // This performs better than using %. + uint gid = get_global_id(0); + uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne + dst[gid] = src0[gid] / src1[idx1]; +} diff --git a/ggml/src/ggml-opencl/kernels/embed_kernel.py b/ggml/src/ggml-opencl/kernels/embed_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..b5d1d7242b624bae75b03a700d1038449ac7c5fe --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/embed_kernel.py @@ -0,0 +1,26 @@ +# + +import sys +import logging +logger = logging.getLogger("opencl-embed-kernel") + + +def main(): + logging.basicConfig(level=logging.INFO) + + if len(sys.argv) != 3: + logger.info("Usage: python embed_kernel.py ") + sys.exit(1) + + ifile = open(sys.argv[1], "r") + ofile = open(sys.argv[2], "w") + + for i in ifile: + ofile.write('R"({})"\n'.format(i)) + + ifile.close() + ofile.close() + + +if __name__ == "__main__": + main() diff --git a/ggml/src/ggml-opencl/kernels/gelu.cl b/ggml/src/ggml-opencl/kernels/gelu.cl new file mode 100644 index 0000000000000000000000000000000000000000..71c310cc9f986086016affe5e898233b45a0d0c7 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gelu.cl @@ -0,0 +1,62 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// gelu +//------------------------------------------------------------------------------ +#define GELU_COEF_A 0.044715f +#define GELU_QUICK_COEF -1.702f +#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f + +kernel void kernel_gelu( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + float x = src0[get_global_id(0)]; + + dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + float4 x = src0[get_global_id(0)]; + + dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_quick( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + float x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_gelu_quick_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + float4 x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl new file mode 100644 index 0000000000000000000000000000000000000000..ee5c79f000d695f0b34f2f0c4cdab6876ddd634c --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl @@ -0,0 +1,268 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +// assume +#define QK4_0 32 +#define N_SIMDGROUP 4 + +#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +__kernel void kernel_gemv_noshuffle( + __read_only image1d_buffer_t src0_q, // quantized A + global half2 * src0_d, // A scales + __read_only image1d_buffer_t src1, // B + ulong offset1, // offset to B (0) + global float * dst, // C + ulong offsetd, // offset to C (0) + uint K, // K + int ne01, // M + int ne02, // 1 + int ne10, // K + int ne12, // 1 + int ne0, // M + int ne1, // N + int r2, // 1 + int r3) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + __private uint4 regA; + __private half2 regS; + __private float8 regB; + + __private float2 totalSum = (float2)(0.0f); + + // loop along K in block granularity, skip 4 blocks every iter + for (uint k = groupId; k < (K / QK4_0); k += N_SIMDGROUP) { + regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows + // first 4 fibers in each wave load 8 B values to its private scope + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + // load half weights for two blocks in consecutive rows + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAT + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAT + } + + // reduction in local memory, assumes #wave=4 + __local float2 reduceLM[SIMDGROUP_WIDTH * 3]; + if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum; + if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum; + if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum; + barrier(CLK_LOCAL_MEM_FENCE); + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl new file mode 100644 index 0000000000000000000000000000000000000000..469d3edef00ccdd552894d4916c7e75825f74e1b --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl @@ -0,0 +1,274 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +// assume +#define QK4_0 32 +#define N_SIMDGROUP 4 + +#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +__kernel void kernel_gemv_noshuffle( + __read_only image1d_buffer_t src0_q, // quantized A + global half2 * src0_d, // A scales + __read_only image1d_buffer_t src1, // B + ulong offset1, // offset to B (0) + global float * dst, // C + ulong offsetd, // offset to C (0) + int ne00, // K + int ne01, // M + int ne02, // 1 + int ne10, // K + int ne12, // 1 + int ne0, // M + int ne1, // N + int r2, // 1 + int r3) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = N_SIMDGROUP * M; + + __private uint4 regA; + __private half2 regS; + __private float8 regB; + + __private float2 totalSum = (float2)(0.0f); + + // loop along K in block granularity, skip 4 blocks every iter + for (uint k = groupId; k < (K / QK4_0); k += N_SIMDGROUP) { + regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows + // first 4 fibers in each wave load 8 B values to its private scope + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + // load half weights for two blocks in consecutive rows + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAT + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAT + } + + // reduction in local memory, assumes #wave=4 + __local float2 reduceLM[SIMDGROUP_WIDTH * 3]; + if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum; + if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum; + if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum; + barrier(CLK_LOCAL_MEM_FENCE); + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} diff --git a/ggml/src/ggml-opencl/kernels/get_rows.cl b/ggml/src/ggml-opencl/kernels/get_rows.cl new file mode 100644 index 0000000000000000000000000000000000000000..b3fea2923df8f91f357249bc1ebbaabea229742d --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/get_rows.cl @@ -0,0 +1,163 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +#define QK4_0 32 + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + + +//------------------------------------------------------------------------------ +// dequantize_q4_0_f32, dequantize_q4_0_f16 +//------------------------------------------------------------------------------ +void dequantize_q4_0_f32(global struct block_q4_0 * xb, short il, float16 * reg) { + global ushort * qs = ((global ushort *)xb + 1); + float d1 = il ? (xb->d / 16.h) : xb->d; + float d2 = d1 / 256.f; + float md = -8.h * xb->d; + ushort mask0 = il ? 0x00F0 : 0x000F; + ushort mask1 = mask0 << 8; + + reg->s0 = d1 * (qs[0] & mask0) + md; + reg->s1 = d2 * (qs[0] & mask1) + md; + + reg->s2 = d1 * (qs[1] & mask0) + md; + reg->s3 = d2 * (qs[1] & mask1) + md; + + reg->s4 = d1 * (qs[2] & mask0) + md; + reg->s5 = d2 * (qs[2] & mask1) + md; + + reg->s6 = d1 * (qs[3] & mask0) + md; + reg->s7 = d2 * (qs[3] & mask1) + md; + + reg->s8 = d1 * (qs[4] & mask0) + md; + reg->s9 = d2 * (qs[4] & mask1) + md; + + reg->sa = d1 * (qs[5] & mask0) + md; + reg->sb = d2 * (qs[5] & mask1) + md; + + reg->sc = d1 * (qs[6] & mask0) + md; + reg->sd = d2 * (qs[6] & mask1) + md; + + reg->se = d1 * (qs[7] & mask0) + md; + reg->sf = d2 * (qs[7] & mask1) + md; +} + + +//------------------------------------------------------------------------------ +// get_rows +//------------------------------------------------------------------------------ +kernel void kernel_get_rows_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb10, + ulong nb11, + ulong nb1, + ulong nb2 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i10 = get_group_id(0); + int i11 = get_group_id(1); + + int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + + int i02 = i11; + + for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { + ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = + ((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +kernel void kernel_get_rows_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb10, + ulong nb11, + ulong nb1, + ulong nb2 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i10 = get_group_id(0); + int i11 = get_group_id(1); + + int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + + int i02 = i11; + + for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { + ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = + ((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +kernel void kernel_get_rows_q4_0( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb10, + ulong nb11, + ulong nb1, + ulong nb2 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + const int NL = 2; + + int i10 = get_group_id(0); + int i11 = get_group_id(1); + + int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + + int i02 = i11; + + for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) { + float16 temp; + dequantize_q4_0_f32( + ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp); + *(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; + } +} diff --git a/ggml/src/ggml-opencl/kernels/group_norm.cl b/ggml/src/ggml-opencl/kernels/group_norm.cl new file mode 100644 index 0000000000000000000000000000000000000000..57c9df4d35b0926ea0bff101e7d1a2d6c7868008 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/group_norm.cl @@ -0,0 +1,72 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +// Workgroup must be a subgroup +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_group_norm( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne, + int group_size, + float eps +) { + src0 = (global float *)((global char *)src0 + offset0); + dst = (global float *)((global char *)dst + offsetd); + + int start = get_group_id(0) * group_size; + int end = start + group_size; + + start += get_local_id(0); + + if (end >= ne) { + end = ne; + } + + float tmp = 0.0f; + + for (int j = start; j < end; j += get_local_size(0)) { + tmp += src0[j]; + } + + tmp = sub_group_reduce_add(tmp); + + const float mean = tmp / group_size; + tmp = 0.0f; + + for (int j = start; j < end; j += get_local_size(0)) { + float xi = src0[j] - mean; + dst[j] = xi; + tmp += xi * xi; + } + + tmp = sub_group_reduce_add(tmp); + + const float variance = tmp / group_size; + const float scale = 1.0f/sqrt(variance + eps); + for (int j = start; j < end; j += get_local_size(0)) { + dst[j] *= scale; + } +} diff --git a/ggml/src/ggml-opencl/kernels/im2col_f16.cl b/ggml/src/ggml-opencl/kernels/im2col_f16.cl new file mode 100644 index 0000000000000000000000000000000000000000..b84c8984653c2bd9d7f6ebeae04b548fed0624be --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/im2col_f16.cl @@ -0,0 +1,57 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_im2col_f16( + global float * src1, + ulong offset1, + global half * dst, + ulong offsetd, + ulong batch_offset, + ulong delta_offset, + long IW, + long IH, + long IC, + long OW, + long OH, + long KW, + long KH, + long pelements, + long CHW, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1 +) { + long i = get_global_id(0); + if (i >= pelements) { + return; + } + + src1 = (global float*)((global char*)src1 + offset1); + dst = (global half*)((global char*)dst + offsetd); + + long ksize = OW * (KH > 1 ? KW : 1); + long kx = i / ksize; + long kd = kx * ksize; + long ky = (i - kd) / OW; + long ix = i % OW; + + long oh = get_group_id(1); + long batch = get_group_id(2) / IC; + long ic = get_group_id(2) % IC; + + long iiw = ix * s0 + kx * d0 - p0; + long iih = oh * s1 + ky * d1 - p1; + + long offset_dst = + ((batch * OH + oh) * OW + ix) * CHW + + (ic * (KW * KH) + ky * KW + kx); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + long offset_src = ic * delta_offset + batch * batch_offset; + dst[offset_dst] = src1[offset_src + iih * IW + iiw]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/im2col_f32.cl b/ggml/src/ggml-opencl/kernels/im2col_f32.cl new file mode 100644 index 0000000000000000000000000000000000000000..4bf65e4eaafba82f50988e736cfeb495d411845b --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/im2col_f32.cl @@ -0,0 +1,57 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_im2col_f32( + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + ulong batch_offset, + ulong delta_offset, + long IW, + long IH, + long IC, + long OW, + long OH, + long KW, + long KH, + long pelements, + long CHW, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1 +) { + long i = get_global_id(0); + if (i >= pelements) { + return; + } + + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + long ksize = OW * (KH > 1 ? KW : 1); + long kx = i / ksize; + long kd = kx * ksize; + long ky = (i - kd) / OW; + long ix = i % OW; + + long oh = get_group_id(1); + long batch = get_group_id(2) / IC; + long ic = get_group_id(2) % IC; + + long iiw = ix * s0 + kx * d0 - p0; + long iih = oh * s1 + ky * d1 - p1; + + long offset_dst = + ((batch * OH + oh) * OW + ix) * CHW + + (ic * (KW * KH) + ky * KW + kx); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + long offset_src = ic * delta_offset + batch * batch_offset; + dst[offset_dst] = src1[offset_src + iih * IW + iiw]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul.cl b/ggml/src/ggml-opencl/kernels/mul.cl new file mode 100644 index 0000000000000000000000000000000000000000..2a2b4eb70a13cb29d366899d45abc18f3178b330 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul.cl @@ -0,0 +1,79 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// mul +//------------------------------------------------------------------------------ +kernel void kernel_mul( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i10 = i0 % ne10; + *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) * *((global float *)(src1_ptr + i10*nb10)); + } +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_mul_row( + global float4 * src0, + ulong offset0, + global float4 * src1, + ulong offset1, + global float4 * dst, + ulong offsetd, + int ne +) { + src0 = (global float4*)((global char*)src0 + offset0); + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float4*)((global char*)dst + offsetd); + + // This performs better than using %. + uint gid = get_global_id(0); + uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne + dst[gid] = src0[gid] * src1[idx1]; +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl b/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl new file mode 100644 index 0000000000000000000000000000000000000000..ecb577b99333982e03e9ac2ef03cc020f5fd0327 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl @@ -0,0 +1,139 @@ +// src0_q, src0_d, src1 are transposed as a preprocessing step +// 4-bit weights are transposed in groups of 4 (unsigned short int) +// consider weights originally "next to each other", now "on top of each other" +// each fiber computes a 8x4 tile of output elements +// using unshuffled weights + +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif + +kernel void kernel_mul_mat_Ab_Bi_8x4( + global const ushort * src0_q, // quantized A + global const half * src0_d, // A scales + __read_only image1d_buffer_t src1, // B (1d image) + global float * dst, // C + int m, // M + int n, // N with padding + int k, // K + int n_no_padding // N without padding +) { + + int m_4 = m >> 2; + int n_4 = n >> 2; + + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; // 8x4 output elements + half8 B; // registers for activations + half4 dequantized_weights; // registers for dequantized weights + __global const ushort* weight_ptr = src0_q + gx_2; // pointer for weights + __global const half* scale_ptr = src0_d + gx_2; // pointer for scales + + for(int i=0; i> 4) - 8) * scale.s0; // dequantize a row of the 16 weights + dequantized_weights.s1 = (((bits4.s1 & (0x00F0)) >> 4) - 8) * scale.s1; + dequantized_weights.s2 = (((bits4.s2 & (0x00F0)) >> 4) - 8) * scale.s2; + dequantized_weights.s3 = (((bits4.s3 & (0x00F0)) >> 4) - 8) * scale.s3; + c0 += B * dequantized_weights.s0; //vector-scalar multiplication to accumulate + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1); + dequantized_weights.s0 = (((bits4.s0 & (0x0F00)) >> 8) - 8) * scale.s0; // dequantize a row of the 16 weights + dequantized_weights.s1 = (((bits4.s1 & (0x0F00)) >> 8) - 8) * scale.s1; + dequantized_weights.s2 = (((bits4.s2 & (0x0F00)) >> 8) - 8) * scale.s2; + dequantized_weights.s3 = (((bits4.s3 & (0x0F00)) >> 8) - 8) * scale.s3; + c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1); + dequantized_weights.s0 = (((bits4.s0 & (0xF000)) >> 12) - 8) * scale.s0; // dequantize a row of the 16 weights + dequantized_weights.s1 = (((bits4.s1 & (0xF000)) >> 12) - 8) * scale.s1; + dequantized_weights.s2 = (((bits4.s2 & (0xF000)) >> 12) - 8) * scale.s2; + dequantized_weights.s3 = (((bits4.s3 & (0xF000)) >> 12) - 8) * scale.s3; + c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + + int idx = (gy<<3)*m + (gx<<2); // vectorized store 16 elements + + // conditional check if store is to a valid location. Required when N is not a multiple of 8 + // if statements allow registers to be reused for each store + // provides a performance boost due to reduced register footprint, which increases number of concurrent waves + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl new file mode 100644 index 0000000000000000000000000000000000000000..9393b5494158ac099a59b054c9adedcd68d1096d --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl @@ -0,0 +1,118 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define N_F16_F16 4 + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f16( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3) +{ + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int rb = get_group_id(1)*N_F16_F16; + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global half * x = (global half *) (src0 + offset_src0); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global half * y = (global half *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += (half) x[i] * (half) y[i]; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + global half4 * x4 = (global half4 *)x; + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global half * y = (global half *) (src1 + offset_src1); + global half4 * y4 = (global half4 *) y; + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += (half) x4[i].s0 * y4[i].s0; + sumf += (half) x4[i].s1 * y4[i].s1; + sumf += (half) x4[i].s2 * y4[i].s2; + sumf += (half) x4[i].s3 * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (half) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl new file mode 100644 index 0000000000000000000000000000000000000000..e52d3c6d47558c9f203c0fc6c155a4833b50b576 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl @@ -0,0 +1,118 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define N_F16_F32 4 + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f32( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int rb = get_group_id(1)*N_F16_F32; + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global half * x = (global half *) (src0 + offset_src0); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += convert_float(x[i]) * y[i]; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + global half4 * x4 = (global half4 *)x; + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + global float4 * y4 = (global float4 *) y; + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += convert_float(x4[i].s0) * y4[i].s0; + sumf += convert_float(x4[i].s1) * y4[i].s1; + sumf += convert_float(x4[i].s2) * y4[i].s2; + sumf += convert_float(x4[i].s3) * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (float) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl new file mode 100644 index 0000000000000000000000000000000000000000..28d30212cda902b8b394886636060ee6fee80350 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl @@ -0,0 +1,94 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f32_1row( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global half * x = (global half *) (src0 + offset_src0); + global float * y = (global float *) (src1 + offset_src1); + + float sumf = 0; + if (ne00 < 128) { + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += (float) x[i] * (float) y[i]; + } + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } else { + global half4 * x4 = (global half4 *) x; + global float4 * y4 = (global float4 *) y; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += (float) x4[i].s0 * y4[i].s0; + sumf += (float) x4[i].s1 * y4[i].s1; + sumf += (float) x4[i].s2 * y4[i].s2; + sumf += (float) x4[i].s3 * y4[i].s3; + } + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (float) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl new file mode 100644 index 0000000000000000000000000000000000000000..cdf8197c47058b19c303a72c5cbb6e3974436f99 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl @@ -0,0 +1,84 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +// Assumes row size (ne00) is a multiple of 4 +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f32_l4( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int nrows = ne11; + int r0 = get_group_id(0); + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global half4 * x4 = (global half4 *) (src0 + offset_src0); + + for (int r1 = 0; r1 < nrows; ++r1) { + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float4 * y4 = (global float4 *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += convert_float(x4[i].s0) * y4[i].s0; + sumf += convert_float(x4[i].s1) * y4[i].s1; + sumf += convert_float(x4[i].s2) * y4[i].s2; + sumf += convert_float(x4[i].s3) * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl new file mode 100644 index 0000000000000000000000000000000000000000..ec71b8756523610c6cf66c1b51e555338d91dfdf --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl @@ -0,0 +1,118 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define N_F32_F32 4 + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f32_f32( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int rb = get_group_id(1)*N_F32_F32; + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global float * x = (global float *) (src0 + offset_src0); + + if (ne00 < 128) { + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + global float4 * x4 = (global float4 *)x; + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + global float4 * y4 = (global float4 *) y; + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += (float) x4[i].s0 * y4[i].s0; + sumf += (float) x4[i].s1 * y4[i].s1; + sumf += (float) x4[i].s2 * y4[i].s2; + sumf += (float) x4[i].s3 * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (float) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl new file mode 100644 index 0000000000000000000000000000000000000000..7ccf41efbe91831d4240ce711370a8a904220ac9 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl @@ -0,0 +1,283 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +// This function requires the original shuffled weights. +// As a reminder, the original weights are shuffled so that (q[0], q[16]) are +// packed together in a byte, so are (q[1], q[17]) and so on. +inline float block_q_4_0_dot_y_flat( + global uchar * x, + global half * dh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + global ushort * qs = ((global ushort *)x + il/2); + float acc = 0.f; + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +// +// This variant outputs 8 values. +// +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 8 // each SIMD group works on 8 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // subgroup size +#elif defined (ADRENO_GPU) +#define N_DST 8 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_8x_flat( + global char * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = 0; + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float8 sumf = 0.f; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float8 tot = (float8)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_id_q4_0_f32_8x_flat( + global char * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global char * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb02, + int ne10, + int ne11, + int ne12, + ulong nb11, + ulong nb12, + int ne20, + int ne21, + ulong nb21, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float *)((global char *)src1 + offset1); + src2 = (global char *)((global char *)src2 + offset2); + dst = (global float *)((global char *)dst + offsetd); + + const int iid1 = get_group_id(2)/ne20; + const int idx = get_group_id(2)%ne20; + + const int i02 = ((global int *)(src2 + iid1*nb21))[idx]; + + const int i11 = idx%ne11; + const int i12 = iid1; + + const int i1 = idx; + const int i2 = i12; + + global char * src0_q_cur = src0_q + (i02*nb02/nb00)*(QK4_0/2); + global half * src0_d_cur = src0_d + (i02*nb02/nb00); + global float * src1_cur = (global float *)((global char *) src1 + i11*nb11 + i12*nb12); + global float * dst_cur = dst + i1*ne0 + i2*ne1*ne0; + + mul_vec_q_n_f32_8x_flat(src0_q_cur, src0_d_cur, src1_cur, dst_cur, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl new file mode 100644 index 0000000000000000000000000000000000000000..52141e0ed55c2eb1aff562bef38e0d9d0e60b413 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl @@ -0,0 +1,192 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +//------------------------------------------------------------------------------ +// mul_vec_q_n_f32 +//------------------------------------------------------------------------------ +// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_4_0_dot_y( + global struct block_q4_0 * qb_curr, + float sumy, + private float * yl, + int il +) { + float d = qb_curr->d; + float2 acc = 0.f; + global ushort * qs = ((global ushort *)qb_curr + 1 + il/2); + for (int i = 0; i < 8; i+=2) { + acc.s0 += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc.s1 += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); + } + return d * (sumy * -8.f + acc.s0 + acc.s1); +} + +#ifdef INTEL_GPU +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global + // id of a SIMD group in the grid. + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; // src1 vector cache + float sumf[N_DST]={0.f}; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + for (int i = 0; i < 8; i += 2) { + sumy += yb[i] + yb[i+1]; + yl[i+0] = yb[i+ 0]; + yl[i+1] = yb[i+ 1]/256.f; + sumy += yb[i+16] + yb[i+17]; + yl[i+8] = yb[i+16]/16.f; + yl[i+9] = yb[i+17]/4096.f; + } + + for (int row = 0; row < N_DST; row++) { + sumf[row] += block_q_4_0_dot_y(x+ib+row*nb, sumy, yl, il); + } + + // One thread in a SIMD group (i.e., subgroup) handles a half block, + // hence then entire SIMD group handles SIMDWIDTH/2 blocks. + // y points to the activation matrix (of type float). Therefore for + // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because + // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of + // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size. + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + // The above does not work for Adreno - it produces incorrect results for + // row = 1, 2, 3 and only row = 0 gives the correct result. + // If N_DST is changed, the below array must be initialized accordingly. + // This also seems to perform better on Intel. + float tot[N_DST] = { + sub_group_reduce_add(sumf[0]), sub_group_reduce_add(sumf[1]), + sub_group_reduce_add(sumf[2]), sub_group_reduce_add(sumf[3])}; + for (int row = 0; row < N_DST; ++row) { + if (get_sub_group_local_id() == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot[row]; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl new file mode 100644 index 0000000000000000000000000000000000000000..3eebab8f0f2ca6814ac61436bc775c74740ed9b0 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl @@ -0,0 +1,307 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +inline float mm_block_q_4_0_dot_y_flat( + global uchar * x, + global half * dh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + global ushort * qs = ((global ushort *)x + il/2); + float acc = 0.f; + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +#ifdef INTEL_GPU +#define N_DST 16 // each SIMD group works on 8 rows (in weights matrix) +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 16 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif +// +// This variant performs 1d blocking with 16x output. +// Eeach simdgroup outputs 16 values on `n0` dim (row in the output matrix). +// +inline void mul_mat_q_n_f32_1d_16x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const int nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of + // a SIMD group in the grid. Each SIMD group produces N_DST values in the + // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. + // Currently with llama2 7B, im is always 0. + // TODO: how to handle im/gqa*(nb*ne0)? + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float16 sumf = (float16)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + sumf.s8 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 8*nb*QK4_0/2, d + ib + 8*nb, sumy, yl, il); + sumf.s9 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 9*nb*QK4_0/2, d + ib + 9*nb, sumy, yl, il); + sumf.sa += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 10*nb*QK4_0/2, d + ib + 10*nb, sumy, yl, il); + sumf.sb += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 11*nb*QK4_0/2, d + ib + 11*nb, sumy, yl, il); + + sumf.sc += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 12*nb*QK4_0/2, d + ib + 12*nb, sumy, yl, il); + sumf.sd += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 13*nb*QK4_0/2, d + ib + 13*nb, sumy, yl, il); + sumf.se += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 14*nb*QK4_0/2, d + ib + 14*nb, sumy, yl, il); + sumf.sf += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 15*nb*QK4_0/2, d + ib + 15*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float16 tot = (float16)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7), + + sub_group_reduce_add(sumf.s8), sub_group_reduce_add(sumf.s9), + sub_group_reduce_add(sumf.sa), sub_group_reduce_add(sumf.sb), + sub_group_reduce_add(sumf.sc), sub_group_reduce_add(sumf.sd), + sub_group_reduce_add(sumf.se), sub_group_reduce_add(sumf.sf) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + + if (first_row + 8 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 8] = tot.s8; + } + if (first_row + 9 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 9] = tot.s9; + } + if (first_row + 10 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 10] = tot.sa; + } + if (first_row + 11 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 11] = tot.sb; + } + + if (first_row + 12 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 12] = tot.sc; + } + if (first_row + 13 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 13] = tot.sd; + } + if (first_row + 14 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 14] = tot.se; + } + if (first_row + 15 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 15] = tot.sf; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_1d_16x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_mat_q_n_f32_1d_16x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl new file mode 100644 index 0000000000000000000000000000000000000000..38024d00ad5cc02d43615e2a40bb3d1c04b0d8af --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl @@ -0,0 +1,265 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +inline float mm_block_q_4_0_dot_y_flat( + global uchar * x, + global half * dh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + global ushort * qs = ((global ushort *)x + il/2); + float acc = 0.f; + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +#ifdef INTEL_GPU +#define N_DST 8 // each SIMD group works on 8 rows (in weights matrix) +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 8 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif +// +// This variant performs 1d blocking with 8x output. +// Eeach simdgroup outputs 8 values on `n0` dim (row in the output matrix). +// +inline void mul_mat_q_n_f32_1d_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const int nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of + // a SIMD group in the grid. Each SIMD group produces N_DST values in the + // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. + // Currently with llama2 7B, im is always 0. + // TODO: how to handle im/gqa*(nb*ne0)? + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float8 sumf = (float8)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float8 tot = (float8)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_1d_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_mat_q_n_f32_1d_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl new file mode 100644 index 0000000000000000000000000000000000000000..aed1ce7b2609551a26238401d1b359dec857e0d5 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl @@ -0,0 +1,272 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +// This function requires the original shuffled weights. +// As a reminder, the original weights are shuffled so that (q[0], q[16]) are +// packed together in a byte, so are (q[1], q[17]) and so on. +inline float block_q_4_0_dot_y_flat( + global uchar * x, + global half * dh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + global ushort * qs = ((global ushort *)x + il/2); + float acc = 0.f; + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +// +// This variant outputs 8 values. +// +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 8 // each SIMD group works on 8 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 32 +#elif defined (ADRENO_GPU) +#define N_DST 8 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of + // a SIMD group in the grid. Each SIMD group produces N_DST values in the + // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. + // Currently with llama2 7B, im is always 0. + // TODO: how to handle im/gqa*(nb*ne0)? + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float8 sumf = 0.f; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float8 tot = (float8)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl new file mode 100644 index 0000000000000000000000000000000000000000..929552179710e6b89853420d13f3b9bacd97197a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl @@ -0,0 +1,254 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +// +// This variant unrolls the loops and uses vector types instead of pointers. +// It improves performance on Adreno but not so much on Intel. +// +inline float block_q_4_0_dot_y_v( + global struct block_q4_0 * qb_curr, + float sumy, + float16 yl, + int il +) { + float d = qb_curr->d; + float acc = 0.f; + global ushort * qs = ((global ushort *)qb_curr + 1 + il/2); + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_v( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global + // id of a SIMD group in the grid. + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; // src1 vector cache + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q_4_0_dot_y_v(x+ib+0*nb, sumy, yl, il); + sumf.s1 += block_q_4_0_dot_y_v(x+ib+1*nb, sumy, yl, il); + sumf.s2 += block_q_4_0_dot_y_v(x+ib+2*nb, sumy, yl, il); + sumf.s3 += block_q_4_0_dot_y_v(x+ib+3*nb, sumy, yl, il); + + // One thread in a SIMD group (i.e., subgroup) handles a half block, + // hence then entire SIMD group handles SIMDWIDTH/2 blocks. + // y points to the activation matrix (of type float). Therefore for + // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because + // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of + // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size. + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + // The above does not work for Adreno - it produces incorrect results for + // row = 1, 2, 3 and only row = 0 gives the correct result. + // If N_DST is changed, the below array must be initialized accordingly. + // This also seems to perform better on Intel. + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_v( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_v(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl new file mode 100644 index 0000000000000000000000000000000000000000..8a17b9aae63902d8951abca8825ab09bd171db7b --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl @@ -0,0 +1,190 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q6_K +//------------------------------------------------------------------------------ +// 6-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elements each +// Effectively 6.5625 bits per weight +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + half d; // super-block scale +} block_q6_K; + +//------------------------------------------------------------------------------ +// kernel_mul_mv_q6_K_f32 +//------------------------------------------------------------------------------ + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 1 // number of rows each SIMD group works on +#define N_SIMDGROUP 2 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // SIMD group size +#elif defined (ADRENO_GPU) +#define N_DST 1 +#define N_SIMDGROUP 2 +#define N_SIMDWIDTH 64 +#endif + +#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q6_K_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + uchar kmask1 = 0x03; + uchar kmask2 = 0x0C; + uchar kmask3 = 0x30; + uchar kmask4 = 0xC0; + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int row = N_SIMDGROUP * r0 + get_sub_group_id(); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global block_q6_K * x = (global block_q6_K *) src0 + row*nb + offset_src0; + global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float sumf = 0; + + // For Q6_K quantization, 16 values forms a subblock, 16 subblock forms a + // block. Values in a subblock shares a scale that is quantized with 8 bits; + // the entire block shares a single floating point scale. + // For work distribution, each thread processes a subblock (16 weights), hence + // 16 threads process a (super) block -- a subgroup thus handles SIMDWIDTH/16 + // (super) blocks -- this is the block stride. + // The 16 threads that process a (super) block are split into 2 portions, each has + // 8 threads; each portion works on 8 subblocks. + // For subgroup of 16 threads, the entire subgroup works on a single (super) block + // before moving to the next (super) block. Thread0 - thread7 work on the + // first 8 subblocks; thread8 - thread15 works on the last 8 subblocks. + // Thread0 - thread3 work on subblocks 0, 2, 4, 6; thread4 - thread7 work on + // subblocks 1, 3, 5, 7. Each thread does not work on an entire subblock, but + // works on a total of 16 weight values. + int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0 + int ix = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1 + int ip = tid/8; // first or second half of (super) block (0 or 1) + int il = tid%8; // each half has 8 parts, one per scale + int n = 4; // 4 scales at a time (and 4 sums) + int l0 = n*il; // offset into half-block, 0..28 + int is = 8*ip + l0/16; // 0, 1, 8, 9 + + int y_offset = 128*ip + l0; + int q_offset_l = 64*ip + l0; + int q_offset_h = 32*ip + l0; + + for (int i = ix; i < nb; i += BLOCK_STRIDE) { + + global uint8_t * q1 = x[i].ql + q_offset_l; + global uint8_t * q2 = q1 + QK_K/8; + global uint8_t * qh = x[i].qh + q_offset_h; + global int8_t * sc = x[i].scales + is; + + global float * y = yy + i * QK_K + y_offset; + + float dall = x[i].d; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + + sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & kmask1) << 4)) - 32.f); + sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & kmask2) << 2)) - 32.f); + sums.s2 += y[0+64] * ((float)((q1[0] >> 4) | ((qh[0] & kmask3) << 0)) - 32.f); + sums.s3 += y[0+96] * ((float)((q2[0] >> 4) | ((qh[0] & kmask4) >> 2)) - 32.f); + + sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & kmask1) << 4)) - 32.f); + sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & kmask2) << 2)) - 32.f); + sums.s2 += y[1+64] * ((float)((q1[1] >> 4) | ((qh[1] & kmask3) << 0)) - 32.f); + sums.s3 += y[1+96] * ((float)((q2[1] >> 4) | ((qh[1] & kmask4) >> 2)) - 32.f); + + sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & kmask1) << 4)) - 32.f); + sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & kmask2) << 2)) - 32.f); + sums.s2 += y[2+64] * ((float)((q1[2] >> 4) | ((qh[2] & kmask3) << 0)) - 32.f); + sums.s3 += y[2+96] * ((float)((q2[2] >> 4) | ((qh[2] & kmask4) >> 2)) - 32.f); + + sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & kmask1) << 4)) - 32.f); + sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & kmask2) << 2)) - 32.f); + sums.s2 += y[3+64] * ((float)((q1[3] >> 4) | ((qh[3] & kmask3) << 0)) - 32.f); + sums.s3 += y[3+96] * ((float)((q2[3] >> 4) | ((qh[3] & kmask4) >> 2)) - 32.f); + + sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]); + } + + float tot = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[r1*ne0 + im*ne0*ne1 + row] = tot; + } +} diff --git a/ggml/src/ggml-opencl/kernels/norm.cl b/ggml/src/ggml-opencl/kernels/norm.cl new file mode 100644 index 0000000000000000000000000000000000000000..43167ba4d22121fdf7c075ebd6ebbcf35d7eab42 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/norm.cl @@ -0,0 +1,81 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// norm +//------------------------------------------------------------------------------ +kernel void kernel_norm( + global void * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + float eps, + local float * sum +) { + src0 = (global void*)((global char*)src0 + offset0); + dst = (global void*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); + + // MEAN + // parallel sum + sum[get_local_id(0)] = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + sum[get_local_id(0)] += x[i00]; + } + // reduce + barrier(CLK_LOCAL_MEM_FENCE); + for (uint i = get_local_size(0)/2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + barrier(CLK_LOCAL_MEM_FENCE); + } + float mean = sum[0] / ne00; + + // recenter and VARIANCE + barrier(CLK_LOCAL_MEM_FENCE); + global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + sum[get_local_id(0)] = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + y[i00] = x[i00] - mean; + sum[get_local_id(0)] += y[i00] * y[i00]; + } + + // reduce + barrier(CLK_LOCAL_MEM_FENCE); + for (uint i = get_local_size(0)/2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + barrier(CLK_LOCAL_MEM_FENCE); + } + float variance = sum[0] / ne00; + + float scale = 1.0f/sqrt(variance + eps); + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + y[i00] = y[i00] * scale; + } +} diff --git a/ggml/src/ggml-opencl/kernels/pad.cl b/ggml/src/ggml-opencl/kernels/pad.cl new file mode 100644 index 0000000000000000000000000000000000000000..747fa7febcc74c54f665c7da6c09884c3a21d3b0 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/pad.cl @@ -0,0 +1,30 @@ +kernel void kernel_pad( + global const void * src0_ptr, + ulong src0_offset, + global void * dst_ptr, + ulong dst_offset, + int s_ne0, int s_ne1, int s_ne2, + int d_ne0, int d_ne1, int d_ne2 +) { + global const float * src0 = (global const float *)((global const char *)src0_ptr + src0_offset); + global float * dst = (global float *)((global char *)dst_ptr + dst_offset); + + int nidx = get_global_id(0); + int idx_d1 = get_group_id(1); + int idx_d2 = get_group_id(2); + + if (nidx >= d_ne0) { + return; + } + + int dst_el_offset = nidx + idx_d1 * d_ne0 + idx_d2 * d_ne0 * d_ne1; + + bool in_src_bounds = (nidx < s_ne0) && (idx_d1 < s_ne1) && (idx_d2 < s_ne2); + + if (in_src_bounds) { + int src_el_offset = nidx + idx_d1 * s_ne0 + idx_d2 * s_ne0 * s_ne1; + dst[dst_el_offset] = src0[src_el_offset]; + } else { + dst[dst_el_offset] = 0.0f; + } +} diff --git a/ggml/src/ggml-opencl/kernels/relu.cl b/ggml/src/ggml-opencl/kernels/relu.cl new file mode 100644 index 0000000000000000000000000000000000000000..60ff28a61a09f84216fea4a8d9ffa72118846e07 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/relu.cl @@ -0,0 +1,16 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// relu +//------------------------------------------------------------------------------ +kernel void kernel_relu( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = fmax(0.0f, src0[get_global_id(0)]); +} diff --git a/ggml/src/ggml-opencl/kernels/repeat.cl b/ggml/src/ggml-opencl/kernels/repeat.cl new file mode 100644 index 0000000000000000000000000000000000000000..079498f5ab9473f820708234cbe4a35f0825ff07 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/repeat.cl @@ -0,0 +1,39 @@ +kernel void kernel_repeat( + global const char * src0_data_in, + global char * dst_data_in, + ulong src0_offset, + ulong dst_offset, + int src0_ne0, int src0_ne1, int src0_ne2, int src0_ne3, + ulong src0_nb0, ulong src0_nb1, ulong src0_nb2, ulong src0_nb3, + int dst_ne0, int dst_ne1, int dst_ne2, int dst_ne3, + ulong dst_nb0, ulong dst_nb1, ulong dst_nb2, ulong dst_nb3 +) { + global const char * src0_data = src0_data_in + src0_offset; + global char * dst_data = dst_data_in + dst_offset; + + const int d3 = get_global_id(2); + const int d2 = get_global_id(1); + const int d1 = get_global_id(0); + + if (d3 >= dst_ne3 || d2 >= dst_ne2 || d1 >= dst_ne1) { + return; + } + + const int s3 = d3 % src0_ne3; + const int s2 = d2 % src0_ne2; + const int s1 = d1 % src0_ne1; + + const global char * p_src0_slice = src0_data + (ulong)s3*src0_nb3 + (ulong)s2*src0_nb2 + (ulong)s1*src0_nb1; + global char * p_dst_slice = dst_data + (ulong)d3*dst_nb3 + (ulong)d2*dst_nb2 + (ulong)d1*dst_nb1; + + for (int d0 = 0; d0 < dst_ne0; ++d0) { + // Determine source index for dimension 0 based on tiling/broadcasting. + const int s0 = d0 % src0_ne0; + + const global char * restrict current_src_el_ptr = p_src0_slice + (ulong)s0*src0_nb0; + global char * restrict current_dst_el_ptr = p_dst_slice + (ulong)d0*dst_nb0; + for (int k = 0; k < src0_nb0; ++k) { + current_dst_el_ptr[k] = current_src_el_ptr[k]; + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/rms_norm.cl b/ggml/src/ggml-opencl/kernels/rms_norm.cl new file mode 100644 index 0000000000000000000000000000000000000000..9d21f3398ec38bf343ba61512a1253aeaf19efd3 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/rms_norm.cl @@ -0,0 +1,96 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// rms_norm +//------------------------------------------------------------------------------ +// This kernel depends on subgroup size. +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_rms_norm( + global void * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + float eps, + local float * sum // Note, the size depends on number of subgroups +) { + src0 = (global void*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); + global float * x_scalar = (global float *) x; + float4 sumf = 0; + float all_sum = 0; + + // parallel sum + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + sumf += x[i00] * x[i00]; + } + all_sum = sumf.s0 + sumf.s1 + sumf.s2 + sumf.s3; + all_sum = sub_group_reduce_add(all_sum); + if (get_sub_group_local_id() == 0) { + sum[get_sub_group_id()] = all_sum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + // broadcast + for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + } + if (get_local_id(0) == 0) { + for (int i = 4 * (ne00 / 4); i < ne00; i++) { + sum[0] += x_scalar[i]; + } + sum[0] /= ne00; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + const float mean = sum[0]; + const float scale = 1.0f/sqrt(mean + eps); + + global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + global float * y_scalar = (global float *) y; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + y[i00] = x[i00] * scale; + } + if (get_local_id(0) == 0) { + for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) { + y_scalar[i00] = x_scalar[i00] * scale; + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/rope.cl b/ggml/src/ggml-opencl/kernels/rope.cl new file mode 100644 index 0000000000000000000000000000000000000000..0247730c0365f71b28833d51ad83aa34d22f4e13 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/rope.cl @@ -0,0 +1,721 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// kernel_rope +//------------------------------------------------------------------------------ +float rope_yarn_ramp(float low, float high, int i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +float2 rope_yarn( + float theta_extrap, float freq_scale, float2 corr_dims, int i0, float ext_factor, float mscale +) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims.s0, corr_dims.s1, i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); + } + return (float2)(cos(theta) * mscale, sin(theta) * mscale); +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base)); +} + +float2 rope_yarn_corr_dims( + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow +) { + // start and end correction dims + return (float2)( + max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))), + min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))) + ); +} + +kernel void kernel_rope_norm_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + float theta = theta_base * pow(freq_base, inv_ndims*i0); + + float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + float x0 = src[0]; + float x1 = src[1]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_norm_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + float theta = theta_base * pow(freq_base, inv_ndims*i0); + + float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + float x0 = src[0]; + float x1 = src[1]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_neox_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_neox_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_multi_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const int sector = (i0 / 2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + theta_base = pos[i2]; + } + else if (sector >= sections.s0 && sector < sec_w) { + theta_base = pos[i2 + ne2 * 1]; + } + else if (sector >= sec_w && sector < sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 2]; + } + else if (sector >= sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 3]; + } + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_multi_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const int sector = (i0 / 2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + theta_base = pos[i2]; + } + else if (sector >= sections.s0 && sector < sec_w) { + theta_base = pos[i2 + ne2 * 1]; + } + else if (sector >= sec_w && sector < sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 2]; + } + else if (sector >= sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 3]; + } + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_vision_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + int ic = i0/2; + + const int sector = (i0/2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + const int p = sector; + theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); + } else if (sector >= sections.s0 && sector < sec_w) { + const int p = sector - sections.s0; + theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); + } + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } +} + +kernel void kernel_rope_vision_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + int ic = i0/2; + + const int sector = (i0/2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + const int p = sector; + theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); + } else if (sector >= sections.s0 && sector < sec_w) { + const int p = sector - sections.s0; + theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); + } + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } +} diff --git a/ggml/src/ggml-opencl/kernels/scale.cl b/ggml/src/ggml-opencl/kernels/scale.cl new file mode 100644 index 0000000000000000000000000000000000000000..8cfd518fa5a3ef5b566558f9564d48d4f8773b24 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/scale.cl @@ -0,0 +1,16 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// scale +//------------------------------------------------------------------------------ +kernel void kernel_scale( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd, + float scale +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + dst[get_global_id(0)] = src0[get_global_id(0)] * scale; +} diff --git a/ggml/src/ggml-opencl/kernels/sigmoid.cl b/ggml/src/ggml-opencl/kernels/sigmoid.cl new file mode 100644 index 0000000000000000000000000000000000000000..e3f669dde830b777e09c5361afc94fe8e95e446d --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/sigmoid.cl @@ -0,0 +1,29 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// sigmoid +//------------------------------------------------------------------------------ + +kernel void kernel_sigmoid_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = 1.0f / (1.0f + exp(-src0[get_global_id(0)])); +} + +kernel void kernel_sigmoid_f16( + global half * src0, + ulong offset0, + global half * dst, + ulong offsetd +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = 1.0f / (1.0f + exp(-src0[get_global_id(0)])); +} diff --git a/ggml/src/ggml-opencl/kernels/silu.cl b/ggml/src/ggml-opencl/kernels/silu.cl new file mode 100644 index 0000000000000000000000000000000000000000..1d95e1b50fd2a145317c58afe35073247b58e6ee --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/silu.cl @@ -0,0 +1,30 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// silu +//------------------------------------------------------------------------------ +kernel void kernel_silu( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + float x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x / (1.0f + exp(-x)); +} + +kernel void kernel_silu_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + float4 x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x / (1.0f + exp(-x)); +} diff --git a/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl b/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl new file mode 100644 index 0000000000000000000000000000000000000000..62c05369a87b14f132ba4e5cf211986245eaa2b8 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl @@ -0,0 +1,87 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max_4_f16( + global float * src0, + ulong offset0, + global half * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float *)((global char *)src0 + offset0); + src1 = (global half *)((global char *)src1 + offset1); + dst = (global float *)((global char *)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0; + global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float4 lmax4 = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)); + } + float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3)); + + const float max = sub_group_reduce_max(lmax); + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + const float4 exp_psrc4 = exp((psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)) - max); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + pdst4[i00] /= sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl b/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl new file mode 100644 index 0000000000000000000000000000000000000000..d562774eaba5e6c9dbe4c638fa4f5822ead8595a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl @@ -0,0 +1,87 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max_4( + global float * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0; + global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float4 lmax4 = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3)); + + const float max = sub_group_reduce_max(lmax); + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + pdst4[i00] /= sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/softmax_f16.cl b/ggml/src/ggml-opencl/kernels/softmax_f16.cl new file mode 100644 index 0000000000000000000000000000000000000000..d38d099671ecf61dc88f2d814b1a17f933228980 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/softmax_f16.cl @@ -0,0 +1,86 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max_f16( + global float * src0, + ulong offset0, + global half * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float *)((global char *)src0 + offset0); + src1 = (global half *)((global char *)src1 + offset1); + dst = (global float *)((global char *)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0; + global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float lmax = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + float max = sub_group_reduce_max(lmax); + + // parallel sum + float lsum = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); + lsum += exp_psrc0; + // Remember the result of exp here. exp is expensive, so we really do not + // wish to compute it twice. + pdst[i00] = exp_psrc0; + } + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + pdst[i00] /= sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/softmax_f32.cl b/ggml/src/ggml-opencl/kernels/softmax_f32.cl new file mode 100644 index 0000000000000000000000000000000000000000..001b587abe31e5c00b826c7122dbd40d9d83d50d --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/softmax_f32.cl @@ -0,0 +1,86 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max( + global float * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0; + global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float lmax = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + float max = sub_group_reduce_max(lmax); + + // parallel sum + float lsum = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); + lsum += exp_psrc0; + // Remember the result of exp here. exp is expensive, so we really do not + // wish to compute it twice. + pdst[i00] = exp_psrc0; + } + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + pdst[i00] /= sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/sub.cl b/ggml/src/ggml-opencl/kernels/sub.cl new file mode 100644 index 0000000000000000000000000000000000000000..041e88ad3a08013c169d18cc5c33681753011f64 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/sub.cl @@ -0,0 +1,72 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// div +//------------------------------------------------------------------------------ +kernel void kernel_sub( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i10 = i0 % ne10; + *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) - *((global float *)(src1_ptr + i10*nb10)); + } +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_sub_row( + global float4 * src0, + ulong offset0, + global float4 * src1, + ulong offset1, + global float4 * dst, + ulong offsetd, + int ne +) { + src0 = (global float4*)((global char*)src0 + offset0); + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float4*)((global char*)dst + offsetd); + + // This performs better than using %. + uint gid = get_global_id(0); + uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne + dst[gid] = src0[gid] - src1[idx1]; +} diff --git a/ggml/src/ggml-opencl/kernels/sum_rows.cl b/ggml/src/ggml-opencl/kernels/sum_rows.cl new file mode 100644 index 0000000000000000000000000000000000000000..c5f7c570f9514ba4dd9716d42633175ff44ee086 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/sum_rows.cl @@ -0,0 +1,39 @@ + +kernel void kernel_sum_rows_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float *)((global char *)src0 + offset0); + dst = (global float *)((global char *)dst + offsetd); + + int i3 = get_global_id(2); + int i2 = get_global_id(1); + int i1 = get_global_id(0); + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); + global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + + float row_sum = 0; + + for (int i0 = 0; i0 < ne00; i0++) { + row_sum += src_row[i0]; + } + + dst_row[0] = row_sum; +} diff --git a/ggml/src/ggml-opencl/kernels/tanh.cl b/ggml/src/ggml-opencl/kernels/tanh.cl new file mode 100644 index 0000000000000000000000000000000000000000..d9da86b1489218b5cdd1e89cae09e1f7e5c9ea7a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/tanh.cl @@ -0,0 +1,63 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +kernel void kernel_tanh_f32_nd( + global void * p_src0_base, ulong off_src0_abs, + global void * p_dst_base, ulong off_dst_abs, + int ne00, int ne01, int ne02, int ne03, + ulong nb00, ulong nb01, ulong nb02, ulong nb03, + int ne10, int ne11, int ne12, int ne13, + ulong nb10, ulong nb11, ulong nb12, ulong nb13 +) { + int i0 = get_global_id(0); + int i1 = get_global_id(1); + int i2 = get_global_id(2); + + if (i0 < ne10 && i1 < ne11 && i2 < ne12) { + for (int i3 = 0; i3 < ne13; ++i3) { + ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; + global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + + ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; + global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + + *dst_val_ptr = tanh(*src_val_ptr); + } + } +} + +kernel void kernel_tanh_f16_nd( + global void * p_src0_base, ulong off_src0_abs, + global void * p_dst_base, ulong off_dst_abs, + int ne00, int ne01, int ne02, int ne03, + ulong nb00, ulong nb01, ulong nb02, ulong nb03, + int ne10, int ne11, int ne12, int ne13, + ulong nb10, ulong nb11, ulong nb12, ulong nb13 +) { + int i0 = get_global_id(0); + int i1 = get_global_id(1); + int i2 = get_global_id(2); + + if (i0 < ne10 && i1 < ne11 && i2 < ne12) { + for (int i3 = 0; i3 < ne13; ++i3) { + ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; + global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + + ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; + global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + + *dst_val_ptr = tanh(*src_val_ptr); + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/transpose.cl b/ggml/src/ggml-opencl/kernels/transpose.cl new file mode 100644 index 0000000000000000000000000000000000000000..a11490b304c5b29d97085bc62475da0f9e3b49b5 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/transpose.cl @@ -0,0 +1,84 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +// 16-bit transpose, loading/storing a 4x4 tile of elements +kernel void kernel_transpose_16( + __read_only image1d_buffer_t input, + __write_only image1d_buffer_t output, + const uint rows, + const uint cols +) { + + const int i = get_global_id(0); + const int j = get_global_id(1); + const int i_2 = i<<2; + const int j_2 = j<<2; + + half4 temp0 = read_imageh(input, (j_2+0)*cols+i); + half4 temp1 = read_imageh(input, (j_2+1)*cols+i); + half4 temp2 = read_imageh(input, (j_2+2)*cols+i); + half4 temp3 = read_imageh(input, (j_2+3)*cols+i); + + write_imageh(output, (i_2+0)*rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); + write_imageh(output, (i_2+1)*rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); + write_imageh(output, (i_2+2)*rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); + write_imageh(output, (i_2+3)*rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); +} + +// 32-bit transpose, loading/storing a 4x4 tile of elements +kernel void kernel_transpose_32( + __read_only image1d_buffer_t input, + __write_only image1d_buffer_t output, + const uint rows, + const uint cols +) { + + const int i = get_global_id(0); + const int j = get_global_id(1); + const int i_2 = i<<2; + const int j_2 = j<<2; + + float4 temp0 = read_imagef(input, (j_2+0)*cols+i); + float4 temp1 = read_imagef(input, (j_2+1)*cols+i); + float4 temp2 = read_imagef(input, (j_2+2)*cols+i); + float4 temp3 = read_imagef(input, (j_2+3)*cols+i); + + write_imagef(output, (i_2+0)*rows+j, (float4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); + write_imagef(output, (i_2+1)*rows+j, (float4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); + write_imagef(output, (i_2+2)*rows+j, (float4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); + write_imagef(output, (i_2+3)*rows+j, (float4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); + +} + +// 32-bit transpose, loading/storing a 4x4 tile of elements +// Only used for activations +// converts to FP16 +// also adds zero padding for non multiple of 8 prompt lengths +kernel void kernel_transpose_32_16(__read_only image1d_buffer_t input, __write_only image1d_buffer_t output, const uint rows, const uint cols, const uint padded_rows) { + + const int i = get_global_id(0); + const int j = get_global_id(1); + const int i_2 = i<<2; + const int j_2 = j<<2; + half4 temp0 = {0,0,0,0}; // initialize outputs to 0 + half4 temp1 = {0,0,0,0}; + half4 temp2 = {0,0,0,0}; + half4 temp3 = {0,0,0,0}; + + if((j_2+0)*cols+i*4+3 < rows*cols*16){ // only load from a valid location. Otherwise keep register data as 0 + temp0 = read_imageh(input, (j_2+0)*cols+i); + } + if((j_2+1)*cols+i*4+3 < rows*cols*16){ + temp1 = read_imageh(input, (j_2+1)*cols+i); + } + if((j_2+2)*cols+i*4+3 < rows*cols*16){ + temp2 = read_imageh(input, (j_2+2)*cols+i); + } + if((j_2+3)*cols+i*4+3 < rows*cols*16){ + temp3 = read_imageh(input, (j_2+3)*cols+i); + } + + write_imageh(output, (i_2+0)*padded_rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); // no conditionals for output, includes zero padding + write_imageh(output, (i_2+1)*padded_rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); + write_imageh(output, (i_2+2)*padded_rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); + write_imageh(output, (i_2+3)*padded_rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); +} diff --git a/ggml/src/ggml-opencl/kernels/tsembd.cl b/ggml/src/ggml-opencl/kernels/tsembd.cl new file mode 100644 index 0000000000000000000000000000000000000000..4b1107f70ba7ad50f1a2a19a371b9626bbba5150 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/tsembd.cl @@ -0,0 +1,48 @@ +kernel void kernel_timestep_embedding( + global const void * p_timesteps, + ulong off_timesteps, + global void * p_dst, + ulong off_dst, + int dst_nb1_bytes, + int logical_dim, + int max_period +) { + int local_i; + int local_j; + int local_half_dim; + float local_timestep_val; + float local_freq; + float local_arg; + global float * local_embed_data_ptr; + global const float * local_timesteps_input_ptr; + global float * local_dst_output_base_ptr; + + local_timesteps_input_ptr = (global const float *)((global char *)p_timesteps + off_timesteps); + local_dst_output_base_ptr = (global float *)((global char *)p_dst + off_dst); + + local_i = get_global_id(1); + local_j = get_global_id(0); + + local_half_dim = logical_dim / 2; + local_embed_data_ptr = (global float *)((global char *)local_dst_output_base_ptr + local_i * dst_nb1_bytes); + + if (logical_dim % 2 != 0 && local_j == ((logical_dim + 1) / 2)) { + local_embed_data_ptr[logical_dim] = 0.0f; + } + + if (local_j >= local_half_dim) { + return; + } + + local_timestep_val = local_timesteps_input_ptr[local_i]; + + if (local_half_dim == 0) { + local_freq = 1.0f; + } else { + local_freq = exp(-log((float)max_period) * (float)local_j / (float)local_half_dim); + } + + local_arg = local_timestep_val * local_freq; + local_embed_data_ptr[local_j] = cos(local_arg); + local_embed_data_ptr[local_j + local_half_dim] = sin(local_arg); +} diff --git a/ggml/src/ggml-opencl/kernels/upscale.cl b/ggml/src/ggml-opencl/kernels/upscale.cl new file mode 100644 index 0000000000000000000000000000000000000000..219d31dbb9248c4681fe7a9e6396e1cd81cb42a4 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/upscale.cl @@ -0,0 +1,121 @@ +kernel void kernel_upscale( + global const void * p_src0, + ulong off_src0, + global void * p_dst, + ulong off_dst, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + float sf0, + float sf1, + float sf2, + float sf3 +) { + global const char * src_base = (global const char *)p_src0 + off_src0; + global float * dst_base = (global float *)((global char *)p_dst + off_dst); + + int index = get_global_id(0); + int dst_total_elements = ne10 * ne11 * ne12 * ne13; + + if (index >= dst_total_elements) { + return; + } + + int i10 = index % ne10; + int i11 = (index / ne10) % ne11; + int i12 = (index / (ne10 * ne11)) % ne12; + int i13 = index / (ne10 * ne11 * ne12); + + int i00 = (int)(i10 / sf0); + int i01 = (int)(i11 / sf1); + int i02 = (int)(i12 / sf2); + int i03 = (int)(i13 / sf3); + + ulong offset_src_element = (ulong)i03 * nb03 + (ulong)i02 * nb02 + (ulong)i01 * nb01 + (ulong)i00 * nb00; + global const float * src_element_ptr = (global const float *)(src_base + offset_src_element); + + dst_base[index] = *src_element_ptr; +} + +kernel void kernel_upscale_bilinear( + global const void * p_src0, + ulong off_src0, + global void * p_dst, + ulong off_dst, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne00_src, + int ne01_src, + int ne10_dst, + int ne11_dst, + int ne12_dst, + int ne13_dst, + float sf0, + float sf1, + float sf2, + float sf3 +) { + global const char * src_base = (global const char *)p_src0 + off_src0; + global float * dst_base = (global float *)((global char *)p_dst + off_dst); + + int index = get_global_id(0); + int dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + int i10_dst = index % ne10_dst; + int i11_dst = (index / ne10_dst) % ne11_dst; + int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + int i02_src = (int)(i12_dst / sf2); + int i03_src = (int)(i13_dst / sf3); + + const float pixel_offset = 0.5f; + + float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset; + long y0_src = (long)floor(y_src_f); + long y1_src = y0_src + 1; + + y0_src = max(0L, min(y0_src, (long)ne01_src - 1)); + y1_src = max(0L, min(y1_src, (long)ne01_src - 1)); + + float dy = y_src_f - (float)y0_src; + dy = max(0.0f, min(dy, 1.0f)); + + float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset; + long x0_src = (long)floor(x_src_f); + long x1_src = x0_src + 1; + + x0_src = max(0L, min(x0_src, (long)ne00_src - 1)); + x1_src = max(0L, min(x1_src, (long)ne00_src - 1)); + + float dx = x_src_f - (float)x0_src; + dx = max(0.0f, min(dx, 1.0f)); + + global const float * p_a = (global const float *)(src_base + (ulong)x0_src * nb00 + (ulong)y0_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03); + global const float * p_b = (global const float *)(src_base + (ulong)x1_src * nb00 + (ulong)y0_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03); + global const float * p_c = (global const float *)(src_base + (ulong)x0_src * nb00 + (ulong)y1_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03); + global const float * p_d = (global const float *)(src_base + (ulong)x1_src * nb00 + (ulong)y1_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03); + + const float val_a = *p_a; + const float val_b = *p_b; + const float val_c = *p_c; + const float val_d = *p_d; + + float result = val_a * (1.0f - dx) * (1.0f - dy) + + val_b * dx * (1.0f - dy) + + val_c * (1.0f - dx) * dy + + val_d * dx * dy; + + dst_base[index] = result; +} diff --git a/ggml/src/ggml-opt.cpp b/ggml/src/ggml-opt.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a3c82d6757714b64f4a46dfb0e5a7a62c22ee566 --- /dev/null +++ b/ggml/src/ggml-opt.cpp @@ -0,0 +1,1037 @@ +#include "ggml-opt.h" + +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "ggml-impl.h" + +#include +#include +#include +#include +#include +#include +#include + +struct ggml_opt_dataset { + struct ggml_context * ctx = nullptr; + ggml_backend_buffer_t buf = nullptr; + struct ggml_tensor * data = nullptr; + struct ggml_tensor * labels = nullptr; + + int64_t ndata = -1; + int64_t ndata_shard = -1; + size_t nbs_data = -1; + size_t nbs_labels = -1; + + std::vector permutation; +}; + +struct ggml_opt_context { + ggml_backend_sched_t backend_sched = nullptr; + ggml_cgraph * allocated_graph = nullptr; + ggml_cgraph * allocated_graph_copy = nullptr; + struct ggml_context * ctx_static = nullptr; + struct ggml_context * ctx_cpu = nullptr; + struct ggml_context * ctx_compute = nullptr; + struct ggml_context * ctx_copy = nullptr; + ggml_backend_buffer_t buf_static = nullptr; + ggml_backend_buffer_t buf_cpu = nullptr; + std::mt19937 rng; + enum ggml_opt_loss_type loss_type; + enum ggml_opt_build_type build_type; + enum ggml_opt_build_type build_type_alloc; + + struct ggml_tensor * inputs = nullptr; + struct ggml_tensor * outputs = nullptr; + struct ggml_tensor * labels = nullptr; + + struct ggml_tensor * loss = nullptr; + struct ggml_tensor * pred = nullptr; + struct ggml_tensor * ncorrect = nullptr; + + struct ggml_cgraph * gf = nullptr; + struct ggml_cgraph * gb_grad = nullptr; + struct ggml_cgraph * gb_opt = nullptr; + bool static_graphs = false; + bool eval_ready = false; + std::vector grad_accs; + std::vector grad_m; + std::vector grad_v; + + int64_t iter = 1; + int32_t opt_period = 1; + int32_t opt_i = 0; + bool loss_per_datapoint = false; + + ggml_opt_get_optimizer_params get_opt_pars = nullptr; + void * get_opt_pars_ud = nullptr; + struct ggml_tensor * adamw_params = nullptr; +}; + +struct ggml_opt_result { + int64_t ndata = 0; + std::vector loss; + std::vector pred; + int64_t ncorrect = 0; + + int64_t opt_period = -1; + bool loss_per_datapoint = false; +}; + +// ====== Dataset ====== + +ggml_opt_dataset_t ggml_opt_dataset_init( + enum ggml_type type_data, + enum ggml_type type_label, + int64_t ne_datapoint, + int64_t ne_label, + int64_t ndata, + int64_t ndata_shard) { + GGML_ASSERT(ne_datapoint > 0); + GGML_ASSERT(ne_label >= 0); + GGML_ASSERT(ndata > 0); + GGML_ASSERT(ndata_shard > 0); + + ggml_opt_dataset_t result = new ggml_opt_dataset; + result->ndata = ndata; + result->ndata_shard = ndata_shard; + + { + struct ggml_init_params params = { + /*.mem_size =*/ 2*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + result->ctx = ggml_init(params); + } + + result->data = ggml_new_tensor_2d(result->ctx, type_data, ne_datapoint, ndata); + result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata; + + if (ne_label > 0) { + result->labels = ggml_new_tensor_2d(result->ctx, type_label, ne_label, ndata); + result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata; + } else { + result->labels = nullptr; + result->nbs_labels = 0; + } + + result->buf = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx, ggml_backend_cpu_buffer_type()); + + const int64_t nshards = ndata/ndata_shard; + result->permutation.resize(nshards); + for (int64_t i = 0; i < nshards; ++i) { + result->permutation[i] = i; + } + return result; +} + +void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) { + ggml_backend_buffer_free(dataset->buf); + ggml_free(dataset->ctx); + delete dataset; +} + +int64_t ggml_opt_dataset_ndata(ggml_opt_dataset_t dataset) { + return dataset->ndata; +} + +struct ggml_tensor * ggml_opt_dataset_data(ggml_opt_dataset_t dataset) { + return dataset->data; +} + +struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset) { + return dataset->labels; +} + +void ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata) { + GGML_ASSERT(idata <= dataset->ndata); + + if (idata < 0) { + std::shuffle(dataset->permutation.begin(), dataset->permutation.end(), opt_ctx->rng); + return; + } + + GGML_ASSERT(idata % dataset->ndata_shard == 0); + const int64_t ishard_max = idata / dataset->ndata_shard; + std::shuffle(dataset->permutation.begin(), dataset->permutation.begin() + ishard_max, opt_ctx->rng); +} + +void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor * data_batch, struct ggml_tensor * labels_batch, int64_t ibatch) { + GGML_ASSERT( data_batch && ggml_is_contiguous(data_batch)); + GGML_ASSERT(!labels_batch || ggml_is_contiguous(labels_batch)); + GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr)); + GGML_ASSERT( data_batch->type == dataset->data->type); + GGML_ASSERT(!labels_batch || labels_batch->type == dataset->labels->type); + + const size_t nb_data_batch = ggml_nbytes(data_batch); + GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0); + const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data; + + if (labels_batch) { + const size_t nb_labels_batch = ggml_nbytes(labels_batch); + GGML_ASSERT(nb_labels_batch == shards_per_batch*dataset->nbs_labels); + } + + GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size())); + + for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) { + const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch]; + + const char * ptr_data = (const char *) dataset->data->data + ishard*dataset->nbs_data; + ggml_backend_tensor_set(data_batch, ptr_data, ishard_batch*dataset->nbs_data, dataset->nbs_data); + + if (!labels_batch) { + continue; + } + + const char * ptr_labels = (const char *) dataset->labels->data + ishard*dataset->nbs_labels; + ggml_backend_tensor_set(labels_batch, ptr_labels, ishard_batch*dataset->nbs_labels, dataset->nbs_labels); + } +} + +void ggml_opt_dataset_get_batch_host(ggml_opt_dataset_t dataset, void * data_batch, size_t nb_data_batch, void * labels_batch, int64_t ibatch) { + GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr)); + GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0); + + const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data; + + GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size())); + + for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) { + const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch]; + + const char * ptr_data = (const char *) dataset->data->data + ishard *dataset->nbs_data; + char * ptr_data_batch = (char *) data_batch + ishard_batch*dataset->nbs_data; + memcpy(ptr_data_batch, ptr_data, dataset->nbs_data); + + if (!labels_batch) { + continue; + } + + const char * ptr_labels = (const char *) dataset->labels->data + ishard *dataset->nbs_labels; + char * ptr_labels_batch = (char *) labels_batch + ishard_batch*dataset->nbs_labels; + memcpy(ptr_labels_batch, ptr_labels, dataset->nbs_labels); + } +} + +// ====== Model / Context ====== + +struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) { + GGML_UNUSED(userdata); + + ggml_opt_optimizer_params result; + + result.adamw.alpha = 0.001f; + result.adamw.beta1 = 0.9f; + result.adamw.beta2 = 0.999f; + result.adamw.eps = 1e-8f; + result.adamw.wd = 0.0f; + + return result; +} + +struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) { + return *((struct ggml_opt_optimizer_params *) userdata); +} + +struct ggml_opt_params ggml_opt_default_params( + ggml_backend_sched_t backend_sched, + enum ggml_opt_loss_type loss_type) { + return { + /*backend_sched =*/ backend_sched, + /*ctx_compute =*/ nullptr, + /*inputs =*/ nullptr, + /*logits =*/ nullptr, + /*loss_type =*/ loss_type, + /*build_type =*/ GGML_OPT_BUILD_TYPE_OPT, + /*opt_period =*/ 1, + /*get_opt_pars =*/ ggml_opt_get_default_optimizer_params, + /*get_opt_pars_ud =*/ nullptr, + }; +} + +static ggml_tensor * map_tensor(std::map & tensor_map, ggml_context * ctx, ggml_tensor * tensor) { + if (!tensor) { + return nullptr; + } + + if (tensor_map.find(tensor) != tensor_map.end()) { + return tensor_map[tensor]; + } + + ggml_tensor * new_tensor = ggml_dup_tensor(ctx, tensor); + tensor_map[tensor] = new_tensor; + + new_tensor->op = tensor->op; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + new_tensor->nb[i] = tensor->nb[i]; + } + new_tensor->flags = tensor->flags; + memcpy(new_tensor->op_params, tensor->op_params, sizeof(tensor->op_params)); + strcpy(new_tensor->name, tensor->name); + new_tensor->data = tensor->data; + new_tensor->buffer = tensor->buffer; + new_tensor->extra = tensor->extra; + new_tensor->view_offs = tensor->view_offs; + new_tensor->view_src = map_tensor(tensor_map, ctx, tensor->view_src); + for (int i = 0; i < GGML_MAX_SRC; i++) { + new_tensor->src[i] = map_tensor(tensor_map, ctx, tensor->src[i]); + } + + return new_tensor; +} + +static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * src) { + std::map tensor_map; + + ggml_cgraph * dst = ggml_new_graph_custom(ctx, src->size, /*grads =*/ true); + + for (int i = 0; i < src->n_leafs; i++) { + ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->leafs[i])); + } + GGML_ASSERT(dst->n_leafs == src->n_leafs); + for (int i = 0; i < src->n_nodes; i++) { + ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->nodes[i])); + } + GGML_ASSERT(dst->n_nodes == src->n_nodes); + for (int i = 0; i < src->n_nodes; ++i) { + const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]); + const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]); + + GGML_ASSERT(igrad_src != GGML_HASHSET_FULL); + GGML_ASSERT(ggml_bitset_get(src->visited_hash_set.used, igrad_src)); + GGML_ASSERT(igrad_dst != GGML_HASHSET_FULL); + GGML_ASSERT(ggml_bitset_get(dst->visited_hash_set.used, igrad_dst)); + + dst->grads[igrad_dst] = src->grads[igrad_src]; + dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src]; + } + + return dst; +} + +static void ggml_opt_build(ggml_opt_context_t opt_ctx) { + GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc"); + GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically"); + + const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD && + !(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1); + + ggml_set_input(opt_ctx->inputs); + ggml_set_output(opt_ctx->outputs); + + int n_param = 0; + for (int i = 0; i < opt_ctx->gf->n_nodes; ++i) { + const struct ggml_tensor * node = opt_ctx->gf->nodes[i]; + if (node->flags & GGML_TENSOR_FLAG_PARAM) { + n_param++; + } + GGML_ASSERT(!(node->flags & GGML_TENSOR_FLAG_LOSS) && "support for extra loss terms not implemented"); + } + + if (!opt_ctx->ctx_static) { + // The static context is used for: + // - gradients (1 per loss, 1 tensor per param if using gradient accumulation) + // - optimizer momenta (2 tensors per param) + // - labels (if using static graphs) + // - loss (if using static graphs, up to 5 tensors) + // - pred (if using static graphs) + // - ncorrect (if using static graphs, 2 tensors). + constexpr size_t n_loss = 1; + const size_t tensors_per_param = (accumulate ? 1 : 0) + + (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0); + const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0; + const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead(); + struct ggml_init_params params = { + /*.mem_size =*/ size_meta, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + opt_ctx->ctx_static = ggml_init(params); + } + GGML_ASSERT(opt_ctx->build_type <= opt_ctx->build_type_alloc); + + { + // The cpu context is allocated statically if using static graphs, dynamically otherwise. + // It is used for: + // - optimizer parameters (1 shared for all optimizer invocations) + const size_t size_meta = 1 * ggml_tensor_overhead(); + struct ggml_init_params params = { + /*.mem_size =*/ size_meta, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ggml_free(opt_ctx->ctx_cpu); + opt_ctx->ctx_cpu = ggml_init(params); + + ggml_backend_buffer_free(opt_ctx->buf_cpu); + opt_ctx->buf_cpu = nullptr; + } + + struct ggml_context * ctx_results = opt_ctx->static_graphs ? opt_ctx->ctx_static : opt_ctx->ctx_compute; + + switch (opt_ctx->loss_type) { + case GGML_OPT_LOSS_TYPE_MEAN: { + opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs); + ggml_set_name(opt_ctx->loss, "loss_sum"); + const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs)); + opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale); + ggml_set_name(opt_ctx->loss, "loss_mean"); + opt_ctx->loss_per_datapoint = true; + break; + } + case GGML_OPT_LOSS_TYPE_SUM: { + opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs); + ggml_set_name(opt_ctx->loss, "loss_sum"); + opt_ctx->loss_per_datapoint = false; + break; + } + case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: { + opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs); + ggml_set_input(opt_ctx->labels); + ggml_set_name(opt_ctx->labels, "labels"); + opt_ctx->loss = ggml_cross_entropy_loss(ctx_results, opt_ctx->outputs, opt_ctx->labels); + ggml_set_name(opt_ctx->loss, "loss_cross_entropy"); + if (opt_ctx->opt_period > 1) { + opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, 1.0f / opt_ctx->opt_period); + ggml_set_name(opt_ctx->loss, "loss_cross_entropy_scaled"); + } + opt_ctx->loss_per_datapoint = true; + break; + } + case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: { + opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs); + ggml_set_input(opt_ctx->labels); + ggml_set_name(opt_ctx->labels, "labels"); + opt_ctx->loss = ggml_sub(ctx_results, opt_ctx->outputs, opt_ctx->labels); + ggml_set_name(opt_ctx->loss, "loss_error"); + opt_ctx->loss = ggml_sqr(ctx_results, opt_ctx->loss); + ggml_set_name(opt_ctx->loss, "loss_squared_error"); + opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->loss); + ggml_set_name(opt_ctx->loss, "loss_sum_squared_error"); + const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs)); + opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale); + ggml_set_name(opt_ctx->loss, "loss_mean_squared_error"); + opt_ctx->loss_per_datapoint = true; + break; + } + } + ggml_set_output(opt_ctx->loss); + ggml_set_loss(opt_ctx->loss); + ggml_build_forward_expand(opt_ctx->gf, opt_ctx->loss); + + if (opt_ctx->loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY) { + opt_ctx->pred = ggml_argmax(ctx_results, opt_ctx->outputs); + ggml_set_name(opt_ctx->pred, "pred"); + ggml_set_output(opt_ctx->pred); + ggml_build_forward_expand(opt_ctx->gf, opt_ctx->pred); + + opt_ctx->ncorrect = ggml_count_equal(ctx_results, opt_ctx->pred, ggml_argmax(ctx_results, opt_ctx->labels)); + ggml_set_name(opt_ctx->ncorrect, "ncorrect"); + ggml_set_output(opt_ctx->ncorrect); + ggml_build_forward_expand(opt_ctx->gf, opt_ctx->ncorrect); + } + + if (opt_ctx->buf_static) { + if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_FORWARD) { + return; + } + } else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_FORWARD) { + opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors( + opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0)); + return; + } + + if (opt_ctx->grad_accs.empty()) { + GGML_ASSERT(opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD); + + const int n_nodes = opt_ctx->gf->n_nodes; + opt_ctx->grad_accs.resize(n_nodes); + for (int i = 0; i < n_nodes; ++i) { + ggml_tensor * node = opt_ctx->gf->nodes[i]; + if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) { + opt_ctx->grad_accs[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); + } else { + opt_ctx->grad_accs[i] = nullptr; + } + } + + if (opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) { + opt_ctx->grad_m.resize(n_nodes); + opt_ctx->grad_v.resize(n_nodes); + for (int i = 0; i < n_nodes; ++i) { + ggml_tensor * node = opt_ctx->gf->nodes[i]; + if (node->flags & GGML_TENSOR_FLAG_PARAM) { + opt_ctx->grad_m[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); + opt_ctx->grad_v[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); + } else { + opt_ctx->grad_m[i] = nullptr; + opt_ctx->grad_v[i] = nullptr; + } + } + } + } + + // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients. + opt_ctx->gb_grad = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gf, /*force_grads =*/ true); + ggml_build_backward_expand(opt_ctx->ctx_compute, opt_ctx->gb_grad, opt_ctx->grad_accs.data()); + + if (opt_ctx->buf_static) { + if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_GRAD) { + return; + } + } else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_GRAD) { + opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0)); + ggml_graph_reset(opt_ctx->gb_grad); + } + + GGML_ASSERT(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT); + + // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step. + opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true); + + opt_ctx->adamw_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, 7); + ggml_set_input(opt_ctx->adamw_params); + ggml_set_name(opt_ctx->adamw_params, "adamw_params"); + + for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) { + struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i]; + struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node); + + if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) { + struct ggml_tensor * m = opt_ctx->grad_m[i]; + struct ggml_tensor * v = opt_ctx->grad_v[i]; + struct ggml_tensor * opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params); + + ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str()); + ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str()); + ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str()); + + ggml_build_forward_expand(opt_ctx->gb_opt, opt_step); + } + } + + if (!opt_ctx->buf_static) { + opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors( + opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0)); + ggml_graph_reset(opt_ctx->gb_opt); + } + + opt_ctx->buf_cpu = ggml_backend_alloc_ctx_tensors_from_buft(opt_ctx->ctx_cpu, ggml_backend_cpu_buffer_type()); +} + +ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) { + ggml_opt_context_t result = new struct ggml_opt_context; + result->backend_sched = params.backend_sched; + result->ctx_compute = params.ctx_compute; + result->loss_type = params.loss_type; + result->build_type = params.build_type; + result->build_type_alloc = params.build_type; + result->inputs = params.inputs; + result->outputs = params.outputs; + result->opt_period = params.opt_period; + result->get_opt_pars = params.get_opt_pars; + result->get_opt_pars_ud = params.get_opt_pars_ud; + + GGML_ASSERT(result->opt_period >= 1); + + result->static_graphs = result->ctx_compute; + + if (!result->static_graphs) { + GGML_ASSERT(!result->inputs); + GGML_ASSERT(!result->outputs); + return result; + } + + GGML_ASSERT(result->inputs); + GGML_ASSERT(result->outputs); + + result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass. + ggml_build_forward_expand(result->gf, result->outputs); + + ggml_opt_build(result); + + return result; +} + +void ggml_opt_free(ggml_opt_context_t opt_ctx) { + if (opt_ctx == nullptr) { + return; + } + ggml_backend_buffer_free(opt_ctx->buf_static); + ggml_backend_buffer_free(opt_ctx->buf_cpu); + ggml_free(opt_ctx->ctx_static); + ggml_free(opt_ctx->ctx_cpu); + delete opt_ctx; +} + +void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) { + if (optimizer) { + ggml_graph_reset(opt_ctx->gb_opt); + opt_ctx->iter = 1; + } else { + ggml_graph_reset(opt_ctx->gb_grad); + } +} + +bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx) { + return opt_ctx->static_graphs; +} + +struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) { + return opt_ctx->inputs; +} + +struct ggml_tensor * ggml_opt_outputs(ggml_opt_context_t opt_ctx) { + return opt_ctx->outputs; +} + +struct ggml_tensor * ggml_opt_labels(ggml_opt_context_t opt_ctx) { + return opt_ctx->labels; +} + +struct ggml_tensor * ggml_opt_loss(ggml_opt_context_t opt_ctx) { + return opt_ctx->loss; +} + +struct ggml_tensor * ggml_opt_pred(ggml_opt_context_t opt_ctx) { + return opt_ctx->pred; +} + +struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx) { + return opt_ctx->ncorrect; +} + +struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node) { + return ggml_graph_get_grad_acc(opt_ctx->gb_opt, node); +} + +// ====== Optimization Result ====== + +ggml_opt_result_t ggml_opt_result_init() { + return new ggml_opt_result; +} + +void ggml_opt_result_free(ggml_opt_result_t result) { + delete result; +} + +void ggml_opt_result_reset(ggml_opt_result_t result) { + result->ndata = 0; + result->loss.clear(); + result->pred.clear(); + result->ncorrect = 0; +} + +void ggml_opt_result_ndata(ggml_opt_result_t result, int64_t * ndata) { + *ndata = result->ndata; +} + +void ggml_opt_result_loss(ggml_opt_result_t result, double * loss, double * unc) { + const int64_t nbatches = result->loss.size(); // Number of physical batches. + + if (nbatches == 0) { + *loss = 0.0; + *unc = NAN; + return; + } + + double sum = 0.0; + double sum_squared = 0.0; + + for (const float & loss : result->loss) { + // If the loss is per datapoint it was scaled by 1.0f/opt_period for each physical batch. + const float loss_scaled = result->loss_per_datapoint ? loss*result->opt_period : loss; + sum += loss_scaled; + sum_squared += loss_scaled*loss_scaled; + } + + const double mean = sum/nbatches; + *loss = result->loss_per_datapoint ? mean : sum; + + if (!unc) { + return; + } + + if (nbatches < 2) { + *unc = NAN; + return; + } + + const double var_sum = sum_squared/nbatches - mean*mean; // variance without Bessel's correction, i.e. nbatches/(nbatches-1) + *unc = result->loss_per_datapoint ? sqrt(var_sum / (nbatches - 1)) : sqrt(var_sum * nbatches/(nbatches - 1)); +} + +void ggml_opt_result_pred(ggml_opt_result_t result, int32_t * pred) { + for (size_t i = 0; i < result->pred.size(); ++i) { + pred[i] = result->pred[i]; + } +} + +void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, double * unc) { + *accuracy = result->ncorrect >= 0 ? double(result->ncorrect) / double(result->ndata) : NAN; + + if (!unc) { + return; + } + + *unc = result->ncorrect >= 0 && result->ndata >= 2 ? + sqrt((*accuracy) * (1.0 - (*accuracy)) / double(result->ndata - 1)) : NAN; +} + +// ====== Computation ====== + +void ggml_opt_prepare_alloc( + ggml_opt_context_t opt_ctx, + struct ggml_context * ctx_compute, + struct ggml_cgraph * gf, + struct ggml_tensor * inputs, + struct ggml_tensor * outputs) { + GGML_ASSERT(!opt_ctx->static_graphs); + opt_ctx->ctx_compute = ctx_compute; + opt_ctx->gf = gf; + opt_ctx->inputs = inputs; + opt_ctx->outputs = outputs; +} + +void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) { + GGML_ASSERT(!opt_ctx->eval_ready); + if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period > 1 && opt_ctx->opt_i == 0) { + ggml_graph_reset(opt_ctx->gb_grad); + } + if (backward) { + const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period; + opt_ctx->build_type = opt_i_next == 0 ? GGML_OPT_BUILD_TYPE_OPT : GGML_OPT_BUILD_TYPE_GRAD; + } else { + opt_ctx->build_type = GGML_OPT_BUILD_TYPE_FORWARD; + } + + if (!opt_ctx->static_graphs) { + ggml_opt_build(opt_ctx); + } + + struct ggml_cgraph * graph = nullptr; + switch (opt_ctx->build_type) { + case GGML_OPT_BUILD_TYPE_FORWARD: { + graph = opt_ctx->gf; + } break; + case GGML_OPT_BUILD_TYPE_GRAD: { + graph = opt_ctx->gb_grad; + } break; + case GGML_OPT_BUILD_TYPE_OPT: { + graph = opt_ctx->gb_opt; + } break; + } + GGML_ASSERT(graph); + + if (opt_ctx->allocated_graph == graph) { + opt_ctx->eval_ready = true; + return; + } + + ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph + + if (opt_ctx->static_graphs) { + ggml_init_params params = { + /*.mem_size =*/ graph->size*ggml_tensor_overhead() + ggml_graph_overhead_custom(graph->size, graph->grads), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ggml_free(opt_ctx->ctx_copy); + opt_ctx->ctx_copy = ggml_init(params); + + opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph); + } else { + opt_ctx->allocated_graph_copy = graph; + } + + ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); + opt_ctx->allocated_graph = graph; + + opt_ctx->eval_ready = true; +} + +void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) { + GGML_ASSERT(opt_ctx->eval_ready); + if (opt_ctx->allocated_graph == opt_ctx->gb_opt) { + struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud); + + GGML_ASSERT(opt_pars.adamw.alpha > 0.0f); + GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f); + GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f); + GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f); + GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f); + GGML_ASSERT(opt_pars.adamw.eps >= 0.0f); + GGML_ASSERT(opt_pars.adamw.wd >= 0.0f); + GGML_ASSERT(opt_pars.adamw.wd <= 1.0f); + + // beta1, beta2 after applying warmup + const float beta1h = 1.0f/(1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter)); + const float beta2h = 1.0f/(1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter)); + + float * adamw_par_data = ggml_get_data_f32(opt_ctx->adamw_params); + adamw_par_data[0] = opt_pars.adamw.alpha; + adamw_par_data[1] = opt_pars.adamw.beta1; + adamw_par_data[2] = opt_pars.adamw.beta2; + adamw_par_data[3] = opt_pars.adamw.eps; + adamw_par_data[4] = opt_pars.adamw.wd; + adamw_par_data[5] = beta1h; + adamw_par_data[6] = beta2h; + } + + ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); + opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt; + opt_ctx->opt_i = (opt_ctx->opt_i + 1) % opt_ctx->opt_period; + + if (!opt_ctx->static_graphs) { + opt_ctx->gf = nullptr; + opt_ctx->gb_grad = nullptr; + opt_ctx->gb_opt = nullptr; + opt_ctx->allocated_graph = nullptr; + opt_ctx->allocated_graph_copy = nullptr; + } + + opt_ctx->eval_ready = false; + + if (!result) { + return; + } + + if (result->ndata == 0) { + result->loss_per_datapoint = opt_ctx->loss_per_datapoint; + result->opt_period = opt_ctx->opt_period; + } else { + GGML_ASSERT(result->loss_per_datapoint == opt_ctx->loss_per_datapoint); + GGML_ASSERT(result->opt_period == opt_ctx->opt_period); + } + + const int64_t ndata = opt_ctx->outputs->ne[1]; + GGML_ASSERT(result->ndata == ndata*int64_t(result->loss.size()) && "varying batch size not supported"); + result->ndata += ndata; + + GGML_ASSERT(ggml_is_scalar(opt_ctx->loss)); + GGML_ASSERT(opt_ctx->loss->type == GGML_TYPE_F32); + float loss; + ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, ggml_nbytes(opt_ctx->loss)); + result->loss.push_back(loss); + + if (opt_ctx->pred) { + GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32); + std::vector pred(ndata); + ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred)); + result->pred.insert(result->pred.end(), pred.begin(), pred.end()); + } + + if (!opt_ctx->ncorrect || result->ncorrect < 0) { + result->ncorrect = -1; + return; + } + + GGML_ASSERT(ggml_is_scalar(opt_ctx->ncorrect)); + GGML_ASSERT(opt_ctx->ncorrect->type == GGML_TYPE_I64); + int64_t ncorrect; + ggml_backend_tensor_get(opt_ctx->ncorrect, &ncorrect, 0, ggml_nbytes(opt_ctx->ncorrect)); + result->ncorrect += ncorrect; +} + +// ====== High-Level Functions ====== + +void ggml_opt_epoch( + ggml_opt_context_t opt_ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, + ggml_opt_result_t result_eval, + int64_t idata_split, + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval) { + GGML_ASSERT(ggml_opt_static_graphs(opt_ctx) && "ggml_opt_epoch requires static graphs"); + struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx); + struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); + struct ggml_tensor * data = ggml_opt_dataset_data(dataset); + GGML_ASSERT(data->ne[0] == inputs->ne[0]); + + const int64_t ndata = data->ne[1]; + const int64_t ndata_batch = inputs->ne[1]; + + GGML_ASSERT(data->ne[1] % inputs->ne[1] == 0); + const int64_t nbatches = ndata/ndata_batch; + + idata_split = idata_split < 0 ? ndata : idata_split; + GGML_ASSERT(idata_split % ndata_batch == 0); + const int64_t ibatch_split = idata_split / ndata_batch; + + int64_t ibatch = 0; + int64_t t_loop_start = ggml_time_us(); + for (; ibatch < ibatch_split; ++ibatch) { + ggml_opt_alloc(opt_ctx, /*backward =*/ true); + ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch); + ggml_opt_eval(opt_ctx, result_train); + if (callback_train) { + callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start); + } + } + t_loop_start = ggml_time_us(); + for (; ibatch < nbatches; ++ibatch) { + ggml_opt_alloc(opt_ctx, /*backward =*/ false); + ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch); + ggml_opt_eval(opt_ctx, result_eval); + if (callback_eval) { + callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start); + } + } +} + +void ggml_opt_epoch_callback_progress_bar( + bool train, + ggml_opt_context_t opt_ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result, + int64_t ibatch, + int64_t ibatch_max, + int64_t t_start_us) { + fprintf(stderr, "%s[", train ? "train: " : "val: "); + + // The progress bar consists of partially filled blocks, unicode has 8 separate fill levels. + constexpr int64_t bar_length = 8; + const int64_t ibatch8 = 8 * ibatch; + for (int64_t j = 0; j < bar_length; ++j) { + if (ibatch_max * (8*j + 8) / bar_length < ibatch8) { + fprintf(stderr, "\u2588"); // full block + } else if (ibatch_max * (8*j + 7) / bar_length < ibatch8) { + fprintf(stderr, "\u2589"); // 7/8 filled + } else if (ibatch_max * (8*j + 6) / bar_length < ibatch8) { + fprintf(stderr, "\u258A"); // 6/8 filled + } else if (ibatch_max * (8*j + 5) / bar_length < ibatch8) { + fprintf(stderr, "\u258B"); // 5/8 filled + } else if (ibatch_max * (8*j + 4) / bar_length < ibatch8) { + fprintf(stderr, "\u258C"); // 4/8 filled + } else if (ibatch_max * (8*j + 3) / bar_length < ibatch8) { + fprintf(stderr, "\u258D"); // 3/8 filled + } else if (ibatch_max * (8*j + 2) / bar_length < ibatch8) { + fprintf(stderr, "\u258E"); // 2/8 filled + } else if (ibatch_max * (8*j + 1) / bar_length < ibatch8) { + fprintf(stderr, "\u258F"); // 1/8 filled + } else { + fprintf(stderr, " "); + } + } + + const int64_t batch_size = ggml_opt_inputs(opt_ctx)->ne[1]; + const int64_t idata = ibatch*batch_size; + const int64_t idata_max = ibatch_max*batch_size; + + double loss; + double loss_unc; + ggml_opt_result_loss(result, &loss, &loss_unc); + + double accuracy; + double accuracy_unc; + ggml_opt_result_accuracy(result, &accuracy, &accuracy_unc); + + const int64_t t_ibatch_us = ggml_time_us() - t_start_us; + int64_t t_ibatch_s = t_ibatch_us / 1000000; + const int64_t t_ibatch_h = t_ibatch_s / 3600; + t_ibatch_s -= t_ibatch_h * 3600; + const int64_t t_ibatch_m = t_ibatch_s / 60; + t_ibatch_s -= t_ibatch_m * 60; + + const int64_t t_eta_us = t_ibatch_us * (ibatch_max - ibatch)/ibatch; + int64_t t_eta_s = t_eta_us / 1000000; + const int64_t t_eta_h = t_eta_s / 3600; + t_eta_s -= t_eta_h * 3600; + const int64_t t_eta_m = t_eta_s / 60; + t_eta_s -= t_eta_m * 60; + + fprintf(stderr, "] data=%07" PRId64 "/%07" PRId64 " loss=%.5lf±%.5lf acc=%.2lf±%.2lf%% " + "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " \r", + idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc, + t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s); + if (ibatch == ibatch_max) { + fprintf(stderr, "\n"); + } + fflush(stderr); + + GGML_UNUSED(dataset); +} + +void ggml_opt_fit( + ggml_backend_sched_t backend_sched, + ggml_context * ctx_compute, + ggml_tensor * inputs, + ggml_tensor * outputs, + ggml_opt_dataset_t dataset, + enum ggml_opt_loss_type loss_type, + ggml_opt_get_optimizer_params get_opt_pars, + int64_t nepoch, + int64_t nbatch_logical, + float val_split, + bool silent) { + ggml_time_init(); + const int64_t t_start_us = ggml_time_us(); + + const int64_t ndata = ggml_opt_dataset_data(dataset)->ne[1]; + const int64_t nbatch_physical = inputs->ne[1]; + GGML_ASSERT(ndata % nbatch_logical == 0); + GGML_ASSERT(nbatch_logical % nbatch_physical == 0); + + const int64_t opt_period = nbatch_logical / nbatch_physical; + const int64_t nbatches_logical = ndata / nbatch_logical; + + GGML_ASSERT(val_split >= 0.0f); + GGML_ASSERT(val_split < 1.0f); + const int64_t ibatch_split = int64_t(((1.0f - val_split) * nbatches_logical)) * opt_period; // train <-> val split index (physical) + const int64_t idata_split = ibatch_split * nbatch_physical; + + int64_t epoch = 1; + + ggml_opt_params params = ggml_opt_default_params(backend_sched, loss_type); + params.ctx_compute = ctx_compute; + params.inputs = inputs; + params.outputs = outputs; + params.opt_period = opt_period; + params.get_opt_pars = get_opt_pars; + params.get_opt_pars_ud = &epoch; + ggml_opt_context_t opt_ctx = ggml_opt_init(params); + + // Shuffling the data is generally useful but there is only a point if not all data is used in a single batch. + if (nbatch_logical < ndata) { + ggml_opt_dataset_shuffle(opt_ctx, dataset, -1); // Shuffle all data (train + validation). + } + + ggml_opt_result_t result_train = ggml_opt_result_init(); + ggml_opt_result_t result_val = ggml_opt_result_init(); + + ggml_opt_epoch_callback epoch_callback = silent ? nullptr : ggml_opt_epoch_callback_progress_bar; + + for (; epoch <= nepoch; ++epoch) { + if (nbatch_logical < idata_split) { + ggml_opt_dataset_shuffle(opt_ctx, dataset, idata_split); + } + + ggml_opt_result_reset(result_train); + ggml_opt_result_reset(result_val); + + if (!silent) { + fprintf(stderr, "%s: epoch %04" PRId64 "/%04" PRId64 ":\n", __func__, epoch, nepoch); + } + ggml_opt_epoch(opt_ctx, dataset, result_train, result_val, idata_split, epoch_callback, epoch_callback); + if (!silent) { + fprintf(stderr, "\n"); + } + } + + if (!silent) { + int64_t t_total_s = (ggml_time_us() - t_start_us) / 1000000; + const int64_t t_total_h = t_total_s / 3600; + t_total_s -= t_total_h * 3600; + const int64_t t_total_m = t_total_s / 60; + t_total_s -= t_total_m * 60; + fprintf(stderr, "%s: training took %02" PRId64 ":%02" PRId64 ":%02" PRId64 "\n", __func__, t_total_h, t_total_m, t_total_s); + } + + ggml_opt_free(opt_ctx); + ggml_opt_result_free(result_train); + ggml_opt_result_free(result_val); +} diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c new file mode 100644 index 0000000000000000000000000000000000000000..e389a46dbed87450cc5095b7e164daba826807e0 --- /dev/null +++ b/ggml/src/ggml-quants.c @@ -0,0 +1,5230 @@ +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" + +#include "ggml-quants.h" +#include "ggml-impl.h" +#include "ggml-cpu/ggml-cpu-impl.h" +#include "ggml-cpu.h" + +#include +#include +#include +#include +#include // for qsort +#include // for GGML_ASSERT + +#define GROUP_MAX_EPS 1e-15f +#define GROUP_MAX_EPS_IQ3_XXS 1e-8f +#define GROUP_MAX_EPS_IQ2_S 1e-8f +#define GROUP_MAX_EPS_IQ1_M 1e-7f +#define GROUP_MAX_EPS_IQ1_S 1e-12f + +#define UNUSED GGML_UNUSED + +// reference implementation for deterministic creation of model files +void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) { + static const int qk = QK4_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float d = max / -8; + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + for (int j = 0; j < qk/2; ++j) { + const float x0 = x[i*qk + 0 + j]*id; + const float x1 = x[i*qk + qk/2 + j]*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); + + y[i].qs[j] = xi0; + y[i].qs[j] |= xi1 << 4; + } + } +} + +void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k) { + const int qk = QK4_1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + + if (v < min) min = v; + if (v > max) max = v; + } + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + y[i].m = GGML_FP32_TO_FP16(min); + + for (int j = 0; j < qk/2; ++j) { + const float x0 = (x[i*qk + 0 + j] - min)*id; + const float x1 = (x[i*qk + qk/2 + j] - min)*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); + + y[i].qs[j] = xi0; + y[i].qs[j] |= xi1 << 4; + } + } +} + +void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k) { + static const int qk = QK5_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float d = max / -16; + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + uint32_t qh = 0; + + for (int j = 0; j < qk/2; ++j) { + const float x0 = x[i*qk + 0 + j]*id; + const float x1 = x[i*qk + qk/2 + j]*id; + + const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); + const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); + + y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + + // get the 5-th bit and store it in qh at the right position + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2); + } + + memcpy(&y[i].qh, &qh, sizeof(qh)); + } +} + +void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k) { + const int qk = QK5_1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + + if (v < min) min = v; + if (v > max) max = v; + } + + const float d = (max - min) / ((1 << 5) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + y[i].m = GGML_FP32_TO_FP16(min); + + uint32_t qh = 0; + + for (int j = 0; j < qk/2; ++j) { + const float x0 = (x[i*qk + 0 + j] - min)*id; + const float x1 = (x[i*qk + qk/2 + j] - min)*id; + + const uint8_t xi0 = (uint8_t)(x0 + 0.5f); + const uint8_t xi1 = (uint8_t)(x1 + 0.5f); + + y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + + // get the 5-th bit and store it in qh at the right position + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2); + } + + memcpy(&y[i].qh, &qh, sizeof(y[i].qh)); + } +} + +// reference implementation for deterministic creation of model files +void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k) { + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = x[i*QK8_0 + j]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = x[i*QK8_0 + j]*id; + + y[i].qs[j] = roundf(x0); + } + } +} + +// reference implementation for deterministic creation of model files +void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k) { + assert(QK8_1 == 32); + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_1; j++) { + const float v = x[i*QK8_1 + j]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + int sum = 0; + + for (int j = 0; j < QK8_1/2; ++j) { + const float v0 = x[i*QK8_1 + j]*id; + const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id; + + y[i].qs[ j] = roundf(v0); + y[i].qs[QK8_1/2 + j] = roundf(v1); + + sum += y[i].qs[ j]; + sum += y[i].qs[QK8_1/2 + j]; + } + + y[i].s = GGML_FP32_TO_FP16(sum*d); + } +} + +void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + static const int qk = QK4_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + + for (int j = 0; j < qk/2; ++j) { + const int x0 = (x[i].qs[j] & 0x0F) - 8; + const int x1 = (x[i].qs[j] >> 4) - 8; + + y[i*qk + j + 0 ] = x0*d; + y[i*qk + j + qk/2] = x1*d; + } + } +} + +void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + static const int qk = QK4_1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + const float m = GGML_FP16_TO_FP32(x[i].m); + + for (int j = 0; j < qk/2; ++j) { + const int x0 = (x[i].qs[j] & 0x0F); + const int x1 = (x[i].qs[j] >> 4); + + y[i*qk + j + 0 ] = x0*d + m; + y[i*qk + j + qk/2] = x1*d + m; + } + } +} + +void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + static const int qk = QK5_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16; + const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; + + y[i*qk + j + 0 ] = x0*d; + y[i*qk + j + qk/2] = x1*d; + } + } +} + +void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + static const int qk = QK5_1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + const float m = GGML_FP16_TO_FP32(x[i].m); + + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + const int x0 = (x[i].qs[j] & 0x0F) | xh_0; + const int x1 = (x[i].qs[j] >> 4) | xh_1; + + y[i*qk + j + 0 ] = x0*d + m; + y[i*qk + j + qk/2] = x1*d + m; + } + } +} + +void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + static const int qk = QK8_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + + for (int j = 0; j < qk; ++j) { + y[i*qk + j] = x[i].qs[j]*d; + } + } +} + +// +// 2-6 bit quantization in super-blocks +// + +// +// ===================== Helper functions +// +static inline int nearest_int(float fval) { + assert(fabsf(fval) <= 4194303.f); + float val = fval + 12582912.f; + int i; memcpy(&i, &val, sizeof(int)); + return (i & 0x007fffff) - 0x00400000; +} + +static float make_qx_quants(int n, int nmax, const float * GGML_RESTRICT x, int8_t * GGML_RESTRICT L, int rmse_type, + const float * GGML_RESTRICT qw) { + float max = 0; + float amax = 0; + for (int i = 0; i < n; ++i) { + float ax = fabsf(x[i]); + if (ax > amax) { amax = ax; max = x[i]; } + } + if (amax < GROUP_MAX_EPS) { // all zero + for (int i = 0; i < n; ++i) { + L[i] = 0; + } + return 0.f; + } + float iscale = -nmax / max; + if (rmse_type == 0) { + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + L[i] = nmax + MAX(-nmax, MIN(nmax-1, l)); + } + return 1/iscale; + } + bool return_early = false; + if (rmse_type < 0) { + rmse_type = -rmse_type; + return_early = true; + } + float sumlx = 0; + float suml2 = 0; +#ifdef HAVE_BUGGY_APPLE_LINKER + // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 + for (volatile int i = 0; i < n; ++i) { +#else + for (int i = 0; i < n; ++i) { +#endif + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + L[i] = l + nmax; + float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i])); + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + float scale = suml2 ? sumlx/suml2 : 0.0f; + if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale; + float best = scale * sumlx; + for (int is = -9; is <= 9; ++is) { + if (is == 0) { + continue; + } + iscale = -(nmax + 0.1f*is) / max; + sumlx = suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i])); + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + if (suml2 > 0 && sumlx*sumlx > best*suml2) { + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + L[i] = nmax + MAX(-nmax, MIN(nmax-1, l)); + } + scale = sumlx/suml2; best = scale*sumlx; + } + } + return scale; +} + +static float make_q3_quants(int n, int nmax, const float * GGML_RESTRICT x, int8_t * GGML_RESTRICT L, bool do_rmse) { + float max = 0; + float amax = 0; + for (int i = 0; i < n; ++i) { + float ax = fabsf(x[i]); + if (ax > amax) { amax = ax; max = x[i]; } + } + if (amax < GROUP_MAX_EPS) { // all zero + for (int i = 0; i < n; ++i) { L[i] = 0; } + return 0.f; + } + float iscale = -nmax / max; + if (do_rmse) { + float sumlx = 0; + float suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + L[i] = l; + float w = x[i]*x[i]; + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + for (int itry = 0; itry < 5; ++itry) { + int n_changed = 0; + for (int i = 0; i < n; ++i) { + float w = x[i]*x[i]; + float slx = sumlx - w*x[i]*L[i]; + if (slx > 0) { + float sl2 = suml2 - w*L[i]*L[i]; + int new_l = nearest_int(x[i] * sl2 / slx); + new_l = MAX(-nmax, MIN(nmax-1, new_l)); + if (new_l != L[i]) { + slx += w*x[i]*new_l; + sl2 += w*new_l*new_l; + if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) { + L[i] = new_l; sumlx = slx; suml2 = sl2; + ++n_changed; + } + } + } + } + if (!n_changed) { + break; + } + } + for (int i = 0; i < n; ++i) { + L[i] += nmax; + } + return sumlx / suml2; + } + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + L[i] = l + nmax; + } + return 1/iscale; +} + +static float make_qkx1_quants(int n, int nmax, const float * GGML_RESTRICT x, uint8_t * GGML_RESTRICT L, float * GGML_RESTRICT the_min, + int ntry, float alpha) { + float min = x[0]; + float max = x[0]; + for (int i = 1; i < n; ++i) { + if (x[i] < min) min = x[i]; + if (x[i] > max) max = x[i]; + } + if (max == min) { + for (int i = 0; i < n; ++i) L[i] = 0; + *the_min = 0; + return 0.f; + } + if (min > 0) min = 0; + float iscale = nmax/(max - min); + float scale = 1/iscale; + for (int itry = 0; itry < ntry; ++itry) { + float sumlx = 0; int suml2 = 0; + bool did_change = false; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + l = MAX(0, MIN(nmax, l)); + if (l != L[i]) { + L[i] = l; + did_change = true; + } + sumlx += (x[i] - min)*l; + suml2 += l*l; + } + scale = sumlx/suml2; + float sum = 0; + for (int i = 0; i < n; ++i) { + sum += x[i] - scale*L[i]; + } + min = alpha*min + (1 - alpha)*sum/n; + if (min > 0) min = 0; + iscale = 1/scale; + if (!did_change) break; + } + *the_min = -min; + return scale; +} + +static float make_qkx2_quants(int n, int nmax, const float * GGML_RESTRICT x, const float * GGML_RESTRICT weights, + uint8_t * GGML_RESTRICT L, float * GGML_RESTRICT the_min, uint8_t * GGML_RESTRICT Laux, + float rmin, float rdelta, int nstep, bool use_mad) { + float min = x[0]; + float max = x[0]; + float sum_w = weights[0]; + float sum_x = sum_w * x[0]; +#ifdef HAVE_BUGGY_APPLE_LINKER + // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 + for (volatile int i = 1; i < n; ++i) { +#else + for (int i = 1; i < n; ++i) { +#endif + if (x[i] < min) min = x[i]; + if (x[i] > max) max = x[i]; + float w = weights[i]; + sum_w += w; + sum_x += w * x[i]; + } + if (min > 0) min = 0; + if (max == min) { + for (int i = 0; i < n; ++i) L[i] = 0; + *the_min = -min; + return 0.f; + } + float iscale = nmax/(max - min); + float scale = 1/iscale; + float best_mad = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + L[i] = MAX(0, MIN(nmax, l)); + float diff = scale * L[i] + min - x[i]; + diff = use_mad ? fabsf(diff) : diff * diff; + float w = weights[i]; + best_mad += w * diff; + } + if (nstep < 1) { + *the_min = -min; + return scale; + } + for (int is = 0; is <= nstep; ++is) { + iscale = (rmin + rdelta*is + nmax)/(max - min); + float sum_l = 0, sum_l2 = 0, sum_xl = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + l = MAX(0, MIN(nmax, l)); + Laux[i] = l; + float w = weights[i]; + sum_l += w*l; + sum_l2 += w*l*l; + sum_xl += w*l*x[i]; + } + float D = sum_w * sum_l2 - sum_l * sum_l; + if (D > 0) { + float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D; + float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D; + if (this_min > 0) { + this_min = 0; + this_scale = sum_xl / sum_l2; + } + float mad = 0; + for (int i = 0; i < n; ++i) { + float diff = this_scale * Laux[i] + this_min - x[i]; + diff = use_mad ? fabsf(diff) : diff * diff; + float w = weights[i]; + mad += w * diff; + } + if (mad < best_mad) { + for (int i = 0; i < n; ++i) { + L[i] = Laux[i]; + } + best_mad = mad; + scale = this_scale; + min = this_min; + } + } + } + *the_min = -min; + return scale; +} + +static inline void get_scale_min_k4(int j, const uint8_t * GGML_RESTRICT q, uint8_t * GGML_RESTRICT d, uint8_t * GGML_RESTRICT m) { + if (j < 4) { + *d = q[j] & 63; *m = q[j + 4] & 63; + } else { + *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + } +} + +//========================- 2-bit (de)-quantization + +void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + uint8_t L[QK_K]; + uint8_t Laux[16]; + float weights[16]; + float mins[QK_K/16]; + float scales[QK_K/16]; + + const float q4scale = 15.f; + + for (int i = 0; i < nb; i++) { + float max_scale = 0; // as we are deducting the min, scales are always positive + float max_min = 0; + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]); + scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true); + float scale = scales[j]; + if (scale > max_scale) { + max_scale = scale; + } + float min = mins[j]; + if (min > max_min) { + max_min = min; + } + } + + if (max_scale > 0) { + float iscale = q4scale/max_scale; + for (int j = 0; j < QK_K/16; ++j) { + int l = nearest_int(iscale*scales[j]); + y[i].scales[j] = l; + } + y[i].d = GGML_FP32_TO_FP16(max_scale/q4scale); + } else { + for (int j = 0; j < QK_K/16; ++j) y[i].scales[j] = 0; + y[i].d = GGML_FP32_TO_FP16(0.f); + } + if (max_min > 0) { + float iscale = q4scale/max_min; + for (int j = 0; j < QK_K/16; ++j) { + int l = nearest_int(iscale*mins[j]); + y[i].scales[j] |= (l << 4); + } + y[i].dmin = GGML_FP32_TO_FP16(max_min/q4scale); + } else { + y[i].dmin = GGML_FP32_TO_FP16(0.f); + } + for (int j = 0; j < QK_K/16; ++j) { + const float d = GGML_FP16_TO_FP32(y[i].d) * (y[i].scales[j] & 0xF); + if (!d) continue; + const float dm = GGML_FP16_TO_FP32(y[i].dmin) * (y[i].scales[j] >> 4); + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int((x[16*j + ii] + dm)/d); + l = MAX(0, MIN(3, l)); + L[16*j + ii] = l; + } + } + + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); + } + } + + x += QK_K; + } +} + +void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + const float min = GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * q = x[i].qs; + + int is = 0; + float dl, ml; + for (int n = 0; n < QK_K; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + + uint8_t sc = x[i].scales[is++]; + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml; + + sc = x[i].scales[is++]; + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml; + + shift += 2; + } + q += 32; + } + } +} + +static float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, const float * GGML_RESTRICT weights, + uint8_t * GGML_RESTRICT L, float * GGML_RESTRICT the_min, uint8_t * GGML_RESTRICT Laux, + float rmin, float rdelta, int nstep, bool use_mad) { + float min = x[0]; + float max = x[0]; + float sum_w = weights ? weights[0] : x[0]*x[0]; + float sum_x = sum_w * x[0]; +#ifdef HAVE_BUGGY_APPLE_LINKER + // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 + for (volatile int i = 1; i < n; ++i) { +#else + for (int i = 1; i < n; ++i) { +#endif + if (x[i] < min) min = x[i]; + if (x[i] > max) max = x[i]; + float w = weights ? weights[i] : x[i]*x[i]; + sum_w += w; + sum_x += w * x[i]; + } + if (min > 0) { + min = 0; + } + if (max <= min) { + memset(L, 0, n); + *the_min = -min; + return 0.f; + } + float iscale = nmax/(max - min); + float scale = 1/iscale; + float best_mad = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + L[i] = MAX(0, MIN(nmax, l)); + float diff = scale * L[i] + min - x[i]; + diff = use_mad ? fabsf(diff) : diff*diff; + float w = weights ? weights[i] : x[i]*x[i]; + best_mad += w * diff; + } + if (nstep < 1) { + *the_min = -min; + return scale; + } + for (int is = 0; is <= nstep; ++is) { + iscale = (rmin + rdelta*is + nmax)/(max - min); + float sum_l = 0, sum_l2 = 0, sum_xl = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + l = MAX(0, MIN(nmax, l)); + Laux[i] = l; + float w = weights ? weights[i] : x[i]*x[i]; + sum_l += w*l; + sum_l2 += w*l*l; + sum_xl += w*l*x[i]; + } + float D = sum_w * sum_l2 - sum_l * sum_l; + if (D > 0) { + float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D; + float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D; + if (this_min > 0) { + this_min = 0; + this_scale = sum_xl / sum_l2; + } + float mad = 0; + for (int i = 0; i < n; ++i) { + float diff = this_scale * Laux[i] + this_min - x[i]; + diff = use_mad ? fabsf(diff) : diff*diff; + float w = weights ? weights[i] : x[i]*x[i]; + mad += w * diff; + } + if (mad < best_mad) { + for (int i = 0; i < n; ++i) { + L[i] = Laux[i]; + } + best_mad = mad; + scale = this_scale; + min = this_min; + } + } + } + *the_min = -min; + return scale; +} + +static float make_qp_quants(int n, int nmax, const float * GGML_RESTRICT x, uint8_t * GGML_RESTRICT L, const float * quant_weights) { + float max = 0; + for (int i = 0; i < n; ++i) { + max = MAX(max, x[i]); + } + if (!max) { // all zero + for (int i = 0; i < n; ++i) { L[i] = 0; } + return 0.f; + } + float iscale = nmax / max; + for (int i = 0; i < n; ++i) { + L[i] = nearest_int(iscale * x[i]); + } + float scale = 1/iscale; + float best_mse = 0; + for (int i = 0; i < n; ++i) { + float diff = x[i] - scale*L[i]; + float w = quant_weights[i]; + best_mse += w*diff*diff; + } + for (int is = -4; is <= 4; ++is) { + if (is == 0) continue; + float iscale_is = (0.1f*is + nmax)/max; + float scale_is = 1/iscale_is; + float mse = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale_is*x[i]); + l = MIN(nmax, l); + float diff = x[i] - scale_is*l; + float w = quant_weights[i]; + mse += w*diff*diff; + } + if (mse < best_mse) { + best_mse = mse; + iscale = iscale_is; + } + } + float sumlx = 0; + float suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MIN(nmax, l); + L[i] = l; + float w = quant_weights[i]; + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + for (int itry = 0; itry < 5; ++itry) { + int n_changed = 0; + for (int i = 0; i < n; ++i) { + float w = quant_weights[i]; + float slx = sumlx - w*x[i]*L[i]; + float sl2 = suml2 - w*L[i]*L[i]; + if (slx > 0 && sl2 > 0) { + int new_l = nearest_int(x[i] * sl2 / slx); + new_l = MIN(nmax, new_l); + if (new_l != L[i]) { + slx += w*x[i]*new_l; + sl2 += w*new_l*new_l; + if (slx*slx*suml2 > sumlx*sumlx*sl2) { + L[i] = new_l; sumlx = slx; suml2 = sl2; + ++n_changed; + } + } + } + } + if (!n_changed) { + break; + } + } + return sumlx/suml2; +} + +static void quantize_row_q2_K_impl(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int k, const float * GGML_RESTRICT quant_weights) { + GGML_ASSERT(quant_weights); + assert(k % QK_K == 0); + const int nb = k / QK_K; + const bool requantize = true; + + uint8_t L[QK_K]; + uint8_t Laux[16]; + float mins[QK_K/16]; + float scales[QK_K/16]; + float sw[QK_K/16]; + float weight[16]; + uint8_t Ls[QK_K/16], Lm[QK_K/16]; + + for (int i = 0; i < nb; i++) { + memset(sw, 0, QK_K/16*sizeof(float)); + float sumx2 = 0; + for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j]; + float sigma2 = sumx2/QK_K; + for (int j = 0; j < QK_K/16; ++j) { + const float * GGML_RESTRICT qw = quant_weights + QK_K * i + 16*j; + for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]); + for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l]; + scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false); + } + + float dm, mm; + dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw); + mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw); + + y[i].d = GGML_FP32_TO_FP16(dm); + y[i].dmin = GGML_FP32_TO_FP16(mm); + dm = GGML_FP16_TO_FP32(y[i].d); + mm = GGML_FP16_TO_FP32(y[i].dmin); + + for (int j = 0; j < QK_K/16; ++j) { + y[i].scales[j] = Ls[j] | (Lm[j] << 4); + } + + if (requantize) { + for (int j = 0; j < QK_K/16; ++j) { + const float d = dm * (y[i].scales[j] & 0xF); + if (!d) continue; + const float m = mm * (y[i].scales[j] >> 4); + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int((x[16*j + ii] + m)/d); + l = MAX(0, MIN(3, l)); + L[16*j + ii] = l; + } + } + } + + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); + } + } + + x += QK_K; + } +} + +size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row); + if (!quant_weights) { + quantize_row_q2_K_ref(src, dst, (int64_t)nrow*n_per_row); + } + else { + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q2_K_impl(src, (block_q2_K*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + } + return nrow * row_size; +} + +//========================= 3-bit (de)-quantization + +void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + int8_t L[QK_K]; + float scales[QK_K / 16]; + + for (int i = 0; i < nb; i++) { + + float max_scale = 0; + float amax = 0; + for (int j = 0; j < QK_K/16; ++j) { + scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true); + float scale = fabsf(scales[j]); + if (scale > amax) { + amax = scale; max_scale = scales[j]; + } + } + + memset(y[i].scales, 0, 12); + if (max_scale) { + float iscale = -32.f/max_scale; + for (int j = 0; j < QK_K/16; ++j) { + int8_t l = nearest_int(iscale*scales[j]); + l = MAX(-32, MIN(31, l)) + 32; + if (j < 8) { + y[i].scales[j] = l & 0xF; + } else { + y[i].scales[j-8] |= ((l & 0xF) << 4); + } + l >>= 4; + y[i].scales[j%4 + 8] |= (l << (2*(j/4))); + } + y[i].d = GGML_FP32_TO_FP16(1/iscale); + } else { + y[i].d = GGML_FP32_TO_FP16(0.f); + } + + int8_t sc; + for (int j = 0; j < QK_K/16; ++j) { + sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4; + sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32; + float d = GGML_FP16_TO_FP32(y[i].d) * sc; + if (!d) { + continue; + } + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-4, MIN(3, l)); + L[16*j + ii] = l + 4; + } + } + + memset(y[i].hmask, 0, QK_K/8); + // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc. + int m = 0; + uint8_t hm = 1; + for (int j = 0; j < QK_K; ++j) { + if (L[j] > 3) { + y[i].hmask[m] |= hm; + L[j] -= 4; + } + if (++m == QK_K/8) { + m = 0; hm <<= 1; + } + } + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); + } + } + + x += QK_K; + } +} + +void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + uint32_t aux[4]; + const int8_t * scales = (const int8_t*)aux; + + for (int i = 0; i < nb; i++) { + + const float d_all = GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * GGML_RESTRICT q = x[i].qs; + const uint8_t * GGML_RESTRICT hm = x[i].hmask; + uint8_t m = 1; + + memcpy(aux, x[i].scales, 12); + uint32_t tmp = aux[2]; + aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + + int is = 0; + float dl; + for (int n = 0; n < QK_K; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + + dl = d_all * (scales[is++] - 32); + for (int l = 0; l < 16; ++l) { + *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4)); + } + + dl = d_all * (scales[is++] - 32); + for (int l = 0; l < 16; ++l) { + *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4)); + } + + shift += 2; + m <<= 1; + } + q += 32; + } + + } +} + +static void quantize_row_q3_K_impl(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t n_per_row, const float * GGML_RESTRICT quant_weights) { + assert(n_per_row % QK_K == 0); + const int nb = n_per_row / QK_K; + + int8_t L[QK_K]; + float scales[QK_K / 16]; + float weight[16]; + float sw[QK_K / 16]; + int8_t Ls[QK_K / 16]; + + for (int i = 0; i < nb; i++) { + + float sumx2 = 0; + for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j]; + float sigma2 = 2*sumx2/QK_K; + + for (int j = 0; j < QK_K/16; ++j) { + if (quant_weights) { + const float * qw = quant_weights + QK_K * i + 16*j; + for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j+l]*x[16*j+l]); + } else { + for (int l = 0; l < 16; ++l) weight[l] = x[16*j+l]*x[16*j+l]; + } + float sumw = 0; + for (int l = 0; l < 16; ++l) sumw += weight[l]; + sw[j] = sumw; + + scales[j] = make_qx_quants(16, 4, x + 16*j, L + 16*j, 1, weight); + + } + + memset(y[i].scales, 0, 12); + + float d_block = make_qx_quants(QK_K/16, 32, scales, Ls, 1, sw); + for (int j = 0; j < QK_K/16; ++j) { + int l = Ls[j]; + if (j < 8) { + y[i].scales[j] = l & 0xF; + } else { + y[i].scales[j-8] |= ((l & 0xF) << 4); + } + l >>= 4; + y[i].scales[j%4 + 8] |= (l << (2*(j/4))); + } + y[i].d = GGML_FP32_TO_FP16(d_block); + + int8_t sc; + for (int j = 0; j < QK_K/16; ++j) { + sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4; + sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32; + float d = GGML_FP16_TO_FP32(y[i].d) * sc; + if (!d) { + continue; + } + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-4, MIN(3, l)); + L[16*j + ii] = l + 4; + } + } + + memset(y[i].hmask, 0, QK_K/8); + // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc. + int m = 0; + uint8_t hm = 1; + for (int j = 0; j < QK_K; ++j) { + if (L[j] > 3) { + y[i].hmask[m] |= hm; + L[j] -= 4; + } + if (++m == QK_K/8) { + m = 0; hm <<= 1; + } + } + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); + } + } + + x += QK_K; + } +} + +size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row); + if (!quant_weights) { + quantize_row_q3_K_ref(src, dst, (int64_t)nrow*n_per_row); + } + else { + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q3_K_impl(src, (block_q3_K*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + } + return nrow * row_size; +} + +// ====================== 4-bit (de)-quantization + +void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + uint8_t L[QK_K]; + uint8_t Laux[32]; + float weights[32]; + float mins[QK_K/32]; + float scales[QK_K/32]; + + for (int i = 0; i < nb; i++) { + float max_scale = 0; // as we are deducting the min, scales are always positive + float max_min = 0; + for (int j = 0; j < QK_K/32; ++j) { + //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f); + float sum_x2 = 0; + for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l]; + float av_x = sqrtf(sum_x2/32); + for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]); + scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false); + float scale = scales[j]; + if (scale > max_scale) { + max_scale = scale; + } + float min = mins[j]; + if (min > max_min) { + max_min = min; + } + } + + float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f; + float inv_min = max_min > 0 ? 63.f/max_min : 0.f; + for (int j = 0; j < QK_K/32; ++j) { + uint8_t ls = nearest_int(inv_scale*scales[j]); + uint8_t lm = nearest_int(inv_min*mins[j]); + ls = MIN(63, ls); + lm = MIN(63, lm); + if (j < 4) { + y[i].scales[j] = ls; + y[i].scales[j+4] = lm; + } else { + y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4); + y[i].scales[j-4] |= ((ls >> 4) << 6); + y[i].scales[j-0] |= ((lm >> 4) << 6); + } + } + y[i].d = GGML_FP32_TO_FP16(max_scale/63.f); + y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f); + + uint8_t sc, m; + for (int j = 0; j < QK_K/32; ++j) { + get_scale_min_k4(j, y[i].scales, &sc, &m); + const float d = GGML_FP16_TO_FP32(y[i].d) * sc; + if (!d) continue; + const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + dm)/d); + l = MAX(0, MIN(15, l)); + L[32*j + ii] = l; + } + } + + uint8_t * q = y[i].qs; + for (int j = 0; j < QK_K; j += 64) { + for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4); + q += 32; + } + + x += QK_K; + } +} + +void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + const uint8_t * q = x[i].qs; + + const float d = GGML_FP16_TO_FP32(x[i].d); + const float min = GGML_FP16_TO_FP32(x[i].dmin); + + int is = 0; + uint8_t sc, m; + for (int j = 0; j < QK_K; j += 64) { + get_scale_min_k4(is + 0, x[i].scales, &sc, &m); + const float d1 = d * sc; const float m1 = min * m; + get_scale_min_k4(is + 1, x[i].scales, &sc, &m); + const float d2 = d * sc; const float m2 = min * m; + for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1; + for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; + q += 32; is += 2; + } + } +} + +static void quantize_row_q4_K_impl(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) { + assert(n_per_row % QK_K == 0); + const int64_t nb = n_per_row / QK_K; + + uint8_t L[QK_K]; + uint8_t Laux[32]; + uint8_t Ls[QK_K/32]; + uint8_t Lm[QK_K/32]; + float weights[32]; + float sw[QK_K/32]; + float mins[QK_K/32]; + float scales[QK_K/32]; + + for (int i = 0; i < nb; i++) { + + float sum_x2 = 0; + for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l]; + float sigma2 = 2*sum_x2/QK_K; + float av_x = sqrtf(sigma2); + + for (int j = 0; j < QK_K/32; ++j) { + if (quant_weights) { + const float * qw = quant_weights + QK_K*i + 32*j; + for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]); + } else { + for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]); + } + float sumw = 0; + for (int l = 0; l < 32; ++l) sumw += weights[l]; + sw[j] = sumw; + scales[j] = make_qkx3_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false); + } + + float d_block = make_qp_quants(QK_K/32, 63, scales, Ls, sw); + float m_block = make_qp_quants(QK_K/32, 63, mins, Lm, sw); + for (int j = 0; j < QK_K/32; ++j) { + uint8_t ls = Ls[j]; + uint8_t lm = Lm[j]; + if (j < 4) { + y[i].scales[j] = ls; + y[i].scales[j+4] = lm; + } else { + y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4); + y[i].scales[j-4] |= ((ls >> 4) << 6); + y[i].scales[j-0] |= ((lm >> 4) << 6); + } + } + y[i].d = GGML_FP32_TO_FP16(d_block); + y[i].dmin = GGML_FP32_TO_FP16(m_block); + + uint8_t sc, m; + for (int j = 0; j < QK_K/32; ++j) { + get_scale_min_k4(j, y[i].scales, &sc, &m); + const float d = GGML_FP16_TO_FP32(y[i].d) * sc; + if (!d) continue; + const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + dm)/d); + l = MAX(0, MIN(15, l)); + L[32*j + ii] = l; + } + } + uint8_t * q = y[i].qs; + for (int j = 0; j < QK_K; j += 64) { + for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4); + q += 32; + } + + x += QK_K; + + } +} + +size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row); + if (!quant_weights) { + quantize_row_q4_K_ref(src, dst, (int64_t)nrow*n_per_row); + } + else { + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q4_K_impl(src, (block_q4_K*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + } + return nrow * row_size; +} + +// ====================== 5-bit (de)-quantization + +void quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + uint8_t L[QK_K]; + float mins[QK_K/32]; + float scales[QK_K/32]; + float weights[32]; + uint8_t Laux[32]; + + for (int i = 0; i < nb; i++) { + float max_scale = 0; // as we are deducting the min, scales are always positive + float max_min = 0; + for (int j = 0; j < QK_K/32; ++j) { + //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f); + float sum_x2 = 0; + for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l]; + float av_x = sqrtf(sum_x2/32); + for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]); + scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false); + float scale = scales[j]; + if (scale > max_scale) { + max_scale = scale; + } + float min = mins[j]; + if (min > max_min) { + max_min = min; + } + } + + float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f; + float inv_min = max_min > 0 ? 63.f/max_min : 0.f; + for (int j = 0; j < QK_K/32; ++j) { + uint8_t ls = nearest_int(inv_scale*scales[j]); + uint8_t lm = nearest_int(inv_min*mins[j]); + ls = MIN(63, ls); + lm = MIN(63, lm); + if (j < 4) { + y[i].scales[j] = ls; + y[i].scales[j+4] = lm; + } else { + y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4); + y[i].scales[j-4] |= ((ls >> 4) << 6); + y[i].scales[j-0] |= ((lm >> 4) << 6); + } + } + y[i].d = GGML_FP32_TO_FP16(max_scale/63.f); + y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f); + + uint8_t sc, m; + for (int j = 0; j < QK_K/32; ++j) { + get_scale_min_k4(j, y[i].scales, &sc, &m); + const float d = GGML_FP16_TO_FP32(y[i].d) * sc; + if (!d) continue; + const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + dm)/d); + l = MAX(0, MIN(31, l)); + L[32*j + ii] = l; + } + } + + uint8_t * GGML_RESTRICT qh = y[i].qh; + uint8_t * GGML_RESTRICT ql = y[i].qs; + memset(qh, 0, QK_K/8); + + uint8_t m1 = 1, m2 = 2; + for (int n = 0; n < QK_K; n += 64) { + for (int j = 0; j < 32; ++j) { + int l1 = L[n + j]; + if (l1 > 15) { + l1 -= 16; qh[j] |= m1; + } + int l2 = L[n + j + 32]; + if (l2 > 15) { + l2 -= 16; qh[j] |= m2; + } + ql[j] = l1 | (l2 << 4); + } + m1 <<= 2; m2 <<= 2; + ql += 32; + } + + x += QK_K; + } +} + +void dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + const uint8_t * ql = x[i].qs; + const uint8_t * qh = x[i].qh; + + const float d = GGML_FP16_TO_FP32(x[i].d); + const float min = GGML_FP16_TO_FP32(x[i].dmin); + + int is = 0; + uint8_t sc, m; + uint8_t u1 = 1, u2 = 2; + for (int j = 0; j < QK_K; j += 64) { + get_scale_min_k4(is + 0, x[i].scales, &sc, &m); + const float d1 = d * sc; const float m1 = min * m; + get_scale_min_k4(is + 1, x[i].scales, &sc, &m); + const float d2 = d * sc; const float m2 = min * m; + for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1; + for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2; + ql += 32; is += 2; + u1 <<= 2; u2 <<= 2; + } + } +} + +static void quantize_row_q5_K_impl(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) { + assert(n_per_row % QK_K == 0); + const int64_t nb = n_per_row / QK_K; + + uint8_t L[QK_K]; + uint8_t Laux[32]; + uint8_t Ls[QK_K/32]; + uint8_t Lm[QK_K/32]; + float mins[QK_K/32]; + float scales[QK_K/32]; + float sw[QK_K/32]; + float weights[32]; + + for (int i = 0; i < nb; i++) { + + float sum_x2 = 0; + for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l]; + float sigma2 = 2*sum_x2/QK_K; + float av_x = sqrtf(sigma2); + + for (int j = 0; j < QK_K/32; ++j) { + if (quant_weights) { + const float * qw = quant_weights + QK_K*i + 32*j; + for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]); + } else { + for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]); + } + float sumw = 0; + for (int l = 0; l < 32; ++l) sumw += weights[l]; + sw[j] = sumw; + + scales[j] = make_qkx3_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false); + } + + float d_block = make_qp_quants(QK_K/32, 63, scales, Ls, sw); + float m_block = make_qp_quants(QK_K/32, 63, mins, Lm, sw); + + for (int j = 0; j < QK_K/32; ++j) { + uint8_t ls = Ls[j]; + uint8_t lm = Lm[j]; + ls = MIN(63, ls); + lm = MIN(63, lm); + if (j < 4) { + y[i].scales[j] = ls; + y[i].scales[j+4] = lm; + } else { + y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4); + y[i].scales[j-4] |= ((ls >> 4) << 6); + y[i].scales[j-0] |= ((lm >> 4) << 6); + } + } + y[i].d = GGML_FP32_TO_FP16(d_block); + y[i].dmin = GGML_FP32_TO_FP16(m_block); + + uint8_t sc, m; + for (int j = 0; j < QK_K/32; ++j) { + get_scale_min_k4(j, y[i].scales, &sc, &m); + const float d = GGML_FP16_TO_FP32(y[i].d) * sc; + if (!d) continue; + const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + dm)/d); + l = MAX(0, MIN(31, l)); + L[32*j + ii] = l; + } + } + + uint8_t * GGML_RESTRICT qh = y[i].qh; + uint8_t * GGML_RESTRICT ql = y[i].qs; + memset(qh, 0, QK_K/8); + + uint8_t m1 = 1, m2 = 2; + for (int n = 0; n < QK_K; n += 64) { + for (int j = 0; j < 32; ++j) { + int l1 = L[n + j]; + if (l1 > 15) { + l1 -= 16; qh[j] |= m1; + } + int l2 = L[n + j + 32]; + if (l2 > 15) { + l2 -= 16; qh[j] |= m2; + } + ql[j] = l1 | (l2 << 4); + } + m1 <<= 2; m2 <<= 2; + ql += 32; + } + + x += QK_K; + + } +} + +size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row); + if (!quant_weights) { + quantize_row_q5_K_ref(src, dst, (int64_t)nrow*n_per_row); + } + else { + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q5_K_impl(src, (block_q5_K*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + } + return nrow * row_size; +} + +// ====================== 6-bit (de)-quantization + +void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + int8_t L[QK_K]; + float scales[QK_K/16]; + + for (int i = 0; i < nb; i++) { + + float max_scale = 0; + float max_abs_scale = 0; + + for (int ib = 0; ib < QK_K/16; ++ib) { + + const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL); + scales[ib] = scale; + + const float abs_scale = fabsf(scale); + if (abs_scale > max_abs_scale) { + max_abs_scale = abs_scale; + max_scale = scale; + } + + } + + if (max_abs_scale < GROUP_MAX_EPS) { + memset(&y[i], 0, sizeof(block_q6_K)); + y[i].d = GGML_FP32_TO_FP16(0.f); + x += QK_K; + continue; + } + + float iscale = -128.f/max_scale; + y[i].d = GGML_FP32_TO_FP16(1/iscale); + for (int ib = 0; ib < QK_K/16; ++ib) { + y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib])); + } + + for (int j = 0; j < QK_K/16; ++j) { + float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j]; + if (!d) { + continue; + } + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-32, MIN(31, l)); + L[16*j + ii] = l + 32; + } + } + + uint8_t * GGML_RESTRICT ql = y[i].ql; + uint8_t * GGML_RESTRICT qh = y[i].qh; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + const uint8_t q1 = L[j + l + 0] & 0xF; + const uint8_t q2 = L[j + l + 32] & 0xF; + const uint8_t q3 = L[j + l + 64] & 0xF; + const uint8_t q4 = L[j + l + 96] & 0xF; + ql[l+ 0] = q1 | (q3 << 4); + ql[l+32] = q2 | (q4 << 4); + qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6); + } + ql += 64; + qh += 32; + } + + x += QK_K; + } +} + +void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * GGML_RESTRICT ql = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT sc = x[i].scales; + + for (int n = 0; n < QK_K; n += 128) { + for (int l = 0; l < 32; ++l) { + int is = l/16; + const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + y[l + 0] = d * sc[is + 0] * q1; + y[l + 32] = d * sc[is + 2] * q2; + y[l + 64] = d * sc[is + 4] * q3; + y[l + 96] = d * sc[is + 6] * q4; + } + y += 128; + ql += 64; + qh += 32; + sc += 8; + } + } +} + +static void quantize_row_q6_K_impl(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) { + assert(n_per_row % QK_K == 0); + const int64_t nb = n_per_row / QK_K; + + int8_t L[QK_K]; + float scales[QK_K/16]; + //float weights[16]; + + for (int i = 0; i < nb; i++) { + + //float sum_x2 = 0; + //for (int j = 0; j < QK_K; ++j) sum_x2 += x[j]*x[j]; + //float sigma2 = sum_x2/QK_K; + + float max_scale = 0; + float max_abs_scale = 0; + + for (int ib = 0; ib < QK_K/16; ++ib) { + + float scale; + if (quant_weights) { + const float * qw = quant_weights + QK_K*i + 16*ib; + //for (int j = 0; j < 16; ++j) weights[j] = qw[j] * sqrtf(sigma2 + x[16*ib + j]*x[16*ib + j]); + //scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, weights); + scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, qw); + } else { + scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL); + } + scales[ib] = scale; + + const float abs_scale = fabsf(scale); + if (abs_scale > max_abs_scale) { + max_abs_scale = abs_scale; + max_scale = scale; + } + + } + + if (max_abs_scale < GROUP_MAX_EPS) { + memset(&y[i], 0, sizeof(block_q6_K)); + y[i].d = GGML_FP32_TO_FP16(0.f); + x += QK_K; + continue; + } + + float iscale = -128.f/max_scale; + y[i].d = GGML_FP32_TO_FP16(1/iscale); + for (int ib = 0; ib < QK_K/16; ++ib) { + y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib])); + } + + for (int j = 0; j < QK_K/16; ++j) { + float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j]; + if (!d) { + continue; + } + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-32, MIN(31, l)); + L[16*j + ii] = l + 32; + } + } + + uint8_t * GGML_RESTRICT ql = y[i].ql; + uint8_t * GGML_RESTRICT qh = y[i].qh; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + const uint8_t q1 = L[j + l + 0] & 0xF; + const uint8_t q2 = L[j + l + 32] & 0xF; + const uint8_t q3 = L[j + l + 64] & 0xF; + const uint8_t q4 = L[j + l + 96] & 0xF; + ql[l+ 0] = q1 | (q3 << 4); + ql[l+32] = q2 | (q4 << 4); + qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6); + } + ql += 64; + qh += 32; + } + + x += QK_K; + + } +} + +size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row); + if (!quant_weights) { + quantize_row_q6_K_ref(src, dst, (int64_t)nrow*n_per_row); + } + else { + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q6_K_impl(src, (block_q6_K*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + } + return nrow * row_size; +} + +static void quantize_row_q4_0_impl(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) { + static_assert(QK4_0 == 32, "QK4_0 must be 32"); + + if (!quant_weights) { + quantize_row_q4_0_ref(x, y, n_per_row); + return; + } + + float weight[QK4_0]; + int8_t L[QK4_0]; + + float sum_x2 = 0; + for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; + float sigma2 = sum_x2/n_per_row; + + const int64_t nb = n_per_row/QK4_0; + for (int ib = 0; ib < nb; ++ib) { + const float * xb = x + QK4_0 * ib; + const float * qw = quant_weights + QK4_0 * ib; + for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + float d = make_qx_quants(QK4_0, 8, xb, L, 1, weight); + y[ib].d = GGML_FP32_TO_FP16(d); + for (int j = 0; j < 16; ++j) { + y[ib].qs[j] = L[j] | (L[j+16] << 4); + } + } +} + +size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + quantize_row_q4_0_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row); + } + size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; +} + +static void quantize_row_q4_1_impl(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) { + static_assert(QK4_1 == 32, "QK4_1 must be 32"); + + if (!quant_weights) { + quantize_row_q4_1_ref(x, y, n_per_row); + return; + } + + float weight[QK4_1]; + uint8_t L[QK4_1], Laux[QK4_1]; + + float sum_x2 = 0; + for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; + float sigma2 = sum_x2/n_per_row; + + const int64_t nb = n_per_row/QK4_1; + for (int ib = 0; ib < nb; ++ib) { + const float * xb = x + QK4_1 * ib; + const float * qw = quant_weights + QK4_1 * ib; + for (int j = 0; j < QK4_1; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + float min; + float d = make_qkx3_quants(QK4_1, 15, xb, weight, L, &min, Laux, -0.9f, 0.05f, 36, false); + y[ib].d = GGML_FP32_TO_FP16(d); + y[ib].m = GGML_FP32_TO_FP16(-min); + for (int j = 0; j < 16; ++j) { + y[ib].qs[j] = L[j] | (L[j+16] << 4); + } + } +} + +size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + quantize_row_q4_1_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_Q4_1, n_per_row); + } + size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q4_1_impl(src, (block_q4_1*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; +} + +static void quantize_row_q5_0_impl(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) { + static_assert(QK5_0 == 32, "QK5_0 must be 32"); + + if (!quant_weights) { + quantize_row_q5_0_ref(x, y, n_per_row); + return; + } + + float weight[QK5_0]; + int8_t L[QK5_0]; + + float sum_x2 = 0; + for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; + float sigma2 = sum_x2/n_per_row; + + const int64_t nb = n_per_row/QK5_0; + for (int ib = 0; ib < nb; ++ib) { + const float * xb = x + QK5_0 * ib; + const float * qw = quant_weights + QK5_0 * ib; + for (int j = 0; j < QK5_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + float d = make_qx_quants(QK5_0, 16, xb, L, 1, weight); + y[ib].d = GGML_FP32_TO_FP16(d); + + uint32_t qh = 0; + + for (int j = 0; j < 16; ++j) { + const uint8_t xi0 = L[j]; + const uint8_t xi1 = L[j+16]; + y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + + // get the 5-th bit and store it in qh at the right position + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); + } + + memcpy(&y[ib].qh, &qh, sizeof(qh)); + } +} + +size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + quantize_row_q5_0_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_Q5_0, n_per_row); + } + size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q5_0_impl(src, (block_q5_0*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; +} + +static void quantize_row_q5_1_impl(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) { + static_assert(QK5_1 == 32, "QK5_1 must be 32"); + + if (!quant_weights) { + quantize_row_q5_1_ref(x, y, n_per_row); + return; + } + + float weight[QK5_1]; + uint8_t L[QK5_1], Laux[QK5_1]; + + float sum_x2 = 0; + for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; + float sigma2 = sum_x2/n_per_row; + + const int64_t nb = n_per_row/QK5_1; + for (int ib = 0; ib < nb; ++ib) { + const float * xb = x + QK5_1 * ib; + const float * qw = quant_weights + QK5_1 * ib; + for (int j = 0; j < QK5_1; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + float min; + float d = make_qkx3_quants(QK5_1, 31, xb, weight, L, &min, Laux, -0.9f, 0.05f, 36, false); + y[ib].d = GGML_FP32_TO_FP16(d); + y[ib].m = GGML_FP32_TO_FP16(-min); + + uint32_t qh = 0; + for (int j = 0; j < 16; ++j) { + const uint8_t xi0 = L[j]; + const uint8_t xi1 = L[j+16]; + y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + // get the 5-th bit and store it in qh at the right position + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); + } + memcpy(&y[ib].qh, &qh, sizeof(qh)); + } +} + +size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + quantize_row_q5_1_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_Q5_1, n_per_row); + } + size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q5_1_impl(src, (block_q5_1*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; +} + +size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + (void)quant_weights; // not used + const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row); + quantize_row_q8_0_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * row_size; +} + +// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs) + +void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + for (int64_t i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK_K; j++) { + const float v = x[j]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax; + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + // 5 elements per byte, along 32 bytes + for (size_t j = 0; j < sizeof(y->qs) - sizeof(y->qs) % 32; j += 32) { + for (size_t m = 0; m < 32; ++m) { + uint8_t q = 0; + for (size_t n = 0; n < 5; ++n) { + int xi = lroundf(x[m + n*32] * id) + 1; // -1, 0, 1 -> 0, 1, 2 + q *= 3; + q += xi; + } + // ceiling division (243 == pow(3, 5)) + q = ((uint16_t)q * 256 + (243 - 1)) / 243; + y[i].qs[j + m] = q; + } + x += 5*32; + } + // along 16 bytes + for (size_t j = sizeof(y->qs) - sizeof(y->qs) % 32; j < sizeof(y->qs); j += 16) { + for (size_t m = 0; m < 16; ++m) { + uint8_t q = 0; + for (size_t n = 0; n < 5; ++n) { + int xi = lroundf(x[m + n*16] * id) + 1; // -1, 0, 1 -> 0, 1, 2 + q *= 3; + q += xi; + } + // ceiling division (243 == pow(3, 5)) + q = ((uint16_t)q * 256 + (243 - 1)) / 243; + y[i].qs[j + m] = q; + } + x += 5*16; + } + // 4 elements per byte + for (size_t j = 0; j < sizeof(y->qh); ++j) { + uint8_t q = 0; + for (size_t m = 0; m < 4; ++m) { + // -1, 0, 1 -> 0, 1, 2 + int xi = lroundf(x[j + m*sizeof(y->qh)] * id) + 1; + q *= 3; + q += xi; + } + // shift the first value to the most significant trit + q *= 3; + // ceiling division (243 == pow(3, 5)) + q = ((uint16_t)q * 256 + (243 - 1)) / 243; + y[i].qh[j] = q; + } + x += 4*sizeof(y->qh); + } +} + +void quantize_row_tq2_0_ref(const float * GGML_RESTRICT x, block_tq2_0 * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + for (int64_t i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK_K; j++) { + const float v = x[j]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax; + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + for (size_t j = 0; j < sizeof(y->qs); j += 32) { + for (size_t m = 0; m < 32; ++m) { + uint8_t q = 0; + for (size_t n = 0; n < 4; ++n) { + // -1, 0, 1 -> 0, 1, 2 + int xi = lroundf(x[m + n*32] * id) + 1; + q += (xi & 3) << (2*n); + } + y[i].qs[j + m] = q; + } + x += 4*32; + } + } +} + +size_t quantize_tq1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + (void)quant_weights; // not used + const size_t row_size = ggml_row_size(GGML_TYPE_TQ1_0, n_per_row); + quantize_row_tq1_0_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * row_size; +} + +size_t quantize_tq2_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + (void)quant_weights; // not used + const size_t row_size = ggml_row_size(GGML_TYPE_TQ2_0, n_per_row); + quantize_row_tq2_0_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * row_size; +} + +void dequantize_row_tq1_0(const block_tq1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243}; + + for (int64_t i = 0; i < nb; ++i) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + + for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) { + for (size_t n = 0; n < 5; ++n) { + for (size_t m = 0; m < 32; ++m) { + uint8_t q = x[i].qs[j + m] * pow3[n]; + int16_t xi = ((uint16_t) q * 3) >> 8; + *y++ = (float) (xi - 1) * d; + } + } + } + for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) { + for (size_t n = 0; n < 5; ++n) { + for (size_t m = 0; m < 16; ++m) { + uint8_t q = x[i].qs[j + m] * pow3[n]; + int16_t xi = ((uint16_t) q * 3) >> 8; + *y++ = (float) (xi - 1) * d; + } + } + } + + for (size_t n = 0; n < 4; ++n) { + for (size_t j = 0; j < sizeof(x->qh); ++j) { + uint8_t q = x[i].qh[j] * pow3[n]; + int16_t xi = ((uint16_t) q * 3) >> 8; + *y++ = (float) (xi - 1) * d; + } + } + } +} + +void dequantize_row_tq2_0(const block_tq2_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + for (int64_t i = 0; i < nb; ++i) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + + for (size_t j = 0; j < sizeof(x->qs); j += 32) { + for (size_t l = 0; l < 4; ++l) { + for (size_t m = 0; m < 32; ++m) { + int8_t q = (x[i].qs[j + m] >> (l*2)) & 3; + *y++ = (float) (q - 1) * d; + } + } + } + } +} + +// ====================== "True" 2-bit (de)-quantization + +void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(aux32, x[i].qs + 4*ib32, 2*sizeof(uint32_t)); + const float db = d * (0.5f + (aux32[1] >> 28)) * 0.25f; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); + const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + y[j] = db * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + y += 8; + } + } + } +} + +// ====================== 2.3125 bpw (de)-quantization + +void dequantize_row_iq2_xs(const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + float db[2]; + + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f; + db[1] = d * (0.5f + (x[i].scales[ib32] >> 4)) * 0.25f; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (x[i].qs[4*ib32 + l] & 511)); + const uint8_t signs = ksigns_iq2xs[x[i].qs[4*ib32 + l] >> 9]; + for (int j = 0; j < 8; ++j) { + y[j] = db[l/2] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + y += 8; + } + } + } +} + +// ====================== 2.5625 bpw (de)-quantization + +void dequantize_row_iq2_s(const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + float db[2]; + + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint8_t * signs = qs + QK_K/8; + + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f; + db[1] = d * (0.5f + (x[i].scales[ib32] >> 4)) * 0.25f; + for (int l = 0; l < 4; ++l) { + const float dl = db[l/2]; + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + } + y += 8; + } + qs += 4; + signs += 4; + } + } +} + +// ====================== 3.0625 bpw (de)-quantization + +void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + uint32_t aux32; + + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + const uint8_t * qs = x[i].qs; + const uint8_t * scales_and_signs = qs + QK_K/4; + + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(&aux32, scales_and_signs + 4*ib32, sizeof(uint32_t)); + const float db = d * (0.5f + (aux32 >> 28)) * 0.5f; + for (int l = 0; l < 4; ++l) { + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127]; + const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + qs[2*l+0]); + const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + qs[2*l+1]); + for (int j = 0; j < 4; ++j) { + y[j+0] = db * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = db * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); + } + y += 8; + } + qs += 8; + } + } +} + +// ====================== 3.3125 bpw (de)-quantization + +void dequantize_row_iq3_s(const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint8_t * signs = x[i].signs; + + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const float db1 = d * (1 + 2*(x[i].scales[ib32/2] & 0xf)); + const float db2 = d * (1 + 2*(x[i].scales[ib32/2] >> 4)); + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); + } + y += 8; + } + qs += 8; + signs += 4; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); + } + y += 8; + } + qh += 2; + qs += 8; + signs += 4; + } + } +} + +// ====================== 1.5625 bpw (de)-quantization + +void dequantize_row_iq1_s(const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + for (int ib = 0; ib < QK_K/32; ++ib) { + const float dl = d * (2*((qh[ib] >> 12) & 7) + 1); + const float delta = qh[ib] & 0x8000 ? -IQ1S_DELTA : IQ1S_DELTA; + for (int l = 0; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8))); + for (int j = 0; j < 8; ++j) { + y[j] = dl * (grid[j] + delta); + } + y += 8; + } + qs += 4; + } + } +} + +void dequantize_row_iq1_m(const block_iq1_m * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + float delta[4]; + uint16_t idx[4]; + + iq1m_scale_t scale; + + for (int i = 0; i < nb; i++) { + + const uint16_t * sc = (const uint16_t *)x[i].scales; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + const float d = GGML_FP16_TO_FP32(scale.f16); + + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + + for (int ib = 0; ib < QK_K/32; ++ib) { + const float dl1 = d * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1); + const float dl2 = d * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1); + + idx[0] = qs[0] | ((qh[0] << 8) & 0x700); + idx[1] = qs[1] | ((qh[0] << 4) & 0x700); + idx[2] = qs[2] | ((qh[1] << 8) & 0x700); + idx[3] = qs[3] | ((qh[1] << 4) & 0x700); + delta[0] = qh[0] & 0x08 ? -IQ1S_DELTA : IQ1S_DELTA; + delta[1] = qh[0] & 0x80 ? -IQ1S_DELTA : IQ1S_DELTA; + delta[2] = qh[1] & 0x08 ? -IQ1S_DELTA : IQ1S_DELTA; + delta[3] = qh[1] & 0x80 ? -IQ1S_DELTA : IQ1S_DELTA; + for (int l = 0; l < 2; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]); + for (int j = 0; j < 8; ++j) { + y[j] = dl1 * (grid[j] + delta[l]); + } + y += 8; + } + for (int l = 2; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]); + for (int j = 0; j < 8; ++j) { + y[j] = dl2 * (grid[j] + delta[l]); + } + y += 8; + } + qs += 4; + qh += 2; + } + } +} + +void dequantize_row_iq4_nl(const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK4_NL == 0); + const int64_t nb = k / QK4_NL; + + for (int i = 0; i < nb; i++) { + + const uint8_t * qs = x[i].qs; + + const float d = GGML_FP16_TO_FP32(x[i].d); + for (int j = 0; j < QK4_NL/2; ++j) { + y[j+ 0] = d * kvalues_iq4nl[qs[j] & 0xf]; + y[j+QK4_NL/2] = d * kvalues_iq4nl[qs[j] >> 4]; + } + y += QK4_NL; + qs += QK4_NL/2; + } +} + +void dequantize_row_iq4_xs(const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const uint8_t * qs = x[i].qs; + + const float d = GGML_FP16_TO_FP32(x[i].d); + + for (int ib = 0; ib < QK_K/32; ++ib) { + const int ls = ((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4); + const float dl = d * (ls - 32); + for (int j = 0; j < 16; ++j) { + y[j+ 0] = dl * kvalues_iq4nl[qs[j] & 0xf]; + y[j+16] = dl * kvalues_iq4nl[qs[j] >> 4]; + } + y += 32; + qs += 16; + } + } +} + +//===================================== Q8_K ============================================== + +void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + float max = 0; + float amax = 0; + for (int j = 0; j < QK_K; ++j) { + float ax = fabsf(x[j]); + if (ax > amax) { + amax = ax; max = x[j]; + } + } + if (!amax) { + y[i].d = 0; + memset(y[i].qs, 0, QK_K); + x += QK_K; + continue; + } + //const float iscale = -128.f/max; + // We need this change for IQ2_XXS, else the AVX implementation becomes very awkward + const float iscale = -127.f/max; + for (int j = 0; j < QK_K; ++j) { + int v = nearest_int(iscale*x[j]); + y[i].qs[j] = MIN(127, v); + } + for (int j = 0; j < QK_K/16; ++j) { + int sum = 0; + for (int ii = 0; ii < 16; ++ii) { + sum += y[i].qs[j*16 + ii]; + } + y[i].bsums[j] = sum; + } + y[i].d = 1/iscale; + x += QK_K; + } +} + +void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < QK_K; ++j) { + *y++ = x[i].d * x[i].qs[j]; + } + } +} + +// ================================ IQ2 quantization ============================================= + +typedef struct { + uint64_t * grid; + int * map; + uint16_t * neighbours; +} iq2_entry_t; + +static iq2_entry_t iq2_data[4] = { + {NULL, NULL, NULL}, + {NULL, NULL, NULL}, + {NULL, NULL, NULL}, + {NULL, NULL, NULL}, +}; + +static inline int iq2_data_index(enum ggml_type type) { + GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S); + return type == GGML_TYPE_IQ2_XXS ? 0 : + type == GGML_TYPE_IQ2_XS ? 1 : + type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? 2 : 3; +} + +static inline int iq2_grid_size(enum ggml_type type) { + GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S); + return type == GGML_TYPE_IQ2_XXS ? 256 : + type == GGML_TYPE_IQ2_XS ? 512 : + type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? NGRID_IQ1S : 1024; +} + +static int iq2_compare_func(const void * left, const void * right) { + const int * l = (const int *)left; + const int * r = (const int *)right; + return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0; +} + +void iq2xs_init_impl(enum ggml_type type) { + const int gindex = iq2_data_index(type); + const int grid_size = iq2_grid_size(type); + if (iq2_data[gindex].grid) { + return; + } + static const uint16_t kgrid_2bit_256[256] = { + 0, 2, 5, 8, 10, 17, 20, 32, 34, 40, 42, 65, 68, 80, 88, 97, + 100, 128, 130, 138, 162, 257, 260, 272, 277, 320, 388, 408, 512, 514, 546, 642, + 1025, 1028, 1040, 1057, 1060, 1088, 1090, 1096, 1120, 1153, 1156, 1168, 1188, 1280, 1282, 1288, + 1312, 1350, 1385, 1408, 1425, 1545, 1552, 1600, 1668, 1700, 2048, 2053, 2056, 2068, 2088, 2113, + 2116, 2128, 2130, 2184, 2308, 2368, 2562, 2580, 4097, 4100, 4112, 4129, 4160, 4192, 4228, 4240, + 4245, 4352, 4360, 4384, 4432, 4442, 4480, 4644, 4677, 5120, 5128, 5152, 5157, 5193, 5248, 5400, + 5474, 5632, 5654, 6145, 6148, 6160, 6208, 6273, 6400, 6405, 6560, 6737, 8192, 8194, 8202, 8260, + 8289, 8320, 8322, 8489, 8520, 8704, 8706, 9217, 9220, 9232, 9280, 9302, 9472, 9537, 9572, 9872, + 10248, 10272, 10388, 10820, 16385, 16388, 16400, 16408, 16417, 16420, 16448, 16456, 16470, 16480, 16513, 16516, + 16528, 16640, 16672, 16737, 16768, 16773, 16897, 16912, 16968, 16982, 17000, 17408, 17416, 17440, 17536, 17561, + 17682, 17700, 17920, 18433, 18436, 18448, 18496, 18501, 18688, 18776, 18785, 18818, 19013, 19088, 20480, 20488, + 20497, 20505, 20512, 20608, 20616, 20740, 20802, 20900, 21137, 21648, 21650, 21770, 22017, 22100, 22528, 22545, + 22553, 22628, 22848, 23048, 24580, 24592, 24640, 24680, 24832, 24917, 25112, 25184, 25600, 25605, 25872, 25874, + 25988, 26690, 32768, 32770, 32778, 32833, 32898, 33028, 33048, 33088, 33297, 33793, 33796, 33808, 33813, 33856, + 33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142, + 37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268, + }; + static const uint16_t kgrid_2bit_512[512] = { + 0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70, + 73, 80, 82, 85, 88, 97, 100, 128, 130, 133, 136, 145, 148, 153, 160, 257, + 260, 262, 265, 272, 274, 277, 280, 282, 289, 292, 320, 322, 325, 328, 337, 340, + 352, 360, 385, 388, 400, 512, 514, 517, 520, 529, 532, 544, 577, 580, 592, 597, + 640, 650, 1025, 1028, 1030, 1033, 1040, 1042, 1045, 1048, 1057, 1060, 1088, 1090, 1093, 1096, + 1105, 1108, 1110, 1120, 1153, 1156, 1168, 1280, 1282, 1285, 1288, 1297, 1300, 1312, 1345, 1348, + 1360, 1377, 1408, 1537, 1540, 1552, 1574, 1600, 1602, 1668, 2048, 2050, 2053, 2056, 2058, 2065, + 2068, 2080, 2085, 2113, 2116, 2128, 2136, 2176, 2208, 2218, 2305, 2308, 2320, 2368, 2433, 2441, + 2560, 2592, 2600, 2710, 2720, 4097, 4100, 4102, 4105, 4112, 4114, 4117, 4120, 4129, 4132, 4160, + 4162, 4165, 4168, 4177, 4180, 4192, 4202, 4225, 4228, 4240, 4352, 4354, 4357, 4360, 4369, 4372, + 4384, 4417, 4420, 4432, 4480, 4500, 4502, 4609, 4612, 4614, 4624, 4672, 4704, 5120, 5122, 5125, + 5128, 5137, 5140, 5152, 5185, 5188, 5193, 5200, 5220, 5248, 5377, 5380, 5392, 5440, 5632, 5652, + 5705, 6145, 6148, 6160, 6162, 6208, 6228, 6278, 6400, 6405, 6502, 6737, 6825, 8192, 8194, 8197, + 8200, 8202, 8209, 8212, 8224, 8257, 8260, 8272, 8320, 8352, 8449, 8452, 8464, 8512, 8520, 8549, + 8704, 8738, 8832, 8872, 9217, 9220, 9232, 9257, 9280, 9472, 9537, 9554, 9625, 9729, 9754, 9894, + 10240, 10248, 10250, 10272, 10325, 10376, 10402, 10600, 10640, 10760, 10784, 10882, 10888, 10890, 16385, 16388, + 16390, 16393, 16400, 16402, 16405, 16408, 16417, 16420, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16480, + 16485, 16513, 16516, 16528, 16640, 16642, 16645, 16648, 16657, 16660, 16672, 16705, 16708, 16720, 16768, 16773, + 16802, 16897, 16900, 16912, 16914, 16937, 16960, 17408, 17410, 17413, 17416, 17425, 17428, 17433, 17440, 17473, + 17476, 17488, 17536, 17556, 17665, 17668, 17680, 17700, 17728, 17818, 17920, 17930, 17988, 18000, 18433, 18436, + 18448, 18496, 18501, 18516, 18530, 18688, 18705, 18756, 18768, 18793, 18948, 20480, 20482, 20485, 20488, 20497, + 20500, 20512, 20520, 20545, 20548, 20560, 20608, 20737, 20740, 20752, 20757, 20800, 20802, 20992, 21060, 21162, + 21505, 21508, 21520, 21537, 21568, 21600, 21633, 21665, 21760, 21768, 21888, 21896, 22049, 22120, 22177, 22528, + 22548, 22593, 22608, 22681, 22810, 22848, 22850, 23173, 24577, 24580, 24592, 24640, 24660, 24674, 24710, 24745, + 24832, 25124, 25162, 25234, 25600, 25622, 25872, 25920, 25925, 26020, 26625, 26730, 26917, 27142, 27220, 27234, + 32768, 32770, 32773, 32776, 32785, 32788, 32800, 32810, 32833, 32836, 32848, 32896, 32898, 32936, 32938, 33025, + 33028, 33030, 33040, 33088, 33105, 33113, 33280, 33312, 33408, 33410, 33440, 33448, 33793, 33796, 33808, 33810, + 33813, 33856, 33888, 33929, 34048, 34116, 34213, 34328, 34410, 34816, 34824, 34853, 34906, 34944, 34946, 34984, + 35078, 35362, 35456, 35464, 35478, 35496, 36865, 36868, 36880, 36928, 36950, 36996, 37120, 37154, 37220, 37462, + 37513, 37888, 37893, 37956, 37968, 37976, 38185, 38288, 38290, 38465, 38993, 39078, 39241, 39445, 39520, 40960, + 40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048, + 42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690, + }; + static const uint16_t kgrid_1bit_2048[NGRID_IQ1S] = { + 0, 2, 5, 8, 10, 17, 21, 32, 34, 40, 42, 69, 81, 84, 86, 101, + 128, 130, 136, 138, 149, 160, 162, 168, 170, 260, 261, 273, 276, 278, 281, 282, + 293, 321, 326, 329, 338, 341, 346, 353, 356, 358, 360, 389, 401, 404, 406, 421, + 512, 514, 520, 522, 533, 544, 546, 552, 554, 581, 593, 601, 612, 617, 640, 642, + 648, 650, 657, 661, 665, 672, 674, 680, 682, 1041, 1044, 1046, 1061, 1089, 1097, 1109, + 1114, 1124, 1125, 1169, 1177, 1189, 1281, 1284, 1285, 1286, 1301, 1304, 1306, 1321, 1344, 1349, + 1354, 1360, 1361, 1364, 1365, 1366, 1369, 1376, 1378, 1381, 1384, 1386, 1409, 1425, 1429, 1432, + 1434, 1441, 1444, 1445, 1446, 1449, 1556, 1561, 1601, 1604, 1616, 1618, 1621, 1624, 1632, 1633, + 1638, 1641, 1669, 1681, 1684, 1689, 2048, 2050, 2056, 2058, 2069, 2080, 2082, 2088, 2090, 2117, + 2129, 2134, 2149, 2176, 2178, 2184, 2186, 2197, 2208, 2210, 2216, 2218, 2309, 2321, 2324, 2329, + 2340, 2341, 2369, 2384, 2385, 2389, 2401, 2404, 2409, 2449, 2452, 2454, 2457, 2469, 2560, 2562, + 2568, 2570, 2581, 2592, 2594, 2600, 2602, 2629, 2641, 2649, 2657, 2661, 2688, 2690, 2693, 2696, + 2698, 2709, 2720, 2722, 2728, 2730, 4112, 4113, 4116, 4121, 4132, 4133, 4161, 4164, 4176, 4181, + 4184, 4193, 4196, 4197, 4201, 4241, 4244, 4246, 4257, 4261, 4353, 4356, 4358, 4361, 4368, 4370, + 4373, 4376, 4385, 4388, 4393, 4421, 4426, 4432, 4433, 4434, 4436, 4437, 4438, 4441, 4448, 4453, + 4484, 4498, 4501, 4513, 4516, 4625, 4628, 4630, 4645, 4672, 4678, 4681, 4690, 4693, 4696, 4698, + 4708, 4710, 4741, 4753, 4756, 4758, 4773, 5121, 5126, 5129, 5140, 5141, 5144, 5145, 5153, 5158, + 5185, 5189, 5190, 5192, 5194, 5201, 5204, 5205, 5206, 5209, 5218, 5221, 5224, 5252, 5257, 5264, + 5268, 5269, 5272, 5273, 5274, 5281, 5284, 5285, 5289, 5378, 5381, 5386, 5393, 5396, 5397, 5398, + 5401, 5408, 5410, 5413, 5416, 5418, 5441, 5444, 5445, 5446, 5457, 5458, 5460, 5461, 5462, 5465, + 5466, 5473, 5476, 5477, 5478, 5481, 5504, 5506, 5508, 5509, 5512, 5514, 5520, 5521, 5524, 5525, + 5526, 5529, 5530, 5536, 5538, 5541, 5633, 5636, 5637, 5638, 5653, 5654, 5656, 5658, 5665, 5670, + 5696, 5698, 5700, 5701, 5704, 5706, 5713, 5717, 5718, 5720, 5721, 5729, 5732, 5733, 5736, 5737, + 5738, 5766, 5770, 5778, 5781, 5796, 5801, 6161, 6166, 6181, 6209, 6212, 6214, 6217, 6224, 6229, + 6232, 6234, 6240, 6241, 6244, 6246, 6249, 6277, 6289, 6292, 6309, 6416, 6418, 6421, 6426, 6433, + 6437, 6466, 6468, 6469, 6472, 6481, 6484, 6485, 6486, 6489, 6490, 6496, 6501, 6506, 6537, 6545, + 6546, 6549, 6552, 6561, 6566, 6569, 6665, 6678, 6692, 6694, 6724, 6726, 6729, 6736, 6738, 6741, + 6744, 6753, 6758, 6761, 6789, 6801, 6806, 6810, 8192, 8194, 8200, 8202, 8213, 8224, 8226, 8229, + 8232, 8234, 8261, 8273, 8281, 8289, 8293, 8320, 8322, 8328, 8330, 8341, 8352, 8354, 8357, 8360, + 8362, 8453, 8465, 8468, 8473, 8485, 8514, 8516, 8521, 8533, 8536, 8538, 8545, 8548, 8549, 8550, + 8581, 8592, 8598, 8601, 8613, 8705, 8712, 8714, 8721, 8725, 8736, 8738, 8744, 8746, 8773, 8785, + 8790, 8793, 8805, 8833, 8840, 8842, 8849, 8853, 8864, 8866, 8872, 8874, 9221, 9236, 9238, 9241, + 9253, 9284, 9285, 9286, 9289, 9298, 9301, 9304, 9306, 9318, 9349, 9361, 9364, 9369, 9377, 9381, + 9481, 9493, 9505, 9513, 9536, 9541, 9544, 9553, 9556, 9557, 9561, 9570, 9573, 9576, 9609, 9616, + 9620, 9621, 9624, 9626, 9633, 9636, 9638, 9641, 9733, 9744, 9746, 9753, 9765, 9793, 9801, 9813, + 9824, 9825, 9833, 9860, 9862, 9872, 9882, 10240, 10242, 10248, 10250, 10261, 10272, 10274, 10280, 10282, + 10309, 10321, 10324, 10341, 10368, 10370, 10376, 10378, 10400, 10402, 10408, 10410, 10505, 10513, 10516, 10521, + 10533, 10566, 10569, 10578, 10581, 10593, 10596, 10598, 10601, 10629, 10640, 10646, 10649, 10660, 10661, 10752, + 10754, 10760, 10762, 10784, 10786, 10792, 10794, 10821, 10833, 10838, 10841, 10853, 10880, 10882, 10888, 10890, + 10901, 10912, 10914, 10920, 10922, 16389, 16401, 16406, 16421, 16457, 16466, 16469, 16472, 16474, 16481, 16484, + 16486, 16532, 16537, 16545, 16550, 16640, 16641, 16644, 16646, 16649, 16658, 16661, 16662, 16664, 16666, 16673, + 16678, 16681, 16709, 16712, 16714, 16721, 16724, 16725, 16726, 16729, 16730, 16741, 16744, 16746, 16769, 16772, + 16774, 16784, 16786, 16789, 16800, 16801, 16802, 16901, 16913, 16916, 16918, 16933, 16961, 16978, 16981, 16986, + 16996, 17001, 17033, 17044, 17061, 17409, 17429, 17433, 17449, 17477, 17480, 17482, 17489, 17492, 17493, 17494, + 17505, 17506, 17509, 17512, 17514, 17537, 17542, 17545, 17552, 17554, 17557, 17568, 17569, 17577, 17665, 17666, + 17669, 17674, 17681, 17684, 17685, 17686, 17689, 17696, 17701, 17706, 17729, 17732, 17733, 17734, 17737, 17744, + 17745, 17748, 17749, 17750, 17752, 17753, 17761, 17764, 17765, 17766, 17769, 17794, 17796, 17797, 17800, 17809, + 17812, 17813, 17814, 17817, 17818, 17829, 17832, 17834, 17921, 17925, 17929, 17940, 17941, 17944, 17946, 17953, + 17956, 17961, 17984, 17986, 17989, 17992, 18000, 18001, 18002, 18005, 18006, 18009, 18018, 18021, 18024, 18049, + 18053, 18058, 18068, 18069, 18081, 18084, 18086, 18437, 18449, 18453, 18458, 18469, 18498, 18505, 18512, 18517, + 18520, 18529, 18532, 18534, 18537, 18565, 18577, 18580, 18582, 18585, 18597, 18689, 18693, 18694, 18698, 18704, + 18708, 18709, 18712, 18721, 18724, 18726, 18752, 18757, 18762, 18769, 18770, 18772, 18773, 18774, 18777, 18784, + 18786, 18789, 18790, 18794, 18822, 18825, 18834, 18837, 18838, 18840, 18849, 18852, 18854, 18857, 18966, 19012, + 19014, 19017, 19029, 19032, 19034, 19044, 19049, 19092, 19109, 20481, 20484, 20485, 20486, 20489, 20498, 20501, + 20506, 20513, 20516, 20521, 20544, 20549, 20552, 20561, 20564, 20565, 20566, 20569, 20581, 20584, 20614, 20617, + 20629, 20632, 20640, 20641, 20646, 20649, 20741, 20744, 20745, 20746, 20753, 20756, 20757, 20758, 20760, 20761, + 20768, 20773, 20774, 20776, 20778, 20801, 20804, 20805, 20806, 20809, 20816, 20817, 20818, 20820, 20821, 20822, + 20824, 20825, 20826, 20833, 20836, 20837, 20838, 20841, 20866, 20869, 20881, 20884, 20885, 20886, 20889, 20896, + 20901, 20906, 20993, 20998, 21010, 21013, 21018, 21025, 21028, 21058, 21061, 21066, 21073, 21076, 21077, 21078, + 21081, 21090, 21093, 21125, 21136, 21138, 21141, 21145, 21146, 21156, 21508, 21509, 21521, 21524, 21525, 21526, + 21528, 21529, 21537, 21541, 21544, 21546, 21569, 21572, 21573, 21574, 21577, 21578, 21584, 21585, 21588, 21589, + 21590, 21592, 21593, 21594, 21601, 21602, 21604, 21605, 21606, 21609, 21632, 21640, 21642, 21649, 21652, 21653, + 21654, 21657, 21665, 21668, 21669, 21674, 21761, 21762, 21764, 21765, 21766, 21769, 21776, 21777, 21778, 21780, + 21781, 21782, 21785, 21786, 21793, 21796, 21797, 21798, 21801, 21824, 21825, 21826, 21828, 21829, 21830, 21832, + 21833, 21840, 21841, 21842, 21844, 21845, 21846, 21848, 21849, 21850, 21856, 21857, 21860, 21861, 21862, 21864, + 21865, 21866, 21889, 21892, 21893, 21897, 21898, 21904, 21905, 21908, 21909, 21910, 21912, 21913, 21921, 21924, + 21925, 21926, 21929, 22016, 22017, 22018, 22020, 22022, 22024, 22025, 22033, 22036, 22037, 22040, 22041, 22048, + 22049, 22050, 22052, 22053, 22054, 22056, 22057, 22081, 22085, 22086, 22088, 22089, 22090, 22096, 22097, 22098, + 22100, 22101, 22102, 22104, 22105, 22106, 22113, 22116, 22117, 22121, 22146, 22149, 22150, 22152, 22153, 22154, + 22161, 22165, 22170, 22178, 22181, 22182, 22184, 22185, 22532, 22533, 22534, 22537, 22544, 22549, 22552, 22561, + 22570, 22597, 22600, 22602, 22609, 22612, 22613, 22614, 22616, 22617, 22624, 22626, 22628, 22629, 22658, 22665, + 22672, 22674, 22677, 22680, 22689, 22697, 22785, 22786, 22789, 22794, 22801, 22804, 22805, 22806, 22809, 22821, + 22849, 22852, 22853, 22854, 22857, 22864, 22865, 22866, 22868, 22869, 22870, 22872, 22873, 22874, 22881, 22884, + 22885, 22886, 22889, 22913, 22917, 22921, 22929, 22932, 22933, 22934, 22936, 22937, 22949, 23044, 23048, 23061, + 23066, 23072, 23077, 23078, 23081, 23109, 23112, 23113, 23121, 23125, 23126, 23128, 23129, 23138, 23141, 23144, + 23146, 23169, 23178, 23186, 23189, 23190, 23192, 23194, 23201, 24581, 24596, 24598, 24601, 24613, 24644, 24656, + 24661, 24662, 24664, 24666, 24673, 24676, 24678, 24681, 24705, 24726, 24741, 24833, 24836, 24838, 24841, 24850, + 24853, 24865, 24866, 24870, 24873, 24901, 24905, 24913, 24917, 24918, 24921, 24933, 24934, 24938, 24964, 24970, + 24978, 24981, 24993, 24998, 25001, 25105, 25110, 25113, 25152, 25153, 25158, 25173, 25174, 25176, 25184, 25221, + 25233, 25238, 25253, 25617, 25618, 25621, 25622, 25626, 25633, 25638, 25641, 25664, 25666, 25669, 25672, 25674, + 25681, 25684, 25685, 25686, 25689, 25690, 25696, 25698, 25701, 25732, 25733, 25737, 25744, 25746, 25748, 25749, + 25750, 25752, 25754, 25761, 25764, 25769, 25861, 25864, 25866, 25873, 25877, 25878, 25881, 25924, 25925, 25926, + 25929, 25936, 25937, 25940, 25941, 25942, 25945, 25953, 25956, 25957, 25958, 25961, 25990, 25993, 25994, 26001, + 26005, 26006, 26009, 26010, 26018, 26021, 26022, 26024, 26114, 26121, 26133, 26144, 26150, 26152, 26153, 26176, + 26181, 26184, 26186, 26193, 26196, 26197, 26198, 26200, 26202, 26208, 26213, 26216, 26240, 26242, 26245, 26250, + 26260, 26262, 26264, 26265, 26272, 26276, 26278, 26282, 26646, 26649, 26661, 26689, 26706, 26709, 26714, 26721, + 26729, 26757, 26769, 26776, 26790, 26881, 26884, 26896, 26901, 26913, 26916, 26918, 26921, 26944, 26945, 26949, + 26950, 26952, 26961, 26964, 26965, 26966, 26969, 26976, 26981, 26986, 27010, 27012, 27018, 27029, 27041, 27044, + 27045, 27049, 27153, 27158, 27160, 27201, 27204, 27209, 27216, 27221, 27224, 27226, 27236, 27237, 27241, 27270, + 27284, 27288, 27290, 27302, 32768, 32770, 32776, 32778, 32800, 32802, 32808, 32810, 32837, 32848, 32849, 32852, + 32854, 32857, 32869, 32896, 32898, 32904, 32906, 32917, 32928, 32930, 32936, 32938, 33029, 33041, 33044, 33046, + 33049, 33061, 33089, 33092, 33097, 33104, 33106, 33109, 33110, 33112, 33113, 33124, 33126, 33129, 33157, 33161, + 33172, 33174, 33177, 33189, 33280, 33282, 33288, 33290, 33301, 33312, 33314, 33320, 33322, 33361, 33364, 33369, + 33381, 33408, 33410, 33416, 33418, 33429, 33440, 33442, 33448, 33450, 33812, 33817, 33857, 33860, 33873, 33877, + 33882, 33889, 33892, 33897, 33940, 33945, 34049, 34057, 34066, 34069, 34074, 34086, 34089, 34112, 34113, 34117, + 34120, 34129, 34132, 34133, 34134, 34137, 34138, 34149, 34150, 34152, 34154, 34177, 34180, 34182, 34185, 34192, + 34194, 34197, 34200, 34214, 34321, 34326, 34329, 34341, 34369, 34372, 34377, 34378, 34384, 34389, 34393, 34394, + 34401, 34406, 34410, 34437, 34449, 34458, 34468, 34816, 34818, 34824, 34826, 34837, 34848, 34850, 34856, 34858, + 34881, 34885, 34897, 34900, 34905, 34917, 34921, 34944, 34946, 34952, 34954, 34965, 34976, 34978, 34984, 34986, + 35077, 35078, 35089, 35092, 35094, 35109, 35137, 35140, 35142, 35145, 35152, 35154, 35157, 35162, 35169, 35172, + 35205, 35222, 35225, 35237, 35328, 35330, 35336, 35338, 35349, 35360, 35362, 35368, 35370, 35397, 35409, 35412, + 35414, 35456, 35458, 35464, 35466, 35477, 35488, 35490, 35496, 35498, 36869, 36881, 36886, 36888, 36889, 36901, + 36929, 36934, 36937, 36949, 36952, 36954, 36969, 36970, 36997, 37009, 37012, 37014, 37017, 37029, 37121, 37124, + 37126, 37129, 37136, 37141, 37144, 37146, 37153, 37156, 37158, 37161, 37184, 37189, 37200, 37201, 37204, 37205, + 37206, 37209, 37218, 37221, 37252, 37254, 37266, 37269, 37272, 37281, 37284, 37286, 37289, 37381, 37393, 37396, + 37401, 37413, 37444, 37446, 37449, 37456, 37458, 37461, 37464, 37478, 37481, 37509, 37524, 37526, 37545, 37889, + 37892, 37894, 37904, 37909, 37912, 37926, 37952, 37962, 37969, 37972, 37973, 37974, 37976, 37977, 37984, 37985, + 37986, 37989, 38020, 38022, 38034, 38036, 38037, 38040, 38049, 38057, 38144, 38149, 38152, 38154, 38160, 38161, + 38164, 38165, 38166, 38169, 38177, 38181, 38185, 38186, 38209, 38212, 38213, 38214, 38217, 38224, 38225, 38226, + 38228, 38229, 38230, 38232, 38233, 38234, 38241, 38244, 38245, 38246, 38249, 38273, 38277, 38280, 38289, 38290, + 38292, 38293, 38294, 38297, 38298, 38304, 38306, 38309, 38312, 38314, 38401, 38404, 38416, 38421, 38425, 38432, + 38438, 38441, 38469, 38472, 38473, 38481, 38482, 38485, 38486, 38489, 38501, 38504, 38530, 38532, 38537, 38538, + 38546, 38548, 38549, 38564, 38566, 38569, 38917, 38934, 38937, 38949, 38977, 38982, 38992, 38994, 38997, 38998, + 39002, 39012, 39013, 39045, 39057, 39062, 39065, 39077, 39172, 39174, 39177, 39184, 39186, 39189, 39192, 39194, + 39200, 39201, 39204, 39206, 39232, 39234, 39237, 39240, 39242, 39249, 39252, 39253, 39254, 39257, 39266, 39269, + 39270, 39274, 39297, 39300, 39312, 39314, 39317, 39322, 39329, 39334, 39429, 39445, 39461, 39492, 39494, 39497, + 39504, 39509, 39512, 39521, 39557, 39569, 39572, 39573, 39574, 40960, 40962, 40968, 40970, 40981, 40992, 40994, + 41000, 41002, 41029, 41041, 41044, 41046, 41049, 41088, 41090, 41096, 41098, 41109, 41120, 41122, 41128, 41130, + 41221, 41225, 41233, 41236, 41238, 41241, 41242, 41286, 41289, 41297, 41301, 41304, 41306, 41313, 41316, 41349, + 41360, 41362, 41366, 41369, 41474, 41480, 41482, 41488, 41497, 41506, 41512, 41514, 41541, 41553, 41558, 41561, + 41573, 41600, 41602, 41608, 41610, 41621, 41632, 41634, 41640, 41642, 42009, 42021, 42049, 42052, 42064, 42068, + 42069, 42072, 42074, 42081, 42085, 42086, 42088, 42089, 42117, 42246, 42249, 42256, 42258, 42261, 42264, 42278, + 42281, 42306, 42309, 42321, 42324, 42325, 42326, 42329, 42341, 42346, 42369, 42372, 42373, 42374, 42377, 42386, + 42389, 42392, 42501, 42513, 42518, 42522, 42529, 42533, 42564, 42566, 42570, 42578, 42581, 42582, 42584, 42592, + 42594, 42630, 42640, 42645, 42646, 42649, 42657, 42660, 42662, 43008, 43010, 43016, 43018, 43040, 43042, 43048, + 43050, 43089, 43092, 43094, 43097, 43136, 43138, 43144, 43146, 43157, 43168, 43170, 43176, 43178, 43269, 43284, + 43289, 43297, 43301, 43329, 43344, 43349, 43354, 43361, 43366, 43369, 43408, 43414, 43520, 43522, 43528, 43530, + 43552, 43554, 43560, 43562, 43601, 43604, 43606, 43648, 43650, 43656, 43658, 43669, 43680, 43682, 43688, 43690, + }; + static const uint16_t kgrid_2bit_1024[1024] = { + 0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70, + 73, 80, 82, 85, 88, 97, 100, 102, 105, 128, 130, 133, 136, 145, 148, 160, + 165, 170, 257, 260, 262, 265, 272, 274, 277, 280, 289, 292, 320, 322, 325, 328, + 337, 340, 342, 345, 352, 357, 360, 385, 388, 400, 402, 405, 417, 420, 512, 514, + 517, 520, 529, 532, 544, 554, 577, 580, 582, 585, 592, 597, 640, 645, 650, 660, + 674, 1025, 1028, 1030, 1033, 1040, 1042, 1045, 1048, 1057, 1060, 1062, 1065, 1088, 1090, 1093, + 1096, 1098, 1105, 1108, 1110, 1113, 1120, 1122, 1125, 1153, 1156, 1158, 1161, 1168, 1173, 1176, + 1185, 1188, 1280, 1282, 1285, 1288, 1290, 1297, 1300, 1302, 1305, 1312, 1317, 1320, 1345, 1348, + 1350, 1353, 1360, 1362, 1365, 1368, 1377, 1380, 1408, 1410, 1413, 1416, 1425, 1428, 1440, 1537, + 1540, 1542, 1545, 1552, 1557, 1600, 1605, 1608, 1617, 1620, 1632, 1665, 1668, 1680, 2048, 2050, + 2053, 2056, 2065, 2068, 2070, 2073, 2080, 2085, 2090, 2113, 2116, 2118, 2121, 2128, 2130, 2133, + 2136, 2145, 2148, 2176, 2181, 2196, 2218, 2305, 2308, 2320, 2322, 2325, 2328, 2337, 2368, 2373, + 2376, 2385, 2388, 2400, 2433, 2448, 2560, 2577, 2580, 2594, 2600, 2602, 2640, 2713, 4097, 4100, + 4102, 4105, 4112, 4114, 4117, 4120, 4129, 4132, 4134, 4160, 4162, 4165, 4168, 4177, 4180, 4182, + 4185, 4192, 4194, 4197, 4200, 4225, 4228, 4230, 4240, 4245, 4248, 4257, 4260, 4352, 4354, 4357, + 4360, 4362, 4369, 4372, 4374, 4377, 4384, 4386, 4389, 4392, 4417, 4420, 4422, 4425, 4432, 4434, + 4437, 4440, 4449, 4452, 4480, 4482, 4485, 4488, 4497, 4500, 4609, 4612, 4617, 4624, 4629, 4641, + 4644, 4672, 4677, 4689, 4692, 4737, 4740, 4752, 5120, 5122, 5125, 5128, 5137, 5140, 5142, 5145, + 5152, 5157, 5160, 5185, 5188, 5190, 5193, 5200, 5202, 5205, 5208, 5217, 5220, 5248, 5250, 5253, + 5256, 5265, 5268, 5280, 5377, 5380, 5382, 5385, 5392, 5394, 5397, 5400, 5409, 5412, 5440, 5442, + 5445, 5448, 5457, 5460, 5472, 5505, 5508, 5520, 5632, 5637, 5640, 5649, 5652, 5664, 5697, 5700, + 5712, 5760, 5802, 6145, 6148, 6150, 6153, 6160, 6165, 6168, 6177, 6208, 6210, 6213, 6216, 6225, + 6228, 6240, 6273, 6276, 6400, 6402, 6405, 6408, 6417, 6420, 6432, 6465, 6468, 6480, 6505, 6562, + 6660, 6672, 6720, 6742, 8192, 8194, 8197, 8200, 8209, 8212, 8214, 8217, 8224, 8229, 8234, 8257, + 8260, 8272, 8274, 8277, 8292, 8320, 8330, 8340, 8362, 8449, 8452, 8464, 8466, 8469, 8481, 8512, + 8514, 8517, 8529, 8532, 8544, 8577, 8580, 8592, 8704, 8714, 8738, 8744, 8746, 8772, 8784, 8840, + 8842, 8872, 9217, 9220, 9222, 9225, 9232, 9237, 9240, 9249, 9252, 9280, 9282, 9285, 9288, 9297, + 9300, 9312, 9345, 9348, 9360, 9472, 9477, 9480, 9489, 9492, 9504, 9537, 9540, 9552, 9574, 9600, + 9729, 9732, 9744, 9792, 9817, 10240, 10245, 10257, 10260, 10305, 10308, 10320, 10378, 10410, 10497, 10500, + 10512, 10645, 10762, 10786, 10852, 10888, 10890, 16385, 16388, 16390, 16393, 16400, 16402, 16405, 16408, 16410, + 16417, 16420, 16422, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16470, 16473, 16480, 16482, 16485, 16513, + 16516, 16528, 16533, 16536, 16545, 16548, 16640, 16642, 16645, 16648, 16657, 16660, 16662, 16665, 16672, 16674, + 16677, 16705, 16708, 16710, 16713, 16720, 16722, 16725, 16728, 16737, 16740, 16768, 16770, 16773, 16776, 16785, + 16788, 16800, 16897, 16900, 16912, 16914, 16917, 16920, 16932, 16960, 16965, 16968, 16977, 16980, 16992, 17025, + 17028, 17408, 17410, 17413, 17416, 17418, 17425, 17428, 17430, 17433, 17440, 17442, 17445, 17448, 17473, 17476, + 17478, 17481, 17488, 17490, 17493, 17496, 17505, 17508, 17536, 17538, 17541, 17544, 17553, 17556, 17568, 17665, + 17668, 17670, 17673, 17680, 17682, 17685, 17688, 17697, 17700, 17728, 17730, 17733, 17736, 17745, 17748, 17760, + 17770, 17793, 17796, 17808, 17920, 17922, 17925, 17928, 17937, 17940, 17952, 17985, 17988, 18000, 18048, 18085, + 18433, 18436, 18441, 18448, 18450, 18453, 18456, 18465, 18468, 18496, 18498, 18501, 18504, 18513, 18516, 18528, + 18564, 18576, 18688, 18690, 18693, 18696, 18705, 18708, 18720, 18753, 18756, 18768, 18816, 18838, 18945, 18948, + 18960, 19008, 20480, 20482, 20485, 20488, 20497, 20500, 20502, 20505, 20512, 20514, 20517, 20520, 20545, 20548, + 20550, 20553, 20560, 20562, 20565, 20568, 20577, 20580, 20608, 20610, 20613, 20616, 20625, 20628, 20737, 20740, + 20742, 20745, 20752, 20754, 20757, 20760, 20769, 20772, 20800, 20802, 20805, 20808, 20817, 20820, 20832, 20865, + 20868, 20880, 20992, 20997, 21000, 21009, 21012, 21024, 21057, 21060, 21072, 21097, 21120, 21505, 21508, 21510, + 21513, 21520, 21522, 21525, 21528, 21537, 21540, 21568, 21570, 21573, 21576, 21585, 21588, 21600, 21633, 21636, + 21648, 21760, 21762, 21765, 21768, 21777, 21780, 21792, 21825, 21828, 21840, 21888, 22017, 22020, 22032, 22054, + 22080, 22528, 22530, 22533, 22536, 22545, 22548, 22560, 22593, 22596, 22608, 22618, 22656, 22785, 22788, 22800, + 22848, 23040, 23065, 23173, 23208, 24577, 24580, 24582, 24592, 24594, 24597, 24600, 24609, 24612, 24640, 24645, + 24648, 24657, 24660, 24672, 24708, 24720, 24832, 24834, 24837, 24840, 24849, 24852, 24864, 24897, 24900, 24912, + 24960, 24985, 25092, 25104, 25152, 25174, 25249, 25600, 25605, 25608, 25617, 25620, 25632, 25665, 25668, 25680, + 25728, 25857, 25860, 25872, 25920, 25930, 25960, 26002, 26112, 26260, 26625, 26628, 26640, 26725, 26776, 26880, + 26922, 27202, 27297, 32768, 32770, 32773, 32776, 32785, 32788, 32793, 32800, 32805, 32833, 32836, 32848, 32850, + 32853, 32856, 32865, 32896, 32901, 32913, 32916, 33025, 33028, 33033, 33040, 33042, 33045, 33048, 33057, 33060, + 33088, 33090, 33093, 33096, 33105, 33108, 33153, 33156, 33168, 33193, 33280, 33285, 33290, 33297, 33300, 33345, + 33348, 33360, 33793, 33796, 33798, 33801, 33808, 33810, 33813, 33816, 33825, 33856, 33858, 33861, 33864, 33873, + 33876, 33888, 33921, 33924, 33936, 34048, 34050, 34053, 34056, 34065, 34068, 34080, 34113, 34116, 34128, 34176, + 34186, 34305, 34308, 34320, 34345, 34368, 34816, 34821, 34833, 34836, 34881, 34884, 34896, 34978, 35073, 35076, + 35136, 35173, 35362, 35416, 35418, 35458, 35490, 36865, 36868, 36873, 36880, 36882, 36885, 36888, 36900, 36928, + 36930, 36933, 36936, 36945, 36948, 36960, 36993, 36996, 37008, 37120, 37125, 37137, 37140, 37185, 37188, 37200, + 37210, 37377, 37380, 37392, 37440, 37542, 37888, 37890, 37893, 37896, 37905, 37908, 37920, 37953, 37956, 37968, + 38016, 38038, 38145, 38148, 38160, 38208, 38296, 38305, 38400, 38470, 38500, 38913, 38916, 38928, 38950, 38976, + 39081, 39168, 39241, 39250, 39568, 40960, 40965, 40970, 40980, 40994, 41002, 41025, 41028, 41040, 41122, 41130, + 41280, 41317, 41474, 41482, 41506, 41512, 41514, 41602, 41608, 41610, 41640, 41985, 41988, 42000, 42048, 42121, + 42148, 42240, 42265, 42577, 43018, 43048, 43170, 43348, 43398, 43528, 43530, 43552, 43554, 43560, 43656, 43690, + }; + + const int kmap_size = 43692; + //const int nwant = type == GGML_TYPE_IQ1_S ? 3 : 2; + const int nwant = type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? 3 : type == GGML_TYPE_IQ2_S ? 1 : 2; + const uint16_t * kgrid = type == GGML_TYPE_IQ2_XXS ? kgrid_2bit_256 : + type == GGML_TYPE_IQ2_XS ? kgrid_2bit_512 : + type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? kgrid_1bit_2048 : kgrid_2bit_1024; + uint64_t * kgrid_q2xs; + int * kmap_q2xs; + uint16_t * kneighbors_q2xs; + + //printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size); + uint64_t * the_grid = (uint64_t *)malloc(grid_size*sizeof(uint64_t)); + for (int k = 0; k < grid_size; ++k) { + int8_t * pos = (int8_t *)(the_grid + k); + for (int i = 0; i < 8; ++i) { + int l = (kgrid[k] >> 2*i) & 0x3; + pos[i] = 2*l + 1; + } + } + kgrid_q2xs = the_grid; + iq2_data[gindex].grid = the_grid; + kmap_q2xs = (int *)malloc(kmap_size*sizeof(int)); + iq2_data[gindex].map = kmap_q2xs; + for (int i = 0; i < kmap_size; ++i) kmap_q2xs[i] = -1; + uint64_t aux64; + uint8_t * aux8 = (uint8_t *)&aux64; + for (int i = 0; i < grid_size; ++i) { + aux64 = kgrid_q2xs[i]; + uint16_t index = 0; + for (int k=0; k<8; ++k) { + uint16_t q = (aux8[k] - 1)/2; + index |= (q << 2*k); + } + kmap_q2xs[index] = i; + } + int8_t pos[8]; + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + int num_neighbors = 0, num_not_in_map = 0; + for (int i = 0; i < kmap_size; ++i) { + if (kmap_q2xs[i] >= 0) continue; + ++num_not_in_map; + for (int k = 0; k < 8; ++k) { + int l = (i >> 2*k) & 0x3; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); + int d2 = 0; + for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); + int n = 0; int d2 = dist2[0]; + int nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + ++n; + } + num_neighbors += n; + } + //printf("%s: %d neighbours in total\n", __func__, num_neighbors); + kneighbors_q2xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t)); + iq2_data[gindex].neighbours = kneighbors_q2xs; + int counter = 0; + for (int i = 0; i < kmap_size; ++i) { + if (kmap_q2xs[i] >= 0) continue; + for (int k = 0; k < 8; ++k) { + int l = (i >> 2*k) & 0x3; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); + int d2 = 0; + for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); + kmap_q2xs[i] = -(counter + 1); + int d2 = dist2[0]; + uint16_t * start = &kneighbors_q2xs[counter++]; + int n = 0, nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + kneighbors_q2xs[counter++] = dist2[2*j+1]; + ++n; + } + *start = n; + } + free(dist2); +} + +void iq2xs_free_impl(enum ggml_type type) { + GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S); + const int gindex = iq2_data_index(type); + if (iq2_data[gindex].grid) { + free(iq2_data[gindex].grid); iq2_data[gindex].grid = NULL; + free(iq2_data[gindex].map); iq2_data[gindex].map = NULL; + free(iq2_data[gindex].neighbours); iq2_data[gindex].neighbours = NULL; + } +} + +static int iq2_find_best_neighbour(const uint16_t * GGML_RESTRICT neighbours, const uint64_t * GGML_RESTRICT grid, + const float * GGML_RESTRICT xval, const float * GGML_RESTRICT weight, float scale, int8_t * GGML_RESTRICT L) { + int num_neighbors = neighbours[0]; + GGML_ASSERT(num_neighbors > 0); + float best_d2 = FLT_MAX; + int grid_index = -1; + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float d2 = 0; + for (int i = 0; i < 8; ++i) { + float q = pg[i]; + float diff = scale*q - xval[i]; + d2 += weight[i]*diff*diff; + } + if (d2 < best_d2) { + best_d2 = d2; grid_index = neighbours[j]; + } + } + GGML_ASSERT(grid_index >= 0); + const int8_t * pg = (const int8_t *)(grid + grid_index); + for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2; + return grid_index; +} + +static void quantize_row_iq2_xxs_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n, const float * GGML_RESTRICT quant_weights) { + + const int gindex = iq2_data_index(GGML_TYPE_IQ2_XXS); + + const uint64_t * kgrid_q2xs = iq2_data[gindex].grid; + const int * kmap_q2xs = iq2_data[gindex].map; + const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; + + GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); + + const int kMaxQ = 3; + + const int64_t nbl = n/QK_K; + + block_iq2_xxs * y = vy; + + float scales[QK_K/32]; + float weight[32]; + float xval[32]; + int8_t L[32]; + int8_t Laux[32]; + float waux[32]; + uint8_t block_signs[4]; + uint32_t q2[2*(QK_K/32)]; + + for (int ibl = 0; ibl < nbl; ++ibl) { + + y[ibl].d = GGML_FP32_TO_FP16(0.f); + memset(q2, 0, QK_K/4); + + float max_scale = 0; + + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = sumx2/QK_K; + + for (int ib = 0; ib < QK_K/32; ++ib) { + const float * xb = xbl + 32*ib; + const float * qw = quant_weights + QK_K*ibl + 32*ib; + for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]); + for (int k = 0; k < 4; ++k) { + int nflip = 0; + uint8_t s = 0; + for (int i = 0; i < 8; ++i) { + if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i]; + else { + xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i); + } + } + if (nflip%2) { + int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin]; + for (int i = 1; i < 8; ++i) { + float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i]; + if (ax < min) { + min = ax; imin = i; + } + } + xval[8*k+imin] = -xval[8*k+imin]; + s ^= (1 << imin); + } + block_signs[k] = s & 127; + } + float max = xval[0]; + for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]); + if (max < GROUP_MAX_EPS) { + scales[ib] = 0; + memset(L, 0, 32); + continue; + } + float scale = make_qp_quants(32, kMaxQ+1, xval, (uint8_t*)L, weight); + float eff_max = scale*kMaxQ; + float best = 0; + for (int is = -6; is <= 6; ++is) { + float id = (2*kMaxQ-1+is*0.1f)/eff_max; + float this_scale = 1/id; + for (int k = 0; k < 4; ++k) { + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l)); + } + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 32; ++i) { + float w = weight[i]; + float q = 2*Laux[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + scale = sumqx/sumq2; best = scale*sumqx; + memcpy(L, Laux, 32); + } + } + if (scale > 0) { + float id = 1/scale; + for (int k = 0; k < 4; ++k) { + uint16_t u = 0; + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + l = MAX(0, MIN(kMaxQ-1, l)); + u |= (l << 2*i); + } + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k); + } + const int8_t * pg = (const int8_t *)(kgrid_q2xs + grid_index); + for (int i = 0; i < 8; ++i) L[8*k+i] = (pg[i] - 1)/2; + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 32; ++i) { + float w = weight[i]; + float q = 2*L[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0) scale = sumqx/sumq2; + } + if (scale < 0) { + // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale) + // and correspondingly flip quant signs. + scale = -scale; + for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127; + } + for (int k = 0; k < 4; ++k) { + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + printf("Oops: found point %u not on grid:", u); + for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]); + printf("\n"); + GGML_ABORT("fatal error"); + } + q2[2*ib+0] |= ((uint32_t) grid_index << 8*k); + q2[2*ib+1] |= (block_signs[k] << 7*k); + } + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + max_scale = MAX(max_scale, scale); + } + + if (!max_scale) { + memset(y[ibl].qs, 0, QK_K/4); + continue; + } + + float d = max_scale/31; + y[ibl].d = GGML_FP32_TO_FP16(d); + float id = 1/d; + for (int ib = 0; ib < QK_K/32; ++ib) { + int l = nearest_int(0.5f*(id*scales[ib]-1)); + l = MAX(0, MIN(15, l)); + q2[2*ib+1] |= ((uint32_t)l << 28); + } + memcpy(y[ibl].qs, q2, QK_K/4); + } +} + +static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n, const float * GGML_RESTRICT quant_weights) { + + const int gindex = iq2_data_index(GGML_TYPE_IQ2_XS); + + const uint64_t * kgrid_q2xs = iq2_data[gindex].grid; + const int * kmap_q2xs = iq2_data[gindex].map; + const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; + + GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); + + const int kMaxQ = 3; + + const int64_t nbl = n/QK_K; + + block_iq2_xs * y = vy; + + float scales[QK_K/16]; + float weight[16]; + float xval[16]; + int8_t L[16]; + int8_t Laux[16]; + float waux[16]; + bool is_on_grid[2]; + bool is_on_grid_aux[2]; + uint8_t block_signs[2]; + uint16_t q2[2*(QK_K/16)]; + + for (int ibl = 0; ibl < nbl; ++ibl) { + + y[ibl].d = GGML_FP32_TO_FP16(0.f); + memset(q2, 0, QK_K/4); + memset(y[ibl].scales, 0, QK_K/32); + + float max_scale = 0; + + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = sumx2/QK_K; + + for (int ib = 0; ib < QK_K/16; ++ib) { + const float * xb = xbl + 16*ib; + const float * qw = quant_weights + QK_K*ibl + 16*ib; + for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + for (int i = 0; i < 16; ++i) waux[i] = sqrtf(weight[i]); + for (int k = 0; k < 2; ++k) { + int nflip = 0; + uint8_t s = 0; + for (int i = 0; i < 8; ++i) { + if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i]; + else { + xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i); + } + } + if (nflip%2) { + int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin]; + for (int i = 1; i < 8; ++i) { + float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i]; + if (ax < min) { + min = ax; imin = i; + } + } + xval[8*k+imin] = -xval[8*k+imin]; + s ^= (1 << imin); + } + block_signs[k] = s & 127; + } + float max = xval[0]; + for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]); + if (max < GROUP_MAX_EPS) { + scales[ib] = 0; + memset(L, 0, 16); + continue; + } + float best = 0; + float scale = max/(2*kMaxQ-1); + is_on_grid[0] = is_on_grid[1] = true; + for (int is = -9; is <= 9; ++is) { + float id = (2*kMaxQ-1+is*0.1f)/max; + float this_scale = 1/id; + for (int k = 0; k < 2; ++k) { + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l)); + } + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + is_on_grid_aux[k] = true; + if (grid_index < 0) { + is_on_grid_aux[k] = false; + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 16; ++i) { + float w = weight[i]; + float q = 2*Laux[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + scale = sumqx/sumq2; best = scale*sumqx; + for (int i = 0; i < 16; ++i) L[i] = Laux[i]; + for (int k = 0; k < 2; ++k) is_on_grid[k] = is_on_grid_aux[k]; + } + } + int n_not_ongrid = 0; + for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid; + if (n_not_ongrid > 0 && scale > 0) { + float id = 1/scale; + for (int k = 0; k < 2; ++k) { + if (is_on_grid[k]) continue; + uint16_t u = 0; + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + l = MAX(0, MIN(kMaxQ-1, l)); + u |= (l << 2*i); + L[8*k + i] = l; + } + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 16; ++i) { + float w = weight[i]; + float q = 2*L[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0) scale = sumqx/sumq2; + } + if (scale < 0) { + scale = -scale; + for (int k = 0; k < 2; ++k) block_signs[k] = (~block_signs[k]) & 127; + } + for (int k = 0; k < 2; ++k) { + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + printf("Oops: found point %u not on grid:", u); + for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]); + printf("\n"); + GGML_ABORT("fatal error"); + } + q2[2*ib+k] = grid_index | (block_signs[k] << 9); + } + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + max_scale = MAX(max_scale, scale); + } + + if (!max_scale) { + memset(y[ibl].qs, 0, QK_K/4); + continue; + } + + float d = max_scale/31; + y[ibl].d = GGML_FP32_TO_FP16(d); + float id = 1/d; + for (int ib = 0; ib < QK_K/16; ++ib) { + int l = nearest_int(0.5f*(id*scales[ib]-1)); + l = MAX(0, MIN(15, l)); + if (ib%2 == 0) y[ibl].scales[ib/2] = l; + else y[ibl].scales[ib/2] |= (l << 4); + } + memcpy(y[ibl].qs, q2, QK_K/4); + + } +} + +size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK_K == 0); + int64_t nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_iq2_xxs_impl(src, qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += nblock*sizeof(block_iq2_xxs); + } + return nrow * nblock * sizeof(block_iq2_xxs); +} + +size_t quantize_iq2_xs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK_K == 0); + int64_t nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_iq2_xs_impl(src, qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += nblock*sizeof(block_iq2_xs); + } + return nrow * nblock * sizeof(block_iq2_xs); +} + +// +// ============================================= 3-bit using D4 lattice +// + +typedef struct { + uint32_t * grid; + int * map; + uint16_t * neighbours; +} iq3_entry_t; + +static iq3_entry_t iq3_data[2] = { + {NULL, NULL, NULL}, + {NULL, NULL, NULL}, +}; + +static inline int iq3_data_index(int grid_size) { + (void)grid_size; + GGML_ASSERT(grid_size == 256 || grid_size == 512); + return grid_size == 256 ? 0 : 1; +} + +static int iq3_compare_func(const void * left, const void * right) { + const int * l = (const int *)left; + const int * r = (const int *)right; + return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0; +} + +void iq3xs_init_impl(int grid_size) { + const int gindex = iq3_data_index(grid_size); + if (iq3_data[gindex].grid) { + return; + } + static const uint16_t kgrid_256[256] = { + 0, 2, 4, 9, 11, 15, 16, 18, 25, 34, 59, 61, 65, 67, 72, 74, + 81, 85, 88, 90, 97, 108, 120, 128, 130, 132, 137, 144, 146, 153, 155, 159, + 169, 175, 189, 193, 199, 200, 202, 213, 248, 267, 287, 292, 303, 315, 317, 321, + 327, 346, 362, 413, 436, 456, 460, 462, 483, 497, 513, 515, 520, 522, 529, 531, + 536, 538, 540, 551, 552, 576, 578, 585, 592, 594, 641, 643, 648, 650, 657, 664, + 698, 704, 706, 720, 729, 742, 758, 769, 773, 808, 848, 852, 870, 889, 901, 978, + 992, 1024, 1026, 1033, 1035, 1040, 1042, 1046, 1049, 1058, 1089, 1091, 1093, 1096, 1098, 1105, + 1112, 1139, 1143, 1144, 1152, 1154, 1161, 1167, 1168, 1170, 1183, 1184, 1197, 1217, 1224, 1228, + 1272, 1276, 1309, 1323, 1347, 1367, 1377, 1404, 1473, 1475, 1486, 1509, 1537, 1544, 1546, 1553, + 1555, 1576, 1589, 1594, 1600, 1602, 1616, 1625, 1636, 1638, 1665, 1667, 1672, 1685, 1706, 1722, + 1737, 1755, 1816, 1831, 1850, 1856, 1862, 1874, 1901, 1932, 1950, 1971, 2011, 2032, 2052, 2063, + 2077, 2079, 2091, 2095, 2172, 2192, 2207, 2208, 2224, 2230, 2247, 2277, 2308, 2345, 2356, 2389, + 2403, 2424, 2501, 2504, 2506, 2520, 2570, 2593, 2616, 2624, 2630, 2646, 2669, 2700, 2714, 2746, + 2754, 2795, 2824, 2835, 2839, 2874, 2882, 2905, 2984, 3028, 3042, 3092, 3108, 3110, 3124, 3153, + 3185, 3215, 3252, 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610, + 3626, 3670, 3680, 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992, + }; + static const uint16_t kgrid_512[512] = { + 0, 1, 2, 5, 7, 8, 9, 10, 12, 14, 16, 17, 21, 27, 32, 34, + 37, 39, 41, 43, 48, 50, 57, 60, 63, 64, 65, 66, 68, 72, 73, 77, + 80, 83, 87, 89, 93, 100, 113, 117, 122, 128, 129, 133, 135, 136, 139, 142, + 145, 149, 152, 156, 162, 165, 167, 169, 171, 184, 187, 195, 201, 205, 208, 210, + 217, 219, 222, 228, 232, 234, 247, 249, 253, 256, 267, 271, 273, 276, 282, 288, + 291, 297, 312, 322, 324, 336, 338, 342, 347, 353, 357, 359, 374, 379, 390, 393, + 395, 409, 426, 441, 448, 450, 452, 464, 466, 470, 475, 488, 492, 512, 513, 514, + 516, 520, 521, 523, 525, 527, 528, 530, 537, 540, 542, 556, 558, 561, 570, 576, + 577, 579, 582, 584, 588, 593, 600, 603, 609, 616, 618, 632, 638, 640, 650, 653, + 655, 656, 660, 666, 672, 675, 685, 688, 698, 705, 708, 711, 712, 715, 721, 727, + 728, 732, 737, 754, 760, 771, 773, 778, 780, 793, 795, 802, 806, 808, 812, 833, + 840, 843, 849, 856, 858, 873, 912, 916, 919, 932, 934, 961, 963, 968, 970, 977, + 989, 993, 1010, 1016, 1024, 1025, 1027, 1029, 1031, 1032, 1034, 1036, 1038, 1041, 1043, 1047, + 1048, 1050, 1057, 1059, 1061, 1064, 1066, 1079, 1080, 1083, 1085, 1088, 1090, 1096, 1099, 1103, + 1106, 1109, 1113, 1116, 1122, 1129, 1153, 1156, 1159, 1169, 1171, 1176, 1183, 1185, 1195, 1199, + 1209, 1212, 1216, 1218, 1221, 1225, 1234, 1236, 1241, 1243, 1250, 1256, 1270, 1281, 1287, 1296, + 1299, 1306, 1309, 1313, 1338, 1341, 1348, 1353, 1362, 1375, 1376, 1387, 1400, 1408, 1410, 1415, + 1425, 1453, 1457, 1477, 1481, 1494, 1496, 1507, 1512, 1538, 1545, 1547, 1549, 1551, 1554, 1561, + 1563, 1565, 1570, 1572, 1575, 1577, 1587, 1593, 1601, 1603, 1605, 1612, 1617, 1619, 1632, 1648, + 1658, 1662, 1664, 1674, 1680, 1690, 1692, 1704, 1729, 1736, 1740, 1745, 1747, 1751, 1752, 1761, + 1763, 1767, 1773, 1787, 1795, 1801, 1806, 1810, 1817, 1834, 1840, 1844, 1857, 1864, 1866, 1877, + 1882, 1892, 1902, 1915, 1934, 1953, 1985, 1987, 2000, 2002, 2013, 2048, 2052, 2058, 2064, 2068, + 2071, 2074, 2081, 2088, 2104, 2114, 2119, 2121, 2123, 2130, 2136, 2141, 2147, 2153, 2157, 2177, + 2179, 2184, 2189, 2193, 2203, 2208, 2223, 2226, 2232, 2244, 2249, 2251, 2256, 2258, 2265, 2269, + 2304, 2306, 2324, 2335, 2336, 2361, 2373, 2375, 2385, 2418, 2443, 2460, 2480, 2504, 2509, 2520, + 2531, 2537, 2562, 2568, 2572, 2578, 2592, 2596, 2599, 2602, 2614, 2620, 2625, 2627, 2629, 2634, + 2641, 2650, 2682, 2688, 2697, 2707, 2712, 2718, 2731, 2754, 2759, 2760, 2775, 2788, 2793, 2805, + 2811, 2817, 2820, 2832, 2842, 2854, 2890, 2902, 2921, 2923, 2978, 3010, 3012, 3026, 3081, 3083, + 3085, 3097, 3099, 3120, 3136, 3152, 3159, 3188, 3210, 3228, 3234, 3245, 3250, 3256, 3264, 3276, + 3281, 3296, 3349, 3363, 3378, 3392, 3395, 3420, 3440, 3461, 3488, 3529, 3531, 3584, 3588, 3591, + 3600, 3602, 3614, 3616, 3628, 3634, 3650, 3657, 3668, 3683, 3685, 3713, 3716, 3720, 3726, 3729, + 3736, 3753, 3778, 3802, 3805, 3819, 3841, 3845, 3851, 3856, 3880, 3922, 3938, 3970, 3993, 4032, + }; + + const int kmap_size = 4096; + const int nwant = grid_size == 256 ? 2 : 3; + const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512; + uint32_t * kgrid_q3xs; + int * kmap_q3xs; + uint16_t * kneighbors_q3xs; + + //printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size); + uint32_t * the_grid = (uint32_t *)malloc(grid_size*sizeof(uint32_t)); + for (int k = 0; k < grid_size; ++k) { + int8_t * pos = (int8_t *)(the_grid + k); + for (int i = 0; i < 4; ++i) { + int l = (kgrid[k] >> 3*i) & 0x7; + pos[i] = 2*l + 1; + } + } + kgrid_q3xs = the_grid; + iq3_data[gindex].grid = the_grid; + kmap_q3xs = (int *)malloc(kmap_size*sizeof(int)); + iq3_data[gindex].map = kmap_q3xs; + for (int i = 0; i < kmap_size; ++i) kmap_q3xs[i] = -1; + uint32_t aux32; + uint8_t * aux8 = (uint8_t *)&aux32; + for (int i = 0; i < grid_size; ++i) { + aux32 = kgrid_q3xs[i]; + uint16_t index = 0; + for (int k=0; k<4; ++k) { + uint16_t q = (aux8[k] - 1)/2; + index |= (q << 3*k); + } + kmap_q3xs[index] = i; + } + int8_t pos[4]; + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + int num_neighbors = 0, num_not_in_map = 0; + for (int i = 0; i < kmap_size; ++i) { + if (kmap_q3xs[i] >= 0) continue; + ++num_not_in_map; + for (int k = 0; k < 4; ++k) { + int l = (i >> 3*k) & 0x7; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); + int d2 = 0; + for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); + int n = 0; int d2 = dist2[0]; + int nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + ++n; + } + num_neighbors += n; + } + //printf("%s: %d neighbours in total\n", __func__, num_neighbors); + kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t)); + iq3_data[gindex].neighbours = kneighbors_q3xs; + int counter = 0; + for (int i = 0; i < kmap_size; ++i) { + if (kmap_q3xs[i] >= 0) continue; + for (int k = 0; k < 4; ++k) { + int l = (i >> 3*k) & 0x7; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); + int d2 = 0; + for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); + kmap_q3xs[i] = -(counter + 1); + int d2 = dist2[0]; + uint16_t * start = &kneighbors_q3xs[counter++]; + int n = 0, nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + kneighbors_q3xs[counter++] = dist2[2*j+1]; + ++n; + } + *start = n; + } + free(dist2); +} + +void iq3xs_free_impl(int grid_size) { + GGML_ASSERT(grid_size == 256 || grid_size == 512); + const int gindex = iq3_data_index(grid_size); + if (iq3_data[gindex].grid) { + free(iq3_data[gindex].grid); iq3_data[gindex].grid = NULL; + free(iq3_data[gindex].map); iq3_data[gindex].map = NULL; + free(iq3_data[gindex].neighbours); iq3_data[gindex].neighbours = NULL; + } +} + +static int iq3_find_best_neighbour(const uint16_t * GGML_RESTRICT neighbours, const uint32_t * GGML_RESTRICT grid, + const float * GGML_RESTRICT xval, const float * GGML_RESTRICT weight, float scale, int8_t * GGML_RESTRICT L) { + int num_neighbors = neighbours[0]; + GGML_ASSERT(num_neighbors > 0); + float best_d2 = FLT_MAX; + int grid_index = -1; + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float d2 = 0; + for (int i = 0; i < 4; ++i) { + float q = pg[i]; + float diff = scale*q - xval[i]; + d2 += weight[i]*diff*diff; + } + if (d2 < best_d2) { + best_d2 = d2; grid_index = neighbours[j]; + } + } + GGML_ASSERT(grid_index >= 0); + const int8_t * pg = (const int8_t *)(grid + grid_index); + for (int i = 0; i < 4; ++i) L[i] = (pg[i] - 1)/2; + return grid_index; +} + +static void quantize_row_iq3_xxs_impl(int grid_size, const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n, + const float * GGML_RESTRICT quant_weights) { + + const int gindex = iq3_data_index(grid_size); + + const uint32_t * kgrid_q3xs = iq3_data[gindex].grid; + const int * kmap_q3xs = iq3_data[gindex].map; + const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours; + + //GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kgrid_q3xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kmap_q3xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); + + const int kMaxQ = 8; + + const int64_t nbl = n/QK_K; + + ggml_fp16_t * dh; + uint8_t * qs; + int block_size; + if (grid_size == 256) { + block_iq3_xxs * y = vy; + dh = &y->d; + qs = y->qs; + block_size = sizeof(block_iq3_xxs); + } else { + block_iq3_s * y = vy; + dh = &y->d; + qs = y->qs; + block_size = sizeof(block_iq3_s); + } + int quant_size = block_size - sizeof(ggml_fp16_t); + + float scales[QK_K/32]; + float weight[32]; + float xval[32]; + int8_t L[32]; + int8_t Laux[32]; + float waux[32]; + bool is_on_grid[8]; + bool is_on_grid_aux[8]; + uint8_t block_signs[8]; + uint8_t q3[3*(QK_K/8)+QK_K/32]; + uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4); + uint8_t * qh = q3 + 3*(QK_K/8); + + for (int ibl = 0; ibl < nbl; ++ibl) { + + dh[0] = GGML_FP32_TO_FP16(0.f); + memset(q3, 0, 3*QK_K/8+QK_K/32); + + float max_scale = 0; + + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = 2*sumx2/QK_K; + + for (int ib = 0; ib < QK_K/32; ++ib) { + const float * xb = xbl + 32*ib; + if (quant_weights) { + const float * qw = quant_weights + QK_K*ibl + 32*ib; + for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + } else { + for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i]; + } + for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]); + for (int k = 0; k < 4; ++k) { + int nflip = 0; + uint8_t s = 0; + for (int i = 0; i < 8; ++i) { + if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i]; + else { + xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i); + } + } + if (nflip%2) { + int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin]; + for (int i = 1; i < 8; ++i) { + float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i]; + if (ax < min) { + min = ax; imin = i; + } + } + xval[8*k+imin] = -xval[8*k+imin]; + s ^= (1 << imin); + } + block_signs[k] = s & 127; + } + float max = xval[0]; + for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]); + if (max < GROUP_MAX_EPS_IQ3_XXS) { + scales[ib] = 0; + memset(L, 0, 32); + continue; + } + float best = 0; + float scale = max/(2*kMaxQ-1); + for (int is = -15; is <= 15; ++is) { + float id = (2*kMaxQ-1+is*0.2f)/max; + float this_scale = 1/id; + for (int k = 0; k < 8; ++k) { + for (int i = 0; i < 4; ++i) { + int l = nearest_int(0.5f*(id*xval[4*k+i]-1)); + Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l)); + } + uint16_t u = 0; + for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i); + int grid_index = kmap_q3xs[u]; + is_on_grid_aux[k] = true; + if (grid_index < 0) { + is_on_grid_aux[k] = false; + const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1; + grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 32; ++i) { + float w = weight[i]; + float q = 2*Laux[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + scale = sumqx/sumq2; best = scale*sumqx; + for (int i = 0; i < 32; ++i) L[i] = Laux[i]; + for (int k = 0; k < 8; ++k) is_on_grid[k] = is_on_grid_aux[k]; + } + } + int n_not_ongrid = 0; + for (int k = 0; k < 8; ++k) if (!is_on_grid[k]) ++n_not_ongrid; + if (n_not_ongrid > 0 && scale > 0) { + float id = 1/scale; + for (int k = 0; k < 8; ++k) { + if (is_on_grid[k]) continue; + uint16_t u = 0; + for (int i = 0; i < 4; ++i) { + int l = nearest_int(0.5f*(id*xval[4*k+i]-1)); + l = MAX(0, MIN(kMaxQ-1, l)); + u |= (l << 3*i); + } + int grid_index = kmap_q3xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1; + grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k); + } + const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index); + for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2; + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 32; ++i) { + float w = weight[i]; + float q = 2*L[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0) scale = sumqx/sumq2; + } + if (scale < 0) { + // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale) + // and correspondingly flip quant signs. + scale = -scale; + for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127; + } + for (int k = 0; k < 8; ++k) { + uint16_t u = 0; + for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i); + int grid_index = kmap_q3xs[u]; + if (grid_index < 0) { + printf("Oops: found point %u not on grid:", u); + for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]); + printf("\n"); + GGML_ABORT("fatal error"); + } + if (grid_size == 256) { + q3[8*ib+k] = grid_index; + } else { + q3[8*ib+k] = grid_index & 255; + qh[ib] |= ((grid_index >> 8) << k); + } + + } + scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21); + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + max_scale = MAX(max_scale, scale); + } + + if (!max_scale) { + memset(qs, 0, quant_size); + dh += block_size/sizeof(ggml_fp16_t); + qs += block_size; + continue; + } + + float d = max_scale/31; + dh[0] = GGML_FP32_TO_FP16(d * 1.0125f); // small improvement via this fudge factor + float id = 1/d; + for (int ib = 0; ib < QK_K/32; ++ib) { + int l = nearest_int(0.5f*(id*scales[ib]-1)); + l = MAX(0, MIN(15, l)); + scales_and_signs[ib] |= ((uint32_t)l << 28); + } + memcpy(qs, q3, quant_size); + + dh += block_size/sizeof(ggml_fp16_t); + qs += block_size; + + } +} + +size_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK_K == 0); + int64_t nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_iq3_xxs_impl(256, src, qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += nblock*sizeof(block_iq3_xxs); + } + return nrow * nblock * sizeof(block_iq3_xxs); +} + +void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + quantize_row_iq3_xxs_impl(256, x, y, k, NULL); +} + +static void quantize_row_iq3_s_impl(int block_size, const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int n, + const float * GGML_RESTRICT quant_weights, + float * scales, + float * weight, + float * xval, + int8_t * L, + int8_t * Laux, + float * waux, + bool * is_on_grid, + bool * is_on_grid_aux, + uint8_t * block_signs) { + + const int gindex = iq3_data_index(512); + + const uint32_t * kgrid_q3xs = iq3_data[gindex].grid; + const int * kmap_q3xs = iq3_data[gindex].map; + const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours; + + //GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kgrid_q3xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kmap_q3xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); + + const int kMaxQ = 8; + + const int64_t nbl = n/QK_K; + + block_iq3_s * y = vy; + + const int bs4 = block_size/4; + const int bs8 = block_size/8; + + for (int ibl = 0; ibl < nbl; ++ibl) { + + memset(&y[ibl], 0, sizeof(block_iq3_s)); + y[ibl].d = GGML_FP32_TO_FP16(0.f); + + uint8_t * qs = y[ibl].qs; + uint8_t * qh = y[ibl].qh; + uint8_t * signs = y[ibl].signs; + + float max_scale = 0; + + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = 2*sumx2/QK_K; + + for (int ib = 0; ib < QK_K/block_size; ++ib) { + const float * xb = xbl + block_size*ib; + if (quant_weights) { + const float * qw = quant_weights + QK_K*ibl + block_size*ib; + for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + } else { + for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i]; + } + for (int i = 0; i < block_size; ++i) waux[i] = sqrtf(weight[i]); + for (int k = 0; k < bs8; ++k) { + uint8_t s = 0; + for (int i = 0; i < 8; ++i) { + if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i]; + else { + xval[8*k + i] = -xb[8*k + i]; s |= (1 << i); + } + } + block_signs[k] = s; + } + float max = xval[0]; + for (int i = 1; i < block_size; ++i) max = MAX(max, xval[i]); + if (!max) { + scales[ib] = 0; + continue; + } + float best = 0; + float scale = max/(2*kMaxQ-1); + for (int k = 0; k < bs4; ++k) is_on_grid[k] = false; + for (int is = -9; is <= 9; ++is) { + float id = (2*kMaxQ-1+is*0.2f)/max; + float this_scale = 1/id; + for (int k = 0; k < bs4; ++k) { + for (int i = 0; i < 4; ++i) { + int l = nearest_int(0.5f*(id*xval[4*k+i]-1)); + Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l)); + } + uint16_t u = 0; + for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i); + int grid_index = kmap_q3xs[u]; + is_on_grid_aux[k] = true; + if (grid_index < 0) { + is_on_grid_aux[k] = false; + const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1; + grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < block_size; ++i) { + float w = weight[i]; + float q = 2*Laux[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + scale = sumqx/sumq2; best = scale*sumqx; + for (int i = 0; i < block_size; ++i) L[i] = Laux[i]; + for (int k = 0; k < bs4; ++k) is_on_grid[k] = is_on_grid_aux[k]; + } + } + int n_not_ongrid = 0; + for (int k = 0; k < bs4; ++k) if (!is_on_grid[k]) ++n_not_ongrid; + if (n_not_ongrid > 0 && scale > 0) { + float id = 1/scale; + for (int k = 0; k < bs4; ++k) { + //if (is_on_grid[k]) continue; + uint16_t u = 0; + for (int i = 0; i < 4; ++i) { + int l = nearest_int(0.5f*(id*xval[4*k+i]-1)); + l = MAX(0, MIN(kMaxQ-1, l)); + u |= (l << 3*i); + } + int grid_index = kmap_q3xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1; + grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k); + } + const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index); + for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2; + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < block_size; ++i) { + float w = weight[i]; + float q = 2*L[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0) scale = sumqx/sumq2; + } + if (scale < 0) { + // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale) + // and correspondingly flip quant signs. + scale = -scale; + for (int k = 0; k < bs8; ++k) block_signs[k] = ~block_signs[k]; + } + for (int k = 0; k < bs4; ++k) { + uint16_t u = 0; + for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i); + int grid_index = kmap_q3xs[u]; + if (grid_index < 0) { + printf("Oops: found point %u not on grid:", u); + for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]); + printf("\n"); + GGML_ABORT("fatal error"); + } + qs[k] = grid_index & 255; + qh[(ib*bs4+k)/8] |= ((grid_index >> 8) << ((ib*bs4+k)%8)); + } + qs += bs4; + for (int k = 0; k < bs8; ++k) signs[k] = block_signs[k]; + signs += bs8; + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + max_scale = MAX(max_scale, scale); + } + + if (!max_scale) { + continue; + } + + float d = max_scale/31; + y[ibl].d = GGML_FP32_TO_FP16(d * 1.033f); + float id = 1/d; + for (int ib = 0; ib < QK_K/block_size; ib += 2) { + int l1 = nearest_int(0.5f*(id*scales[ib+0]-1)); + l1 = MAX(0, MIN(15, l1)); + int l2 = nearest_int(0.5f*(id*scales[ib+1]-1)); + l2 = MAX(0, MIN(15, l2)); + y[ibl].scales[ib/2] = l1 | (l2 << 4); + } + + } +} + +#define IQ3S_BLOCK_SIZE 32 +size_t quantize_iq3_s(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK_K == 0); + int64_t nblock = n_per_row/QK_K; + float scales[QK_K/IQ3S_BLOCK_SIZE]; + float weight[IQ3S_BLOCK_SIZE]; + float xval[IQ3S_BLOCK_SIZE]; + int8_t L[IQ3S_BLOCK_SIZE]; + int8_t Laux[IQ3S_BLOCK_SIZE]; + float waux[IQ3S_BLOCK_SIZE]; + bool is_on_grid[IQ3S_BLOCK_SIZE/4]; + bool is_on_grid_aux[IQ3S_BLOCK_SIZE/4]; + uint8_t block_signs[IQ3S_BLOCK_SIZE/8]; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_iq3_s_impl(IQ3S_BLOCK_SIZE, src, qrow, n_per_row, quant_weights, + scales, weight, xval, L, Laux, waux, is_on_grid, is_on_grid_aux, block_signs); + src += n_per_row; + qrow += nblock*sizeof(block_iq3_s); + } + return nrow * nblock * sizeof(block_iq3_s); +} + +void quantize_row_iq3_s_ref(const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq3_s(x, y, 1, k, NULL); +} + + +// =================================== 1.5 bpw =================================================== + +static int iq1_find_best_neighbour(const uint16_t * GGML_RESTRICT neighbours, const uint64_t * GGML_RESTRICT grid, + const float * GGML_RESTRICT xval, const float * GGML_RESTRICT weight, float * scale, int8_t * GGML_RESTRICT L, int ngrid) { + int num_neighbors = neighbours[0]; + GGML_ASSERT(num_neighbors > 0); + float best_score = -FLT_MAX; + int grid_index = -1; + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 8; ++i) { + float q = (pg[i] - 3)/2; + float w = weight[i]; + sumqx += w*q*xval[i]; + sumq2 += w*q*q; + } + if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { + *scale = sumqx/sumq2; best_score = *scale * sumqx; + grid_index = neighbours[j]; + } + } + if (grid_index < 0) { + for (int i = 0; i < ngrid; ++i) { + const int8_t * grid_i = (const int8_t *)(grid + i); + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < 8; ++j) { + float w = weight[j]; + float q = (grid_i[j] - 3)/2; + sumqx += w*q*xval[j]; + sumq2 += w*q*q; + } + if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { + *scale = sumqx/sumq2; best_score = *scale*sumqx; + grid_index = i; + } + } + } + if (grid_index < 0) { + printf("Oops, did not find grid point\n"); + printf("Have %d neighbours\n", num_neighbors); + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 8; ++i) { + float q = (pg[i] - 3)/2; + float w = weight[i]; + sumqx += w*q*xval[i]; + sumq2 += w*q*q; + } + printf(" neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2); + } + } + GGML_ASSERT(grid_index >= 0); + //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + *scale *= 1.05f; // This is a fudge factor. Don't ask me why it improves the result. + //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + const int8_t * pg = (const int8_t *)(grid + grid_index); + for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2; + return grid_index; +} + +static int iq1_find_best_neighbour2(const uint16_t * GGML_RESTRICT neighbours, const uint64_t * GGML_RESTRICT grid, + const float * GGML_RESTRICT xval, const float * GGML_RESTRICT weight, float scale, const float * GGML_RESTRICT xg, int8_t * GGML_RESTRICT L, int ngrid) { + int num_neighbors = neighbours[0]; + GGML_ASSERT(num_neighbors > 0); + float best_score = FLT_MAX; + int grid_index = -1; + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float d2 = 0; + for (int i = 0; i < 8; ++i) { + float q = xg[(pg[i] - 1)/2]; + float w = weight[i]; + float diff = scale*q - xval[i]; + d2 += w*diff*diff; + } + if (d2 < best_score) { + best_score = d2; + grid_index = neighbours[j]; + } + } + if (grid_index < 0) { + for (int i = 0; i < ngrid; ++i) { + const int8_t * grid_i = (const int8_t *)(grid + i); + float d2 = 0; + for (int j = 0; j < 8; ++j) { + float w = weight[j]; + float q = xg[(grid_i[j] - 1)/2]; + float diff = scale*q - xval[i]; + d2 += w*diff*diff; + } + if (d2 < best_score) { + best_score = d2; + grid_index = i; + } + } + } + if (grid_index < 0) { + printf("Oops, did not find grid point\n"); + printf("Have %d neighbours\n", num_neighbors); + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 8; ++i) { + float q = xg[(pg[i] - 1)/2]; + float w = weight[i]; + sumqx += w*q*xval[i]; + sumq2 += w*q*q; + } + printf(" neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2); + } + } + GGML_ASSERT(grid_index >= 0); + const int8_t * pg = (const int8_t *)(grid + grid_index); + for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2; + return grid_index; +} + +static int iq1_sort_helper(const void * left, const void * right) { + const float * l = left; + const float * r = right; + return *l < *r ? -1 : *l > *r ? 1 : 0; +} + +#define IQ1S_BLOCK_SIZE 32 +#define IQ1M_BLOCK_SIZE 16 +static void quantize_row_iq1_s_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n, const float * GGML_RESTRICT quant_weights, + float * scales, + float * weight, + float * sumx, + float * sumw, + float * pairs, + int8_t * L, + uint16_t * index, + int8_t * shifts) { + + const int gindex = iq2_data_index(GGML_TYPE_IQ1_S); + + const uint64_t * kgrid_q2xs = iq2_data[gindex].grid; + const int * kmap_q2xs = iq2_data[gindex].map; + const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; + + GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); + + block_iq1_s * y = vy; + + const int64_t nbl = n/QK_K; + + const int block_size = IQ1S_BLOCK_SIZE; + + const float x_p[3] = {-1 + IQ1S_DELTA, IQ1S_DELTA, 1 + IQ1S_DELTA}; + const float x_m[3] = {-1 - IQ1S_DELTA, -IQ1S_DELTA, 1 - IQ1S_DELTA}; + + + int * idx = (int *)(pairs + 1); + + for (int ibl = 0; ibl < nbl; ++ibl) { + + y[ibl].d = GGML_FP32_TO_FP16(0.f); + memset(y[ibl].qs, 0, QK_K/8); + memset(y[ibl].qh, 0, QK_K/16); + + float max_scale = 0; + + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = 2*sumx2/QK_K; + + for (int ib = 0; ib < QK_K/block_size; ++ib) { + const float * xb = xbl + block_size*ib; + const float * qw = quant_weights + QK_K*ibl + block_size*ib; + for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + float max = fabsf(xb[0]); + for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i])); + if (max < GROUP_MAX_EPS_IQ1_S) { + scales[ib] = 0; + memset(L, 1, block_size); + continue; + } + // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem. + // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two + // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights + // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and + // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale + // for each possible and score for each split. + for (int j = 0; j < block_size; ++j) { + pairs[2*j] = xb[j]; + idx[2*j] = j; + } + qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper); + { + sumx[0] = sumw[0] = 0; + for (int j = 0; j < block_size; ++j) { + int i = idx[2*j]; + sumx[j+1] = sumx[j] + weight[i]*xb[i]; + sumw[j+1] = sumw[j] + weight[i]; + } + } + float best_score = -FLT_MIN, scale = max; + int besti1 = -1, besti2 = -1, best_shift = 0; + for (int i1 = 0; i1 <= block_size; ++i1) { + for (int i2 = i1; i2 <= block_size; ++i2) { + float sumqx = (sumx[i1] - sumx[0])*x_p[0] + (sumx[i2] - sumx[i1])*x_p[1] + (sumx[block_size] - sumx[i2])*x_p[2]; + float sumq2 = (sumw[i1] - sumw[0])*x_p[0]*x_p[0] + (sumw[i2] - sumw[i1])*x_p[1]*x_p[1] + (sumw[block_size] - sumw[i2])*x_p[2]*x_p[2]; + if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { + scale = sumqx/sumq2; best_score = scale*sumqx; + besti1 = i1; besti2 = i2; best_shift = 1; + } + sumqx = (sumx[i1] - sumx[0])*x_m[0] + (sumx[i2] - sumx[i1])*x_m[1] + (sumx[block_size] - sumx[i2])*x_m[2]; + sumq2 = (sumw[i1] - sumw[0])*x_m[0]*x_m[0] + (sumw[i2] - sumw[i1])*x_m[1]*x_m[1] + (sumw[block_size] - sumw[i2])*x_m[2]*x_m[2]; + if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { + scale = sumqx/sumq2; best_score = scale*sumqx; + besti1 = i1; besti2 = i2; best_shift = -1; + } + } + } + GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_shift != 0); + for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0; + for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1; + for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2; + if (scale < 0) { + for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j]; + scale = -scale; best_shift = -best_shift; + } + bool all_on_grid = true; + const float * xx = best_shift == 1 ? x_p : x_m; + for (int k = 0; k < block_size/8; ++k) { + uint16_t u = 0; + for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + all_on_grid = false; + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, xx, L + 8*k, NGRID_IQ1S); + GGML_ASSERT(grid_index >= 0); + } + index[k] = grid_index; + } + if (!all_on_grid) { + float sumqx = 0, sumq2 = 0; + for (int k = 0; k < block_size/8; ++k) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]); + for (int j = 0; j < 8; ++j) { + float w = weight[8*k + j]; + float q = xx[(pg[j] - 1)/2]; + sumqx += w*q*xb[8*k+j]; + sumq2 += w*q*q; + } + } + if (sumqx > 0 && sumq2 > 0) scale = sumqx/sumq2; + } + uint16_t h = 0; + for (int k = 0; k < block_size/8; ++k) { + y[ibl].qs[(block_size/8)*ib + k] = index[k] & 255; + h |= (index[k] >> 8) << 3*k; + } + y[ibl].qh[ib] = h; + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + shifts[ib] = best_shift; + max_scale = MAX(max_scale, scale); + } + + if (!max_scale) { + continue; + } + + float d = max_scale/15; + y[ibl].d = GGML_FP32_TO_FP16(d*1.125f); // 1.125f is another fudge factor. Don't ask me why it is needed. + float id = 1/d; + for (int ib = 0; ib < QK_K/block_size; ++ib) { + int l = nearest_int(0.5f*(id*scales[ib]-1)); + l = MAX(0, MIN(7, l)); + if (shifts[ib] == -1) l |= 8; + y[ibl].qh[ib] |= (l << 12); + } + } +} + +size_t quantize_iq1_s(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK_K == 0); + float scales[QK_K/IQ1S_BLOCK_SIZE]; + float weight[IQ1S_BLOCK_SIZE]; + int8_t L[IQ1S_BLOCK_SIZE]; + float sumx[IQ1S_BLOCK_SIZE+1]; + float sumw[IQ1S_BLOCK_SIZE+1]; + float pairs[2*IQ1S_BLOCK_SIZE]; + uint16_t index[IQ1S_BLOCK_SIZE/8]; + int8_t shifts[QK_K/IQ1S_BLOCK_SIZE]; + int64_t nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_iq1_s_impl(src, qrow, n_per_row, quant_weights, scales, weight, sumx, sumw, pairs, L, index, shifts); + src += n_per_row; + qrow += nblock*sizeof(block_iq1_s); + } + return nrow * nblock * sizeof(block_iq1_s); +} + +static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n, const float * GGML_RESTRICT quant_weights, + float * scales, + float * weight, + float * pairs, + int8_t * L, + uint16_t * index, + int8_t * shifts) { + + const int gindex = iq2_data_index(GGML_TYPE_IQ1_M); + + const uint64_t * kgrid_q2xs = iq2_data[gindex].grid; + const int * kmap_q2xs = iq2_data[gindex].map; + const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; + + //GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); + + block_iq1_m * y = vy; + + const int64_t nbl = n/QK_K; + + const int block_size = IQ1M_BLOCK_SIZE; + + const float x_p[3] = {-1 + IQ1M_DELTA, IQ1M_DELTA, 1 + IQ1M_DELTA}; + const float x_m[3] = {-1 - IQ1M_DELTA, -IQ1M_DELTA, 1 - IQ1M_DELTA}; + const uint8_t masks[4] = {0x00, 0x80, 0x08, 0x88}; + + int * idx = (int *)(pairs + 1); + + float sumqx[4], sumq2[4]; + + iq1m_scale_t s; + const float * xx; + + for (int ibl = 0; ibl < nbl; ++ibl) { + memset(y[ibl].qs, 0, QK_K/8); + memset(y[ibl].qh, 0, QK_K/16); + memset(y[ibl].scales, 0, QK_K/32); + + float max_scale = 0; + + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = 2*sumx2/QK_K; + + for (int ib = 0; ib < QK_K/block_size; ++ib) { + const float * xb = xbl + block_size*ib; + if (quant_weights) { + const float * qw = quant_weights + QK_K*ibl + block_size*ib; + for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + } else { + for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i]; + } + float max = fabsf(xb[0]); + for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i])); + if (max < GROUP_MAX_EPS_IQ1_M) { + scales[ib] = 0; + memset(L, 1, block_size); + continue; + } + // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem. + // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two + // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights + // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and + // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale + // for each possible and score for each split. + for (int j = 0; j < block_size; ++j) { + pairs[2*j] = xb[j]; + idx[2*j] = j; + } + qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper); + float best_score = -FLT_MIN, scale = max; + int besti1 = -1, besti2 = -1, best_k = -1; + // 0: +, + + // 1: +, - + // 2: -, + + // 3: -, - + for (int i1 = 0; i1 <= block_size; ++i1) { + for (int i2 = i1; i2 <= block_size; ++i2) { + memset(sumqx, 0, 4*sizeof(float)); + memset(sumq2, 0, 4*sizeof(float)); + for (int j = 0; j < i1; ++j) { + int i = idx[2*j]; + if (i < block_size/2) { + sumqx[0] += weight[i]*x_p[0]*xb[i]; + sumqx[1] += weight[i]*x_p[0]*xb[i]; + sumqx[2] += weight[i]*x_m[0]*xb[i]; + sumqx[3] += weight[i]*x_m[0]*xb[i]; + sumq2[0] += weight[i]*x_p[0]*x_p[0]; + sumq2[1] += weight[i]*x_p[0]*x_p[0]; + sumq2[2] += weight[i]*x_m[0]*x_m[0]; + sumq2[3] += weight[i]*x_m[0]*x_m[0]; + } else { + sumqx[0] += weight[i]*x_p[0]*xb[i]; + sumqx[2] += weight[i]*x_p[0]*xb[i]; + sumqx[1] += weight[i]*x_m[0]*xb[i]; + sumqx[3] += weight[i]*x_m[0]*xb[i]; + sumq2[0] += weight[i]*x_p[0]*x_p[0]; + sumq2[2] += weight[i]*x_p[0]*x_p[0]; + sumq2[1] += weight[i]*x_m[0]*x_m[0]; + sumq2[3] += weight[i]*x_m[0]*x_m[0]; + } + } + for (int j = i1; j < i2; ++j) { + int i = idx[2*j]; + if (i < block_size/2) { + sumqx[0] += weight[i]*x_p[1]*xb[i]; + sumqx[1] += weight[i]*x_p[1]*xb[i]; + sumqx[2] += weight[i]*x_m[1]*xb[i]; + sumqx[3] += weight[i]*x_m[1]*xb[i]; + sumq2[0] += weight[i]*x_p[1]*x_p[1]; + sumq2[1] += weight[i]*x_p[1]*x_p[1]; + sumq2[2] += weight[i]*x_m[1]*x_m[1]; + sumq2[3] += weight[i]*x_m[1]*x_m[1]; + } else { + sumqx[0] += weight[i]*x_p[1]*xb[i]; + sumqx[2] += weight[i]*x_p[1]*xb[i]; + sumqx[1] += weight[i]*x_m[1]*xb[i]; + sumqx[3] += weight[i]*x_m[1]*xb[i]; + sumq2[0] += weight[i]*x_p[1]*x_p[1]; + sumq2[2] += weight[i]*x_p[1]*x_p[1]; + sumq2[1] += weight[i]*x_m[1]*x_m[1]; + sumq2[3] += weight[i]*x_m[1]*x_m[1]; + } + } + for (int j = i2; j < block_size; ++j) { + int i = idx[2*j]; + if (i < block_size/2) { + sumqx[0] += weight[i]*x_p[2]*xb[i]; + sumqx[1] += weight[i]*x_p[2]*xb[i]; + sumqx[2] += weight[i]*x_m[2]*xb[i]; + sumqx[3] += weight[i]*x_m[2]*xb[i]; + sumq2[0] += weight[i]*x_p[2]*x_p[2]; + sumq2[1] += weight[i]*x_p[2]*x_p[2]; + sumq2[2] += weight[i]*x_m[2]*x_m[2]; + sumq2[3] += weight[i]*x_m[2]*x_m[2]; + } else { + sumqx[0] += weight[i]*x_p[2]*xb[i]; + sumqx[2] += weight[i]*x_p[2]*xb[i]; + sumqx[1] += weight[i]*x_m[2]*xb[i]; + sumqx[3] += weight[i]*x_m[2]*xb[i]; + sumq2[0] += weight[i]*x_p[2]*x_p[2]; + sumq2[2] += weight[i]*x_p[2]*x_p[2]; + sumq2[1] += weight[i]*x_m[2]*x_m[2]; + sumq2[3] += weight[i]*x_m[2]*x_m[2]; + } + } + for (int k = 0; k < 4; ++k) { + if (sumq2[k] > 0 && sumqx[k]*sumqx[k] > best_score*sumq2[k]) { + scale = sumqx[k]/sumq2[k]; best_score = scale*sumqx[k]; + besti1 = i1; besti2 = i2; best_k = k; + } + } + } + } + GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_k >= 0); + for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0; + for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1; + for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2; + if (scale < 0) { + for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j]; + scale = -scale; + best_k = best_k == 0 ? 3 : best_k == 1 ? 2 : best_k == 2 ? 1 : 0; + } + bool all_on_grid = true; + for (int k = 0; k < block_size/8; ++k) { + if (k == 0) xx = best_k < 2 ? x_p : x_m; + else xx = best_k%2 == 0 ? x_p : x_m; + uint16_t u = 0; + for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + all_on_grid = false; + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, xx, L + 8*k, NGRID_IQ1S); + GGML_ASSERT(grid_index >= 0); + } + index[k] = grid_index; + } + if (!all_on_grid) { + float sumqx_f = 0, sumq2_f = 0; + for (int k = 0; k < block_size/8; ++k) { + if (k == 0) xx = best_k < 2 ? x_p : x_m; + else xx = best_k%2 == 0 ? x_p : x_m; + const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]); + for (int j = 0; j < 8; ++j) { + float w = weight[8*k + j]; + float q = xx[(pg[j] - 1)/2]; + sumqx_f += w*q*xb[8*k+j]; + sumq2_f += w*q*q; + } + } + if (sumqx_f > 0 && sumq2_f > 0) scale = sumqx_f/sumq2_f; + } + y[ibl].qs[2*ib + 0] = index[0] & 255; + y[ibl].qs[2*ib + 1] = index[1] & 255; + y[ibl].qh[ib] = (index[0] >> 8) | ((index[1] >> 8) << 4); + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + shifts[ib] = best_k; + max_scale = MAX(max_scale, scale); + } + + if (!max_scale) { + continue; + } + + uint16_t * sc = (uint16_t *)y[ibl].scales; + float d = max_scale/15; + float id = 1/d; + float sumqx_f = 0, sumq2_f = 0; + for (int ib = 0; ib < QK_K/block_size; ++ib) { + int l = nearest_int(0.5f*(id*scales[ib+0]-1)); + l = MAX(0, MIN(7, l)); + sc[ib/4] |= (l << 3*(ib%4)); + y[ibl].qh[ib] |= masks[shifts[ib]]; + const float * xb = xbl + block_size*ib; + if (quant_weights) { + const float * qw = quant_weights + QK_K*ibl + block_size*ib; + for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + } else { + for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i]; + } + for (int k = 0; k < block_size/8; ++k) { + if (k == 0) xx = shifts[ib] < 2 ? x_p : x_m; + else xx = shifts[ib]%2 == 0 ? x_p : x_m; + const int8_t * pg = (const int8_t *)(kgrid_q2xs + y[ibl].qs[2*ib+k] + ((y[ibl].qh[ib] << (8 - 4*k)) & 0x700)); + for (int j = 0; j < 8; ++j) { + float w = weight[8*k + j]; + float q = xx[(pg[j] - 1)/2]*(2*l+1); + sumqx_f += w*q*xb[8*k+j]; + sumq2_f += w*q*q; + } + } + } + if (sumq2_f > 0) d = sumqx_f/sumq2_f; + s.f16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed. + sc[0] |= ((s.u16 & 0x000f) << 12); + sc[1] |= ((s.u16 & 0x00f0) << 8); + sc[2] |= ((s.u16 & 0x0f00) << 4); + sc[3] |= ((s.u16 & 0xf000) << 0); + } +} + +size_t quantize_iq1_m(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK_K == 0); + float scales[QK_K/IQ1M_BLOCK_SIZE]; + float weight[IQ1M_BLOCK_SIZE]; + int8_t L[IQ1M_BLOCK_SIZE]; + float pairs[2*IQ1M_BLOCK_SIZE]; + uint16_t index[IQ1M_BLOCK_SIZE/8]; + int8_t shifts[QK_K/IQ1M_BLOCK_SIZE]; + int64_t nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_iq1_m_impl(src, qrow, n_per_row, quant_weights, scales, weight, pairs, L, index, shifts); + src += n_per_row; + qrow += nblock*sizeof(block_iq1_m); + } + return nrow * nblock * sizeof(block_iq1_m); +} + +// ============================ 4-bit non-linear quants + +static inline int best_index_int8(int n, const int8_t * val, float x) { + if (x <= val[0]) return 0; + if (x >= val[n-1]) return n-1; + int ml = 0, mu = n-1; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < val[mav]) mu = mav; else ml = mav; + } + return x - val[mu-1] < val[mu] - x ? mu-1 : mu; +} + +static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * GGML_RESTRICT x, + ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l, + float * scales, float * weight, uint8_t * L, + const int8_t * values, + const float * quant_weights, + const int ntry) { + + float sigma2 = 0; + for (int j = 0; j < super_block_size; ++j) sigma2 += x[j]*x[j]; + sigma2 *= 2.f/super_block_size; + + memset(q4, 0, super_block_size/2); + dh[0] = GGML_FP32_TO_FP16(0.f); + + float max_scale = 0, amax_scale = 0; + for (int ib = 0; ib < super_block_size/block_size; ++ib) { + const float * xb = x + ib*block_size; + uint8_t * Lb = L + ib*block_size; + if (quant_weights) { + const float * qw = quant_weights + ib*block_size; + for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j]; + } + float amax = 0, max = 0; + for (int j = 0; j < block_size; ++j) { + float ax = fabsf(xb[j]); + if (ax > amax) { + amax = ax; max = xb[j]; + } + } + if (amax < GROUP_MAX_EPS) { + scales[ib] = 0; + continue; + } + float d = ntry > 0 ? -max/values[0] : max/values[0]; + float id = 1/d; + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < block_size; ++j) { + float al = id*xb[j]; + int l = best_index_int8(16, values, al); + Lb[j] = l; + float q = values[l]; + float w = weight[j]; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + d = sumqx/sumq2; + float best = d*sumqx; + for (int itry = -ntry; itry <= ntry; ++itry) { + id = (itry + values[0])/max; + sumqx = sumq2 = 0; + for (int j = 0; j < block_size; ++j) { + float al = id*xb[j]; + int l = best_index_int8(16, values, al); + float q = values[l]; + float w = weight[j]; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d * sumqx; + } + } + scales[ib] = d; + float abs_d = fabsf(d); + if (abs_d > amax_scale) { + amax_scale = abs_d; max_scale = d; + } + } + + if (super_block_size/block_size > 1) { + int nb = super_block_size/block_size; + memset(scales_h, 0, ((nb+7)/8)*sizeof(uint16_t)); + float d = -max_scale/32; + dh[0] = GGML_FP32_TO_FP16(d); + float id = d ? 1/d : 0.f; + for (int ib = 0; ib < super_block_size/block_size; ++ib) { + int l = nearest_int(id*scales[ib]); + l = MAX(-32, MIN(31, l)); + float dl = d * l; + float idl = dl ? 1/dl : 0.f; + uint8_t * Lb = L + ib*block_size; + const float * xb = x + ib*block_size; + for (int j = 0; j < block_size; ++j) { + Lb[j] = best_index_int8(16, values, idl*xb[j]); + } + l += 32; + uint8_t l_l = l & 0xf; + uint8_t l_h = l >> 4; + if (ib%2 == 0) scales_l[ib/2] = l_l; + else scales_l[ib/2] |= (l_l << 4); + scales_h[ib/8] |= (l_h << 2*(ib%8)); + } + } else { + dh[0] = GGML_FP32_TO_FP16(scales[0]); + if (ntry > 0) { + float id = scales[0] ? 1/scales[0] : 0; + for (int j = 0; j < super_block_size; ++j) { + L[j] = best_index_int8(16, values, id*x[j]); + } + } + } + + for (int i = 0; i < super_block_size/32; ++i) { + for (int j = 0; j < 16; ++j) { + q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4); + } + } +} + +size_t quantize_iq4_nl(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK4_NL == 0); + int64_t nblock = n_per_row/QK4_NL; + char * qrow = (char *)dst; + uint8_t L[QK4_NL]; + float weight[QK4_NL]; + uint16_t unused_h; + uint8_t * unused_l = NULL; + float scale; + for (int64_t row = 0; row < nrow; ++row) { + block_iq4_nl * iq4 = (block_iq4_nl *)qrow; + for (int ibl = 0; ibl < nblock; ++ibl) { + const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL; + quantize_row_iq4_nl_impl(QK4_NL, 32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l, + &scale, weight, L, kvalues_iq4nl, qw, 7); + } + src += n_per_row; + qrow += nblock*sizeof(block_iq4_nl); + } + return nrow * nblock * sizeof(block_iq4_nl); +} + +//void quantize_row_iq4_nl_ref(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { +void quantize_row_iq4_nl_ref(const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k) { + GGML_ASSERT(k%QK4_NL == 0); + int64_t nblock = k/QK4_NL; + uint8_t L[QK4_NL]; + float weight[QK4_NL]; + uint16_t unused_h; + uint8_t * unused_l = NULL; + float scale; + block_iq4_nl * iq4 = y; + for (int ibl = 0; ibl < nblock; ++ibl) { + quantize_row_iq4_nl_impl(QK4_NL, 32, x + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l, + &scale, weight, L, kvalues_iq4nl, NULL, -1); + } +} + +size_t quantize_iq4_xs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK_K == 0); + int64_t nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + uint8_t L[QK_K]; + float weight[32]; + float scales[QK_K/32]; + for (int64_t row = 0; row < nrow; ++row) { + block_iq4_xs * iq4 = (block_iq4_xs *)qrow; + for (int ibl = 0; ibl < nblock; ++ibl) { + const float * qw = quant_weights ? quant_weights + QK_K*ibl : NULL; + quantize_row_iq4_nl_impl(QK_K, 32, src + QK_K*ibl, &iq4[ibl].d, iq4[ibl].qs, &iq4[ibl].scales_h, iq4[ibl].scales_l, + scales, weight, L, kvalues_iq4nl, qw, 7); + } + src += n_per_row; + qrow += nblock*sizeof(block_iq4_xs); + } + return nrow * nblock * sizeof(block_iq4_xs); +} + +void quantize_row_iq4_xs_ref(const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq4_xs(x, y, 1, k, NULL); +} + +// =============================== 2.5625 bpw + +static void quantize_row_iq2_s_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n, const float * GGML_RESTRICT quant_weights) { + + const int gindex = iq2_data_index(GGML_TYPE_IQ2_S); + + const uint64_t * kgrid_q2xs = iq2_data[gindex].grid; + const int * kmap_q2xs = iq2_data[gindex].map; + const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; + + GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); + + const int kMaxQ = 3; + + const int64_t nbl = n/QK_K; + + block_iq2_s * y = vy; + + float scales[QK_K/16]; + float weight[16]; + float xval[16]; + int8_t L[16]; + int8_t Laux[16]; + float waux[16]; + bool is_on_grid[2]; + bool is_on_grid_aux[2]; + uint8_t block_signs[2]; + + for (int ibl = 0; ibl < nbl; ++ibl) { + + memset(&y[ibl], 0, sizeof(block_iq2_s)); + y[ibl].d = GGML_FP32_TO_FP16(0.f); + + float max_scale = 0; + + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = 2*sumx2/QK_K; + + for (int ib = 0; ib < QK_K/16; ++ib) { + const float * xb = xbl + 16*ib; + if (quant_weights) { + const float * qw = quant_weights + QK_K*ibl + 16*ib; + for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + } else { + for (int i = 0; i < 16; ++i) weight[i] = 0.25f*sigma2 + xb[i]*xb[i]; + } + for (int i = 0; i < 16; ++i) waux[i] = sqrtf(weight[i]); + for (int k = 0; k < 2; ++k) { + uint8_t s = 0; + for (int i = 0; i < 8; ++i) { + if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i]; + else { + xval[8*k + i] = -xb[8*k + i]; s |= (1 << i); + } + } + block_signs[k] = s; + } + float max = xval[0]; + for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]); + if (max < GROUP_MAX_EPS_IQ2_S) { + scales[ib] = 0; + continue; + } + float best = 0; + float scale = max/(2*kMaxQ-1); + is_on_grid[0] = is_on_grid[1] = true; + for (int is = -9; is <= 9; ++is) { + float id = (2*kMaxQ-1+is*0.1f)/max; + float this_scale = 1/id; + for (int k = 0; k < 2; ++k) { + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l)); + } + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + is_on_grid_aux[k] = true; + if (grid_index < 0) { + is_on_grid_aux[k] = false; + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 16; ++i) { + float w = weight[i]; + float q = 2*Laux[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + scale = sumqx/sumq2; best = scale*sumqx; + for (int i = 0; i < 16; ++i) L[i] = Laux[i]; + for (int k = 0; k < 2; ++k) is_on_grid[k] = is_on_grid_aux[k]; + } + } + int n_not_ongrid = 0; + for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid; + if (n_not_ongrid > 0 && scale > 0) { + float id = 1/scale; + for (int k = 0; k < 2; ++k) { + if (is_on_grid[k]) continue; + uint16_t u = 0; + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + l = MAX(0, MIN(kMaxQ-1, l)); + u |= (l << 2*i); + L[8*k + i] = l; + } + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 16; ++i) { + float w = weight[i]; + float q = 2*L[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0) scale = sumqx/sumq2; + } + if (scale < 0) { + scale = -scale; + for (int k = 0; k < 2; ++k) block_signs[k] = ~block_signs[k]; + } + for (int k = 0; k < 2; ++k) { + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + printf("Oops: found point %u not on grid:", u); + for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]); + printf("\n"); + GGML_ABORT("fatal error"); + } + const int i8 = 2*ib + k; + y[ibl].qs[i8] = grid_index & 255; + y[ibl].qh[i8/4] |= ((grid_index >> 8) << 2*(i8%4)); + y[ibl].qs[QK_K/8 + i8] = block_signs[k]; + } + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + max_scale = MAX(max_scale, scale); + } + + if (!max_scale) { + continue; + } + + float d = max_scale/31; + y[ibl].d = GGML_FP32_TO_FP16(d * 0.9875f); + float id = 1/d; + for (int ib = 0; ib < QK_K/16; ++ib) { + int l = nearest_int(0.5f*(id*scales[ib]-1)); + l = MAX(0, MIN(15, l)); + if (ib%2 == 0) y[ibl].scales[ib/2] = l; + else y[ibl].scales[ib/2] |= (l << 4); + } + } +} + +size_t quantize_iq2_s(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK_K == 0); + int64_t nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_iq2_s_impl(src, qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += nblock*sizeof(block_iq2_s); + } + return nrow * nblock * sizeof(block_iq2_s); +} + +void quantize_row_iq2_s_ref(const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq2_s(x, y, 1, k, NULL); +} + +// =============================== data validation + +static bool validate_float(float f, size_t i) { + if (isinf(f)) { + fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i); + return false; + } + + if (isnan(f)) { + fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i); + return false; + } + + return true; +} + +static bool isinf_fp16(ggml_fp16_t f) { + return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) == 0; +} + +static bool isnan_fp16(ggml_fp16_t f) { + return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) != 0; +} + +static bool validate_fp16(ggml_fp16_t f, size_t i) { + if (isinf_fp16(f)) { + fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i); + return false; + } + + if (isnan_fp16(f)) { + fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i); + return false; + } + + return true; +} + +#define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \ + const type * q = (const type *) (data); \ + for (size_t i = 0; i < (nb); ++i) { \ + if (!validate_fp16(q[i].d, i)) { \ + return false; \ + } \ + } + +#define VALIDATE_ROW_DATA_DM_F16_IMPL(type, data, nb, d, m) \ + const type * q = (const type *) (data); \ + for (size_t i = 0; i < (nb); ++i) { \ + if (!validate_fp16(q[i].d, i) || !validate_fp16(q[i].m, i)) { \ + return false; \ + } \ + } + +#define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \ + const type * q = (const type *) (data); \ + for (size_t i = 0; i < (nb); ++i) { \ + for (size_t j = 0; j < (nr); ++j) { \ + if (!validate_fp16(q[i].d[j], i)) { \ + return false; \ + } \ + } \ + } + +bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes) { + if (type < 0 || type >= GGML_TYPE_COUNT) { + fprintf(stderr, "%s: invalid type %d\n", __func__, type); + return false; + } + + if (nbytes % ggml_type_size(type) != 0) { + fprintf(stderr, "%s: invalid size %zu for type %s (type size = %zu)\n", __func__, nbytes, ggml_type_name(type), ggml_type_size(type)); + return false; + } + + const size_t nb = nbytes/ggml_type_size(type); + + switch (type) { + case GGML_TYPE_BF16: + { + int nans = 0; + int infs = 0; + const unsigned short * f = (const unsigned short *) data; + for (size_t i = 0; i < nb; ++i) { + nans += (f[i] & 0x7fff) > 0x7f80; + infs += (f[i] & 0x7fff) == 0x7f80; + } + if (nans) { + fprintf(stderr, "%s: found %d NaNs in row of %zu BF16 values\n", __func__, nans, nb); + return false; + } + if (infs) { + fprintf(stderr, "%s: found %d infinities in row of %zu BF16 values\n", __func__, infs, nb); + return false; + } + } break; + case GGML_TYPE_F16: + { + const ggml_fp16_t * f = (const ggml_fp16_t *) data; + size_t i = 0; +#if defined(__AVX2__) + for (; i + 15 < nb; i += 16) { + __m256i v = _mm256_loadu_si256((const __m256i *)(f + i)); + __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi16(0x7c00)); + __m256i cmp = _mm256_cmpeq_epi16(vexp, _mm256_set1_epi16(0x7c00)); + int mask = _mm256_movemask_epi8(cmp); + if (mask) { + for (size_t j = 0; j < 16; ++j) { + if (!validate_fp16(f[i + j], i + j)) { + return false; + } + } + GGML_UNREACHABLE(); + } + } +#elif defined(__ARM_NEON) + for (; i + 7 < nb; i += 8) { + uint16x8_t v = vld1q_u16(f + i); + uint16x8_t vexp = vandq_u16(v, vdupq_n_u16(0x7c00)); + uint16x8_t cmp = vceqq_u16(vexp, vdupq_n_u16(0x7c00)); + uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(cmp, 4)), 0); + if (mask) { + for (size_t j = 0; j < 8; ++j) { + if (!validate_fp16(f[i + j], i + j)) { + return false; + } + } + GGML_UNREACHABLE(); + } + } +#endif + for (; i < nb; ++i) { + if (!validate_fp16(f[i], i)) { + return false; + } + } + } break; + case GGML_TYPE_F32: + { + const float * f = (const float *) data; + size_t i = 0; +#if defined(__AVX2__) + for (; i + 7 < nb; i += 8) { + __m256i v = _mm256_loadu_si256((const __m256i *)(f + i)); + __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi32(0x7f800000)); + __m256i cmp = _mm256_cmpeq_epi32(vexp, _mm256_set1_epi32(0x7f800000)); + int mask = _mm256_movemask_epi8(cmp); + if (mask) { + for (size_t j = 0; j < 8; ++j) { + if (!validate_float(f[i + j], i + j)) { + return false; + } + } + GGML_UNREACHABLE(); + } + } +#elif defined(__ARM_NEON) + for (; i + 3 < nb; i += 4) { + uint32x4_t v = vld1q_u32((const uint32_t *)f + i); + uint32x4_t vexp = vandq_u32(v, vdupq_n_u32(0x7f800000)); + uint32x4_t cmp = vceqq_u32(vexp, vdupq_n_u32(0x7f800000)); + uint64_t mask = vget_lane_u64(vreinterpret_u64_u16(vshrn_n_u32(cmp, 8)), 0); + if (mask) { + for (size_t j = 0; j < 4; ++j) { + if (!validate_float(f[i + j], i + j)) { + return false; + } + } + GGML_UNREACHABLE(); + } + } +#endif + for (; i < nb; ++i) { + if (!validate_float(f[i], i)) { + return false; + } + } + } break; + case GGML_TYPE_F64: + { + const double * f = (const double *) data; + for (size_t i = 0; i < nb; ++i) { + if (!validate_float(f[i], i)) { + return false; + } + } + } break; + case GGML_TYPE_Q4_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb); + } break; + case GGML_TYPE_Q4_1: + { + VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_1, data, nb, d, m); + } break; + case GGML_TYPE_Q5_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_0, data, nb); + } break; + case GGML_TYPE_Q5_1: + { + VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_1, data, nb, d, m); + } break; + case GGML_TYPE_Q8_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb); + } break; + case GGML_TYPE_Q2_K: + { + VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin); + } break; + case GGML_TYPE_Q3_K: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q3_K, data, nb); + } break; + case GGML_TYPE_Q4_K: + { + VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d, dmin); + } break; + case GGML_TYPE_Q5_K: + { + VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_K, data, nb, d, dmin); + } break; + case GGML_TYPE_Q6_K: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q6_K, data, nb); + } break; + case GGML_TYPE_Q8_K: + { + const block_q8_K * q = (const block_q8_K *) data; + for (size_t i = 0; i < nb; ++i) { + if (!validate_float(q[i].d, i)) { + return false; + } + } + } break; + case GGML_TYPE_TQ1_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_tq1_0, data, nb); + } break; + case GGML_TYPE_TQ2_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_tq2_0, data, nb); + } break; + case GGML_TYPE_IQ1_S: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb); + } break; + case GGML_TYPE_IQ1_M: + { + const block_iq1_m * q = (const block_iq1_m *) data; + for (size_t i = 0; i < nb; ++i) { + iq1m_scale_t scale; + const uint16_t * sc = (const uint16_t *)q[i].scales; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + if (!validate_fp16(scale.f16, i)) { + return false; + } + } + } break; + case GGML_TYPE_IQ2_XXS: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xxs, data, nb); + } break; + case GGML_TYPE_IQ2_XS: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xs, data, nb); + } break; + case GGML_TYPE_IQ2_S: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_s, data, nb); + } break; + case GGML_TYPE_IQ3_XXS: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_xxs, data, nb); + } break; + + case GGML_TYPE_IQ3_S: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_s, data, nb); + } break; + case GGML_TYPE_IQ4_XS: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_xs, data, nb); + } break; + case GGML_TYPE_IQ4_NL: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); + } break; + + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_I64: + // nothing to validate + break; + default: + { + fprintf(stderr, "%s: invalid type %d\n", __func__, type); + return false; + } + } + + return true; +} diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h new file mode 100644 index 0000000000000000000000000000000000000000..d09173e11161aa954f4b3258f35ab0ef2a3bb5f7 --- /dev/null +++ b/ggml/src/ggml-quants.h @@ -0,0 +1,100 @@ +#pragma once + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" + +#include "ggml.h" + +// GGML internal header + +#ifdef __cplusplus +extern "C" { +#endif + +// NOTE: these functions are defined as GGML_API because they used by the CPU backend + +// Quantization +GGML_API void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k); + +GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k); + +GGML_API void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_tq2_0_ref(const float * GGML_RESTRICT x, block_tq2_0 * GGML_RESTRICT y, int64_t k); + +GGML_API void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_iq4_nl_ref (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); + +// Dequantization +GGML_API void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +//GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + +GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + +GGML_API void dequantize_row_tq1_0(const block_tq1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_tq2_0(const block_tq2_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + +GGML_API void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_iq1_m (const block_iq1_m * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + +// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") +GGML_API size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_iq2_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_iq1_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_iq1_m (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + +GGML_API size_t quantize_tq1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_tq2_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + +GGML_API size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + +GGML_API void iq2xs_init_impl(enum ggml_type type); +GGML_API void iq2xs_free_impl(enum ggml_type type); +GGML_API void iq3xs_init_impl(int grid_size); +GGML_API void iq3xs_free_impl(int grid_size); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/ggml-rpc/CMakeLists.txt b/ggml/src/ggml-rpc/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..f5acb8ec2cb280f307ef89ceeb54334d5c06b644 --- /dev/null +++ b/ggml/src/ggml-rpc/CMakeLists.txt @@ -0,0 +1,9 @@ +message(STATUS "Using RPC backend") + +ggml_add_backend_library(ggml-rpc + ggml-rpc.cpp + ) + +if (WIN32) + target_link_libraries(ggml-rpc PRIVATE ws2_32) +endif() diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f468f796d57736790441ddb7c390d8fdd840471e --- /dev/null +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -0,0 +1,1816 @@ +#include "ggml-rpc.h" +#include "ggml-impl.h" +#include "ggml-backend-impl.h" +#include "ggml-cpp.h" + +#include +#include +#include +#include +#include +#include +#include +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +# include +#else +# include +# include +# include +# include +# include +# include +# include +#endif +#include +#include +#include + +namespace fs = std::filesystem; + +#ifdef _WIN32 +typedef SOCKET sockfd_t; +using ssize_t = __int64; +#else +typedef int sockfd_t; +#endif + +// cross-platform socket +struct socket_t { + sockfd_t fd; + socket_t(sockfd_t fd) : fd(fd) {} + ~socket_t() { + GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd); +#ifdef _WIN32 + closesocket(this->fd); +#else + close(this->fd); +#endif + } +}; + +// macro for nicer error messages on server crash +#define RPC_STATUS_ASSERT(x) if (!(x)) GGML_ABORT("Remote RPC server crashed or returned malformed response") + +// all RPC structures must be packed +#pragma pack(push, 1) +// ggml_tensor is serialized into rpc_tensor +struct rpc_tensor { + uint64_t id; + uint32_t type; + uint64_t buffer; + uint32_t ne[GGML_MAX_DIMS]; + uint32_t nb[GGML_MAX_DIMS]; + uint32_t op; + int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; + int32_t flags; + uint64_t src[GGML_MAX_SRC]; + uint64_t view_src; + uint64_t view_offs; + uint64_t data; + char name[GGML_MAX_NAME]; + + char padding[4]; +}; + +static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8"); + +// RPC commands +enum rpc_cmd { + RPC_CMD_ALLOC_BUFFER = 0, + RPC_CMD_GET_ALIGNMENT, + RPC_CMD_GET_MAX_SIZE, + RPC_CMD_BUFFER_GET_BASE, + RPC_CMD_FREE_BUFFER, + RPC_CMD_BUFFER_CLEAR, + RPC_CMD_SET_TENSOR, + RPC_CMD_SET_TENSOR_HASH, + RPC_CMD_GET_TENSOR, + RPC_CMD_COPY_TENSOR, + RPC_CMD_GRAPH_COMPUTE, + RPC_CMD_GET_DEVICE_MEMORY, + RPC_CMD_INIT_TENSOR, + RPC_CMD_GET_ALLOC_SIZE, + RPC_CMD_HELLO, + RPC_CMD_COUNT, +}; + +// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold +const size_t HASH_THRESHOLD = 10 * 1024 * 1024; + +struct rpc_msg_hello_rsp { + uint8_t major; + uint8_t minor; + uint8_t patch; +}; + +struct rpc_msg_get_alloc_size_req { + rpc_tensor tensor; +}; + +struct rpc_msg_get_alloc_size_rsp { + uint64_t alloc_size; +}; + +struct rpc_msg_init_tensor_req { + rpc_tensor tensor; +}; + +struct rpc_msg_alloc_buffer_req { + uint64_t size; +}; + +struct rpc_msg_alloc_buffer_rsp { + uint64_t remote_ptr; + uint64_t remote_size; +}; + +struct rpc_msg_get_alignment_rsp { + uint64_t alignment; +}; + +struct rpc_msg_get_max_size_rsp { + uint64_t max_size; +}; + +struct rpc_msg_buffer_get_base_req { + uint64_t remote_ptr; +}; + +struct rpc_msg_buffer_get_base_rsp { + uint64_t base_ptr; +}; + +struct rpc_msg_free_buffer_req { + uint64_t remote_ptr; +}; + +struct rpc_msg_buffer_clear_req { + uint64_t remote_ptr; + uint8_t value; +}; + +struct rpc_msg_set_tensor_hash_req { + rpc_tensor tensor; + uint64_t offset; + uint64_t hash; +}; + +struct rpc_msg_set_tensor_hash_rsp { + uint8_t result; +}; + +struct rpc_msg_get_tensor_req { + rpc_tensor tensor; + uint64_t offset; + uint64_t size; +}; + +struct rpc_msg_copy_tensor_req { + rpc_tensor src; + rpc_tensor dst; +}; + +struct rpc_msg_copy_tensor_rsp { + uint8_t result; +}; + +struct rpc_msg_graph_compute_rsp { + uint8_t result; +}; + +struct rpc_msg_get_device_memory_rsp { + uint64_t free_mem; + uint64_t total_mem; +}; +#pragma pack(pop) + +// RPC data structures + +static ggml_guid_t ggml_backend_rpc_guid() { + static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03}; + return &guid; +} + +struct ggml_backend_rpc_buffer_type_context { + std::string endpoint; + std::string name; + size_t alignment; + size_t max_size; +}; + +struct ggml_backend_rpc_context { + std::string endpoint; + std::string name; +}; + +struct ggml_backend_rpc_buffer_context { + std::shared_ptr sock; + void * base_ptr; + uint64_t remote_ptr; +}; + +// RPC helper functions + +// Computes FNV-1a hash of the data +static uint64_t fnv_hash(const uint8_t * data, size_t len) { + const uint64_t fnv_prime = 0x100000001b3ULL; + uint64_t hash = 0xcbf29ce484222325ULL; + + for (size_t i = 0; i < len; ++i) { + hash ^= data[i]; + hash *= fnv_prime; + } + return hash; +} + +static std::shared_ptr make_socket(sockfd_t fd) { +#ifdef _WIN32 + if (fd == INVALID_SOCKET) { + return nullptr; + } +#else + if (fd < 0) { + return nullptr; + } +#endif + return std::make_shared(fd); +} + +static bool set_no_delay(sockfd_t sockfd) { + int flag = 1; + // set TCP_NODELAY to disable Nagle's algorithm + int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int)); + return ret == 0; +} + +static bool set_reuse_addr(sockfd_t sockfd) { + int flag = 1; + int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int)); + return ret == 0; +} + +static std::shared_ptr socket_connect(const char * host, int port) { + struct sockaddr_in addr; + auto sockfd = socket(AF_INET, SOCK_STREAM, 0); + auto sock_ptr = make_socket(sockfd); + if (sock_ptr == nullptr) { + return nullptr; + } + if (!set_no_delay(sockfd)) { + fprintf(stderr, "Failed to set TCP_NODELAY\n"); + return nullptr; + } + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + struct hostent * server = gethostbyname(host); + if (server == NULL) { + fprintf(stderr, "Cannot resolve host '%s'\n", host); + return nullptr; + } + memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length); + if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { + return nullptr; + } + return sock_ptr; +} + +static std::shared_ptr socket_accept(sockfd_t srv_sockfd) { + auto client_socket_fd = accept(srv_sockfd, NULL, NULL); + auto client_socket = make_socket(client_socket_fd); + if (client_socket == nullptr) { + return nullptr; + } + if (!set_no_delay(client_socket_fd)) { + fprintf(stderr, "Failed to set TCP_NODELAY\n"); + return nullptr; + } + return client_socket; +} + +static std::shared_ptr create_server_socket(const char * host, int port) { + auto sockfd = socket(AF_INET, SOCK_STREAM, 0); + auto sock = make_socket(sockfd); + if (sock == nullptr) { + return nullptr; + } + if (!set_reuse_addr(sockfd)) { + fprintf(stderr, "Failed to set SO_REUSEADDR\n"); + return nullptr; + } + if (inet_addr(host) == INADDR_NONE) { + fprintf(stderr, "Invalid host address: %s\n", host); + return nullptr; + } + struct sockaddr_in serv_addr; + serv_addr.sin_family = AF_INET; + serv_addr.sin_addr.s_addr = inet_addr(host); + serv_addr.sin_port = htons(port); + + if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) { + return nullptr; + } + if (listen(sockfd, 1) < 0) { + return nullptr; + } + return sock; +} + +static bool send_data(sockfd_t sockfd, const void * data, size_t size) { + size_t bytes_sent = 0; + while (bytes_sent < size) { + ssize_t n = send(sockfd, (const char *)data + bytes_sent, size - bytes_sent, 0); + if (n < 0) { + return false; + } + bytes_sent += n; + } + return true; +} + +static bool recv_data(sockfd_t sockfd, void * data, size_t size) { + size_t bytes_recv = 0; + while (bytes_recv < size) { + ssize_t n = recv(sockfd, (char *)data + bytes_recv, size - bytes_recv, 0); + if (n <= 0) { + return false; + } + bytes_recv += n; + } + return true; +} + +static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) { + if (!send_data(sockfd, &msg_size, sizeof(msg_size))) { + return false; + } + return send_data(sockfd, msg, msg_size); +} + +static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) { + uint64_t size; + if (!recv_data(sockfd, &size, sizeof(size))) { + return false; + } + if (size != msg_size) { + return false; + } + return recv_data(sockfd, msg, msg_size); +} + +static bool recv_msg(sockfd_t sockfd, std::vector & input) { + uint64_t size; + if (!recv_data(sockfd, &size, sizeof(size))) { + return false; + } + try { + input.resize(size); + } catch (const std::bad_alloc & e) { + fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", size); + return false; + } + return recv_data(sockfd, input.data(), size); +} + +static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) { + size_t pos = endpoint.find(':'); + if (pos == std::string::npos) { + return false; + } + host = endpoint.substr(0, pos); + port = std::stoi(endpoint.substr(pos + 1)); + return true; +} + +// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | +// No response +static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const void * input, size_t input_size) { + uint8_t cmd_byte = cmd; + if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) { + return false; + } + if (!send_data(sock->fd, &input_size, sizeof(input_size))) { + return false; + } + if (!send_data(sock->fd, input, input_size)) { + return false; + } + return true; +} + +// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | +// RPC response: | response_size (8 bytes) | response_data (response_size bytes) | +static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) { + if (!send_rpc_cmd(sock, cmd, input, input_size)) { + return false; + } + // TODO: currently the output_size is always known, do we need support for commands with variable output size? + // even if we do, we can skip sending output_size from the server for commands with known output size + uint64_t out_size; + if (!recv_data(sock->fd, &out_size, sizeof(out_size))) { + return false; + } + if (out_size != output_size) { + return false; + } + if (!recv_data(sock->fd, output, output_size)) { + return false; + } + return true; +} + +// RPC client-side implementation + +static bool check_server_version(const std::shared_ptr & sock) { + rpc_msg_hello_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) { + fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch); + return false; + } + if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) { + fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch); + } + return true; +} + +static std::shared_ptr get_socket(const std::string & endpoint) { + static std::mutex mutex; + std::lock_guard lock(mutex); + static std::unordered_map> sockets; + static bool initialized = false; + + auto it = sockets.find(endpoint); + if (it != sockets.end()) { + if (auto sock = it->second.lock()) { + return sock; + } + } + std::string host; + int port; + if (!parse_endpoint(endpoint, host, port)) { + return nullptr; + } +#ifdef _WIN32 + if (!initialized) { + WSADATA wsaData; + int res = WSAStartup(MAKEWORD(2, 2), &wsaData); + if (res != 0) { + return nullptr; + } + initialized = true; + } +#else + GGML_UNUSED(initialized); +#endif + auto sock = socket_connect(host.c_str(), port); + if (sock == nullptr) { + return nullptr; + } + if (!check_server_version(sock)) { + return nullptr; + } + GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd); + sockets[endpoint] = sock; + return sock; +} + +static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + rpc_msg_free_buffer_req request = {ctx->remote_ptr}; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0); + RPC_STATUS_ASSERT(status); + delete ctx; +} + +static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + if (ctx->base_ptr != nullptr) { + return ctx->base_ptr; + } + rpc_msg_buffer_get_base_req request = {ctx->remote_ptr}; + rpc_msg_buffer_get_base_rsp response; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + ctx->base_ptr = reinterpret_cast(response.base_ptr); + return ctx->base_ptr; +} + +static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { + rpc_tensor result; + result.id = reinterpret_cast(tensor); + result.type = tensor->type; + if (tensor->buffer) { + ggml_backend_buffer_t buffer = tensor->buffer; + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + result.buffer = ctx->remote_ptr; + } else { + result.buffer = 0; + } + for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) { + result.ne[i] = tensor->ne[i]; + result.nb[i] = tensor->nb[i]; + } + result.op = tensor->op; + for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) { + result.op_params[i] = tensor->op_params[i]; + } + result.flags = tensor->flags; + for (uint32_t i = 0; i < GGML_MAX_SRC; i++) { + result.src[i] = reinterpret_cast(tensor->src[i]); + } + result.view_src = reinterpret_cast(tensor->view_src); + result.view_offs = tensor->view_offs; + result.data = reinterpret_cast(tensor->data); + + // Avoid sending uninitialized data over the wire + memset(result.name, 0, sizeof(result.name)); + memset(result.padding, 0, sizeof(result.padding)); + + snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name); + return result; +} + +static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + + // CUDA backend on the server pads everything to 512 due to CUDA limitations. + // Due to bandwidth constraints, we only call the server init tensor functions if necessary. + // In particular, only quantized tensors need padding + if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) { + rpc_msg_init_tensor_req request; + + request.tensor = serialize_tensor(tensor); + + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0); + RPC_STATUS_ASSERT(status); + } + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + rpc_tensor rpc_tensor = serialize_tensor(tensor); + if (size > HASH_THRESHOLD) { + rpc_msg_set_tensor_hash_req request; + request.tensor = rpc_tensor; + request.offset = offset; + request.hash = fnv_hash((const uint8_t*)data, size); + rpc_msg_set_tensor_hash_rsp response; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + if (response.result) { + // the server has the same data, no need to send it + return; + } + } + // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) + size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size; + std::vector input(input_size, 0); + memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor)); + memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); + memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size); + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size()); + RPC_STATUS_ASSERT(status); +} + +static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + rpc_msg_get_tensor_req request; + request.tensor = serialize_tensor(tensor); + request.offset = offset; + request.size = size; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size); + RPC_STATUS_ASSERT(status); +} + +static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { + // check if src and dst are on the same server + ggml_backend_buffer_t src_buffer = src->buffer; + ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context; + ggml_backend_buffer_t dst_buffer = dst->buffer; + ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context; + if (src_ctx->sock != dst_ctx->sock) { + return false; + } + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + rpc_msg_copy_tensor_req request; + request.src = serialize_tensor(src); + request.dst = serialize_tensor(dst); + rpc_msg_copy_tensor_rsp response; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + return response.result; +} + +static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value}; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0); + RPC_STATUS_ASSERT(status); +} + +static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = { + /* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer, + /* .get_base = */ ggml_backend_rpc_buffer_get_base, + /* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor, + /* .memset_tensor = */ NULL, + /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor, + /* .clear = */ ggml_backend_rpc_buffer_clear, + /* .reset = */ NULL, +}; + +static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) { + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + return buft_ctx->name.c_str(); +} + +static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + rpc_msg_alloc_buffer_req request = {size}; + rpc_msg_alloc_buffer_rsp response; + auto sock = get_socket(buft_ctx->endpoint); + bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + if (response.remote_ptr != 0) { + ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft, + ggml_backend_rpc_buffer_interface, + new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr}, + response.remote_size); + return buffer; + } else { + return nullptr; + } +} + +static size_t get_alignment(const std::shared_ptr & sock) { + rpc_msg_get_alignment_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + return response.alignment; +} + +static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + return buft_ctx->alignment; +} + +static size_t get_max_size(const std::shared_ptr & sock) { + rpc_msg_get_max_size_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + return response.max_size; +} + +static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) { + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + return buft_ctx->max_size; +} + +static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + // See comments in init_tensor. + if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) { + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + auto sock = get_socket(buft_ctx->endpoint); + + rpc_msg_get_alloc_size_req request; + + request.tensor = serialize_tensor(tensor); + + rpc_msg_get_alloc_size_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + + return response.alloc_size; + } else { + return ggml_nbytes(tensor); + } +} + +static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = { + /* .get_name = */ ggml_backend_rpc_buffer_type_name, + /* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_rpc_get_max_size, + /* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size, + /* .is_host = */ NULL, +}; + +static const char * ggml_backend_rpc_name(ggml_backend_t backend) { + ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; + + return rpc_ctx->name.c_str(); +} + +static void ggml_backend_rpc_free(ggml_backend_t backend) { + ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; + delete rpc_ctx; + delete backend; +} + +static void ggml_backend_rpc_synchronize(ggml_backend_t backend) { + GGML_UNUSED(backend); + // this is no-op because we don't have any async operations +} + +static void add_tensor(ggml_tensor * tensor, std::vector & tensors, std::unordered_set & visited) { + if (tensor == nullptr) { + return; + } + if (visited.find(tensor) != visited.end()) { + return; + } + visited.insert(tensor); + for (int i = 0; i < GGML_MAX_SRC; i++) { + add_tensor(tensor->src[i], tensors, visited); + } + add_tensor(tensor->view_src, tensors, visited); + tensors.push_back(serialize_tensor(tensor)); +} + +static void serialize_graph(const ggml_cgraph * cgraph, std::vector & output) { + uint32_t n_nodes = cgraph->n_nodes; + std::vector tensors; + std::unordered_set visited; + for (uint32_t i = 0; i < n_nodes; i++) { + add_tensor(cgraph->nodes[i], tensors, visited); + } + // serialization format: + // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | + uint32_t n_tensors = tensors.size(); + int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor); + output.resize(output_size, 0); + memcpy(output.data(), &n_nodes, sizeof(n_nodes)); + for (uint32_t i = 0; i < n_nodes; i++) { + memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t)); + } + uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t)); + *out_ntensors = n_tensors; + rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t)); + memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor)); +} + +static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; + std::vector input; + serialize_graph(cgraph, input); + rpc_msg_graph_compute_rsp response; + auto sock = get_socket(rpc_ctx->endpoint); + bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + return (enum ggml_status)response.result; +} + +static ggml_backend_i ggml_backend_rpc_interface = { + /* .get_name = */ ggml_backend_rpc_name, + /* .free = */ ggml_backend_rpc_free, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ ggml_backend_rpc_synchronize, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_rpc_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, +}; + +ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { + static std::mutex mutex; + std::lock_guard lock(mutex); + // NOTE: buffer types are allocated and never freed; this is by design + static std::unordered_map buft_map; + auto it = buft_map.find(endpoint); + if (it != buft_map.end()) { + return it->second; + } + auto sock = get_socket(endpoint); + if (sock == nullptr) { + fprintf(stderr, "Failed to connect to %s\n", endpoint); + return nullptr; + } + size_t alignment = get_alignment(sock); + size_t max_size = get_max_size(sock); + ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context { + /* .endpoint = */ endpoint, + /* .name = */ "RPC[" + std::string(endpoint) + "]", + /* .alignment = */ alignment, + /* .max_size = */ max_size + }; + + ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type { + /* .iface = */ ggml_backend_rpc_buffer_type_interface, + /* .device = */ ggml_backend_rpc_add_device(endpoint), + /* .context = */ buft_ctx + }; + buft_map[endpoint] = buft; + return buft; +} + +ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { + ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { + /* .endpoint = */ endpoint, + /* .name = */ "RPC[" + std::string(endpoint) + "]", + }; + + ggml_backend_t backend = new ggml_backend { + /* .guid = */ ggml_backend_rpc_guid(), + /* .interface = */ ggml_backend_rpc_interface, + /* .device = */ ggml_backend_rpc_add_device(endpoint), + /* .context = */ ctx + }; + return backend; +} + +bool ggml_backend_is_rpc(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid()); +} + +static void get_device_memory(const std::shared_ptr & sock, size_t * free, size_t * total) { + rpc_msg_get_device_memory_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + *free = response.free_mem; + *total = response.total_mem; +} + +void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) { + auto sock = get_socket(endpoint); + if (sock == nullptr) { + *free = 0; + *total = 0; + return; + } + get_device_memory(sock, free, total); +} + +// RPC server-side implementation + +class rpc_server { +public: + rpc_server(ggml_backend_t backend, const char * cache_dir) + : backend(backend), cache_dir(cache_dir) { + } + ~rpc_server(); + + void hello(rpc_msg_hello_rsp & response); + void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response); + void get_alignment(rpc_msg_get_alignment_rsp & response); + void get_max_size(rpc_msg_get_max_size_rsp & response); + bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response); + bool free_buffer(const rpc_msg_free_buffer_req & request); + bool buffer_clear(const rpc_msg_buffer_clear_req & request); + bool set_tensor(const std::vector & input); + bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response); + bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response); + bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response); + bool graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response); + bool init_tensor(const rpc_msg_init_tensor_req & request); + bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response); + +private: + bool get_cached_file(uint64_t hash, std::vector & data); + ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor); + ggml_tensor * create_node(uint64_t id, + struct ggml_context * ctx, + const std::unordered_map & tensor_ptrs, + std::unordered_map & tensor_map); + + + ggml_backend_t backend; + const char * cache_dir; + std::unordered_set buffers; +}; + +void rpc_server::hello(rpc_msg_hello_rsp & response) { + response.major = RPC_PROTO_MAJOR_VERSION; + response.minor = RPC_PROTO_MINOR_VERSION; + response.patch = RPC_PROTO_PATCH_VERSION; + GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch); +} + +bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) { + ggml_backend_buffer_type_t buft; + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); + + if (tensor == nullptr) { + GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n"); + return false; + } + + if (tensor->buffer == nullptr) { + //No buffer allocated. + buft = ggml_backend_get_default_buffer_type(backend); + } else { + buft = tensor->buffer->buft; + } + + response.alloc_size = ggml_backend_buft_get_alloc_size(buft,tensor); + + return true; +} + +void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) { + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); + ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size); + response.remote_ptr = 0; + response.remote_size = 0; + if (buffer != nullptr) { + response.remote_ptr = reinterpret_cast(buffer); + response.remote_size = buffer->size; + GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size); + buffers.insert(buffer); + } else { + GGML_LOG_ERROR("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size); + } +} + +void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) { + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); + size_t alignment = ggml_backend_buft_get_alignment(buft); + GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment); + response.alignment = alignment; +} + +void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) { + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); + size_t max_size = ggml_backend_buft_get_max_size(buft); + GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size); + response.max_size = max_size; +} + +bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) { + GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr); + ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr); + if (buffers.find(buffer) == buffers.end()) { + GGML_LOG_ERROR("[%s] buffer not found\n", __func__); + return false; + } + void * base = ggml_backend_buffer_get_base(buffer); + response.base_ptr = reinterpret_cast(base); + return true; +} + +bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) { + GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr); + ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr); + if (buffers.find(buffer) == buffers.end()) { + GGML_LOG_ERROR("[%s] buffer not found\n", __func__); + return false; + } + ggml_backend_buffer_free(buffer); + buffers.erase(buffer); + return true; +} + +bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) { + GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value); + ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr); + if (buffers.find(buffer) == buffers.end()) { + GGML_LOG_ERROR("[%s] buffer not found\n", __func__); + return false; + } + ggml_backend_buffer_clear(buffer, request.value); + return true; +} + +ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) { + // Validate tensor type before using it + if (tensor->type >= GGML_TYPE_COUNT) { + GGML_LOG_ERROR("[%s] invalid tensor type received: %u\n", __func__, tensor->type); + return nullptr; + } + + ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type, + tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + + // ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type + if (result == nullptr) { + GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\\n", __func__, tensor->type); + return nullptr; + } + + for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) { + result->nb[i] = tensor->nb[i]; + } + result->buffer = reinterpret_cast(tensor->buffer); + if (result->buffer && buffers.find(result->buffer) == buffers.end()) { + result->buffer = nullptr; + } + + if (result->buffer) { + // require that the tensor data does not go beyond the buffer end + uint64_t tensor_size = (uint64_t) ggml_nbytes(result); + uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer); + uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer); + GGML_ASSERT(tensor->data + tensor_size >= tensor->data); // check for overflow + GGML_ASSERT(tensor->data >= buffer_start && tensor->data + tensor_size <= buffer_start + buffer_size); + } + + result->op = (ggml_op) tensor->op; + for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) { + result->op_params[i] = tensor->op_params[i]; + } + result->flags = tensor->flags; + result->data = reinterpret_cast(tensor->data); + ggml_set_name(result, tensor->name); + return result; +} + + +bool rpc_server::set_tensor(const std::vector & input) { + // serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) | + if (input.size() < sizeof(rpc_tensor) + sizeof(uint64_t)) { + return false; + } + const rpc_tensor * in_tensor = (const rpc_tensor *)input.data(); + uint64_t offset; + memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset)); + const size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset); + + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); + if (tensor == nullptr) { + GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); + return false; + } + GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size); + + // sanitize tensor->data + { + const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer); + const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer); + + if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) { + GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu) out of buffer bounds [0x%zx, 0x%zx)\n", + __func__, in_tensor->data, offset, size, p0, p1); + return false; + } + } + + const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset); + if (cache_dir && size > HASH_THRESHOLD) { + uint64_t hash = fnv_hash((const uint8_t*)data, size); + char hash_str[17]; + snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash); + // save to cache_dir/hash_str + fs::path cache_file = fs::path(cache_dir) / hash_str; + std::ofstream ofs(cache_file, std::ios::binary); + ofs.write((const char *)data, size); + printf("[%s] saved to '%s'\n", __func__, cache_file.c_str()); + } + ggml_backend_tensor_set(tensor, data, offset, size); + return true; +} + +bool rpc_server::get_cached_file(uint64_t hash, std::vector & data) { + if (!cache_dir) { + return false; + } + char hash_str[17]; + snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash); + fs::path cache_file = fs::path(cache_dir) / hash_str; + if (!fs::exists(cache_file)) { + return false; + } + std::ifstream ifs(cache_file, std::ios::binary); + ifs.seekg(0, std::ios::end); + size_t size = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + data.resize(size); + ifs.read((char *)data.data(), size); + return true; +} + +bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response) +{ + std::vector cached_file; + if (!get_cached_file(request.hash, cached_file)) { + response.result = 0; + return true; + } + size_t size = cached_file.size(); + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); + if (tensor == nullptr) { + GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); + return false; + } + GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n", + __func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash); + + // sanitize tensor->data + { + const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer); + const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer); + + if (request.tensor.data + request.offset < p0 + || request.tensor.data + request.offset >= p1 + || size > (p1 - request.tensor.data - request.offset)) { + GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu, hash=0x%" PRIx64 ") out of buffer bounds [0x%zx, 0x%zx)\n", + __func__, request.tensor.data, request.offset, size, request.hash, p0, p1); + return false; + } + } + ggml_backend_tensor_set(tensor, cached_file.data(), request.offset, size); + response.result = 1; + return true; +} + +bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) { + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); + if (tensor == nullptr) { + GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n"); + return false; + } + + // Call the backend's buffer_init_tensor function + ggml_backend_buffer_t buffer = tensor->buffer; + if (buffer && buffer->iface.init_tensor) { + buffer->iface.init_tensor(buffer, tensor); + } else { + GGML_LOG_ERROR("Null buffer for tensor passed to init_tensor function\n"); + } + + if (tensor->extra != nullptr) { + // This pointer can either be passed around client/server, or probably better stored server-side and kept track of. + // Currently unimplemented. + GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n"); + return false; + } + + return true; +} + +bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response) { + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); + if (tensor == nullptr) { + GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); + return false; + } + GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size); + + // sanitize tensor->data + { + const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer); + const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer); + + if (request.tensor.data + request.offset < p0 || + request.tensor.data + request.offset >= p1 || + request.size > (p1 - request.tensor.data - request.offset)) { + GGML_LOG_ERROR("[%s] requested tensor region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%" PRIu64 ") out of buffer bounds [0x%zx, 0x%zx)\n", + __func__, request.tensor.data, request.offset, request.size, p0, p1); + return false; + } + } + + response.resize(request.size, 0); + ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size); + return true; +} + +bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response) { + struct ggml_init_params params { + /*.mem_size =*/ 2*ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + + ggml_tensor * src = deserialize_tensor(ctx, &request.src); + ggml_tensor * dst = deserialize_tensor(ctx, &request.dst); + if (src == nullptr || dst == nullptr) { + GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__); + return false; + } + + uint64_t src_size = (uint64_t) ggml_nbytes(src); + uint64_t dst_data = (uint64_t) dst->data; + uint64_t dst_base = (uint64_t) ggml_backend_buffer_get_base(dst->buffer); + uint64_t dst_buf_sz = (uint64_t) ggml_backend_buffer_get_size(dst->buffer); + + if (dst_data + src_size > dst_base + dst_buf_sz) { + GGML_PRINT_DEBUG("[%s] out-of-bounds write in rpc_server::copy_tensor:\n" + " write range : [0x%" PRIx64 ", 0x%" PRIx64 "]\n" + " buffer base: [0x%" PRIx64 ", 0x%" PRIx64 "]\n", + __func__, + dst_data, + dst_data + src_size, + dst_base, + dst_base + dst_buf_sz); + return false; + } + + GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", + __func__, (void*) src->buffer, (void*) dst->buffer); + + response.result = ggml_backend_buffer_copy_tensor(src, dst); + return true; +} + +ggml_tensor * rpc_server::create_node(uint64_t id, + struct ggml_context * ctx, + const std::unordered_map & tensor_ptrs, + std::unordered_map & tensor_map) { + if (tensor_map.find(id) != tensor_map.end()) { + return tensor_map[id]; + } + // Safely find the tensor pointer + auto it_ptr = tensor_ptrs.find(id); + if (it_ptr == tensor_ptrs.end()) { + return nullptr; + } + const rpc_tensor * tensor = it_ptr->second; + + struct ggml_tensor * result = deserialize_tensor(ctx, tensor); + if (result == nullptr) { + return nullptr; + } + tensor_map[id] = result; + for (int i = 0; i < GGML_MAX_SRC; i++) { + // Check if the source ID is 0 before calling create_node recursively + if (tensor->src[i] == 0) { + result->src[i] = nullptr; + } else { + result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map); + // If the recursive call failed for a non-zero ID, propagate the error + if (result->src[i] == nullptr) { + GGML_LOG_ERROR("[%s] failed to create source node %d (src_id=%" PRIu64 ") for node id %" PRIu64 "\n", + __func__, i, tensor->src[i], id); + // Must return nullptr to signal failure up the call stack + return nullptr; + } + } + } + + // Handle view_src similarly + if (tensor->view_src == 0) { + result->view_src = nullptr; + } else { + result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map); + // If the recursive call failed for a non-zero ID, propagate the error + if (result->view_src == nullptr) { + GGML_LOG_ERROR("[%s] failed to create view_src node (view_src_id=%" PRIu64 ") for node id %" PRIu64 "\n", + __func__, tensor->view_src, id); + // Must return nullptr to signal failure up the call stack + return nullptr; + } + } + result->view_offs = tensor->view_offs; + return result; +} + +bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response) { + // serialization format: + // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | + if (input.size() < sizeof(uint32_t)) { + return false; + } + uint32_t n_nodes; + memcpy(&n_nodes, input.data(), sizeof(n_nodes)); + if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) { + return false; + } + const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes)); + uint32_t n_tensors; + memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors)); + if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) { + return false; + } + const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors)); + GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors); + + size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false); + + struct ggml_init_params params = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false); + graph->n_nodes = n_nodes; + std::unordered_map tensor_ptrs; + for (uint32_t i = 0; i < n_tensors; i++) { + tensor_ptrs[tensors[i].id] = &tensors[i]; + } + std::unordered_map tensor_map; + for (uint32_t i = 0; i < n_nodes; i++) { + int64_t id; + memcpy(&id, &nodes[i], sizeof(id)); + graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map); + + // Check if create_node failed for a *non-zero* ID. + // If id was 0, create_node returning nullptr is expected. + // If id was non-zero and create_node returned nullptr, it indicates a deserialization error. + if (graph->nodes[i] == nullptr && id != 0) { + GGML_LOG_ERROR("[%s] failed to create graph node %d (id=%" PRId64 ")\n", __func__, i, id); + return false; + } + } + ggml_status status = ggml_backend_graph_compute(backend, graph); + response.result = status; + return true; +} + +rpc_server::~rpc_server() { + for (auto buffer : buffers) { + ggml_backend_buffer_free(buffer); + } +} + +static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, + sockfd_t sockfd, size_t free_mem, size_t total_mem) { + rpc_server server(backend, cache_dir); + uint8_t cmd; + if (!recv_data(sockfd, &cmd, 1)) { + return; + } + // the first command sent by the client must be HELLO + if (cmd != RPC_CMD_HELLO) { + fprintf(stderr, "Expected HELLO command, update client\n"); + return; + } + if (!recv_msg(sockfd, nullptr, 0)) { + return; + } + rpc_msg_hello_rsp response; + server.hello(response); + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } + while (true) { + if (!recv_data(sockfd, &cmd, 1)) { + break; + } + if (cmd >= RPC_CMD_COUNT) { + // fail fast if the command is invalid + fprintf(stderr, "Unknown command: %d\n", cmd); + break; + } + switch (cmd) { + case RPC_CMD_HELLO: { + // HELLO command is handled above + return; + } + case RPC_CMD_ALLOC_BUFFER: { + rpc_msg_alloc_buffer_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + rpc_msg_alloc_buffer_rsp response; + server.alloc_buffer(request, response); + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_GET_ALLOC_SIZE: { + rpc_msg_get_alloc_size_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + rpc_msg_get_alloc_size_rsp response; + if (!server.get_alloc_size(request, response)) { + return; + } + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_GET_ALIGNMENT: { + if (!recv_msg(sockfd, nullptr, 0)) { + return; + } + rpc_msg_get_alignment_rsp response; + server.get_alignment(response); + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_GET_MAX_SIZE: { + if (!recv_msg(sockfd, nullptr, 0)) { + return; + } + rpc_msg_get_max_size_rsp response; + server.get_max_size(response); + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_BUFFER_GET_BASE: { + rpc_msg_buffer_get_base_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + rpc_msg_buffer_get_base_rsp response; + if (!server.buffer_get_base(request, response)) { + return; + } + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_FREE_BUFFER: { + rpc_msg_free_buffer_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + if (!server.free_buffer(request)) { + return; + } + if (!send_msg(sockfd, nullptr, 0)) { + return; + } + break; + } + case RPC_CMD_BUFFER_CLEAR: { + rpc_msg_buffer_clear_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + if (!server.buffer_clear(request)) { + return; + } + if (!send_msg(sockfd, nullptr, 0)) { + return; + } + break; + } + case RPC_CMD_SET_TENSOR: { + std::vector input; + if (!recv_msg(sockfd, input)) { + return; + } + if (!server.set_tensor(input)) { + return; + } + break; + } + case RPC_CMD_SET_TENSOR_HASH: { + rpc_msg_set_tensor_hash_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + rpc_msg_set_tensor_hash_rsp response; + if (!server.set_tensor_hash(request, response)) { + return; + } + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_INIT_TENSOR: { + rpc_msg_init_tensor_req request; + if (!recv_msg(sockfd, &request,sizeof(request))) { + return; + } + if (!server.init_tensor(request)) { + return; + } + if (!send_msg(sockfd, nullptr, 0)) { + return; + } + break; + } + case RPC_CMD_GET_TENSOR: { + rpc_msg_get_tensor_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + std::vector response; + if (!server.get_tensor(request, response)) { + return; + } + if (!send_msg(sockfd, response.data(), response.size())) { + return; + } + break; + } + case RPC_CMD_COPY_TENSOR: { + rpc_msg_copy_tensor_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + rpc_msg_copy_tensor_rsp response; + if (!server.copy_tensor(request, response)) { + return; + } + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_GRAPH_COMPUTE: { + std::vector input; + if (!recv_msg(sockfd, input)) { + return; + } + rpc_msg_graph_compute_rsp response; + if (!server.graph_compute(input, response)) { + return; + } + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_GET_DEVICE_MEMORY: { + if (!recv_msg(sockfd, nullptr, 0)) { + return; + } + rpc_msg_get_device_memory_rsp response; + response.free_mem = free_mem; + response.total_mem = total_mem; + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } + break; + } + default: { + fprintf(stderr, "Unknown command: %d\n", cmd); + return; + } + } + } +} + +void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, + const char * cache_dir, + size_t free_mem, size_t total_mem) { + printf("Starting RPC server v%d.%d.%d\n", + RPC_PROTO_MAJOR_VERSION, + RPC_PROTO_MINOR_VERSION, + RPC_PROTO_PATCH_VERSION); + printf(" endpoint : %s\n", endpoint); + printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a"); + printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024)); + + std::string host; + int port; + if (!parse_endpoint(endpoint, host, port)) { + return; + } +#ifdef _WIN32 + { + WSADATA wsaData; + int res = WSAStartup(MAKEWORD(2, 2), &wsaData); + if (res != 0) { + fprintf(stderr, "WSAStartup failed: %d\n", res); + return; + } + } +#endif + auto server_socket = create_server_socket(host.c_str(), port); + if (server_socket == nullptr) { + fprintf(stderr, "Failed to create server socket\n"); + return; + } + while (true) { + auto client_socket = socket_accept(server_socket->fd); + if (client_socket == nullptr) { + fprintf(stderr, "Failed to accept client connection\n"); + return; + } + printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem); + fflush(stdout); + rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem); + printf("Client connection closed\n"); + fflush(stdout); + } +#ifdef _WIN32 + WSACleanup(); +#endif +} + +// device interface + +struct ggml_backend_rpc_device_context { + std::string endpoint; + std::string name; +}; + +static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; + + return ctx->name.c_str(); +} + +static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; + + return ctx->name.c_str(); +} + +static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; + + ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total); + + GGML_UNUSED(dev); +} + +static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) { + // TODO: obtain value from the server + return GGML_BACKEND_DEVICE_TYPE_GPU; + + GGML_UNUSED(dev); +} + +static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_rpc_device_get_name(dev); + props->description = ggml_backend_rpc_device_get_description(dev); + props->type = ggml_backend_rpc_device_get_type(dev); + ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ false, + /* .events = */ false, + }; +} + +static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) { + ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; + + return ggml_backend_rpc_init(ctx->endpoint.c_str()); + + GGML_UNUSED(params); +} + +static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) { + ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; + + return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str()); + + GGML_UNUSED(dev); +} + +static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + GGML_UNUSED(dev); + GGML_UNUSED(op); + //TODO: call the remote backend and cache the results + return true; +} + +static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) { + return false; + } + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context; + return buft_ctx->endpoint == dev_ctx->endpoint; +} + +static const struct ggml_backend_device_i ggml_backend_rpc_device_i = { + /* .get_name = */ ggml_backend_rpc_device_get_name, + /* .get_description = */ ggml_backend_rpc_device_get_description, + /* .get_memory = */ ggml_backend_rpc_device_get_memory, + /* .get_type = */ ggml_backend_rpc_device_get_type, + /* .get_props = */ ggml_backend_rpc_device_get_props, + /* .init_backend = */ ggml_backend_rpc_device_init, + /* .get_buffer_type = */ ggml_backend_rpc_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ NULL, + /* .supports_op = */ ggml_backend_rpc_device_supports_op, + /* .supports_buft = */ ggml_backend_rpc_device_supports_buft, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +// backend reg interface + +static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) { + return "RPC"; + + GGML_UNUSED(reg); +} + +static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) { + return 0; + + GGML_UNUSED(reg); +} + +static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) { + GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead"); + + GGML_UNUSED(reg); + GGML_UNUSED(index); +} + +static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) { + if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) { + return (void *)ggml_backend_rpc_add_device; + } + if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) { + return (void *)ggml_backend_rpc_start_server; + } + return NULL; + + GGML_UNUSED(reg); +} + +static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = { + /* .get_name = */ ggml_backend_rpc_reg_get_name, + /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count, + /* .get_device = */ ggml_backend_rpc_reg_get_device, + /* .get_proc_address = */ ggml_backend_rpc_get_proc_address, +}; + +ggml_backend_reg_t ggml_backend_rpc_reg(void) { + static struct ggml_backend_reg ggml_backend_rpc_reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_rpc_reg_i, + /* .context = */ NULL, + }; + + return &ggml_backend_rpc_reg; +} + +ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) { + static std::unordered_map dev_map; + + static std::mutex mutex; + std::lock_guard lock(mutex); + + if (dev_map.find(endpoint) != dev_map.end()) { + return dev_map[endpoint]; + } + + ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context { + /* .endpoint = */ endpoint, + /* .name = */ "RPC[" + std::string(endpoint) + "]", + }; + + ggml_backend_dev_t dev = new ggml_backend_device { + /* .iface = */ ggml_backend_rpc_device_i, + /* .reg = */ ggml_backend_rpc_reg(), + /* .context = */ ctx, + }; + + dev_map[endpoint] = dev; + + return dev; +} + +GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg) diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..efd78b912cc65cb66fccb0a10b956e44eeafd8a7 --- /dev/null +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -0,0 +1,189 @@ +message(STATUS "GGML_SYCL_TARGET=${GGML_SYCL_TARGET}") + +if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL|NVIDIA|AMD)$") + message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL, NVIDIA, or AMD") +endif() + +check_cxx_compiler_flag("-fsycl" SUPPORTS_SYCL) + +if (DEFINED ENV{ONEAPI_ROOT}) + message(STATUS "Using oneAPI Release SYCL compiler (icpx).") +elseif(SUPPORTS_SYCL) + message(WARNING "Using open-source SYCL compiler (clang++). Didn't detect ENV {ONEAPI_ROOT}. + If you expected the oneAPI Release compiler, please install oneAPI & source it, like: + source /opt/intel/oneapi/setvars.sh") +else() + message(FATAL_ERROR "C++ compiler lacks SYCL support.") +endif() +message(STATUS "SYCL found") +#todo: AOT + +ggml_add_backend_library(ggml-sycl + ggml-sycl.cpp + ../../include/ggml-sycl.h + ) + +file(GLOB GGML_HEADERS_SYCL "*.hpp") +file(GLOB GGML_SOURCES_SYCL "*.cpp") +target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL}) + +if (WIN32) + # To generate a Visual Studio solution, using Intel C++ Compiler for ggml-sycl is mandatory + if( ${CMAKE_GENERATOR} MATCHES "Visual Studio" AND NOT (${CMAKE_GENERATOR_TOOLSET} MATCHES "Intel C")) + set_target_properties(ggml-sycl PROPERTIES VS_PLATFORM_TOOLSET "Intel C++ Compiler 2025") + set(CMAKE_CXX_COMPILER "icx") + set(CMAKE_CXX_COMPILER_ID "IntelLLVM") + endif() +endif() + +find_package(IntelSYCL) +if (IntelSYCL_FOUND) + # Use oneAPI CMake when possible + target_link_libraries(ggml-sycl PRIVATE IntelSYCL::SYCL_CXX) +else() + # Fallback to the simplest way of enabling SYCL when using intel/llvm nightly for instance + target_compile_options(ggml-sycl PRIVATE "-fsycl") + target_link_options(ggml-sycl PRIVATE "-fsycl") +endif() + +target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing") + +# Link against oneDNN +set(GGML_SYCL_DNNL 0) +if(GGML_SYCL_DNN) + find_package(DNNL) + if(DNNL_FOUND) + if (NOT DEFINED DNNL_GPU_VENDOR) + # default to intel target + set(DNNL_GPU_VENDOR "INTEL") + if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL") + message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target") + endif() + endif() + + # Verify oneDNN was compiled for the same target as llama + if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}") + target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl) + set(GGML_SYCL_DNNL 1) + get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS) + foreach(CONFIG ${CONFIGS}) + get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG}) + message(STATUS "Found oneDNN: ${DNNL_LIB}") + endforeach() + else() + message(WARNING + "oneDNN must be compiled for the same target as llama.cpp. + llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}. + Disabling oneDNN support.") + endif() + else() + message(STATUS "oneDNN not found, disabling oneDNN support") + endif() +else() + message(STATUS "oneDNN support disabled by the user") +endif() +target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL}) + +if (GGML_SYCL_F16) + if (GGML_SYCL_TARGET STREQUAL "AMD") + message(WARNING "AMD target does not entirely support FP16 in the SYCL backend.") + endif() + add_compile_definitions(GGML_SYCL_F16) +endif() + +if (GGML_SYCL_TARGET STREQUAL "NVIDIA") + add_compile_definitions(GGML_SYCL_WARP_SIZE=32) +elseif (GGML_SYCL_TARGET STREQUAL "AMD") + # INFO: Allowed Sub_group_sizes are not consistent through all + # hip targets. For example, 64 is used for certain models, but the backend + # does not support it. + # Target archs tested working: gfx1030, gfx1031, (Only tested sub_group_size = 32) + add_compile_definitions(GGML_SYCL_WARP_SIZE=32) +else() + add_compile_definitions(GGML_SYCL_WARP_SIZE=16) +endif() + +if (GGML_SYCL_GRAPH) + target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GRAPH) +endif() + +# Link against Intel oneMKL or oneMath +if (GGML_SYCL_TARGET STREQUAL "INTEL") + # Intel devices use Intel oneMKL directly instead of oneMath to avoid the limitation of linking Intel oneMKL statically + # See https://github.com/uxlfoundation/oneMath/issues/654 + if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + set(SYCL_COMPILER ON) + endif() + find_package(MKL REQUIRED) + target_link_libraries(ggml-sycl PRIVATE MKL::MKL_SYCL::BLAS) + target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_USE_INTEL_ONEMKL) +else() + find_package(oneMath QUIET) + if (NOT oneMath_FOUND) + message(STATUS "oneMath not found: oneMath will be automatically downloaded") + # Use FetchContent to automatically pull and build oneMath + include(FetchContent) + set(BUILD_FUNCTIONAL_TESTS False) + set(BUILD_EXAMPLES False) + set(TARGET_DOMAINS blas) + if (GGML_SYCL_TARGET STREQUAL "NVIDIA") + set(ENABLE_MKLCPU_BACKEND False) + set(ENABLE_MKLGPU_BACKEND False) + set(ENABLE_CUBLAS_BACKEND True) + elseif (GGML_SYCL_TARGET STREQUAL "AMD") + set(ENABLE_MKLCPU_BACKEND False) + set(ENABLE_MKLGPU_BACKEND False) + set(ENABLE_ROCBLAS_BACKEND True) + # Ensure setting a string variable here is not overriden by oneMath CACHE variables + cmake_policy(SET CMP0126 NEW) + # Setting the device architecture is only needed and useful for AMD devices in oneMath + set(HIP_TARGETS ${GGML_SYCL_DEVICE_ARCH} CACHE STRING "oneMath HIP target" FORCE) + endif() + FetchContent_Declare( + ONEMATH + GIT_REPOSITORY https://github.com/uxlfoundation/oneMath.git + GIT_TAG 8efe85f5aaebb37f1d8c503b7af66315feabf142 + ) + FetchContent_MakeAvailable(ONEMATH) + # Create alias to match with find_package targets name + function(onemath_alias target) + if (TARGET ${target}_obj) + # Silence verbose warnings from external libraries + target_compile_options(${target}_obj PRIVATE -w) + endif() + if (TARGET ${target}) + add_library(ONEMATH::${target} ALIAS ${target}) + endif() + endfunction() + onemath_alias(onemath) + onemath_alias(onemath_blas_mklcpu) + onemath_alias(onemath_blas_mklgpu) + onemath_alias(onemath_blas_cublas) + onemath_alias(onemath_blas_rocblas) + endif() + + # Below oneMath compile-time dispatching is used for better performance + if (GGML_SYCL_TARGET STREQUAL "NVIDIA") + target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath_blas_cublas) + target_compile_options(ggml-sycl PRIVATE "-fsycl-targets=nvptx64-nvidia-cuda") + target_link_options(ggml-sycl PRIVATE "-fsycl-targets=nvptx64-nvidia-cuda") + target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_NVIDIA) + elseif (GGML_SYCL_TARGET STREQUAL "AMD") + if (NOT GGML_SYCL_DEVICE_ARCH) + message(FATAL_ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.") + endif() + target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath_blas_rocblas) + target_compile_options(ggml-sycl PRIVATE "-fsycl-targets=amdgcn-amd-amdhsa") + target_link_options(ggml-sycl PRIVATE "-fsycl-targets=amdgcn-amd-amdhsa") + target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_AMD) + else() + # Fallback to oneMath runtime dispatcher + target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath) + target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GENERIC) + endif() +endif() + +if (GGML_SYCL_DEVICE_ARCH) + target_compile_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH}) + target_link_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH}) +endif() diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f78a36ddf8f669a77bd9f42eb86337c83344c7ca --- /dev/null +++ b/ggml/src/ggml-sycl/backend.hpp @@ -0,0 +1,37 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_BACKEND_HPP +#define GGML_SYCL_BACKEND_HPP + +#include "binbcast.hpp" +#include "common.hpp" +#include "concat.hpp" +#include "conv.hpp" +#include "convert.hpp" +#include "cpy.hpp" +#include "dequantize.hpp" +#include "dmmv.hpp" +#include "element_wise.hpp" +#include "gla.hpp" +#include "im2col.hpp" +#include "mmq.hpp" +#include "mmvq.hpp" +#include "norm.hpp" +#include "outprod.hpp" +#include "quants.hpp" +#include "rope.hpp" +#include "softmax.hpp" +#include "tsembd.hpp" +#include "wkv.hpp" + +#endif // GGML_SYCL_BACKEND_HPP diff --git a/ggml/src/ggml-sycl/binbcast.cpp b/ggml/src/ggml-sycl/binbcast.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0a3883ae1eda57017c864be9bc60ab231be8cdce --- /dev/null +++ b/ggml/src/ggml-sycl/binbcast.cpp @@ -0,0 +1,345 @@ +#include "binbcast.hpp" + +#include +#include +#include + +#include "ggml.h" + +template +static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, + int ne0, int ne1, int ne2, int ne3, + int ne10, int ne11, int ne12, int ne13, + /*int s0, */ int s1, int s2, int s3, + /*int s00,*/ int s01, int s02, int s03, + /*int s10,*/ int s11, int s12, int s13, + const sycl::nd_item<3> &item_ct1) { + const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1)); + const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + + item_ct1.get_local_id(0)) / + ne3; + const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + + item_ct1.get_local_id(0)) % + ne3; + + if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + return; + } + + const int i11 = i1 % ne11; + const int i12 = i2 % ne12; + const int i13 = i3 % ne13; + + const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; + const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; + const size_t i_dst = i3*s3 + i2*s2 + i1*s1; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + dst_t * dst_row = dst + i_dst; + + for (int i0 = i0s; i0 < ne0; + i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) { + const int i10 = i0 % ne10; + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); + } +} + +template +static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, + int ne0, int ne1, int ne2, int ne3, + int ne10, int ne11, int ne12, int ne13, + /*int s0, */ int s1, int s2, int s3, + /*int s00,*/ int s01, int s02, int s03, + /*int s10,*/ int s11, int s12, int s13, + const sycl::nd_item<3> &item_ct1) { + + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + const int i3 = i/(ne2*ne1*ne0); + const int i2 = (i/(ne1*ne0)) % ne2; + const int i1 = (i/ne0) % ne1; + const int i0 = i % ne0; + + if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + return; + } + + const int i11 = i1 % ne11; + const int i12 = i2 % ne12; + const int i13 = i3 % ne13; + + const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; + const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; + const size_t i_dst = i3*s3 + i2*s2 + i1*s1; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + dst_t * dst_row = dst + i_dst; + + const int i10 = i0 % ne10; + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); +} + + +template +struct bin_bcast_sycl { + template + void operator()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00, + const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11, + const int64_t ne12, const int64_t ne13, const int64_t ne0, const int64_t ne1, const int64_t ne2, + const int64_t ne3, const size_t nb00, const size_t nb01, const size_t nb02, const size_t nb03, + const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0, + const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous, + const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) { + int nr0 = ne10 / ne0; + int nr1 = ne11/ne1; + int nr2 = ne12/ne2; + int nr3 = ne13/ne3; + + int nr[4] = { nr0, nr1, nr2, nr3 }; + + // collapse dimensions until first broadcast dimension + int64_t cne[] = {ne0, ne1, ne2, ne3}; + int64_t cne0[] = {ne00, ne01, ne02, ne03}; + int64_t cne1[] = {ne10, ne11, ne12, ne13}; + size_t cnb[] = {nb0, nb1, nb2, nb3}; + size_t cnb0[] = {nb00, nb01, nb02, nb03}; + size_t cnb1[] = {nb10, nb11, nb12, nb13}; + auto collapse = [](int64_t cne[]) { + cne[0] *= cne[1]; + cne[1] = cne[2]; + cne[2] = cne[3]; + cne[3] = 1; + }; + + auto collapse_nb = [](size_t cnb[], int64_t cne[]) { + cnb[1] *= cne[1]; + cnb[2] *= cne[2]; + cnb[3] *= cne[3]; + }; + + if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) { + for (int i = 0; i < 4; i++) { + if (nr[i] != 1) { + break; + } + if (i > 0) { + collapse_nb(cnb, cne); + collapse_nb(cnb0, cne0); + collapse_nb(cnb1, cne1); + collapse(cne); + collapse(cne0); + collapse(cne1); + } + } + } + { + int64_t ne0 = cne[0]; + int64_t ne1 = cne[1]; + int64_t ne2 = cne[2]; + int64_t ne3 = cne[3]; + + int64_t ne10 = cne1[0]; + int64_t ne11 = cne1[1]; + int64_t ne12 = cne1[2]; + int64_t ne13 = cne1[3]; + + size_t nb0 = cnb[0]; + size_t nb1 = cnb[1]; + size_t nb2 = cnb[2]; + size_t nb3 = cnb[3]; + + size_t nb00 = cnb0[0]; + size_t nb01 = cnb0[1]; + size_t nb02 = cnb0[2]; + size_t nb03 = cnb0[3]; + + size_t nb10 = cnb1[0]; + size_t nb11 = cnb1[1]; + size_t nb12 = cnb1[2]; + size_t nb13 = cnb1[3]; + + size_t s0 = nb0 / sizeof(dst_t); + size_t s1 = nb1 / sizeof(dst_t); + size_t s2 = nb2 / sizeof(dst_t); + size_t s3 = nb3 / sizeof(dst_t); + + size_t s10 = nb10 / sizeof(src1_t); + size_t s11 = nb11 / sizeof(src1_t); + size_t s12 = nb12 / sizeof(src1_t); + size_t s13 = nb13 / sizeof(src1_t); + + size_t s00 = nb00 / sizeof(src0_t); + size_t s01 = nb01 / sizeof(src0_t); + size_t s02 = nb02 / sizeof(src0_t); + size_t s03 = nb03 / sizeof(src0_t); + + GGML_UNUSED(s00); + + GGML_ASSERT(nb0 % sizeof(dst_t) == 0); + GGML_ASSERT(nb1 % sizeof(dst_t) == 0); + GGML_ASSERT(nb2 % sizeof(dst_t) == 0); + GGML_ASSERT(nb3 % sizeof(dst_t) == 0); + + GGML_ASSERT(nb00 % sizeof(src0_t) == 0); + GGML_ASSERT(nb01 % sizeof(src0_t) == 0); + GGML_ASSERT(nb02 % sizeof(src0_t) == 0); + GGML_ASSERT(nb03 % sizeof(src0_t) == 0); + + GGML_ASSERT(nb10 % sizeof(src1_t) == 0); + GGML_ASSERT(nb11 % sizeof(src1_t) == 0); + GGML_ASSERT(nb12 % sizeof(src1_t) == 0); + GGML_ASSERT(nb13 % sizeof(src1_t) == 0); + + GGML_ASSERT(s0 == 1); + GGML_ASSERT(s10 == 1); + + const int block_size = 128; + + int64_t hne0 = std::max(ne0/2LL, 1LL); + + sycl::range<3> block_dims(1, 1, 1); + block_dims[2] = std::min(hne0, block_size); + block_dims[1] = std::min( + ne1, block_size / (unsigned int)block_dims[2]); + block_dims[0] = std::min( + std::min( + ne2 * ne3, block_size / (unsigned int)block_dims[2] / + (unsigned int)block_dims[1]), + 64U); + + sycl::range<3> block_nums( + (ne2 * ne3 + block_dims[0] - 1) / block_dims[0], + (ne1 + block_dims[1] - 1) / block_dims[1], + (hne0 + block_dims[2] - 1) / block_dims[2]); + + if (block_nums[0] > 65535) { + // this is the maximum number of blocks in z direction, fallback to 1D grid kernel + int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size; + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * + sycl::range<3>(1, 1, block_size), + sycl::range<3>(1, 1, block_size)), + [=](sycl::nd_item<3> item_ct1) { + k_bin_bcast_unravel( + src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, + ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02, + s03, s11, s12, s13, item_ct1); + }); + } + } else { + /* + DPCT1049:16: The work-group size passed to the SYCL kernel may + exceed the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if + needed. + */ + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1, + ne2, ne3, ne10, ne11, ne12, ne13, + s1, s2, s3, s01, s02, s03, s11, s12, s13, + item_ct1); + }); + } + } + } +}; + +template +inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, + ggml_tensor * dst) { + dpct::queue_ptr main_stream = ctx.stream(); + GGML_TENSOR_BINARY_OP_LOCALS + + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + op()((const float *) src0->data, (const float *) src1->data, (float *) dst->data, ne00, ne01, ne02, ne03, ne10, + ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, + ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + op()((const sycl::half *) src0->data, (const sycl::half *) src1->data, (sycl::half *) dst->data, ne00, ne01, + ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, + nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), + main_stream); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + op()((const sycl::half *) src0->data, (const float *) src1->data, (sycl::half *) dst->data, ne00, ne01, ne02, + ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, + nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { + op()((const int32_t *) src0->data, (const int32_t *) src1->data, (int32_t *) dst->data, ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, + nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) { + op()((const int16_t *) src0->data, (const int16_t *) src1->data, (int16_t *) dst->data, ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, + nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type), + ggml_type_name(src0->type), ggml_type_name(src1->type)); + GGML_ABORT("fatal error"); + } +} + +inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + + ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); +} + +inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + + ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); +} + +inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + + ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); +} + +inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + + ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); +} + +inline void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + ggml_sycl_op_bin_bcast>(ctx, dst, dst->src[0], dst); +} + + +void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_add(ctx, dst); +} + +void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_sub(ctx, dst); +} + +void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_mul(ctx, dst); +} + +void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_div(ctx, dst); +} + +void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_repeat(ctx, dst); +} + diff --git a/ggml/src/ggml-sycl/binbcast.hpp b/ggml/src/ggml-sycl/binbcast.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9cce0f053a5826949a5f354fb2b9877b5eeaa374 --- /dev/null +++ b/ggml/src/ggml-sycl/binbcast.hpp @@ -0,0 +1,39 @@ +#ifndef GGML_SYCL_BINBCAST_HPP +#define GGML_SYCL_BINBCAST_HPP +#include "common.hpp" + + +static __dpct_inline__ float op_repeat(const float a, const float b) { + return b; + GGML_UNUSED(a); +} + +static __dpct_inline__ float op_add(const float a, const float b) { + return a + b; +} + +static __dpct_inline__ float op_sub(const float a, const float b) { + return a - b; +} + +static __dpct_inline__ float op_mul(const float a, const float b) { + return a * b; +} + +static __dpct_inline__ float op_div(const float a, const float b) { + return a / b; +} + +void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + + +#endif //GGML_SYCL_BINBCAST_HPP + diff --git a/ggml/src/ggml-sycl/common.cpp b/ggml/src/ggml-sycl/common.cpp new file mode 100644 index 0000000000000000000000000000000000000000..05fd5ef46c76afaa6858a052836e0bc756f9b690 --- /dev/null +++ b/ggml/src/ggml-sycl/common.cpp @@ -0,0 +1,83 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#include "common.hpp" + +#include "ggml-backend-impl.h" +#include "ggml-impl.h" + +int get_current_device_id() { + return dpct::dev_mgr::instance().current_device_id(); +} + +void* ggml_sycl_host_malloc(size_t size) try { + if (getenv("GGML_SYCL_NO_PINNED") != nullptr) { + return nullptr; + } + + void* ptr = nullptr; + // allow to use dpct::get_in_order_queue() for host malloc + dpct::err0 err = CHECK_TRY_ERROR( + ptr = (void*)sycl::malloc_host(size, dpct::get_in_order_queue())); + + if (err != 0) { + // clear the error + GGML_LOG_ERROR("WARNING: failed to allocate %.2f MB of pinned memory: %s\n", size / 1024.0 / 1024.0, "syclGetErrorString is not supported"); + return nullptr; + } + + return ptr; +} catch (sycl::exception const& exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +void ggml_sycl_host_free(void* ptr) try { + // allow to use dpct::get_in_order_queue() for host malloc + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, dpct::get_in_order_queue()))); +} catch (sycl::exception const& exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +bool gpu_has_xmx(sycl::device &dev) { + return dev.has(sycl::aspect::ext_intel_matrix); +} + +int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) { + const int64_t max_range = std::numeric_limits::max(); + int64_t sycl_down_blk_size = block_size; + int64_t global_range = accumulate_block_num * sycl_down_blk_size; + while(global_range > max_range) { + sycl_down_blk_size /= 2; + global_range = accumulate_block_num * sycl_down_blk_size; + } + return sycl_down_blk_size; +} + +void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector streams) { + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) { + if (extra->events[i][is] != nullptr) { + SYCL_CHECK(CHECK_TRY_ERROR(dpct::destroy_event(extra->events[i][is]))); + } + } + if (extra->data_device[i] != nullptr && streams.size()>0) { + ggml_sycl_set_device(i); + SYCL_CHECK( + CHECK_TRY_ERROR(sycl::free(extra->data_device[i], *(streams[i])))); + } + } + delete extra; +} diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..753b4af143622d16f504b0d9c64a609627b72e7b --- /dev/null +++ b/ggml/src/ggml-sycl/common.hpp @@ -0,0 +1,584 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_COMMON_HPP +#define GGML_SYCL_COMMON_HPP + +#include +#include +#include +#include + +#include "dpct/helper.hpp" +#include "ggml-sycl.h" +#include "presets.hpp" +#include "sycl_hw.hpp" + + +#if GGML_SYCL_DNNL +#include "dnnl.hpp" +#include "dnnl_sycl.hpp" +#endif + +#define GGML_COMMON_DECL_SYCL +#define GGML_COMMON_IMPL_SYCL +/* suppress warning spam */ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wnested-anon-types" +#include "ggml-common.h" +#pragma clang diagnostic pop +#include "ggml-impl.h" + +void* ggml_sycl_host_malloc(size_t size); +void ggml_sycl_host_free(void* ptr); + + +extern int g_ggml_sycl_debug; +extern int g_ggml_sycl_disable_optimize; +extern int g_ggml_sycl_prioritize_dmmv; + +#if defined(__clang__) && __has_builtin(__builtin_expect) +// Hint the optimizer to pipeline the more likely following instruction in branches +# define LIKELY(expr) __builtin_expect(expr, true) +# define UNLIKELY(expr) __builtin_expect(expr, false) +#else +# define LIKELY(expr) (expr) +# define UNLIKELY(expr) (expr) +#endif + +#define GGML_SYCL_DEBUG(...) \ + do { \ + if (UNLIKELY(g_ggml_sycl_debug)) \ + fprintf(stderr, __VA_ARGS__); \ + } while (0) + +#define CHECK_TRY_ERROR(expr) \ + [&]() { \ + try { \ + expr; \ + return dpct::success; \ + } catch (std::exception const& e) { \ + std::cerr << e.what() << "\nException caught at file:" << __FILE__ \ + << ", line:" << __LINE__ << ", func:" << __func__ \ + << std::endl; \ + return dpct::default_error; \ + } \ + }() + + +#define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP +#define VER_4VEC 610 // todo for hardward optimize. +#define VER_GEN9 700 // todo for hardward optimize. +#define VER_GEN12 1000000 // todo for hardward optimize. +#define VER_GEN13 (VER_GEN12 + 1030) // todo for hardward optimize. + +#define GGML_SYCL_MAX_NODES 8192 // TODO: adapt to hardwares + +// define for XMX in Intel GPU +// TODO: currently, it's not used for XMX really. +#if !defined(GGML_SYCL_FORCE_MMQ) + #define SYCL_USE_XMX +#endif + +// max batch size to use MMQ kernels when tensor cores are available +#define MMQ_MAX_BATCH_SIZE 32 + +// dmmv = dequantize_mul_mat_vec +#ifndef GGML_SYCL_DMMV_X +#define GGML_SYCL_DMMV_X 32 +#endif +#ifndef GGML_SYCL_MMV_Y +#define GGML_SYCL_MMV_Y 1 +#endif + +typedef sycl::queue *queue_ptr; + +enum ggml_sycl_backend_gpu_mode { + SYCL_UNSET_GPU_MODE = -1, + SYCL_SINGLE_GPU_MODE = 0, + SYCL_MUL_GPU_MODE +}; + +static_assert(sizeof(sycl::half) == sizeof(ggml_fp16_t), "wrong fp16 size"); + +static void crash() { + int* ptr = NULL; + *ptr = 0; +} + +[[noreturn]] static void ggml_sycl_error( + const char* stmt, + const char* func, + const char* file, + const int line, + const char* msg) { + fprintf(stderr, "SYCL error: %s: %s\n", stmt, msg); + fprintf(stderr, " in function %s at %s:%d\n", func, file, line); + GGML_ABORT("SYCL error"); +} + +#define SYCL_CHECK(err) \ + do { \ + auto err_ = (err); \ + if (err_ != 0) \ + ggml_sycl_error(#err, __func__, __FILE__, __LINE__, "Exception caught in this line of code."); \ + } while (0) + +#if DPCT_COMPAT_RT_VERSION >= 11100 +#define GGML_SYCL_ASSUME(x) __builtin_assume(x) +#else +#define GGML_SYCL_ASSUME(x) +#endif // DPCT_COMPAT_RT_VERSION >= 11100 + +#ifdef GGML_SYCL_F16 +typedef sycl::half dfloat; // dequantize float +typedef sycl::half2 dfloat2; +#else +typedef float dfloat; // dequantize float +typedef sycl::float2 dfloat2; +#endif // GGML_SYCL_F16 + +#define MMVQ_MAX_BATCH_SIZE 8 + +static int g_all_sycl_device_count = -1; +static bool g_ggml_backend_sycl_buffer_type_initialized = false; + +static ggml_sycl_backend_gpu_mode g_ggml_sycl_backend_gpu_mode = + SYCL_UNSET_GPU_MODE; + +static void* g_scratch_buffer = nullptr; +static size_t g_scratch_size = 0; // disabled by default +static size_t g_scratch_offset = 0; + +[[noreturn]] static inline void bad_arch(const sycl::stream& stream_ct1) { + stream_ct1 << "ERROR: ggml-sycl was compiled without support for the " + "current GPU architecture.\n"; + // __trap(); + std::exit(1); + + (void)bad_arch; // suppress unused function warning +} + +int get_current_device_id(); + +inline dpct::err0 ggml_sycl_set_device(const int device) try { + int current_device_id; + SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id())); + + // GGML_SYCL_DEBUG("ggml_sycl_set_device device_id=%d, + // current_device_id=%d\n", device, current_device); + if (device == current_device_id) { + return 0; + } + + return CHECK_TRY_ERROR(dpct::select_device(device)); +} catch (sycl::exception const& exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + crash(); + std::exit(1); +} + +////////////////////// +struct optimize_feature { + bool reorder=false; +}; + +struct sycl_device_info { + int cc; // compute capability + // int nsm; // number of streaming multiprocessors + // size_t smpb; // max. shared memory per block + bool vmm; // virtual memory support + size_t total_vram; + sycl_hw_info hw_info; + optimize_feature opt_feature; +}; + + +struct ggml_sycl_device_info { + int device_count; + + sycl_device_info devices[GGML_SYCL_MAX_DEVICES] = {}; + + std::array default_tensor_split = {}; + + int max_work_group_sizes[GGML_SYCL_MAX_DEVICES] = {0}; +}; + +const ggml_sycl_device_info & ggml_sycl_info(); + +struct ggml_sycl_pool { + virtual ~ggml_sycl_pool() = default; + + virtual void * alloc(size_t size, size_t * actual_size) = 0; + virtual void free(void * ptr, size_t size) = 0; +}; + +template +struct ggml_sycl_pool_alloc { + ggml_sycl_pool * pool = nullptr; + T * ptr = nullptr; + size_t actual_size = 0; + + explicit ggml_sycl_pool_alloc(ggml_sycl_pool & pool) : pool(&pool) { + } + + ggml_sycl_pool_alloc(ggml_sycl_pool & pool, size_t size) : pool(&pool) { + alloc(size); + } + + ~ggml_sycl_pool_alloc() { + if (ptr != nullptr) { + pool->free(ptr, actual_size); + } + } + + T * realloc(size_t size) { + GGML_ASSERT(pool != nullptr); + if (ptr) + pool->free(ptr, actual_size); + ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size); + return ptr; + } + + // size is in number of elements + T * alloc(size_t size) { + GGML_ASSERT(pool != nullptr); + GGML_ASSERT(ptr == nullptr); + ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size); + return ptr; + } + + T * alloc(ggml_sycl_pool & pool, size_t size) { + this->pool = &pool; + return alloc(size); + } + + T * get() { + return ptr; + } + + ggml_sycl_pool_alloc() = default; + ggml_sycl_pool_alloc(const ggml_sycl_pool_alloc &) = delete; + ggml_sycl_pool_alloc(ggml_sycl_pool_alloc &&) = delete; + ggml_sycl_pool_alloc& operator=(const ggml_sycl_pool_alloc &) = delete; + ggml_sycl_pool_alloc& operator=(ggml_sycl_pool_alloc &&) = delete; +}; + +// backend interface + +struct ggml_tensor_extra_gpu { + void* data_device[GGML_SYCL_MAX_DEVICES]; // 1 pointer for each device for split + // tensors + dpct::event_ptr events[GGML_SYCL_MAX_DEVICES] + [GGML_SYCL_MAX_STREAMS]; // events for synchronizing multiple GPUs + optimize_feature optimized_feature; +}; + +void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector streams={}); + +inline optimize_feature check_gpu_optimize_feature(syclex::architecture &arch) { + optimize_feature opt; + + opt.reorder = + (arch == syclex::architecture::intel_gpu_dg1 || + arch == syclex::architecture::intel_gpu_acm_g10 || + arch == syclex::architecture::intel_gpu_acm_g11 || + arch == syclex::architecture::intel_gpu_acm_g12 || + arch == syclex::architecture::intel_gpu_pvc || + arch == syclex::architecture::intel_gpu_pvc_vg || + arch == syclex::architecture::intel_gpu_mtl_u || + arch == syclex::architecture::intel_gpu_mtl_s || + arch == syclex::architecture::intel_gpu_mtl_h || + arch == syclex::architecture::intel_gpu_arl_u || + arch == syclex::architecture::intel_gpu_arl_s || + arch == syclex::architecture::intel_gpu_arl_h || + arch == syclex::architecture::intel_gpu_bmg_g21 || + arch == syclex::architecture::intel_gpu_lnl_m + ); + + return opt; +} + +namespace sycl_ex = sycl::ext::oneapi::experimental; +struct ggml_backend_sycl_context { + int device; + std::string name; + optimize_feature opt_feature; + + queue_ptr qptrs[GGML_SYCL_MAX_DEVICES][GGML_SYCL_MAX_STREAMS] = { { nullptr } }; + + explicit ggml_backend_sycl_context(int device) : + device(device), + name(GGML_SYCL_NAME + std::to_string(device)) { + opt_feature = ggml_sycl_info().devices[device].opt_feature; + } + + queue_ptr stream(int device, int stream) { + if (qptrs[device][stream] == nullptr) { + qptrs[device][stream] = &(dpct::get_device(device).default_queue()); + } + return qptrs[device][stream]; + } + + queue_ptr stream() { + return stream(device, 0); + } + +#if GGML_SYCL_DNNL + dnnl::engine make_engine(sycl::queue* q) { + // Get the device associated with the queue + sycl::device dev = q->get_device(); + // Get the context associated with the queue + sycl::context ctx = q->get_context(); + const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx); + return eng; + } + + std::unordered_map stream_map; + std::unordered_map engine_map; + dnnl::stream stream_dnnl(int device, int _stream) { + auto q = stream(device, _stream); + return stream_dnnl(q); + } + dnnl::engine engine_dnnl(sycl::queue* qptr) { + auto it = engine_map.find(qptr); + if (it == engine_map.end()) { + auto eng = make_engine(qptr); + engine_map[qptr] = eng; + return eng; + } + else + { + return it->second; + } + } + dnnl::stream stream_dnnl(sycl::queue* qptr) { + auto it = stream_map.find(qptr); + if (it == stream_map.end()) { + auto eng = engine_dnnl(qptr); + auto stream = dnnl::sycl_interop::make_stream(eng, *qptr); + stream_map[qptr] = stream; + return stream; + } + else + { + return it->second; + } + } + dnnl::stream stream_dnnl() { + return stream_dnnl(device, 0); + } + dnnl::memory get_scratchpad_mem(const dnnl::memory::desc & scratchpad_md, + const dnnl::engine & eng, const queue_ptr q) { + ggml_sycl_pool_alloc * pool; + auto it = scratchpad_map.find(q); + if (it == scratchpad_map.end()) { + scratchpad_map[q] = std::make_unique>(this->pool()); + pool = scratchpad_map[q].get(); + } else { + pool = it->second.get(); + } + + size_t scratchpad_size = scratchpad_md.get_size(); + if (scratchpad_size > pool->actual_size) { + pool->realloc(scratchpad_size); + } + void * mem_ptr = pool->get(); + return dnnl::memory(scratchpad_md, eng, mem_ptr); + } +#endif + + // pool + std::unique_ptr pools[GGML_SYCL_MAX_DEVICES]; + std::unordered_map>> scratchpad_map; + + std::unique_ptr host_pools[GGML_SYCL_MAX_DEVICES]; + + static std::unique_ptr new_pool_for_device(queue_ptr qptr, int device); + + static std::unique_ptr new_pool_for_host(queue_ptr qptr, int device); + + ggml_sycl_pool & pool(int device) { + if (pools[device] == nullptr) { + pools[device] = new_pool_for_device(stream(device,0), device); + } + return *pools[device]; + } + + ggml_sycl_pool & pool() { + return pool(device); + } + +#ifdef GGML_SYCL_GRAPH + std::unique_ptr> exec_graph = nullptr; +#endif + + ggml_sycl_pool & host_pool(int device) { + if (host_pools[device] == nullptr) { + host_pools[device] = new_pool_for_host(stream(device, 0), device); + } + return *host_pools[device]; + } + + ggml_sycl_pool & host_pool() { return host_pool(device); } +}; + +// common device functions + +static __dpct_inline__ float warp_reduce_sum(float x, + const sycl::nd_item<3>& item_ct1) { +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + /* + DPCT1096:98: The right-most dimension of the work-group used in the SYCL + kernel that calls this function may be less than "32". The function + "dpct::permute_sub_group_by_xor" may return an unexpected result on the + CPU device. Modify the size of the work-group to ensure that the value + of the right-most dimension is a multiple of "32". + */ + x += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), x, mask); + } + return x; +} + +static __dpct_inline__ sycl::float2 +warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) { +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + a.x() += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), a.x(), + mask); + a.y() += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), a.y(), + mask); + } + return a; +} + +static __dpct_inline__ float warp_reduce_max(float x, + const sycl::nd_item<3>& item_ct1) { +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + /* + DPCT1096:97: The right-most dimension of the work-group used in the SYCL + kernel that calls this function may be less than "32". The function + "dpct::permute_sub_group_by_xor" may return an unexpected result on the + CPU device. Modify the size of the work-group to ensure that the value + of the right-most dimension is a multiple of "32". + */ + x = sycl::fmax(x, dpct::permute_sub_group_by_xor( + item_ct1.get_sub_group(), x, mask)); + } + return x; +} + +/* Helper for Computing the linear offset of a ggml_tensor given +per-dimension sizes, strides, and indices */ +template +__dpct_inline__ size_t calculate_offset(const std::array & strides, const std::array & indices) { + size_t offset = 0; +#pragma unroll + for (int i = 0; i < N; i++) { + auto index_i = indices[i]; + offset += strides[i] * index_i; + } + return offset; +} + +// Helper for vec loading aligned data +template +inline sycl::vec vec_aligned_load(const Tp* aligned_ptr) { + return *reinterpret_cast*>(aligned_ptr); +} + +// Helper for accessing pointers with no warnings +template +static __dpct_inline__ Tp* get_pointer(sycl::local_accessor acc) { + return acc.template get_multi_ptr().get(); +} + +int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size); + +constexpr size_t ceil_div(const size_t m, const size_t n) { + return (m + n - 1) / n; +} + +bool gpu_has_xmx(sycl::device &dev); + +template std::string debug_get_array_str(const std::string & prefix, const T array[N]) { + if (LIKELY(!g_ggml_sycl_debug)) { + return ""; + } + std::stringstream ss; + ss << prefix << "=["; + for (std::size_t i = 0; i < N - 1; ++i) { + ss << array[i] << ", "; + } + if constexpr (N > 0) { + ss << array[N - 1]; + } + ss << "]"; + return ss.str(); +} + +inline std::string debug_get_tensor_str(const std::string &prefix, + const ggml_tensor *tensor, const std::string &suffix = "") { + std::stringstream ss; + if (LIKELY(!g_ggml_sycl_debug)) { return ss.str(); } + ss << prefix.c_str() << "="; + if (tensor) { + ss << "'" << tensor->name << "':type=" << ggml_type_name(tensor->type); + ss << debug_get_array_str(";ne", tensor->ne); + ss << debug_get_array_str(";nb", tensor->nb); + + if (!ggml_is_contiguous(tensor)) { ss << ";strided"; } + if (ggml_is_permuted(tensor)) { ss << ";permuted"; } + } else { + ss << "nullptr"; + } + ss << suffix; + return ss.str(); +} + +// Use scope_op_debug_print to log operations coming from running a model +struct scope_op_debug_print { + // Use string_views to avoid the cost of creating a string and concatenating them + // string_views must be alive for as long as the object is alive + // scope_op_debug_print are used with string literals in practice which are stored in constant space so always accessible + scope_op_debug_print(const std::string_view & func, const std::string_view & func_suffix, const ggml_tensor * dst, + std::size_t num_src, const std::string_view & suffix = "") : + func(func), + func_suffix(func_suffix) { + if (LIKELY(!g_ggml_sycl_debug)) { + return; + } + GGML_SYCL_DEBUG("[SYCL][OP] call %s%s:", func.data(), func_suffix.data()); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(" dst", dst).c_str()); + if (dst) { + for (std::size_t i = 0; i < num_src; ++i) { + GGML_SYCL_DEBUG("%s", debug_get_tensor_str("\tsrc" + std::to_string(i), dst->src[i]).c_str()); + } + } + GGML_SYCL_DEBUG("%s\n", suffix.data()); + } + + scope_op_debug_print(const std::string_view & func, const ggml_tensor * dst, std::size_t num_src, + const std::string_view & suffix = "") : + scope_op_debug_print(func, "", dst, num_src, suffix) {} + + ~scope_op_debug_print() { GGML_SYCL_DEBUG("[SYCL][OP] call %s%s done\n", func.data(), func_suffix.data()); } + + private: + std::string_view func; + std::string_view func_suffix; +}; + +#endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/concat.cpp b/ggml/src/ggml-sycl/concat.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7aa91c861d58398200a63527e97a76e6d5eddf4b --- /dev/null +++ b/ggml/src/ggml-sycl/concat.cpp @@ -0,0 +1,195 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#include "concat.hpp" +#include "common.hpp" + +static void concat_f32_dim0(const float *x, const float *y, float *dst, + const int ne0, const int ne00, + const sycl::nd_item<3> &item_ct1) { + int nidx = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + if (nidx >= ne0) { + return; + } + // operation + int offset_dst = nidx + item_ct1.get_group(1) * ne0 + + item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1); + if (nidx < ne00) { // src0 + int offset_src = nidx + item_ct1.get_group(1) * ne00 + + item_ct1.get_group(0) * ne00 * item_ct1.get_group_range(1); + dst[offset_dst] = x[offset_src]; + } else { + int offset_src = + nidx - ne00 + item_ct1.get_group(1) * (ne0 - ne00) + + item_ct1.get_group(0) * (ne0 - ne00) * item_ct1.get_group_range(1); + dst[offset_dst] = y[offset_src]; + } +} + +static void concat_f32_dim1(const float *x, const float *y, float *dst, + const int ne0, const int ne01, + const sycl::nd_item<3> &item_ct1) { + int nidx = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + if (nidx >= ne0) { + return; + } + // operation + int offset_dst = nidx + item_ct1.get_group(1) * ne0 + + item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1); + if (item_ct1.get_group(1) < (size_t) ne01) { // src0 + int offset_src = + nidx + item_ct1.get_group(1) * ne0 + item_ct1.get_group(0) * ne0 * ne01; + dst[offset_dst] = x[offset_src]; + } else { + int offset_src = + nidx + (item_ct1.get_group(1) - ne01) * ne0 + + item_ct1.get_group(0) * ne0 * (item_ct1.get_group_range(1) - ne01); + dst[offset_dst] = y[offset_src]; + } +} + +static void concat_f32_dim2(const float *x, const float *y, float *dst, + const int ne0, const int ne02, + const sycl::nd_item<3> &item_ct1) { + int nidx = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + if (nidx >= ne0) { + return; + } + // operation + int offset_dst = nidx + item_ct1.get_group(1) * ne0 + + item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1); + if (item_ct1.get_group(0) < (size_t) ne02) { // src0 + int offset_src = nidx + item_ct1.get_group(1) * ne0 + + item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1); + dst[offset_dst] = x[offset_src]; + } else { + int offset_src = + nidx + item_ct1.get_group(1) * ne0 + + (item_ct1.get_group(0) - ne02) * ne0 * item_ct1.get_group_range(1); + dst[offset_dst] = y[offset_src]; + } +} + +static void concat_f32_sycl(const float *x, const float *y, float *dst, + int ne00, int ne01, int ne02, int ne0, int ne1, + int ne2, int dim, queue_ptr stream) { + int num_blocks = (ne0 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE; + sycl::range<3> gridDim(ne2, ne1, num_blocks); + switch (dim) { + case 0: + stream->parallel_for( + sycl::nd_range<3>(gridDim * + sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1); + }); + break; + case 1: + stream->parallel_for( + sycl::nd_range<3>(gridDim * + sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1); + }); + break; + // dim >=2 will be dispatched to the default path + default: + stream->parallel_for( + sycl::nd_range<3>(gridDim * + sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1); + }); + break; + } +} + +// non-contiguous kernel (slow) +static void concat_f32_sycl_non_cont( + queue_ptr stream, const char *src0, const char *src1, char *dst, + int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, uint64_t nb00, + uint64_t nb01, uint64_t nb02, uint64_t nb03, int64_t /*ne10*/, + int64_t /*ne11*/, int64_t /*ne12*/, int64_t /*ne13*/, uint64_t nb10, + uint64_t nb11, uint64_t nb12, uint64_t nb13, int64_t ne0, int64_t ne1, + int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2, + uint64_t nb3, int32_t dim) { + sycl::range<3> gridDim(ne3, ne2, ne1); + stream->parallel_for( + sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)), + [=](sycl::nd_item<3> item_ct1) { + int64_t i3 = item_ct1.get_group(0); + int64_t i2 = item_ct1.get_group(1); + int64_t i1 = item_ct1.get_group(2); + + int64_t o[4] = {0, 0, 0, 0}; + o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); + + const float *x; + + for (int i0 = item_ct1.get_local_id(2); i0 < ne0; + i0 += item_ct1.get_local_range(2)) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const float *)(src0 + (i3)*nb03 + (i2)*nb02 + (i1)*nb01 + + (i0)*nb00); + } else { + x = (const float *)(src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 + + (i1 - o[1]) * nb11 + (i0 - o[0]) * nb10); + } + + float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0); + + *y = *x; + } + }); +} + +void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + queue_ptr stream = ctx.stream(); + + const int32_t dim = ((int32_t *) dst->op_params)[0]; + + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + + float * dst_d = (float *) dst->data; + + if (dim != 3) { + for (int i3 = 0; i3 < dst->ne[3]; i3++) { + concat_f32_sycl(src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4), + dst_d + i3 * (dst->nb[3] / 4), src0->ne[0], src0->ne[1], src0->ne[2], dst->ne[0], + dst->ne[1], dst->ne[2], dim, stream); + } + } else { + const size_t size0 = ggml_nbytes(src0); + const size_t size1 = ggml_nbytes(src1); + + SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d, src0_d, size0).wait())); + SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d + size0 / 4, src1_d, size1).wait())); + } + } else { + concat_f32_sycl_non_cont(stream, (const char *) src0->data, (const char *) src1->data, (char *) dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], + src0->nb[2], src0->nb[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], + dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim); + } +} diff --git a/ggml/src/ggml-sycl/concat.hpp b/ggml/src/ggml-sycl/concat.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e5cb7314c9f336195c0d49d35224ca05646c8f1e --- /dev/null +++ b/ggml/src/ggml-sycl/concat.hpp @@ -0,0 +1,20 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_CONCAT_HPP +#define GGML_SYCL_CONCAT_HPP + +#include "common.hpp" + +void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst); + +#endif // GGML_SYCL_CONCAT_HPP diff --git a/ggml/src/ggml-sycl/conv.cpp b/ggml/src/ggml-sycl/conv.cpp new file mode 100644 index 0000000000000000000000000000000000000000..475bd34a25d5626cba8cd58cc803bbec5a6316bf --- /dev/null +++ b/ggml/src/ggml-sycl/conv.cpp @@ -0,0 +1,101 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#include "conv.hpp" + +static void conv_transpose_1d_kernel( + const int s0, const int output_size, + const int src0_ne0, const int src0_ne1, const int src0_ne2, + const int src1_ne0, const int dst_ne0, + const float * src0, const float * src1, float * dst, + const sycl::nd_item<3> &item_ct1) { + int global_index = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + if (global_index >= output_size) { + return; + } + + int out_index = global_index / dst_ne0; + + float accumulator = 0; + + for (int c = 0; c < src0_ne2; c++) { + int idx = global_index % dst_ne0; + + int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0); + int input_offset = src1_ne0 * c; + + for (int i = 0; i < src1_ne0; i++) { + if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) { + continue; + } + int weight_idx = idx - i*s0; + + float kernel_weight = src0[kernel_offset + weight_idx]; + float input_value = src1[input_offset+i]; + + accumulator += kernel_weight * input_value; + } + } + dst[global_index] = accumulator; +} + +static void conv_transpose_1d_f32_f32_sycl( + const int s0, const int output_size, + const int src0_ne0, const int src0_ne1, const int src0_ne2, + const int src1_ne0, const int dst_ne0, + const float *src0, const float *src1, float *dst, + const queue_ptr& stream) { + + const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE; + const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE); + const sycl::range<3> block_nums(1, 1, num_blocks); + stream->parallel_for( + sycl::nd_range<3>( + block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + conv_transpose_1d_kernel( + s0, output_size, + src0_ne0, src0_ne1, src0_ne2, + src1_ne0, dst_ne0, + src0, src1, dst, item_ct1); + }); +} + +void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + const ggml_tensor *src0 = dst->src[0]; + const ggml_tensor *src1 = dst->src[1]; + const float * src0_d = (const float *)src0->data; + const float * src1_d = (const float *)src1->data; + + float * dst_d = (float *)dst->data; + dpct::queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + const int32_t * opts = (const int32_t *)dst->op_params; + + const int s0 = opts[0]; + + const int64_t output_size = ggml_nelements(dst); + + conv_transpose_1d_f32_f32_sycl(s0, output_size, + src0->ne[0], src0->ne[1], src0->ne[2], + src1->ne[0], dst->ne[0], + src0_d, src1_d, dst_d, stream); +} + diff --git a/ggml/src/ggml-sycl/conv.hpp b/ggml/src/ggml-sycl/conv.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f9e60dc758029ef02f4dd9feb86ea2735d4b5155 --- /dev/null +++ b/ggml/src/ggml-sycl/conv.hpp @@ -0,0 +1,20 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_CONV_HPP +#define GGML_SYCL_CONV_HPP + +#include "common.hpp" + +void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, ggml_tensor *dst); + +#endif // GGML_SYCL_CONV_HPP diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp new file mode 100644 index 0000000000000000000000000000000000000000..96d2583b13b83a295d254a5ee2261cc772ccba50 --- /dev/null +++ b/ggml/src/ggml-sycl/convert.cpp @@ -0,0 +1,642 @@ +#include "convert.hpp" +#include "dequantize.hpp" +#include "presets.hpp" + +template +static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, + const sycl::nd_item<3> &item_ct1) { + const int64_t i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2)); + + if (i >= k) { + return; + } + + const int64_t ib = i/qk; // block index + const int64_t iqs = (i%qk)/qr; // quant index + const int64_t iybs = i - i%qk; // y block start index + const int64_t y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + dfloat2 v; + dequantize_kernel(vx, ib, iqs, v); + + y[iybs + iqs + 0] = v.x(); + y[iybs + iqs + y_offset] = v.y(); +} + +template +static void dequantize_block_sycl(const void *__restrict__ vx, + dst_t *__restrict__ y, const int64_t k, + dpct::queue_ptr stream) { + const int64_t num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE); + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + stream->parallel_for( + sycl::nd_range<3>( + sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block(vx, y, k, item_ct1); + }); + } +} + +template +static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; +#if QK_K == 256 + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 64), + sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q2_K(vx, y, item_ct1); + }); + } +#else + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q2_K(vx, y, item_ct1); + }); + } + +#endif +} + +template +static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; +#if QK_K == 256 + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 64), + sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q3_K(vx, y, item_ct1); + }); + } +#else + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q3_K(vx, y, item_ct1); + }); + } +#endif +} + +template +static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + const int64_t nb32 = k / 32; + const int64_t nb = (k + 255) / 256; + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q4_0(vx, y, nb32, item_ct1); + }); + } +} + +template +static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + int constexpr WARP_K = WARP_SIZE * QK4_0; + const int n_warp = (k + WARP_K - 1) / WARP_K; + GGML_ASSERT(k % 2 == 0); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) * + sycl::range<3>(1, 1, WARP_SIZE), + sycl::range<3>(1, 1, WARP_SIZE)), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{ + dequantize_block_q4_0_reorder(vx, y, k, item_ct1); + }); + +} + +template +static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + const int64_t nb32 = k / 32; + const int64_t nb = (k + 255) / 256; + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q4_1(vx, y, nb32, item_ct1); + }); + } +} + + +template +static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor scale_local_acc(sycl::range<1>(12), cgh); + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1); + }); + }); + } +} + +template +static void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; + const size_t local_size = 32; + const size_t global_size = nb * local_size; + + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + + stream->submit([&](sycl::handler & cgh) { + sycl::local_accessor scale_local_acc(sycl::range<1>(12), cgh); + + cgh.parallel_for(sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)), + [=](sycl::nd_item<1> item_ct1) { + dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb); + }); + }); +} + +template +static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; +#if QK_K == 256 + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 64), + sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q5_K(vx, y, item_ct1); + }); + } +#else + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q5_K(vx, y, item_ct1); + }); + } + +#endif +} + +template +static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; +#if QK_K == 256 + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 64), + sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q6_K(vx, y, item_ct1); + }); + } +#else + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q6_K(vx, y, item_ct1); + }); + } + +#endif +} + +template +static void dequantize_row_q6_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; + + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); }); +} + +template +static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq1_s( + vx, y, item_ct1, iq1s_grid_gpu + ); + }); + }); + } +} + +template +static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq1_m( + vx, y, item_ct1, iq1s_grid_gpu + ); + }); + }); + } +} + +template +static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq2_xxs( + vx, y, item_ct1, iq2xxs_grid, + ksigns_iq2xs, kmask_iq2xs); + }); + }); + } +} + +template +static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq2_xs( + vx, y, item_ct1, iq2xs_grid, + ksigns_iq2xs, kmask_iq2xs); + }); + }); + } +} + +template +static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq2_s(vx, y, item_ct1); + }); + }); + } +} + + +template +static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq3_xxs( + vx, y, item_ct1, iq3xxs_grid, + ksigns_iq2xs, kmask_iq2xs); + }); + }); + } +} + +template +static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq3_s( + vx, y, item_ct1, kmask_iq2xs, iq3s_grid); + }); + }); + } +} + +template +static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + const int64_t nb = (k + QK_K - 1) / QK_K; +#if QK_K == 64 + dequantize_row_iq4_nl_sycl(vx, y, k, stream); +#else + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq4_xs(vx, y, item_ct1); + }); + }); + } +#endif +} + +template +static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + const int64_t nb = (k + QK_K - 1) / QK_K; + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq4_nl(vx, y, item_ct1); + }); + }); + } +} + +template +static void convert_unary_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, + const int64_t ne02, const int64_t s01, const int64_t s02, const int64_t s03, + const sycl::nd_item<3> & item_ct1) { + + const int64_t work_group_size = item_ct1.get_local_range(2); + const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2); + + const int64_t i01 = item_ct1.get_group(1); + const int64_t i02 = item_ct1.get_group(0) % ne02; + const int64_t i03 = item_ct1.get_group(0) / ne02; + + // make each work-item deal with more elements since sycl global range can not exceed max int + const src_t * x = static_cast(vx); + const int64_t ix = i03 * s03 + i02 * s02 + i01 * s01; + const int64_t iy = ((i03 * ne02 + i02) * ne01 + i01) * ne00; + +#pragma unroll + for (int64_t i00 = global_id; i00 < ne00; i00 += work_group_size * item_ct1.get_group_range(2)) { + y[iy + i00] = static_cast(x[ix + i00]); + } +} + +template +static void convert_unary_nc_sycl(const void * __restrict__ vx, dst_t * __restrict__ y, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t s01, const int64_t s02, const int64_t s03, dpct::queue_ptr queue) { + dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 }); + + sycl::range<3> global_size(ne02 * ne03, ne01, ceil_div(ne00, SYCL_DEQUANTIZE_BLOCK_SIZE)); + + // decrease global range when it exceeds the max int + // TODO: Downsample logic is separated from the kernel, a rewrite is desirable + int64_t downsized_workgroup = downsample_sycl_global_range(global_size[0], SYCL_DEQUANTIZE_BLOCK_SIZE); + sycl::range<3> workgroup_size(1, 1, downsized_workgroup); + + queue->parallel_for(sycl::nd_range<3>(global_size * workgroup_size, workgroup_size), [=](sycl::nd_item<3> item_ct1) { + convert_unary_nc(vx, y, ne00, ne01, ne02, s01, s02, s03, item_ct1); + }); +} + +template +static void convert_unary_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr queue) { + convert_unary_nc_sycl(vx, y, k, 1, 1, 1, k, k, k, queue); +} + +to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { + switch (type) { + case GGML_TYPE_Q4_0: + if (dst->src[0]->extra && + ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q4_0_sycl_reorder; + } else { + return dequantize_block_sycl; + } + case GGML_TYPE_Q4_1: + return dequantize_block_sycl; + case GGML_TYPE_Q5_0: + return dequantize_block_sycl; + case GGML_TYPE_Q5_1: + return dequantize_block_sycl; + case GGML_TYPE_Q8_0: + return dequantize_block_sycl; + case GGML_TYPE_Q2_K: + return dequantize_row_q2_K_sycl; + case GGML_TYPE_Q3_K: + return dequantize_row_q3_K_sycl; + case GGML_TYPE_Q4_K: + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q4_K_sycl_reorder; + } else { + return dequantize_row_q4_K_sycl; + } + case GGML_TYPE_Q5_K: + return dequantize_row_q5_K_sycl; + case GGML_TYPE_Q6_K: + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q6_K_sycl_reorder; + } else { + return dequantize_row_q6_K_sycl; + } + case GGML_TYPE_IQ1_S: + return dequantize_row_iq1_s_sycl; + case GGML_TYPE_IQ1_M: + return dequantize_row_iq1_m_sycl; + case GGML_TYPE_IQ2_XXS: + return dequantize_row_iq2_xxs_sycl; + case GGML_TYPE_IQ2_XS: + return dequantize_row_iq2_xs_sycl; + case GGML_TYPE_IQ2_S: + return dequantize_row_iq2_s_sycl; + case GGML_TYPE_IQ3_XXS: + return dequantize_row_iq3_xxs_sycl; + case GGML_TYPE_IQ3_S: + return dequantize_row_iq3_s_sycl; + case GGML_TYPE_IQ4_XS: + return dequantize_row_iq4_xs_sycl; + case GGML_TYPE_IQ4_NL: + return dequantize_row_iq4_nl_sycl; + case GGML_TYPE_F32: + return convert_unary_sycl; + default: + return nullptr; + } +} + +to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { + switch (type) { + case GGML_TYPE_Q4_0: + if (dst->src[0]->extra && + ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q4_0_sycl_reorder; + } else { + return dequantize_row_q4_0_sycl; + } + case GGML_TYPE_Q4_1: + return dequantize_row_q4_1_sycl; + case GGML_TYPE_Q5_0: + return dequantize_block_sycl; + case GGML_TYPE_Q5_1: + return dequantize_block_sycl; + case GGML_TYPE_Q8_0: + return dequantize_block_sycl; + case GGML_TYPE_Q2_K: + return dequantize_row_q2_K_sycl; + case GGML_TYPE_Q3_K: + return dequantize_row_q3_K_sycl; + case GGML_TYPE_Q4_K: + if (dst->src[0]->extra && + ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q4_K_sycl_reorder; + } else { + return dequantize_row_q4_K_sycl; + } + case GGML_TYPE_Q5_K: + return dequantize_row_q5_K_sycl; + case GGML_TYPE_Q6_K: + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q6_K_sycl_reorder; + } else { + return dequantize_row_q6_K_sycl; + } + case GGML_TYPE_IQ1_S: + return dequantize_row_iq1_s_sycl; + case GGML_TYPE_IQ1_M: + return dequantize_row_iq1_m_sycl; + case GGML_TYPE_IQ2_XXS: + return dequantize_row_iq2_xxs_sycl; + case GGML_TYPE_IQ2_XS: + return dequantize_row_iq2_xs_sycl; + case GGML_TYPE_IQ2_S: + return dequantize_row_iq2_s_sycl; + case GGML_TYPE_IQ3_XXS: + return dequantize_row_iq3_xxs_sycl; + case GGML_TYPE_IQ3_S: + return dequantize_row_iq3_s_sycl; + case GGML_TYPE_IQ4_XS: + return dequantize_row_iq4_xs_sycl; + case GGML_TYPE_IQ4_NL: + return dequantize_row_iq4_nl_sycl; + case GGML_TYPE_F16: + return convert_unary_sycl; + default: + return nullptr; + } +} + +to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) { + switch (type) { + case GGML_TYPE_F32: + return convert_unary_nc_sycl; + default: + return nullptr; + } +} diff --git a/ggml/src/ggml-sycl/convert.hpp b/ggml/src/ggml-sycl/convert.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f8cb573e3688bc470aecac9f5bbddf232a1c028b --- /dev/null +++ b/ggml/src/ggml-sycl/convert.hpp @@ -0,0 +1,34 @@ +// +// MIT license +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_CONVERT_HPP +#define GGML_SYCL_CONVERT_HPP + +#include "common.hpp" + +template +using to_t_sycl_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, dpct::queue_ptr stream); +typedef to_t_sycl_t to_fp32_sycl_t; +typedef to_t_sycl_t to_fp16_sycl_t; + +to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst); +to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor * dst); + +// Nc = Non-contiguous +template +using to_t_nc_sycl_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, + int64_t s01, int64_t s02, int64_t s03, dpct::queue_ptr queue); + +typedef to_t_nc_sycl_t to_fp16_nc_sycl_t; +to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type); + +#endif // GGML_SYCL_CONVERT_HPP diff --git a/ggml/src/ggml-sycl/cpy.cpp b/ggml/src/ggml-sycl/cpy.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bec13714019558ee862dfa98460b15957a51188c --- /dev/null +++ b/ggml/src/ggml-sycl/cpy.cpp @@ -0,0 +1,817 @@ +#include "cpy.hpp" + +#include +#include + +#include "dequantize.hpp" +#include "ggml-sycl/common.hpp" +#include "ggml-sycl/presets.hpp" +#include "ggml.h" + +static __dpct_inline__ int best_index_int8(int n, const int8_t * val, float x) { + if (x <= val[0]) { + return 0; + } + if (x >= val[n - 1]) { + return n - 1; + } + int ml = 0, mu = n - 1; + while (mu - ml > 1) { + int mav = (ml + mu) / 2; + if (x < val[mav]) { + mu = mav; + } else { + ml = mav; + } + } + return x - val[mu - 1] < val[mu] - x ? mu - 1 : mu; +} + +static void cpy_1_f32_f32(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + float * dsti = (float *) cdsti; + + *dsti = *xi; +} + +static void cpy_1_f32_f16(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + sycl::half * dsti = (sycl::half *) cdsti; + + *dsti = sycl::vec(*xi).convert()[0]; +} + +static void cpy_1_f16_f16(const char * cxi, char * cdsti) { + const sycl::half * xi = (const sycl::half *) cxi; + sycl::half * dsti = (sycl::half *) cdsti; + + *dsti = *xi; +} + +static void cpy_1_f16_f32(const char * cxi, char * cdsti) { + const sycl::half * xi = (const sycl::half *) cxi; + float * dsti = (float *) cdsti; + + *dsti = *xi; +} + +static void cpy_1_i16_i16(const char * cxi, char * cdsti) { + const int16_t * xi = (const int16_t *) cxi; + int16_t * dsti = (int16_t *) cdsti; + + *dsti = *xi; +} + +static void cpy_1_i32_i32(const char * cxi, char * cdsti) { + const int32_t * xi = (const int32_t *) cxi; + int32_t * dsti = (int32_t *) cdsti; + + *dsti = *xi; +} + +template +static void cpy_f32_f16(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, + const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, + const sycl::nd_item<3> & item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); + + if (i >= ne) { + return; + } + + // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor + // then combine those indices with the corresponding byte offsets to get the total offsets + const int i03 = i / (ne00 * ne01 * ne02); + const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); + const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00; + const int i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00; + const int x_offset = i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03; + + const int i13 = i / (ne10 * ne11 * ne12); + const int i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11); + const int i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10; + const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10; + const int dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13; + + cpy_1(cx + x_offset, cdst + dst_offset); +} + +static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_q8_0 * dsti = (block_q8_0 *) cdsti; + + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = xi[j]; + amax = sycl::fmax(amax, sycl::fabs((float) v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f / d : 0.0f; + + dsti->d = d; + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = xi[j] * id; + + dsti->qs[j] = sycl::round((float) x0); + } +} + +/* quantized type same copy */ +template +static void cpy_blck_q_q(const char * cxi, char * cdsti) { + const T * xi = (const T *) cxi; + T * dsti = (T *) cdsti; + *dsti = *xi; +} + + +static void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { + float * cdstf = (float *) (cdsti); + + for (int j = 0; j < QK8_0; j += 2) { + dfloat2 dq; + dequantize_q8_0(cxi, 0, j, dq); + *(cdstf + j) = dq.x(); + *(cdstf + j + 1) = dq.y(); + } +} + +static void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_q4_0 * dsti = (block_q4_0 *) cdsti; + + float amax = 0.0f; + float vmax = 0.0f; + + for (int j = 0; j < QK4_0; ++j) { + const float v = xi[j]; + if (amax < sycl::fabs((float) v)) { + amax = sycl::fabs((float) v); + vmax = v; + } + } + + const float d = vmax / -8; + const float id = d ? 1.0f / d : 0.0f; + + dsti->d = d; + + for (int j = 0; j < QK4_0 / 2; ++j) { + const float x0 = xi[0 + j] * id; + const float x1 = xi[QK4_0 / 2 + j] * id; + + const uint8_t xi0 = dpct::min(15, (int8_t) (x0 + 8.5f)); + const uint8_t xi1 = dpct::min(15, (int8_t) (x1 + 8.5f)); + + dsti->qs[j] = xi0; + dsti->qs[j] |= xi1 << 4; + } +} + +static void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_q4_1 * dsti = (block_q4_1 *) cdsti; + + float vmin = FLT_MAX; + float vmax = -FLT_MAX; + + for (int j = 0; j < QK4_1; ++j) { + const float v = xi[j]; + + if (v < vmin) { + vmin = v; + } + if (v > vmax) { + vmax = v; + } + } + + const float d = (vmax - vmin) / ((1 << 4) - 1); + const float id = d ? 1.0f / d : 0.0f; + + dsti->dm.x() = d; + dsti->dm.y() = vmin; + + for (int j = 0; j < QK4_1 / 2; ++j) { + const float x0 = (xi[0 + j] - vmin) * id; + const float x1 = (xi[QK4_1 / 2 + j] - vmin) * id; + + const uint8_t xi0 = dpct::min(15, (int8_t) (x0 + 0.5f)); + const uint8_t xi1 = dpct::min(15, (int8_t) (x1 + 0.5f)); + + dsti->qs[j] = xi0; + dsti->qs[j] |= xi1 << 4; + } +} + +static void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_q5_0 * dsti = (block_q5_0 *) cdsti; + + float amax = 0.0f; + float vmax = 0.0f; + + for (int j = 0; j < QK5_0; ++j) { + const float v = xi[j]; + if (amax < sycl::fabs((float) v)) { + amax = sycl::fabs((float) v); + vmax = v; + } + } + + const float d = vmax / -16; + const float id = d ? 1.0f / d : 0.0f; + + dsti->d = d; + + uint32_t qh = 0; + for (int j = 0; j < QK5_0 / 2; ++j) { + const float x0 = xi[0 + j] * id; + const float x1 = xi[QK5_0 / 2 + j] * id; + + const uint8_t xi0 = dpct::min(31, (int8_t) (x0 + 16.5f)); + const uint8_t xi1 = dpct::min(31, (int8_t) (x1 + 16.5f)); + + dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0 / 2); + } + memcpy(dsti->qh, &qh, sizeof(qh)); +} + +static void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_q5_1 * dsti = (block_q5_1 *) cdsti; + + float min = xi[0]; + float max = xi[0]; + + for (int j = 1; j < QK5_1; ++j) { + const float v = xi[j]; + min = v < min ? v : min; + max = v > max ? v : max; + } + + const float d = (max - min) / 31; + const float id = d ? 1.0f / d : 0.0f; + + dsti->dm.x() = d; + dsti->dm.y() = min; + + uint32_t qh = 0; + for (int j = 0; j < QK5_1 / 2; ++j) { + const float x0 = (xi[0 + j] - min) * id; + const float x1 = (xi[QK5_1 / 2 + j] - min) * id; + + const uint8_t xi0 = (uint8_t) (x0 + 0.5f); + const uint8_t xi1 = (uint8_t) (x1 + 0.5f); + + dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1 / 2); + } + memcpy(dsti->qh, &qh, sizeof(qh)); +} + +static void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_iq4_nl * dsti = (block_iq4_nl *) cdsti; + + float amax = 0.0f; + float vmax = 0.0f; + + for (int j = 0; j < QK4_NL; ++j) { + const float v = xi[j]; + if (amax < sycl::fabs((float) v)) { + amax = sycl::fabs((float) v); + vmax = v; + } + } + + float d = vmax / kvalues_iq4nl[0]; + const float id = d ? 1.0f / d : 0.0f; + + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < QK4_NL / 2; ++j) { + const float x0 = xi[0 + j] * id; + const float x1 = xi[QK4_NL / 2 + j] * id; + const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0); + const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1); + dsti->qs[j] = xi0 | (xi1 << 4); + const float v0 = kvalues_iq4nl[xi0]; + const float v1 = kvalues_iq4nl[xi1]; + const float w0 = xi[0 + j] * xi[0 + j]; + const float w1 = xi[QK4_NL / 2 + j] * xi[QK4_NL / 2 + j]; + sumqx += w0 * v0 * xi[j] + w1 * v1 * xi[QK4_NL / 2 + j]; + sumq2 += w0 * v0 * v0 + w1 * v1 * v1; + } + + dsti->d = sumq2 > 0 ? sumqx / sumq2 : d; +} + +template static void cpy_blck_q_f32(const char * cxi, char * cdsti) { + float * cdstf = (float *) (cdsti); + + for (int j = 0; j < qk / 2; j++) { + dfloat2 dq; + dequant(cxi, 0, j, dq); + *(cdstf + j) = dq.x(); + *(cdstf + j + qk / 2) = dq.y(); + } +} + + +template +static void cpy_q_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, + const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, + const sycl::nd_item<3> & item_ct1) { + const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk; + + if (i >= ne) { + return; + } + + const int i03 = i / (ne00 * ne01 * ne02); + const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); + const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00; + const int i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00; + const int x_offset = (i00 / qk) * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03; + + + const int i13 = i / (ne10 * ne11 * ne12); + const int i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11); + const int i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10; + const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10; + const int dst_offset = (i10 / qk) * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13; + + cpy_blck_q_q(cx + x_offset, cdst + dst_offset); +} + +template +static void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, + const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, + const sycl::nd_item<3> & item_ct1) { + const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk; + + if (i >= ne) { + return; + } + + + const int i03 = i / (ne00 * ne01 * ne02); + const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); + const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00; + const int i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00; + const int x_offset = i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03; + + const int i13 = i / (ne10 * ne11 * ne12); + const int i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11); + const int i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10; + const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10; + const int dst_offset = (i10 / qk) * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13; + + cpy_blck(cx + x_offset, cdst + dst_offset); +} + +template +static void cpy_q_f32(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, + const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, + const sycl::nd_item<3> & item_ct1) { + const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk; + + if (i >= ne) { + return; + } + + const int i03 = i / (ne00 * ne01 * ne02); + const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); + const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00; + const int i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00; + const int x_offset = (i00 / qk) * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03; + + const int i13 = i / (ne10 * ne11 * ne12); + const int i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11); + const int i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10; + const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10; + const int dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13; + + cpy_blck(cx + x_offset, cdst + dst_offset); +} + +static void ggml_cpy_f16_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; + { + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, + nb10, nb11, nb12, nb13, item_ct1); + }); + } +} + +static void ggml_cpy_f32_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; + { + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, + nb10, nb11, nb12, nb13, item_ct1); + }); + } +} + +static void ggml_cpy_f32_f16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; + { + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, + nb10, nb11, nb12, nb13, item_ct1); + }); + } +} + +static void ggml_cpy_f32_q8_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + GGML_ASSERT(ne % QK8_0 == 0); + const int num_blocks = ne / QK8_0; + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + +static void ggml_cpy_q8_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = ne; + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), + [=](sycl::nd_item<3> item_ct1) { + cpy_q_f32(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + +static void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + GGML_ASSERT(ne % QK4_0 == 0); + const int num_blocks = ne / QK4_0; + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + +static void ggml_cpy_q4_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = ne; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { + cpy_q_f32, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, + nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, + item_ct1); + }); +} + +static void ggml_cpy_f32_q4_1_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + GGML_ASSERT(ne % QK4_1 == 0); + const int num_blocks = ne / QK4_1; + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + +static void ggml_cpy_q4_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = ne; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { + cpy_q_f32, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, + nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, + item_ct1); + }); +} + +static void ggml_cpy_f32_q5_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + GGML_ASSERT(ne % QK5_0 == 0); + const int num_blocks = ne / QK5_0; + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + +static void ggml_cpy_q5_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = ne; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { + cpy_q_f32, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, + nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, + item_ct1); + }); +} + +static void ggml_cpy_f32_q5_1_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + GGML_ASSERT(ne % QK5_1 == 0); + const int num_blocks = ne / QK5_1; + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + +static void ggml_cpy_q5_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = ne; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { + cpy_q_f32, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, + nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, + item_ct1); + }); +} + +static void ggml_cpy_f32_iq4_nl_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + GGML_ASSERT(ne % QK4_NL == 0); + const int num_blocks = ne / QK4_NL; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { + cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, + ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + +static void ggml_cpy_f16_f16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; + { + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, + nb10, nb11, nb12, nb13, item_ct1); + }); + } +} + +static void ggml_cpy_i16_i16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; + { + // dpct::has_capability_or_fail(stream->get_device(), + // {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, + nb10, nb11, nb12, nb13, item_ct1); + }); + } +} + +static void ggml_cpy_i32_i32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; + { + // dpct::has_capability_or_fail(stream->get_device(), + // {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, + nb10, nb11, nb12, nb13, item_ct1); + }); + } +} + +static void ggml_cpy_q8_0_q8_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { + cpy_q_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + + +static void ggml_cpy_q5_0_q5_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { + cpy_q_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + + +static void ggml_cpy_q5_1_q5_1(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { + cpy_q_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + + +static void ggml_cpy_q4_0_q4_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { + cpy_q_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + + +static void ggml_cpy_q4_1_q4_1(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + + const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { + cpy_q_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + +void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1) try { + // Unlike other operators ggml_sycl_cpy takes 2 distinct tensors instead of a dst ggml_tensor and rely on its src field + scope_op_debug_print scope_dbg_print(__func__, src1, /*num_src=*/0, debug_get_tensor_str("\tsrc0", src0)); + const int64_t ne = ggml_nelements(src0); + GGML_ASSERT(ne == ggml_nelements(src1)); + + GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); + GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); + + GGML_TENSOR_BINARY_OP_LOCALS01; + + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + queue_ptr main_stream = ctx.stream(); + + char * src0_ddc = (char *) src0->data; + char * src1_ddc = (char *) src1->data; + if ((src0->type == src1->type) && (ggml_is_contiguous(src0) && ggml_is_contiguous(src1))) { + GGML_SYCL_DEBUG("%s: memcpy path\n", __func__); + main_stream->memcpy(src1_ddc, src0_ddc, ggml_nbytes(src0)); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + ggml_cpy_f32_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { + ggml_cpy_f32_f16_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { + ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { + ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { + ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { + ggml_cpy_f16_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { + ggml_cpy_f16_f16_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) { + ggml_cpy_i16_i16_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) { + ggml_cpy_i32_i32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q4_0_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q4_1_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q8_0_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { + ggml_cpy_f32_q5_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q5_0_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { + ggml_cpy_f32_q5_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q5_1_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { + ggml_cpy_f32_iq4_nl_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, + nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) { + ggml_cpy_q8_0_q8_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_Q5_0) { + ggml_cpy_q5_0_q5_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_Q5_1) { + ggml_cpy_q5_1_q5_1(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_Q4_0) { + ggml_cpy_q4_0_q4_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_Q4_1) { + ggml_cpy_q4_1_q4_1(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else { + GGML_LOG_ERROR("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), + ggml_type_name(src1->type)); + GGML_ABORT("fatal error"); + } +} catch (const sycl::exception & exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_cpy(ctx, dst->src[0], dst); +} diff --git a/ggml/src/ggml-sycl/cpy.hpp b/ggml/src/ggml-sycl/cpy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0a0f561d2309a059694e7d9bdbe5f75443b59fc6 --- /dev/null +++ b/ggml/src/ggml-sycl/cpy.hpp @@ -0,0 +1,11 @@ +#ifndef GGML_SYCL_CPY_HPP +#define GGML_SYCL_CPY_HPP + +#include "common.hpp" + +typedef void (*cpy_kernel_t)(const char * cx, char * cdst); + +void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1); +void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif // GGML_SYCL_CPY_HPP diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp new file mode 100644 index 0000000000000000000000000000000000000000..540539bb2238102c85635af265029ee5ac64f293 --- /dev/null +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -0,0 +1,823 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_DEQUANTIZE_HPP +#define GGML_SYCL_DEQUANTIZE_HPP + +#include "common.hpp" + +typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v); +typedef void (*dequantize_kernel_t_reorder)(const void *d, const int64_t ib, const void *qs, + const int iqs, dfloat2 &v); + +static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { + const block_q4_0 * x = (const block_q4_0 *) vx; + + const dfloat d = x[ib].d; + + const int vui = x[ib].qs[iqs]; + + v.x() = vui & 0xF; + v.y() = vui >> 4; + +#ifdef GGML_SYCL_F16 + // v = v - {8.0f, 8.0f}; + // v = v * {d, d}; + v.s0() = (v.s0() - 8.0f) * d; + v.s1() = (v.s1() - 8.0f) * d; + +#else + v.x() = (v.x() - 8.0f) * d; + v.y() = (v.y() - 8.0f) * d; +#endif // GGML_SYCL_F16 +} + +static __dpct_inline__ void dequantize_q4_0_reorder(const void *d_ptr, const int64_t ib, const void *qs, + const int iqs, dfloat2 &v) { + // const block_q4_0 * x = (const block_q4_0 *) vx; + + const dfloat d = (const dfloat)*((const sycl::half*)d_ptr+ib); + + const int vui = *((const uint8_t *)qs+iqs); + + v.x() = vui & 0xF; + v.y() = vui >> 4; + +#ifdef GGML_SYCL_F16 + // v = v - {8.0f, 8.0f}; + // v = v * {d, d}; + v.s0() = (v.s0() - 8.0f) * d; + v.s1() = (v.s1() - 8.0f) * d; + +#else + v.x() = (v.x() - 8.0f) * d; + v.y() = (v.y() - 8.0f) * d; +#endif // GGML_SYCL_F16 +} + +static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { + const block_q4_1 * x = (const block_q4_1 *) vx; + + const dfloat d = x[ib].dm[0]; + const dfloat m = x[ib].dm[1]; + + const int vui = x[ib].qs[iqs]; + + v.x() = vui & 0xF; + v.y() = vui >> 4; + +#ifdef GGML_SYCL_F16 + // v = v * {d, d}; + // v = v + {m, m}; + v.s0() = sycl::fma(v.s0(), d, m); + v.s1() = sycl::fma(v.s1(), d, m); + +#else + v.x() = sycl::fma(v.x(), d, m); + v.y() = sycl::fma(v.y(), d, m); +#endif // GGML_SYCL_F16 +} + +static __dpct_inline__ void dequantize_q5_0(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { + const block_q5_0 * x = (const block_q5_0 *) vx; + + const dfloat d = x[ib].d; + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + + v.x() = ((x[ib].qs[iqs] & 0xf) | xh_0); + v.y() = ((x[ib].qs[iqs] >> 4) | xh_1); + +#ifdef GGML_SYCL_F16 + // v = v - {16.0f, 16.0f}; + // v = v * {d, d}; + v.s0() = (v.s0() - 16.0f) * d; + v.s1() = (v.s1() - 16.0f) * d; + +#else + v.x() = (v.x() - 16.0f) * d; + v.y() = (v.y() - 16.0f) * d; +#endif // GGML_SYCL_F16 +} + +static __dpct_inline__ void dequantize_q5_1(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { + const block_q5_1 * x = (const block_q5_1 *) vx; + + const dfloat d = x[ib].dm[0]; + const dfloat m = x[ib].dm[1]; + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + + v.x() = ((x[ib].qs[iqs] & 0xf) | xh_0); + v.y() = ((x[ib].qs[iqs] >> 4) | xh_1); + +#ifdef GGML_SYCL_F16 + // v = v * {d, d}; + // v = v + {m, m}; + v.s0() = sycl::fma(v.s0(), d, m); + v.s1() = sycl::fma(v.s1(), d, m); +#else + v.x() = sycl::fma(v.x(), d, m); + v.y() = sycl::fma(v.y(), d, m); +#endif // GGML_SYCL_F16 +} + +static __dpct_inline__ void dequantize_q8_0(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { + const block_q8_0 * x = (const block_q8_0 *) vx; + + const dfloat d = x[ib].d; + + v.x() = x[ib].qs[iqs + 0]; + v.y() = x[ib].qs[iqs + 1]; + +#ifdef GGML_SYCL_F16 + // v = v * {d, d}; + v.s0() *= d; + v.s1() *= d; +#else + v.x() *= d; + v.y() *= d; +#endif // GGML_SYCL_F16 +} + +template +static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32, + const sycl::nd_item<3> &item_ct1) { + + const int64_t i = item_ct1.get_group(2); + + // assume 32 threads + const int64_t tid = item_ct1.get_local_id(2); + const int64_t il = tid/8; + const int64_t ir = tid%8; + const int64_t ib = 8*i + ir; + if (ib >= nb32) { + return; + } + + dst_t * y = yy + 256*i + 32*ir + 4*il; + + const block_q4_0 * x = (const block_q4_0 *)vx + ib; + const float d = sycl::vec(x->d) + .convert()[0]; + const float dm = -8*d; + + const uint8_t * q = x->qs + 4*il; + + for (int l = 0; l < 4; ++l) { + y[l+ 0] = d * (q[l] & 0xF) + dm; + y[l+16] = d * (q[l] >> 4) + dm; + } +} + +template +static void dequantize_block_q4_0_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32, + const sycl::nd_item<3> &item_ct1) { + + const int64_t i = item_ct1.get_group(2); + auto k=nb32; + // assume 32 threads + const int64_t tid = item_ct1.get_local_id(2); + const int lane_ib = i * WARP_SIZE + tid; + + if (lane_ib >= k / QK4_0) { + return; + } + + dst_t * y_ptr = yy + lane_ib * QK4_0; + + auto qs = (const uint8_t*)vx + lane_ib * QK4_0 / 2; + auto s_ptr = (const sycl::half*)((const uint8_t*)vx + k / 2) + lane_ib; + + const float d = float(*s_ptr); + +#pragma unroll + for (int l = 0; l < QK4_0 / 2; ++l) { + int vq = qs[l]; + y_ptr[l + 0] = d * ((vq & 0xF) - 8); + y_ptr[l + 16] = d * ((vq >> 4) - 8); + } + +} + +template +static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32, + const sycl::nd_item<3> &item_ct1) { + + const int64_t i = item_ct1.get_group(2); + + // assume 32 threads + const int64_t tid = item_ct1.get_local_id(2); + const int64_t il = tid/8; + const int64_t ir = tid%8; + const int64_t ib = 8*i + ir; + if (ib >= nb32) { + return; + } + + dst_t * y = yy + 256*i + 32*ir + 4*il; + + const block_q4_1 * x = (const block_q4_1 *)vx + ib; + const sycl::float2 d = + x->dm.convert(); + + const uint8_t * q = x->qs + 4*il; + + for (int l = 0; l < 4; ++l) { + y[l + 0] = d.x() * (q[l] & 0xF) + d.y(); + y[l + 16] = d.x() * (q[l] >> 4) + d.y(); + } +} + + +//================================== k-quants + +template +static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy, + const sycl::nd_item<3> &item_ct1) { + + const int64_t i = item_ct1.get_group(2); + const block_q2_K * x = (const block_q2_K *) vx; + + const int64_t tid = item_ct1.get_local_id(2); +#if QK_K == 256 + const int64_t n = tid/32; + const int64_t l = tid - 32*n; + const int64_t is = 8*n + l/16; + + const uint8_t q = x[i].qs[32*n + l]; + dst_t * y = yy + i*QK_K + 128*n; + + float dall = x[i].dm[0]; + float dmin = x[i].dm[1]; + y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); + y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4); + y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); + y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4); +#else + const int64_t is = tid/16; // 0 or 1 + const int64_t il = tid%16; // 0...15 + const uint8_t q = x[i].qs[il] >> (2*is); + dst_t * y = yy + i*QK_K + 16*is + il; + + float dall = x[i].dm[0]; + float dmin = x[i].dm[1]; + y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); + y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4); +#endif + +} + +template +static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy, + const sycl::nd_item<3> &item_ct1) { + + const int64_t i = item_ct1.get_group(2); + const block_q3_K * x = (const block_q3_K *) vx; + +#if QK_K == 256 + const int64_t r = item_ct1.get_local_id(2) / 4; + const int64_t tid = r/2; + const int64_t is0 = r%2; + const int64_t l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4); + const int64_t n = tid / 4; + const int64_t j = tid - 4*n; + + uint8_t m = 1 << (4*n + j); + int64_t is = 8*n + 2*j + is0; + int shift = 2*j; + + int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) : + (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4); + float d_all = x[i].d; + float dl = d_all * (us - 32); + + dst_t * y = yy + i*QK_K + 128*n + 32*j; + const uint8_t * q = x[i].qs + 32*n; + const uint8_t * hm = x[i].hmask; + + for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); +#else + const int64_t tid = item_ct1.get_local_id(2); + const int64_t is = tid/16; // 0 or 1 + const int64_t il = tid%16; // 0...15 + const int64_t im = il/8; // 0...1 + const int64_t in = il%8; // 0...7 + + dst_t * y = yy + i*QK_K + 16*is + il; + + const uint8_t q = x[i].qs[il] >> (2*is); + const uint8_t h = x[i].hmask[in] >> (2*is + im); + const float d = (float)x[i].d; + + if (is == 0) { + y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); + y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); + } else { + y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); + y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); + } +#endif + +} + +#if QK_K == 256 +static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { + if (j < 4) { + d = q[j] & 63; + m = q[j + 4] & 63; + } else { + d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + } +} +#endif + +template +inline void dequantize_q4_K_common(dst_t * __restrict__ y, const uint8_t * __restrict__ qs_ptr, const float dall, + const float dmin, uint8_t * __restrict__ scales_local, int il, int ir) { + const int is = 2 * il; + constexpr int n = 4; + + uint8_t sc, m; + get_scale_min_k4(is + 0, scales_local, sc, m); + const float d1 = dall * sc; + const float m1 = dmin * m; + + get_scale_min_k4(is + 1, scales_local, sc, m); + const float d2 = dall * sc; + const float m2 = dmin * m; + + sycl::vec q_vec = vec_aligned_load(qs_ptr + 32 * il + n * ir); + for (int l = 0; l < n; ++l) { + y[l + 0] = d1 * (q_vec[l] & 0xF) - m1; + y[l + 32] = d2 * (q_vec[l] >> 4) - m2; + } +} + +template +static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy, + uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) { + const block_q4_K * x = (const block_q4_K *) vx; + + const int64_t i = item_ct1.get_group(2); + +#if QK_K == 256 + const int64_t tid = item_ct1.get_local_id(2); + const int64_t il = tid / 8; + const int64_t ir = tid % 8; + + dst_t * y = yy + i * QK_K + 64 * il + 4 * ir; + + const sycl::half2 dm = x[i].dm; + const float dall = dm[0]; + const float dmin = dm[1]; + + if (tid < 12) { + scales_local[tid] = x[i].scales[tid]; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + dequantize_q4_K_common(y, x[i].qs, dall, dmin, scales_local, il, ir); +#else + const int64_t tid = item_ct1.get_local_id(2); + const uint8_t * q = x[i].qs; + dst_t * y = yy + i*QK_K; + const float d = (float)x[i].dm[0]; + const float m = (float)x[i].dm[1]; + y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4); + y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4); +#endif +} + +template +static void dequantize_block_q4_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, uint8_t * scales_local, + const sycl::nd_item<1> & item_ct1, int64_t nb) { + const int64_t i = item_ct1.get_group(0); // block index + const int64_t tid = item_ct1.get_local_id(0); // thread index within block + const int64_t il = tid / 8; + const int64_t ir = tid % 8; + + dst_t * y = yy + i * QK_K + 64 * il + 4 * ir; + + const uint8_t * base = static_cast(vx); + const size_t qs_offset = i * (QK_K / 2); + const size_t scales_offset = nb * (QK_K / 2) + i * K_SCALE_SIZE; + const size_t dm_offset = nb * (QK_K / 2) + nb * K_SCALE_SIZE + i * sizeof(ggml_half2); + + const uint8_t * qs_ptr = base + qs_offset; + const uint8_t * scales_ptr = base + scales_offset; + ggml_half2 dm_values = *reinterpret_cast(base + dm_offset); + + const float dall = dm_values.x(); + const float dmin = dm_values.y(); + + if (tid < 12) { + scales_local[tid] = scales_ptr[tid]; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + dequantize_q4_K_common(y, qs_ptr, dall, dmin, scales_local, il, ir); +} + +template +static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy, + const sycl::nd_item<3> &item_ct1) { + const block_q5_K * x = (const block_q5_K *) vx; + + const int64_t i = item_ct1.get_group(2); + +#if QK_K == 256 + // assume 64 threads - this is very slightly better than the one below + const int64_t tid = item_ct1.get_local_id(2); + const int64_t il = tid/16; // il is in 0...3 + const int64_t ir = tid%16; // ir is in 0...15 + const int64_t is = 2*il; // is is in 0...6 + + dst_t * y = yy + i*QK_K + 64*il + 2*ir; + + const float dall = x[i].dm[0]; + const float dmin = x[i].dm[1]; + + const uint8_t * ql = x[i].qs + 32*il + 2*ir; + const uint8_t * qh = x[i].qh + 2*ir; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[i].scales, sc, m); + const float d1 = dall * sc; const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[i].scales, sc, m); + const float d2 = dall * sc; const float m2 = dmin * m; + + uint8_t hm = 1 << (2*il); + y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1; + y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1; + hm <<= 1; + y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; + y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; +#else + const int64_t tid = item_ct1.get_local_id(2); + const uint8_t q = x[i].qs[tid]; + const int64_t im = tid/8; // 0...3 + const int64_t in = tid%8; // 0...7 + const int64_t is = tid/16; // 0 or 1 + const uint8_t h = x[i].qh[in] >> im; + const float d = x[i].d; + dst_t * y = yy + i*QK_K + tid; + y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16)); + y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16)); +#endif +} + +template +static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy, + const sycl::nd_item<3> &item_ct1) { + const block_q6_K * x = (const block_q6_K *) vx; + + const int64_t i = item_ct1.get_group(2); +#if QK_K == 256 + + // assume 64 threads - this is very slightly better than the one below + const int64_t tid = item_ct1.get_local_id(2); + const int64_t ip = tid/32; // ip is 0 or 1 + const int64_t il = tid - 32*ip; // 0...32 + const int64_t is = 8*ip + il/16; + + dst_t * y = yy + i*QK_K + 128*ip + il; + + const float d = x[i].d; + + const uint8_t * ql = x[i].ql + 64*ip + il; + const uint8_t qh = x[i].qh[32*ip + il]; + const int8_t * sc = x[i].scales + is; + + y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32); + y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32); + y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32); + y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); +#else + + // assume 32 threads + const int64_t tid = item_ct1.get_local_id(2); + const int64_t ip = tid/16; // 0 or 1 + const int64_t il = tid - 16*ip; // 0...15 + + dst_t * y = yy + i*QK_K + 16*ip + il; + + const float d = x[i].d; + + const uint8_t ql = x[i].ql[16*ip + il]; + const uint8_t qh = x[i].qh[il] >> (2*ip); + const int8_t * sc = x[i].scales; + + y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32); + y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32); +#endif +} + +template +static void dequantize_block_q6_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, + const sycl::nd_item<3> & item_ct1, int64_t n_blocks) { + const int64_t ib = item_ct1.get_group(2); + + const int64_t tid = item_ct1.get_local_id(2); + const int64_t ip = tid / 32; // ip is 0 or 1 + const int64_t il = tid - 32 * ip; // 0...32 + const int64_t is = 8 * ip + il / 16; + + const uint8_t * base_ptr = static_cast(vx); + const auto ql_offset = ib * (QK_K / 2); + const auto qh_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * ib; + const auto base_scales_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * n_blocks + (QK_K / 16) * ib; + const auto base_d_offset = ((QK_K / 2) + (QK_K / 4) + (QK_K / 16)) * n_blocks; + const uint8_t * ql_ptr = base_ptr + ql_offset; + const uint8_t * qh_ptr = base_ptr + qh_offset; + const uint8_t * scales_ptr = base_ptr + base_scales_offset; + const ggml_half * d = (const ggml_half *) (base_ptr + base_d_offset) + ib; + + dst_t * y = yy + ib * QK_K + 128 * ip + il; + + const uint8_t * ql = ql_ptr + 64 * ip + il; + const uint8_t qh = *(qh_ptr + 32 * ip + il); + const int8_t * sc = reinterpret_cast(scales_ptr + is); + + y[0] = *d * sc[0] * ((int8_t) ((ql[0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32); + y[32] = *d * sc[2] * ((int8_t) ((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32); + y[64] = *d * sc[4] * ((int8_t) ((ql[0] >> 4) | (((qh >> 4) & 3) << 4)) - 32); + y[96] = *d * sc[6] * ((int8_t) ((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); +} + +template +static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy, + const sycl::nd_item<3> &item_ct1, + const uint64_t *iq2xxs_grid_ptr, + const uint8_t *ksigns_iq2xs_ptr, + const uint8_t *kmask_iq2xs_ptr) { + + const int64_t i = item_ct1.get_group(2); + const block_iq2_xxs * x = (const block_iq2_xxs *) vx; + + const int64_t tid = item_ct1.get_local_id(2); +#if QK_K == 256 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint16_t * q2 = x[i].qs + 4*ib; + const uint8_t * aux8 = (const uint8_t *)q2; + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid_ptr + aux8[il]); + const uint32_t aux32 = q2[2] | (q2[3] << 16); + const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f; + const uint8_t signs = ksigns_iq2xs_ptr[(aux32 >> 7*il) & 127]; + for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs_ptr[j] ? -1.f : 1.f); +#else + assert(false); +#endif + +} + +template +static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy, + const sycl::nd_item<3> &item_ct1, + const uint64_t *iq2xs_grid, + const uint8_t *ksigns_iq2xs, + const uint8_t *kmask_iq2xs) { + + const int64_t i = item_ct1.get_group(2); + const block_iq2_xs * x = (const block_iq2_xs *) vx; + + const int64_t tid = item_ct1.get_local_id(2); +#if QK_K == 256 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint16_t * q2 = x[i].qs + 4*ib; + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511)); + const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f; + const uint8_t signs = ksigns_iq2xs[q2[il] >> 9]; + for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); +#else + assert(false); +#endif + +} + +template +__dpct_inline__ static void +dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy, + const sycl::nd_item<3> &item_ct1) { + + const int64_t i = item_ct1.get_group(2); + const block_iq2_s * x = (const block_iq2_s *) vx; + + const int64_t tid = item_ct1.get_local_id(2); +#if QK_K == 256 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300))); + const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f; + const uint8_t signs = x[i].qs[QK_K/8+4*ib+il]; +#pragma unroll + for (int j = 0; j < 8; ++j) + y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); +#else + assert(false); + +#endif + +} + +template +static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy, + const sycl::nd_item<3> &item_ct1, + const uint32_t *iq3xxs_grid, + const uint8_t *ksigns_iq2xs, + const uint8_t *kmask_iq2xs) { + + const int64_t i = item_ct1.get_group(2); + const block_iq3_xxs * x = (const block_iq3_xxs *) vx; + + const int64_t tid = item_ct1.get_local_id(2); +#if QK_K == 256 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint8_t * q3 = x[i].qs + 8*ib; + const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib; + const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]); + const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]); + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f; + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127]; + for (int j = 0; j < 4; ++j) { + y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); + } +#else + assert(false); +#endif + +} + +template +__dpct_inline__ static void +dequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy, + const sycl::nd_item<3> &item_ct1, + const uint8_t *kmask_iq2xs, const uint32_t *iq3s_grid) { + + const int64_t i = item_ct1.get_group(2); + const block_iq3_s * x = (const block_iq3_s *) vx; + + const int64_t tid = item_ct1.get_local_id(2); +#if QK_K == 256 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint8_t * qs = x[i].qs + 8*ib; + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256))); + const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)); + const uint8_t signs = x[i].signs[4*ib + il]; +#pragma unroll + for (int j = 0; j < 4; ++j) { + y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); + } +#else + assert(false); +#endif + +} + +template +__dpct_inline__ static void +dequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy, + const sycl::nd_item<3> &item_ct1, + const uint32_t *iq1s_grid_gpu) { + + const int64_t i = item_ct1.get_group(2); + const block_iq1_s * x = (const block_iq1_s *) vx; + + const int64_t tid = item_ct1.get_local_id(2); +#if QK_K == 256 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA; + const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1); + uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32; + grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)]; + grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f; + grid32[0] &= 0x0f0f0f0f; +#pragma unroll + for (int j = 0; j < 8; ++j) { + y[j] = d * (q[j] + delta); + } +#else + assert(false); +#endif + +} + +template +__dpct_inline__ static void +dequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy, + const sycl::nd_item<3> &item_ct1, + const uint32_t *iq1s_grid_gpu) { + + const int64_t i = item_ct1.get_group(2); + const block_iq1_m * x = (const block_iq1_m *) vx; + + const int64_t tid = item_ct1.get_local_id(2); +#if QK_K == 256 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint16_t * sc = (const uint16_t *)x[i].scales; + iq1m_scale_t scale; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + const int ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4); + const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1); + const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA; + uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32; + grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)]; + grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f; + grid32[0] &= 0x0f0f0f0f; +#pragma unroll + for (int j = 0; j < 8; ++j) { + y[j] = d * (q[j] + delta); + } +#else + assert(false); +#endif + +} + +template +__dpct_inline__ static void +dequantize_block_iq4_nl(const void *__restrict__ vx, dst_t *__restrict__ yy, + const sycl::nd_item<3> &item_ct1) { + + const int64_t i = item_ct1.get_group(2); + const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL); + + const int64_t tid = item_ct1.get_local_id(2); + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 4*il; + const uint8_t * q4 = x[ib].qs + 4*il; + const float d = (float)x[ib].d; +#pragma unroll + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf]; + y[j+16] = d * kvalues_iq4nl[q4[j] >> 4]; + } + +} + + +template +__dpct_inline__ static void +dequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy, + const sycl::nd_item<3> &item_ct1) { + const int64_t i = item_ct1.get_group(2); + const block_iq4_xs * x = (const block_iq4_xs *)vx; + + const int64_t tid = item_ct1.get_local_id(2); + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 4*il; + const uint8_t * q4 = x[i].qs + 16*ib + 4*il; + const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32); +#pragma unroll + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf]; + y[j+16] = d * kvalues_iq4nl[q4[j] >> 4]; + } +} + + +#endif // GGML_SYCL_DEQUANTIZE_HPP diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4f2760110c212c68a0f6146acd63d484b480a9e9 --- /dev/null +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -0,0 +1,1162 @@ +#include "convert.hpp" +#include "dmmv.hpp" +#include "dequantize.hpp" +#include "presets.hpp" + +static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ + const sycl::half *x = (const sycl::half *)vx; + + // automatic half -> float type cast if dfloat == float + v.x() = x[ib + iqs + 0]; + v.y() = x[ib + iqs + 1]; +} + +static void convert_f32(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ + const float * x = (const float *) vx; + + // automatic half -> float type cast if dfloat == float + v.x() = x[ib + iqs + 0]; + v.y() = x[ib + iqs + 1]; +} + +template +static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, + const sycl::nd_item<3> &item_ct1) { + // qk = quantized weights per x block + // qr = number of quantized weights per data value in x block + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int tid = item_ct1.get_local_id(2); + + const int iter_stride = 2*GGML_SYCL_DMMV_X; + const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter + const int y_offset = qr == 1 ? 1 : qk/2; + +// partial sum for each thread +#ifdef GGML_SYCL_F16 + sycl::half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics +#else + float tmp = 0.0f; +#endif // GGML_SYCL_F16 + + for (int i = 0; i < ncols; i += iter_stride) { + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/qk; // x block index + const int iqs = (col%qk)/qr; // x quant index + const int iybs = col - col%qk; // y block start index + +// processing >2 values per i iter is faster for fast GPUs +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + // process 2 vals per j iter + + // dequantize + // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val + dfloat2 v; + dequantize_kernel(vx, ib, iqs + j/qr, v); + + // matrix multiplication + // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 +#ifdef GGML_SYCL_F16 + dfloat2 t1{y[iybs + iqs + j / qr + 0], + y[iybs + iqs + j / qr + y_offset]}; + + tmp += v * t1; +#else + tmp += v.x() * y[iybs + iqs + j / qr + 0]; + tmp += v.y() * y[iybs + iqs + j / qr + y_offset]; +#endif // GGML_SYCL_F16 + } + } + + // sum up partial sums and write back result + const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2; + for (int mask = mask_start; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (tid == 0) { +#ifdef GGML_SYCL_F16 + dst[row] = tmp.x() + tmp.y(); +#else + dst[row] = tmp; +#endif // GGML_SYCL_F16 + } +} + +template +static void dequantize_mul_mat_vec_reorder(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, + const sycl::nd_item<3> &item_ct1) { + // qk = quantized weights per x block + // qr = number of quantized weights per data value in x block + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int tid = item_ct1.get_local_id(2); + + + const int ncols_left = ncols % (QK4_0*WARP_SIZE); + const int ncols_align = ncols - ncols_left; + const int iter_stride = 8*2*GGML_SYCL_DMMV_X; + const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter //64/16=4, 512/16/2= 16 + const int y_offset = qr == 1 ? 1 : qk/2; + +// partial sum for each thread +#ifdef GGML_SYCL_F16 + sycl::half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics +#else + float tmp = 0.0f; +#endif // GGML_SYCL_F16 + const char *d_ptr = (const char*)vx+ncols*nrows/2; + int i=0; + for (i = 0; i < ncols_align; i += iter_stride) { + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/qk; // x block index + const int iqs = (col%qk)/qr; // x quant index + const int iybs = col - col%qk; // y block start index + +// processing >2 values per i iter is faster for fast GPUs +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + // process 2 vals per j iter + + // dequantize + // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val + dfloat2 v; + dequantize_kernel_reorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v); + + // matrix multiplication + // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 +#ifdef GGML_SYCL_F16 + dfloat2 t1{y[iybs + iqs + j / qr + 0], + y[iybs + iqs + j / qr + y_offset]}; + + tmp += v * t1; +#else + tmp += v.x() * y[iybs + iqs + j / qr + 0]; + tmp += v.y() * y[iybs + iqs + j / qr + y_offset]; +#endif // GGML_SYCL_F16 + } + } + + for (; i < ncols; i += iter_stride) { + if (tid>=ncols_left/QK4_0) continue; + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/qk; // x block index + const int iqs = (col%qk)/qr; // x quant index + const int iybs = col - col%qk; // y block start index + +// processing >2 values per i iter is faster for fast GPUs +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + // process 2 vals per j iter + + // dequantize + // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val + dfloat2 v; + dequantize_kernel_reorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v); + + // matrix multiplication + // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 +#ifdef GGML_SYCL_F16 + dfloat2 t1{y[iybs + iqs + j / qr + 0], + y[iybs + iqs + j / qr + y_offset]}; + + tmp += v * t1; +#else + tmp += v.x() * y[iybs + iqs + j / qr + 0]; + tmp += v.y() * y[iybs + iqs + j / qr + y_offset]; +#endif // GGML_SYCL_F16 + } + } + + // sum up partial sums and write back result + const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2; + for (int mask = mask_start; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (tid == 0) { +#ifdef GGML_SYCL_F16 + dst[row] = tmp.x() + tmp.y(); +#else + dst[row] = tmp; +#endif // GGML_SYCL_F16 + } +} + +static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols, + nrows, item_ct1); + }); + } +} + +/* +DPCT1110:4: The total declared local variable size in device function +dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register +pressure. Consult with your hardware vendor to find the total register size +available and adjust the code, or use smaller sub-group size to avoid high +register pressure. +*/ +static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx, + const float *__restrict__ yy, + float *__restrict__ dst, + const int ncols, int nrows, + const sycl::nd_item<3> &item_ct1) { + + static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q2_K * x = (const block_q2_K *)vx + ib0; + + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + const int tid = + item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...15 + const int ix = + item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1 + + const int step = 16/K_QUANTS_PER_ITERATION; + + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 + + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2 + const int q_offset = 32*im + l0; + const int s_offset = 8*im; + const int y_offset = 128*im + l0; + + uint32_t aux[4]; + const uint8_t * d = (const uint8_t *)aux; + const uint8_t * m = (const uint8_t *)(aux + 2); + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * q = x[i].qs + q_offset; + + const float dall = x[i].dm[0]; + const float dmin = x[i].dm[1]; + + const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset); + aux[0] = a[0] & 0x0f0f0f0f; + aux[1] = a[1] & 0x0f0f0f0f; + aux[2] = (a[0] >> 4) & 0x0f0f0f0f; + aux[3] = (a[1] >> 4) & 0x0f0f0f0f; + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3) + + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3) + + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3) + + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3) + + y[l+16] * d[1] * ((q[l+16] >> 0) & 3) + + y[l+48] * d[3] * ((q[l+16] >> 2) & 3) + + y[l+80] * d[5] * ((q[l+16] >> 4) & 3) + +y[l+112] * d[7] * ((q[l+16] >> 6) & 3); + sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6] + + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7]; + + } + tmp += dall * sum1 - dmin * sum2; + + } +#else + const int tid = item_ct1.get_local_id(2) / + (2 * K_QUANTS_PER_ITERATION); // 0...15 or 0...7 + const int ix = item_ct1.get_local_id(2) % + (2 * K_QUANTS_PER_ITERATION); // 0....1 or 0...3 + const int offset = tid * K_QUANTS_PER_ITERATION; + + uint32_t uaux[2]; + const uint8_t * d = (const uint8_t *)uaux; + + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + offset; + const uint8_t * q = x[i].qs + offset; + const uint32_t * s = (const uint32_t *)x[i].scales; + + uaux[0] = s[0] & 0x0f0f0f0f; + uaux[1] = (s[0] >> 4) & 0x0f0f0f0f; + + const sycl::float2 dall = + x[i].dm.convert(); + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + const uint8_t ql = q[l]; + sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3) + + y[l+16] * d[1] * ((ql >> 2) & 3) + + y[l+32] * d[2] * ((ql >> 4) & 3) + + y[l+48] * d[3] * ((ql >> 6) & 3); + sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7]; + } + tmp += dall.x() * sum1 - dall.y() * sum2; + } + +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + +/* +DPCT1110:5: The total declared local variable size in device function +dequantize_mul_mat_vec_q3_k exceeds 128 bytes and may cause high register +pressure. Consult with your hardware vendor to find the total register size +available and adjust the code, or use smaller sub-group size to avoid high +register pressure. +*/ +static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx, + const float *__restrict__ yy, + float *__restrict__ dst, + const int ncols, int nrows, + const sycl::nd_item<3> &item_ct1) { + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q3_K * x = (const block_q3_K *)vx + ib0; + + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + + const uint16_t kmask1 = 0x0303; + const uint16_t kmask2 = 0x0f0f; + + const int tid = + item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = + item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1 + + const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop + const int step = 16/K_QUANTS_PER_ITERATION; + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0....15 or 0...7 + + const uint8_t m = 1 << (4*im); + + const int l0 = n*in; // 0...15 or 0...14 in steps of 2 + const int q_offset = 32*im + l0; + const int y_offset = 128*im + l0; + + uint16_t utmp[4]; + const int8_t * s = (const int8_t *)utmp; + + const uint16_t s_shift = 4*im; + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * q = x[i].qs + q_offset; + const uint8_t * h = x[i].hmask + l0; + + const uint16_t * a = (const uint16_t *)x[i].scales; + utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4); + utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4); + utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4); + utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4); + + const float d = x[i].d; + + float sum = 0; + for (int l = 0; l < n; ++l) { + sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4)) + + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4)) + + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4)) + + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4)); + sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4)) + + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4)) + + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4)) + + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4)); + } + tmp += d * sum; + + } +#else + + const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 + const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 + const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14 + const int in = offset/8; // 0 or 1 + const int im = offset%8; // 0...7 + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + offset; + const uint8_t * q = x[i].qs + offset; + const uint8_t * s = x[i].scales; + + const float dall = (float)x[i].d; + + float sum = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + const uint8_t hl = x[i].hmask[im+l] >> in; + const uint8_t ql = q[l]; + sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4)) + + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4)) + + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4)) + + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4)); + } + tmp += sum; + } +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + +/* +DPCT1110:6: The total declared local variable size in device function +dequantize_mul_mat_vec_q4_k exceeds 128 bytes and may cause high register +pressure. Consult with your hardware vendor to find the total register size +available and adjust the code, or use smaller sub-group size to avoid high +register pressure. +*/ +static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx, + const float *__restrict__ yy, + float *__restrict__ dst, + const int ncols, int nrows, + const sycl::nd_item<3> &item_ct1) { + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + if (row > nrows) return; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q4_K * x = (const block_q4_K *)vx + ib0; + +#if QK_K == 256 + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int tid = + item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = + item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1 + + const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 + + const int il = tid/step; // 0...3 + const int ir = tid - step*il; // 0...7 or 0...3 + const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 + + const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int in = il%2; + + const int l0 = n*(2*ir + in); + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; + + uint16_t aux[4]; + const uint8_t * sc = (const uint8_t *)aux; + +#if K_QUANTS_PER_ITERATION == 2 + uint32_t q32[4]; + const uint8_t * q4 = (const uint8_t *)q32; +#else + uint16_t q16[4]; + const uint8_t * q4 = (const uint8_t *)q16; +#endif + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y1 = yy + i*QK_K + y_offset; + const float * y2 = y1 + 128; + + const float dall = x[i].dm[0]; + const float dmin = x[i].dm[1]; + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux[0] = a[im+0] & kmask1; + aux[1] = a[im+2] & kmask1; + aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); + aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); + +#if K_QUANTS_PER_ITERATION == 2 + const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset); + const uint32_t * q2 = q1 + 16; + + q32[0] = q1[0] & 0x0f0f0f0f; + q32[1] = q1[0] & 0xf0f0f0f0; + q32[2] = q2[0] & 0x0f0f0f0f; + q32[3] = q2[0] & 0xf0f0f0f0; + + sycl::float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < 4; ++l) { + s.x() += y1[l] * q4[l + 0]; s.y() += y1[l + 32] * q4[l + 4]; + s.z() += y2[l] * q4[l + 8]; s.w() += y2[l + 32] * q4[l + 12]; + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (s.x() * sc[0] + s.y() * sc[1] * 1.f / 16.f + + s.z() * sc[4] + s.w() * sc[5] * 1.f / 16.f) - + dmin * smin; +#else + const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset); + const uint16_t * q2 = q1 + 32; + + q16[0] = q1[0] & 0x0f0f; + q16[1] = q1[0] & 0xf0f0; + q16[2] = q2[0] & 0x0f0f; + q16[3] = q2[0] & 0xf0f0; + + float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < 2; ++l) { + s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2]; + s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6]; + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; +#endif + + } +#else + const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 + const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); + + const int step = tid * K_QUANTS_PER_ITERATION; + + uint16_t aux16[2]; + const uint8_t * s = (const uint8_t *)aux16; + + float tmp = 0; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const uint8_t * q = x[i].qs + step; + const float * y = yy + i*QK_K + step; + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + const float d = (float)x[i].dm[0]; + const float m = (float)x[i].dm[1]; + float sum = 0.f; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2]) + + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2]) + + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3]) + + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]); + } + tmp += sum; + } + +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + +/* +DPCT1110:7: The total declared local variable size in device function +dequantize_mul_mat_vec_q5_k exceeds 128 bytes and may cause high register +pressure. Consult with your hardware vendor to find the total register size +available and adjust the code, or use smaller sub-group size to avoid high +register pressure. +*/ +static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx, + const float *__restrict__ yy, + float *__restrict__ dst, + const int ncols, + const sycl::nd_item<3> &item_ct1) { + + const int row = item_ct1.get_group(2); + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q5_K * x = (const block_q5_K *)vx + ib0; + + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int tid = item_ct1.get_local_id(2) / 2; // 0...15 + const int ix = item_ct1.get_local_id(2) % 2; + + const int il = tid/4; // 0...3 + const int ir = tid - 4*il;// 0...3 + const int n = 2; + + const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int in = il%2; + + const int l0 = n*(2*ir + in); + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; + + const uint8_t hm1 = 1 << (2*im); + const uint8_t hm2 = hm1 << 4; + + uint16_t aux[4]; + const uint8_t * sc = (const uint8_t *)aux; + + uint16_t q16[8]; + const uint8_t * q4 = (const uint8_t *)q16; + + for (int i = ix; i < num_blocks_per_row; i += 2) { + + const uint8_t * ql1 = x[i].qs + q_offset; + const uint8_t * qh = x[i].qh + l0; + const float * y1 = yy + i*QK_K + y_offset; + const float * y2 = y1 + 128; + + const float dall = x[i].dm[0]; + const float dmin = x[i].dm[1]; + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux[0] = a[im+0] & kmask1; + aux[1] = a[im+2] & kmask1; + aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); + aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); + + sycl::float4 sum = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + const uint16_t * q1 = (const uint16_t *)ql1; + const uint16_t * q2 = q1 + 32; + q16[0] = q1[0] & 0x0f0f; + q16[1] = q1[8] & 0x0f0f; + q16[2] = (q1[0] >> 4) & 0x0f0f; + q16[3] = (q1[8] >> 4) & 0x0f0f; + q16[4] = q2[0] & 0x0f0f; + q16[5] = q2[8] & 0x0f0f; + q16[6] = (q2[0] >> 4) & 0x0f0f; + q16[7] = (q2[8] >> 4) & 0x0f0f; + for (int l = 0; l < n; ++l) { + sum.x() += + y1[l + 0] * (q4[l + 0] + (qh[l + 0] & (hm1 << 0) ? 16 : 0)) + + y1[l + 16] * (q4[l + 2] + (qh[l + 16] & (hm1 << 0) ? 16 : 0)); + sum.y() += + y1[l + 32] * (q4[l + 4] + (qh[l + 0] & (hm1 << 1) ? 16 : 0)) + + y1[l + 48] * (q4[l + 6] + (qh[l + 16] & (hm1 << 1) ? 16 : 0)); + sum.z() += + y2[l + 0] * (q4[l + 8] + (qh[l + 0] & (hm2 << 0) ? 16 : 0)) + + y2[l + 16] * (q4[l + 10] + (qh[l + 16] & (hm2 << 0) ? 16 : 0)); + sum.w() += + y2[l + 32] * (q4[l + 12] + (qh[l + 0] & (hm2 << 1) ? 16 : 0)) + + y2[l + 48] * (q4[l + 14] + (qh[l + 16] & (hm2 << 1) ? 16 : 0)); + smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3] + + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7]; + } + tmp += dall * (sum.x() * sc[0] + sum.y() * sc[1] + sum.z() * sc[4] + + sum.w() * sc[5]) - + dmin * smin; + } + +#else + const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 + const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); + const int step = tid * K_QUANTS_PER_ITERATION; + const int im = step/8; + const int in = step%8; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const uint8_t * q = x[i].qs + step; + const int8_t * s = x[i].scales; + const float * y = yy + i*QK_K + step; + const float d = x[i].d; + float sum = 0.f; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + const uint8_t h = x[i].qh[in+j] >> im; + sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16)) + + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16)) + + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16)) + + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16)); + } + tmp += sum; + } +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + +static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows, + const sycl::nd_item<3> &item_ct1) { + + static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q6_K * x = (const block_q6_K *)vx + ib0; + +#if QK_K == 256 + + const int tid = + item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = + item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0, 1 + + const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 + +#if K_QUANTS_PER_ITERATION == 1 + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 + const int is = 0; +#else + const int l0 = 4 * in; // 0, 4, 8, ..., 28 + const int is = in / 4; +#endif + const int ql_offset = 64*im + l0; + const int qh_offset = 32*im + l0; + const int s_offset = 8*im + is; + const int y_offset = 128*im + l0; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * ql = x[i].ql + ql_offset; + const uint8_t * qh = x[i].qh + qh_offset; + const int8_t * s = x[i].scales + s_offset; + + const float d = x[i].d; + +#if K_QUANTS_PER_ITERATION == 1 + float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) + + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) + + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) + + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) + + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) + + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) + + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) + +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); + tmp += sum; +#else + float sum = 0; + for (int l = 0; l < 4; ++l) { + sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) + + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) + + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) + + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); + } + tmp += sum; +#endif + + } + +#else + + const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...7 + const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0...3 + + const int step = tid * K_QUANTS_PER_ITERATION; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + step; + const uint8_t * ql = x[i].ql + step; + const uint8_t * qh = x[i].qh + step; + const int8_t * s = x[i].scales; + + const float d = x[i+0].d; + + float sum = 0; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32) + + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32) + + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32) + + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32); + } + tmp += sum; + + } + +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + +static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloat *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_mul_mat_vec_reorder( + vx, y, dst, ncols, nrows, item_ct1); + }); + } +} + + +static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_mul_mat_vec( + vx, y, dst, ncols, nrows, item_ct1); + }); + } +} + +static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_mul_mat_vec( + vx, y, dst, ncols, nrows, item_ct1); + }); + } +} + +static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_mul_mat_vec( + vx, y, dst, ncols, nrows, item_ct1); + }); + } +} + +static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_mul_mat_vec( + vx, y, dst, ncols, nrows, item_ct1); + }); + } +} + +static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_mul_mat_vec( + vx, y, dst, ncols, nrows, item_ct1); + }); + } +} + +static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2 + const int block_num_y = (nrows + ny - 1) / ny; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { + dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1); + }); +} + +static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2 / K_QUANTS_PER_ITERATION; + const int block_num_y = (nrows + ny - 1) / ny; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { + dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1); + }); +} + +static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2 / K_QUANTS_PER_ITERATION; + const int block_num_y = (nrows + ny - 1) / ny; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { + dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1); + }); +} + +static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { + dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1); + }); +} + +static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2 / K_QUANTS_PER_ITERATION; + const int block_num_y = (nrows + ny - 1) / ny; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { + dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1); + }); +} + +void ggml_sycl_op_dequantize_mul_mat_vec( + ggml_backend_sycl_context & ctx, + const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, + const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, + float *dst_dd_i, const int64_t row_low, const int64_t row_high, + const int64_t src1_ncols, const int64_t src1_padded_row_size, + const dpct::queue_ptr &stream) { + + const int64_t ne00 = src0->ne[0]; + const int64_t row_diff = row_high - row_low; + GGML_ASSERT(src1->type == GGML_TYPE_F32); + // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics +#ifdef GGML_SYCL_F16 + ggml_sycl_pool_alloc src1_dfloat_a(ctx.pool()); + sycl::half *src1_dfloat = nullptr; // dfloat == half + + bool src1_convert_f16 = + src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || + src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || + src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; + + if (src1_convert_f16) { + scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2, + " : converting src1 to fp16"); + src1_dfloat = src1_dfloat_a.alloc(ne00); + const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst); + GGML_ASSERT(to_fp16_sycl != nullptr); + to_fp16_sycl(src1_ddf_i, src1_dfloat, ne00, stream); + } +#else + const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion +#endif // GGML_SYCL_F16 + + switch (src0->type) { + case GGML_TYPE_Q4_0: + if ((ggml_tensor_extra_gpu*)dst->src[0]->extra && + ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) { + dequantize_mul_mat_vec_q4_0_sycl_reorder(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + } else { + dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + } + break; + case GGML_TYPE_Q4_1: + dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_0: + dequantize_mul_mat_vec_q5_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_1: + dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q8_0: + dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q2_K: + dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q3_K: + dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q4_K: + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + // reorder is currently not supported for dmmv + GGML_ABORT("Unimplemented dequantize case case for q4_k reorder"); + } else { + dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + } + break; + case GGML_TYPE_Q5_K: + dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q6_K: + dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_F16: + convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + default: + printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type); + GGML_ABORT("fatal error"); + } + + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_ddq_i); + GGML_UNUSED(src1_ncols); + GGML_UNUSED(src1_padded_row_size); + GGML_UNUSED(ctx); +} diff --git a/ggml/src/ggml-sycl/dmmv.hpp b/ggml/src/ggml-sycl/dmmv.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bd837356415330d87cb59f36d1a7a5d1b1685ce6 --- /dev/null +++ b/ggml/src/ggml-sycl/dmmv.hpp @@ -0,0 +1,27 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_DMMV_HPP +#define GGML_SYCL_DMMV_HPP + +#include "common.hpp" + + +void ggml_sycl_op_dequantize_mul_mat_vec( + ggml_backend_sycl_context & ctx, + const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, + const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, + float *dst_dd_i, const int64_t row_low, const int64_t row_high, + const int64_t src1_ncols, const int64_t src1_padded_row_size, + const dpct::queue_ptr &stream); + +#endif // GGML_SYCL_DMMV_HPP diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d538965b096bf3e14ef483b4e0a2ee466aece834 --- /dev/null +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -0,0 +1,2957 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_DPCT_HELPER_HPP +#define GGML_SYCL_DPCT_HELPER_HPP + +#include +#include +#include +#include + +#ifdef GGML_SYCL_USE_INTEL_ONEMKL +#include +// Allow to use the same namespace for Intel oneMKL and oneMath +namespace oneapi { + namespace math = mkl; +} +#else +#include +#endif + +#include "ggml.h" + +#if defined(__linux__) +#include +#elif defined(_WIN64) +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#else +#error "Only support Windows and Linux." +#endif + +#if defined(__linux__) +#include +#include +#endif +#if defined(_WIN64) +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#endif + +#define DPCT_COMPATIBILITY_TEMP (900) + +#if defined(_MSC_VER) +#define __dpct_align__(n) __declspec(align(n)) +#define __dpct_inline__ __forceinline +#else +#define __dpct_align__(n) __attribute__((aligned(n))) +#define __dpct_inline__ __inline__ __attribute__((always_inline)) +#endif + +#if defined(_MSC_VER) +#define __dpct_noinline__ __declspec(noinline) +#else +#define __dpct_noinline__ __attribute__((noinline)) +#endif + +inline std::string get_device_type_name(const sycl::device &Device) { + auto DeviceType = Device.get_info(); + switch (DeviceType) { + case sycl::info::device_type::cpu: + return "cpu"; + case sycl::info::device_type::gpu: + return "gpu"; + case sycl::info::device_type::host: + return "host"; + case sycl::info::device_type::accelerator: + return "acc"; + default: + return "unknown"; + } +} + +inline std::string get_device_backend_and_type(const sycl::device &device) { + std::stringstream device_type; + sycl::backend backend = device.get_backend(); + device_type << backend << ":" << get_device_type_name(device); + return device_type.str(); +} + +template struct matrix_info_t { + oneapi::math::transpose transpose_info[2]; + Ts value_info[2]; + std::int64_t size_info[3]; + std::int64_t ld_info[3]; + std::int64_t groupsize_info; +}; + +inline auto get_onemath_backend(sycl::queue& queue) +#if defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL) + -> sycl::queue& +#endif +{ +// If the backend is known at compile-time, use oneMath backend_selector to use +// compile-time dispatching and avoid the need to dlopen libraries. Otherwise +// fallback to runtime dispatching. +#if defined(GGML_SYCL_NVIDIA) + return oneapi::math::backend_selector{ queue }; +#elif defined(GGML_SYCL_AMD) + return oneapi::math::backend_selector{ queue }; +#elif defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL) + return queue; +#else + static_assert(false, "Unsupported backend"); +#endif +} + +namespace dpct +{ + typedef sycl::queue *queue_ptr; + typedef sycl::event *event_ptr; + typedef char *device_ptr; + typedef uint8_t byte_t; + typedef sycl::buffer buffer_t; + + /// SYCL default exception handler + inline auto exception_handler = [](sycl::exception_list exceptions) + { + for (std::exception_ptr const &e : exceptions) + { + try + { + std::rethrow_exception(e); + } + catch (sycl::exception const &e) + { + std::cerr << "Caught asynchronous SYCL exception:" << std::endl + << e.what() << std::endl + << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + } + } + }; + + enum error_code + { + success = 0, + default_error = 999 + }; + + enum memcpy_direction + { + host_to_host, + host_to_device, + device_to_host, + device_to_device, + automatic + }; + + enum memory_region + { + global = 0, // device global memory + constant, // device constant memory + local, // device local memory + shared, // memory which can be accessed by host and device + }; + + enum class library_data_t : unsigned char + { + real_float = 0, + complex_float, + real_double, + complex_double, + real_half, + complex_half, + real_bfloat16, + complex_bfloat16, + real_int4, + complex_int4, + real_uint4, + complex_uint4, + real_int8, + complex_int8, + real_uint8, + complex_uint8, + real_int16, + complex_int16, + real_uint16, + complex_uint16, + real_int32, + complex_int32, + real_uint32, + complex_uint32, + real_int64, + complex_int64, + real_uint64, + complex_uint64, + real_int8_4, + real_int8_32, + real_uint8_4, + library_data_t_size + }; + + template + struct DataType + { + using T2 = T; + }; + template + struct DataType> + { + using T2 = std::complex; + }; + + static void destroy_event(event_ptr event) + { + delete event; + } + + static inline unsigned int get_tid() + { +#if defined(__linux__) + return syscall(SYS_gettid); +#elif defined(_WIN64) + return GetCurrentThreadId(); +#else +#error "Only support Windows and Linux." +#endif + } + + namespace detail + { + static void get_version(const sycl::device &dev, int &major, int &minor) + { + // Version string has the following format: + // a. OpenCL + // b. + // c. e.g gfx1030 + std::string ver; + ver = dev.get_info(); + std::string::size_type i = 0; + while (i < ver.size()) { + if (isdigit(ver[i])) + break; + i++; + } + major = std::stoi(&(ver[i])); + while (i < ver.size()) { + if (ver[i] == '.') + break; + i++; + } + if (i < ver.size()) { + // a. and b. + i++; + minor = std::stoi(&(ver[i])); + } else { + // c. + minor = 0; + } + } + + template + class generic_error_type + { + public: + generic_error_type() = default; + generic_error_type(T value) : value{value} {} + operator T() const { return value; } + + private: + T value; + }; + + } // namespace detail + + /// Pitched 2D/3D memory data. + class pitched_data + { + public: + pitched_data() : pitched_data(nullptr, 0, 0, 0) {} + pitched_data(void *data, size_t pitch, size_t x, size_t y) + : _data(data), _pitch(pitch), _x(x), _y(y) {} + + void *get_data_ptr() { return _data; } + void set_data_ptr(void *data) { _data = data; } + + size_t get_pitch() { return _pitch; } + void set_pitch(size_t pitch) { _pitch = pitch; } + + size_t get_x() { return _x; } + void set_x(size_t x) { _x = x; } + + size_t get_y() { return _y; } + void set_y(size_t y) { _y = y; } + + private: + void *_data; + size_t _pitch, _x, _y; + }; + + class device_info + { + public: + // get interface + const char *get_name() const { return _name; } + char *get_name() { return _name; } + template , + std::enable_if_t> || + std::is_same_v, + int> = 0> + auto get_max_work_item_sizes() const + { + if constexpr (std::is_same_v>) + return sycl::range<3>(_max_work_item_sizes_i[0], + _max_work_item_sizes_i[1], + _max_work_item_sizes_i[2]); + else + { + return _max_work_item_sizes_i; + } + } + template , + std::enable_if_t> || + std::is_same_v, + int> = 0> + auto get_max_work_item_sizes() + { + if constexpr (std::is_same_v>) + return sycl::range<3>(_max_work_item_sizes_i[0], + _max_work_item_sizes_i[1], + _max_work_item_sizes_i[2]); + else + { + return _max_work_item_sizes_i; + } + } + bool get_host_unified_memory() const { return _host_unified_memory; } + int get_major_version() const { return _major; } + int get_minor_version() const { return _minor; } + int get_integrated() const { return _integrated; } + int get_max_clock_frequency() const { return _frequency; } + int get_max_compute_units() const { return _max_compute_units; } + int get_max_work_group_size() const { return _max_work_group_size; } + int get_max_sub_group_size() const { return _max_sub_group_size; } + int get_max_work_items_per_compute_unit() const + { + return _max_work_items_per_compute_unit; + } + int get_max_register_size_per_work_group() const + { + return _max_register_size_per_work_group; + } + template || + std::is_same_v, + int> = 0> + auto get_max_nd_range_size() const + { + if constexpr (std::is_same_v) + return _max_nd_range_size; + else + return _max_nd_range_size_i; + } + template || + std::is_same_v, + int> = 0> + auto get_max_nd_range_size() + { + if constexpr (std::is_same_v) + return _max_nd_range_size; + else + return _max_nd_range_size_i; + } + size_t get_global_mem_size() const { return _global_mem_size; } + size_t get_local_mem_size() const { return _local_mem_size; } + size_t get_max_mem_alloc_size() const { return _max_mem_alloc_size; } + /// Returns the maximum clock rate of device's global memory in kHz. If + /// compiler does not support this API then returns default value 3200000 kHz. + unsigned int get_memory_clock_rate() const { return _memory_clock_rate; } + /// Returns the maximum bus width between device and memory in bits. If + /// compiler does not support this API then returns default value 64 bits. + unsigned int get_memory_bus_width() const { return _memory_bus_width; } + uint32_t get_device_id() const { return _device_id; } + std::array get_uuid() const { return _uuid; } + /// Returns global memory cache size in bytes. + unsigned int get_global_mem_cache_size() const + { + return _global_mem_cache_size; + } + + // set interface + void set_name(const char *name) + { + size_t length = strlen(name); + if (length < 256) + { + std::memcpy(_name, name, length + 1); + } + else + { + std::memcpy(_name, name, 255); + _name[255] = '\0'; + } + } + void set_max_work_item_sizes(const sycl::range<3> max_work_item_sizes) + { + for (int i = 0; i < 3; ++i) + _max_work_item_sizes_i[i] = max_work_item_sizes[i]; + } + [[deprecated]] void + set_max_work_item_sizes(const sycl::id<3> max_work_item_sizes) + { + for (int i = 0; i < 3; ++i) + { + _max_work_item_sizes_i[i] = max_work_item_sizes[i]; + } + } + void set_host_unified_memory(bool host_unified_memory) + { + _host_unified_memory = host_unified_memory; + } + void set_major_version(int major) { _major = major; } + void set_minor_version(int minor) { _minor = minor; } + void set_integrated(int integrated) { _integrated = integrated; } + void set_max_clock_frequency(int frequency) { _frequency = frequency; } + void set_max_compute_units(int max_compute_units) + { + _max_compute_units = max_compute_units; + } + void set_global_mem_size(size_t global_mem_size) + { + _global_mem_size = global_mem_size; + } + void set_local_mem_size(size_t local_mem_size) + { + _local_mem_size = local_mem_size; + } + void set_max_mem_alloc_size(size_t max_mem_alloc_size) + { + _max_mem_alloc_size = max_mem_alloc_size; + } + void set_max_work_group_size(int max_work_group_size) + { + _max_work_group_size = max_work_group_size; + } + void set_max_sub_group_size(int max_sub_group_size) + { + _max_sub_group_size = max_sub_group_size; + } + void + set_max_work_items_per_compute_unit(int max_work_items_per_compute_unit) + { + _max_work_items_per_compute_unit = max_work_items_per_compute_unit; + } + void set_max_nd_range_size(int max_nd_range_size[]) + { + for (int i = 0; i < 3; i++) + { + _max_nd_range_size[i] = max_nd_range_size[i]; + _max_nd_range_size_i[i] = max_nd_range_size[i]; + } + } + void set_memory_clock_rate(unsigned int memory_clock_rate) + { + _memory_clock_rate = memory_clock_rate; + } + void set_memory_bus_width(unsigned int memory_bus_width) + { + _memory_bus_width = memory_bus_width; + } + void + set_max_register_size_per_work_group(int max_register_size_per_work_group) + { + _max_register_size_per_work_group = max_register_size_per_work_group; + } + void set_device_id(uint32_t device_id) + { + _device_id = device_id; + } + void set_uuid(std::array uuid) + { + _uuid = std::move(uuid); + } + void set_global_mem_cache_size(unsigned int global_mem_cache_size) + { + _global_mem_cache_size = global_mem_cache_size; + } + + private: + char _name[256]; + int _max_work_item_sizes_i[3]; + bool _host_unified_memory = false; + int _major; + int _minor; + int _integrated = 0; + int _frequency; + // Set estimated value 3200000 kHz as default value. + unsigned int _memory_clock_rate = 3200000; + // Set estimated value 64 bits as default value. + unsigned int _memory_bus_width = 64; + unsigned int _global_mem_cache_size; + int _max_compute_units; + int _max_work_group_size; + int _max_sub_group_size; + int _max_work_items_per_compute_unit; + int _max_register_size_per_work_group; + size_t _global_mem_size; + size_t _local_mem_size; + size_t _max_mem_alloc_size; + size_t _max_nd_range_size[3]; + int _max_nd_range_size_i[3]; + uint32_t _device_id; + std::array _uuid; + }; + + static int get_major_version(const sycl::device &dev) + { + int major, minor; + detail::get_version(dev, major, minor); + return major; + } + + static int get_minor_version(const sycl::device &dev) + { + int major, minor; + detail::get_version(dev, major, minor); + return minor; + } + + static void get_device_info(device_info &out, const sycl::device &dev) + { + device_info prop; + prop.set_name(dev.get_info().c_str()); + + int major, minor; + detail::get_version(dev, major, minor); + prop.set_major_version(major); + prop.set_minor_version(minor); + + prop.set_max_work_item_sizes( +#if (__SYCL_COMPILER_VERSION && __SYCL_COMPILER_VERSION < 20220902) + // oneAPI DPC++ compiler older than 2022/09/02, where max_work_item_sizes + // is an enum class element + dev.get_info()); +#else + // SYCL 2020-conformant code, max_work_item_sizes is a struct templated by + // an int + dev.get_info>()); +#endif + prop.set_host_unified_memory(dev.has(sycl::aspect::usm_host_allocations)); + + prop.set_max_clock_frequency( + dev.get_info() * 1000); + + prop.set_max_compute_units( + dev.get_info()); + prop.set_max_work_group_size( + dev.get_info()); + prop.set_global_mem_size(dev.get_info()); + prop.set_local_mem_size(dev.get_info()); + prop.set_max_mem_alloc_size(dev.get_info()); + +#if (defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 6) + if (dev.has(sycl::aspect::ext_intel_memory_clock_rate)) + { + unsigned int tmp = + dev.get_info(); + if (tmp != 0) + prop.set_memory_clock_rate(1000 * tmp); + } + if (dev.has(sycl::aspect::ext_intel_memory_bus_width)) + { + prop.set_memory_bus_width( + dev.get_info()); + } + if (dev.has(sycl::aspect::ext_intel_device_id)) + { + prop.set_device_id( + dev.get_info()); + } + if (dev.has(sycl::aspect::ext_intel_device_info_uuid)) + { + prop.set_uuid(dev.get_info()); + } +#elif defined(_MSC_VER) && !defined(__clang__) +#pragma message("get_device_info: querying memory_clock_rate and \ + memory_bus_width are not supported by the compiler used. \ + Use 3200000 kHz as memory_clock_rate default value. \ + Use 64 bits as memory_bus_width default value.") +#else +#warning "get_device_info: querying memory_clock_rate and \ + memory_bus_width are not supported by the compiler used. \ + Use 3200000 kHz as memory_clock_rate default value. \ + Use 64 bits as memory_bus_width default value." +#endif + + size_t max_sub_group_size = 1; + std::vector sub_group_sizes = + dev.get_info(); + + for (const auto &sub_group_size : sub_group_sizes) + { + if (max_sub_group_size < sub_group_size) + max_sub_group_size = sub_group_size; + } + + prop.set_max_sub_group_size(max_sub_group_size); + + prop.set_max_work_items_per_compute_unit( + dev.get_info()); + int max_nd_range_size[] = {0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF}; + prop.set_max_nd_range_size(max_nd_range_size); + + // Estimates max register size per work group, feel free to update the value + // according to device properties. + prop.set_max_register_size_per_work_group(65536); + + prop.set_global_mem_cache_size( + dev.get_info()); + out = prop; + } + + /// dpct device extension + class device_ext : public sycl::device { + typedef std::mutex mutex_type; + + public: + device_ext() : sycl::device() {} + ~device_ext() { + std::lock_guard lock(m_mutex); + clear_queues(); + } + device_ext(const sycl::device &base) : sycl::device(base) { + std::lock_guard lock(m_mutex); + init_queues(); + } + + int is_native_atomic_supported() { return 0; } + int get_major_version() const { return dpct::get_major_version(*this); } + + int get_minor_version() const { return dpct::get_minor_version(*this); } + + int get_max_compute_units() const { + return get_device_info().get_max_compute_units(); + } + + /// Return the maximum clock frequency of this device in KHz. + int get_max_clock_frequency() const { + return get_device_info().get_max_clock_frequency(); + } + + int get_integrated() const { return get_device_info().get_integrated(); } + + int get_max_sub_group_size() const { + return get_device_info().get_max_sub_group_size(); + } + + int get_max_register_size_per_work_group() const { + return get_device_info().get_max_register_size_per_work_group(); + } + + int get_max_work_group_size() const { + return get_device_info().get_max_work_group_size(); + } + + int get_mem_base_addr_align() const { + return get_info(); + } + + size_t get_global_mem_size() const { + return get_device_info().get_global_mem_size(); + } + + size_t get_max_mem_alloc_size() const { + return get_device_info().get_max_mem_alloc_size(); + } + + /// Get the number of bytes of free and total memory on the SYCL device. + /// \param [out] free_memory The number of bytes of free memory on the + /// SYCL device. \param [out] total_memory The number of bytes of total + /// memory on the SYCL device. + void get_memory_info(size_t &free_memory, size_t &total_memory) { + total_memory = get_device_info().get_global_mem_size(); + const char *warning_info = + "get_memory_info: [warning] ext_intel_free_memory is not " + "supported (export/set ZES_ENABLE_SYSMAN=1 to support), " + "use total memory as free memory"; +#if (defined(__SYCL_COMPILER_VERSION) && __SYCL_COMPILER_VERSION >= 20221105) + if (!has(sycl::aspect::ext_intel_free_memory)) { + std::cerr << warning_info << std::endl; + free_memory = total_memory; + } else { + free_memory = get_info(); + } +#else + std::cerr << warning_info << std::endl; + free_memory = total_memory; +#if defined(_MSC_VER) && !defined(__clang__) +#pragma message("Querying the number of bytes of free memory is not supported") +#else +#warning "Querying the number of bytes of free memory is not supported" +#endif +#endif + } + + void get_device_info(device_info &out) const { + dpct::get_device_info(out, *this); + } + + device_info get_device_info() const { + device_info prop; + dpct::get_device_info(prop, *this); + return prop; + } + + void reset() { + std::lock_guard lock(m_mutex); + clear_queues(); + init_queues(); + } + + sycl::queue &in_order_queue() { return _q_in_order; } + + sycl::queue &out_of_order_queue() { return _q_out_of_order; } + + sycl::queue &default_queue() { return in_order_queue(); } + + void queues_wait_and_throw() { + std::unique_lock lock(m_mutex); + lock.unlock(); + for (auto &q : _queues) { + q.wait_and_throw(); + } + // Guard the destruct of current_queues to make sure the ref count is + // safe. + lock.lock(); + } + + sycl::queue create_queue(bool enable_exception_handler = false) { + return create_in_order_queue(enable_exception_handler); + } + + sycl::queue create_queue(sycl::device device, + bool enable_exception_handler = false) { + return create_in_order_queue(device, enable_exception_handler); + } + + sycl::queue create_in_order_queue(bool enable_exception_handler = false) { + std::lock_guard lock(m_mutex); + return create_queue_impl(enable_exception_handler, + sycl::property::queue::in_order()); + } + + sycl::queue create_in_order_queue(sycl::device device, + bool enable_exception_handler = false) { + std::lock_guard lock(m_mutex); + return create_queue_impl(device, enable_exception_handler, + sycl::property::queue::in_order()); + } + + sycl::queue create_out_of_order_queue( + bool enable_exception_handler = false) { + std::lock_guard lock(m_mutex); + return create_queue_impl(enable_exception_handler); + } + + void destroy_queue(sycl::queue queue) { + std::lock_guard lock(m_mutex); + _queues.erase(std::remove_if(_queues.begin(), _queues.end(), + [=](const sycl::queue &q) -> bool + { + return q == queue; + }), + _queues.end()); + } + void set_saved_queue(sycl::queue q) { + std::lock_guard lock(m_mutex); + _saved_queue = q; + } + sycl::queue get_saved_queue() const { + std::lock_guard lock(m_mutex); + return _saved_queue; + } + + private: + void clear_queues() { _queues.clear(); } + + void init_queues() { + _q_in_order = + create_queue_impl(true, sycl::property::queue::in_order()); + _q_out_of_order = create_queue_impl(true); + _saved_queue = default_queue(); + } + + /// Caller should acquire resource \p m_mutex before calling this + /// function. + template + sycl::queue create_queue_impl(bool enable_exception_handler, + Properties... properties) { + sycl::async_handler eh = {}; + if (enable_exception_handler) { + eh = exception_handler; + } + _queues.push_back(sycl::queue( + *this, eh, + sycl::property_list( +#ifdef DPCT_PROFILING_ENABLED + sycl::property::queue::enable_profiling(), +#endif + properties...))); + + return _queues.back(); + } + + template + sycl::queue create_queue_impl(sycl::device device, + bool enable_exception_handler, + Properties... properties) { + sycl::async_handler eh = {}; + if (enable_exception_handler) { + eh = exception_handler; + } + _queues.push_back(sycl::queue( + device, eh, + sycl::property_list( +#ifdef DPCT_PROFILING_ENABLED + sycl::property::queue::enable_profiling(), +#endif + properties...))); + + return _queues.back(); + } + + void get_version(int &major, int &minor) const { + detail::get_version(*this, major, minor); + } + sycl::queue _q_in_order, _q_out_of_order; + sycl::queue _saved_queue; + std::vector _queues; + mutable mutex_type m_mutex; + }; + + + /// device manager + class dev_mgr + { + public: + device_ext ¤t_device() + { + unsigned int dev_id = current_device_id(); + check_id(dev_id); + return *_devs[dev_id]; + } + device_ext &cpu_device() const + { + std::lock_guard lock(m_mutex); + if (_cpu_device == -1) + { + throw std::runtime_error("no valid cpu device"); + } + else + { + return *_devs[_cpu_device]; + } + } + device_ext &get_device(unsigned int id) const + { + std::lock_guard lock(m_mutex); + check_id(id); + return *_devs[id]; + } + unsigned int current_device_id() const + { + std::lock_guard lock(m_mutex); + auto it = _thread2dev_map.find(get_tid()); + if (it != _thread2dev_map.end()) + return it->second; + return DEFAULT_DEVICE_ID; + } + + /// Select device with a device ID. + /// \param [in] id The id of the device which can + /// be obtained through get_device_id(const sycl::device). + void select_device(unsigned int id) + { + std::lock_guard lock(m_mutex); + check_id(id); + _thread2dev_map[get_tid()] = id; + } + unsigned int device_count() { return _devs.size(); } + + unsigned int get_device_id(const sycl::device &dev) + { + unsigned int id = 0; + for (auto &dev_item : _devs) + { + if (*dev_item == dev) + { + return id; + } + id++; + } + return -1; + } + + inline std::string get_preferred_gpu_platform_name() { + std::string result; + + std::string filter = ""; + char* env = getenv("ONEAPI_DEVICE_SELECTOR"); + if (env) { + if (std::strstr(env, "level_zero")) { + filter = "level-zero"; + } + else if (std::strstr(env, "opencl")) { + filter = "opencl"; + } + else if (std::strstr(env, "cuda")) { + filter = "cuda"; + } + else if (std::strstr(env, "hip")) { + filter = "hip"; + } + else { + throw std::runtime_error("invalid device filter: " + std::string(env)); + } + } else { + auto default_device = sycl::device(sycl::default_selector_v); + auto default_platform_name = default_device.get_platform().get_info(); + + if (std::strstr(default_platform_name.c_str(), "Level-Zero") || default_device.is_cpu()) { + filter = "level-zero"; + } + else if (std::strstr(default_platform_name.c_str(), "CUDA")) { + filter = "cuda"; + } + else if (std::strstr(default_platform_name.c_str(), "HIP")) { + filter = "hip"; + } + } + + auto platform_list = sycl::platform::get_platforms(); + + for (const auto& platform : platform_list) { + auto devices = platform.get_devices(); + auto gpu_dev = std::find_if(devices.begin(), devices.end(), [](const sycl::device& d) { + return d.is_gpu(); + }); + + if (gpu_dev == devices.end()) { + // cout << "platform [" << platform_name + // << "] does not contain GPU devices, skipping\n"; + continue; + } + + auto platform_name = platform.get_info(); + std::string platform_name_low_case; + platform_name_low_case.resize(platform_name.size()); + + std::transform( + platform_name.begin(), platform_name.end(), platform_name_low_case.begin(), ::tolower); + + if (platform_name_low_case.find(filter) == std::string::npos) { + // cout << "platform [" << platform_name + // << "] does not match with requested " + // << filter << ", skipping\n"; + continue; + } + + result = platform_name; + } + + if (result.empty()) + throw std::runtime_error("can not find preferred GPU platform"); + + return result; + } + + template + std::enable_if_t< + std::is_invocable_r_v> + select_device(const DeviceSelector &selector = sycl::gpu_selector_v) + { + sycl::device selected_device = sycl::device(selector); + unsigned int selected_device_id = get_device_id(selected_device); + select_device(selected_device_id); + } + + /// Returns the instance of device manager singleton. + static dev_mgr &instance() + { + static dev_mgr d_m; + return d_m; + } + dev_mgr(const dev_mgr &) = delete; + dev_mgr &operator=(const dev_mgr &) = delete; + dev_mgr(dev_mgr &&) = delete; + dev_mgr &operator=(dev_mgr &&) = delete; + + private: + mutable std::recursive_mutex m_mutex; + static bool compare_dev(sycl::device &device1, sycl::device &device2) + { + sycl::backend backend1 = device1.get_backend(); + sycl::backend backend2 = device2.get_backend(); + // levelzero backends always come first + if(backend1 == sycl::backend::ext_oneapi_level_zero && backend2 != sycl::backend::ext_oneapi_level_zero) return true; + if(backend1 != sycl::backend::ext_oneapi_level_zero && backend2 == sycl::backend::ext_oneapi_level_zero) return false; + dpct::device_info prop1; + dpct::get_device_info(prop1, device1); + dpct::device_info prop2; + dpct::get_device_info(prop2, device2); + return prop1.get_max_compute_units() > prop2.get_max_compute_units(); + } + static int convert_backend_index(std::string & backend) { + if (backend == "ext_oneapi_level_zero:gpu") return 0; + if (backend == "opencl:gpu") return 1; + if (backend == "ext_oneapi_cuda:gpu") return 2; + if (backend == "ext_oneapi_hip:gpu") return 3; + if (backend == "opencl:cpu") return 4; + if (backend == "opencl:acc") return 5; + printf("convert_backend_index: can't handle backend=%s\n", backend.c_str()); + GGML_ABORT("fatal error"); + } + static bool compare_backend(std::string &backend1, std::string &backend2) { + return convert_backend_index(backend1) < convert_backend_index(backend2); + } + dev_mgr() + { + sycl::device default_device = + sycl::device(sycl::default_selector_v); + _devs.push_back(std::make_shared(default_device)); + + std::vector sycl_all_devs; + // Collect other devices except for the default device. + if (default_device.is_cpu()) + _cpu_device = 0; + + auto Platforms = sycl::platform::get_platforms(); + // Keep track of the number of devices per backend + std::map DeviceNums; + std::map> backend_devices; + auto preferred_platform_name = get_preferred_gpu_platform_name(); + + while (!Platforms.empty()) { + auto Platform = Platforms.back(); + Platforms.pop_back(); + auto platform_name = Platform.get_info(); + if (platform_name.compare(preferred_platform_name) != 0) { + continue; + } + auto devices = Platform.get_devices(); + std::string backend_type = get_device_backend_and_type(devices[0]); + for (const auto &device : devices) { + backend_devices[backend_type].push_back(device); + } + } + + std::vector keys; + for(auto it = backend_devices.begin(); it != backend_devices.end(); ++it) { + keys.push_back(it->first); + } + std::sort(keys.begin(), keys.end(), compare_backend); + + for (auto &key : keys) { + std::vector devs = backend_devices[key]; + std::sort(devs.begin(), devs.end(), compare_dev); + for (const auto &dev : devs) { + sycl_all_devs.push_back(dev); + } + } + + for (auto &dev : sycl_all_devs) + { + if (dev == default_device) + { + continue; + } + _devs.push_back(std::make_shared(dev)); + if (_cpu_device == -1 && dev.is_cpu()) + { + _cpu_device = _devs.size() - 1; + } + } + } + void check_id(unsigned int id) const + { + if (id >= _devs.size()) + { + throw std::runtime_error("invalid device id"); + } + } + std::vector> _devs; + /// DEFAULT_DEVICE_ID is used, if current_device_id() can not find current + /// thread id in _thread2dev_map, which means default device should be used + /// for the current thread. + const unsigned int DEFAULT_DEVICE_ID = 0; + /// thread-id to device-id map. + std::map _thread2dev_map; + int _cpu_device = -1; + }; + + static inline sycl::queue &get_default_queue() + { + return dev_mgr::instance().current_device().default_queue(); + } + + namespace detail + { + enum class pointer_access_attribute + { + host_only = 0, + device_only, + host_device, + end + }; + + static pointer_access_attribute get_pointer_attribute(sycl::queue &q, + const void *ptr) + { + switch (sycl::get_pointer_type(ptr, q.get_context())) + { + case sycl::usm::alloc::unknown: + return pointer_access_attribute::host_only; + case sycl::usm::alloc::device: + return pointer_access_attribute::device_only; + case sycl::usm::alloc::shared: + case sycl::usm::alloc::host: + return pointer_access_attribute::host_device; + } + } + + template + inline constexpr std::uint64_t get_type_combination_id(ArgT Val) + { + static_assert((unsigned char)library_data_t::library_data_t_size <= + std::numeric_limits::max() && + "library_data_t size exceeds limit."); + static_assert(std::is_same_v, "Unsupported ArgT"); + return (std::uint64_t)Val; + } + + template + inline constexpr std::uint64_t get_type_combination_id(FirstT FirstVal, + RestT... RestVal) + { + static_assert((std::uint8_t)library_data_t::library_data_t_size <= + std::numeric_limits::max() && + "library_data_t size exceeds limit."); + static_assert(sizeof...(RestT) <= 8 && "Too many parameters"); + static_assert(std::is_same_v, "Unsupported FirstT"); + return get_type_combination_id(RestVal...) << 8 | ((std::uint64_t)FirstVal); + } + + class mem_mgr + { + mem_mgr() + { + // Reserved address space, no real memory allocation happens here. +#if defined(__linux__) + mapped_address_space = + (byte_t *)mmap(nullptr, mapped_region_size, PROT_NONE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); +#elif defined(_WIN64) + mapped_address_space = (byte_t *)VirtualAlloc( + NULL, // NULL specified as the base address parameter + mapped_region_size, // Size of allocation + MEM_RESERVE, // Allocate reserved pages + PAGE_NOACCESS); // Protection = no access +#else +#error "Only support Windows and Linux." +#endif + next_free = mapped_address_space; + } + + public: + using buffer_id_t = int; + + struct allocation + { + buffer_t buffer; + byte_t *alloc_ptr; + size_t size; + }; + + ~mem_mgr() + { +#if defined(__linux__) + munmap(mapped_address_space, mapped_region_size); +#elif defined(_WIN64) + VirtualFree(mapped_address_space, 0, MEM_RELEASE); +#else +#error "Only support Windows and Linux." +#endif + } + + mem_mgr(const mem_mgr &) = delete; + mem_mgr &operator=(const mem_mgr &) = delete; + mem_mgr(mem_mgr &&) = delete; + mem_mgr &operator=(mem_mgr &&) = delete; + + /// Allocate + void *mem_alloc(size_t size) + { + if (!size) + return nullptr; + std::lock_guard lock(m_mutex); + if (next_free + size > mapped_address_space + mapped_region_size) + { + throw std::runtime_error("dpct_malloc: out of memory for virtual memory pool"); + } + // Allocation + sycl::range<1> r(size); + buffer_t buf(r); + allocation A{buf, next_free, size}; + // Map allocation to device pointer + void *result = next_free; + m_map.emplace(next_free + size, A); + // Update pointer to the next free space. + next_free += (size + extra_padding + alignment - 1) & ~(alignment - 1); + + return result; + } + + /// Deallocate + void mem_free(const void *ptr) + { + if (!ptr) + return; + std::lock_guard lock(m_mutex); + auto it = get_map_iterator(ptr); + m_map.erase(it); + } + + /// map: device pointer -> allocation(buffer, alloc_ptr, size) + allocation translate_ptr(const void *ptr) + { + std::lock_guard lock(m_mutex); + auto it = get_map_iterator(ptr); + return it->second; + } + + /// Check if the pointer represents device pointer or not. + bool is_device_ptr(const void *ptr) const + { + std::lock_guard lock(m_mutex); + return (mapped_address_space <= ptr) && + (ptr < mapped_address_space + mapped_region_size); + } + + /// Returns the instance of memory manager singleton. + static mem_mgr &instance() + { + static mem_mgr m; + return m; + } + + private: + std::map m_map; + mutable std::mutex m_mutex; + byte_t *mapped_address_space; + byte_t *next_free; + const size_t mapped_region_size = 128ull * 1024 * 1024 * 1024; + const size_t alignment = 256; + /// This padding may be defined to some positive value to debug + /// out of bound accesses. + const size_t extra_padding = 0; + + std::map::iterator get_map_iterator(const void *ptr) + { + auto it = m_map.upper_bound(const_cast(reinterpret_cast(ptr))); + if (it == m_map.end()) + { + // Not a virtual pointer. + throw std::runtime_error("can not get buffer from non-virtual pointer"); + } + const allocation &alloc = it->second; + if (ptr < alloc.alloc_ptr) + { + // Out of bound. + // This may happen if there's a gap between allocations due to alignment + // or extra padding and pointer points to this gap. + throw std::runtime_error("invalid virtual pointer"); + } + return it; + } + }; + + template + class accessor; + template + class memory_traits + { + public: + static constexpr sycl::access::target target = + sycl::access::target::device; + static constexpr sycl::access_mode mode = + (Memory == constant) ? sycl::access_mode::read + : sycl::access_mode::read_write; + static constexpr size_t type_size = sizeof(T); + using element_t = + typename std::conditional::type; + using value_t = typename std::remove_cv::type; + template + using accessor_t = typename std::conditional< + Memory == local, sycl::local_accessor, + sycl::accessor>::type; + using pointer_t = T *; + }; + + static inline void *dpct_malloc(size_t size, sycl::queue &q) + { + return sycl::malloc_device(size, q.get_device(), q.get_context()); + } + +#define PITCH_DEFAULT_ALIGN(x) (((x) + 31) & ~(0x1F)) + static inline void *dpct_malloc(size_t &pitch, size_t x, size_t y, size_t z, + sycl::queue &q) + { + pitch = PITCH_DEFAULT_ALIGN(x); + return dpct_malloc(pitch * y * z, q); + } + + /** + * @brief Sets \p value to the first \p size elements starting from \p dev_ptr in \p q. + * @tparam valueT The type of the element to be set. + * @param [in] q The queue in which the operation is done. + * @param [in] dev_ptr Pointer to the virtual device memory address. + * @param [in] value The value to be set. + * @param [in] size Number of elements to be set to the value. + * @return An event representing the memset operation. + */ + template + static inline sycl::event dpct_memset(sycl::queue &q, void *dev_ptr, + valueT value, size_t size) + { + return q.fill(dev_ptr, value, size); + } + + /** + * @brief Sets \p value to the 3D memory region pointed by \p data in \p q. + * @tparam valueT The type of the element to be set. + * @param [in] q The queue in which the operation is done. + * @param [in] data Pointer to the pitched device memory region. + * @param [in] value The value to be set. + * @param [in] size 3D memory region by number of elements. + * @return An event list representing the memset operations. + */ + template + static inline std::vector + dpct_memset(sycl::queue &q, pitched_data data, valueT value, + sycl::range<3> size) + { + std::vector event_list; + size_t slice = data.get_pitch() * data.get_y(); + unsigned char *data_surface = (unsigned char *)data.get_data_ptr(); + for (size_t z = 0; z < size.get(2); ++z) + { + unsigned char *data_ptr = data_surface; + for (size_t y = 0; y < size.get(1); ++y) + { + event_list.push_back(dpct_memset(q, data_ptr, value, size.get(0))); + data_ptr += data.get_pitch(); + } + data_surface += slice; + } + return event_list; + } + + /** + * @brief Sets \p val to the pitched 2D memory region pointed by \p ptr in \p q. + * @tparam valueT The type of the element to be set. + * @param [in] q The queue in which the operation is done. + * @param [in] ptr Pointer to the virtual device memory. + * @param [in] pitch The pitch size by number of elements, including padding. + * @param [in] val The value to be set. + * @param [in] x The width of memory region by number of elements. + * @param [in] y The height of memory region by number of elements. + * @return An event list representing the memset operations. + */ + template + static inline std::vector + dpct_memset(sycl::queue &q, void *ptr, size_t pitch, valueT val, size_t x, + size_t y) + { + return dpct_memset(q, pitched_data(ptr, pitch, x, 1), val, + sycl::range<3>(x, y, 1)); + } + + static memcpy_direction deduce_memcpy_direction(sycl::queue &q, void *to_ptr, + const void *from_ptr, + memcpy_direction dir) + { + switch (dir) + { + case memcpy_direction::host_to_host: + case memcpy_direction::host_to_device: + case memcpy_direction::device_to_host: + case memcpy_direction::device_to_device: + return dir; + case memcpy_direction::automatic: + { + // table[to_attribute][from_attribute] + static const memcpy_direction + direction_table[static_cast(pointer_access_attribute::end)] + [static_cast(pointer_access_attribute::end)] = + {{memcpy_direction::host_to_host, + memcpy_direction::device_to_host, + memcpy_direction::host_to_host}, + {memcpy_direction::host_to_device, + memcpy_direction::device_to_device, + memcpy_direction::device_to_device}, + {memcpy_direction::host_to_host, + memcpy_direction::device_to_device, + memcpy_direction::device_to_device}}; + return direction_table[static_cast(get_pointer_attribute( + q, to_ptr))][static_cast(get_pointer_attribute(q, from_ptr))]; + } + default: + throw std::runtime_error("dpct_memcpy: invalid direction value"); + } + } + + static sycl::event + dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size, + memcpy_direction direction, + const std::vector &dep_events = {}) + { + if (!size) + return sycl::event{}; + return q.memcpy(to_ptr, from_ptr, size, dep_events); + GGML_UNUSED(direction); + } + + // Get actual copy range and make sure it will not exceed range. + static inline size_t get_copy_range(sycl::range<3> size, size_t slice, + size_t pitch) + { + return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0); + } + + static inline size_t get_offset(sycl::id<3> id, size_t slice, + size_t pitch) + { + return slice * id.get(2) + pitch * id.get(1) + id.get(0); + } + + /// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr + /// and \p from_range to another specified by \p to_ptr and \p to_range. + static inline std::vector + dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, + sycl::range<3> to_range, sycl::range<3> from_range, + sycl::id<3> to_id, sycl::id<3> from_id, + sycl::range<3> size, memcpy_direction direction, + const std::vector &dep_events = {}) + { + // RAII for host pointer + class host_buffer + { + void *_buf; + size_t _size; + sycl::queue &_q; + const std::vector &_deps; // free operation depends + + public: + host_buffer(size_t size, sycl::queue &q, + const std::vector &deps) + : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {} + void *get_ptr() const { return _buf; } + size_t get_size() const { return _size; } + ~host_buffer() + { + if (_buf) + { + _q.submit([&](sycl::handler &cgh) + { + cgh.depends_on(_deps); + cgh.host_task([buf = _buf] { std::free(buf); }); }); + } + } + }; + std::vector event_list; + + size_t to_slice = to_range.get(1) * to_range.get(0), + from_slice = from_range.get(1) * from_range.get(0); + unsigned char *to_surface = + (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0)); + const unsigned char *from_surface = + (const unsigned char *)from_ptr + + get_offset(from_id, from_slice, from_range.get(0)); + + if (to_slice == from_slice && to_slice == size.get(1) * size.get(0)) + { + return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2), + direction, dep_events)}; + } + direction = deduce_memcpy_direction(q, to_ptr, from_ptr, direction); + size_t size_slice = size.get(1) * size.get(0); + switch (direction) + { + case host_to_host: + for (size_t z = 0; z < size.get(2); ++z) + { + unsigned char *to_ptr = to_surface; + const unsigned char *from_ptr = from_surface; + if (to_range.get(0) == from_range.get(0) && + to_range.get(0) == size.get(0)) + { + event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice, + direction, dep_events)); + } + else + { + for (size_t y = 0; y < size.get(1); ++y) + { + event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0), + direction, dep_events)); + to_ptr += to_range.get(0); + from_ptr += from_range.get(0); + } + } + to_surface += to_slice; + from_surface += from_slice; + } + break; + case host_to_device: + { + host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q, + event_list); + std::vector host_events; + if (to_slice == size_slice) + { + // Copy host data to a temp host buffer with the shape of target. + host_events = + dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range, + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, + host_to_host, dep_events); + } + else + { + // Copy host data to a temp host buffer with the shape of target. + host_events = dpct_memcpy( + q, buf.get_ptr(), from_surface, to_range, from_range, + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host, + // If has padding data, not sure whether it is useless. So fill temp + // buffer with it. + std::vector{ + dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(), + device_to_host, dep_events)}); + } + // Copy from temp host buffer to device with only one submit. + event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(), + buf.get_size(), host_to_device, + host_events)); + break; + } + case device_to_host: + { + host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q, + event_list); + // Copy from host temp buffer to host target with reshaping. + event_list = dpct_memcpy( + q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0), + sycl::id<3>(0, 0, 0), size, host_to_host, + // Copy from device to temp host buffer with only one submit. + std::vector{dpct_memcpy(q, buf.get_ptr(), from_surface, + buf.get_size(), + device_to_host, dep_events)}); + break; + } + case device_to_device: + event_list.push_back(q.submit([&](sycl::handler &cgh){ + cgh.depends_on(dep_events); + cgh.parallel_for( + size, + [=](sycl::id<3> id) { + to_surface[get_offset(id, to_slice, to_range.get(0))] = + from_surface[get_offset(id, from_slice, from_range.get(0))]; + }); })); + break; + default: + throw std::runtime_error("dpct_memcpy: invalid direction value"); + } + return event_list; + } + + /// memcpy 2D/3D matrix specified by pitched_data. + static inline std::vector + dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id, + pitched_data from, sycl::id<3> from_id, sycl::range<3> size, + memcpy_direction direction = automatic) + { + return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(), + sycl::range<3>(to.get_pitch(), to.get_y(), 1), + sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id, + size, direction); + } + + /// memcpy 2D matrix with pitch. + static inline std::vector + dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, + size_t to_pitch, size_t from_pitch, size_t x, size_t y, + memcpy_direction direction = automatic) + { + return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1), + sycl::range<3>(from_pitch, y, 1), + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), + sycl::range<3>(x, y, 1), direction); + } + + namespace deprecated + { + + template + class usm_allocator + { + private: + using Alloc = sycl::usm_allocator; + Alloc _impl; + + public: + using value_type = typename std::allocator_traits::value_type; + using pointer = typename std::allocator_traits::pointer; + using const_pointer = typename std::allocator_traits::const_pointer; + using void_pointer = typename std::allocator_traits::void_pointer; + using const_void_pointer = + typename std::allocator_traits::const_void_pointer; + using reference = typename std::allocator_traits::value_type &; + using const_reference = + const typename std::allocator_traits::value_type &; + using difference_type = + typename std::allocator_traits::difference_type; + using size_type = typename std::allocator_traits::size_type; + using propagate_on_container_copy_assignment = typename std::allocator_traits< + Alloc>::propagate_on_container_copy_assignment; + using propagate_on_container_move_assignment = typename std::allocator_traits< + Alloc>::propagate_on_container_move_assignment; + using propagate_on_container_swap = + typename std::allocator_traits::propagate_on_container_swap; + using is_always_equal = + typename std::allocator_traits::is_always_equal; + + template + struct rebind + { + typedef usm_allocator other; + }; + + usm_allocator() : _impl(dpct::get_default_queue()) {} + ~usm_allocator() {} + usm_allocator(const usm_allocator &other) : _impl(other._impl) {} + usm_allocator(usm_allocator &&other) : _impl(std::move(other._impl)) {} + pointer address(reference r) { return &r; } + const_pointer address(const_reference r) { return &r; } + pointer allocate(size_type cnt, const_void_pointer hint = nullptr) + { + return std::allocator_traits::allocate(_impl, cnt, hint); + } + void deallocate(pointer p, size_type cnt) + { + std::allocator_traits::deallocate(_impl, p, cnt); + } + size_type max_size() const + { + return std::allocator_traits::max_size(_impl); + } + bool operator==(const usm_allocator &other) const { return _impl == other._impl; } + bool operator!=(const usm_allocator &other) const { return _impl != other._impl; } + }; + + } // namespace deprecated + + inline void dpct_free(void *ptr, + const sycl::queue &q) + { + if (ptr) + { + sycl::free(ptr, q.get_context()); + } + } + + template + inline auto get_memory(const void *x) + { + T *new_x = reinterpret_cast(const_cast(x)); + return new_x; + } + + template + inline typename DataType::T2 get_value(const T *s, sycl::queue &q) + { + using Ty = typename DataType::T2; + Ty s_h; + if (get_pointer_attribute(q, s) == pointer_access_attribute::device_only) + detail::dpct_memcpy(q, (void *)&s_h, (const void *)s, sizeof(T), device_to_host) + .wait(); + else + s_h = *reinterpret_cast(s); + return s_h; + } + + } // namespace detail + + template + inline auto get_value(const T *s, sycl::queue &q) + { + return detail::get_value(s, q); + } + + namespace detail + { + template + inline void gemm_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, + int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb, + const void * beta, void * c, int ldc) { + Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); + auto data_a = get_memory(a); + auto data_b = get_memory(b); + auto data_c = get_memory(c); + oneapi::math::blas::column_major::gemm(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a, + lda, data_b, ldb, beta_value, data_c, ldc); + } + + template + class vectorized_binary + { + public: + inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op) + { + VecT v4; + for (size_t i = 0; i < v4.size(); ++i) + { + v4[i] = binary_op(a[i], b[i]); + } + return v4; + } + }; + + template + class vectorized_binary< + VecT, BinaryOperation, + std::void_t>> + { + public: + inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op) + { + return binary_op(a, b).template as(); + } + }; + + template + inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, + int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b, + int ldb, const void * beta, void ** c, int ldc, int batch_size, + matrix_info_t * matrix_info) { + Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); + + matrix_info->transpose_info[0] = a_trans; + matrix_info->transpose_info[1] = b_trans; + matrix_info->value_info[0] = alpha_value; + matrix_info->value_info[1] = beta_value; + matrix_info->size_info[0] = m; + matrix_info->size_info[1] = n; + matrix_info->size_info[2] = k; + matrix_info->ld_info[0] = lda; + matrix_info->ld_info[1] = ldb; + matrix_info->ld_info[2] = ldc; + matrix_info->groupsize_info = batch_size; + + sycl::event e = oneapi::math::blas::column_major::gemm_batch( + get_onemath_backend(q), matrix_info->transpose_info, matrix_info->transpose_info + 1, + matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2, + reinterpret_cast(matrix_info->value_info), reinterpret_cast(a), matrix_info->ld_info, + reinterpret_cast(b), matrix_info->ld_info + 1, + reinterpret_cast(matrix_info->value_info + 1), reinterpret_cast(c), + matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); + } + + template + inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, + int m, int n, int k, const void * alpha, const void * a, int lda, + long long int stride_a, const void * b, int ldb, long long int stride_b, + const void * beta, void * c, int ldc, long long int stride_c, int batch_size) { + Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); + auto data_a = get_memory(a); + auto data_b = get_memory(b); + auto data_c = get_memory(c); + oneapi::math::blas::column_major::gemm_batch(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, + data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, + data_c, ldc, stride_c, batch_size); + } + + } // namespace detail + + template + inline unsigned vectorized_binary(unsigned a, unsigned b, + const BinaryOperation binary_op) + { + sycl::vec v0{a}, v1{b}; + auto v2 = v0.as(); + auto v3 = v1.as(); + auto v4 = + detail::vectorized_binary()(v2, v3, binary_op); + v0 = v4.template as>(); + return v0; + } + + static void async_dpct_memcpy(void *to_ptr, const void *from_ptr, size_t size, + memcpy_direction direction = automatic, + sycl::queue &q = dpct::get_default_queue()) + { + detail::dpct_memcpy(q, to_ptr, from_ptr, size, direction); + } + + static inline unsigned int select_device(unsigned int id) + { + dev_mgr::instance().select_device(id); + return id; + } + + template + T permute_sub_group_by_xor(sycl::sub_group g, T x, unsigned int mask, + unsigned int logical_sub_group_size = 32) + { + unsigned int id = g.get_local_linear_id(); + unsigned int start_index = + id / logical_sub_group_size * logical_sub_group_size; + unsigned int target_offset = (id % logical_sub_group_size) ^ mask; + return sycl::select_from_group(g, x, + target_offset < logical_sub_group_size + ? start_index + target_offset + : id); + } + + template + inline auto dp4a(T1 a, T2 b, T3 c) + { + return syclcompat::dp4a(a, b, c); + } + + struct sub_sat + { + template + auto operator()(const T x, const T y) const + { + return sycl::sub_sat(x, y); + } + }; + + template + inline T vectorized_min(T a, T b) + { + sycl::vec v0{a}, v1{b}; + auto v2 = v0.template as(); + auto v3 = v1.template as(); + auto v4 = sycl::min(v2, v3); + v0 = v4.template as>(); + return v0; + } + + inline float pow(const float a, const int b) { return sycl::pown(a, b); } + inline double pow(const double a, const int b) { return sycl::pown(a, b); } + inline float pow(const float a, const float b) { return sycl::pow(a, b); } + inline double pow(const double a, const double b) { return sycl::pow(a, b); } + template + inline typename std::enable_if_t, T> + pow(const T a, const U b) + { + return sycl::pow(a, static_cast(b)); + } + template + inline typename std::enable_if_t, double> + pow(const T a, const U b) + { + return sycl::pow(static_cast(a), static_cast(b)); + } + + inline double min(const double a, const float b) + { + return sycl::fmin(a, static_cast(b)); + } + inline double min(const float a, const double b) + { + return sycl::fmin(static_cast(a), b); + } + inline float min(const float a, const float b) { return sycl::fmin(a, b); } + inline double min(const double a, const double b) { return sycl::fmin(a, b); } + inline std::uint32_t min(const std::uint32_t a, const std::int32_t b) + { + return sycl::min(a, static_cast(b)); + } + inline std::uint32_t min(const std::int32_t a, const std::uint32_t b) + { + return sycl::min(static_cast(a), b); + } + inline std::int32_t min(const std::int32_t a, const std::int32_t b) + { + return sycl::min(a, b); + } + inline std::uint32_t min(const std::uint32_t a, const std::uint32_t b) + { + return sycl::min(a, b); + } + inline std::uint64_t min(const std::uint64_t a, const std::int64_t b) + { + return sycl::min(a, static_cast(b)); + } + inline std::uint64_t min(const std::int64_t a, const std::uint64_t b) + { + return sycl::min(static_cast(a), b); + } + inline std::int64_t min(const std::int64_t a, const std::int64_t b) + { + return sycl::min(a, b); + } + inline std::uint64_t min(const std::uint64_t a, const std::uint64_t b) + { + return sycl::min(a, b); + } + inline std::uint64_t min(const std::uint64_t a, const std::int32_t b) + { + return sycl::min(a, static_cast(b)); + } + inline std::uint64_t min(const std::int32_t a, const std::uint64_t b) + { + return sycl::min(static_cast(a), b); + } + inline std::uint64_t min(const std::uint64_t a, const std::uint32_t b) + { + return sycl::min(a, static_cast(b)); + } + inline std::uint64_t min(const std::uint32_t a, const std::uint64_t b) + { + return sycl::min(static_cast(a), b); + } + // max function overloads. + // For floating-point types, `float` or `double` arguments are acceptable. + // For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or + // `std::int64_t` type arguments are acceptable. + inline double max(const double a, const float b) + { + return sycl::fmax(a, static_cast(b)); + } + inline double max(const float a, const double b) + { + return sycl::fmax(static_cast(a), b); + } + inline float max(const float a, const float b) { return sycl::fmax(a, b); } + inline double max(const double a, const double b) { return sycl::fmax(a, b); } + inline std::uint32_t max(const std::uint32_t a, const std::int32_t b) + { + return sycl::max(a, static_cast(b)); + } + inline std::uint32_t max(const std::int32_t a, const std::uint32_t b) + { + return sycl::max(static_cast(a), b); + } + inline std::int32_t max(const std::int32_t a, const std::int32_t b) + { + return sycl::max(a, b); + } + inline std::uint32_t max(const std::uint32_t a, const std::uint32_t b) + { + return sycl::max(a, b); + } + inline std::uint64_t max(const std::uint64_t a, const std::int64_t b) + { + return sycl::max(a, static_cast(b)); + } + inline std::uint64_t max(const std::int64_t a, const std::uint64_t b) + { + return sycl::max(static_cast(a), b); + } + inline std::int64_t max(const std::int64_t a, const std::int64_t b) + { + return sycl::max(a, b); + } + inline std::uint64_t max(const std::uint64_t a, const std::uint64_t b) + { + return sycl::max(a, b); + } + inline std::uint64_t max(const std::uint64_t a, const std::int32_t b) + { + return sycl::max(a, static_cast(b)); + } + inline std::uint64_t max(const std::int32_t a, const std::uint64_t b) + { + return sycl::max(static_cast(a), b); + } + inline std::uint64_t max(const std::uint64_t a, const std::uint32_t b) + { + return sycl::max(a, static_cast(b)); + } + inline std::uint64_t max(const std::uint32_t a, const std::uint64_t b) + { + return sycl::max(static_cast(a), b); + } + + inline void + has_capability_or_fail(const sycl::device &dev, + const std::initializer_list &props) + { + for (const auto &it : props) + { + if (dev.has(it)) + continue; + switch (it) + { + case sycl::aspect::fp64: + throw std::runtime_error("'double' is not supported in '" + + dev.get_info() + + "' device"); + break; + case sycl::aspect::fp16: + throw std::runtime_error("'half' is not supported in '" + + dev.get_info() + + "' device"); + break; + default: +#define __SYCL_ASPECT(ASPECT, ID) \ + case sycl::aspect::ASPECT: \ + return #ASPECT; +#define __SYCL_ASPECT_DEPRECATED(ASPECT, ID, MESSAGE) __SYCL_ASPECT(ASPECT, ID) +#define __SYCL_ASPECT_DEPRECATED_ALIAS(ASPECT, ID, MESSAGE) + auto getAspectNameStr = [](sycl::aspect AspectNum) -> std::string + { + switch (AspectNum) + { +#include +#include + default: + return "unknown aspect"; + } + }; +#undef __SYCL_ASPECT_DEPRECATED_ALIAS +#undef __SYCL_ASPECT_DEPRECATED +#undef __SYCL_ASPECT + throw std::runtime_error( + "'" + getAspectNameStr(it) + "' is not supported in '" + + dev.get_info() + "' device"); + } + break; + } + } + + static inline unsigned int get_current_device_id() + { + return dev_mgr::instance().current_device_id(); + } + + static inline device_ext &get_current_device() + { + return dev_mgr::instance().current_device(); + } + + static inline device_ext &get_device(unsigned int id) + { + return dev_mgr::instance().get_device(id); + } + + static inline sycl::queue &get_in_order_queue() + { + return dev_mgr::instance().current_device().in_order_queue(); + } + + static sycl::event + dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size, + memcpy_direction direction, + const std::vector &dep_events = {}) + { + if (!size) + return sycl::event{}; + return q.memcpy(to_ptr, from_ptr, size, dep_events); + GGML_UNUSED(direction); + } + + // Get actual copy range and make sure it will not exceed range. + static inline size_t get_copy_range(sycl::range<3> size, size_t slice, + size_t pitch) + { + return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0); + } + + static inline size_t get_offset(sycl::id<3> id, size_t slice, + size_t pitch) + { + return slice * id.get(2) + pitch * id.get(1) + id.get(0); + } + + /// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr + /// and \p from_range to another specified by \p to_ptr and \p to_range. + static inline std::vector + dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, + sycl::range<3> to_range, sycl::range<3> from_range, + sycl::id<3> to_id, sycl::id<3> from_id, + sycl::range<3> size, memcpy_direction direction, + const std::vector &dep_events = {}) + { + // RAII for host pointer + class host_buffer + { + void *_buf; + size_t _size; + sycl::queue &_q; + const std::vector &_deps; // free operation depends + + public: + host_buffer(size_t size, sycl::queue &q, + const std::vector &deps) + : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {} + void *get_ptr() const { return _buf; } + size_t get_size() const { return _size; } + ~host_buffer() + { + if (_buf) + { + _q.submit([&](sycl::handler &cgh) + { + cgh.depends_on(_deps); + cgh.host_task([buf = _buf] { std::free(buf); }); }); + } + } + }; + std::vector event_list; + + size_t to_slice = to_range.get(1) * to_range.get(0), + from_slice = from_range.get(1) * from_range.get(0); + unsigned char *to_surface = + (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0)); + const unsigned char *from_surface = + (const unsigned char *)from_ptr + + get_offset(from_id, from_slice, from_range.get(0)); + + if (to_slice == from_slice && to_slice == size.get(1) * size.get(0)) + { + return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2), + direction, dep_events)}; + } + direction = detail::deduce_memcpy_direction(q, to_ptr, from_ptr, direction); + size_t size_slice = size.get(1) * size.get(0); + switch (direction) + { + case host_to_host: + for (size_t z = 0; z < size.get(2); ++z) + { + unsigned char *to_ptr = to_surface; + const unsigned char *from_ptr = from_surface; + if (to_range.get(0) == from_range.get(0) && + to_range.get(0) == size.get(0)) + { + event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice, + direction, dep_events)); + } + else + { + for (size_t y = 0; y < size.get(1); ++y) + { + event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0), + direction, dep_events)); + to_ptr += to_range.get(0); + from_ptr += from_range.get(0); + } + } + to_surface += to_slice; + from_surface += from_slice; + } + break; + case host_to_device: + { + host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q, + event_list); + std::vector host_events; + if (to_slice == size_slice) + { + // Copy host data to a temp host buffer with the shape of target. + host_events = + dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range, + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, + host_to_host, dep_events); + } + else + { + // Copy host data to a temp host buffer with the shape of target. + host_events = dpct_memcpy( + q, buf.get_ptr(), from_surface, to_range, from_range, + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host, + // If has padding data, not sure whether it is useless. So fill temp + // buffer with it. + std::vector{ + dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(), + device_to_host, dep_events)}); + } + // Copy from temp host buffer to device with only one submit. + event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(), + buf.get_size(), host_to_device, + host_events)); + break; + } + case device_to_host: + { + host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q, + event_list); + // Copy from host temp buffer to host target with reshaping. + event_list = dpct_memcpy( + q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0), + sycl::id<3>(0, 0, 0), size, host_to_host, + // Copy from device to temp host buffer with only one submit. + std::vector{dpct_memcpy(q, buf.get_ptr(), from_surface, + buf.get_size(), + device_to_host, dep_events)}); + break; + } + case device_to_device: + event_list.push_back(q.submit([&](sycl::handler &cgh) + { + cgh.depends_on(dep_events); + cgh.parallel_for( + size, + [=](sycl::id<3> id) { + to_surface[get_offset(id, to_slice, to_range.get(0))] = + from_surface[get_offset(id, from_slice, from_range.get(0))]; + }); })); + break; + default: + throw std::runtime_error("dpct_memcpy: invalid direction value"); + } + return event_list; + } + + /// memcpy 2D/3D matrix specified by pitched_data. + static inline std::vector + dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id, + pitched_data from, sycl::id<3> from_id, sycl::range<3> size, + memcpy_direction direction = automatic) + { + return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(), + sycl::range<3>(to.get_pitch(), to.get_y(), 1), + sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id, + size, direction); + } + + /// memcpy 2D matrix with pitch. + static inline std::vector + dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, + size_t to_pitch, size_t from_pitch, size_t x, size_t y, + memcpy_direction direction = automatic) + { + return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1), + sycl::range<3>(from_pitch, y, 1), + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), + sycl::range<3>(x, y, 1), direction); + } + + inline void gemm(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, int n, + int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b, + library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc, + library_data_t scaling_type) { + if (scaling_type == library_data_t::real_float && + c_type == library_data_t::complex_float) + { + scaling_type = library_data_t::complex_float; + } + else if (scaling_type == library_data_t::real_double && + c_type == library_data_t::complex_double) + { + scaling_type = library_data_t::complex_double; + } + + std::uint64_t key = + detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); + switch (key) + { + case detail::get_type_combination_id( + library_data_t::real_float, library_data_t::real_float, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_double, library_data_t::real_double, + library_data_t::real_double, library_data_t::real_double): + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_float, library_data_t::complex_float, + library_data_t::complex_float, library_data_t::complex_float): + { + detail::gemm_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double, + library_data_t::complex_double, library_data_t::complex_double): + { + detail::gemm_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_half): + { + detail::gemm_impl(q, a_trans, b_trans, m, n, k, alpha, a, + lda, b, ldb, beta, c, ldc); + break; + } +#ifdef __INTEL_MKL__ + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_float): + { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_value = + dpct::get_value(reinterpret_cast(beta), q); + sycl::half alpha_half(alpha_value); + sycl::half beta_half(beta_value); + detail::gemm_impl(q, a_trans, b_trans, m, n, k, &alpha_half, + a, lda, b, ldb, &beta_half, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_bfloat16, library_data_t::real_float): + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_int32, library_data_t::real_int32): + { + float alpha_float = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_float = + dpct::get_value(reinterpret_cast(beta), q); + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc); + break; + } +#endif // __INTEL_MKL__ + default: + throw std::runtime_error("the combination of data type is unsupported"); + } + } // gemm() + + /// Computes a batch of matrix-matrix product with general matrices. + /// \param [in] q The queue where the routine should be executed. + /// \param [in] a_trans Specifies the operation applied to A. + /// \param [in] b_trans Specifies the operation applied to B. + /// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C. + /// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C. + /// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B). + /// \param [in] alpha Scaling factor for the matrix-matrix product. + /// \param [in] a Input matrix A. + /// \param [in] a_type Data type of the matrix A. + /// \param [in] lda Leading dimension of A. + /// \param [in] b Input matrix B. + /// \param [in] b_type Data type of the matrix B. + /// \param [in] ldb Leading dimension of B. + /// \param [in] beta Scaling factor for matrix C. + /// \param [in, out] c Input/Output matrix C. + /// \param [in] c_type Data type of the matrix C. + /// \param [in] ldc Leading dimension of C. + /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. + /// \param [in] scaling_type Data type of the scaling factors. + inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, + int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda, + const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[], + library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type, + matrix_info_t * matrix_info) { + std::uint64_t key = + detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); + switch (key) + { + case detail::get_type_combination_id( + library_data_t::real_float, library_data_t::real_float, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, batch_size, matrix_info); + break; + } + case detail::get_type_combination_id( + library_data_t::real_double, library_data_t::real_double, + library_data_t::real_double, library_data_t::real_double): + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, batch_size, matrix_info); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_half): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); + break; + } +#ifdef __INTEL_MKL__ + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_bfloat16, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); + break; + } +#endif + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_int32, library_data_t::real_int32): + { + float alpha_float = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_float = + dpct::get_value(reinterpret_cast(beta), q); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size, + matrix_info); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_float): + { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_value = + dpct::get_value(reinterpret_cast(beta), q); + sycl::half alpha_half(alpha_value); + sycl::half beta_half(beta_value); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } + } + + /// Computes a batch of matrix-matrix product with general matrices. + /// \param [in] q The queue where the routine should be executed. + /// \param [in] a_trans Specifies the operation applied to A. + /// \param [in] b_trans Specifies the operation applied to B. + /// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C. + /// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C. + /// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B). + /// \param [in] alpha Scaling factor for the matrix-matrix product. + /// \param [in] a Input matrix A. + /// \param [in] a_type Data type of the matrix A. + /// \param [in] lda Leading dimension of A. + /// \param [in] stride_a Stride between the different A matrices. + /// \param [in] b Input matrix B. + /// \param [in] b_type Data type of the matrix B. + /// \param [in] ldb Leading dimension of B. + /// \param [in] stride_b Stride between the different B matrices. + /// \param [in] beta Scaling factor for matrix C. + /// \param [in, out] c Input/Output matrix C. + /// \param [in] c_type Data type of the matrix C. + /// \param [in] ldc Leading dimension of C. + /// \param [in] stride_c Stride between the different C matrices. + /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. + /// \param [in] scaling_type Data type of the scaling factors. + inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, + int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda, + long long int stride_a, const void * b, library_data_t b_type, int ldb, + long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc, + long long int stride_c, int batch_size, library_data_t scaling_type) { + if (scaling_type == library_data_t::real_float && + c_type == library_data_t::complex_float) + { + scaling_type = library_data_t::complex_float; + } + else if (scaling_type == library_data_t::real_double && + c_type == library_data_t::complex_double) + { + scaling_type = library_data_t::complex_double; + } + + std::uint64_t key = + detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); + switch (key) + { + case detail::get_type_combination_id( + library_data_t::real_float, library_data_t::real_float, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_double, library_data_t::real_double, + library_data_t::real_double, library_data_t::real_double): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_float, library_data_t::complex_float, + library_data_t::complex_float, library_data_t::complex_float): + { + detail::gemm_batch_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double, + library_data_t::complex_double, library_data_t::complex_double): + { + detail::gemm_batch_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_half): + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, + a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } +#ifdef __INTEL_MKL__ + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_bfloat16, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, + batch_size); + break; + } +#endif + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_int32, library_data_t::real_int32): + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, + a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_float): + { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_value = + dpct::get_value(reinterpret_cast(beta), q); + sycl::half alpha_half(alpha_value); + sycl::half beta_half(beta_value); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, stride_a, b, ldb, stride_b, + &beta_half, c, ldc, stride_c, batch_size); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } + } + + static inline void + async_dpct_memcpy(void *to_ptr, size_t to_pitch, const void *from_ptr, + size_t from_pitch, size_t x, size_t y, + memcpy_direction direction = automatic, + sycl::queue &q = get_default_queue()) + { + detail::dpct_memcpy(q, to_ptr, from_ptr, to_pitch, from_pitch, x, y, + direction); + } + + using err0 = detail::generic_error_type; + using err1 = detail::generic_error_type; + + static inline void dpct_free(void *ptr, sycl::queue &q = get_default_queue()) { + detail::dpct_free(ptr, q); + } + + /// dpct accessor used as device function parameter. + template class accessor; + template class accessor { + public: + using memory_t = detail::memory_traits; + using element_t = typename memory_t::element_t; + using pointer_t = typename memory_t::pointer_t; + using accessor_t = typename memory_t::template accessor_t<3>; + accessor(pointer_t data, const sycl::range<3> &in_range) + : _data(data), _range(in_range) {} + template + accessor(typename std::enable_if::type &acc) + : accessor(acc, acc.get_range()) {} + accessor(const accessor_t &acc, const sycl::range<3> &in_range) + : accessor(acc.get_pointer(), in_range) {} + accessor operator[](size_t index) const { + sycl::range<2> sub(_range.get(1), _range.get(2)); + return accessor(_data + index * sub.size(), sub); + } + + pointer_t get_ptr() const { return _data; } + + private: + pointer_t _data; + sycl::range<3> _range; + }; + template class accessor { + public: + using memory_t = detail::memory_traits; + using element_t = typename memory_t::element_t; + using pointer_t = typename memory_t::pointer_t; + using accessor_t = typename memory_t::template accessor_t<2>; + accessor(pointer_t data, const sycl::range<2> &in_range) + : _data(data), _range(in_range) {} + template + accessor(typename std::enable_if::type &acc) + : accessor(acc, acc.get_range()) {} + accessor(const accessor_t &acc, const sycl::range<2> &in_range) + : accessor(acc.get_pointer(), in_range) {} + + pointer_t operator[](size_t index) const { + return _data + _range.get(1) * index; + } + + pointer_t get_ptr() const { return _data; } + + private: + pointer_t _data; + sycl::range<2> _range; + }; + + namespace detail { + /// Device variable with address space of shared, global or constant. + template class device_memory { + public: + using accessor_t = + typename detail::memory_traits::template accessor_t; + using value_t = typename detail::memory_traits::value_t; + using dpct_accessor_t = dpct::accessor; + + device_memory() : device_memory(sycl::range(1)) {} + + /// Constructor of 1-D array with initializer list + device_memory(const sycl::range &in_range, + std::initializer_list &&init_list) + : device_memory(in_range) { + assert(init_list.size() <= in_range.size()); + _host_ptr = (value_t *)std::malloc(_size); + std::memset(_host_ptr, 0, _size); + std::memcpy(_host_ptr, init_list.begin(), init_list.size() * sizeof(T)); + } + + /// Constructor of 2-D array with initializer list + template + device_memory( + const typename std::enable_if>::type &in_range, + std::initializer_list> &&init_list) + : device_memory(in_range) { + assert(init_list.size() <= in_range[0]); + _host_ptr = (value_t *)std::malloc(_size); + std::memset(_host_ptr, 0, _size); + auto tmp_data = _host_ptr; + for (auto sub_list : init_list) { + assert(sub_list.size() <= in_range[1]); + std::memcpy(tmp_data, sub_list.begin(), + sub_list.size() * sizeof(T)); + tmp_data += in_range[1]; + } + } + + /// Constructor with range + device_memory(const sycl::range &range_in) + : _size(range_in.size() * sizeof(T)), _range(range_in), + _reference(false), _host_ptr(nullptr), _device_ptr(nullptr) { + static_assert( + (Memory == global) || (Memory == constant) || (Memory == shared), + "device memory region should be global, constant or shared"); + // Make sure that singleton class mem_mgr and dev_mgr will destruct + // later than this. + detail::mem_mgr::instance(); + dev_mgr::instance(); + } + + /// Constructor with range + template + device_memory(Args... Arguments) + : device_memory(sycl::range(Arguments...)) {} + + ~device_memory() { + if (_device_ptr && !_reference) + dpct::dpct_free(_device_ptr); + if (_host_ptr) + std::free(_host_ptr); + } + + /// Allocate memory with default queue, and init memory if has initial + /// value. + void init() { init(dpct::get_default_queue()); } + /// Allocate memory with specified queue, and init memory if has initial + /// value. + void init(sycl::queue &q) { + if (_device_ptr) + return; + if (!_size) + return; + allocate_device(q); + if (_host_ptr) + detail::dpct_memcpy(q, _device_ptr, _host_ptr, _size, + host_to_device); + } + + /// The variable is assigned to a device pointer. + void assign(value_t *src, size_t size) { + this->~device_memory(); + new (this) device_memory(src, size); + } + + /// Get memory pointer of the memory object, which is virtual pointer when + /// usm is not used, and device pointer when usm is used. + value_t *get_ptr() { return get_ptr(get_default_queue()); } + /// Get memory pointer of the memory object, which is virtual pointer when + /// usm is not used, and device pointer when usm is used. + value_t *get_ptr(sycl::queue &q) { + init(q); + return _device_ptr; + } + + /// Get the device memory object size in bytes. + size_t get_size() { return _size; } + + template + typename std::enable_if::type &operator[](size_t index) { + init(); + return _device_ptr[index]; + } + + /// Get dpct::accessor with dimension info for the device memory object + /// when usm is used and dimension is greater than 1. + template + typename std::enable_if::type + get_access([[maybe_unused]] sycl::handler &cgh) { + return dpct_accessor_t((T *)_device_ptr, _range); + } + + private: + device_memory(value_t *memory_ptr, size_t size) + : _size(size), _range(size / sizeof(T)), _reference(true), + _device_ptr(memory_ptr) {} + + void allocate_device(sycl::queue &q) { + #ifndef DPCT_USM_LEVEL_NONE + if (Memory == shared) { + _device_ptr = (value_t *)sycl::malloc_shared(_size, q.get_device(), + q.get_context()); + return; + } + #ifdef SYCL_EXT_ONEAPI_USM_DEVICE_READ_ONLY + if (Memory == constant) { + _device_ptr = (value_t *)sycl::malloc_device( + _size, q.get_device(), q.get_context(), + sycl::ext::oneapi::property::usm::device_read_only()); + return; + } + #endif + #endif + _device_ptr = (value_t *)detail::dpct_malloc(_size, q); + } + + size_t _size; + sycl::range _range; + bool _reference; + value_t *_host_ptr; + value_t *_device_ptr; + }; + template + class device_memory : public device_memory { + public: + using base = device_memory; + using value_t = typename base::value_t; + using accessor_t = + typename detail::memory_traits::template accessor_t<0>; + + /// Constructor with initial value. + device_memory(const value_t &val) : base(sycl::range<1>(1), {val}) {} + + /// Default constructor + device_memory() : base(1) {} + }; + } // namespace detail + + template + using global_memory = detail::device_memory; + template + using constant_memory = detail::device_memory; + template + using shared_memory = detail::device_memory; + + + template + inline T atomic_fetch_add(T *addr, T operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_add(operand); + } + + template + inline T1 atomic_fetch_add(T1 *addr, T2 operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_add(operand); + } + + template + inline T atomic_fetch_add(T *addr, T operand, + sycl::memory_order memoryOrder) { + switch (memoryOrder) { + case sycl::memory_order::relaxed: + return atomic_fetch_add(addr, operand); + case sycl::memory_order::acq_rel: + return atomic_fetch_add(addr, operand); + case sycl::memory_order::seq_cst: + return atomic_fetch_add(addr, operand); + default: + assert(false && "Invalid memory_order for atomics. Valid memory_order for " + "atomics are: sycl::memory_order::relaxed, " + "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!"); + } + } + + template + inline T1 atomic_fetch_add(T1 *addr, T2 operand, + sycl::memory_order memoryOrder) { + atomic_fetch_add(addr, operand, memoryOrder); + } + +} // COPY from DPCT head files + +#endif // GGML_SYCL_DPCT_HELPER_HPP diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5b7c4f0b4f003c55c8fadf4e37386cebb87074fa --- /dev/null +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -0,0 +1,1571 @@ +#include "common.hpp" +#include "ggml.h" +#include "element_wise.hpp" + +static void acc_f32(const float * x, const float * y, float * dst, const int ne, + const int ne10, const int ne11, const int ne12, + const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + if (i >= ne) { + return; + } + int src1_idx = i - offset; + int oz = src1_idx / nb2; + int oy = (src1_idx - (oz * nb2)) / nb1; + int ox = src1_idx % nb1; + if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) { + dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11]; + } else { + dst[i] = x[i]; + } +} + +template +static void sgn(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { + for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) { + dst[i] = x[i] > static_cast(0.f) ? static_cast(1.f) : ((x[i] < static_cast(0.f) ? static_cast(-1.f) : static_cast(0.f))); + } +} + +template +static void abs_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { + for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) { + dst[i] = sycl::fabs(x[i]); + } +} + +template +static void elu_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { + for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) { + dst[i] = (x[i] > static_cast(0.f)) ? x[i] : sycl::expm1(x[i]); + } +} + +template +static void gelu(const T * x, T * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const T GELU_COEF_A = static_cast(0.044715f); + const T SQRT_2_OVER_PI = static_cast(0.79788456080286535587989211986876f); + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + + float xi = x[i]; + dst[i] = static_cast(0.5f) * xi * + (static_cast(1.0f) + + sycl::tanh(SQRT_2_OVER_PI * xi * (static_cast(1.0f) + GELU_COEF_A * xi * xi))); +} + +template +static void silu(const T * x, T * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = x[i] / (static_cast(1.0f) + sycl::native::exp(-x[i])); +} + +template +static void gelu_quick(const T *x, T *dst, int k, + const sycl::nd_item<3> &item_ct1) { + const float GELU_QUICK_COEF = -1.702f; + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + if (i >= k) { + return; + } + dst[i] = x[i] * (static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x[i]))); +} + +template +static void gelu_erf(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { + const T SQRT_2_INV = static_cast(0.70710678118654752440084436210484f); + for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) { + auto x_i = x[i]; + dst[i] = static_cast(0.5f) * x_i * (static_cast(1.0f) + sycl::erf(x_i * SQRT_2_INV)); + } +} + +template +static void tanh(const T *x, T *dst, int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + if (i >= k) { + return; + } + dst[i] = sycl::tanh((x[i])); +} + +template +static void relu(const T * x, T * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = sycl::fmax((x[i]), static_cast(0)); +} + +template +static void sigmoid(const T * x, T * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = 1.0f / (static_cast(1.0f) + sycl::native::exp(-x[i])); +} + +template +static void sqrt(const T * x, T * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = sycl::sqrt(x[i]); +} + +template +static void sin(const T * x, T * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = sycl::sin(x[i]); +} + +template +static void cos(const T * x, T * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = sycl::cos(x[i]); +} + +template +static void hardsigmoid(const T * x, T * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x[i] + static_cast(3.0f)) / static_cast(6.0f))); +} + +template +static void hardswish(const T * x, T * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = x[i] * sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x[i] + static_cast(3.0f)) / static_cast(6.0f))); +} + +template +static void exp(const T * x, T * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = sycl::exp(x[i]); +} + +template +static void log(const T * x, T * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + T xi = x[i]; + if (xi <= 0) { + dst[i] = neg_infinity(); + } else { + dst[i] = sycl::log(xi); + } +} + +template +static void neg(const T * x, T * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = -x[i]; +} + +template +static void step(const T * x, T * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = x[i] > static_cast(0.0f); +} + +template +static void leaky_relu(const T *x, T *dst, const int k, const float negative_slope, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + if (i >= k) { + return; + } + dst[i] = sycl::fmax((x[i]), static_cast(0)) + + sycl::fmin((x[i]), static_cast(0.0f)) * negative_slope; +} + +template +static void sqr(const T * x, T * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = x[i] * x[i]; +} + +template +static void upscale(const T *x, T *dst, const int nb00, const int nb01, + const int nb02, const int nb03, const int ne10, const int ne11, + const int ne12, const int ne13, const float sf0, const float sf1, + const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) { + int index = item_ct1.get_local_id(0) + + item_ct1.get_group(0) * item_ct1.get_local_range(0); + if (index >= ne10 * ne11 * ne12 * ne13) { + return; + } + // operation + int i10 = index % ne10; + int i11 = (index / ne10) % ne11; + int i12 = (index / (ne10 * ne11)) % ne12; + int i13 = (index / (ne10 * ne11 * ne12)) % ne13; + + int i00 = i10 / sf0; + int i01 = i11 / sf1; + int i02 = i12 / sf2; + int i03 = i13 / sf3; + + dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00); +} + +template +static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02, + const sycl::nd_item<3> &item_ct1) { + int nidx = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + if (nidx >= ne0) { + return; + } + + // operation + int offset_dst = nidx + item_ct1.get_group(1) * ne0 + + item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1); + if (nidx < ne00 && item_ct1.get_group(1) < (size_t) ne01 && item_ct1.get_group(0) < (size_t) ne02) { + int offset_src = nidx + item_ct1.get_group(1) * ne00 + + item_ct1.get_group(0) * ne00 * ne01; + dst[offset_dst] = x[offset_src]; + } else { + dst[offset_dst] = static_cast(0.0f); + } +} + + +template +static void clamp(const T * x, T * dst, const float min, const float max, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + + dst[i] = x[i] < static_cast(min) ? static_cast(min) : (x[i] > static_cast(max) ? static_cast(max) : x[i]); +} + +static void acc_f32_sycl(const float *x, const float *y, float *dst, + const int n_elements, const int ne10, const int ne11, + const int ne12, const int nb1, const int nb2, + const int offset, queue_ptr stream) { + int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset, + item_ct1); + }); +} + +template +static void gelu_sycl(const T *x, T *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + gelu(x, dst, k, item_ct1); + }); +} + +template +static void silu_sycl(const T *x, T *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + silu(x, dst, k, item_ct1); + }); +} + +template +static void sgn_sycl(const T * x, T * dst, const int k, queue_ptr stream) { + // hard code for now + const int num_blocks = ceil_div(k, 256); + stream->parallel_for( + sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range(1, 1, 256)), sycl::range(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { + sgn(x, dst, k, item_ct1); + }); +} + +template +static void abs_sycl(const T * x, T * dst, const int k, queue_ptr stream) { + // hard code for now + const int num_blocks = ceil_div(k, 256); + stream->parallel_for( + sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { + abs_op(x, dst, k, item_ct1); + }); +} + + +template +static void elu_sycl(const T * x, T * dst, const int k, queue_ptr stream) { + // hard code for now + const int num_blocks = ceil_div(k, 256); + stream->parallel_for( + sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { + elu_op(x, dst, k, item_ct1); + }); +} + +template +static void gelu_quick_sycl(const T *x, T *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + gelu_quick(x, dst, k, item_ct1); + }); +} + + +template +static void gelu_erf_sycl(const T *x, T *dst, const int k, + queue_ptr stream) { + const int num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + gelu_erf(x, dst, k, item_ct1); + }); +} + +template +static void tanh_sycl(const T *x, T *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + tanh(x, dst, k, item_ct1); + }); +} + +template +static void relu_sycl(const T *x, T *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + relu(x, dst, k, item_ct1); + }); +} + +template +static void hardsigmoid_sycl(const T *x, T *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + hardsigmoid(x, dst, k, item_ct1); + }); +} + +template +static void hardswish_sycl(const T *x, T *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + hardswish(x, dst, k, item_ct1); + }); +} + +template +static void exp_sycl(const T *x, T *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + exp(x, dst, k, item_ct1); + }); +} + +template +static void log_sycl(const T *x, T *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + log(x, dst, k, item_ct1); + }); +} + +template +static void neg_sycl(const T *x, T *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + neg(x, dst, k, item_ct1); + }); +} + +template +static void step_sycl(const T *x, T *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + step(x, dst, k, item_ct1); + }); +} + +template +static void sigmoid_sycl(const T *x, T *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + sigmoid(x, dst, k, item_ct1); + }); +} + +template +static void sqrt_sycl(const T *x, T *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + sqrt(x, dst, k, item_ct1); + }); +} + +template +static void sin_sycl(const T *x, T *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + sin(x, dst, k, item_ct1); + }); +} + +template +static void cos_sycl(const T *x, T *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + cos(x, dst, k, item_ct1); + }); +} + +template +static void leaky_relu_sycl(const T *x, T *dst, const int k, + const float negative_slope, + queue_ptr stream) { + const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + leaky_relu(x, dst, k, negative_slope, item_ct1); + }); +} + +template +static void sqr_sycl(const T *x, T *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + sqr(x, dst, k, item_ct1); + }); +} + +template +static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01, + const int nb02, const int nb03, const int ne10, const int ne11, + const int ne12, const int ne13, const float sf0, const float sf1, + const float sf2, const float sf3, queue_ptr stream) { + int dst_size = ne10 * ne11 * ne12 * ne13; + int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; + sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1); + }); +} + +template +static void pad_sycl(const T *x, T *dst, const int ne00, + const int ne01, const int ne02, const int ne0, + const int ne1, const int ne2, queue_ptr stream) { + int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE; + sycl::range<3> gridDim(ne2, ne1, num_blocks); + stream->parallel_for( + sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); + }); +} + +template +static void clamp_sycl(const T *x, T *dst, const float min, + const float max, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + clamp(x, dst, min, max, k, item_ct1); + }); +} + +inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + sgn_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + sgn_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + abs_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + abs_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + + +inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + elu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + elu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + silu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + silu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + gelu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + gelu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + gelu_quick_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + gelu_quick_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + + +inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + hardsigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + hardsigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + + GGML_ASSERT(dst->src[0]->type == dst->type); + float negative_slope; + memcpy(&negative_slope, dst->op_params, sizeof(float)); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + leaky_relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), negative_slope, main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + leaky_relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), negative_slope, main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + #if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0]; + const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1]; + const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2]; + const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3]; + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + upscale_sycl(data_pts.src, data_pts.dst, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2], + dst->src[0]->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, + main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + upscale_sycl(data_pts.src, data_pts.dst, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2], + dst->src[0]->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, + main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0], + dst->ne[1], dst->ne[2], main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0], + dst->ne[1], dst->ne[2], main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined(GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + float min; + float max; + memcpy(&min, dst->op_params, sizeof(float)); + memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); + + switch (dst->type) { +#if defined(GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + clamp_sycl(data_pts.src, data_pts.dst, min, max, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + clamp_sycl(data_pts.src, data_pts.dst, min, max, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + const float * src1_dd = static_cast(dst->src[1]->data); + float * dst_dd = static_cast(dst->data); + + int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 + int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 + // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused + int offset = dst->op_params[3] / 4; // offset in bytes + + acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], nb1, nb2, offset, main_stream); +} + + +void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_sqrt(ctx, dst); +} + +void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_sin(ctx, dst); +} + +void ggml_sycl_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_cos(ctx, dst); +} + +void ggml_sycl_acc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_acc(ctx, dst); +} + +void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_gelu(ctx, dst); +} + +void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_silu(ctx, dst); +} + +void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_gelu_quick(ctx, dst); +} + +void ggml_sycl_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_gelu_erf(ctx, dst); +} + +void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_tanh(ctx, dst); +} + +void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_relu(ctx, dst); +} + +void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_sigmoid(ctx, dst); +} + +void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_hardsigmoid(ctx, dst); +} + +void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_hardswish(ctx, dst); +} + +void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_exp(ctx, dst); +} + +void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_log(ctx, dst); +} + +void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_neg(ctx, dst); +} + +void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_step(ctx, dst); +} + +void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_leaky_relu(ctx, dst); +} + +void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_sqr(ctx, dst); +} + +void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_upscale(ctx, dst); +} + +void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_pad(ctx, dst); +} + +void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_clamp(ctx, dst); +} + +void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_sgn(ctx, dst); +} + +void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_abs(ctx, dst); +} + +void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_elu(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bd40113f0970560317caa4d4fc4bc506c5485495 --- /dev/null +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -0,0 +1,77 @@ +#ifndef GGML_SYCL_ELEMENTWISE_HPP +#define GGML_SYCL_ELEMENTWISE_HPP + +#include "common.hpp" +#include "ggml.h" +#include + +template +T neg_infinity() { + return -std::numeric_limits::infinity(); +} + +template +struct typed_data { + const T * src; + T * dst; +}; + +template +typed_data cast_data(ggml_tensor * dst) { + return { + /* .src = */ static_cast(dst->src[0]->data), + /* .dst = */ static_cast(dst->data) + }; +} + +void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_acc(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +#endif // GGML_SYCL_ELEMENTWISE_HPP + diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5efe03d364b1b2d11661827139fbd792849259cd --- /dev/null +++ b/ggml/src/ggml-sycl/gemm.hpp @@ -0,0 +1,102 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_GEMM_HPP +#define GGML_SYCL_GEMM_HPP + +#include "ggml-sycl.h" + +#if GGML_SYCL_DNNL + +#include "dnnl.hpp" +#include "dnnl_sycl.hpp" + +class DnnlGemmWrapper { +public: + using dt = dnnl::memory::data_type; + using tag = dnnl::memory::format_tag; + + template + static constexpr dt to_dt() { + if constexpr (std::is_same_v) return dt::f32; + else if constexpr (std::is_same_v) return dt::f16; + else static_assert(0); + } + + // matrix A has m rows, k columns + // matrix B has k rows, n columns + // nra - number of elements to skip when moving into next row in A + // nrb - number of elements to skip when moving into next row in B + // nca - number of elements to skip when moving into next column in A + // ncb - number of elements to skip when moving into next column in B + // stride_a - number of elements to skip when moving to next A matrix + // stride_b - number of elements to skip when moving to next B matrix + // batches_a - number of A matrices + // batches_b - number of B matrices + static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k, + const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a, + const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b, + void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) { + + auto stream = ctx.stream_dnnl(q); + auto eng = ctx.engine_dnnl(q); + + // { # strides, # rows, # columns } + dnnl::memory::dims a_dims = { batches_a, m, k }; + dnnl::memory::dims b_dims = { batches_b, k, n }; + dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n }; + + // { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column } + dnnl::memory::dims a_strides = { stride_a, nra, nca }; + dnnl::memory::dims b_strides = { stride_b, nrb, ncb }; + + const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides); + const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides); + const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc); + + dnnl::primitive_attr primitive_attr; + primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); +#ifdef GGML_SYCL_F16 + primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16); +#endif + + auto a_mem = dnnl::memory(a_in_md, eng, const_cast(a)); + auto b_mem = dnnl::memory(b_in_md, eng, const_cast(b)); + auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md, primitive_attr); + auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c); + + auto scratchpad_md = matmul_pd.scratchpad_desc(); + auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q); + auto matmul_prim = dnnl::matmul(matmul_pd); + + std::unordered_map matmul_args; + matmul_args.insert({ DNNL_ARG_SRC, a_mem }); + matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem }); + matmul_args.insert({ DNNL_ARG_DST, c_mem }); + matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem }); + + matmul_prim.execute(stream, matmul_args); + } + + // matrices A and B are column major, both having k rows + // matrix A has m column, matrix B has n columns + // output: column major matrix C = A transposed * B + static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k, + const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) { + + gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1); + } +}; + +#endif + +#endif // GGML_SYCL_GEMM_HPP diff --git a/ggml/src/ggml-sycl/getrows.cpp b/ggml/src/ggml-sycl/getrows.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4a7712781364e9a07bbcd7f3d0503a3eb809f067 --- /dev/null +++ b/ggml/src/ggml-sycl/getrows.cpp @@ -0,0 +1,309 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#include "ggml-impl.h" +#include "common.hpp" +#include "dequantize.hpp" +#include "getrows.hpp" + + +template +static void k_get_rows( + const void * src0, const int32_t * src1, dst_t * dst, + int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ + /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ + /*size_t s0,*/ size_t s1, size_t s2, size_t s3, + /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, + size_t s10, size_t s11, size_t s12, + const sycl::nd_item<3> &item_ct1/*, size_t s13*/) { + + const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) + + item_ct1.get_local_id(2)) * + 2; + const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1); + const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + + item_ct1.get_local_id(0)) / + ne12; + const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + + item_ct1.get_local_id(0)) % + ne12; + + if (i00 >= ne00) { + return; + } + + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03; + + const int ib = i00/qk; // block index + const int iqs = (i00%qk)/qr; // quant index + const int iybs = i00 - i00%qk; // dst block start index + const int y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + dfloat2 v; + dequantize_kernel(src0_row, ib, iqs, v); + + dst_row[iybs + iqs + 0] = v.x(); + dst_row[iybs + iqs + y_offset] = v.y(); +} + +template +static void k_get_rows_reorder( + const void * src0, const void *src0_dq, const int32_t * src1, dst_t * dst, + int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ + /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ + /*size_t s0,*/ size_t s1, size_t s2, size_t s3, + /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, + size_t s10, size_t s11, size_t s12, + const sycl::nd_item<3> &item_ct1/*, size_t s13*/) { + + const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) + + item_ct1.get_local_id(2)) * + 2; + const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1); + const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + + item_ct1.get_local_id(0)) / + ne12; + const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + + item_ct1.get_local_id(0)) % + ne12; + + if (i00 >= ne00) { + return; + } + auto ncols = ne00; + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + + const int src0_off = i01 * ncols + i00; + const int ib = src0_off / QK4_0; // block index + const int iqs = (i00%qk)/qr; // x quant index + const int iybs = i00 - i00%qk; // dst block start index + const int y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + dfloat2 v; + dequantize_kernel_recorder((const void *)src0_dq, ib, (const void *)src0, src0_off/2, v); + + dst_row[iybs + iqs + 0] = v.x(); + dst_row[iybs + iqs + y_offset] = v.y(); + + GGML_UNUSED(nb01); + GGML_UNUSED(nb02); + GGML_UNUSED(nb03); +} + +template +static void k_get_rows_float( + const src0_t * src0, const int32_t * src1, dst_t * dst, + int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ + /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ + /*size_t s0,*/ size_t s1, size_t s2, size_t s3, + /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, + size_t s10, size_t s11, size_t s12, + const sycl::nd_item<3> &item_ct1/*, size_t s13*/) { + + const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + + item_ct1.get_local_id(2); + const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1); + const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + + item_ct1.get_local_id(0)) / + ne12; + const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + + item_ct1.get_local_id(0)) % + ne12; + + if (i00 >= ne00) { + return; + } + + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03); + + dst_row[i00] = src0_row[i00]; +} + +template +static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const void *src0_dd, + const int32_t *src1_dd, float *dst_dd, + queue_ptr stream) { + + GGML_TENSOR_BINARY_OP_LOCALS + + const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE); + const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE); + const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x); + + // strides in elements + //const size_t s0 = nb0 / ggml_element_size(dst); + const size_t s1 = nb1 / ggml_element_size(dst); + const size_t s2 = nb2 / ggml_element_size(dst); + const size_t s3 = nb3 / ggml_element_size(dst); + + const size_t s10 = nb10 / ggml_element_size(src1); + const size_t s11 = nb11 / ggml_element_size(src1); + const size_t s12 = nb12 / ggml_element_size(src1); + //const size_t s13 = nb13 / ggml_element_size(src1); + + GGML_ASSERT(ne00 % 2 == 0); + + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_get_rows( + src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, + s3, nb01, nb02, nb03, s10, s11, s12, item_ct1); + }); + + GGML_UNUSED(dst); + GGML_UNUSED(ctx); +} + +template +static void get_rows_sycl_reorder(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const void *src0_dd, + const int32_t *src1_dd, float *dst_dd, + queue_ptr stream) { + + GGML_TENSOR_BINARY_OP_LOCALS + + const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE); + const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE); + const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x); + + // strides in elements + //const size_t s0 = nb0 / ggml_element_size(dst); + const size_t s1 = nb1 / ggml_element_size(dst); + const size_t s2 = nb2 / ggml_element_size(dst); + const size_t s3 = nb3 / ggml_element_size(dst); + + const size_t s10 = nb10 / ggml_element_size(src1); + const size_t s11 = nb11 / ggml_element_size(src1); + const size_t s12 = nb12 / ggml_element_size(src1); + //const size_t s13 = nb13 / ggml_element_size(src1); + + GGML_ASSERT(ne00 % 2 == 0); + + const uint8_t* src0_q = (const uint8_t*)src0_dd; + const size_t ncols = ne00; + const size_t nrows = ne01; + const sycl::half* src0_dq = (const sycl::half*)(src0_q + nrows * ncols / 2); + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{ + k_get_rows_reorder( + src0_dd, src0_dq, src1_dd, dst_dd, ne00, ne12, s1, s2, + s3, nb01, nb02, nb03, s10, s11, s12, item_ct1); + }); + + GGML_UNUSED(dst); + GGML_UNUSED(ctx); +} + + +template +static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const src0_t *src0_dd, const int32_t *src1_dd, + float *dst_dd, queue_ptr stream) { + + GGML_TENSOR_BINARY_OP_LOCALS + + const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE); + const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE; + const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x); + + // strides in elements + //const size_t s0 = nb0 / ggml_element_size(dst); + const size_t s1 = nb1 / ggml_element_size(dst); + const size_t s2 = nb2 / ggml_element_size(dst); + const size_t s3 = nb3 / ggml_element_size(dst); + + const size_t s10 = nb10 / ggml_element_size(src1); + const size_t s11 = nb11 / ggml_element_size(src1); + const size_t s12 = nb12 / ggml_element_size(src1); + //const size_t s13 = nb13 / ggml_element_size(src1); + + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, + s3, nb01, nb02, nb03, s10, s11, s12, item_ct1); + }); + } + + GGML_UNUSED(dst); + GGML_UNUSED(ctx); +} + +void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_ASSERT(dst->src[0]->nb[0] == ggml_type_size(dst->src[0]->type)); + GGML_ASSERT(dst->src[1]->nb[0] == ggml_type_size(dst->src[1]->type)); + GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type)); + + const int32_t * src1_i32 = (const int32_t *) dst->src[1]->data; + /* TODO: Refactor and remove duplicates */ + switch (dst->src[0]->type) { + case GGML_TYPE_F16: + get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const sycl::half *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_F32: + get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_Q4_0: + if (ctx.opt_feature.reorder && dst->op == GGML_OP_MUL_MAT) { + get_rows_sycl_reorder(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + } else { + get_rows_sycl(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + } + break; + case GGML_TYPE_Q4_1: + get_rows_sycl(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_Q5_0: + get_rows_sycl(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_Q5_1: + get_rows_sycl(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_Q8_0: + get_rows_sycl(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + default: + // TODO: k-quants + GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(dst->src[0]->type)); + GGML_ABORT("fatal error"); + } +} diff --git a/ggml/src/ggml-sycl/getrows.hpp b/ggml/src/ggml-sycl/getrows.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1c560cd9f8941e409c70cdac60d4b6fa02f7e515 --- /dev/null +++ b/ggml/src/ggml-sycl/getrows.hpp @@ -0,0 +1,20 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_GETROWS_HPP +#define GGML_SYCL_GETROWS_HPP + +#include "common.hpp" + +void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst); + +#endif // GGML_SYCL_GETROWS_HPP diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4b7610362b6080923373d4a3affbdf6fde243615 --- /dev/null +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -0,0 +1,4615 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "ggml-sycl.h" +#include "ggml-impl.h" +#include "ggml-backend-impl.h" + +#include "ggml-sycl/backend.hpp" +#include "ggml-sycl/common.hpp" +#include "ggml-sycl/element_wise.hpp" +#include "ggml-sycl/presets.hpp" +#include "ggml-sycl/gemm.hpp" +#include "ggml-sycl/sycl_hw.hpp" +#include "ggml-sycl/getrows.hpp" +#include "ggml.h" + +static bool g_sycl_loaded = false; +int g_ggml_sycl_debug = 0; +int g_ggml_sycl_disable_optimize = 0; +int g_ggml_sycl_disable_graph = 0; +int g_ggml_sycl_disable_dnn = 0; +int g_ggml_sycl_prioritize_dmmv = 0; + +static ggml_sycl_device_info ggml_sycl_init() { + ggml_sycl_device_info info = {}; + + info.device_count = dpct::dev_mgr::instance().device_count(); + if (info.device_count == 0) { + GGML_LOG_ERROR("%s: failed to initialize: %s\n", GGML_SYCL_NAME, __func__); + return info; + } + + GGML_ASSERT(info.device_count <= GGML_SYCL_MAX_DEVICES); + + int64_t total_vram = 0; +/* This is a bit misleading; reserved for later */ +// #if defined(SYCL_USE_XMX) +// GGML_LOG_INFO("%s: SYCL_USE_XMX: yes\n", __func__); +// #else +// GGML_LOG_INFO("%s: SYCL_USE_XMX: no\n", __func__); +// #endif + for (int i = 0; i < info.device_count; ++i) { + info.devices[i].vmm = 0; + dpct::device_info prop; + sycl::device device = dpct::dev_mgr::instance().get_device(i); + + SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( + prop, device))); + + info.default_tensor_split[i] = total_vram; + total_vram += prop.get_global_mem_size(); + + info.devices[i].cc = + 100 * prop.get_major_version() + 10 * prop.get_minor_version(); + info.devices[i].hw_info = get_device_hw_info(&device); + info.devices[i].opt_feature = check_gpu_optimize_feature(info.devices[i].hw_info.arch); + + info.max_work_group_sizes[i] = prop.get_max_work_group_size(); + } + + for (int id = 0; id < info.device_count; ++id) { + info.default_tensor_split[id] /= total_vram; + } + return info; +} + +const ggml_sycl_device_info & ggml_sycl_info() { + static ggml_sycl_device_info info = ggml_sycl_init(); + return info; +} + +static void print_device_detail(int id, sycl::device &device, std::string device_type) { + + dpct::device_info prop; + SYCL_CHECK(CHECK_TRY_ERROR( + dpct::get_device_info(prop, device))); + + std::string version; + version += std::to_string(prop.get_major_version()); + version += "."; + version += std::to_string(prop.get_minor_version()); + + device_type = std::regex_replace(device_type, std::regex("ext_oneapi_"), ""); + std::string name = std::string(prop.get_name()); + name = std::regex_replace(name, std::regex("\\(R\\)"), ""); + name = std::regex_replace(name, std::regex("\\(TM\\)"), ""); + + auto global_mem_size = prop.get_global_mem_size()/1000000; + GGML_LOG_INFO("|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(), + name.c_str(), version.c_str(), prop.get_max_compute_units(), + prop.get_max_work_group_size(), prop.get_max_sub_group_size(), + global_mem_size, device.get_info().c_str()); +} + +static void print_device_opt_feature(int device_count) { + GGML_LOG_INFO("SYCL Optimization Feature:\n"); + GGML_LOG_INFO( + "|ID| Device Type|Reorder|\n"); + GGML_LOG_INFO( + "|--|-------------------|-------|\n"); + std::map DeviceNums; + for (int id = 0; id < device_count; ++id) { + sycl::device device = dpct::dev_mgr::instance().get_device(id); + std::string backend_type = get_device_backend_and_type(device); + int type_id = DeviceNums[backend_type]++; + std::stringstream device_type; + device_type << "[" << backend_type << ":" << std::to_string(type_id) + << "]"; + std::string device_type_s = device_type.str(); + device_type_s = std::regex_replace(device_type_s, std::regex("ext_oneapi_"), ""); + GGML_LOG_INFO("|%2d|%19s|%7s|\n", id, device_type_s.c_str(), + ggml_sycl_info().devices[id].opt_feature.reorder ? "Y": "N"); + } + +} +void ggml_backend_sycl_print_sycl_devices() { + GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n"); + int device_count = dpct::dev_mgr::instance().device_count(); + std::map DeviceNums; + GGML_LOG_INFO("Found %d SYCL devices:\n", device_count); + + GGML_LOG_INFO( + "| | | | " + " |Max | |Max |Global | |\n"); + GGML_LOG_INFO( + "| | | | " + " |compute|Max work|sub |mem | |\n"); + GGML_LOG_INFO( + "|ID| Device Type| " + "Name|Version|units |group |group|size | Driver version|\n"); + GGML_LOG_INFO( + "|--|-------------------|---------------------------------------|------" + "-|-------|--------|-----|-------|---------------------|\n"); + + for (int id = 0; id < device_count; ++id) { + sycl::device device = dpct::dev_mgr::instance().get_device(id); + std::string backend_type = get_device_backend_and_type(device); + int type_id = DeviceNums[backend_type]++; + std::stringstream device_type; + device_type << "[" << backend_type << ":" << std::to_string(type_id) + << "]"; + print_device_detail(id, device, device_type.str()); + } + + print_device_opt_feature(device_count); +} + +static inline int get_sycl_env(const char *env_name, int default_val) { + char *user_device_string = getenv(env_name); + int user_number = default_val; + + unsigned n; + if (user_device_string != NULL && + sscanf(user_device_string, " %u", &n) == 1) { + user_number = (int)n; + } else { + user_number = default_val; + } + return user_number; +} + +static void ggml_check_sycl() try { + static bool initialized = false; + + if (!initialized) { + g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0); + g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1); + g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1); + g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0); + g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); + GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n"); + GGML_LOG_INFO("Running with Environment Variables:\n"); + GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); + GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize); +#ifdef GGML_SYCL_GRAPH + GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph); +#else + GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n"); +#endif +#if GGML_SYCL_DNNL + GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn); +#else + GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n"); +#endif + GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv); + GGML_LOG_INFO("Build with Macros:\n"); +#if defined(GGML_SYCL_FORCE_MMQ) + GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n"); +#else + GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n"); +#endif +#if defined(GGML_SYCL_F16) + GGML_LOG_INFO(" GGML_SYCL_F16: yes\n"); +#else + GGML_LOG_INFO(" GGML_SYCL_F16: no\n"); +#endif + +/* NOT REMOVE, keep it for next optimize for XMX. +#if defined(SYCL_USE_XMX) + fprintf(stderr, "%s: SYCL_USE_XMX: yes\n", __func__); +#else + fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__); +#endif +*/ + + if (CHECK_TRY_ERROR(g_all_sycl_device_count = + dpct::dev_mgr::instance().device_count()) != 0) { + initialized = true; + g_sycl_loaded = false; + return; + } + GGML_ASSERT(g_all_sycl_device_count <= GGML_SYCL_MAX_DEVICES); + + initialized = true; + g_sycl_loaded = true; + ggml_backend_sycl_print_sycl_devices(); + } +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +/* +device_index: device index from 0 to n (continue numbers). + It is used for device select/set in SYCL backend internal data structure. +*/ +inline void check_allow_gpu_index(const int device_index) { + if (device_index >= ggml_sycl_info().device_count) { + char error_buf[256]; + snprintf( + error_buf, + sizeof(error_buf), + "%s error: device_index:%d is out of range: [0-%d]", + __func__, + device_index, + ggml_sycl_info().device_count - 1); + GGML_LOG_ERROR("%s\n", error_buf); + assert(false); + } +} + +GGML_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len) try { + GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_gpu_list\n"); + for(int i=0;i=max_len) break; + id_list[i] = i; + } + return; +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +// sycl buffer + +struct ggml_backend_sycl_buffer_context { + int device; + void * dev_ptr = nullptr; + queue_ptr stream; + std::string name; + optimize_feature opt_feature; + std::vector tensor_extras; + + ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) : + device(device), dev_ptr(dev_ptr), stream(stream) { + check_allow_gpu_index(device); + name = (GGML_SYCL_NAME + std::to_string(device)); + opt_feature = ggml_sycl_info().devices[device].opt_feature; + } + + ~ggml_backend_sycl_buffer_context() { + if (dev_ptr != nullptr) { + ggml_sycl_set_device(device); + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream))); + } + + //release extra used by tensors + for (ggml_tensor_extra_gpu * extra : tensor_extras) { + release_extra_gpu(extra); + } + + } +}; + +static const char * ggml_backend_sycl_buffer_type_get_name(ggml_backend_buffer_type_t buft); + +static bool ggml_backend_buffer_is_sycl(ggml_backend_buffer_t buffer) { + return buffer->buft->iface.get_name == ggml_backend_sycl_buffer_type_get_name; +} + +static void +ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try { + ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; + ggml_sycl_set_device(ctx->device); + + delete ctx; +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void * ggml_backend_sycl_buffer_get_base(ggml_backend_buffer_t buffer) { + ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; + return ctx->dev_ptr; +} + +static enum ggml_status +ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, + ggml_tensor *tensor) try { + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor, "\n").c_str()); + ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context; + + if (tensor->view_src != NULL) { + assert(tensor->view_src->buffer->buft == buffer->buft); + return GGML_STATUS_SUCCESS; + } + if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) && + !g_ggml_sycl_disable_optimize) { + ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; + tensor->extra = extra; + ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx. + } + + if (ggml_is_quantized(tensor->type)) { + // initialize padding to 0 to avoid possible NaN values + size_t original_size = ggml_nbytes(tensor); + size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor); + + if (padded_size > original_size && tensor->view_src == nullptr) { + SYCL_CHECK(CHECK_TRY_ERROR(ctx->stream->memset( + (char *)tensor->data + original_size, 0, + padded_size - original_size).wait())); + } + } + return GGML_STATUS_SUCCESS; +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer, + ggml_tensor *tensor, + const void *data, size_t offset, + size_t size) try { + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str()); + GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); + ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; + ggml_sycl_set_device(ctx->device); + auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue()); + SYCL_CHECK(CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw())); +#ifndef _WIN32 + // Note: Use host buffer to save the data from mmap(), then copy to device. It's workaround for mmap() issue on PVC GPU. + // This function will be called during load model from disk. Use memory buffer replace dynamic won't save more time and brings potential memory leak risk here. + char * host_buf = (char *) malloc(size); + memcpy(host_buf, data, size); + SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, host_buf, size).wait())); + free(host_buf); +#else + SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, data, size).wait())); +#endif +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_backend_sycl_buffer_get_tensor(ggml_backend_buffer_t buffer, + const ggml_tensor *tensor, + void *data, size_t offset, + size_t size) try { + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str()); + GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); + ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; + + ggml_sycl_set_device(ctx->device); + auto stream = dpct::dev_mgr::instance().get_device(ctx->device).default_queue(); + + SYCL_CHECK(CHECK_TRY_ERROR( + stream.memcpy(data, (const char *)tensor->data + offset, size) + .wait())); +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst, + const void *ptr_src, size_t size) { + char *host_buf = (char *)malloc(size); + q_src.memcpy(host_buf, (const char *)ptr_src, size).wait(); + q_dst.memcpy((char *)ptr_dst, host_buf, size).wait(); + free(host_buf); +} + +static bool +ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer, + const ggml_tensor *src, + ggml_tensor *dst) try { + bool is_cpy_supported = ggml_backend_buffer_is_sycl(src->buffer); + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": dst", dst).c_str()); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(" src", src).c_str()); + GGML_SYCL_DEBUG(" is_cpy_supported=%d\n", is_cpy_supported); + if (is_cpy_supported) { + ggml_backend_sycl_buffer_context * src_ctx = (ggml_backend_sycl_buffer_context *)src->buffer->context; + ggml_backend_sycl_buffer_context * dst_ctx = (ggml_backend_sycl_buffer_context *)dst->buffer->context; + + ggml_sycl_set_device(src_ctx->device); + /* + DPCT1009:198: SYCL uses exceptions to report errors and does not use the + error codes. The original code was commented out and a warning string + was inserted. You need to rewrite this code. + */ + SYCL_CHECK(CHECK_TRY_ERROR( + dpct::dev_mgr::instance().get_device(src_ctx->device).queues_wait_and_throw())); + ggml_sycl_set_device(dst_ctx->device); + /* + DPCT1009:199: SYCL uses exceptions to report errors and does not use the + error codes. The original code was commented out and a warning string + was inserted. You need to rewrite this code. + */ + SYCL_CHECK(CHECK_TRY_ERROR( + dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw())); + /* + DPCT1009:200: SYCL uses exceptions to report errors and does not use the + error codes. The original code was commented out and a warning string + was inserted. You need to rewrite this code. + */ + + queue_ptr stream_dst = dst_ctx->stream; + queue_ptr stream_src = src_ctx->stream; + size_t size = ggml_nbytes(src); + + //todo. it's dirty solutino to walkaroud known issue:device2device cross GPUs. + dev2dev_memcpy(*stream_dst, *stream_src, dst->data, src->data, size); + +//todo, it's known issue:error in device2device cross GPUs. reused when the issue is fixed. DON"T remove +#if 0 + SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy( + (char *)dst->data, (const char *)src->data, size).wait())); + + /* + DPCT1009:201: SYCL uses exceptions to report errors and does not use the + error codes. The original code was commented out and a warning string + was inserted. You need to rewrite this code. + */ + SYCL_CHECK(CHECK_TRY_ERROR( + dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw())); +#endif + return true; + } + return false; + GGML_UNUSED(buffer); +} catch (const sycl::exception & exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer, + uint8_t value) try { + GGML_SYCL_DEBUG("[SYCL] call %s: size=%zu\n", __func__, buffer->size); + ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context; + + ggml_sycl_set_device(ctx->device); + queue_ptr stream = ctx->stream; + SYCL_CHECK( + CHECK_TRY_ERROR(dpct::get_current_device().queues_wait_and_throw())); + + SYCL_CHECK(CHECK_TRY_ERROR((*stream) + .memset(ctx->dev_ptr, value, buffer->size) + .wait())); +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_backend_sycl_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, + size_t offset, size_t size) { + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str()); + GGML_SYCL_DEBUG(" size=%zu offset=%zu value=%u\n", size, offset, value); + ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context; + SYCL_CHECK(ggml_sycl_set_device(ctx->device)); + auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue()); + if (size == 0) { + return; // Nothing to do + } + if (tensor->data == nullptr) { + GGML_ABORT("Error: Tensor data pointer is null.\n"); + } + void * target_ptr = static_cast(tensor->data) + offset; + SYCL_CHECK(CHECK_TRY_ERROR((*stream).memset(target_ptr, value, size))); + SYCL_CHECK(CHECK_TRY_ERROR((*stream).wait())); +} + +static void ggml_backend_sycl_buffer_reset(ggml_backend_buffer_t buffer) { + GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__); + if (buffer == nullptr) { + return; + } + + ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context; + + if (ctx != nullptr) { + for (ggml_tensor_extra_gpu * extra : ctx->tensor_extras) { + release_extra_gpu(extra); + } + ctx->tensor_extras.clear(); // reset the tensor_extras vector + } +} + +static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = { + /* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer, + /* .get_base = */ ggml_backend_sycl_buffer_get_base, + /* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor, + /* .memset_tensor = */ ggml_backend_sycl_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor, + /* .clear = */ ggml_backend_sycl_buffer_clear, + /* .reset = */ ggml_backend_sycl_buffer_reset, +}; + +// sycl buffer type +struct ggml_backend_sycl_buffer_type_context { + int device; + std::string name; + + // each buffer type has its own stream + queue_ptr stream = nullptr; +}; + +static const char * ggml_backend_sycl_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + ggml_backend_sycl_buffer_type_context * ctx = (ggml_backend_sycl_buffer_type_context *)buft->context; + + return ctx->name.c_str(); +} + +static ggml_backend_buffer_t +ggml_backend_sycl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, + size_t size) try { + ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context; + ggml_sycl_set_device(buft_ctx->device); + const queue_ptr stream = buft_ctx->stream; + size = std::max(size, (size_t)1); // syclMalloc returns null for size 0 + + void * dev_ptr; + SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)sycl::malloc_device( + size, *stream))); + if (!dev_ptr) { + GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device\n", __func__, size); + return nullptr; + } + ggml_backend_sycl_buffer_context * ctx = new ggml_backend_sycl_buffer_context(buft_ctx->device, dev_ptr, buft_ctx->stream); + return ggml_backend_buffer_init(buft, ggml_backend_sycl_buffer_interface, ctx, size); +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return 128; + GGML_UNUSED(buft); +} + +static size_t ggml_backend_sycl_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + return dpct::get_current_device().get_max_mem_alloc_size(); + + GGML_UNUSED(buft); +} + +static size_t ggml_backend_sycl_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + size_t size = ggml_nbytes(tensor); + int64_t ne0 = tensor->ne[0]; + + if (ggml_is_quantized(tensor->type)) { + if (ne0 % MATRIX_ROW_PADDING != 0) { + size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); + } + } + + return size; + + GGML_UNUSED(buft); +} + +static const ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = { + /* .get_name = */ ggml_backend_sycl_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_sycl_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_sycl_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_sycl_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_sycl_buffer_type_get_alloc_size, + /* .is_host = */ NULL, +}; + +ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) { + static std::mutex mutex; + std::lock_guard lock(mutex); + + + auto dev_count = ggml_backend_sycl_get_device_count(); + + if (device>=dev_count or device<0) { + GGML_LOG_ERROR("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n", + device, dev_count-1); + GGML_ASSERT(devicedevice; + if (device>=ggml_sycl_info().device_count or device<0) { + GGML_LOG_ERROR("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n", + device, ggml_sycl_info().device_count-1); + GGML_ASSERT(devicestream(i, 0)}, + }; + } + ggml_backend_sycl_buffer_type_initialized = true; + } + return &ggml_backend_sycl_buffer_types[device]; +} + +// sycl split buffer + +static int64_t get_row_rounding(ggml_type type, const std::array & tensor_split) { + int64_t min_compute_capability = INT_MAX; + int64_t max_compute_capability = INT_MIN; + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + if (tensor_split[i] < (i + 1 < ggml_sycl_info().device_count ? tensor_split[i + 1] : 1.0f)) { + if (min_compute_capability > ggml_sycl_info().devices[i].cc) { + min_compute_capability = ggml_sycl_info().devices[i].cc; + } + if (max_compute_capability < ggml_sycl_info().devices[i].cc) { + max_compute_capability = ggml_sycl_info().devices[i].cc; + } + } + } + + switch(type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + return max_compute_capability >= VER_GEN9 ? 128 : 64; + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + return 64; + case GGML_TYPE_F16: + case GGML_TYPE_F32: + return 1; + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + return max_compute_capability >= VER_GEN9 ? 128 : 64; + case GGML_TYPE_IQ3_S: + return max_compute_capability >= VER_GEN9 ? 128 : 64; + case GGML_TYPE_Q6_K: + return 64; + default: + GGML_ABORT("fatal error"); + } +} + +static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array & tensor_split, int id) { + const int64_t nrows = ggml_nrows(tensor); + const int64_t rounding = get_row_rounding(tensor->type, tensor_split); + + *row_low = id == 0 ? 0 : nrows*tensor_split[id]; + *row_low -= *row_low % rounding; + if (id == ggml_sycl_info().device_count - 1) { + *row_high = nrows; + } else { + *row_high = nrows*tensor_split[id + 1]; + *row_high -= *row_high % rounding; + } +} + +static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]); +} + +struct ggml_backend_sycl_split_buffer_type_context { + std::array tensor_split; +}; + +struct ggml_backend_sycl_split_buffer_context { + ~ggml_backend_sycl_split_buffer_context() try { + for (ggml_tensor_extra_gpu * extra : tensor_extras) { + release_extra_gpu(extra, streams); + } + } + catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); + } + + std::vector tensor_extras; + std::vector streams; +}; + +static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; + delete ctx; +} + +static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) { + // the pointers are stored in the tensor extras, this is just a dummy address and never dereferenced + return (void *)0x1000; + + GGML_UNUSED(buffer); +} + +static enum ggml_status +ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer, + ggml_tensor *tensor) try { + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor, "\n").c_str()); + GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported + + ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; + ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context; + + const int64_t ne0 = tensor->ne[0]; + + ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; + + ctx->tensor_extras.push_back(extra); + ctx->streams.push_back(&(dpct::get_current_device().default_queue())); + + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + int64_t row_low, row_high; + get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i); + + int64_t nrows_split = row_high - row_low; + if (nrows_split == 0) { + continue; + } + + size_t size = ggml_nbytes_split(tensor, nrows_split); + const size_t original_size = size; + + // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses + if (ne0 % MATRIX_ROW_PADDING != 0) { + size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); + } + + // FIXME: do not crash if SYCL Buffer alloc fails + // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first + ggml_sycl_set_device(i); + const queue_ptr stream = ctx->streams[i]; + char * buf; + /* + DPCT1009:208: SYCL uses exceptions to report errors and does not use the + error codes. The original code was commented out and a warning string + was inserted. You need to rewrite this code. + */ + SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)sycl::malloc_device( + size, *stream))); + if (!buf) { + char err_buf[1024]; + snprintf(err_buf, 1023, "%s: can't allocate %lu Bytes of memory on device\n", __func__, size); + throw std::runtime_error(err_buf); + } + // set padding to 0 to avoid possible NaN values + if (size > original_size) { + /* + DPCT1009:209: SYCL uses exceptions to report errors and does not use + the error codes. The original code was commented out and a warning + string was inserted. You need to rewrite this code. + */ + SYCL_CHECK(CHECK_TRY_ERROR( + (*stream) + .memset(buf + original_size, 0, size - original_size) + .wait())); + } + + extra->data_device[i] = buf; + + for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) { + /* + DPCT1009:210: SYCL uses exceptions to report errors and does not use + the error codes. The original code was commented out and a warning + string was inserted. You need to rewrite this code. + */ + SYCL_CHECK( + CHECK_TRY_ERROR(extra->events[i][is] = new sycl::event())); + } + } + tensor->extra = extra; + return GGML_STATUS_SUCCESS; +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void +ggml_backend_sycl_split_buffer_set_tensor(ggml_backend_buffer_t buffer, + ggml_tensor *tensor, const void *data, + size_t offset, size_t size) try { + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str()); + GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); + // split tensors must always be set in their entirety at once + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + + ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; + ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context; + + const int64_t ne0 = tensor->ne[0]; + const size_t nb1 = tensor->nb[1]; + ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra; + + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + int64_t row_low, row_high; + get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i); + + int64_t nrows_split = row_high - row_low; + if (nrows_split == 0) { + continue; + } + + const size_t offset_split = row_low*nb1; + size_t size = ggml_nbytes_split(tensor, nrows_split); + const size_t original_size = size; + + // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses + if (ne0 % MATRIX_ROW_PADDING != 0) { + size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); + } + + const char * buf_host = (const char *)data + offset_split; + /* + DPCT1009:211: SYCL uses exceptions to report errors and does not use the + error codes. The original code was commented out and a warning string + was inserted. You need to rewrite this code. + */ + ggml_sycl_set_device(i); + const queue_ptr stream = ctx->streams[i]; + SYCL_CHECK(CHECK_TRY_ERROR( + (*stream) + .memcpy(extra->data_device[i], buf_host, original_size) + .wait())); + } +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void +ggml_backend_sycl_split_buffer_get_tensor(ggml_backend_buffer_t buffer, + const ggml_tensor *tensor, void *data, + size_t offset, size_t size) try { + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str()); + GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); + // split tensors must always be set in their entirety at once + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + + ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; + ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context; + + const int64_t ne0 = tensor->ne[0]; + const size_t nb1 = tensor->nb[1]; + ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra; + + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + int64_t row_low, row_high; + get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i); + + int64_t nrows_split = row_high - row_low; + if (nrows_split == 0) { + continue; + } + + const size_t offset_split = row_low*nb1; + size_t size = ggml_nbytes_split(tensor, nrows_split); + const size_t original_size = size; + + // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses + if (ne0 % MATRIX_ROW_PADDING != 0) { + size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); + } + + char * buf_host = (char *)data + offset_split; + /* + DPCT1009:212: SYCL uses exceptions to report errors and does not use the + error codes. The original code was commented out and a warning string + was inserted. You need to rewrite this code. + */ + ggml_sycl_set_device(i); + const queue_ptr stream = ctx->streams[i]; + SYCL_CHECK(CHECK_TRY_ERROR( + (*stream) + .memcpy(buf_host, extra->data_device[i], original_size) + .wait())); + } +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_backend_sycl_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + GGML_UNUSED(buffer); + GGML_UNUSED(value); +} + +static struct ggml_backend_buffer_i ggml_backend_sycl_split_buffer_interface = { + /* .free_buffer = */ ggml_backend_sycl_split_buffer_free_buffer, + /* .get_base = */ ggml_backend_sycl_split_buffer_get_base, + /* .init_tensor = */ ggml_backend_sycl_split_buffer_init_tensor, + /* .memset_tensor = */ NULL, + /* .set_tensor = */ ggml_backend_sycl_split_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_sycl_split_buffer_get_tensor, + /* .cpy_tensor = */ NULL, + /* .clear = */ ggml_backend_sycl_split_buffer_clear, + /* .reset = */ NULL, +}; + +// sycl split buffer type + +static const char * ggml_backend_sycl_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return GGML_SYCL_NAME "_Split"; + + GGML_UNUSED(buft); +} + +static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) { + return buffer->buft->iface.get_name == ggml_backend_sycl_split_buffer_type_get_name; +} + +static ggml_backend_buffer_t ggml_backend_sycl_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point + // instead, we allocate them for each tensor separately in init_tensor + // however, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated, + // as returned by get_alloc_size. this limit is enforced during tensor allocation by ggml-alloc, so it must be correct. + ggml_backend_sycl_split_buffer_context * ctx = new ggml_backend_sycl_split_buffer_context(); + + return ggml_backend_buffer_init(buft, ggml_backend_sycl_split_buffer_interface, ctx, size); +} + +static size_t ggml_backend_sycl_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return 128; + GGML_UNUSED(buft); +} + +static size_t ggml_backend_sycl_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + ggml_backend_sycl_split_buffer_type_context * ctx = (ggml_backend_sycl_split_buffer_type_context *)buft->context; + + size_t total_size = 0; + + const int64_t ne0 = tensor->ne[0]; + + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + int64_t row_low, row_high; + get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, i); + + int64_t nrows_split = row_high - row_low; + if (nrows_split == 0) { + continue; + } + + total_size += ggml_nbytes_split(tensor, nrows_split); + + // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses + if (ne0 % MATRIX_ROW_PADDING != 0) { + total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); + } + } + + return total_size; +} + +static bool ggml_backend_sycl_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + return false; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_type_i ggml_backend_sycl_split_buffer_type_interface = { + /* .get_name = */ ggml_backend_sycl_split_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_sycl_split_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_sycl_split_buffer_type_get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ ggml_backend_sycl_split_buffer_type_get_alloc_size, + /* .is_host = */ ggml_backend_sycl_split_buffer_type_is_host, +}; + +ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split) { + static std::mutex mutex; + std::lock_guard lock(mutex); + + GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_split_buffer_type\n"); + ggml_check_sycl(); + // FIXME: this is not thread safe + static std::map, struct ggml_backend_buffer_type> buft_map; + + std::array tensor_split_arr = {}; + + bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_SYCL_MAX_DEVICES, [](float x) { return x == 0.0f; }); + if (all_zero) { + tensor_split_arr = ggml_sycl_info().default_tensor_split; + } else { + float split_sum = 0.0f; + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + tensor_split_arr[i] = split_sum; + split_sum += tensor_split[i]; + } + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + tensor_split_arr[i] /= split_sum; + } + } + + auto it = buft_map.find(tensor_split_arr); + if (it != buft_map.end()) { + return &it->second; + } + + struct ggml_backend_buffer_type buft { + /* .iface = */ ggml_backend_sycl_split_buffer_type_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), 0), + /* .context = */ new ggml_backend_sycl_split_buffer_type_context{tensor_split_arr}, + }; + + auto result = buft_map.emplace(tensor_split_arr, buft); + return &result.first->second; +} + +// host buffer type + +static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_type_t buft) { + return GGML_SYCL_NAME "_Host"; + + GGML_UNUSED(buft); +} + +static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_sycl_host_free(buffer->context); +} + +static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + void * ptr = ggml_sycl_host_malloc(size); + + if (ptr == nullptr) { + // fallback to cpu buffer + return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); + } + + // FIXME: this is a hack to avoid having to implement a new buffer type + ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); + buffer->buft = buft; + buffer->iface.free_buffer = ggml_backend_sycl_host_buffer_free_buffer; + + return buffer; +} + +ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type() { + GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_host_buffer_type\n"); + static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_type_host = { + /* .iface = */ { + /* .get_name = */ ggml_backend_sycl_host_buffer_type_name, + /* .alloc_buffer = */ ggml_backend_sycl_host_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment, + /* .get_max_size = */ NULL, // TODO: return device.maxBufferLength + /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, + /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), 0), + /* .context = */ nullptr, + }; + + return &ggml_backend_sycl_buffer_type_host; +} + +// buffer pool for sycl (legacy) +struct ggml_sycl_pool_leg : public ggml_sycl_pool { + static const int MAX_SYCL_BUFFERS = 256; + + int device; + queue_ptr qptr; + struct ggml_sycl_buffer { + void * ptr = nullptr; + size_t size = 0; + }; + + ggml_sycl_buffer buffer_pool[MAX_SYCL_BUFFERS] = {}; + size_t pool_size = 0; + + explicit ggml_sycl_pool_leg(queue_ptr qptr_, int device_) : device(device_), qptr(qptr_) {} + + ~ggml_sycl_pool_leg() { + for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { + ggml_sycl_buffer & b = buffer_pool[i]; + if (b.ptr != nullptr) { + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr))); + pool_size -= b.size; + } + } + GGML_ASSERT(pool_size == 0); + } + + void * alloc(size_t size, size_t * actual_size) override { +#ifdef DEBUG_sycl_MALLOC + int nnz = 0; + size_t max_size = 0; +#endif + size_t best_diff = 1ull << 36; + int ibest = -1; + for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { + ggml_sycl_buffer& b = buffer_pool[i]; + if (b.ptr != nullptr) { +#ifdef DEBUG_sycl_MALLOC + ++nnz; + if (b.size > max_size) max_size = b.size; +#endif + if (b.size >= size) { + size_t diff = b.size - size; + if (diff < best_diff) { + best_diff = diff; + ibest = i; + if (!best_diff) { + void * ptr = b.ptr; + *actual_size = b.size; + b.ptr = nullptr; + b.size = 0; + return ptr; + } + } + } + } + } + if (ibest >= 0) { + ggml_sycl_buffer& b = buffer_pool[ibest]; + void * ptr = b.ptr; + *actual_size = b.size; + b.ptr = nullptr; + b.size = 0; + return ptr; + } + void * ptr; + size_t look_ahead_size = (size_t) (1.05 * size); + + SYCL_CHECK( + CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_device( + look_ahead_size, *qptr))); + if (!ptr) { + GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device/GPU\n", __func__, look_ahead_size); + return nullptr; + } + + *actual_size = look_ahead_size; + pool_size += look_ahead_size; + +#ifdef DEBUG_SYCL_MALLOC + GGML_LOG_DEBUG("%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz, + (uint32_t)(max_size/1024/1024), (uint32_t)(g_sycl_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024)); +#endif + + // GGML_SYCL_DEBUG("ggml_sycl_pool_malloc_leg look_ahead_size=%lu, return %p\n", look_ahead_size, ptr); + return ptr; + } + + void free(void * ptr, size_t size) override { + for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { + ggml_sycl_buffer& b = buffer_pool[i]; + if (b.ptr == nullptr) { + b.ptr = ptr; + b.size = size; + return; + } + } + GGML_LOG_WARN("WARNING: sycl buffer pool full, increase MAX_sycl_BUFFERS\n"); + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr))); + pool_size -= size; + } +}; + +struct ggml_sycl_pool_host : public ggml_sycl_pool { + queue_ptr qptr; + int device; + + inline static int counter{ 0 }; + + struct ggml_sycl_buffer { + void * ptr = nullptr; + size_t size = 0; + }; + + // Set arbitrarly to 64 + static constexpr int MAX_POOL_SIZE{ 64 }; + std::vector buffer_pool = std::vector(MAX_POOL_SIZE); + size_t pool_size = 0; + + explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {} + + ~ggml_sycl_pool_host() { + for (int i = 0; i < MAX_POOL_SIZE; ++i) { + ggml_sycl_buffer & b = buffer_pool[i]; + if (b.ptr != nullptr) { + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr))); + b.ptr = nullptr; + pool_size -= b.size; + b.size = 0; + } + } + counter = 0; + } + + void * alloc(size_t size, size_t * actual_size) override { + if (counter == MAX_POOL_SIZE) { + ggml_sycl_buffer b = buffer_pool[0]; + void * ptr = b.ptr; + *actual_size = b.size; + counter = 1; + return ptr; + } + ggml_sycl_buffer & b = buffer_pool[counter]; + + if (b.ptr == nullptr) { + void * ptr; + + SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr))); + if (!ptr) { + GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size); + return nullptr; + } + pool_size += size; + *actual_size = size; + counter = counter + 1; + return ptr; + } else { + ++counter; + b.size = size; + return b.ptr; + } + } + + void free(void * ptr, size_t size) override { + // if the pool is not completed add the pointer to it in place of the first nullptr found. + // Otherwise do nothing, pointers will be freed once the pool is deallocated. + for (int i = 0; i < MAX_POOL_SIZE; ++i) { + ggml_sycl_buffer & b = buffer_pool[i]; + if (b.ptr == nullptr) { + b.ptr = ptr; + b.size = size; + return; + } + } + } +}; + +std::unique_ptr ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) { + // return pool for the host to speed up memory management + return std::unique_ptr(new ggml_sycl_pool_host(qptr, device)); +} + +std::unique_ptr ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) { + // TBD: NO VMM support + // if (ggml_sycl_info().devices[device].vmm) { + // return std::unique_ptr(new ggml_sycl_pool_vmm(device)); + // } + return std::unique_ptr(new ggml_sycl_pool_leg(qptr, device)); +} + +// TBD pool with virtual memory management +// struct ggml_sycl_pool_vmm : public ggml_sycl_pool + +/// kernels +typedef void (*ggml_sycl_op_mul_mat_t)( + ggml_backend_sycl_context & ctx, + const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, + const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, + float *dst_dd_i, const int64_t row_low, const int64_t row_high, + const int64_t src1_ncols, const int64_t src1_padded_row_size, + const queue_ptr &stream); + + + +template +static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded, + const sycl::nd_item<3> &item_ct1) { + const int ix = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2)) * QUANT_BLOCK_TILE; + + if (ix >= kx_padded) { + return; + } + + const int iy = item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1); + + const int i_padded = iy*kx_padded + ix; + + block_q8_1 * y = (block_q8_1 *) vy; + + const int ib = i_padded / QK8_1; // block index + const int iqs = i_padded % QK8_1; // quant index + typedef sycl::vec TC; + typedef sycl::vec TQ; + TC zeros; + TQ qzeros; +#pragma unroll + for (int i = 0; i < QUANT_BLOCK_TILE; i++) + { + zeros[i] = 0.f; + qzeros[i] = 0; + } + const TC xi = ix < kx ? *(const TC *)&x[iy * kx + ix] : zeros; + float sum = xi[0]; + float amax = sycl::fabs(xi[0]); +#pragma unroll + for (int i = 1; i < QUANT_BLOCK_TILE; i++) + { + sum += xi[i]; + amax = sycl::fmax(sycl::fabs(xi[i]), amax); + } + sum = warp_reduce_sum(sum, item_ct1); + amax = warp_reduce_max(amax, item_ct1); + + const float d = amax / 127; + TQ q = qzeros; + if (amax != 0.0f) + { +#pragma unroll + for (int i = 0; i < QUANT_BLOCK_TILE; i++) { + q[i] = sycl::round(xi[i] / d); + } + } + + *(TQ *)&y[ib].qs[iqs] = q; + + if (iqs > 0) { + return; + } + + reinterpret_cast(y[ib].ds.x()) = d; + reinterpret_cast(y[ib].ds.y()) = sum; +} + +template +static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, void * reordered_q8_tensor, + const int kx, const int kx_padded, const sycl::nd_item<1> & it) { + /* + Quantizes and reorders the resultant q8 tensor in a per row fashion + Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values + */ + + auto subgroup_id = it.get_group(0); + auto wi_id = it.get_local_id(0); + + const int num_blocks_per_row = kx / QK8_1; + auto row = subgroup_id / num_blocks_per_row; + auto col = subgroup_id % num_blocks_per_row; + + auto row_offset = row * (kx_padded / QK8_1) * sizeof(block_q8_1); + auto col_offset = QK8_1 * col + wi_id * ElementsPerWI; + + auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset); + auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2)); + + sycl::vec wi_f32_vals; + sycl::vec quantized_values; + + auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id; + wi_f32_vals = *reinterpret_cast *>(x + float_ptr_offset); + + float sum = 0.0f; + float amax = 0.0f; + +#pragma unroll(ElementsPerWI) + for (int i = 0; i < ElementsPerWI; i++) { + sum += wi_f32_vals[i]; + amax = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i])); + quantized_values[i] = 0; + } + sum = sycl::reduce_over_group(it.get_group(), sum, sycl::plus()); + amax = sycl::reduce_over_group(it.get_group(), amax, sycl::maximum()); + float d = amax == 0 ? 1 : amax / 127; + +#pragma unroll(ElementsPerWI) + for (int i = 0; i < ElementsPerWI; i++) { + quantized_values[i] = sycl::round(wi_f32_vals[i] / d); + } + + d = amax == 0 ? 0 : d; + + *reinterpret_cast *>(quant_ptr) = quantized_values; + if (wi_id == 0) { + *ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum)); + } +} + +static void mul_mat_p021_f16_f32( + const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y, + const sycl::nd_item<3> &item_ct1) { + + const sycl::half *x = (const sycl::half *)vx; + + const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1); + const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) + + item_ct1.get_local_id(0); + const int channel_x = channel / (nchannels_y / nchannels_x); + + const int nrows_y = ncols_x; + const int nrows_dst = nrows_x; + const int row_dst = row_x; + + float tmp = 0.0f; + + for (int col_x0 = 0; col_x0 < ncols_x; + col_x0 += item_ct1.get_local_range(2)) { + const int col_x = col_x0 + item_ct1.get_local_id(2); + + if (col_x >= ncols_x) { + break; + } + + // x is transposed and permuted + const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x; + const float xi = + sycl::vec(x[ix]) + .convert()[0]; + + const int row_y = col_x; + + + // y is not transposed but permuted + const int iy = channel*nrows_y + row_y; + + tmp += xi * y[iy]; + } + + // dst is not transposed and not permuted + const int idst = channel*nrows_dst + row_dst; + + // sum up partial sums and write back result +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[idst] = tmp; + } +} + +static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous + const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, + const int row_stride_x, const int channel_stride_x, const int channel_x_divisor, + const sycl::nd_item<3> &item_ct1) { + + const sycl::half *x = (const sycl::half *)vx; + + const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1); + const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) + + item_ct1.get_local_id(0); + const int channel_x = channel / channel_x_divisor; + + const int nrows_y = ncols_x; + const int nrows_dst = nrows_x; + const int row_dst = row_x; + + const int idst = channel*nrows_dst + row_dst; + + float tmp = 0.0f; + + for (int col_x0 = 0; col_x0 < ncols_x; + col_x0 += item_ct1.get_local_range(2)) { + const int col_x = col_x0 + item_ct1.get_local_id(2); + + if (col_x >= ncols_x) { + break; + } + + const int row_y = col_x; + + const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; + const int iy = channel*nrows_y + row_y; + + const float xi = + sycl::vec(x[ix]) + .convert()[0]; + + tmp += xi * y[iy]; + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[idst] = tmp; + } +} + +static void k_sum_rows_f32(const float * x, float * dst, const int ncols, + const sycl::nd_item<3> &item_ct1) { + const int row = item_ct1.get_group(1); + const int col = item_ct1.get_local_id(2); + + float sum = 0.0f; + for (int i = col; i < ncols; i += item_ct1.get_local_range(2)) { + sum += x[row * ncols + i]; + } + + sum = warp_reduce_sum(sum, item_ct1); + + if (col == 0) { + dst[row] = sum; + } +} + + +template +static inline void ggml_sycl_swap(T & a, T & b) { + T tmp = a; + a = b; + b = tmp; +} + +template +__dpct_inline__ static void +k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad, + const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) { + // bitonic sort + int col = item_ct1.get_local_id(2); + int row = item_ct1.get_group(1); + + if (col >= ncols_pad) { + return; + } + + const float * x_row = x + row * ncols; + auto dst_row = (int *)dpct_local; + + // initialize indices + dst_row[col] = col; + + item_ct1.barrier(sycl::access::fence_space::local_space); + + for (int k = 2; k <= ncols_pad; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= ncols || + (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { + ggml_sycl_swap(dst_row[col], dst_row[ixj]); + } + } else { + if (dst_row[ixj] >= ncols || + (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { + ggml_sycl_swap(dst_row[col], dst_row[ixj]); + } + } + } + /* + DPCT1118:1: SYCL group functions and algorithms must be encountered + in converged control flow. You may need to adjust the code. + */ + item_ct1.barrier(sycl::access::fence_space::local_space); + } + } + + // copy the result to dst without the padding + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } +} + + +static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past, + const sycl::nd_item<3> &item_ct1) { + const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1); + const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (col >= ncols) { + return; + } + + const int i = row*ncols + col; + //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i]; + //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU + dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX; +} + +static void scale_f32(const float * x, float * dst, const float scale, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + + dst[i] = scale * x[i]; +} + + +template +static void pool2d_nchw_kernel( + const int ih, const int iw, const int oh, const int ow, + const int kh, const int kw, const int sh, const int sw, + const int ph, const int pw, const int parallel_elements, + const Ti* src, To* dst, const enum ggml_op_pool op, + const sycl::nd_item<3> &item_ct1) { + int idx = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + if (idx >= parallel_elements) { + return; + } + + const int I_HW = ih * iw; + const int O_HW = oh * ow; + const int nc = idx / O_HW; + const int cur_oh = idx % O_HW / ow; + const int cur_ow = idx % O_HW % ow; + const Ti* i_ptr = src + nc * I_HW; + To* o_ptr = dst + nc * O_HW; + const int start_h = cur_oh * sh - ph; + const int bh = sycl::max(0, start_h); + const int eh = sycl::min(ih, start_h + kh); + const int start_w = cur_ow * sw - pw; + const int bw = sycl::max(0, start_w); + const int ew = sycl::min(iw, start_w + kw); + + To res = 0; + + switch (op) { + case GGML_OP_POOL_AVG: res = 0; break; + case GGML_OP_POOL_MAX: res = -FLT_MAX; break; + default: + res = (To) sycl::nan(uint32_t(0)); + break; + } + + for (int i = bh; i < eh; i += 1) { + for (int j = bw; j < ew; j += 1) { +#if DPCT_COMPATIBILITY_TEMP >= 350 + /* + DPCT1098:106: The '*' expression is used instead of the __ldg + call. These two expressions do not provide the exact same + functionality. Check the generated code for potential precision + and/or performance issues. + */ + Ti cur = *(i_ptr + i * iw + j); +#else + Ti cur = i_ptr[i * iw + j]; +#endif + switch (op) { + case GGML_OP_POOL_AVG: res += (cur / (kh * kw)); break; + case GGML_OP_POOL_MAX: res = sycl::max(res, (To)cur); break; + default: + res = (To) sycl::nan(uint32_t(0)); + break; + } + } + } + o_ptr[cur_oh * ow + cur_ow] = res; +} + +static void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded, + bool reorder_q8_tensor, queue_ptr stream) { + if (reorder_q8_tensor) { + auto local_range = std::size_t(WARP_SIZE); + auto num_quant_blocks = ky * (kx / QK8_1); + auto global_range = num_quant_blocks * local_range; + stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), + [=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + quantize_and_reorder_q8_1(x, vy, kx, kx_padded, it); + }); + } else { + const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE; + const sycl::range<3> num_blocks(1, ky, block_num_x); + int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE; + static_assert(QK8_1 % WARP_SIZE == 0); + const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE); + { + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + + stream->parallel_for(sycl::nd_range<3>(num_blocks * block_size, block_size), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + quantize_q8_1(x, vy, kx, kx_padded, item_ct1); + }); + } + } +} + +static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y, + float *dst, const int ncols_x, + const int nrows_x, + const int nchannels_x, + const int nchannels_y, + queue_ptr stream) { + + const sycl::range<3> block_nums(nchannels_y, nrows_x, 1); + const sycl::range<3> block_dims(1, 1, WARP_SIZE); + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x, + nchannels_y, item_ct1); + }); + } +} + +static void ggml_mul_mat_vec_nc_f16_f32_sycl( + const void *vx, const float *y, float *dst, const int ncols_x, + const int nrows_x, const int row_stride_x, const int nchannels_x, + const int nchannels_y, const int channel_stride_x, queue_ptr stream) { + + const sycl::range<3> block_nums(nchannels_y, nrows_x, 1); + const sycl::range<3> block_dims(1, 1, WARP_SIZE); + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x, + row_stride_x, channel_stride_x, + nchannels_y / nchannels_x, item_ct1); + }); + } +} + + + +static void scale_f32_sycl(const float *x, float *dst, const float scale, + const int k, queue_ptr stream) { + const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + scale_f32(x, dst, scale, k, item_ct1); + }); +} + + +static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols, + const int nrows, queue_ptr stream) { + const sycl::range<3> block_dims(1, 1, WARP_SIZE); + const sycl::range<3> block_nums(1, nrows, 1); + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + k_sum_rows_f32(x, dst, ncols, item_ct1); + }); +} + +static int next_power_of_2(int x) { + int n = 1; + while (n < x) { + n *= 2; + } + return n; +} + +static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, + const int nrows, ggml_sort_order order, + queue_ptr stream) { + // bitonic sort requires ncols to be power of 2 + const int ncols_pad = next_power_of_2(ncols); + + const sycl::range<3> block_dims(1, 1, ncols_pad); + const sycl::range<3> block_nums(1, nrows, 1); + const size_t shared_mem = ncols_pad * sizeof(int); + + if (order == GGML_SORT_ORDER_ASC) { + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor dpct_local_acc_ct1( + sycl::range<1>(shared_mem), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_argsort_f32_i32( + x, dst, ncols, ncols_pad, item_ct1, + dpct_local_acc_ct1.get_multi_ptr() + .get()); + }); + }); + } else if (order == GGML_SORT_ORDER_DESC) { + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor dpct_local_acc_ct1( + sycl::range<1>(shared_mem), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_argsort_f32_i32( + x, dst, ncols, ncols_pad, item_ct1, + dpct_local_acc_ct1.get_multi_ptr() + .get()); + }); + }); + } else { + GGML_ABORT("fatal error"); + } +} + +static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols, + const int nrows, queue_ptr stream) { + const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE); + const sycl::range<3> block_nums(1, nrows, 1); + const size_t shared_mem = 256 * sizeof(float); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor shared_data( + sycl::range<1>(shared_mem/sizeof(float)), cgh); + sycl::local_accessor shared_indices( + sycl::range<1>(shared_mem/sizeof(float)), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + const int tid = item_ct1.get_local_id(2); + const int row = item_ct1.get_global_id(1); + + float max_val = -INFINITY; + int max_idx = -1; + + for (int col = tid; col < ncols; col += 256) { + float val = x[row * ncols + col]; + if (val > max_val) { + max_val = val; + max_idx = col; + } + } + + shared_data[tid] = max_val; + shared_indices[tid] = max_idx; + item_ct1.barrier(sycl::access::fence_space::local_space); + + for (int stride = 256/2; stride > 0; stride >>= 1) { + if (tid < stride) { + float val1 = shared_data[tid]; + float val2 = shared_data[tid + stride]; + if (val2 > val1) { + shared_data[tid] = val2; + shared_indices[tid] = shared_indices[tid + stride]; + } + } + item_ct1.barrier(sycl::access::fence_space::local_space); + } + + + if (tid == 0) { + dst[row] = shared_indices[0]; + } + }); + }); +} +static void diag_mask_inf_f32_sycl(const float *x, float *dst, + const int ncols_x, const int nrows_x, + const int rows_per_channel, const int n_past, + queue_ptr stream) { + const sycl::range<3> block_dims(1, SYCL_DIAG_MASK_INF_BLOCK_SIZE, 1); + const int block_num_x = (ncols_x + SYCL_DIAG_MASK_INF_BLOCK_SIZE - 1) / SYCL_DIAG_MASK_INF_BLOCK_SIZE; + const sycl::range<3> block_nums(1, block_num_x, nrows_x); + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + diag_mask_inf_f32(x, dst, ncols_x, + rows_per_channel, n_past, + item_ct1); + }); +} + +static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst, + const struct ggml_tensor *src, + int64_t i3, int64_t i2, + int64_t i1_low, int64_t i1_high, + queue_ptr stream) try { + + dpct::memcpy_direction kind; + char * src_ptr; + if (ggml_backend_buffer_is_host(src->buffer)) { + kind = dpct::host_to_device; + //GGML_SYCL_DEBUG("%s: Host buffer type src tensor\n", __func__); + src_ptr = (char *) src->data; + // GGML_SYCL_DEBUG("ggml_sycl_cpy_tensor_2d GGML_BACKEND_TYPE_CPU src_ptr %p\n", src_ptr); + } else if (ggml_backend_buffer_is_sycl(src->buffer)) { + // If buffer is a SYCL buffer + //GGML_SYCL_DEBUG("%s: SYCL buffer type src tensor\n", __func__); + kind = dpct::device_to_device; + src_ptr = (char *) src->data; + } else if (ggml_backend_buffer_is_sycl_split(src->buffer)) { + /* + If buffer is a SYCL split buffer + */ + //GGML_SYCL_DEBUG("%s: Split buffer type src tensor\n", __func__); + GGML_ASSERT(i1_low == 0 && i1_high == src->ne[1]); + kind = dpct::device_to_device; + ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra; + int id; + SYCL_CHECK(CHECK_TRY_ERROR( + id = get_current_device_id())); + // GGML_SYCL_DEBUG("current device index %d\n", id); + src_ptr = (char *) extra->data_device[id]; + } else { + // GGML_SYCL_DEBUG("GGML_ABORT("fatal error")\n"); + GGML_ABORT("fatal error"); + } + char * dst_ptr = (char *) dst; + + GGML_TENSOR_LOCALS_1(int64_t, ne, src, ne); + GGML_TENSOR_LOCALS(int64_t, nb, src, nb); + const enum ggml_type type = src->type; + const int64_t ts = ggml_type_size(type); + const int64_t bs = ggml_blck_size(type); + int64_t i1_diff = i1_high - i1_low; + + const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3; + if (nb0 == ts && nb1 == ts*ne0/bs) { + // GGML_SYCL_DEBUG("stream->memcpy: dst_ptr=%p, x=%p, size=%lu\n", dst_ptr, x, i1_diff * nb1); + // return CHECK_TRY_ERROR(stream->memcpy(dst_ptr, x, i1_diff * nb1)); + return CHECK_TRY_ERROR(dpct::async_dpct_memcpy(dst_ptr, x, i1_diff * nb1, + kind, *stream)); + + } else if (nb0 == ts) { + return CHECK_TRY_ERROR( + dpct::async_dpct_memcpy(dst_ptr, ts * ne0 / bs, x, nb1, + ts * ne0 / bs, i1_diff, kind, *stream)); + } else { + for (int64_t i1 = 0; i1 < i1_diff; i1++) { + const void * rx = (const void *) ((const char *) x + i1*nb1); + void * rd = (void *) (dst_ptr + i1*ts*ne0/bs); + // pretend the row is a matrix with cols=1 + dpct::err0 r = CHECK_TRY_ERROR(dpct::async_dpct_memcpy( + rd, ts / bs, rx, nb0, ts / bs, ne0, kind, *stream)); + /* + DPCT1001:85: The statement could not be removed. + */ + /* + DPCT1000:86: Error handling if-stmt was detected but could not be + rewritten. + */ + if (r != 0) return r; + } + return 0; + } +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +inline void ggml_sycl_op_mul_mat_sycl( + ggml_backend_sycl_context & ctx, + const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, + const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, + float *dst_dd_i, const int64_t row_low, const int64_t row_high, + const int64_t src1_ncols, const int64_t src1_padded_row_size, + const queue_ptr &stream) try { + + GGML_ASSERT(src0_dd_i != nullptr); + GGML_ASSERT(src1_ddf_i != nullptr); + GGML_ASSERT(dst_dd_i != nullptr); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne10 = src1->ne[0]; + GGML_ASSERT(ne00 == ne10); + + const int64_t row_diff = row_high - row_low; + + int id; + SYCL_CHECK( + CHECK_TRY_ERROR(id = get_current_device_id())); + + const int64_t ne0 = dst->ne[0]; // used by MKL only + // the main device has a larger memory buffer to hold the results from all GPUs + // ldc == nrows of the matrix that cuBLAS writes into + int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only + +#ifdef GGML_SYCL_F16 + bool use_fp16 = true; // TODO(Yu) SYCL capability check +#else + bool use_fp16 = false; +#endif + if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && use_fp16 && ggml_is_contiguous(src0) && + row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { + ggml_sycl_pool_alloc src0_as_f16(ctx.pool()); + if (src0->type != GGML_TYPE_F16) { + scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2, + " : converting src0 to fp16"); + const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type, dst); + GGML_ASSERT(to_fp16_sycl != nullptr); + size_t ne = row_diff*ne00; + src0_as_f16.alloc(ne); + to_fp16_sycl(src0_dd_i, src0_as_f16.get(), ne, stream); + } + const sycl::half *src0_ptr = src0->type == GGML_TYPE_F16 + ? (const sycl::half *)src0_dd_i + : src0_as_f16.get(); + + ggml_sycl_pool_alloc src1_as_f16(ctx.pool()); + if (src1->type != GGML_TYPE_F16) { + scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2, + " : converting src1 to fp16"); + const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst); + GGML_ASSERT(to_fp16_sycl != nullptr); + size_t ne = src1_ncols*ne10; + src1_as_f16.alloc(ne); + to_fp16_sycl(src1_ddf_i, src1_as_f16.get(), ne, stream); + } + const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16 + ? (const sycl::half *)src1->data + src1_padded_row_size + : src1_as_f16.get(); + +#if GGML_SYCL_DNNL + if (!g_ggml_sycl_disable_dnn) { + DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr, + DnnlGemmWrapper::to_dt(), src0_ptr, DnnlGemmWrapper::to_dt(), + dst_dd_i, DnnlGemmWrapper::to_dt(), stream); + } + else +#endif + { + ggml_sycl_pool_alloc dst_f16(ctx.pool(), row_diff * src1_ncols); + + const sycl::half alpha_f16 = 1.0f; + const sycl::half beta_f16 = 0.0f; + SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( + *stream, oneapi::math::transpose::trans, + oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10, + &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00, + src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16, + dst_f16.get(), dpct::library_data_t::real_half, ldc, + dpct::library_data_t::real_half))); + scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2, + " : converting dst to fp32"); + const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); + to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); + } + } else { + ggml_sycl_pool_alloc src0_ddq_as_f32(ctx.pool()); + ggml_sycl_pool_alloc src1_ddq_as_f32(ctx.pool()); + if (src0->type != GGML_TYPE_F32) { + scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2, + " : converting src0 to fp32"); + const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type, dst); + GGML_ASSERT(to_fp32_sycl != nullptr); + src0_ddq_as_f32.alloc(row_diff*ne00); + to_fp32_sycl(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream); + } + if (src1->type != GGML_TYPE_F32) { + scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2, + " : converting src1 to fp32"); + const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type, dst); + GGML_ASSERT(to_fp32_sycl != nullptr); + src1_ddq_as_f32.alloc(src1_ncols*ne10); + to_fp32_sycl(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream); + } + const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get(); + const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get(); + +#if GGML_SYCL_DNNL + if (!g_ggml_sycl_disable_dnn) { + DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i, + DnnlGemmWrapper::to_dt(), src0_ddf_i, DnnlGemmWrapper::to_dt(), + dst_dd_i, DnnlGemmWrapper::to_dt(), stream); + } + else +#endif + { + const float alpha = 1.0f; + const float beta = 0.0f; + SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm( + get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff, + src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, + dpct::get_value(&beta, *stream), dst_dd_i, ldc))); + } + } + GGML_UNUSED(dst); + GGML_UNUSED(src1_ddq_i); + GGML_UNUSED(src1_padded_row_size); +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); + + const int32_t * opts = (const int32_t *)dst->op_params; + enum ggml_op_pool op = static_cast(opts[0]); + const int k0 = opts[1]; + const int k1 = opts[2]; + const int s0 = opts[3]; + const int s1 = opts[4]; + const int p0 = opts[5]; + const int p1 = opts[6]; + + const int64_t IH = dst->src[0]->ne[1]; + const int64_t IW = dst->src[0]->ne[0]; + + const int64_t N = dst->ne[3]; + const int64_t OC = dst->ne[2]; + const int64_t OH = dst->ne[1]; + const int64_t OW = dst->ne[0]; + + const int parallel_elements = N * OC * OH * OW; + const int num_blocks = (parallel_elements + SYCL_POOL2D_BLOCK_SIZE - 1) / SYCL_POOL2D_BLOCK_SIZE; + sycl::range<3> block_nums(1, 1, num_blocks); + main_stream->parallel_for( + sycl::nd_range<3>(block_nums * + sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + pool2d_nchw_kernel(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, + parallel_elements, src0_dd, dst_dd, op, + item_ct1); + }); +} + +inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); + + const int64_t ne = ggml_nelements(dst->src[0]); + + sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream); +} + +inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); + + const int64_t ncols = dst->src[0]->ne[0]; + const int64_t nrows = ggml_nrows(dst->src[0]); + + sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream); +} + +inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_I32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + int32_t * dst_dd = static_cast(dst->data); + + + const int64_t ncols = dst->src[0]->ne[0]; + const int64_t nrows = ggml_nrows(dst->src[0]); + + enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; + + argsort_f32_i32_sycl(src0_dd, (int *) dst_dd, ncols, nrows, order, main_stream); +} + +inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + int32_t * dst_dd = static_cast(dst->data); + + const int64_t ncols = dst->src[0]->ne[0]; + const int64_t nrows = ggml_nrows(dst->src[0]); + + argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream); +} + +inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); + + const int64_t ne00 = dst->src[0]->ne[0]; + const int64_t ne01 = dst->src[0]->ne[1]; + const int nrows0 = ggml_nrows(dst->src[0]); + + const int n_past = ((int32_t *) dst->op_params)[0]; + + diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream); +} + +inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); + + float scale; + memcpy(&scale, dst->op_params, sizeof(float)); + + scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream); + /* + DPCT1010:87: SYCL uses exceptions to report errors and does not use the + error codes. The call was replaced with 0. You need to rewrite this code. + */ + SYCL_CHECK(0); +} + +static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) { + static bool peer_access_enabled = false; + + const bool enable_peer_access = n_tokens <= GGML_SYCL_PEER_MAX_BATCH_SIZE; + + if (peer_access_enabled == enable_peer_access) { + return; + } + +#ifdef NDEBUG + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + SYCL_CHECK(ggml_sycl_set_device(i)); + } + + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + SYCL_CHECK(ggml_sycl_set_device(i)); + + for (int id_other = 0; id_other < ggml_sycl_info().device_count; ++id_other) { + if (i == id_other) { + continue; + } + if (i != main_device && id_other != main_device) { + continue; + } + + // int can_access_peer; + // SYCL_CHECK(syclDeviceCanAccessPeer(&can_access_peer, id, id_other)); + // if (can_access_peer) { + // if (enable_peer_access) { + // SYCL_CHECK(syclDeviceEnablePeerAccess(id_other, 0)); + // } else { + // SYCL_CHECK(syclDeviceDisablePeerAccess(id_other)); + // } + // } + } + } +#endif // NDEBUG + + peer_access_enabled = enable_peer_access; +} + +static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + ggml_sycl_op_mul_mat_t op, + const bool convert_src1_to_q8_1) try { + + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); + + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); + const int64_t nrows1 = ggml_nrows(src1); + + GGML_ASSERT(ne03 == ne13); + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer)); + GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src1->buffer)); + GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1)); + + GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0); + + const int64_t i02_divisor = ne12 / ne02; + + const size_t src0_ts = ggml_type_size(src0->type); + const size_t src0_bs = ggml_blck_size(src0->type); + const size_t q8_1_ts = sizeof(block_q8_1); + const size_t q8_1_bs = QK8_1; + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + + const bool src0_is_contiguous = ggml_is_contiguous(src0); + const bool src1_is_contiguous = ggml_is_contiguous(src1); + + int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING); + + const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer); + GGML_ASSERT(!(split && ne02 > 1)); + GGML_ASSERT(!(split && ne03 > 1)); + GGML_ASSERT(!(split && ne02 < ne12)); + + std::array tensor_split; + if (split) { + // TODO: check that src0->buffer->buft is a split buffer type, replace GGML_BACKEND_TYPE_GPU_SPLIT check + // GGML_ASSERT(src0->buffer != nullptr && src0->buffer->buft == ...); + ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context; + tensor_split = buft_ctx->tensor_split; + } + + struct dev_data { + ggml_sycl_pool_alloc src0_dd_alloc; + ggml_sycl_pool_alloc src1_ddf_alloc; + ggml_sycl_pool_alloc src1_ddq_alloc; + ggml_sycl_pool_alloc dst_dd_alloc; + + char *src0_dd = nullptr; + float *src1_ddf = nullptr; // float + char *src1_ddq = nullptr; // q8_1 + float *dst_dd = nullptr; + + int64_t row_low; + int64_t row_high; + }; + + dev_data dev[GGML_SYCL_MAX_DEVICES]; + + int used_devices = 0; + queue_ptr main_stream = ctx.stream(); + + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + // by default, use all rows + dev[i].row_low = 0; + dev[i].row_high = ne01; + + // for multi GPU, get the row boundaries from tensor split + // and round to mul_mat_q tile sizes + if (split) { + const int64_t rounding = get_row_rounding(src0->type, tensor_split); + + if (i != 0) { + dev[i].row_low = ne01*tensor_split[i]; + if (dev[i].row_low < ne01) { + dev[i].row_low -= dev[i].row_low % rounding; + } + } + + if (i != ggml_sycl_info().device_count - 1) { + dev[i].row_high = ne01*tensor_split[i + 1]; + if (dev[i].row_high < ne01) { + dev[i].row_high -= dev[i].row_high % rounding; + } + } + } + } + + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) { + continue; + } + + used_devices++; + + const bool src1_on_device = i == ctx.device; + const bool dst_on_device = i == ctx.device; + + ggml_sycl_set_device(i); + queue_ptr stream = ctx.stream(i, 0); + + if (src0_is_contiguous) { + dev[i].src0_dd = (char *) src0->data; + } else { + dev[i].src0_dd = dev[i].src0_dd_alloc.alloc(ctx.pool(i), ggml_nbytes(src0)); + } + + if (src1_on_device && src1_is_contiguous) { + dev[i].src1_ddf = (float *) src1->data; + } else { + dev[i].src1_ddf = dev[i].src1_ddf_alloc.alloc(ctx.pool(i), ggml_nelements(src1)); + } + + if (convert_src1_to_q8_1) { + dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs); + + if (src1_on_device && src1_is_contiguous) { + bool reorder_q8_tensor = src0->extra && ((ggml_tensor_extra_gpu *)src0->extra)->optimized_feature.reorder; + scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst, + /*num_src=*/2, " : converting src1 to Q8_1"); + quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, reorder_q8_tensor, stream); + /* + DPCT1010:90: SYCL uses exceptions to report errors and does not + use the error codes. The call was replaced with 0. You need to + rewrite this code. + */ + SYCL_CHECK(0); + } + } + + if (dst_on_device) { + dev[i].dst_dd = (float *) dst->data; + } else { + const size_t size_dst_ddf = split ? (dev[i].row_high - dev[i].row_low)*ne1 : ggml_nelements(dst); + dev[i].dst_dd = dev[i].dst_dd_alloc.alloc(ctx.pool(i), size_dst_ddf); + } + } + + // if multiple devices are used they need to wait for the main device + // here an event is recorded that signals that the main device has finished calculating the input data + if (split && used_devices > 1) { + ggml_sycl_set_device(ctx.device); + /* + DPCT1024:91: The original code returned the error code that was further + consumed by the program logic. This original code was replaced with 0. + You may need to rewrite the program logic consuming the error code. + */ + SYCL_CHECK(CHECK_TRY_ERROR( + *src0_extra->events[ctx.device][0] = + ctx.stream()->ext_oneapi_submit_barrier())); + } + + const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11; + for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) { + const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_SYCL_MAX_STREAMS : 0; + const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride; + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) { + continue; + } + + const bool src1_on_device = i == ctx.device; + const bool dst_on_device = i == ctx.device; + const int64_t row_diff = dev[i].row_high - dev[i].row_low; + + ggml_sycl_set_device(i); + queue_ptr stream = ctx.stream(i, is); + + // wait for main GPU data if necessary + if (split && (i != ctx.device || is != 0)) { + /* + DPCT1009:163: SYCL uses exceptions to report errors and does not + use the error codes. The original code was commented out and a + warning string was inserted. You need to rewrite this code. + */ + SYCL_CHECK(CHECK_TRY_ERROR(stream->ext_oneapi_submit_barrier( + {*src0_extra->events[ctx.device][0]}))); + } + + for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) { + const int64_t i03 = i0 / ne12; + const int64_t i02 = i0 % ne12; + + const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs; + + // for split tensors the data begins at i0 == i0_offset_low + char * src0_dd_i = dev[i].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs; + float * src1_ddf_i = dev[i].src1_ddf + (i0*ne11 + src1_col_0) * ne10; + char * src1_ddq_i = dev[i].src1_ddq + src1_ddq_i_offset; + float * dst_dd_i = dev[i].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff); + + // the main device memory buffer can be on VRAM scratch, with space for all partial results + // in that case an offset on dst_ddf_i is needed + if (i == ctx.device) { + dst_dd_i += dev[i].row_low; // offset is 0 if no tensor split + } + + // copy src0, src1 to device if necessary + if (src1_is_contiguous) { + if (i != ctx.device) { + if (convert_src1_to_q8_1) { + char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset; + SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy( + src1_ddq_i, src1_ddq_i_source, + src1_ncols * src1_padded_col_size * q8_1_ts / + q8_1_bs).wait())); + } else { + + float * src1_ddf_i_source = (float *) src1_extra->data_device[ctx.device]; + src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10; + + SYCL_CHECK(CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream, + src1_ddf_i, src1_ddf_i_source, + src1_ncols * ne10 * sizeof(float)))); + } + } + } else if (src1_on_device && !src1_is_contiguous) { + SYCL_CHECK(ggml_sycl_cpy_tensor_2d( + src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream)); + } else { + GGML_ABORT("fatal error"); + } + + if (convert_src1_to_q8_1 && !src1_is_contiguous) { + scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst, + /*num_src=*/2, " : converting src1 to Q8_1"); + quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, false, stream); + /* + DPCT1010:92: SYCL uses exceptions to report errors and does + not use the error codes. The call was replaced with 0. You + need to rewrite this code. + */ + SYCL_CHECK(0); + } + + if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) { + SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[i].row_low, dev[i].row_high, stream)); + } + if (src1->type == GGML_TYPE_F16) { + src1_padded_col_size = (i0 * ne11 + src1_col_0) * ne10; + } + // do the computation + SYCL_CHECK(CHECK_TRY_ERROR(op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i, + dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream))); + /* + DPCT1010:93: SYCL uses exceptions to report errors and does not + use the error codes. The call was replaced with 0. You need to + rewrite this code. + */ + SYCL_CHECK(0); + + // copy dst to host or other device if necessary + if (!dst_on_device) { + void * dst_off_device = dst->data; + if (split) { + // src0 = weight matrix is saved as a transposed matrix for better memory layout. + // dst is NOT transposed. + // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU. + // Instead they need to be copied to the correct slice in ne0 = dst row index. + // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results. + float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); + GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); + dhf_dst_i += src1_col_0*ne0 + dev[i].row_low; + + SYCL_CHECK(CHECK_TRY_ERROR(dpct::async_dpct_memcpy( + dhf_dst_i, ne0 * sizeof(float), dst_dd_i, + row_diff * sizeof(float), row_diff * sizeof(float), + src1_ncols, dpct::device_to_device, *stream))); + } else { + float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); + GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); + dhf_dst_i += src1_col_0*ne0; + SYCL_CHECK(CHECK_TRY_ERROR( + stream->memcpy(dhf_dst_i, dst_dd_i, + src1_ncols * ne0 * sizeof(float)).wait())); + } + } + + // add event for the main device to wait on until other device is done + if (split && (i != ctx.device || is != 0)) { + /* + DPCT1024:94: The original code returned the error code that + was further consumed by the program logic. This original + code was replaced with 0. You may need to rewrite the + program logic consuming the error code. + */ + SYCL_CHECK(CHECK_TRY_ERROR( + *src0_extra->events[i][is] = + stream->ext_oneapi_submit_barrier())); + } + } + } + } + + // main device waits for all other devices to be finished + if (split && ggml_sycl_info().device_count > 1) { + int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE; + is_max = is_max <= GGML_SYCL_MAX_STREAMS ? is_max : GGML_SYCL_MAX_STREAMS; + + ggml_sycl_set_device(ctx.device); + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + if (dev[i].row_low == dev[i].row_high) { + continue; + } + for (int64_t is = 0; is < is_max; ++is) { + SYCL_CHECK(CHECK_TRY_ERROR( + ctx.stream()->ext_oneapi_submit_barrier( + {*src0_extra->events[i][is]}))); + } + } + } +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + + +static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_get_rows(ctx, dst); +} + +static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_norm(ctx, dst); +} + +static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_rms_norm(ctx, dst); +} + +static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_l2_norm(ctx, dst); +} + +static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_group_norm(ctx, dst); +} + +static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, + ggml_tensor *dst) try { + GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); + GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer)); + GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation + GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + + const int64_t ne12 = src1->ne[2]; + + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + queue_ptr main_stream = ctx.stream(); + + void * src0_ddq = src0->data; + float * src1_ddf = (float *) src1->data; + float * dst_ddf = (float *) dst->data; + + ggml_mul_mat_p021_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream); +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, + ggml_tensor *dst) try { + GGML_ASSERT(!ggml_is_transposed(src0)); + GGML_ASSERT(!ggml_is_transposed(src1)); + GGML_ASSERT(!ggml_is_permuted(src0)); + GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer)); + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + + const int64_t nb01 = src0->nb[1]; + const int64_t nb02 = src0->nb[2]; + + const int64_t ne12 = src1->ne[2]; + + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + queue_ptr main_stream = ctx.stream(); + + void * src0_ddq = src0->data; + float * src1_ddf = (float *) src1->data; + float * dst_ddf = (float *) dst->data; + + const int64_t row_stride_x = nb01 / sizeof(sycl::half); + const int64_t channel_stride_x = nb02 / sizeof(sycl::half); + + ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst, + const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23, + size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3, + int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) { + const int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2); + const int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); + + if (i13 >= ne13 || i12 >= ne12) { + return; + } + + const int64_t i03 = i13 / r3; + const int64_t i02 = i12 / r2; + + const uint8_t * src0_bytes = reinterpret_cast(src0_as_f16); + const uint8_t * src1_bytes = reinterpret_cast(src1_as_f16); + uint8_t * dst_bytes = static_cast(dst); + + ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03; + ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13; + ptrs_dst[0 * ne23 + i12 + i13 * ne12] = dst_bytes + i12 * nbd2 + i13 * nbd3; +} + +static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, + const ggml_tensor * src1, ggml_tensor * dst) try { + GGML_ASSERT(!ggml_is_transposed(src0)); + GGML_ASSERT(!ggml_is_transposed(src1)); + GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer)); + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + // TODO: see https://github.com/ggml-org/llama.cpp/pull/13155 + // Batched mul_mat requires a rewrite to support both oneDNN and non-contiguous dst + GGML_ASSERT(ggml_is_contiguous(dst)); + + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + queue_ptr queue = ctx.stream(); + + dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 }); + + const sycl::half * src0_f16 = static_cast(src0->data); + float * dst_ddf = static_cast(dst->data); + + const sycl::half * src1_f16 = static_cast(src1->data); + const size_t type_size_src1 = ggml_type_size(src1->type); + GGML_ASSERT(nb10 == type_size_src1); + + // SRC1 strides + int64_t s11 = nb11 / type_size_src1; + int64_t s12 = nb12 / type_size_src1; + int64_t s13 = nb13 / type_size_src1; + ggml_sycl_pool_alloc src1_f16_alloc(ctx.pool()); + + // convert src1 to fp16 + if (src1->type != GGML_TYPE_F16) { + scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, /*num_src=*/2, + " : converting src1 to fp16"); + const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type); + GGML_ASSERT(to_fp16_nc_sycl != nullptr); + const int64_t ne_src1 = ggml_nelements(src1); + src1_f16_alloc.alloc(ne_src1); + to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue); + + src1_f16 = src1_f16_alloc.get(); + s11 = ne10; + s12 = ne11 * s11; + s13 = ne12 * s12; + } + + ggml_sycl_pool_alloc dst_f16(ctx.pool()); + + dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float; + dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float; + + // dst strides + size_t nbd2 = dst->nb[2]; + size_t nbd3 = dst->nb[3]; + + const float alpha_f32 = 1.0f; + const float beta_f32 = 0.0f; + + const void * alpha = &alpha_f32; + const void * beta = &beta_f32; + + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); + GGML_ASSERT(ne01 == static_cast(nb1/nb0)); + GGML_ASSERT(ne10 == ne00); + + // broadcast factors + const int64_t r2 = ne12 / ne02; + const int64_t r3 = ne13 / ne03; + +#if GGML_SYCL_DNNL + if (!g_ggml_sycl_disable_dnn) { + auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12] + (const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) { + + DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10, + src1, DnnlGemmWrapper::to_dt(), s11, 1, s12, + src0, DnnlGemmWrapper::to_dt(), 1, nb01/nb00, nb02/nb00, + dst, DnnlGemmWrapper::to_dt(), queue, batches_a, batches_b); + }; + + if (r2 == 1 && r3 == 1) { + if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { + dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03); + } + else { + for (int64_t ie03 = 0; ie03 < ne03; ++ie03) { + const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes + const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13; + float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float)); + dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02); + } + } + } else { + // iterate over batches from smaller set of matrices (matrix 0) + for (int64_t ie02 = 0; ie02 < ne02; ++ie02) { + for (int64_t ie03 = 0; ie03 < ne03; ++ie03) { + const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half)); + const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3; + float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float)); + dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1); + } + } + } + } + else +#endif + { + if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { + // there is no broadcast and src0, src1 are contiguous across dims 2, 3 + SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans, + oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, + src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00, + src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf, + mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type))); + } else { + const int ne23 = ne12 * ne13; + + ggml_sycl_pool_alloc ptrs_src(ctx.pool(), 2 * ne23); + ggml_sycl_pool_alloc ptrs_dst(ctx.pool(), 1 * ne23); + ggml_sycl_pool_alloc> matrix_info(ctx.host_pool(), 1); + + sycl::range<3> block_dims(1, ne12, ne13); + queue->submit([&](sycl::handler & cgh) { + const void ** ptrs_src_get = ptrs_src.get(); + void ** ptrs_dst_get = ptrs_dst.get(); + size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half); + size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half); + cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02, + nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1); + }); + }); + + SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( + *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, + (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00, + (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta, + (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get()))); + } + } +} catch (const sycl::exception & exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +enum class mul_mat_algo { + DMMV = 0, + MMVQ = 1, + MUL_MAT_SYCL = 2, +}; + +inline bool ggml_sycl_supports_mmq(enum ggml_type type) { + // TODO: accuracy issues in MMQ + GGML_UNUSED(type); + return false; +} + +inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return true; + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: + return !g_ggml_sycl_prioritize_dmmv; + default: + return false; + } +} + +inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return true; + default: + return false; + } +} + +inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: + return true; + default: + return false; + } +} + +static bool ggml_sycl_supports_dmmv(enum ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_F16: + return true; + default: + return false; + } +} + +static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, + dpct::queue_ptr stream) { + auto * tmp_buf = sycl::malloc_shared(size, *stream); + SYCL_CHECK( + CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size) + .wait())); + GGML_ASSERT((size % sizeof(block_q4_0) == 0)); + GGML_ASSERT((offset % sizeof(block_q4_0) == 0)); + int offset_blks = offset / sizeof(block_q4_0); + auto qs_ptr = data_device + offset_blks * QK4_0 / 2; + auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks; + + stream->parallel_for( + size / sizeof(block_q4_0), + [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + const block_q4_0* x = (const block_q4_0*)tmp_buf; + const int ib = i; + + for (int j = 0; j < QK4_0/2; j ++) + { + *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j]; + } + *(d_ptr + ib) = x[ib].d; + }).wait_and_throw(); + + sycl::free(tmp_buf, *stream); +} + +static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { + GGML_ASSERT(size % sizeof(block_q4_K) == 0); + GGML_ASSERT(offset % sizeof(block_q4_K) == 0); + + const int nblocks = size / sizeof(block_q4_K); + + auto * tmp_buf = sycl::malloc_shared(size, *stream); + SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait())); + + auto * qs_ptr = data_device; + auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks; + auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks); + + stream->parallel_for(nblocks, [=](auto i) { + const block_q4_K * x = (const block_q4_K *) tmp_buf; + const int ib = i; + + for (int j = 0; j < QK_K / 2; ++j) { + qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j]; + } + + for (int j = 0; j < K_SCALE_SIZE; ++j) { + scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j]; + } + + dm_ptr[ib] = x[ib].dm; + }).wait_and_throw(); + + sycl::free(tmp_buf, *stream); +} + +static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { + GGML_ASSERT(size % sizeof(block_q6_K) == 0); + GGML_ASSERT(offset % sizeof(block_q6_K) == 0); + + const int nblocks = size / sizeof(block_q6_K); + + auto * tmp_buf = sycl::malloc_shared(size, *stream); + SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait())); + + auto * ql_ptr = data_device; + auto * qh_ptr = ql_ptr + (QK_K / 2) * nblocks; + auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks; + sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks); + + stream + ->parallel_for(nblocks, + [=](auto i) { + const block_q6_K * x = (const block_q6_K *) tmp_buf; + const int ib = i; + + const uint8_t * ql = x[ib].ql; + const uint8_t * qh = x[ib].qh; + uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib; + uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib; + uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib; + + for (int j = 0; j < QK_K / 2; ++j) { + base_ql_ptr[j] = ql[j]; + } + for (int j = 0; j < QK_K / 4; ++j) { + base_qh_ptr[j] = qh[j]; + } + + for (int j = 0; j < QK_K / 16; ++j) { + base_scales_ptr[j] = x[ib].scales[j]; + } + + dm_ptr[ib] = x[ib].d; + }) + .wait_and_throw(); + + sycl::free(tmp_buf, *stream); +} + +static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { + uint8_t * data_device = (uint8_t *) src0->data; + size_t ncols = src0->ne[0]; + size_t nrows = src0->ne[1]; + size_t size = ggml_nbytes(src0); + + switch (src0->type) { + case GGML_TYPE_Q4_0: + reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream); + break; + case GGML_TYPE_Q4_K: + reorder_qw_q4_k(data_device, size, 0, stream); + break; + case GGML_TYPE_Q6_K: + reorder_qw_q6_k(data_device, size, 0, stream); + break; + default: + GGML_ABORT("reorder_qw() called with unsupported type"); + break; + } +} + +static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) { + return !g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT + ctx.opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf. + dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases. + dst->src[1]->ne[1]==1 && dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1; +} + +static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */, + ggml_tensor * dst, mul_mat_algo mm_algorithm) { + if (!should_reorder_tensor(*ctx, dst)) { + return; + } + + ggml_tensor_extra_gpu * extra = static_cast(src0->extra); + if (!extra || extra->optimized_feature.reorder) { + return; // Skip permutations and already reordered tensors + } + + switch (mm_algorithm) { + case mul_mat_algo::DMMV: + if (!ggml_sycl_supports_reorder_dmmv(src0->type)) { + return; + } + break; + case mul_mat_algo::MMVQ: + if (!ggml_sycl_supports_reorder_mmvq(src0->type)) { + return; + } + break; + case mul_mat_algo::MUL_MAT_SYCL: + if (!ggml_sycl_supports_reorder_mul_mat_sycl(src0->type)) { + return; + } + break; + } + + reorder_qw(src0, ctx->stream()); + extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering +} + + +static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && + src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1; +} + +static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + return ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && + src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; +} + +static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer); + int64_t min_compute_capability = INT_MAX; + + if (split) { + ggml_backend_sycl_split_buffer_type_context * buft_ctx = + (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context; + auto & tensor_split = buft_ctx->tensor_split; + for (int id = 0; id < ggml_sycl_info().device_count; ++id) { + // skip devices that are not going to do any work: + if (tensor_split[id] >= (id + 1 < ggml_sycl_info().device_count ? tensor_split[id + 1] : 1.0f)) { + continue; + } + + if (min_compute_capability > ggml_sycl_info().devices[id].cc) { + min_compute_capability = ggml_sycl_info().devices[id].cc; + } + } + } else { + min_compute_capability = ggml_sycl_info().devices[ctx.device].cc; + } + + // check data types and tensor shapes for custom matrix multiplication kernels: + bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec(src0, src1, dst); + + bool use_mul_mat_vec_q = can_use_mul_mat_vec_q(src0, src1, dst); + + bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; + + // mmvq and mmq need the __dp4a instruction which is available for gen12+ + // Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e + use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS); +#ifdef SYCL_USE_XMX + use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE); +#endif // SYCL_USE_XMX + + + // mmvq path is faster in the CUDA backend. + if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda + // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization + // is enabled takes precedence over DMMV, the current if-else implementation + // requires disabling DMMV if both conditions are met + || (should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) { + use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; + } + + if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { + // TODO: Refactor and cleanup of mul mat dispatching. + if (src0->ne[3] == 1 && src1->ne[3] == 1) { + // KQ single-batch + // mmv p021 was specific for these dimensions + ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst); + } else { + // The kernel from the if path is faster for that specific case, but does not support all mul mats. + ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); + } + } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { + // KQV single-batch + ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst); + } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + // KQ + KQV multi-batch + ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); + } else if (use_dequantize_mul_mat_vec) { + constexpr bool convert_src1_to_q8_1 = false; + opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV); + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1); + } else if (use_mul_mat_vec_q) { + constexpr bool convert_src1_to_q8_1 = true; + opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ); + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1); + } else if (use_mul_mat_q) { + constexpr bool convert_src1_to_q8_1 = true; + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1); + } else { + constexpr bool convert_src1_to_q8_1 = false; + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1); + } +} + + +struct mmid_row_mapping { + int32_t i1; + int32_t i2; +}; + +__dpct_inline__ static void k_copy_src1_to_contiguous( + const char *__restrict__ src1_original, char *__restrict__ src1_contiguous, + int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping, + const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0, + int64_t ne11, int64_t ne10, size_t nb11, size_t nb12, + const sycl::nd_item<3> &item_ct1, int &src1_row) { + int32_t iid1 = item_ct1.get_group(2); + int32_t id = item_ct1.get_group(1); + + const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0); + + if (row_id_i != i02) { + return; + } + + const int64_t i11 = id % ne11; + const int64_t i12 = iid1; + + if (item_ct1.get_local_id(2) == 0) { + src1_row = + dpct::atomic_fetch_add( + cur_src1_row, 1); + row_mapping[src1_row] = {id, iid1}; + } + /* + DPCT1065:194: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + + const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12); + float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11); + +#pragma unroll + for (int i = item_ct1.get_local_id(2); i < ne10; + i += item_ct1.get_local_range(2)) { + src1_row_contiguous[i] = src1_row_original[i]; + } +} + +__dpct_inline__ static void k_copy_dst_from_contiguous( + char *__restrict__ dst_original, const char *__restrict__ dst_contiguous, + const mmid_row_mapping *__restrict__ row_mapping, int64_t ne0, size_t nb1, + size_t nb2, const sycl::nd_item<3> &item_ct1) { + int32_t i = item_ct1.get_group(2); + + const int32_t i1 = row_mapping[i].i1; + const int32_t i2 = row_mapping[i].i2; + + const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1); + float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2); + +#pragma unroll + for (int j = item_ct1.get_local_id(2); j < ne0; + j += item_ct1.get_local_range(2)) { + dst_row_original[j] = dst_row_contiguous[j]; + } +} + +static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, + ggml_tensor *dst) try { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3); + const ggml_tensor *src0 = dst->src[0]; + const ggml_tensor *src1 = dst->src[1]; + GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers"); + + const ggml_tensor *ids = dst->src[2]; + GGML_TENSOR_BINARY_OP_LOCALS + + const queue_ptr stream = ctx.stream(); + + const int64_t n_as = ne02; + const int64_t n_ids = ids->ne[0]; + + std::vector ids_host(ggml_nbytes(ids)); + const char * ids_dev = (const char *) ids->data; + + SYCL_CHECK(CHECK_TRY_ERROR( + stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids)))); + SYCL_CHECK(CHECK_TRY_ERROR(stream->wait())); + + ggml_tensor src0_row = *src0; + ggml_tensor src1_row = *src1; + ggml_tensor dst_row = *dst; + + char *src0_original = (char *)src0->data; + char *src1_original = (char *)src1->data; + char *dst_original = (char *)dst->data; + + src0_row.ne[2] = 1; + src0_row.ne[3] = 1; + src0_row.nb[3] = nb02; + + src1_row.ne[1] = 1; + src1_row.ne[2] = 1; + src1_row.ne[3] = 1; + src1_row.nb[2] = nb11; + src1_row.nb[3] = nb11; + + dst_row.ne[1] = 1; + dst_row.ne[2] = 1; + dst_row.ne[3] = 1; + dst_row.nb[2] = nb1; + dst_row.nb[3] = nb1; + if (ne12 == 1) { + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + const int64_t i11 = id % ne11; + const int64_t i12 = iid1; + + const int64_t i1 = id; + const int64_t i2 = i12; + + src0_row.data = src0_original + i02*nb02; + src1_row.data = src1_original + i11*nb11 + i12*nb12; + dst_row.data = dst_original + i1*nb1 + i2*nb2; + + ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row); + } + } + } else { + ggml_sycl_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); + ggml_sycl_pool_alloc dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + + src1_row.data = src1_contiguous.get(); + dst_row.data = dst_contiguous.get(); + + for (int64_t i02 = 0; i02 < n_as; i02++) { + int64_t num_src1_rows = 0; + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + + GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + + if (row_id_i != i02) { + continue; + } + + num_src1_rows++; + } + } + + if (num_src1_rows == 0) { + continue; + } + + + ggml_sycl_pool_alloc dev_cur_src1_row(ctx.pool(), 1); + ggml_sycl_pool_alloc dev_row_mapping(ctx.pool(), num_src1_rows); + SYCL_CHECK(CHECK_TRY_ERROR( + stream->memset(dev_cur_src1_row.get(), 0, sizeof(int)))); + + { + sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u)); + sycl::range<3> grid_dims(1, n_ids, ids->ne[1]); + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor src1_row_acc(cgh); + + char *__restrict src1_contiguous_get = + src1_contiguous.get(); + int *__restrict dev_cur_src1_row_get = + dev_cur_src1_row.get(); + mmid_row_mapping *__restrict dev_row_mapping_get = + dev_row_mapping.get(); + size_t ids_nb_ct6 = ids->nb[1]; + size_t ids_nb_ct7 = ids->nb[0]; + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_copy_src1_to_contiguous( + src1_original, src1_contiguous_get, + dev_cur_src1_row_get, + dev_row_mapping_get, ids_dev, i02, + ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12, + item_ct1, src1_row_acc); + }); + }); + } + + src0_row.data = src0_original + i02*nb02; + + GGML_ASSERT(nb11 == sizeof(float)*ne10); + GGML_ASSERT(nb1 == sizeof(float)*ne0); + src1_row.ne[1] = num_src1_rows; + + src1_row.nb[1] = nb11; + src1_row.nb[2] = num_src1_rows*nb11; + src1_row.nb[3] = num_src1_rows*nb11; + + dst_row.ne[1] = num_src1_rows; + dst_row.nb[1] = nb1; + dst_row.nb[2] = num_src1_rows*nb1; + dst_row.nb[3] = num_src1_rows*nb1; + + ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row); + + { + sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u)); + sycl::range<3> grid_dims(1, 1, num_src1_rows); + stream->submit([&](sycl::handler &cgh) { + const char *__restrict dst_contiguous_get = + dst_contiguous.get(); + const mmid_row_mapping *__restrict dev_row_mapping_get = + dev_row_mapping.get(); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_copy_dst_from_contiguous(dst_original, + dst_contiguous_get, + dev_row_mapping_get, + ne0, nb1, nb2, item_ct1); + }); + }); + } + } + } +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_scale(ctx, dst); +} + +static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_diag_mask_inf(ctx, dst); +} + +static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_pool2d(ctx, dst); +} + +static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_im2col(ctx, dst); +} + +static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + GGML_ASSERT(ggml_is_contiguous(dst->src[0])); + ggml_sycl_op_sum(ctx, dst); +} + +static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + GGML_ASSERT(ggml_is_contiguous(dst->src[0])); + ggml_sycl_op_sum_rows(ctx, dst); +} + +static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + GGML_ASSERT(ggml_is_contiguous(dst->src[0])); + ggml_sycl_op_argsort(ctx, dst); +} + +static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + GGML_ASSERT(ggml_is_contiguous(dst->src[0])); + ggml_sycl_op_argmax(ctx, dst); +} + + +static void ggml_sycl_set_main_device(const int main_device) try { + if (dpct::get_current_device_id() == static_cast (main_device)) { + return; + } + check_allow_gpu_index(main_device); + dpct::select_device(main_device); + + if (g_ggml_sycl_debug) { + dpct::device_info prop; + SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( + prop, dpct::dev_mgr::instance().get_device(main_device)))); + GGML_LOG_INFO("Using device %d (%s) as main device\n", + main_device, prop.get_name()); + } +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) try { + if (!g_sycl_loaded) return false; + + if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) { + ggml_sycl_set_peer_access(dst->src[1]->ne[1], ctx.device); + } + + switch (dst->op) { + case GGML_OP_ARGMAX: + ggml_sycl_argmax(ctx, dst); + break; + case GGML_OP_CONV_TRANSPOSE_1D: + ggml_sycl_op_conv_transpose_1d(ctx, dst); + break; + case GGML_OP_REPEAT: + ggml_sycl_repeat(ctx, dst); + break; + case GGML_OP_GET_ROWS: + ggml_sycl_get_rows(ctx, dst); + break; + case GGML_OP_DUP: + ggml_sycl_dup(ctx, dst); + break; + case GGML_OP_ADD: + case GGML_OP_ADD1: // TODO: more efficient implementation + ggml_sycl_add(ctx, dst); + break; + case GGML_OP_SUB: + ggml_sycl_sub(ctx, dst); + break; + case GGML_OP_ACC: + ggml_sycl_acc(ctx, dst); + break; + case GGML_OP_MUL: + ggml_sycl_mul(ctx, dst); + break; + case GGML_OP_LOG: + ggml_sycl_log(ctx, dst); + break; + case GGML_OP_DIV: + ggml_sycl_div(ctx, dst); + break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(dst)) { + case GGML_UNARY_OP_NEG: + ggml_sycl_neg(ctx, dst); + break; + case GGML_UNARY_OP_STEP: + ggml_sycl_step(ctx, dst); + break; + case GGML_UNARY_OP_GELU: + ggml_sycl_gelu(ctx, dst); + break; + case GGML_UNARY_OP_SILU: + ggml_sycl_silu(ctx, dst); + break; + case GGML_UNARY_OP_GELU_QUICK: + ggml_sycl_gelu_quick(ctx, dst); + break; + case GGML_UNARY_OP_GELU_ERF: + ggml_sycl_gelu_erf(ctx, dst); + break; + case GGML_UNARY_OP_TANH: + ggml_sycl_tanh(ctx, dst); + break; + case GGML_UNARY_OP_RELU: + ggml_sycl_relu(ctx, dst); + break; + case GGML_UNARY_OP_SIGMOID: + ggml_sycl_sigmoid(ctx, dst); + break; + case GGML_UNARY_OP_HARDSIGMOID: + ggml_sycl_hardsigmoid(ctx, dst); + break; + case GGML_UNARY_OP_HARDSWISH: + ggml_sycl_hardswish(ctx, dst); + break; + case GGML_UNARY_OP_EXP: + ggml_sycl_exp(ctx, dst); + break; + case GGML_UNARY_OP_SGN: + ggml_sycl_sgn(ctx, dst); + break; + case GGML_UNARY_OP_ABS: + ggml_sycl_abs(ctx, dst); + break; + case GGML_UNARY_OP_ELU: + ggml_sycl_elu(ctx, dst); + break; + default: + return false; + } + break; + case GGML_OP_NORM: + ggml_sycl_norm(ctx, dst); + break; + case GGML_OP_GROUP_NORM: + ggml_sycl_group_norm(ctx, dst); + break; + case GGML_OP_CONCAT: + ggml_sycl_op_concat(ctx, dst); + break; + case GGML_OP_UPSCALE: + ggml_sycl_upscale(ctx, dst); + break; + case GGML_OP_PAD: + ggml_sycl_pad(ctx, dst); + break; + case GGML_OP_LEAKY_RELU: + ggml_sycl_leaky_relu(ctx, dst); + break; + case GGML_OP_RMS_NORM: + ggml_sycl_rms_norm(ctx, dst); + break; + case GGML_OP_L2_NORM: + ggml_sycl_l2_norm(ctx, dst); + break; + case GGML_OP_MUL_MAT: + if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { + return false; + } + /* ggml_sycl_mul_mat_id is dependent on ggml_sycl_mul_mat */ + ggml_sycl_mul_mat(ctx, dst->src[0], dst->src[1], dst); + break; + case GGML_OP_MUL_MAT_ID: + if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { + return false; + } + ggml_sycl_mul_mat_id(ctx, dst); + break; + case GGML_OP_OUT_PROD: + ggml_sycl_op_out_prod(ctx, dst); + break; + case GGML_OP_SCALE: + ggml_sycl_scale(ctx, dst); + break; + case GGML_OP_SQR: + ggml_sycl_sqr(ctx, dst); + break; + case GGML_OP_SQRT: + ggml_sycl_sqrt(ctx, dst); + break; + case GGML_OP_SIN: + ggml_sycl_sin(ctx, dst); + break; + case GGML_OP_COS: + ggml_sycl_cos(ctx, dst); + break; + case GGML_OP_CLAMP: + ggml_sycl_clamp(ctx, dst); + break; + case GGML_OP_CPY: + ggml_sycl_cpy(ctx, dst->src[0], dst->src[1]); + break; + case GGML_OP_CONT: + ggml_sycl_dup(ctx, dst); + break; + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__); + break; + case GGML_OP_DIAG_MASK_INF: + ggml_sycl_diag_mask_inf(ctx, dst); + break; + case GGML_OP_SOFT_MAX: + ggml_sycl_op_soft_max(ctx, dst); + break; + case GGML_OP_ROPE: + ggml_sycl_rope(ctx, dst); + break; + case GGML_OP_IM2COL: + ggml_sycl_im2col(ctx, dst); + break; + case GGML_OP_POOL_2D: + ggml_sycl_pool2d(ctx, dst); + break; + case GGML_OP_SUM: + ggml_sycl_sum(ctx, dst); + break; + case GGML_OP_SUM_ROWS: + ggml_sycl_sum_rows(ctx, dst); + break; + case GGML_OP_ARGSORT: + ggml_sycl_argsort(ctx, dst); + break; + case GGML_OP_TIMESTEP_EMBEDDING: + ggml_sycl_op_timestep_embedding(ctx, dst); + break; + case GGML_OP_RWKV_WKV6: + ggml_sycl_op_rwkv_wkv6(ctx, dst); + break; + case GGML_OP_RWKV_WKV7: + ggml_sycl_op_rwkv_wkv7(ctx, dst); + break; + case GGML_OP_GATED_LINEAR_ATTN: + ggml_sycl_op_gated_linear_attn(ctx, dst); + break; + default: + return false; + } + + return true; +} catch (sycl::exception & e) { + std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +GGML_API void ggml_backend_sycl_get_device_description(int device, char *description, + size_t description_size) try { + GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_description\n"); + dpct::device_info prop; + SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( + prop, dpct::dev_mgr::instance().get_device(device)))); + snprintf(description, description_size, "%s", prop.get_name()); +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +void ggml_backend_sycl_get_device_memory(int device, size_t *free, + size_t *total) try { + GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_memory\n"); + ggml_sycl_set_device(device); + + /* + DPCT1009:218: SYCL uses exceptions to report errors and does not use the + error codes. The original code was commented out and a warning string was + inserted. You need to rewrite this code. + */ + /* + DPCT1106:217: 'cudaMemGetInfo' was migrated with the Intel extensions for + device information which may not be supported by all compilers or runtimes. + You may need to adjust the code. + */ + SYCL_CHECK(CHECK_TRY_ERROR( + dpct::dev_mgr::instance().get_device(device).get_memory_info(*free, *total))); +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +//////////////////////////////////////////////////////////////////////////////// + +// backend + +static const char * ggml_backend_sycl_get_name(ggml_backend_t backend) { + + ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; + + return sycl_ctx->name.c_str(); +} + +static void ggml_backend_sycl_free(ggml_backend_t backend) { + ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; + + delete sycl_ctx; + delete backend; +} + +static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend, + ggml_tensor *tensor, + const void *data, size_t offset, + size_t size) try { + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str()); + GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); + ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + + GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type"); + const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0); + SYCL_CHECK(CHECK_TRY_ERROR( + (stream)->memcpy((char *)tensor->data + offset, data, size))); +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend, + const ggml_tensor *tensor, + void *data, size_t offset, + size_t size) try { + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str()); + GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); + ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + + GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type"); + const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0); + SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy( + data, (const char *)tensor->data + offset, size))); +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend, + const ggml_tensor *src, + ggml_tensor *dst) try { + ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; + bool is_cpy_supported = dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && + ggml_backend_buffer_is_sycl(src->buffer); + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": dst", dst).c_str()); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(" src", src).c_str()); + GGML_SYCL_DEBUG(" is_cpy_supported=%d\n", is_cpy_supported); + if (is_cpy_supported) { + /* + DPCT1009:215: SYCL uses exceptions to report errors and does not use the + error codes. The original code was commented out and a warning string + was inserted. You need to rewrite this code. + */ + const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0); + SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy( + dst->data, src->data, ggml_nbytes(dst)))); + return true; + } + + return false; +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_backend_sycl_synchronize(ggml_backend_t backend) try { + GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__); + ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; + const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0); + SYCL_CHECK(CHECK_TRY_ERROR((stream)->wait())); + + GGML_UNUSED(backend); +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * sycl_ctx, ggml_cgraph * cgraph) { + ggml_sycl_set_main_device(sycl_ctx->device); + + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + continue; + } +#ifndef NDEBUG + assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device)); + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (node->src[j] != nullptr) { + assert(node->src[j]->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device)); + } + } +#endif + bool ok = ggml_sycl_compute_forward(*sycl_ctx, node); + if (!ok) { + GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + } + GGML_ASSERT(ok); + } +} + +#ifdef GGML_SYCL_GRAPH +static bool check_graph_compatibility(ggml_cgraph * cgraph) { + if (ggml_sycl_info().device_count > 1) { + // A sycl_ex::command_graph object can only be created for a single device + GGML_LOG_INFO("%s: disabling SYCL graphs due to multiple devices\n", __func__); + return false; + } + + for (int i = 0; i < cgraph->n_nodes; i++) { + const ggml_op node_op = cgraph->nodes[i]->op; + switch (node_op) { + default: + break; + case GGML_OP_CONCAT: + // ggml_sycl_op_concat() does a blocking host wait after memcpy operations, + // but wait() can't be called on the events returned by a queue recording + // to a graph. + [[fallthrough]]; + case GGML_OP_MUL_MAT_ID: + // ggml_sycl_mul_mat_id() does a blocking host wait on the sycl queue after + // submitting a memcpy operation, but wait() can't be called on a queue that + // is recording to a graph. + GGML_LOG_INFO("%s: disabling SYCL graphs due to unsupported node type %s\n", __func__, + ggml_op_name(node_op)); + return false; + } + } + return true; +} +#endif + +static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + auto * sycl_ctx = static_cast(backend->context); + +#ifdef GGML_SYCL_GRAPH + bool use_sycl_graph = !g_ggml_sycl_disable_graph && check_graph_compatibility(cgraph); + if (use_sycl_graph) { + const bool graph_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_limited_graph); + if (!graph_support) { + GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device); + ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph); + return GGML_STATUS_SUCCESS; + } + + sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}}); + + model_sycl_graph.begin_recording(*(sycl_ctx->stream())); + ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph); + model_sycl_graph.end_recording(); + + const bool graph_update_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph); + if (!sycl_ctx->exec_graph || !graph_update_support) { + auto exec_graph = graph_update_support ? model_sycl_graph.finalize(sycl_ex::property::graph::updatable{}) : + model_sycl_graph.finalize(); + sycl_ctx->exec_graph = std::make_unique< + sycl_ex::command_graph>(exec_graph); + } else { + try { + sycl_ctx->exec_graph->update(model_sycl_graph); + GGML_SYCL_DEBUG("[SYCL-GRAPH] update success\n"); + } catch (sycl::exception const & e) { + GGML_SYCL_DEBUG("[SYCL-GRAPH] Exception when updating graph, %s\n", e.what()); + auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}}); + sycl_ctx->exec_graph = std::make_unique< + sycl_ex::command_graph>(exec_graph); + } + } + + sycl_ctx->stream()->ext_oneapi_graph(*(sycl_ctx->exec_graph)); + } else +#endif + { + ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph); + } + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_sycl_event_record(ggml_backend_t backend, ggml_backend_event_t event) +try +{ + ggml_backend_sycl_context *sycl_ctx = + (ggml_backend_sycl_context *)backend->context; + + sycl::event *sycl_event = static_cast(event->context); + + const queue_ptr &stream = sycl_ctx->stream(sycl_ctx->device, 0); + // Record the current state of the queue + SYCL_CHECK(CHECK_TRY_ERROR(*sycl_event = stream->ext_oneapi_submit_barrier())); +} +catch (sycl::exception const &exc) +{ + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_backend_sycl_event_wait(ggml_backend_t backend, ggml_backend_event_t event) try { + GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__); + sycl::event* sycl_event = static_cast(event->context); + + if (ggml_backend_is_sycl(backend)) { + SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait())); + } else + GGML_ABORT("fatal error"); +} catch (sycl::exception const& exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static ggml_backend_i ggml_backend_sycl_interface = { + /* .get_name = */ ggml_backend_sycl_get_name, + /* .free = */ ggml_backend_sycl_free, + /* .set_tensor_async = */ ggml_backend_sycl_set_tensor_async, + /* .get_tensor_async = */ ggml_backend_sycl_get_tensor_async, + /* .cpy_tensor_async = */ NULL, // ggml_backend_sycl_cpy_tensor_async, + // // TODO: update for the new + // interface + /* .synchronize = */ ggml_backend_sycl_synchronize, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_sycl_graph_compute, + /* .event_record = */ ggml_backend_sycl_event_record, + /* .event_wait = */ ggml_backend_sycl_event_wait, +}; + +static ggml_guid_t ggml_backend_sycl_guid() { + static ggml_guid guid = { 0x58, 0x05, 0x13, 0x8f, 0xcd, 0x3a, 0x61, 0x9d, 0xe7, 0xcd, 0x98, 0xa9, 0x03, 0xfd, 0x7c, 0x53 }; + return &guid; +} + +bool ggml_backend_is_sycl(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_sycl_guid()); +} + +int ggml_backend_sycl_get_device_count() { + return ggml_sycl_info().device_count; +} + + +// backend device + +struct ggml_backend_sycl_device_context { + int device; + std::string name; + std::string description; +}; + +static const char * ggml_backend_sycl_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context; + return ctx->name.c_str(); +} + +static const char * ggml_backend_sycl_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context; + return ctx->description.c_str(); +} + +static void ggml_backend_sycl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context; + ggml_sycl_set_device(ctx->device); + SYCL_CHECK(CHECK_TRY_ERROR( + dpct::dev_mgr::instance().get_device(ctx->device).get_memory_info(*free, *total))); +} + +static enum ggml_backend_dev_type ggml_backend_sycl_device_get_type(ggml_backend_dev_t dev) { + GGML_UNUSED(dev); + return GGML_BACKEND_DEVICE_TYPE_GPU; +} + +static void ggml_backend_sycl_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + props->name = ggml_backend_sycl_device_get_name(dev); + props->description = ggml_backend_sycl_device_get_description(dev); + props->type = ggml_backend_sycl_device_get_type(dev); + ggml_backend_sycl_device_get_memory(dev, &props->memory_free, &props->memory_total); + + bool host_buffer = getenv("GGML_SYCL_NO_PINNED") == nullptr; +#ifdef GGML_SYCL_NO_PEER_COPY + bool events = false; +#else + bool events = true; +#endif + + props->caps = { + /* .async = */ true, + /* .host_buffer = */ host_buffer, + /* .buffer_from_host_ptr = */ false, + /* .events = */ events, + }; +} + +static ggml_backend_t ggml_backend_sycl_device_init(ggml_backend_dev_t dev, const char * params) { + GGML_UNUSED(params); + ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context; + return ggml_backend_sycl_init(ctx->device); +} + +static ggml_backend_buffer_type_t ggml_backend_sycl_device_get_buffer_type(ggml_backend_dev_t dev) { + ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context; + return ggml_backend_sycl_buffer_type(ctx->device); +} + +static ggml_backend_buffer_type_t ggml_backend_sycl_device_get_host_buffer_type(ggml_backend_dev_t dev) { + GGML_UNUSED(dev); + return ggml_backend_sycl_host_buffer_type(); +} + +static ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + GGML_UNUSED(dev); + GGML_UNUSED(ptr); + GGML_UNUSED(size); + GGML_UNUSED(max_tensor_size); + return nullptr; +} + +static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + switch (op->op) { + case GGML_OP_CONV_TRANSPOSE_1D: + { + ggml_type src0_type = op->src[0]->type; + ggml_type src1_type = op->src[1]->type; + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { + return true; + } + return false; + } + case GGML_OP_UNARY: + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_HARDSWISH: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_GELU_ERF: + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SGN: + case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_ELU: +#if defined (GGML_SYCL_F16) + return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type); +#else + return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) && (op->type == op->src[0]->type); +#endif + default: + return false; + } + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + { + struct ggml_tensor * a; + struct ggml_tensor * b; + if (op->op == GGML_OP_MUL_MAT) { + a = op->src[0]; + b = op->src[1]; + } else { + a = op->src[2]; + b = op->src[1]; + } + if (a->ne[3] != b->ne[3]) { + return false; + } + ggml_type a_type = a->type; + if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS || + a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S || + a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ2_S || + a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ1_M + ) { + if (b->ne[1] == 1 && ggml_nrows(b) > 1) { + return false; + } + } + ggml_type src0_type = op->src[0]->type; + if (src0_type == GGML_TYPE_BF16) { + return false; + } + return true; + } + case GGML_OP_OUT_PROD: + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1; + case GGML_OP_GET_ROWS: + { + switch (op->src[0]->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + return true; + default: + return false; + } + } + case GGML_OP_CPY: + { + ggml_type src0_type = op->src[0]->type; + ggml_type src1_type = op->src[1]->type; + if (src0_type == src1_type && (ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) && src0_type != GGML_TYPE_BF16) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) { + return true; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return true; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return true; + } + if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) { + return true; + } + if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) { + return true; + } + if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) { + return true; + } + if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) { + return true; + } + if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) { + return true; + } + if(src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_Q8_0) { + return true; + } + if(src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_Q5_0) { + return true; + } + if(src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_Q5_1) { + return true; + } + if(src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_Q4_0) { + return true; + } + if(src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_Q4_1) { + return true; + } + return false; + } + case GGML_OP_CONCAT: + { + ggml_type src0_type = op->src[0]->type; + return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; + } + case GGML_OP_DUP: + case GGML_OP_ARGMAX: + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + return true; + case GGML_OP_ADD: + case GGML_OP_ADD1: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_REPEAT: + return true; + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_LOG: +#if defined (GGML_SYCL_F16) + return ((op->type == GGML_TYPE_F32 || op->type == GGML_SYCL_F16) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_SYCL_F16) && (op->type == op->src[0]->type)); +#else + return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type); +#endif + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + return true; + case GGML_OP_L2_NORM: + case GGML_OP_GROUP_NORM: + return ggml_is_contiguous(op->src[0]); + case GGML_OP_SCALE: + return true; + case GGML_OP_CONT: + return op->src[0]->type != GGML_TYPE_BF16; + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + return true; + case GGML_OP_ROPE: + case GGML_OP_IM2COL: + return true; + case GGML_OP_UPSCALE: + return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; + case GGML_OP_POOL_2D: + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_ARGSORT: + case GGML_OP_ACC: + case GGML_OP_PAD: + case GGML_OP_LEAKY_RELU: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: + case GGML_OP_GATED_LINEAR_ATTN: + return true; + default: + return false; + } + + GGML_UNUSED(dev); +} + +static bool ggml_backend_sycl_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + if (buft->iface.get_name != ggml_backend_sycl_buffer_type_get_name) { + return false; + } + ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context; + ggml_backend_sycl_device_context * sycl_ctx = (ggml_backend_sycl_device_context *)dev->context; + return buft_ctx->device == sycl_ctx->device; +} + +static int64_t get_op_batch_size(const ggml_tensor * op) { + switch (op->op) { + case GGML_OP_GET_ROWS: + return 0; + case GGML_OP_MUL_MAT: + return op->ne[1]; + case GGML_OP_MUL_MAT_ID: + case GGML_OP_ROPE: + return op->ne[2]; + default: + return ggml_nrows(op); + } +} + +static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + const int min_batch_size = 32; + return get_op_batch_size(op) >= min_batch_size; + GGML_UNUSED(dev); +} + +static ggml_backend_event_t +ggml_backend_sycl_device_event_new(ggml_backend_dev_t dev) { + +#ifdef GGML_SYCL_NO_PEER_COPY + return nullptr; +#else + sycl::event *event_ptr = new sycl::event(); + + return new ggml_backend_event{ + /* .device = */ dev, + /* .context = */ event_ptr, + }; +#endif +} + +static void ggml_backend_sycl_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) try { + GGML_UNUSED(dev); + if (event == nullptr) { + return; + } + + if (event->context != nullptr) { + sycl::event *sycl_event = static_cast(event->context); + delete sycl_event; + event->context = nullptr; + } + + delete event; +} catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + + +static void ggml_backend_sycl_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) try { + GGML_UNUSED(dev); + GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__); + + sycl::event *sycl_event = static_cast(event->context); + SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait())); +} catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static const ggml_backend_device_i ggml_backend_sycl_device_interface = { + /* .get_name = */ ggml_backend_sycl_device_get_name, + /* .get_description = */ ggml_backend_sycl_device_get_description, + /* .get_memory = */ ggml_backend_sycl_device_get_memory, + /* .get_type = */ ggml_backend_sycl_device_get_type, + /* .get_props = */ ggml_backend_sycl_device_get_props, + /* .init_backend = */ ggml_backend_sycl_device_init, + /* .get_buffer_type = */ ggml_backend_sycl_device_get_buffer_type, + /* .get_host_buffer_type = */ ggml_backend_sycl_device_get_host_buffer_type, + /* .buffer_from_host_ptr = */ ggml_backend_sycl_device_buffer_from_host_ptr, + /* .supports_op = */ ggml_backend_sycl_device_supports_op, + /* .supports_buft = */ ggml_backend_sycl_device_supports_buft, + /* .offload_op = */ ggml_backend_sycl_device_offload_op, + /* .event_new = */ ggml_backend_sycl_device_event_new, + /* .event_free = */ ggml_backend_sycl_device_event_free, + /* .event_synchronize = */ ggml_backend_sycl_device_event_synchronize, +}; + +// backend reg + +struct ggml_backend_sycl_reg_context { + std::vector devices; +}; + +static const char * ggml_backend_sycl_reg_get_name(ggml_backend_reg_t reg) { + GGML_UNUSED(reg); + return GGML_SYCL_NAME; +} + +static size_t ggml_backend_sycl_reg_get_device_count(ggml_backend_reg_t reg) { + ggml_backend_sycl_reg_context * ctx = (ggml_backend_sycl_reg_context *)reg->context; + return ctx->devices.size(); +} + +static ggml_backend_dev_t ggml_backend_sycl_reg_get_device(ggml_backend_reg_t reg, size_t index) { + ggml_backend_sycl_reg_context * ctx = (ggml_backend_sycl_reg_context *)reg->context; + GGML_ASSERT(index < ctx->devices.size()); + return ctx->devices[index]; +} + +static void *ggml_backend_sycl_reg_get_proc_address(ggml_backend_reg_t reg, const char *name) { + GGML_UNUSED(reg); + + if (strcmp(name, "ggml_backend_split_buffer_type") == 0) { + return (void *)ggml_backend_sycl_split_buffer_type; + } + + // SYCL doesn't support registering host memory, left here for reference + // "ggml_backend_register_host_buffer" + // "ggml_backend_unregister_host_buffer" + GGML_UNUSED(name); + return nullptr; +} + +static const ggml_backend_reg_i ggml_backend_sycl_reg_interface = { + /* .get_name = */ ggml_backend_sycl_reg_get_name, + /* .get_device_count = */ ggml_backend_sycl_reg_get_device_count, + /* .get_device = */ ggml_backend_sycl_reg_get_device, + /* .get_proc_address = */ ggml_backend_sycl_reg_get_proc_address, +}; + + +// backend registry + +ggml_backend_reg_t ggml_backend_sycl_reg() { + static ggml_backend_reg reg; + static bool initialized = false; + + { + static std::mutex mutex; + std::lock_guard lock(mutex); + if (!initialized) { + ggml_backend_sycl_reg_context * ctx = new ggml_backend_sycl_reg_context; + + for (int i = 0; i < ggml_sycl_info().device_count; i++) { + ggml_backend_sycl_device_context * dev_ctx = new ggml_backend_sycl_device_context; + dev_ctx->device = i; + dev_ctx->name = GGML_SYCL_NAME + std::to_string(i); + + ggml_sycl_set_device(i); + + dpct::device_info prop; + SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( + prop, dpct::dev_mgr::instance().get_device(i)))); + + dev_ctx->description = prop.get_name(); + + ggml_backend_dev_t dev = new ggml_backend_device { + /* .iface = */ ggml_backend_sycl_device_interface, + /* .reg = */ ®, + /* .context = */ dev_ctx + }; + ctx->devices.push_back(dev); + } + + reg = ggml_backend_reg { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_sycl_reg_interface, + /* .context = */ ctx + }; + } + + initialized = true; + } + + return ® +} + +ggml_backend_t ggml_backend_sycl_init(int device) { + GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_init\n"); + ggml_check_sycl(); + + check_allow_gpu_index(device); + + ggml_backend_sycl_context * ctx = new ggml_backend_sycl_context(device); + if (ctx == nullptr) { + GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); + return nullptr; + }; + + ggml_backend_t sycl_backend = new ggml_backend { + /* .guid = */ ggml_backend_sycl_guid(), + /* .interface = */ ggml_backend_sycl_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device), + /* .context = */ ctx + }; + + return sycl_backend; +} + +GGML_BACKEND_DL_IMPL(ggml_backend_sycl_reg) diff --git a/ggml/src/ggml-sycl/gla.cpp b/ggml/src/ggml-sycl/gla.cpp new file mode 100644 index 0000000000000000000000000000000000000000..879184fdd311139a872f4eca962354f0a1058cc5 --- /dev/null +++ b/ggml/src/ggml-sycl/gla.cpp @@ -0,0 +1,106 @@ +#include + +#include "common.hpp" + +template +static void gated_linear_attn_f32_kernel(const dpct::queue_ptr stream, u_int B, u_int T, u_int C, u_int H, float scale, + const float * k, const float * v, const float * r, const float * td, + const float * s, float * dst) { + const u_int head_size = HEAD_SIZE; + const u_int state_size = C * head_size; + const u_int n_seq_tokens = T / B; + sycl::range<1> block_dims((C / H)); + sycl::range<1> grid_dims((B * H)); + stream->submit([&](sycl::handler & cgh) { + /* local memory accessors*/ + auto _k = sycl::local_accessor(sycl::range<1>(head_size), cgh); + auto _r = sycl::local_accessor(sycl::range<1>(head_size), cgh); + auto _td = sycl::local_accessor(sycl::range<1>(head_size), cgh); + + cgh.parallel_for(sycl::nd_range<1>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<1> item) { + u_int tid = item.get_local_id(0); + u_int bid = item.get_group(0); + + u_int batch_i = bid / H; + u_int head_i = bid % H; + + float state[head_size]; + +#pragma unroll + for (u_int i = 0; i < head_size; i++) { + state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid]; + } + + for (u_int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; + t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) { + + item.barrier(sycl::access::fence_space::local_space); //sync threads + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + item.barrier(sycl::access::fence_space::local_space); //sync threads + + const float _v = v[t]; + float y = 0; + + for (u_int j = 0; j < head_size; j += 4) { + const sycl::float4 & k = (sycl::float4 &) (_k[j]); + const sycl::float4 & r = (sycl::float4 &) (_r[j]); + const sycl::float4 & td = (sycl::float4 &) (_td[j]); + sycl::float4 & s = (sycl::float4 &) (state[j]); + sycl::float4 kv; + + kv.x() = k.x() * _v; + kv.y() = k.y() * _v; + kv.z() = k.z() * _v; + kv.w() = k.w() * _v; + + s.x() = s.x() * td.x() + kv.x(); + s.y() = s.y() * td.y() + kv.y(); + s.z() = s.z() * td.z() + kv.z(); + s.w() = s.w() * td.w() + kv.w(); + + y += r.x() * s.x(); + y += r.y() * s.y(); + y += r.z() * s.z(); + y += r.w() * s.w(); + } + dst[t] = y * scale; + } +#pragma unroll + for (u_int i = 0; i < head_size; i++) { + dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i]; + } + }); + }); +} + +void ggml_sycl_op_gated_linear_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/5); + const float * k_d = static_cast(dst->src[0]->data); + const float * v_d = static_cast(dst->src[1]->data); + const float * r_d = static_cast(dst->src[2]->data); + const float * td_d = static_cast(dst->src[3]->data); + const float * s_d = static_cast(dst->src[4]->data); + + const int64_t B = dst->src[4]->ne[1]; + const int64_t T = dst->src[0]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[1]; + + dpct::queue_ptr stream = ctx.stream(); + GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == 64 || C / H == 128); + + float scale; + memcpy(&scale, dst->op_params, sizeof(float)); + + float * dst_d = (float *) dst->data; + + if (C / H == 64) { + gated_linear_attn_f32_kernel<64>(stream, B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d); + } else { + gated_linear_attn_f32_kernel<128>(stream, B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d); + } +} diff --git a/ggml/src/ggml-sycl/gla.hpp b/ggml/src/ggml-sycl/gla.hpp new file mode 100644 index 0000000000000000000000000000000000000000..607cf3a7f3049565110e5c1bfaa06fb8247e31fb --- /dev/null +++ b/ggml/src/ggml-sycl/gla.hpp @@ -0,0 +1,8 @@ +#ifndef GGML_SYCL_GLA_HPP +#define GGML_SYCL_GLA_HPP + +#include "common.hpp" + +void ggml_sycl_op_gated_linear_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif // GGML_SYCL_GLA_HPP diff --git a/ggml/src/ggml-sycl/im2col.cpp b/ggml/src/ggml-sycl/im2col.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aa19c2527dc410336a5967b0aaed08dc3782001d --- /dev/null +++ b/ggml/src/ggml-sycl/im2col.cpp @@ -0,0 +1,136 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#include "im2col.hpp" + +#include +#include // For std::is_same_v + +#include "ggml.h" + +template +static void im2col_kernel(const float * x, T * dst, int64_t batch_offset, int64_t offset_delta, int64_t IC, int64_t IW, + int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW, + int s0, int s1, int p0, int p1, int d0, int d1, const sycl::nd_item<3> & item_ct1) { + const int64_t work_group_size = item_ct1.get_local_range(2); + const int64_t global_id = item_ct1.get_local_id(2) + (work_group_size * item_ct1.get_group(2)); + + // make each work-item deal with more elements since sycl global range can not exceed max int + for (int64_t i = global_id; i < pelements; i += (work_group_size * item_ct1.get_group_range(2))) { + const int64_t ksize = OW * (KH > 1 ? KW : 1); + const int64_t kx = i / ksize; + const int64_t kd = kx * ksize; + const int64_t ky = (i - kd) / OW; + const int64_t ix = i % OW; + + const int64_t oh = item_ct1.get_group(1); + const int64_t batch = item_ct1.get_group(0) / IC; + const int64_t ic = item_ct1.get_group(0) % IC; + + const int64_t iiw = (ix * s0) + (kx * d0) - p0; + const int64_t iih = (oh * s1) + (ky * d1) - p1; + + const int64_t offset_dst = (((batch * OH + oh) * OW + ix) * CHW) + (ic * (KW * KH) + ky * KW + kx); + + const int64_t offset_src_base = (ic * offset_delta) + (batch * batch_offset); + const int64_t offset_src = offset_src_base + (iih * IW) + iiw; + + const bool out_of_bounds = (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW); + const float src_val = out_of_bounds ? 0.0f : x[offset_src]; + + if constexpr (std::is_same_v) { + dst[offset_dst] = sycl::half(src_val); + } else if constexpr (std::is_same_v) { + dst[offset_dst] = src_val; + } + } +} + +template +static void im2col_sycl_internal(const float * x, T * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, + int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta, + int s0, int s1, int p0, int p1, int d0, int d1, queue_ptr stream) { + const int64_t parallel_elements = OW * KW * KH; + const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE; + + // decrease global range when it exceeds the max int + int64_t local_size = downsample_sycl_global_range(batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE); + + sycl::range<3> block_nums(batch * IC, OH, num_blocks); + sycl::range<3> local_range(1, 1, local_size); + + const int64_t CHW = IC * KH * KW; + + stream->parallel_for(sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) { + im2col_kernel(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, CHW, s0, s1, + p0, p1, d0, d1, item_ct1); + }); +} + +static void im2col_sycl_f16(const float * x, sycl::half * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, + int64_t KW, int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, + int64_t offset_delta, int s0, int s1, int p0, int p1, int d0, int d1, queue_ptr stream) { + if (!stream->get_device().has(sycl::aspect::fp16)) { + throw sycl::exception(sycl::make_error_code(sycl::errc::kernel_not_supported), + "Device does not support half precision (fp16) operations!"); + } + im2col_sycl_internal(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, + p1, d0, d1, stream); +} + +static void im2col_sycl_f32(const float * x, float * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, + int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta, int s0, + int s1, int p0, int p1, int d0, int d1, queue_ptr stream) { + im2col_sycl_internal(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, + d0, d1, stream); +} + +void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + + const int32_t s0 = ((const int32_t *) (dst->op_params))[0]; + const int32_t s1 = ((const int32_t *) (dst->op_params))[1]; + const int32_t p0 = ((const int32_t *) (dst->op_params))[2]; + const int32_t p1 = ((const int32_t *) (dst->op_params))[3]; + const int32_t d0 = ((const int32_t *) (dst->op_params))[4]; + const int32_t d1 = ((const int32_t *) (dst->op_params))[5]; + + const bool is_2D = ((const int32_t *) (dst->op_params))[6] == 1; + + const int64_t IC = src1->ne[is_2D ? 2 : 1]; + const int64_t IH = is_2D ? src1->ne[1] : 1; + const int64_t IW = src1->ne[0]; + + const int64_t KH = is_2D ? src0->ne[1] : 1; + const int64_t KW = src0->ne[0]; + + const int64_t OH = is_2D ? dst->ne[2] : 1; + const int64_t OW = dst->ne[1]; + + const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / sizeof(float); + const int64_t batch = src1->ne[is_2D ? 3 : 2]; + const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / sizeof(float); + + queue_ptr stream = ctx.stream(); + + if (dst->type == GGML_TYPE_F16) { + im2col_sycl_f16((const float *) src1->data, (sycl::half *) dst->data, IW, IH, OW, OH, KW, KH, IC, batch, + batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); + } else { + im2col_sycl_f32((const float *) src1->data, (float *) dst->data, IW, IH, OW, OH, KW, KH, IC, batch, + batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); + } +} diff --git a/ggml/src/ggml-sycl/im2col.hpp b/ggml/src/ggml-sycl/im2col.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dbbb248ddb4fcf6e7df0e592b31ae6df0f2ca871 --- /dev/null +++ b/ggml/src/ggml-sycl/im2col.hpp @@ -0,0 +1,21 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_IM2COL_HPP +#define GGML_SYCL_IM2COL_HPP + +#include "common.hpp" + +void ggml_sycl_op_im2col( + ggml_backend_sycl_context & ctx, ggml_tensor *dst); + +#endif // GGML_SYCL_IM2COL_HPP diff --git a/ggml/src/ggml-sycl/mmq.cpp b/ggml/src/ggml-sycl/mmq.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ffb272aa28378b33da19c5c9d9836e51507a618e --- /dev/null +++ b/ggml/src/ggml-sycl/mmq.cpp @@ -0,0 +1,3030 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#include "mmq.hpp" +#include "vecdotq.hpp" + +typedef void (*allocate_tiles_sycl_t)( + int** x_ql, + sycl::half2** x_dm, + int** x_qh, + int** x_sc); +typedef void (*load_tiles_sycl_t)( + const void* __restrict__ vx, + int* __restrict__ x_ql, + sycl::half2* __restrict__ x_dm, + int* __restrict__ x_qh, + int* __restrict__ x_sc, + const int& i_offset, + const int& i_max, + const int& k, + const int& blocks_per_row); +typedef float (*vec_dot_q_mul_mat_sycl_t)( + const int* __restrict__ x_ql, + const sycl::half2* __restrict__ x_dm, + const int* __restrict__ x_qh, + const int* __restrict__ x_sc, + const int* __restrict__ y_qs, + const sycl::half2* __restrict__ y_ms, + const int& i, + const int& j, + const int& k); + + +template +static __dpct_inline__ void +allocate_tiles_q4_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc, + int *tile_x_qs_q4_0, float *tile_x_d_q4_0) { + (void)x_qh; (void)x_sc; + + *x_ql = tile_x_qs_q4_0; + *x_dm = (sycl::half2 *)tile_x_d_q4_0; +} + +template +static __dpct_inline__ void +load_tiles_q4_0(const void *__restrict__ vx, int *__restrict__ x_ql, + sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh, + int *__restrict__ x_sc, const int &i_offset, const int &i_max, + const int &k, const int &blocks_per_row) { + (void)x_qh; (void)x_sc; + GGML_SYCL_ASSUME(i_offset >= 0); + GGML_SYCL_ASSUME(i_offset < nwarps); + GGML_SYCL_ASSUME(k >= 0); + GGML_SYCL_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI4_0; + const int kqsx = k % QI4_0; + + const block_q4_0 * bx0 = (const block_q4_0 *) vx; + + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); + // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) { + int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row; + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d; + } +} + +static __dpct_inline__ float vec_dot_q4_0_q8_1_mul_mat( + const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm, + const int *__restrict__ x_qh, const int *__restrict__ x_sc, + const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds, + const int &i, const int &j, const int &k) { + (void)x_qh; (void)x_sc; + + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + const float * x_dmf = (const float *) x_dm; + + int u[2*VDR_Q4_0_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE]; + } + + return vec_dot_q4_0_q8_1_impl + (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0], + y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); +} + +template +static __dpct_inline__ void +allocate_tiles_q4_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc, + int *tile_x_qs_q4_1, sycl::half2 *tile_x_dm_q4_1) { + (void)x_qh; (void)x_sc; + + *x_ql = tile_x_qs_q4_1; + *x_dm = tile_x_dm_q4_1; +} + + +template +static __dpct_inline__ void +load_tiles_q4_1(const void *__restrict__ vx, int *__restrict__ x_ql, + sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh, + int *__restrict__ x_sc, const int &i_offset, const int &i_max, + const int &k, const int &blocks_per_row) { + (void)x_qh; (void)x_sc; + + GGML_SYCL_ASSUME(i_offset >= 0); + GGML_SYCL_ASSUME(i_offset < nwarps); + GGML_SYCL_ASSUME(k >= 0); + GGML_SYCL_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI4_1; + const int kqsx = k % QI4_1; + + const block_q4_1 * bx0 = (const block_q4_1 *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) { + int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row; + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm; + } +} + +static __dpct_inline__ float vec_dot_q4_1_q8_1_mul_mat( + const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm, + const int *__restrict__ x_qh, const int *__restrict__ x_sc, + const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds, + const int &i, const int &j, const int &k) { + (void)x_qh; (void)x_sc; + + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + + int u[2*VDR_Q4_1_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE]; + } + + return vec_dot_q4_1_q8_1_impl + (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1], + y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); +} + +template +static __dpct_inline__ void +allocate_tiles_q5_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc, + int *tile_x_ql_q5_0, float *tile_x_d_q5_0) { + (void)x_qh; (void)x_sc; + + *x_ql = tile_x_ql_q5_0; + *x_dm = (sycl::half2 *)tile_x_d_q5_0; +} + +template +static __dpct_inline__ void +load_tiles_q5_0(const void *__restrict__ vx, int *__restrict__ x_ql, + sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh, + int *__restrict__ x_sc, const int &i_offset, const int &i_max, + const int &k, const int &blocks_per_row) { + (void)x_qh; (void)x_sc; + + GGML_SYCL_ASSUME(i_offset >= 0); + GGML_SYCL_ASSUME(i_offset < nwarps); + GGML_SYCL_ASSUME(k >= 0); + GGML_SYCL_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI5_0; + const int kqsx = k % QI5_0; + + const block_q5_0 * bx0 = (const block_q5_0 *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx; + + const int ql = get_int_from_uint8(bxi->qs, kqsx); + const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0)); + + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + qs0 = dpct::vectorized_binary( + qs0, 0x10101010, dpct::sub_sat()); // subtract 16 + + x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; + + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + qs1 = dpct::vectorized_binary( + qs1, 0x10101010, dpct::sub_sat()); // subtract 16 + + x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; + const int kbxd = k % blocks_per_tile_x_row; + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) { + int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row; + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d; + } +} + +static __dpct_inline__ float vec_dot_q5_0_q8_1_mul_mat( + const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm, + const int *__restrict__ x_qh, const int *__restrict__ x_sc, + const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds, + const int &i, const int &j, const int &k) { + (void)x_qh; (void)x_sc; + + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0; + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + int u[2*VDR_Q5_0_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE]; + } + + return vec_dot_q8_0_q8_1_impl + (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); +} + +template +static __dpct_inline__ void +allocate_tiles_q5_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc, + int *tile_x_ql_q5_1, sycl::half2 *tile_x_dm_q5_1) { + (void)x_qh; (void)x_sc; + + *x_ql = tile_x_ql_q5_1; + *x_dm = tile_x_dm_q5_1; +} + +template +static __dpct_inline__ void +load_tiles_q5_1(const void *__restrict__ vx, int *__restrict__ x_ql, + sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh, + int *__restrict__ x_sc, const int &i_offset, const int &i_max, + const int &k, const int &blocks_per_row) { + (void)x_qh; (void)x_sc; + + GGML_SYCL_ASSUME(i_offset >= 0); + GGML_SYCL_ASSUME(i_offset < nwarps); + GGML_SYCL_ASSUME(k >= 0); + GGML_SYCL_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI5_1; + const int kqsx = k % QI5_1; + + const block_q5_1 * bx0 = (const block_q5_1 *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx; + + const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); + const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1)); + + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + + x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; + + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + + x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) { + int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row; + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm; + } +} + +static __dpct_inline__ float vec_dot_q5_1_q8_1_mul_mat( + const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm, + const int *__restrict__ x_qh, const int *__restrict__ x_sc, + const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds, + const int &i, const int &j, const int &k) { + (void)x_qh; (void)x_sc; + + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1; + + int u[2*VDR_Q5_1_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE]; + } + + return vec_dot_q8_1_q8_1_impl + (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); +} + +template +static __dpct_inline__ void +allocate_tiles_q8_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc, + int *tile_x_qs_q8_0, float *tile_x_d_q8_0) { + (void)x_qh; (void)x_sc; + + *x_ql = tile_x_qs_q8_0; + *x_dm = (sycl::half2 *)tile_x_d_q8_0; +} + +template +static __dpct_inline__ void +load_tiles_q8_0(const void *__restrict__ vx, int *__restrict__ x_ql, + sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh, + int *__restrict__ x_sc, const int &i_offset, const int &i_max, + const int &k, const int &blocks_per_row) { + (void)x_qh; (void)x_sc; + + GGML_SYCL_ASSUME(i_offset >= 0); + GGML_SYCL_ASSUME(i_offset < nwarps); + GGML_SYCL_ASSUME(k >= 0); + GGML_SYCL_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI8_0; + const int kqsx = k % QI8_0; + float * x_dmf = (float *) x_dm; + + const block_q8_0 * bx0 = (const block_q8_0 *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI8_0; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) { + int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row; + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d; + } +} + +static __dpct_inline__ float vec_dot_q8_0_q8_1_mul_mat( + const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm, + const int *__restrict__ x_qh, const int *__restrict__ x_sc, + const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds, + const int &i, const int &j, const int &k) { + (void)x_qh; (void)x_sc; + + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + return vec_dot_q8_0_q8_1_impl + (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0], + y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]); +} + +template +static __dpct_inline__ void +allocate_tiles_q2_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc, + int *tile_x_ql_q2_K, sycl::half2 *tile_x_dm_q2_K, + int *tile_x_sc_q2_K) { + (void)x_qh; + + *x_ql = tile_x_ql_q2_K; + *x_dm = tile_x_dm_q2_K; + *x_sc = tile_x_sc_q2_K; +} + +template +static __dpct_inline__ void +load_tiles_q2_K(const void *__restrict__ vx, int *__restrict__ x_ql, + sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh, + int *__restrict__ x_sc, const int &i_offset, const int &i_max, + const int &k, const int &blocks_per_row) { + (void)x_qh; + + GGML_SYCL_ASSUME(i_offset >= 0); + GGML_SYCL_ASSUME(i_offset < nwarps); + GGML_SYCL_ASSUME(k >= 0); + GGML_SYCL_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI2_K; + const int kqsx = k % QI2_K; + + const block_q2_K * bx0 = (const block_q2_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI2_K; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) { + int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4); + + x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4)); + } +} + +#define VDR_Q2_K_Q8_1_MMQ 2 +// contiguous u/y values +static __dpct_inline__ float +vec_dot_q2_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u, + const uint8_t *__restrict__ scales, + const sycl::half2 &dm2, const float &d8) { + + int sumi_d = 0; + int sumi_m = 0; + +#pragma unroll + for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) { + int sumi_d_sc = 0; + + const int sc = scales[i0 / (QI8_1/2)]; + + // fill int with 4x m + int m = sc >> 4; + m |= m << 8; + m |= m << 16; + +#pragma unroll + for (int i = i0; i < i0 + QI8_1/2; ++i) { + sumi_d_sc = dpct::dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product + sumi_m = dpct::dp4a(m, u[i], + sumi_m); // multiply sum of q8_1 values with m + } + + sumi_d += sumi_d_sc * (sc & 0xF); + } + + const sycl::float2 dm2f = + dm2.convert(); + + return d8 * (dm2f.x() * sumi_d - dm2f.y() * sumi_m); +} + +static __dpct_inline__ float vec_dot_q2_K_q8_1_mul_mat( + const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm, + const int *__restrict__ x_qh, const int *__restrict__ x_sc, + const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds, + const int &i, const int &j, const int &k) { + (void)x_qh; + + const int kbx = k / QI2_K; + const int ky = (k % QI2_K) * QR2_K; + const float * y_df = (const float *) y_ds; + + int v[QR2_K*VDR_Q2_K_Q8_1_MMQ]; + + const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); + const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2)); + +#pragma unroll + for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) { + v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303; + } + + const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4; + + const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE; + return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]); +} + +template +static __dpct_inline__ void +allocate_tiles_q3_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc, + int *tile_x_ql_q3_K, sycl::half2 *tile_x_dm_q3_K, + int *tile_x_qh_q3_K, int *tile_x_sc_q3_K) { + + *x_ql = tile_x_ql_q3_K; + *x_dm = tile_x_dm_q3_K; + *x_qh = tile_x_qh_q3_K; + *x_sc = tile_x_sc_q3_K; +} + +template +static __dpct_inline__ void +load_tiles_q3_K(const void *__restrict__ vx, int *__restrict__ x_ql, + sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh, + int *__restrict__ x_sc, const int &i_offset, const int &i_max, + const int &k, const int &blocks_per_row) { + + GGML_SYCL_ASSUME(i_offset >= 0); + GGML_SYCL_ASSUME(i_offset < nwarps); + GGML_SYCL_ASSUME(k >= 0); + GGML_SYCL_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI3_K; + const int kqsx = k % QI3_K; + + const block_q3_K * bx0 = (const block_q3_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI3_K; + const int kbxd = k % blocks_per_tile_x_row; + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) { + int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) { + int i = i0 + i_offset * 2 + k / (WARP_SIZE/2); + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2); + + // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted + x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2)); + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4); + + const int ksc = k % (QI3_K/4); + + const int ksc_low = ksc % (QI3_K/8); + const int shift_low = 4 * (ksc / (QI3_K/8)); + const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F; + + const int ksc_high = QI3_K/8; + const int shift_high = 2 * ksc; + const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030; + + const int sc = dpct::vectorized_binary( + sc_low | sc_high, 0x20202020, dpct::sub_sat()); + + x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc; + } +} + +#define VDR_Q3_K_Q8_1_MMQ 2 +// contiguous u/y values +static __dpct_inline__ float +vec_dot_q3_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u, + const int8_t *__restrict__ scales, const float &d3, + const float &d8) { + + int sumi = 0; + +#pragma unroll + for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) { + int sumi_sc = 0; + + for (int i = i0; i < i0 + QI8_1/2; ++i) { + sumi_sc = dpct::dp4a(v[i], u[i], sumi_sc); // SIMD dot product + } + + sumi += sumi_sc * scales[i0 / (QI8_1/2)]; + } + + return d3*d8 * sumi; +} + +static __dpct_inline__ float vec_dot_q3_K_q8_1_mul_mat( + const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm, + const int *__restrict__ x_qh, const int *__restrict__ x_sc, + const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds, + const int &i, const int &j, const int &k) { + + const int kbx = k / QI3_K; + const int ky = (k % QI3_K) * QR3_K; + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; + + int v[QR3_K*VDR_Q3_K_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) { + const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); + const int shift = 2 * ((ky % 32) / 8); + const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303; + + const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); + const int vlh = (vh << 2) & 0x04040404; + + v[l] = dpct::vectorized_binary(vll, vlh, dpct::sub_sat()); + } + + const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE; + return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]); +} + +template +static __dpct_inline__ void +allocate_tiles_q4_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc, + int *tile_x_ql_q4_K, sycl::half2 *tile_x_dm_q4_K, + int *tile_x_sc_q4_K) { + (void)x_qh; + + *x_ql = tile_x_ql_q4_K; + *x_dm = tile_x_dm_q4_K; + *x_sc = tile_x_sc_q4_K; +} + +template +static __dpct_inline__ void +load_tiles_q4_K(const void *__restrict__ vx, int *__restrict__ x_ql, + sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh, + int *__restrict__ x_sc, const int &i_offset, const int &i_max, + const int &k, const int &blocks_per_row) { + (void)x_qh; + + GGML_SYCL_ASSUME(i_offset >= 0); + GGML_SYCL_ASSUME(i_offset < nwarps); + GGML_SYCL_ASSUME(k >= 0); + GGML_SYCL_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI4_K; // == 0 if QK_K == 256 + const int kqsx = k % QI4_K; // == k if QK_K == 256 + + const block_q4_K * bx0 = (const block_q4_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } + + constexpr int blocks_per_tile_x_row = QI4_K > WARP_SIZE ? 1 : WARP_SIZE / QI4_K; // == 1 if QK_K == 256 + const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) { + int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd; + +#if QK_K == 256 + x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm; +#else + x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]}; +#endif + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8); + + const int * scales = (const int *) bxi->scales; + + const int ksc = k % (WARP_SIZE/8); + + // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 + int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits + scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits + + x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; + } +} + + +#define VDR_Q4_K_Q8_1_MMQ 8 + +// contiguous u/y values +static __dpct_inline__ float vec_dot_q4_K_q8_1_impl_mmq( + const int *__restrict__ v, const int *__restrict__ u, + const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m, + const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) { + int sumi_d = 0; + +#pragma unroll + for (int j = 0; j < QI8_1; ++j) { + sumi_d = dpct::dp4a((v[j] >> (4 * i)) & 0x0F0F0F0F, + u[i * QI8_1 + j], sumi_d); // SIMD dot product + } + + const sycl::float2 ds8f = + ds8[i].convert(); + + sumf_d += ds8f.x() * (sc[i] * sumi_d); + sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val + } + + const sycl::float2 dm4f = + dm4.convert(); + + return dm4f.x() * sumf_d - dm4f.y() * sumf_m; +} + + +static __dpct_inline__ float vec_dot_q4_K_q8_1_mul_mat( + const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm, + const int *__restrict__ x_qh, const int *__restrict__ x_sc, + const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds, + const int &i, const int &j, const int &k) { + (void)x_qh; + + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8); + + const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE; + return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8, + x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]); +} + +template +static __dpct_inline__ void +allocate_tiles_q5_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc, + int *tile_x_ql_q5_K, sycl::half2 *tile_x_dm_q5_K, + int *tile_x_sc_q5_K) { + (void)x_qh; + + *x_ql = tile_x_ql_q5_K; + *x_dm = tile_x_dm_q5_K; + *x_sc = tile_x_sc_q5_K; +} + +template +static __dpct_inline__ void +load_tiles_q5_K(const void *__restrict__ vx, int *__restrict__ x_ql, + sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh, + int *__restrict__ x_sc, const int &i_offset, const int &i_max, + const int &k, const int &blocks_per_row) { + (void)x_qh; + + GGML_SYCL_ASSUME(i_offset >= 0); + GGML_SYCL_ASSUME(i_offset < nwarps); + GGML_SYCL_ASSUME(k >= 0); + GGML_SYCL_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI5_K; // == 0 if QK_K == 256 + const int kqsx = k % QI5_K; // == k if QK_K == 256 + + const block_q5_K * bx0 = (const block_q5_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx; + const int ky = QR5_K*kqsx; + + const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4)); + const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010; + const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010; + + const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0; + const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4); + + x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0; + x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1; + } + + constexpr int blocks_per_tile_x_row = QI5_K > WARP_SIZE ? 1 : WARP_SIZE / QI5_K; // == 1 if QK_K == 256 + const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) { + int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd; + +#if QK_K == 256 + x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm; +#endif + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8); + + const int * scales = (const int *) bxi->scales; + + const int ksc = k % (WARP_SIZE/8); + + // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 + int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits + scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits + + x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; + } +} + +#define VDR_Q5_K_Q8_1_MMQ 8 + +// contiguous u/y values +static __dpct_inline__ float vec_dot_q5_K_q8_1_impl_mmq( + const int *__restrict__ v, const int *__restrict__ u, + const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m, + const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) { + int sumi_d = 0; + +#pragma unroll + for (int j = 0; j < QI8_1; ++j) { + sumi_d = dpct::dp4a(v[i * QI8_1 + j], u[i * QI8_1 + j], + sumi_d); // SIMD dot product + } + + const sycl::float2 ds8f = + ds8[i].convert(); + + sumf_d += ds8f.x() * (sc[i] * sumi_d); + sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val + } + + const sycl::float2 dm4f = + dm4.convert(); + + return dm4f.x() * sumf_d - dm4f.y() * sumf_m; +} + +static __dpct_inline__ float vec_dot_q5_K_q8_1_mul_mat( + const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm, + const int *__restrict__ x_qh, const int *__restrict__ x_sc, + const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds, + const int &i, const int &j, const int &k) { + (void)x_qh; + + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8); + + const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k; + const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE; + return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8, + x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]); +} + +template +static __dpct_inline__ void +allocate_tiles_q6_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc, + int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_sc) { + (void)x_qh; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_sc = tile_x_sc; +} + +template +static __dpct_inline__ void +load_tiles_q6_K(const void *__restrict__ vx, int *__restrict__ x_ql, + sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh, + int *__restrict__ x_sc, const int &i_offset, const int &i_max, + const int &k, const int &blocks_per_row) { + (void)x_qh; + + GGML_SYCL_ASSUME(i_offset >= 0); + GGML_SYCL_ASSUME(i_offset < nwarps); + GGML_SYCL_ASSUME(k >= 0); + GGML_SYCL_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI6_K; // == 0 if QK_K == 256 + const int kqsx = k % QI6_K; // == k if QK_K == 256 + + const block_q6_K * bx0 = (const block_q6_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx; + const int ky = QR6_K*kqsx; + + const int ql = get_int_from_uint8(bxi->ql, kqsx); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4)); + const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030; + const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030; + + const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0; + const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2); + + x_ql[i * (2 * WARP_SIZE + 1) + kq0] = + dpct::vectorized_binary(ql0 | qh0, 0x20202020, + dpct::sub_sat()); + x_ql[i * (2 * WARP_SIZE + 1) + kq1] = + dpct::vectorized_binary(ql1 | qh1, 0x20202020, + dpct::sub_sat()); + } + + constexpr int blocks_per_tile_x_row = QI6_K > WARP_SIZE ? 1 : WARP_SIZE / QI6_K; // == 1 if QK_K == 256 + const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) { + int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + + if (need_check) { + i = sycl::min(i, i_max); + } + + const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4; + + x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8)); + } +} + +#define VDR_Q6_K_Q8_1_MMQ 8 + +// contiguous u/y values +static __dpct_inline__ float +vec_dot_q6_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u, + const int8_t *__restrict__ sc, const float &d6, + const float *__restrict__ d8) { + + float sumf_d = 0.0f; + +#pragma unroll + for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) { + sycl::int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale + +#pragma unroll + for (int i = i0; i < i0 + 2; ++i) { + sumi_d.x() = dpct::dp4a(v[2 * i + 0], u[2 * i + 0], + sumi_d.x()); // SIMD dot product + sumi_d.x() = dpct::dp4a(v[2 * i + 1], u[2 * i + 1], + sumi_d.x()); // SIMD dot product + + sumi_d.y() = dpct::dp4a(v[2 * i + 4], u[2 * i + 4], + sumi_d.y()); // SIMD dot product + sumi_d.y() = dpct::dp4a(v[2 * i + 5], u[2 * i + 5], + sumi_d.y()); // SIMD dot product + } + + sumf_d += d8[i0 / 4] * + (sc[i0 / 2 + 0] * sumi_d.x() + sc[i0 / 2 + 1] * sumi_d.y()); + } + + return d6 * sumf_d; +} + +static __dpct_inline__ float vec_dot_q6_K_q8_1_mul_mat( + const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm, + const int *__restrict__ x_qh, const int *__restrict__ x_sc, + const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds, + const int &i, const int &j, const int &k) { + (void)x_qh; + + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]); + + const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k; + const int index_y = j * WARP_SIZE + (QR6_K*k) % WARP_SIZE; + return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); +} + +template +/* +DPCT1110:8: The total declared local variable size in device function mul_mat_q +exceeds 128 bytes and may cause high register pressure. Consult with your +hardware vendor to find the total register size available and adjust the code, +or use smaller sub-group size to avoid high register pressure. +*/ +static __dpct_inline__ void +mul_mat_q(const void *__restrict__ vx, const void *__restrict__ vy, + float *__restrict__ dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, + int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_qh, + int *tile_x_sc, const sycl::nd_item<3> &item_ct1, int *tile_y_qs, + sycl::half2 *tile_y_ds) { + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_col_y = nrows_y / QK8_1; + const int blocks_per_warp = WARP_SIZE / qi; + + const int & ncols_dst = ncols_y; + + const int row_dst_0 = item_ct1.get_group(2) * mmq_y; + const int & row_x_0 = row_dst_0; + + const int col_dst_0 = item_ct1.get_group(1) * mmq_x; + const int & col_y_0 = col_dst_0; + + float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}}; + + for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) { + + load_tiles(x + row_x_0 * blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, + tile_x_qh, tile_x_sc, item_ct1.get_local_id(1), + nrows_x - row_x_0 - 1, item_ct1.get_local_id(2), + blocks_per_row_x); + +#pragma unroll + for (int ir = 0; ir < qr; ++ir) { + const int kqs = ir * WARP_SIZE + item_ct1.get_local_id(2); + const int kbxd = kqs / QI8_1; + +#pragma unroll + for (int i = 0; i < mmq_x; i += nwarps) { + const int col_y_eff = dpct::min( + (unsigned int)(col_y_0 + item_ct1.get_local_id(1) + i), + ncols_y - 1); // to prevent out-of-bounds memory accesses + + const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd]; + + const int index_y = (item_ct1.get_local_id(1) + i) * WARP_SIZE + + kqs % WARP_SIZE; + tile_y_qs[index_y] = get_int_from_int8_aligned( + by0->qs, item_ct1.get_local_id(2) % QI8_1); + } + +#pragma unroll + for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) { + const int ids = + (ids0 + item_ct1.get_local_id(1) * QI8_1 + + item_ct1.get_local_id(2) / (WARP_SIZE / QI8_1)) % + mmq_x; + const int kby = item_ct1.get_local_id(2) % (WARP_SIZE / QI8_1); + const int col_y_eff = sycl::min(col_y_0 + ids, ncols_y - 1); + + // if the sum is not needed it's faster to transform the scale to f32 ahead of time + const sycl::half2 *dsi_src = + &y[col_y_eff * blocks_per_col_y + ib0 * (qk / QK8_1) + + ir * (WARP_SIZE / QI8_1) + kby] + .ds; + sycl::half2 *dsi_dst = + &tile_y_ds[ids * (WARP_SIZE / QI8_1) + kby]; + if (need_sum) { + *dsi_dst = *dsi_src; + } else { + float * dfi_dst = (float *) dsi_dst; + *dfi_dst = (*dsi_src)[0]; + } + } + + /* + DPCT1118:9: SYCL group functions and algorithms must be encountered + in converged control flow. You may need to adjust the code. + */ + /* + DPCT1065:56: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + +// #pragma unroll // unrolling this loop causes too much register pressure + for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) { +#pragma unroll + for (int j = 0; j < mmq_x; j += nwarps) { +#pragma unroll + for (int i = 0; i < mmq_y; i += WARP_SIZE) { + sum[i / WARP_SIZE][j / nwarps] += vec_dot( + tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, + tile_y_qs, tile_y_ds, item_ct1.get_local_id(2) + i, + item_ct1.get_local_id(1) + j, k); + } + } + } + + /* + DPCT1118:10: SYCL group functions and algorithms must be encountered + in converged control flow. You may need to adjust the code. + */ + /* + DPCT1065:57: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + } + } + +#pragma unroll + for (int j = 0; j < mmq_x; j += nwarps) { + const int col_dst = col_dst_0 + j + item_ct1.get_local_id(1); + + if (col_dst >= ncols_dst) { + return; + } + +#pragma unroll + for (int i = 0; i < mmq_y; i += WARP_SIZE) { + const int row_dst = row_dst_0 + item_ct1.get_local_id(2) + i; + + if (row_dst >= nrows_dst) { + continue; + } + + dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps]; + } + } +} + +#define MMQ_X_Q4_0_RDNA2 64 +#define MMQ_Y_Q4_0_RDNA2 128 +#define NWARPS_Q4_0_RDNA2 8 +#define MMQ_X_Q4_0_RDNA1 64 +#define MMQ_Y_Q4_0_RDNA1 64 +#define NWARPS_Q4_0_RDNA1 8 +#if defined(SYCL_USE_XMX) +#define MMQ_X_Q4_0_AMPERE 4 +#define MMQ_Y_Q4_0_AMPERE 32 +#define NWARPS_Q4_0_AMPERE 4 +#else +#define MMQ_X_Q4_0_AMPERE 64 +#define MMQ_Y_Q4_0_AMPERE 128 +#define NWARPS_Q4_0_AMPERE 4 +#endif +#define MMQ_X_Q4_0_PASCAL 64 +#define MMQ_Y_Q4_0_PASCAL 64 +#define NWARPS_Q4_0_PASCAL 8 + +template static void + mul_mat_q4_0( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, + const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_0, float *tile_x_d_q4_0, + int *tile_y_qs, sycl::half2 *tile_y_ds) { + int * tile_x_ql = nullptr; + sycl::half2 *tile_x_dm = nullptr; + int * tile_x_qh = nullptr; + int * tile_x_sc = nullptr; + +//sycl_todo: change according to hardware + + const int mmq_x = MMQ_X_Q4_0_AMPERE; + const int mmq_y = MMQ_Y_Q4_0_AMPERE; + const int nwarps = NWARPS_Q4_0_AMPERE; + allocate_tiles_q4_0(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc, + tile_x_qs_q4_0, tile_x_d_q4_0); + mul_mat_q, VDR_Q4_0_Q8_1_MMQ, + vec_dot_q4_0_q8_1_mul_mat>( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql, + tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds); +} + +#define MMQ_X_Q4_1_RDNA2 64 +#define MMQ_Y_Q4_1_RDNA2 128 +#define NWARPS_Q4_1_RDNA2 8 +#define MMQ_X_Q4_1_RDNA1 64 +#define MMQ_Y_Q4_1_RDNA1 64 +#define NWARPS_Q4_1_RDNA1 8 +#if defined(SYCL_USE_XMX) +#define MMQ_X_Q4_1_AMPERE 4 +#define MMQ_Y_Q4_1_AMPERE 32 +#define NWARPS_Q4_1_AMPERE 4 +#else +#define MMQ_X_Q4_1_AMPERE 64 +#define MMQ_Y_Q4_1_AMPERE 128 +#define NWARPS_Q4_1_AMPERE 4 +#endif +#define MMQ_X_Q4_1_PASCAL 64 +#define MMQ_Y_Q4_1_PASCAL 64 +#define NWARPS_Q4_1_PASCAL 8 + +template static void + mul_mat_q4_1( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, + const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_1, + sycl::half2 *tile_x_dm_q4_1, int *tile_y_qs, sycl::half2 *tile_y_ds) { + int * tile_x_ql = nullptr; + sycl::half2 *tile_x_dm = nullptr; + int * tile_x_qh = nullptr; + int * tile_x_sc = nullptr; + +//sycl_todo: change according to hardware + const int mmq_x = MMQ_X_Q4_1_AMPERE; + const int mmq_y = MMQ_Y_Q4_1_AMPERE; + const int nwarps = NWARPS_Q4_1_AMPERE; + allocate_tiles_q4_1(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc, + tile_x_qs_q4_1, tile_x_dm_q4_1); + mul_mat_q, VDR_Q4_1_Q8_1_MMQ, + vec_dot_q4_1_q8_1_mul_mat>( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql, + tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds); +} + +#define MMQ_X_Q5_0_RDNA2 64 +#define MMQ_Y_Q5_0_RDNA2 128 +#define NWARPS_Q5_0_RDNA2 8 +#define MMQ_X_Q5_0_RDNA1 64 +#define MMQ_Y_Q5_0_RDNA1 64 +#define NWARPS_Q5_0_RDNA1 8 +#if defined(SYCL_USE_XMX) +#define MMQ_X_Q5_0_AMPERE 4 +#define MMQ_Y_Q5_0_AMPERE 32 +#define NWARPS_Q5_0_AMPERE 4 +#else +#define MMQ_X_Q5_0_AMPERE 128 +#define MMQ_Y_Q5_0_AMPERE 64 +#define NWARPS_Q5_0_AMPERE 4 +#endif +#define MMQ_X_Q5_0_PASCAL 64 +#define MMQ_Y_Q5_0_PASCAL 64 +#define NWARPS_Q5_0_PASCAL 8 + +template static void + mul_mat_q5_0( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, + const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_0, float *tile_x_d_q5_0, + int *tile_y_qs, sycl::half2 *tile_y_ds) { + int * tile_x_ql = nullptr; + sycl::half2 *tile_x_dm = nullptr; + int * tile_x_qh = nullptr; + int * tile_x_sc = nullptr; + +//sycl_todo: change according to hardware + const int mmq_x = MMQ_X_Q5_0_AMPERE; + const int mmq_y = MMQ_Y_Q5_0_AMPERE; + const int nwarps = NWARPS_Q5_0_AMPERE; + allocate_tiles_q5_0(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc, + tile_x_ql_q5_0, tile_x_d_q5_0); + mul_mat_q, VDR_Q5_0_Q8_1_MMQ, + vec_dot_q5_0_q8_1_mul_mat>( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql, + tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds); +} + +#define MMQ_X_Q5_1_RDNA2 64 +#define MMQ_Y_Q5_1_RDNA2 128 +#define NWARPS_Q5_1_RDNA2 8 +#define MMQ_X_Q5_1_RDNA1 64 +#define MMQ_Y_Q5_1_RDNA1 64 +#define NWARPS_Q5_1_RDNA1 8 +#if defined(SYCL_USE_XMX) +#define MMQ_X_Q5_1_AMPERE 4 +#define MMQ_Y_Q5_1_AMPERE 32 +#define NWARPS_Q5_1_AMPERE 4 +#else +#define MMQ_X_Q5_1_AMPERE 128 +#define MMQ_Y_Q5_1_AMPERE 64 +#define NWARPS_Q5_1_AMPERE 4 +#endif +#define MMQ_X_Q5_1_PASCAL 64 +#define MMQ_Y_Q5_1_PASCAL 64 +#define NWARPS_Q5_1_PASCAL 8 + +template static void +mul_mat_q5_1( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, + const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_1, + sycl::half2 *tile_x_dm_q5_1, int *tile_y_qs, sycl::half2 *tile_y_ds) { + int * tile_x_ql = nullptr; + sycl::half2 *tile_x_dm = nullptr; + int * tile_x_qh = nullptr; + int * tile_x_sc = nullptr; + +//sycl_todo: change according to hardware + const int mmq_x = MMQ_X_Q5_1_AMPERE; + const int mmq_y = MMQ_Y_Q5_1_AMPERE; + const int nwarps = NWARPS_Q5_1_AMPERE; + allocate_tiles_q5_1(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc, + tile_x_ql_q5_1, tile_x_dm_q5_1); + mul_mat_q, VDR_Q5_1_Q8_1_MMQ, + vec_dot_q5_1_q8_1_mul_mat>( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql, + tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds); +} + +#define MMQ_X_Q8_0_RDNA2 64 +#define MMQ_Y_Q8_0_RDNA2 128 +#define NWARPS_Q8_0_RDNA2 8 +#define MMQ_X_Q8_0_RDNA1 64 +#define MMQ_Y_Q8_0_RDNA1 64 +#define NWARPS_Q8_0_RDNA1 8 +#if defined(SYCL_USE_XMX) +#define MMQ_X_Q8_0_AMPERE 4 +#define MMQ_Y_Q8_0_AMPERE 32 +#define NWARPS_Q8_0_AMPERE 4 +#else +#define MMQ_X_Q8_0_AMPERE 128 +#define MMQ_Y_Q8_0_AMPERE 64 +#define NWARPS_Q8_0_AMPERE 4 +#endif +#define MMQ_X_Q8_0_PASCAL 64 +#define MMQ_Y_Q8_0_PASCAL 64 +#define NWARPS_Q8_0_PASCAL 8 + +template static void + mul_mat_q8_0( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, + const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q8_0, float *tile_x_d_q8_0, + int *tile_y_qs, sycl::half2 *tile_y_ds) { + int * tile_x_ql = nullptr; + sycl::half2 *tile_x_dm = nullptr; + int * tile_x_qh = nullptr; + int * tile_x_sc = nullptr; + +//sycl_todo: change according to hardware + const int mmq_x = MMQ_X_Q8_0_AMPERE; + const int mmq_y = MMQ_Y_Q8_0_AMPERE; + const int nwarps = NWARPS_Q8_0_AMPERE; + allocate_tiles_q8_0(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc, + tile_x_qs_q8_0, tile_x_d_q8_0); + mul_mat_q, VDR_Q8_0_Q8_1_MMQ, + vec_dot_q8_0_q8_1_mul_mat>( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql, + tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds); +} + +#define MMQ_X_Q2_K_RDNA2 64 +#define MMQ_Y_Q2_K_RDNA2 128 +#define NWARPS_Q2_K_RDNA2 8 +#define MMQ_X_Q2_K_RDNA1 128 +#define MMQ_Y_Q2_K_RDNA1 32 +#define NWARPS_Q2_K_RDNA1 8 +#if defined(SYCL_USE_XMX) +#define MMQ_X_Q2_K_AMPERE 4 +#define MMQ_Y_Q2_K_AMPERE 32 +#define NWARPS_Q2_K_AMPERE 4 +#else +#define MMQ_X_Q2_K_AMPERE 64 +#define MMQ_Y_Q2_K_AMPERE 128 +#define NWARPS_Q2_K_AMPERE 4 +#endif +#define MMQ_X_Q2_K_PASCAL 64 +#define MMQ_Y_Q2_K_PASCAL 64 +#define NWARPS_Q2_K_PASCAL 8 + +template static void +mul_mat_q2_K( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, + const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q2_K, + sycl::half2 *tile_x_dm_q2_K, int *tile_x_sc_q2_K, int *tile_y_qs, + sycl::half2 *tile_y_ds) { + int * tile_x_ql = nullptr; + sycl::half2 *tile_x_dm = nullptr; + int * tile_x_qh = nullptr; + int * tile_x_sc = nullptr; + +//sycl_todo: change according to hardware + const int mmq_x = MMQ_X_Q2_K_AMPERE; + const int mmq_y = MMQ_Y_Q2_K_AMPERE; + const int nwarps = NWARPS_Q2_K_AMPERE; + allocate_tiles_q2_K(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc, + tile_x_ql_q2_K, tile_x_dm_q2_K, tile_x_sc_q2_K); + mul_mat_q, VDR_Q2_K_Q8_1_MMQ, + vec_dot_q2_K_q8_1_mul_mat>( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql, + tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds); +} + +#define MMQ_X_Q3_K_RDNA2 128 +#define MMQ_Y_Q3_K_RDNA2 64 +#define NWARPS_Q3_K_RDNA2 8 +#define MMQ_X_Q3_K_RDNA1 32 +#define MMQ_Y_Q3_K_RDNA1 128 +#define NWARPS_Q3_K_RDNA1 8 +#if defined(SYCL_USE_XMX) +#define MMQ_X_Q3_K_AMPERE 4 +#define MMQ_Y_Q3_K_AMPERE 32 +#define NWARPS_Q3_K_AMPERE 4 +#else +#define MMQ_X_Q3_K_AMPERE 128 +#define MMQ_Y_Q3_K_AMPERE 128 +#define NWARPS_Q3_K_AMPERE 4 +#endif +#define MMQ_X_Q3_K_PASCAL 64 +#define MMQ_Y_Q3_K_PASCAL 64 +#define NWARPS_Q3_K_PASCAL 8 + +template static void +mul_mat_q3_K( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, + const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q3_K, + sycl::half2 *tile_x_dm_q3_K, int *tile_x_qh_q3_K, int *tile_x_sc_q3_K, + int *tile_y_qs, sycl::half2 *tile_y_ds) { + int * tile_x_ql = nullptr; + sycl::half2 *tile_x_dm = nullptr; + int * tile_x_qh = nullptr; + int * tile_x_sc = nullptr; + +//sycl_todo: change according to hardware + const int mmq_x = MMQ_X_Q3_K_AMPERE; + const int mmq_y = MMQ_Y_Q3_K_AMPERE; + const int nwarps = NWARPS_Q3_K_AMPERE; + allocate_tiles_q3_K(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc, + tile_x_ql_q3_K, tile_x_dm_q3_K, tile_x_qh_q3_K, + tile_x_sc_q3_K); + mul_mat_q, VDR_Q3_K_Q8_1_MMQ, + vec_dot_q3_K_q8_1_mul_mat>( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql, + tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds); +} + +#define MMQ_X_Q4_K_RDNA2 64 +#define MMQ_Y_Q4_K_RDNA2 128 +#define NWARPS_Q4_K_RDNA2 8 +#define MMQ_X_Q4_K_RDNA1 32 +#define MMQ_Y_Q4_K_RDNA1 64 +#define NWARPS_Q4_K_RDNA1 8 +#if defined(SYCL_USE_XMX) +#define MMQ_X_Q4_K_AMPERE 4 +#define MMQ_Y_Q4_K_AMPERE 32 +#define NWARPS_Q4_K_AMPERE 4 +#else +#define MMQ_X_Q4_K_AMPERE 64 +#define MMQ_Y_Q4_K_AMPERE 128 +#define NWARPS_Q4_K_AMPERE 4 +#endif +#define MMQ_X_Q4_K_PASCAL 64 +#define MMQ_Y_Q4_K_PASCAL 64 +#define NWARPS_Q4_K_PASCAL 8 + +template static void + mul_mat_q4_K( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, + const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q4_K, + sycl::half2 *tile_x_dm_q4_K, int *tile_x_sc_q4_K, int *tile_y_qs, + sycl::half2 *tile_y_ds) { + int * tile_x_ql = nullptr; + sycl::half2 *tile_x_dm = nullptr; + int * tile_x_qh = nullptr; + int * tile_x_sc = nullptr; + +//sycl_todo: change according to hardware + const int mmq_x = MMQ_X_Q4_K_AMPERE; + const int mmq_y = MMQ_Y_Q4_K_AMPERE; + const int nwarps = NWARPS_Q4_K_AMPERE; + allocate_tiles_q4_K(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc, + tile_x_ql_q4_K, tile_x_dm_q4_K, tile_x_sc_q4_K); + mul_mat_q, VDR_Q4_K_Q8_1_MMQ, + vec_dot_q4_K_q8_1_mul_mat>( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql, + tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds); +} + +#define MMQ_X_Q5_K_RDNA2 64 +#define MMQ_Y_Q5_K_RDNA2 128 +#define NWARPS_Q5_K_RDNA2 8 +#define MMQ_X_Q5_K_RDNA1 32 +#define MMQ_Y_Q5_K_RDNA1 64 +#define NWARPS_Q5_K_RDNA1 8 +#if defined(SYCL_USE_XMX) +#define MMQ_X_Q5_K_AMPERE 4 +#define MMQ_Y_Q5_K_AMPERE 32 +#define NWARPS_Q5_K_AMPERE 4 +#else +#define MMQ_X_Q5_K_AMPERE 64 +#define MMQ_Y_Q5_K_AMPERE 128 +#define NWARPS_Q5_K_AMPERE 4 +#endif +#define MMQ_X_Q5_K_PASCAL 64 +#define MMQ_Y_Q5_K_PASCAL 64 +#define NWARPS_Q5_K_PASCAL 8 + +template static void +mul_mat_q5_K( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, + const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_K, + sycl::half2 *tile_x_dm_q5_K, int *tile_x_sc_q5_K, int *tile_y_qs, + sycl::half2 *tile_y_ds) { + int * tile_x_ql = nullptr; + sycl::half2 *tile_x_dm = nullptr; + int * tile_x_qh = nullptr; + int * tile_x_sc = nullptr; + +//sycl_todo: change according to hardware + const int mmq_x = MMQ_X_Q5_K_AMPERE; + const int mmq_y = MMQ_Y_Q5_K_AMPERE; + const int nwarps = NWARPS_Q5_K_AMPERE; + allocate_tiles_q5_K(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc, + tile_x_ql_q5_K, tile_x_dm_q5_K, tile_x_sc_q5_K); + mul_mat_q, VDR_Q5_K_Q8_1_MMQ, + vec_dot_q5_K_q8_1_mul_mat>( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql, + tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds); +} + +#define MMQ_X_Q6_K_RDNA2 64 +#define MMQ_Y_Q6_K_RDNA2 128 +#define NWARPS_Q6_K_RDNA2 8 +#define MMQ_X_Q6_K_RDNA1 32 +#define MMQ_Y_Q6_K_RDNA1 64 +#define NWARPS_Q6_K_RDNA1 8 +#if defined(SYCL_USE_XMX) +#define MMQ_X_Q6_K_AMPERE 4 +#define MMQ_Y_Q6_K_AMPERE 32 +#define NWARPS_Q6_K_AMPERE 4 +#else +#define MMQ_X_Q6_K_AMPERE 64 +#define MMQ_Y_Q6_K_AMPERE 64 +#define NWARPS_Q6_K_AMPERE 4 +#endif +#define MMQ_X_Q6_K_PASCAL 64 +#define MMQ_Y_Q6_K_PASCAL 64 +#define NWARPS_Q6_K_PASCAL 8 + +template static void + mul_mat_q6_K( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, + const sycl::nd_item<3> &item_ct1, int *tile_x_ql, sycl::half2 *tile_x_dm, + int *tile_x_sc, int *tile_y_qs, sycl::half2 *tile_y_ds) { + // int * tile_x_ql = nullptr; + // sycl::half2 *tile_x_dm = nullptr; + int * tile_x_qh = nullptr; + // int * tile_x_sc = nullptr; + +//sycl_todo: change according to hardware + const int mmq_x = MMQ_X_Q6_K_AMPERE; + const int mmq_y = MMQ_Y_Q6_K_AMPERE; + const int nwarps = NWARPS_Q6_K_AMPERE; + allocate_tiles_q6_K(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc, + tile_x_ql, tile_x_dm, tile_x_sc); + mul_mat_q, VDR_Q6_K_Q8_1_MMQ, + vec_dot_q6_K_q8_1_mul_mat>( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql, + tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds); +} + +static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols_x, + const int nrows_x, const int ncols_y, + const int nrows_y, const int nrows_dst, + dpct::queue_ptr stream) try { + + int id; + SYCL_CHECK( + CHECK_TRY_ERROR(id = get_current_device_id())); + const int compute_capability = ggml_sycl_info().devices[id].cc; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= VER_GEN13) { + mmq_x = MMQ_X_Q4_0_RDNA2; + mmq_y = MMQ_Y_Q4_0_RDNA2; + nwarps = NWARPS_Q4_0_RDNA2; + } else if (compute_capability >= VER_GEN12) { + mmq_x = MMQ_X_Q4_0_RDNA1; + mmq_y = MMQ_Y_Q4_0_RDNA1; + nwarps = NWARPS_Q4_0_RDNA1; + } else if (compute_capability >= VER_GEN9) { + mmq_x = MMQ_X_Q4_0_AMPERE; + mmq_y = MMQ_Y_Q4_0_AMPERE; + nwarps = NWARPS_Q4_0_AMPERE; + } else if (compute_capability >= VER_4VEC) { + mmq_x = MMQ_X_Q4_0_PASCAL; + mmq_y = MMQ_Y_Q4_0_PASCAL; + nwarps = NWARPS_Q4_0_PASCAL; + } else { + GGML_ABORT("fatal error"); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const sycl::range<3> block_nums(1, block_num_y, block_num_x); + const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + /* + DPCT1049:20: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor tile_x_qs_q4_0_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_d_q4_0_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0), + cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q4_0( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, + nrows_dst, item_ct1, + get_pointer(tile_x_qs_q4_0_acc_ct1), + get_pointer(tile_x_d_q4_0_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); + }); + }); + } + } else { + const bool need_check = true; + /* + DPCT1049:21: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor tile_x_qs_q4_0_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_d_q4_0_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0), + cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q4_0( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, + nrows_dst, item_ct1, + get_pointer(tile_x_qs_q4_0_acc_ct1), + get_pointer(tile_x_d_q4_0_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); + }); + }); + } + } +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols_x, + const int nrows_x, const int ncols_y, + const int nrows_y, const int nrows_dst, + dpct::queue_ptr stream) try { + + int id; + SYCL_CHECK( + CHECK_TRY_ERROR(id = get_current_device_id())); + const int compute_capability = ggml_sycl_info().devices[id].cc; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= VER_GEN13) { + mmq_x = MMQ_X_Q4_1_RDNA2; + mmq_y = MMQ_Y_Q4_1_RDNA2; + nwarps = NWARPS_Q4_1_RDNA2; + } else if (compute_capability >= VER_GEN12) { + mmq_x = MMQ_X_Q4_1_RDNA1; + mmq_y = MMQ_Y_Q4_1_RDNA1; + nwarps = NWARPS_Q4_1_RDNA1; + } else if (compute_capability >= VER_GEN9) { + mmq_x = MMQ_X_Q4_1_AMPERE; + mmq_y = MMQ_Y_Q4_1_AMPERE; + nwarps = NWARPS_Q4_1_AMPERE; + } else if (compute_capability >= VER_4VEC) { + mmq_x = MMQ_X_Q4_1_PASCAL; + mmq_y = MMQ_Y_Q4_1_PASCAL; + nwarps = NWARPS_Q4_1_PASCAL; + } else { + GGML_ABORT("fatal error"); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const sycl::range<3> block_nums(1, block_num_y, block_num_x); + const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + /* + DPCT1049:22: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor tile_x_qs_q4_1_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh); + sycl::local_accessor tile_x_dm_q4_1_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1), + cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q4_1( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, + nrows_dst, item_ct1, + get_pointer(tile_x_qs_q4_1_acc_ct1), + get_pointer(tile_x_dm_q4_1_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); + }); + }); + } + } else { + const bool need_check = true; + /* + DPCT1049:23: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor tile_x_qs_q4_1_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh); + sycl::local_accessor tile_x_dm_q4_1_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1), + cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q4_1( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, + nrows_dst, item_ct1, + get_pointer(tile_x_qs_q4_1_acc_ct1), + get_pointer(tile_x_dm_q4_1_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); + }); + }); + } + } +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols_x, + const int nrows_x, const int ncols_y, + const int nrows_y, const int nrows_dst, + dpct::queue_ptr stream) try { + + int id; + SYCL_CHECK( + CHECK_TRY_ERROR(id = get_current_device_id())); + const int compute_capability = ggml_sycl_info().devices[id].cc; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= VER_GEN13) { + mmq_x = MMQ_X_Q5_0_RDNA2; + mmq_y = MMQ_Y_Q5_0_RDNA2; + nwarps = NWARPS_Q5_0_RDNA2; + } else if (compute_capability >= VER_GEN12) { + mmq_x = MMQ_X_Q5_0_RDNA1; + mmq_y = MMQ_Y_Q5_0_RDNA1; + nwarps = NWARPS_Q5_0_RDNA1; + } else if (compute_capability >= VER_GEN9) { + mmq_x = MMQ_X_Q5_0_AMPERE; + mmq_y = MMQ_Y_Q5_0_AMPERE; + nwarps = NWARPS_Q5_0_AMPERE; + } else if (compute_capability >= VER_4VEC) { + mmq_x = MMQ_X_Q5_0_PASCAL; + mmq_y = MMQ_Y_Q5_0_PASCAL; + nwarps = NWARPS_Q5_0_PASCAL; + } else { + GGML_ABORT("fatal error"); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const sycl::range<3> block_nums(1, block_num_y, block_num_x); + const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + /* + DPCT1049:24: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor tile_x_ql_q5_0_acc_ct1( + sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_d_q5_0_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0), + cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q5_0( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, + nrows_dst, item_ct1, + get_pointer(tile_x_ql_q5_0_acc_ct1), + get_pointer(tile_x_d_q5_0_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); + }); + }); + } + } else { + const bool need_check = true; + /* + DPCT1049:25: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor tile_x_ql_q5_0_acc_ct1( + sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_d_q5_0_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0), + cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q5_0( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, + nrows_dst, item_ct1, + get_pointer(tile_x_ql_q5_0_acc_ct1), + get_pointer(tile_x_d_q5_0_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); + }); + }); + } + } +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols_x, + const int nrows_x, const int ncols_y, + const int nrows_y, const int nrows_dst, + dpct::queue_ptr stream) try { + + int id; + SYCL_CHECK( + CHECK_TRY_ERROR(id = get_current_device_id())); + const int compute_capability = ggml_sycl_info().devices[id].cc; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= VER_GEN13) { + mmq_x = MMQ_X_Q5_1_RDNA2; + mmq_y = MMQ_Y_Q5_1_RDNA2; + nwarps = NWARPS_Q5_1_RDNA2; + } else if (compute_capability >= VER_GEN12) { + mmq_x = MMQ_X_Q5_1_RDNA1; + mmq_y = MMQ_Y_Q5_1_RDNA1; + nwarps = NWARPS_Q5_1_RDNA1; + } else if (compute_capability >= VER_GEN9) { + mmq_x = MMQ_X_Q5_1_AMPERE; + mmq_y = MMQ_Y_Q5_1_AMPERE; + nwarps = NWARPS_Q5_1_AMPERE; + } else if (compute_capability >= VER_4VEC) { + mmq_x = MMQ_X_Q5_1_PASCAL; + mmq_y = MMQ_Y_Q5_1_PASCAL; + nwarps = NWARPS_Q5_1_PASCAL; + } else { + GGML_ABORT("fatal error"); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const sycl::range<3> block_nums(1, block_num_y, block_num_x); + const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + /* + DPCT1049:26: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor tile_x_ql_q5_1_acc_ct1( + sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_q5_1_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1), + cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q5_1( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, + nrows_dst, item_ct1, + get_pointer(tile_x_ql_q5_1_acc_ct1), + get_pointer(tile_x_dm_q5_1_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); + }); + }); + } + } else { + const bool need_check = true; + /* + DPCT1049:27: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor tile_x_ql_q5_1_acc_ct1( + sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_q5_1_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1), + cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q5_1( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, + nrows_dst, item_ct1, + get_pointer(tile_x_ql_q5_1_acc_ct1), + get_pointer(tile_x_dm_q5_1_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); + }); + }); + } + } +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols_x, + const int nrows_x, const int ncols_y, + const int nrows_y, const int nrows_dst, + dpct::queue_ptr stream) try { + + int id; + SYCL_CHECK( + CHECK_TRY_ERROR(id = get_current_device_id())); + const int compute_capability = ggml_sycl_info().devices[id].cc; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= VER_GEN13) { + mmq_x = MMQ_X_Q8_0_RDNA2; + mmq_y = MMQ_Y_Q8_0_RDNA2; + nwarps = NWARPS_Q8_0_RDNA2; + } else if (compute_capability >= VER_GEN12) { + mmq_x = MMQ_X_Q8_0_RDNA1; + mmq_y = MMQ_Y_Q8_0_RDNA1; + nwarps = NWARPS_Q8_0_RDNA1; + } else if (compute_capability >= VER_GEN9) { + mmq_x = MMQ_X_Q8_0_AMPERE; + mmq_y = MMQ_Y_Q8_0_AMPERE; + nwarps = NWARPS_Q8_0_AMPERE; + } else if (compute_capability >= VER_4VEC) { + mmq_x = MMQ_X_Q8_0_PASCAL; + mmq_y = MMQ_Y_Q8_0_PASCAL; + nwarps = NWARPS_Q8_0_PASCAL; + } else { + GGML_ABORT("fatal error"); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const sycl::range<3> block_nums(1, block_num_y, block_num_x); + const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + /* + DPCT1049:28: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor tile_x_qs_q8_0_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_d_q8_0_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0), + cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q8_0( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, + nrows_dst, item_ct1, + get_pointer(tile_x_qs_q8_0_acc_ct1), + get_pointer(tile_x_d_q8_0_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); + }); + }); + } + } else { + const bool need_check = true; + /* + DPCT1049:29: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor tile_x_qs_q8_0_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_d_q8_0_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0), + cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q8_0( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, + nrows_dst, item_ct1, + get_pointer(tile_x_qs_q8_0_acc_ct1), + get_pointer(tile_x_d_q8_0_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); + }); + }); + } + } +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols_x, + const int nrows_x, const int ncols_y, + const int nrows_y, const int nrows_dst, + dpct::queue_ptr stream) try { + + int id; + SYCL_CHECK( + CHECK_TRY_ERROR(id = get_current_device_id())); + const int compute_capability = ggml_sycl_info().devices[id].cc; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= VER_GEN13) { + mmq_x = MMQ_X_Q2_K_RDNA2; + mmq_y = MMQ_Y_Q2_K_RDNA2; + nwarps = NWARPS_Q2_K_RDNA2; + } else if (compute_capability >= VER_GEN12) { + mmq_x = MMQ_X_Q2_K_RDNA1; + mmq_y = MMQ_Y_Q2_K_RDNA1; + nwarps = NWARPS_Q2_K_RDNA1; + } else if (compute_capability >= VER_GEN9) { + mmq_x = MMQ_X_Q2_K_AMPERE; + mmq_y = MMQ_Y_Q2_K_AMPERE; + nwarps = NWARPS_Q2_K_AMPERE; + } else if (compute_capability >= VER_4VEC) { + mmq_x = MMQ_X_Q2_K_PASCAL; + mmq_y = MMQ_Y_Q2_K_PASCAL; + nwarps = NWARPS_Q2_K_PASCAL; + } else { + GGML_ABORT("fatal error"); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const sycl::range<3> block_nums(1, block_num_y, block_num_x); + const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + /* + DPCT1049:30: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor tile_x_ql_q2_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_q2_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K), + cgh); + sycl::local_accessor tile_x_sc_q2_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q2_K( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, + nrows_dst, item_ct1, + get_pointer(tile_x_ql_q2_K_acc_ct1), + get_pointer(tile_x_dm_q2_K_acc_ct1), + get_pointer(tile_x_sc_q2_K_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); + }); + }); + } + } else { + const bool need_check = true; + /* + DPCT1049:31: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor tile_x_ql_q2_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_q2_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K), + cgh); + sycl::local_accessor tile_x_sc_q2_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q2_K( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, + nrows_dst, item_ct1, + get_pointer(tile_x_ql_q2_K_acc_ct1), + get_pointer(tile_x_dm_q2_K_acc_ct1), + get_pointer(tile_x_sc_q2_K_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); + }); + }); + } + } +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols_x, + const int nrows_x, const int ncols_y, + const int nrows_y, const int nrows_dst, + dpct::queue_ptr stream) try { + +#if QK_K == 256 + + int id; + SYCL_CHECK( + CHECK_TRY_ERROR(id = get_current_device_id())); + const int compute_capability = ggml_sycl_info().devices[id].cc; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= VER_GEN13) { + mmq_x = MMQ_X_Q3_K_RDNA2; + mmq_y = MMQ_Y_Q3_K_RDNA2; + nwarps = NWARPS_Q3_K_RDNA2; + } else if (compute_capability >= VER_GEN12) { + mmq_x = MMQ_X_Q3_K_RDNA1; + mmq_y = MMQ_Y_Q3_K_RDNA1; + nwarps = NWARPS_Q3_K_RDNA1; + } else if (compute_capability >= VER_GEN9) { + mmq_x = MMQ_X_Q3_K_AMPERE; + mmq_y = MMQ_Y_Q3_K_AMPERE; + nwarps = NWARPS_Q3_K_AMPERE; + } else if (compute_capability >= VER_4VEC) { + mmq_x = MMQ_X_Q3_K_PASCAL; + mmq_y = MMQ_Y_Q3_K_PASCAL; + nwarps = NWARPS_Q3_K_PASCAL; + } else { + GGML_ABORT("fatal error"); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const sycl::range<3> block_nums(1, block_num_y, block_num_x); + const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + /* + DPCT1049:32: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor tile_x_ql_q3_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_q3_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K), + cgh); + sycl::local_accessor tile_x_qh_q3_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh); + sycl::local_accessor tile_x_sc_q3_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q3_K( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, + nrows_dst, item_ct1, + get_pointer(tile_x_ql_q3_K_acc_ct1), + get_pointer(tile_x_dm_q3_K_acc_ct1), + get_pointer(tile_x_qh_q3_K_acc_ct1), + get_pointer(tile_x_sc_q3_K_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); + }); + }); + } + } else { + const bool need_check = true; + /* + DPCT1049:33: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor tile_x_ql_q3_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_q3_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K), + cgh); + sycl::local_accessor tile_x_qh_q3_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh); + sycl::local_accessor tile_x_sc_q3_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q3_K( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, + nrows_dst, item_ct1, + get_pointer(tile_x_ql_q3_K_acc_ct1), + get_pointer(tile_x_dm_q3_K_acc_ct1), + get_pointer(tile_x_qh_q3_K_acc_ct1), + get_pointer(tile_x_sc_q3_K_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); + }); + }); + } + } +#endif +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols_x, + const int nrows_x, const int ncols_y, + const int nrows_y, const int nrows_dst, + dpct::queue_ptr stream) try { + + int id; + SYCL_CHECK( + CHECK_TRY_ERROR(id = get_current_device_id())); + const int compute_capability = ggml_sycl_info().devices[id].cc; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= VER_GEN13) { + mmq_x = MMQ_X_Q4_K_RDNA2; + mmq_y = MMQ_Y_Q4_K_RDNA2; + nwarps = NWARPS_Q4_K_RDNA2; + } else if (compute_capability >= VER_GEN12) { + mmq_x = MMQ_X_Q4_K_RDNA1; + mmq_y = MMQ_Y_Q4_K_RDNA1; + nwarps = NWARPS_Q4_K_RDNA1; + } else if (compute_capability >= VER_GEN9) { + mmq_x = MMQ_X_Q4_K_AMPERE; + mmq_y = MMQ_Y_Q4_K_AMPERE; + nwarps = NWARPS_Q4_K_AMPERE; + } else if (compute_capability >= VER_4VEC) { + mmq_x = MMQ_X_Q4_K_PASCAL; + mmq_y = MMQ_Y_Q4_K_PASCAL; + nwarps = NWARPS_Q4_K_PASCAL; + } else { + GGML_ABORT("fatal error"); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const sycl::range<3> block_nums(1, block_num_y, block_num_x); + const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + /* + DPCT1049:34: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor tile_x_ql_q4_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_q4_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K), + cgh); + sycl::local_accessor tile_x_sc_q4_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q4_K( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, + nrows_dst, item_ct1, + get_pointer(tile_x_ql_q4_K_acc_ct1), + get_pointer(tile_x_dm_q4_K_acc_ct1), + get_pointer(tile_x_sc_q4_K_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); + }); + }); + } + } else { + const bool need_check = true; + /* + DPCT1049:35: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor tile_x_ql_q4_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_q4_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K), + cgh); + sycl::local_accessor tile_x_sc_q4_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q4_K( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, + nrows_dst, item_ct1, + get_pointer(tile_x_ql_q4_K_acc_ct1), + get_pointer(tile_x_dm_q4_K_acc_ct1), + get_pointer(tile_x_sc_q4_K_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); + }); + }); + } + } +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols_x, + const int nrows_x, const int ncols_y, + const int nrows_y, const int nrows_dst, + dpct::queue_ptr stream) try { + + int id; + SYCL_CHECK( + CHECK_TRY_ERROR(id = get_current_device_id())); + const int compute_capability = ggml_sycl_info().devices[id].cc; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= VER_GEN13) { + mmq_x = MMQ_X_Q5_K_RDNA2; + mmq_y = MMQ_Y_Q5_K_RDNA2; + nwarps = NWARPS_Q5_K_RDNA2; + } else if (compute_capability >= VER_GEN12) { + mmq_x = MMQ_X_Q5_K_RDNA1; + mmq_y = MMQ_Y_Q5_K_RDNA1; + nwarps = NWARPS_Q5_K_RDNA1; + } else if (compute_capability >= VER_GEN9) { + mmq_x = MMQ_X_Q5_K_AMPERE; + mmq_y = MMQ_Y_Q5_K_AMPERE; + nwarps = NWARPS_Q5_K_AMPERE; + } else if (compute_capability >= VER_4VEC) { + mmq_x = MMQ_X_Q5_K_PASCAL; + mmq_y = MMQ_Y_Q5_K_PASCAL; + nwarps = NWARPS_Q5_K_PASCAL; + } else { + GGML_ABORT("fatal error"); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const sycl::range<3> block_nums(1, block_num_y, block_num_x); + const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + /* + DPCT1049:36: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor tile_x_ql_q5_K_acc_ct1( + sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_q5_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K), + cgh); + sycl::local_accessor tile_x_sc_q5_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q5_K( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, + nrows_dst, item_ct1, + get_pointer(tile_x_ql_q5_K_acc_ct1), + get_pointer(tile_x_dm_q5_K_acc_ct1), + get_pointer(tile_x_sc_q5_K_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); + }); + }); + } + } else { + const bool need_check = true; + /* + DPCT1049:37: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor tile_x_ql_q5_K_acc_ct1( + sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_q5_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K), + cgh); + sycl::local_accessor tile_x_sc_q5_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q5_K( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, + nrows_dst, item_ct1, + get_pointer(tile_x_ql_q5_K_acc_ct1), + get_pointer(tile_x_dm_q5_K_acc_ct1), + get_pointer(tile_x_sc_q5_K_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); + }); + }); + } + } +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols_x, + const int nrows_x, const int ncols_y, + const int nrows_y, const int nrows_dst, + dpct::queue_ptr stream) try { + + int id; + SYCL_CHECK( + CHECK_TRY_ERROR(id = get_current_device_id())); + const int compute_capability = ggml_sycl_info().devices[id].cc; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= VER_GEN13) { + mmq_x = MMQ_X_Q6_K_RDNA2; + mmq_y = MMQ_Y_Q6_K_RDNA2; + nwarps = NWARPS_Q6_K_RDNA2; + } else if (compute_capability >= VER_GEN12) { + mmq_x = MMQ_X_Q6_K_RDNA1; + mmq_y = MMQ_Y_Q6_K_RDNA1; + nwarps = NWARPS_Q6_K_RDNA1; + } else if (compute_capability >= VER_GEN9) { + mmq_x = MMQ_X_Q6_K_AMPERE; + mmq_y = MMQ_Y_Q6_K_AMPERE; + nwarps = NWARPS_Q6_K_AMPERE; + } else if (compute_capability >= VER_4VEC) { + mmq_x = MMQ_X_Q6_K_PASCAL; + mmq_y = MMQ_Y_Q6_K_PASCAL; + nwarps = NWARPS_Q6_K_PASCAL; + } else { + GGML_ABORT("fatal error"); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const sycl::range<3> block_nums(1, block_num_y, block_num_x); + const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + /* + DPCT1049:38: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor tile_x_ql_acc_ct1( + sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K), + cgh); + sycl::local_accessor tile_x_sc_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q6_K( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, + nrows_dst, item_ct1, + get_pointer(tile_x_ql_acc_ct1), + get_pointer(tile_x_dm_acc_ct1), + get_pointer(tile_x_sc_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); + }); + }); + } + } else { + const bool need_check = true; + /* + DPCT1049:39: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor tile_x_ql_acc_ct1( + sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K), + cgh); + sycl::local_accessor tile_x_sc_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q6_K( + vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, + nrows_dst, item_ct1, + get_pointer(tile_x_ql_acc_ct1), + get_pointer(tile_x_dm_acc_ct1), + get_pointer(tile_x_sc_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); + }); + }); + } + } +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +void ggml_sycl_op_mul_mat_q( + ggml_backend_sycl_context & ctx, + const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, + const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, + float *dst_dd_i, const int64_t row_low, const int64_t row_high, + const int64_t src1_ncols, const int64_t src1_padded_row_size, + const dpct::queue_ptr &stream) try { + + const int64_t ne00 = src0->ne[0]; + + const int64_t ne10 = src1->ne[0]; + GGML_ASSERT(ne10 % QK8_1 == 0); + + const int64_t ne0 = dst->ne[0]; + + const int64_t row_diff = row_high - row_low; + + int device_id; + SYCL_CHECK( + CHECK_TRY_ERROR(device_id = get_current_device_id())); + + // the main device has a larger memory buffer to hold the results from all GPUs + // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into + const int64_t nrows_dst = device_id == ctx.device ? ne0 : row_diff; + + switch (src0->type) { + case GGML_TYPE_Q4_0: + ggml_mul_mat_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q4_1: + ggml_mul_mat_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q5_0: + ggml_mul_mat_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q5_1: + ggml_mul_mat_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q8_0: + ggml_mul_mat_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q2_K: + ggml_mul_mat_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q3_K: + ggml_mul_mat_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q4_K: + ggml_mul_mat_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q5_K: + ggml_mul_mat_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q6_K: + ggml_mul_mat_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + default: + GGML_ABORT("fatal error"); + } + + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_ddf_i); +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} diff --git a/ggml/src/ggml-sycl/mmq.hpp b/ggml/src/ggml-sycl/mmq.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3f5297aaa5373172a1e07d13d8baa56bd0651284 --- /dev/null +++ b/ggml/src/ggml-sycl/mmq.hpp @@ -0,0 +1,33 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_MMQ_HPP +#define GGML_SYCL_MMQ_HPP + +#include "common.hpp" + +void ggml_sycl_op_mul_mat_q( + ggml_backend_sycl_context & ctx, + const ggml_tensor* src0, + const ggml_tensor* src1, + ggml_tensor* dst, + const char* src0_dd_i, + const float* src1_ddf_i, + const char* src1_ddq_i, + float* dst_dd_i, + const int64_t row_low, + const int64_t row_high, + const int64_t src1_ncols, + const int64_t src1_padded_row_size, + const dpct::queue_ptr& stream); + +#endif // GGML_SYCL_MMQ_HPP diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5b7f0640749377ae08ffdded652ee16ec5511c97 --- /dev/null +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -0,0 +1,1134 @@ +#include "mmvq.hpp" + +#include "ggml.h" +#include "common.hpp" +#include "quants.hpp" +#include "vecdotq.hpp" + +template +static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols, const int nrows, const sycl::nd_item<3> & nd_item) { + using block_type = ggml_sycl_reordered::block_q_t; + using block_traits = typename block_type::traits; + + const auto sg = nd_item.get_sub_group(); + const int sg_range = sg.get_group_linear_range(); + const int workgroup_id = nd_item.get_group_linear_id(); + const int sg_id = sg.get_group_linear_id(); + const int row = workgroup_id * sg_range + sg_id; + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / block_traits::qk; + constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi); + constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq; + const int nblocks = nrows * (ncols / block_traits::qk); + + static_assert(blocks_per_subgroup > 0); + static_assert(block_elements_per_subgroup > 0); + + float partial_sum = 0.0f; + for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) { + const int ibx = row * blocks_per_row + i; // x block index + + const auto bx_offset = block_type::get_block_offset(ibx, nblocks); + const auto d_offset = block_type::get_d_offset(nrows, ncols, ibx); + // Y block index that aligns with ibx + const int iby = i * block_type::block_to_q8_1_ratio(); + const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1; + const sycl::half2* q8_1_ds_ptr = (const sycl::half2*)((const char*)vy + ncols + iby * sizeof(sycl::half2)); + +#pragma unroll + for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) { + // x block quant index when casting the quants to int + const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup); + + partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs); + } + } + + auto sum = sycl::reduce_over_group(nd_item.get_sub_group(), partial_sum, std::plus<>()); + + if (sg.leader()) { + dst[row] = sum; + } +} + +template +static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols, const int nrows, const sycl::nd_item<3> & item_ct1) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1) / qi; // Ensuring blocks_per_warp > 0 + + assert(blocks_per_warp > 0); + + // partial sum for each thread + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; i += blocks_per_warp) { + const int ibx = row * blocks_per_row + i; // x block index + + const int iby = i * (qk / QK8_1); // y block index that aligns with ibx + + for (size_t elem = 0; elem < qi / vdr; elem += WARP_SIZE) { + const int iqs = elem + vdr * (item_ct1.get_local_id(2) % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs); + } + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + +template +static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx, + const void *__restrict__ vy, + float *__restrict__ dst, const int ncols, + const int nrows, + const sycl::nd_item<3> &item_ct1) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + assert(blocks_per_warp>0); + +// partial sum for each thread + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row*blocks_per_row + i; // x block index + + const int iby = i * (qk/QK8_1); // y block index that aligns with ibx + + const int iqs = + vdr * + (item_ct1.get_local_id(2) % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_iq2_xxs_q8_1(&x[ibx], &y[iby], iqs, iq2xxs_grid, ksigns_iq2xs, kmask_iq2xs); + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + +template +static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx, + const void *__restrict__ vy, + float *__restrict__ dst, const int ncols, + const int nrows, + const sycl::nd_item<3> &item_ct1) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + assert(blocks_per_warp>0); +// partial sum for each thread + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row*blocks_per_row + i; // x block index + + const int iby = i * (qk/QK8_1); // y block index that aligns with ibx + + const int iqs = + vdr * + (item_ct1.get_local_id(2) % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_iq2_xs_q8_1(&x[ibx], &y[iby], iqs, iq2xs_grid, ksigns64); + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + +template +static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx, + const void *__restrict__ vy, + float *__restrict__ dst, const int ncols, + const int nrows, + const sycl::nd_item<3> &item_ct1) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + assert(blocks_per_warp>0); +// partial sum for each thread + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row*blocks_per_row + i; // x block index + + const int iby = i * (qk/QK8_1); // y block index that aligns with ibx + + const int iqs = + vdr * + (item_ct1.get_local_id(2) % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_iq2_s_q8_1(&x[ibx], &y[iby], iqs); + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + +template +static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx, + const void *__restrict__ vy, + float *__restrict__ dst, const int ncols, + const int nrows, + const sycl::nd_item<3> &item_ct1) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + assert(blocks_per_warp>0); +// partial sum for each thread + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row*blocks_per_row + i; // x block index + + const int iby = i * (qk/QK8_1); // y block index that aligns with ibx + + const int iqs = + vdr * + (item_ct1.get_local_id(2) % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_iq3_xxs_q8_1(&x[ibx], &y[iby], iqs, iq3xxs_grid, ksigns64); + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + +template +static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx, + const void *__restrict__ vy, + float *__restrict__ dst, const int ncols, + const int nrows, + const sycl::nd_item<3> &item_ct1) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + assert(blocks_per_warp>0); +// partial sum for each thread + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row*blocks_per_row + i; // x block index + + const int iby = i * (qk/QK8_1); // y block index that aligns with ibx + + const int iqs = + vdr * + (item_ct1.get_local_id(2) % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_iq3_s_q8_1(&x[ibx], &y[iby], iqs, iq3s_grid); + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + +template +static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx, + const void *__restrict__ vy, + float *__restrict__ dst, const int ncols, + const int nrows, + const sycl::nd_item<3> &item_ct1) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + assert(blocks_per_warp>0); +// partial sum for each thread + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row*blocks_per_row + i; // x block index + + const int iby = i * (qk/QK8_1); // y block index that aligns with ibx + + const int iqs = + vdr * + (item_ct1.get_local_id(2) % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_gpu); + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + +template +static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx, + const void *__restrict__ vy, + float *__restrict__ dst, const int ncols, + const int nrows, + const sycl::nd_item<3> &item_ct1) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + assert(blocks_per_warp>0); +// partial sum for each thread + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row*blocks_per_row + i; // x block index + + const int iby = i * (qk/QK8_1); // y block index that aligns with ibx + + const int iqs = + vdr * + (item_ct1.get_local_id(2) % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_iq1_m_q8_1(&x[ibx], &y[iby], iqs); + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + +template +static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx, + const void *__restrict__ vy, + float *__restrict__ dst, const int ncols, + const int nrows, + const sycl::nd_item<3> &item_ct1) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + assert(blocks_per_warp>0); +// partial sum for each thread + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row*blocks_per_row + i; // x block index + + const int iby = i * (qk/QK8_1); // y block index that aligns with ibx + + const int iqs = + vdr * + (item_ct1.get_local_id(2) % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_iq4_nl_q8_1(&x[ibx], &y[iby], iqs); + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + + +template +static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx, + const void *__restrict__ vy, + float *__restrict__ dst, const int ncols, + const int nrows, + const sycl::nd_item<3> &item_ct1) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + assert(blocks_per_warp>0); +// partial sum for each thread + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row*blocks_per_row + i; // x block index + + const int iby = i * (qk/QK8_1); // y block index that aligns with ibx + + const int iqs = + vdr * + (item_ct1.get_local_id(2) % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_iq4_xs_q8_1(&x[ibx], &y[iby], iqs); + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + +static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, + const int nrows, dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_0 == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder>(vx, vy, dst, ncols, nrows, + nd_item); + }); + }); +} + +static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_0 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + + { + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} + +static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_1 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + + stream->submit([&](sycl::handler &cgh) { + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} + +static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK5_0 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + + stream->submit([&](sycl::handler &cgh) { + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} + +static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK5_1 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + + stream->submit([&](sycl::handler &cgh) { + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} + +static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK8_0 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + + stream->submit([&](sycl::handler &cgh) { + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} + +static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + + stream->submit([&](sycl::handler &cgh) { + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} + +static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + + stream->submit([&](sycl::handler &cgh) { + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} + +static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + + stream->submit([&](sycl::handler &cgh) { + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} + +static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, + const int nrows, dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder>(vx, vy, dst, ncols, + nrows, nd_item); + }); + }); +} + + +static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + + stream->submit([&](sycl::handler &cgh) { + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} + +static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, + const int nrows, dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder>(vx, vy, dst, ncols, nrows, + nd_item); + }); + }); +} +static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + + stream->submit([&](sycl::handler &cgh) { + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} + + +static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq2_xxs_q8_1( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} + +static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq2_xs_q8_1( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} + +static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq2_s_q8_1( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} + +static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq3_xxs_q8_1( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} + +static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq3_s_q8_1( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} + +static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq1_s_q8_1( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} + +static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq1_m_q8_1( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} + +static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_NL == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq4_nl_q8_1( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} + +static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq4_xs_q8_1( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} + +void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, + ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, + const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_col_size, + const dpct::queue_ptr & stream) { + const int64_t ne10 = src1->ne[0]; + GGML_ASSERT(ne10 % QK8_1 == 0); + + const int64_t ne00 = src0->ne[0]; + const int64_t row_diff = row_high - row_low; + + int id; + SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id())); + const size_t q8_1_ts = sizeof(block_q8_1); + const size_t q8_1_bs = QK8_1; + // the main device has a larger memory buffer to hold the results from all GPUs + // nrows_dst == nrows of the matrix that the kernel writes into + + for (int i = 0; i < src1_ncols; i++) { + const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs; + const char * src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset; + float * dst_dd_i_bs = dst_dd_i + i * dst->ne[0]; + switch (src0->type) { + case GGML_TYPE_Q4_0: + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n"); + reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } else { + GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl\n"); + mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + break; + case GGML_TYPE_Q4_1: + mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_0: + mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_1: + mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q8_0: + mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q2_K: + mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q3_K: + mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q4_K: + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } else { + GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_K_q8_1_sycl\n"); + mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + break; + case GGML_TYPE_Q5_K: + mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q6_K: + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } else { + GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_k_q8_1_sycl\n"); + mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + break; + case GGML_TYPE_IQ1_S: + mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ1_M: + mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ2_XXS: + mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ2_XS: + mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ2_S: + mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ3_XXS: + mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ3_S: + mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ4_NL: + mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ4_XS: + mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + default: + GGML_ABORT("fatal error"); + } + } + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_ddf_i); + GGML_UNUSED(ctx); +} diff --git a/ggml/src/ggml-sycl/mmvq.hpp b/ggml/src/ggml-sycl/mmvq.hpp new file mode 100644 index 0000000000000000000000000000000000000000..049b43d453532087761d6e13d4ebf4961d5e49f4 --- /dev/null +++ b/ggml/src/ggml-sycl/mmvq.hpp @@ -0,0 +1,27 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_MMVQ_HPP +#define GGML_SYCL_MMVQ_HPP + +#include "common.hpp" + + +void ggml_sycl_op_mul_mat_vec_q( + ggml_backend_sycl_context & ctx, + const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, + const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, + float *dst_dd_i, const int64_t row_low, const int64_t row_high, + const int64_t src1_ncols, const int64_t src1_padded_row_size, + const dpct::queue_ptr &stream); + +#endif // GGML_SYCL_MMVQ_HPP diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4ec1416849c7e718f27f5cd40ccf1946ced3926f --- /dev/null +++ b/ggml/src/ggml-sycl/norm.cpp @@ -0,0 +1,501 @@ +#include "norm.hpp" +#include "ggml-sycl/common.hpp" +#include "ggml-sycl/presets.hpp" + +static void norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, + const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) { + + const int nrows = item_ct1.get_group_range(2); + const int nchannels = item_ct1.get_group_range(1); + + const int nthreads = item_ct1.get_local_range(2); + const int sample = item_ct1.get_group(0); + const int channel = item_ct1.get_group(1); + const int row = item_ct1.get_group(2); + + const int tid = item_ct1.get_local_id(2); + const int nwarps = nthreads / WARP_SIZE; + + const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row}); + const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row}); + + x += strided_offset; + dst += packed_offset; + + sycl::float2 mean_var = sycl::float2(0.f, 0.f); + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[col]; + mean_var.x() += xi; + mean_var.y() += xi * xi; + } + + // sum up partial sums + mean_var = warp_reduce_sum(mean_var, item_ct1); + if (block_size > WARP_SIZE) { + const auto sub_group = item_ct1.get_sub_group(); + const auto sg_id = sub_group.get_group_linear_id(); + const auto wi_in_sg = sub_group.get_local_linear_id(); + if (wi_in_sg == 0) { + s_sum[sg_id] = mean_var; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + mean_var = 0.f; + const size_t nreduce = ceil_div(nwarps, WARP_SIZE); + for (size_t i = 0; i < nreduce; i += 1) + { + mean_var += s_sum[wi_in_sg + i * WARP_SIZE]; + } + mean_var = warp_reduce_sum(mean_var, item_ct1); + } + + const float mean = mean_var.x() / ncols; + const float var = mean_var.y() / ncols - mean * mean; + const float inv_std = sycl::rsqrt(var + eps); + + for (int col = tid; col < ncols; col += block_size) { + dst[col] = (x[col] - mean) * inv_std; + } +} + +static void group_norm_f32(const float* x, float* dst, const int group_size, const int ne_elements, const float eps, + const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) { + int start = item_ct1.get_group(2) * group_size; + int end = start + group_size; + const int nthreads = item_ct1.get_local_range(2); + const int nwarps = nthreads / WARP_SIZE; + start += item_ct1.get_local_id(2); + size_t nreduce = nwarps / WARP_SIZE; + + if (end >= ne_elements) { + end = ne_elements; + } + + float tmp = 0.0f; // partial sum for thread in warp + + for (int j = start; j < end; j += block_size) { + tmp += x[j]; + } + + tmp = warp_reduce_sum(tmp, item_ct1); + if (block_size > WARP_SIZE) { + + int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; + int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + /* + DPCT1118:1: SYCL group functions and algorithms must be encountered in + converged control flow. You may need to adjust the code. + */ + /* + DPCT1065:54: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + tmp = 0.f; + for (size_t i = 0; i < nreduce; i += 1) + { + tmp += s_sum[lane_id + i * WARP_SIZE]; + } + tmp = warp_reduce_sum(tmp, item_ct1); + } + + float mean = tmp / group_size; + tmp = 0.0f; + + for (int j = start; j < end; j += block_size) { + float xi = x[j] - mean; + dst[j] = xi; + tmp += xi * xi; + } + + tmp = warp_reduce_sum(tmp, item_ct1); + if (block_size > WARP_SIZE) { + + int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; + int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + /* + DPCT1118:2: SYCL group functions and algorithms must be encountered in + converged control flow. You may need to adjust the code. + */ + /* + DPCT1065:55: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + tmp = 0.f; + for (size_t i = 0; i < nreduce; i += 1) + { + tmp += s_sum[lane_id + i * WARP_SIZE]; + } + tmp = warp_reduce_sum(tmp, item_ct1); + } + + float variance = tmp / group_size; + float scale = sycl::rsqrt(variance + eps); + for (int j = start; j < end; j += block_size) { + dst[j] *= scale; + } +} + +static void rms_norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, + const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) { + + const int nrows = item_ct1.get_group_range(2); + const int nchannels = item_ct1.get_group_range(1); + + const int sample = item_ct1.get_group(0); + const int channel = item_ct1.get_group(1); + const int row = item_ct1.get_group(2); + + const int nthreads = item_ct1.get_local_range(2); + + const int tid = item_ct1.get_local_id(2); + const int nwarps = nthreads / WARP_SIZE; + + const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row}); + const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row}); + + x += strided_offset; + dst += packed_offset; + + + float tmp = 0.0f; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[col]; + tmp += xi * xi; + } + + // sum up partial sums + tmp = warp_reduce_sum(tmp, item_ct1); + if (block_size > WARP_SIZE) { + const auto sub_group = item_ct1.get_sub_group(); + const auto sg_id = sub_group.get_group_linear_id(); + const auto wi_in_sg = sub_group.get_local_linear_id(); + if (wi_in_sg == 0) { + s_sum[sg_id] = tmp; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + const size_t nreduce = ceil_div(nwarps, WARP_SIZE); + tmp = 0.f; + for (size_t i = 0; i < nreduce; i += 1) + { + tmp += s_sum[wi_in_sg + i * WARP_SIZE]; + } + tmp = warp_reduce_sum(tmp, item_ct1); + } + + const float mean = tmp / ncols; + const float scale = sycl::rsqrt(mean + eps); + + for (int col = tid; col < ncols; col += block_size) { + dst[col] = scale * x[col]; + } +} + +static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps, + const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + const int tid = item_ct1.get_local_id(2); + const int nthreads = item_ct1.get_local_range(2); + const int nwarps = nthreads / WARP_SIZE; + float tmp = 0.0f; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[row * ncols + col]; + tmp += xi * xi; + } + + // sum up partial sums + tmp = warp_reduce_sum(tmp, item_ct1); + if (block_size > WARP_SIZE) { + + int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; + int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + /* + DPCT1118:3: SYCL group functions and algorithms must be encountered in + converged control flow. You may need to adjust the code. + */ + item_ct1.barrier(sycl::access::fence_space::local_space); + size_t nreduce = nwarps / WARP_SIZE; + tmp = 0.f; + for (size_t i = 0; i < nreduce; i += 1) + { + tmp += s_sum[lane_id + i * WARP_SIZE]; + } + tmp = warp_reduce_sum(tmp, item_ct1); + } + + const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps)); + + for (int col = tid; col < ncols; col += block_size) { + dst[row * ncols + col] = scale * x[row * ncols + col]; + } +} + +static void norm_f32_sycl(const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, + const float eps, queue_ptr stream, int device) { + + const sycl::range<3> global_dims(nsamples, nchannels, nrows); + GGML_ASSERT(ncols % WARP_SIZE == 0); + if (ncols < 1024) { + const sycl::range<3> block_dims(1, 1, WARP_SIZE); + stream->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<3>(global_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE); + }); + }); + } + else { + const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; + assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0); + const sycl::range<3> block_dims(1, 1, work_group_size); + /* + DPCT1049:17: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor s_sum_acc_ct1( + sycl::range<1>(work_group_size / WARP_SIZE), cgh); + cgh.parallel_for( + sycl::nd_range<3>(global_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size); + }); + }); + } +} + +static void group_norm_f32_sycl(const float* x, float* dst, + const int num_groups, const float eps, const int group_size, + const int ne_elements, queue_ptr stream, int device) { + if (group_size < 1024) { + const sycl::range<3> block_dims(1, 1, WARP_SIZE); + stream->submit([&](sycl::handler& cgh) { + const float eps_ct4 = eps; + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, + block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + group_norm_f32( + x, dst, group_size, ne_elements, eps_ct4, item_ct1, + nullptr, WARP_SIZE); + }); + }); + } + else { + const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; + assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0); + const sycl::range<3> block_dims(1, 1, work_group_size); + /* + DPCT1049:18: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE), + cgh); + + const float eps_ct4 = eps; + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, + block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + group_norm_f32(x, dst, group_size, ne_elements, + eps_ct4, item_ct1, + get_pointer(s_sum_acc_ct1), work_group_size); + }); + }); + } +} + +static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) { + GGML_ASSERT(ncols % WARP_SIZE == 0); + // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE); + + const sycl::range<3> global_dims(nsamples, nchannels, nrows); + if (ncols < 1024) { + const sycl::range<3> block_dims(1, 1, WARP_SIZE); + stream->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<3>(global_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE); + }); + }); + } + else { + const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; + assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0); + const sycl::range<3> block_dims(1, 1, work_group_size); + /* + DPCT1049:19: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE), + cgh); + cgh.parallel_for( + sycl::nd_range<3>(global_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size); + }); + }); + } +} + +static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols, + const int nrows, const float eps, + queue_ptr stream, int device) { + GGML_ASSERT(ncols % WARP_SIZE == 0); + // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE); + if (ncols < 1024) { + const sycl::range<3> block_dims(1, 1, WARP_SIZE); + stream->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, + block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + l2_norm_f32(x, dst, ncols, eps, item_ct1, + nullptr, WARP_SIZE); + }); + }); + } + else { + const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; + assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0); + const sycl::range<3> block_dims(1, 1, work_group_size); + /* + DPCT1049:19: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE), + cgh); + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, + block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + l2_norm_f32(x, dst, ncols, eps, item_ct1, + get_pointer(s_sum_acc_ct1), work_group_size); + }); + }); + } +} + +void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_TENSOR_UNARY_OP_LOCALS + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); + const size_t ts0 = ggml_type_size(src0->type); + GGML_ASSERT(nb00 == ts0); + const int64_t s01 = nb01 / ts0; + const int64_t s02 = nb02 / ts0; + const int64_t s03 = nb03 / ts0; + + norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device); +} + +void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { + + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + int num_groups = dst->op_params[0]; + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); + + float eps; + memcpy(&eps, dst->op_params + 1, sizeof(float)); + + int group_size = dst->src[0]->ne[0] * dst->src[0]->ne[1] * ((dst->src[0]->ne[2] + num_groups - 1) / num_groups); + group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, dst->src[0]->ne[0] * dst->src[0]->ne[1] * dst->src[0]->ne[2], main_stream, ctx.device); +} + +void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + GGML_TENSOR_UNARY_OP_LOCALS + const size_t ts0 = ggml_type_size(src0->type); + GGML_ASSERT(nb00 == ts0); + const int64_t s01 = nb01 / ts0; + const int64_t s02 = nb02 / ts0; + const int64_t s03 = nb03 / ts0; + rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device); +} + +void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { + + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const int64_t ne00 = dst->src[0]->ne[0]; + const int64_t nrows = ggml_nrows(dst->src[0]); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device); + +} diff --git a/ggml/src/ggml-sycl/norm.hpp b/ggml/src/ggml-sycl/norm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..612cd67cf9183d502b88004dc950087ca5562dd4 --- /dev/null +++ b/ggml/src/ggml-sycl/norm.hpp @@ -0,0 +1,26 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_NORM_HPP +#define GGML_SYCL_NORM_HPP + +#include "common.hpp" + +void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst); + +void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst); + +void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst); + +void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst); + +#endif // GGML_SYCL_NORM_HPP diff --git a/ggml/src/ggml-sycl/outprod.cpp b/ggml/src/ggml-sycl/outprod.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3a17f3a1b88abf3c3b09d3f51f87752c1a02ab47 --- /dev/null +++ b/ggml/src/ggml-sycl/outprod.cpp @@ -0,0 +1,47 @@ +#include "outprod.hpp" + +void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + const ggml_tensor *src0 = dst->src[0]; + const ggml_tensor *src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + GGML_TENSOR_BINARY_OP_LOCALS + + // Get SYCL queue + dpct::queue_ptr stream = ctx.stream(); + + // Dimension checks + GGML_ASSERT(ne01 == ne11); // Inner dimensions must match + GGML_ASSERT(ne0 == ne00); // Output rows match src0 rows + GGML_ASSERT(ne1 == ne10); // Output cols match src1 cols + + // Get data pointers + const float* src0_d = (const float*)src0->data; + const float* src1_d = (const float*)src1->data; + float* dst_d = (float*)dst->data; + + // GEMM parameters + const float alpha = 1.0f; + const float beta = 0.0f; + + // Handle transposition of src1 + const bool src1_T = ggml_is_transposed(src1); + const oneapi::math::transpose src1_op = src1_T ? oneapi::math::transpose::nontrans : oneapi::math::transpose::trans; + const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float); + + try { + // Perform matrix multiplication using oneMath GEMM + oneapi::math::blas::column_major::gemm(get_onemath_backend(*stream), oneapi::math::transpose::nontrans, src1_op, + ne0, ne1, ne01, alpha, src0_d, ne00, src1_d, ldb, beta, dst_d, ne0); + } + catch (sycl::exception const& exc) { + std::cerr << exc.what() << std::endl; + GGML_ASSERT(false); + } +} diff --git a/ggml/src/ggml-sycl/outprod.hpp b/ggml/src/ggml-sycl/outprod.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f50413d3f7a2879502d64028bc27153fc30afd0a --- /dev/null +++ b/ggml/src/ggml-sycl/outprod.hpp @@ -0,0 +1,10 @@ +#ifndef GGML_SYCL_OUTPROD_HPP +#define GGML_SYCL_OUTPROD_HPP + +#include "common.hpp" + +void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst); + + +#endif // GGML_SYCL_OUTPROD_HPP + diff --git a/ggml/src/ggml-sycl/presets.hpp b/ggml/src/ggml-sycl/presets.hpp new file mode 100644 index 0000000000000000000000000000000000000000..af1890727df8f99576a77cb41b5a9a8dbfd20ce6 --- /dev/null +++ b/ggml/src/ggml-sycl/presets.hpp @@ -0,0 +1,74 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_PRESETS_HPP +#define GGML_SYCL_PRESETS_HPP + +#define GGML_SYCL_MAX_STREAMS 8 +#define GGML_SYCL_MAX_BUFFERS 256 + +#define WARP_SIZE GGML_SYCL_WARP_SIZE +#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses + +#define SYCL_GELU_BLOCK_SIZE 256 +#define SYCL_SILU_BLOCK_SIZE 256 +#define SYCL_TANH_BLOCK_SIZE 256 +#define SYCL_RELU_BLOCK_SIZE 256 +#define SYCL_HARDSIGMOID_BLOCK_SIZE 256 +#define SYCL_HARDSWISH_BLOCK_SIZE 256 +#define SYCL_EXP_BLOCK_SIZE 256 +#define SYCL_NEG_BLOCK_SIZE 256 +#define SYCL_SIGMOID_BLOCK_SIZE 256 +#define SYCL_SQRT_BLOCK_SIZE 256 +#define SYCL_SIN_BLOCK_SIZE 256 +#define SYCL_SQR_BLOCK_SIZE 256 +#define SYCL_CPY_BLOCK_SIZE 32 +#define SYCL_SCALE_BLOCK_SIZE 256 +#define SYCL_CLAMP_BLOCK_SIZE 256 +#define SYCL_ROPE_BLOCK_SIZE 256 +#define SYCL_ALIBI_BLOCK_SIZE 32 +#define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32 +#define SYCL_QUANTIZE_BLOCK_SIZE 256 +#define SYCL_DEQUANTIZE_BLOCK_SIZE 256 +#define SYCL_GET_ROWS_BLOCK_SIZE 256 +#define SYCL_UPSCALE_BLOCK_SIZE 256 +#define SYCL_CONCAT_BLOCK_SIZE 256 +#define SYCL_PAD_BLOCK_SIZE 256 +#define SYCL_ACC_BLOCK_SIZE 256 +#define SYCL_IM2COL_BLOCK_SIZE 256 +#define SYCL_POOL2D_BLOCK_SIZE 256 +#define SYCL_ARGMAX_BLOCK_SIZE 256 +#define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256 +#define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256 + +// dmmv = dequantize_mul_mat_vec +#ifndef GGML_SYCL_DMMV_X +#define GGML_SYCL_DMMV_X 32 +#endif +#ifndef GGML_SYCL_MMV_Y +#define GGML_SYCL_MMV_Y 1 +#endif + +#ifndef K_QUANTS_PER_ITERATION +#define K_QUANTS_PER_ITERATION 2 +#else +static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); +#endif + +#ifndef GGML_SYCL_PEER_MAX_BATCH_SIZE +#define GGML_SYCL_PEER_MAX_BATCH_SIZE 128 +#endif // GGML_SYCL_PEER_MAX_BATCH_SIZE + +#define MUL_MAT_SRC1_COL_STRIDE 128 + +#define QK_WARP_SIZE 32 +#endif // GGML_SYCL_PRESETS_HPP diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8b952db43bfe22fdcbbcccdf60dc79d2c51ec7da --- /dev/null +++ b/ggml/src/ggml-sycl/quants.hpp @@ -0,0 +1,111 @@ +// +// MIT license +// Copyright (C) 2025 Codeplay Software Ltd. +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_QUANTS_HPP +#define GGML_SYCL_QUANTS_HPP + +#include + +#include "ggml-common.h" +#include "ggml.h" + +namespace ggml_sycl_reordered { + +// The reordered block moves quants (qs) and scales(d) to two +// uniform regions of memory that is contiguous in the same tensor. +// What this means is that instead of having: +// [d0, qs0] [d1, qs1] [d2, qs2] ... [dN, qsN] +// We have: +// [qs0, qs1, qs2, ..., qsN] [d0, d1, d2, ..., dN] +// +// Notes: out-of-bounds qs will run into d values +// Aligment relies on the allocated size of qs + +template struct block_q_t; + +// qk number of weights / quants in a block +// qr number of weights in a byte (described as 'before dequantization') +// for quantization types that has low and high bits split, qr is calculated with +// using the lower bits, e.g for Q6 quants QR6 is 2 +// qi number of 32 bit integers needed to represent all the quants from a block (`qs` field) +// See ggml-common.h to see how these are calculated +template <> struct block_q_t { + struct traits { + static constexpr uint32_t qk = QK4_0; + static constexpr uint32_t qi = QI4_0; + static constexpr uint32_t qr = QR4_0; + static constexpr uint32_t vdr_mmvq = 2; + }; + + static constexpr std::pair get_block_offset(const int block_index, const int /* nblocks */) { + return { block_index * (traits::qk / traits::qr), 0 }; + } + + static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { + return { (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half), 0 }; + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } +}; + +template <> struct block_q_t { + struct traits { + static constexpr uint32_t qk = QK_K; + static constexpr uint32_t qi = QI4_K; + static constexpr uint32_t qr = QR4_K; + static constexpr uint32_t vdr_mmvq = 2; + }; + + static constexpr std::pair get_block_offset(const int block_index, const int /* nblocks */) { + return { block_index * (traits::qk / traits::qr), 0 }; + } + + static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { + auto nblocks = (nrows * (ncols / traits::qk)); + return { nblocks * (QK_K / 2), + (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)) }; + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } + + constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; } +}; + +template <> struct block_q_t { + struct traits { + static constexpr uint32_t qk = QK_K; + static constexpr uint32_t qi = QI6_K; + static constexpr uint32_t qr = QR6_K; + static constexpr uint32_t vdr_mmvq = 1; + }; + + static constexpr std::pair get_block_offset(const int block_index, const int n_blocks) { + auto low_bits_index = block_index * (traits::qk / traits::qr); + // the index of high bits it's after all low bits + auto high_bits_index = n_blocks * (QK_K / 2) + (block_index * (QK_K / 4)); + return { low_bits_index, high_bits_index }; + } + + static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { + auto nblocks = (nrows * (ncols / traits::qk)); + auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 4); + auto block_scales = total_qs_bytes + block_index * (QK_K / 16); + auto sb_scale = total_qs_bytes + nblocks * (QK_K / 16); + return { block_scales, sb_scale }; + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } +}; +} // namespace ggml_sycl_reordered + +#endif // GGML_SYCL_QUANTS_HPP diff --git a/ggml/src/ggml-sycl/rope.cpp b/ggml/src/ggml-sycl/rope.cpp new file mode 100644 index 0000000000000000000000000000000000000000..44473e1e5580c8c77b196a1e705477dc6a3c6e52 --- /dev/null +++ b/ggml/src/ggml-sycl/rope.cpp @@ -0,0 +1,468 @@ +#include "rope.hpp" +#include "ggml-sycl/common.hpp" +#include "ggml.h" + +struct rope_corr_dims { + float v[2]; +}; + +struct mrope_sections { + int v[4]; +}; + +static float rope_yarn_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / sycl::max(0.001f, high - low); + return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +static void rope_yarn( + float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale, + float * cos_theta, float * sin_theta) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale); + } + *cos_theta = sycl::cos(theta) * mscale; + *sin_theta = sycl::sin(theta) * mscale; +} + +template +static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, + const int32_t * pos, float freq_scale, float ext_factor, float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, + const sycl::nd_item<3> & item_ct1) { + const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1)); + + if (i0 >= ne0) { + return; + } + + const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); + + if (i0 >= n_dims) { + const int i = row * ne0 + i0; + *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i); + return; + } + + const int row0 = row % ne1; + const int channel0 = row / ne1; + + const int i = row * ne0 + i0; + const int i2 = channel0 * s2 + row0 * s1 + i0; + + const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f); + + const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; + + float cos_theta; + float sin_theta; + + rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + + const float x0 = x[i2 + 0]; + const float x1 = x[i2 + 1]; + + dst[i + 0] = x0 * cos_theta - x1 * sin_theta; + dst[i + 1] = x0 * sin_theta + x1 * cos_theta; +} + +template +static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, + const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, + const sycl::nd_item<3> & item_ct1) { + const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1)); + + if (i0 >= ne0) { + return; + } + + const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); + + if (i0 >= n_dims) { + const int i = row * ne0 + i0; + *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i); + return; + } + + const int row0 = row % ne1; + const int channel0 = row / ne1; + + const int i = row * ne0 + i0 / 2; + const int i2 = channel0 * s2 + row0 * s1 + i0 / 2; + + const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f); + + const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; + + float cos_theta; + float sin_theta; + + rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + + const float x0 = x[i2 + 0]; + const float x1 = x[i2 + n_dims / 2]; + + dst[i + 0] = x0 * cos_theta - x1 * sin_theta; + dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta; +} + +template +static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, + const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale, + const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, + const float theta_scale, const float * freq_factors, const mrope_sections sections, + const sycl::nd_item<3> & item_ct1) { + // get index pos + const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1)); + if (i0 >= ne0) { + return; + } + const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2); + + if (i0 >= n_dims) { + const int i = row_dst*ne0 + i0; + *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i); + return; + } + + const int row_x = row_dst % ne1; + const int channel_x = row_dst / ne1; + const int idst = (row_dst * ne0) + (i0 / 2); + const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2); + + const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; + const int sec_w = sections.v[1] + sections.v[0]; + const int sector = (i0 / 2) % sect_dims; + + + float theta_base = 0.0; + if (sector < sections.v[0]) { + theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f); + } + else if (sector >= sections.v[0] && sector < sec_w) { + theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f); + } + else if (sector >= sec_w && sector < sec_w + sections.v[2]) { + theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f); + } + else if (sector >= sec_w + sections.v[2]) { + theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f); + } + + const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; + float cos_theta; + float sin_theta; + rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + const float x0 = x[ix + 0]; + const float x1 = x[ix + n_dims/2]; + + // store results in dst + dst[idst + 0] = x0 * cos_theta - x1 * sin_theta; + dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta; +} + + + +template +static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, + const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale, + const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, + const float theta_scale, const float * freq_factors, const mrope_sections sections, + const sycl::nd_item<3> & item_ct1) { + // get index pos + const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1)); + if (i0 >= ne0) { + return; + } + const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2); + const int row_x = row_dst % ne1; + const int channel_x = row_dst / ne1; + const int idst = (row_dst * ne0) + (i0 / 2); + const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2); + + const int sect_dims = sections.v[0] + sections.v[1]; + const int sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0f; + if (sector < sections.v[0]) { + const int p = sector; + theta_base = pos[channel_x] * sycl::pow(theta_scale, (float) p); + } else { + // Simplified from CUDA backend code: if (sector >= sections.v[0] && sector < sec_w) which is just sector >= sections.v[0] + const int p = sector - sections.v[0]; + theta_base = pos[channel_x + ne2] * sycl::pow(theta_scale, (float) p); + } + + const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; + float cos_theta; + float sin_theta; + rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + const float x0 = x[ix + 0]; + const float x1 = x[ix + n_dims]; + + // store results in dst + dst[idst + 0] = x0 * cos_theta - x1 * sin_theta; + dst[idst + n_dims] = x0 * sin_theta + x1 * cos_theta; +} + +template +static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, + const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base, + const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, + const float * freq_factors, queue_ptr stream) { + GGML_ASSERT(ne0 % 2 == 0); + const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); + const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE)); + const sycl::range<3> block_nums(1, num_blocks_x, nr); + + const float theta_scale = powf(freq_base, -2.0f / n_dims); + + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + + if (freq_factors == nullptr) { + /* + DPCT1049:40: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + rope_norm(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, item_ct1); + }); + } else { + /* + DPCT1049:41: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + rope_norm(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, item_ct1); + }); + } +} + +template +static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, + const int n_dims, const int nr, const int32_t * pos, const float freq_scale, + const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) { + GGML_ASSERT(ne0 % 2 == 0); + const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); + const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE)); + const sycl::range<3> block_nums(1, num_blocks_x, nr); + + const float theta_scale = powf(freq_base, -2.0f / n_dims); + + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + + if (freq_factors == nullptr) { + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + rope_neox(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, item_ct1); + }); + } else { + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + rope_neox(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, item_ct1); + }); + } +} + +template +static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, + const size_t s2, const int n_dims, const int nr, const int32_t * pos, + const float freq_scale, const float freq_base, const float ext_factor, + const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors, + const mrope_sections sections, queue_ptr stream) { + GGML_ASSERT(ne0 % 2 == 0); + const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); + const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE)); + const sycl::range<3> grid_dims(1, n_blocks_y, nr); + const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims); + + const float theta_scale = std::pow(freq_base, -2.0f / n_dims); + // Add FP16 capability check if T could be sycl::half + if constexpr (std::is_same_v) { + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + } + // launch kernel + if (freq_factors == nullptr) { + stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { + rope_multi(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, sections, item_ct1); + }); + } else { + stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { + rope_multi(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, sections, item_ct1); + }); + } +} + + + + +// rope vision +template +static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, + const size_t s2, const int n_dims, const int nr, const int32_t * pos, + const float freq_scale, const float freq_base, const float ext_factor, + const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors, + const mrope_sections sections, queue_ptr stream) { + GGML_ASSERT(ne0 % 2 == 0); + const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); + const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE)); + const sycl::range<3> grid_dims(1, n_blocks_y, nr); + const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims); + + const float theta_scale = std::pow(freq_base, -2.0f / n_dims); + // Add FP16 capability check if T could be sycl::half + if constexpr (std::is_same_v) { + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + } + // launch kernel + if (freq_factors == nullptr) { + stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { + rope_vision(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, sections, item_ct1); + }); + } else { + stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { + rope_vision(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, sections, item_ct1); + }); + } +} + +inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(dst->src[0]->type == dst->type); + const int64_t ne00 = dst->src[0]->ne[0]; // head dims + const int64_t ne01 = dst->src[0]->ne[1]; // num heads + const int64_t ne02 = dst->src[0]->ne[2]; // num heads + const int64_t nr = ggml_nrows(dst->src[0]); + + const size_t s01 = dst->src[0]->nb[1] / ggml_type_size(dst->src[0]->type); + const size_t s02 = dst->src[0]->nb[2] / ggml_type_size(dst->src[0]->type); + + + //const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + //const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + mrope_sections sections; + + // RoPE alteration for extended context + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions.v, (int32_t *) dst->op_params + 11, sizeof(int)*4); + + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (is_mrope) { + GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0); + } + + if (is_vision) { + GGML_ASSERT(n_dims == ne00/2); + } + + const int32_t * pos = (const int32_t *) dst->src[1]->data; + + const float * freq_factors = nullptr; + if (dst->src[2] != nullptr) { + freq_factors = (const float *) dst->src[2]->data; + } + + rope_corr_dims corr_dims; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v); + + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + // compute + if (is_neox) { + GGML_SYCL_DEBUG("%s: neox path\n", __func__); + if (dst->src[0]->type == GGML_TYPE_F32) { + rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr, + pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream); + } else if (dst->src[0]->type == GGML_TYPE_F16) { + rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02, + n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, + main_stream); + } else { + GGML_ABORT("fatal error"); + } + } else if (is_mrope && !is_vision) { + GGML_SYCL_DEBUG("%s: mrope path\n", __func__); + if (dst->src[0]->type == GGML_TYPE_F16) { + rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01, + s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, + freq_factors, sections, main_stream); + } else if (dst->src[0]->type == GGML_TYPE_F32) { + rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims, + nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, + main_stream); + } else { + GGML_ABORT("Fatal error: Tensor type unsupported!"); + } + } else if (is_vision) { + GGML_SYCL_DEBUG("%s: vision path\n", __func__); + if (dst->src[0]->type == GGML_TYPE_F16) { + rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01, + s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, + freq_factors, sections, main_stream); + } else if (dst->src[0]->type == GGML_TYPE_F32) { + rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims, + nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, + main_stream); + } else { + GGML_ABORT("Fatal error: Tensor type unsupported!"); + } + } else { + GGML_SYCL_DEBUG("%s: norm path\n", __func__); + if (dst->src[0]->type == GGML_TYPE_F32) { + rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr, + pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream); + } else if (dst->src[0]->type == GGML_TYPE_F16) { + rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02, + n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, + main_stream); + } else { + GGML_ABORT("fatal error"); + } + } +} + +void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3); + ggml_sycl_op_rope(ctx, dst); +} + diff --git a/ggml/src/ggml-sycl/rope.hpp b/ggml/src/ggml-sycl/rope.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8c7141aac5c9b1fed0d9ff573ec78f614b7a400e --- /dev/null +++ b/ggml/src/ggml-sycl/rope.hpp @@ -0,0 +1,20 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_ROPE_HPP +#define GGML_SYCL_ROPE_HPP + +#include "common.hpp" + +void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst); + +#endif // GGML_SYCL_ROPE_HPP diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..52fcf4b3dbd244db8c0ed92e7f931eb16bb9c199 --- /dev/null +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -0,0 +1,261 @@ +#include "softmax.hpp" + +template +static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, + const int nrows_y, const float scale, const float max_bias, const float m0, + const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) { + const int ncols = ncols_template == 0 ? ncols_par : ncols_template; + + const int tid = item_ct1.get_local_id(2); + const int rowx = item_ct1.get_group(2); + const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension + + const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template; + + const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; + const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; + const int nthreads = block_size; + const int nwarps = nthreads / WARP_SIZE; + size_t nreduce = nwarps / WARP_SIZE; + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + const uint32_t h = rowx/nrows_y; // head index + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = sycl::pow(base, float(exp)); + } + + float *vals = vals_smem ? buf + sycl::max(nwarps, WARP_SIZE) : dst + rowx * ncols; + float max_val = -INFINITY; + + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + break; + } + + const int ix = rowx*ncols + col; + const int iy = rowy*ncols + col; + + const float val = x[ix]*scale + (mask ? slope*static_cast(mask[iy]) : 0.0f); + + vals[col] = val; + max_val = sycl::max(max_val, val); + } + + // find the max value in the block + max_val = warp_reduce_max(max_val, item_ct1); + if (block_size > WARP_SIZE) { + if (warp_id == 0) { + buf[lane_id] = -INFINITY; + for (size_t i = 1; i < nreduce; i += 1) { + buf[lane_id + i * WARP_SIZE] = -INFINITY; + } + } + item_ct1.barrier(sycl::access::fence_space::local_space); + + if (lane_id == 0) { + buf[warp_id] = max_val; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + max_val = buf[lane_id]; + for (size_t i = 1; i < nreduce; i += 1) { + max_val = sycl::max(max_val, buf[lane_id + i * WARP_SIZE]); + } + max_val = warp_reduce_max(max_val, item_ct1); + } + + float tmp = 0.f; +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + if (ncols_template == 0 && col >= ncols) { + break; + } + + const float val = sycl::native::exp(vals[col] - max_val); + tmp += val; + vals[col] = val; + } + + // find the sum of exps in the block + tmp = warp_reduce_sum(tmp, item_ct1); + if (block_size > WARP_SIZE) { + item_ct1.barrier(sycl::access::fence_space::local_space); + if (warp_id == 0) { + buf[lane_id] = 0.f; + for (size_t i = 1; i < nreduce; i += 1) { + buf[lane_id + i * WARP_SIZE] = 0.f; + } + } + item_ct1.barrier(sycl::access::fence_space::local_space); + + if (lane_id == 0) { + buf[warp_id] = tmp; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + + tmp = buf[lane_id]; + for (size_t i = 1; i < nreduce; i += 1) { + tmp += buf[lane_id + i * WARP_SIZE]; + } + tmp = warp_reduce_sum(tmp, item_ct1); + } + + const float inv_sum = 1.f / tmp; + +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + return; + } + + const int idst = rowx*ncols + col; + dst[idst] = vals[col] * inv_sum; + } +} + +template +static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, const int ncols_par, + const int nrows_y, const float scale, const float max_bias, const float m0, + const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims, + const size_t n_local_scratch, queue_ptr stream) { + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor local_buf_acc(n_local_scratch, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + soft_max_f32(x, mask, dst, ncols_par, + nrows_y, scale, max_bias, m0, + m1, n_head_log2, item_ct1, + get_pointer(local_buf_acc)); + }); + }); +} + +template +static void soft_max_f32_sycl(const float * x, const T * mask, + float * dst, const int ncols_x, const int nrows_x, + const int nrows_y, const float scale, const float max_bias, + queue_ptr stream, int device) { + int nth = WARP_SIZE; + int max_block_size = ggml_sycl_info().max_work_group_sizes[device]; + while (nth < ncols_x && nth < max_block_size) nth *= 2; + if (nth>max_block_size) nth = max_block_size; + + const sycl::range<3> block_dims(1, 1, nth); + const sycl::range<3> block_nums(1, 1, nrows_x); + const size_t n_val_tmp = nth / WARP_SIZE; + const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + n_val_tmp); + + const uint32_t n_head_kv = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + const size_t local_mem_size = stream->get_device().get_info(); + if (n_local_scratch*sizeof(float) < local_mem_size) { + if (ncols_x > max_block_size) { + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + return; + } + switch (ncols_x) { + case 32: + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + case 64: + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + case 128: + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + case 256: + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + case 512: + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + case 1024: + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + case 2048: + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + case 4096: + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + default: + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + } + } else { + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, WARP_SIZE, stream); + } +} + +void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(!dst->src[1] || dst->src[1]->type == GGML_TYPE_F16 || dst->src[1]->type == GGML_TYPE_F32); // src1 contains mask and it is optional + + const int64_t ne00 = dst->src[0]->ne[0]; + const int64_t nrows_x = ggml_nrows(dst->src[0]); + const int64_t nrows_y = dst->src[0]->ne[1]; + + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, dst->op_params + 1, sizeof(float)); + + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); + + ggml_sycl_set_device(ctx.device); + dpct::queue_ptr main_stream = ctx.stream(); + + if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) { + const sycl::half * src1_dd = static_cast(dst->src[1]->data); + soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, + main_stream, ctx.device); + } else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) { + const float * src1_dd = static_cast(dst->src[1]->data); + soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + } else { + /* mask unavailable */ + soft_max_f32_sycl(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + } +} diff --git a/ggml/src/ggml-sycl/softmax.hpp b/ggml/src/ggml-sycl/softmax.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2cf8582ec92e9ef52ebc50b446da9e2953309f9c --- /dev/null +++ b/ggml/src/ggml-sycl/softmax.hpp @@ -0,0 +1,20 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_SOFTMAX_HPP +#define GGML_SYCL_SOFTMAX_HPP + +#include "common.hpp" + +void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, ggml_tensor *dst); + +#endif // GGML_SYCL_SOFTMAX_HPP diff --git a/ggml/src/ggml-sycl/sycl_hw.cpp b/ggml/src/ggml-sycl/sycl_hw.cpp new file mode 100644 index 0000000000000000000000000000000000000000..da121ffc261e8bb52722668467eafc7ff3fda627 --- /dev/null +++ b/ggml/src/ggml-sycl/sycl_hw.cpp @@ -0,0 +1,13 @@ +#include "sycl_hw.hpp" + + +sycl_hw_info get_device_hw_info(sycl::device *device_ptr) { + sycl_hw_info res; + int32_t id = device_ptr->get_info(); + res.device_id = id; + + syclex::architecture arch = device_ptr->get_info(); + res.arch = arch; + + return res; +} diff --git a/ggml/src/ggml-sycl/sycl_hw.hpp b/ggml/src/ggml-sycl/sycl_hw.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bf689450ce61ff244b54d79b5ea1197c239e9068 --- /dev/null +++ b/ggml/src/ggml-sycl/sycl_hw.hpp @@ -0,0 +1,23 @@ +#ifndef SYCL_HW_HPP +#define SYCL_HW_HPP + +#include +#include +#include +#include + +#include + +namespace syclex = sycl::ext::oneapi::experimental; + +struct sycl_hw_info { + syclex::architecture arch; + int32_t device_id; +}; + +bool is_in_vector(std::vector &vec, int item); + +sycl_hw_info get_device_hw_info(sycl::device *device_ptr); + + +#endif // SYCL_HW_HPP diff --git a/ggml/src/ggml-sycl/tsembd.cpp b/ggml/src/ggml-sycl/tsembd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f6ca626ea7a53f963626fba465e42d4808f6de54 --- /dev/null +++ b/ggml/src/ggml-sycl/tsembd.cpp @@ -0,0 +1,72 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#include "tsembd.hpp" + +static void timestep_embedding_f32( + const float * timesteps, float * dst, const int nb1, + const int dim, const int max_period, const sycl::nd_item<3> &item_ct1) { + // item_ct1.get_group(1)(blockIDx.y): idx of timesteps->ne[0] + // item_ct1.get_group(2) (blockIDx.x): idx of ((dim + 1) / 2) / BLOCK_SIZE + int i = item_ct1.get_group(1); + int j = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2); + float * embed_data = (float *)((char *)dst + i*nb1); + + if (dim % 2 != 0 && j == ((dim + 1) / 2)) { + embed_data[dim] = 0.f; + } + + int half = dim / 2; + if (j >= half) { + return; + } + + float timestep = timesteps[i]; + float freq = (float)sycl::native::exp(-(sycl::log((float)max_period)) * j / half); + float arg = timestep * freq; + embed_data[j] = sycl::cos(arg); + embed_data[j + half] = sycl::sin(arg); +} + +static void timestep_embedding_f32_sycl( + const float * x, float * dst, const int ne00, const int nb1, + const int dim, const int max_period, const queue_ptr& stream) { + // As the kernel returns when thread.idx is larger than dim/2, the half_ceil does not need to pad + int half_ceil = dim / 2; + int num_blocks = (half_ceil + SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE; + sycl::range<3> block_dims(1, 1, SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE); + sycl::range<3> gridDim(1, ne00, num_blocks); + stream->parallel_for( + sycl::nd_range<3>( + gridDim * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + timestep_embedding_f32( + x, dst, nb1, dim, max_period, item_ct1 + ); + }); +} + +void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + dpct::queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int dim = dst->op_params[0]; + const int max_period = dst->op_params[1]; + + timestep_embedding_f32_sycl(src0_d, dst_d, src0->ne[0], dst->nb[1], dim, max_period, stream); +} diff --git a/ggml/src/ggml-sycl/tsembd.hpp b/ggml/src/ggml-sycl/tsembd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4c18748bbffc285ea6fc6d7463d4f7d2f1b54d7f --- /dev/null +++ b/ggml/src/ggml-sycl/tsembd.hpp @@ -0,0 +1,20 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_TSEMBD_HPP +#define GGML_SYCL_TSEMBD_HPP + +#include "common.hpp" + +void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif // GGML_SYCL_TSEMBD_HPP diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0a5d4999419c9ee279f90c849881ec3c380a9e64 --- /dev/null +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -0,0 +1,1307 @@ +// +// MIT license +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_VECDOTQ_HPP +#define GGML_SYCL_VECDOTQ_HPP + +#include "dpct/helper.hpp" +#include "ggml.h" +#include "quants.hpp" + +typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, + const int & iqs); + +static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) { + const uint16_t* x16 = + (const uint16_t*)(x8 + sizeof(int) * i32); // assume at least 2 byte + // alignment + + int x32 = 0; + x32 |= x16[0] << 0; + x32 |= x16[1] << 16; + + return x32; +} + +static __dpct_inline__ int get_int_from_uint8( + const uint8_t* x8, + const int& i32) { + const uint16_t* x16 = + (const uint16_t*)(x8 + sizeof(int) * i32); // assume at least 2 byte + // alignment + + int x32 = 0; + x32 |= x16[0] << 0; + x32 |= x16[1] << 16; + + return x32; +} + +static __dpct_inline__ int get_int_from_int8_aligned( + const int8_t* x8, + const int& i32) { + return *( + (const int*)(x8 + sizeof(int) * i32)); // assume at least 4 byte alignment +} + +static __dpct_inline__ int get_int_from_uint8_aligned( + const uint8_t* x8, + const int& i32) { + return *( + (const int*)(x8 + sizeof(int) * i32)); // assume at least 4 byte alignment +} + +static __dpct_inline__ void get_int_from_table_16(const uint32_t &q4, + const uint8_t *values, + int &val1, int &val2) { + + uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32; + aux32 = q4 & 0x0f0f0f0f; + uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8); + uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8); + val1 = v1 | (v2 << 16); + aux32 = (q4 >> 4) & 0x0f0f0f0f; + v1 = values[q8[0]] | (values[q8[1]] << 8); + v2 = values[q8[2]] | (values[q8[3]] << 8); + val2 = v1 | (v2 << 16); +} + +#define VDR_Q2_K_Q8_1_MMVQ 1 + +// contiguous v/x values +static __dpct_inline__ float vec_dot_q2_K_q8_1_impl_mmvq( + const int &v, const int *__restrict__ u, const uint8_t *__restrict__ scales, + const sycl::half2 &dm2, const float *__restrict__ d8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR2_K; ++i) { + const int sc = scales[2*i]; + + const int vi = (v >> (2*i)) & 0x03030303; + + sumf_d += + d8[i] * (dpct::dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product + + // fill int with 4x m + int m = sc >> 4; + m |= m << 8; + m |= m << 16; + sumf_m += d8[i] * + dpct::dp4a( + m, u[i], + 0); // multiply constant q2_K part with sum of q8_1 values + } + + const sycl::float2 dm2f = + dm2.convert(); + + return dm2f.x() * sumf_d - dm2f.y() * sumf_m; +} + + +#define VDR_Q3_K_Q8_1_MMVQ 1 + +// contiguous v/x values +static __dpct_inline__ float vec_dot_q3_K_q8_1_impl_mmvq( + const int &vl, const int &vh, const int *__restrict__ u, + const uint8_t *__restrict__ scales, const int &scale_offset, + const float &d3, const float *__restrict__ d8) { + + float sumf = 0.0f; + +#pragma unroll + for (int i = 0; i < QR3_K; ++i) { + const int isc = scale_offset + 2*i; + + const int isc_low = isc % (QK_K/32); + const int sc_shift_low = 4 * (isc / (QK_K/32)); + const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF; + + const int isc_high = isc % (QK_K/64); + const int sc_shift_high = 2 * (isc / (QK_K/64)); + const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4; + + const int sc = (sc_low | sc_high) - 32; + + const int vil = (vl >> (2*i)) & 0x03030303; + + const int vih = ((vh >> i) << 2) & 0x04040404; + + const int vi = + dpct::vectorized_binary(vil, vih, dpct::sub_sat()); + + sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product + } + + return d3 * sumf; +} + +#define VDR_Q4_K_Q8_1_MMVQ 2 + +// contiguous v/x values +static __dpct_inline__ float vec_dot_q4_K_q8_1_impl_vmmq( + const int *__restrict__ v, const int *__restrict__ u, + const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m, + const sycl::half2 &dm4, const float *__restrict__ d8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR4_K; ++i) { + const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F; + const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F; + + const int dot1 = + dpct::dp4a(v1i, u[2 * i + 1], + dpct::dp4a(v0i, u[2 * i + 0], 0)); // SIMD dot product + const int dot2 = + dpct::dp4a(0x01010101, u[2 * i + 1], + dpct::dp4a(0x01010101, u[2 * i + 0], 0)); // sum of u + + sumf_d += d8[i] * (dot1 * sc[i]); + sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values + } + + const sycl::float2 dm4f = + dm4.convert(); + + return dm4f.x() * sumf_d - dm4f.y() * sumf_m; +} + + +#define VDR_Q5_K_Q8_1_MMVQ 2 + +// contiguous v/x values +static __dpct_inline__ float vec_dot_q5_K_q8_1_impl_vmmq( + const int *__restrict__ vl, const int *__restrict__ vh, + const int *__restrict__ u, const uint8_t *__restrict__ sc, + const uint8_t *__restrict__ m, const sycl::half2 &dm5, + const float *__restrict__ d8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR5_K; ++i) { + const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F; + const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F; + + const int vh0i = ((vh[0] >> i) << 4) & 0x10101010; + const int vh1i = ((vh[1] >> i) << 4) & 0x10101010; + + const int v0i = vl0i | vh0i; + const int v1i = vl1i | vh1i; + + const int dot1 = + dpct::dp4a(v0i, u[2 * i + 0], + dpct::dp4a(v1i, u[2 * i + 1], 0)); // SIMD dot product + const int dot2 = + dpct::dp4a(0x01010101, u[2 * i + 0], + dpct::dp4a(0x01010101, u[2 * i + 1], 0)); // sum of u + + sumf_d += d8[i] * (dot1 * sc[i]); + sumf_m += d8[i] * (dot2 * m[i]); + + } + + const sycl::float2 dm5f = + dm5.convert(); + + return dm5f.x() * sumf_d - dm5f.y() * sumf_m; +} + + +#define VDR_Q6_K_Q8_1_MMVQ 1 + +// contiguous v/x values +static __dpct_inline__ float +vec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh, + const int *__restrict__ u, + const int8_t *__restrict__ scales, const float &d, + const float *__restrict__ d8) { + + float sumf = 0.0f; + +#pragma unroll + for (int i = 0; i < QR6_K; ++i) { + const int sc = scales[4*i]; + + const int vil = (vl >> (4*i)) & 0x0F0F0F0F; + + const int vih = ((vh >> (4*i)) << 4) & 0x30303030; + + const int vi = dpct::vectorized_binary( + (vil | vih), 0x20202020, dpct::sub_sat()); // vi = (vil | vih) - 32 + + sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product + } + + return d*sumf; +} + +// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called +// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q + +template struct reorder_vec_dot_q_sycl { + static_assert(T != T, "ggml_type for reorder vecdot not implemented"); +}; + +template <> struct reorder_vec_dot_q_sycl { + static constexpr ggml_type gtype = GGML_TYPE_Q4_0; + + using q4_0_block = ggml_sycl_reordered::block_q_t; + using q4_0_traits = typename q4_0_block::traits; + + __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int * v, const int * u, const float & d4, const sycl::half2 & ds8) { + int sumi = 0; + +#pragma unroll + for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi); + sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi); + } + + const sycl::float2 ds8f = ds8.convert(); + + // second part effectively subtracts 8 from each quant value + return d4 * (sumi * ds8f.x() - (8 * q4_0_traits::vdr_mmvq / q4_0_traits::qi) * ds8f.y()); + } + + __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, + const std::pair d_offset, const int8_t * q8_1_quant_ptr, + const sycl::half2 * q8_1_ds, const int & iqs) { + const uint8_t * bq4_0 = static_cast(vbq) + ibx_offset.first; + const ggml_half d = *(reinterpret_cast(static_cast(vbq) + d_offset.first)); + int v[q4_0_traits::vdr_mmvq]; + int u[2 * q4_0_traits::vdr_mmvq]; + + +#pragma unroll + for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) { + v[i] = get_int_from_uint8(bq4_0, iqs + i); + u[2 * i + 0] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i); + u[2 * i + 1] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i + q4_0_traits::qi); + } + + return vec_dot_q4_0_q8_1_impl(v, u, d, *q8_1_ds); + }; +}; + +static inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales, + const ggml_half2 & dm, const block_q8_1 * __restrict__ bq8_1, + const int & iqs) { + int v[2]; + int u[2 * QR4_K]; + float d8[QR4_K]; + + v[0] = q4[0]; + v[1] = q4[4]; + + uint16_t aux[2]; + const int j = (QR4_K * ((iqs / 2) / (QI8_1 / 2))) / 2; + if (j < 2) { + aux[0] = scales[j + 0] & 0x3f3f; + aux[1] = scales[j + 2] & 0x3f3f; + } else { + aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2); + aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2); + } + + const uint8_t * sc = (const uint8_t *) aux; + const uint8_t * m = sc + 2; + + const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2)); + + for (int i = 0; i < QR4_K; ++i) { + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + d8[i] = bq8i->ds[0]; + + const int * q8 = (const int *) bq8i->qs + ((iqs / 2) % 4); + u[2 * i + 0] = q8[0]; + u[2 * i + 1] = q8[4]; + } + + return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, dm, d8); +} + +template <> struct reorder_vec_dot_q_sycl { + static constexpr ggml_type gtype = GGML_TYPE_Q4_K; + + using q4_k_block = ggml_sycl_reordered::block_q_t; + using q4_k_traits = typename q4_k_block::traits; + + __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, + const std::pair d_offset, const int8_t * q8_1_quant_ptr, + const sycl::half2 * q8_1_ds, const int & iqs) { + const int ib = ibx_offset.first / (QK_K / 2); + + const uint8_t * base = static_cast(vbq); + const uint8_t * qs = base + ibx_offset.first; + const uint8_t * scs = base + d_offset.first + ib * K_SCALE_SIZE; + const ggml_half2 * dms = reinterpret_cast(base + d_offset.second); + + const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2)); + const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4)); + const uint16_t * scales = (const uint16_t *) scs; + + int v[2]; + int u[2 * QR4_K]; + float d8[QR4_K]; + + v[0] = q4[0]; + v[1] = q4[4]; + + uint16_t aux[2]; + const int j = (QR4_K * ((iqs / 2) / (QI8_1 / 2))) / 2; + if (j < 2) { + aux[0] = scales[j + 0] & 0x3f3f; + aux[1] = scales[j + 2] & 0x3f3f; + } else { + aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2); + aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2); + } + + const uint8_t * sc = (const uint8_t *) aux; + const uint8_t * m = sc + 2; + + for (int i = 0; i < QR4_K; ++i) { + const int8_t* quant_base_ptr = q8_1_quant_ptr + (bq8_offset + i) * QK8_1; + sycl::half2 ds_values = *(q8_1_ds + bq8_offset + i); + + d8[i] = ds_values[0]; + + const int * q8 = (const int *) quant_base_ptr + ((iqs / 2) % 4); + u[2 * i + 0] = q8[0]; + u[2 * i + 1] = q8[4]; + } + + return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, *dms, d8); + } +}; + +template <> struct reorder_vec_dot_q_sycl { + static constexpr ggml_type gtype = GGML_TYPE_Q6_K; + + using q6_k_block = ggml_sycl_reordered::block_q_t; + using q6_k_traits = typename q6_k_block::traits; + + __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq(const int vl, const int vh, const int * __restrict__ u, + const int8_t * __restrict__ scales, const float d, + const float * __restrict__ d8) { + float sumf = 0.0f; + +#pragma unroll + for (int i = 0; i < QR6_K; ++i) { + const int sc = scales[4 * i]; + + const int vil = (vl >> (4 * i)) & 0x0F0F0F0F; + + const int vih = ((vh >> (4 * i)) << 4) & 0x30303030; + + const int vi = dpct::vectorized_binary((vil | vih), 0x20202020, + dpct::sub_sat()); // vi = (vil | vih) - 32 + + sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product + } + + return d * sumf; + } + + __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, + const std::pair d_offset, const int8_t * q8_1_quant_ptr, const sycl::half2 * q8_1_ds, + const int iqs) { + const int ib = ibx_offset.first / (QK_K / 2); + + const uint8_t * base = static_cast(vbq); + const uint8_t * ql = base + ibx_offset.first; + const uint8_t * qh = base + ibx_offset.second; + const int8_t * scales = reinterpret_cast(base + d_offset.first); + const ggml_half * d = (const ggml_half *) (base + d_offset.second) + ib; + + const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 4); + const int scale_offset = (QI6_K / 4) * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 8); + const int vh_shift = 2 * ((iqs % (QI6_K / 2)) / (QI6_K / 4)); + + const int vl = get_int_from_uint8(ql, iqs); + const int vh = get_int_from_uint8(qh, (QI6_K / 4) * (iqs / (QI6_K / 2)) + iqs % (QI6_K / 4)) >> vh_shift; + + const int8_t * scs = scales + scale_offset; + + int u[QR6_K]; + float d8[QR6_K]; + +#pragma unroll + for (int i = 0; i < QR6_K; ++i) { + u[i] = get_int_from_int8_aligned(q8_1_quant_ptr + (bq8_offset + 2 * i) * QK8_1, iqs % QI8_1); + const sycl::half2 ds_values = *(q8_1_ds + bq8_offset + 2 * i); + d8[i] = ds_values[0]; + } + return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scs, *d, d8); + } +}; +#define VDR_Q4_0_Q8_1_MMVQ 2 +#define VDR_Q4_0_Q8_1_MMQ 4 + +template +static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int * v, const int * u, const float & d4, + const sycl::half2 & ds8) { + int sumi = 0; +#pragma unroll + for (int i = 0; i < vdr; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi); + sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi); + } + + const sycl::float2 ds8f = ds8.convert(); + + // second part effectively subtracts 8 from each quant value + return d4 * (sumi * ds8f.x() - (8 * vdr / QI4_0) * ds8f.y()); +} + +#define VDR_Q4_1_Q8_1_MMVQ 2 +#define VDR_Q4_1_Q8_1_MMQ 4 + +template +static __dpct_inline__ float vec_dot_q4_1_q8_1_impl(const int *v, const int *u, + const sycl::half2 &dm4, + const sycl::half2 &ds8) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi); + sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi); + } + +#ifdef GGML_SYCL_F16 + const sycl::float2 tmp = + (dm4 * ds8).convert(); + const float d4d8 = tmp.x(); + const float m4s8 = tmp.y(); +#else + const sycl::float2 dm4f = + dm4.convert(); + const sycl::float2 ds8f = + ds8.convert(); + const float d4d8 = dm4f.x() * ds8f.x(); + const float m4s8 = dm4f.y() * ds8f.y(); +#endif // GGML_SYCL_F16 + + // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it + return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1)); +} + +#define VDR_Q5_0_Q8_1_MMVQ 2 +#define VDR_Q5_0_Q8_1_MMQ 4 + +template +static __dpct_inline__ float +vec_dot_q5_0_q8_1_impl(const int *vl, const int *vh, const int *u, + const float &d5, const sycl::half2 &ds8) { + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits + vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 + vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 + vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 + vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 + sumi = dpct::dp4a(vi0, u[2 * i + 0], + sumi); // SIMD dot product of quantized values + + int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits + vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 + vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 + vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 + vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 + sumi = dpct::dp4a(vi1, u[2 * i + 1], + sumi); // SIMD dot product of quantized values + } + + const sycl::float2 ds8f = + ds8.convert(); + + // second part effectively subtracts 16 from each quant value + return d5 * (sumi * ds8f.x() - (16 * vdr / QI5_0) * ds8f.y()); +} + +#define VDR_Q5_1_Q8_1_MMVQ 2 +#define VDR_Q5_1_Q8_1_MMQ 4 + +template +static __dpct_inline__ float +vec_dot_q5_1_q8_1_impl(const int *vl, const int *vh, const int *u, + const sycl::half2 &dm5, const sycl::half2 &ds8) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits + vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 + vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 + vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 + vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 + sumi = dpct::dp4a(vi0, u[2 * i + 0], + sumi); // SIMD dot product of quantized values + + int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits + vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 + vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 + vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 + vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 + sumi = dpct::dp4a(vi1, u[2 * i + 1], + sumi); // SIMD dot product of quantized values + } + +#ifdef GGML_SYCL_F16 + const sycl::float2 tmp = + (dm5 * ds8).convert(); + const float d5d8 = tmp.x(); + const float m5s8 = tmp.y(); + + +#else + const sycl::float2 dm5f = + dm5.convert(); + const sycl::float2 ds8f = + ds8.convert(); + const float d5d8 = dm5f.x() * ds8f.x(); + const float m5s8 = dm5f.y() * ds8f.y(); +#endif // GGML_SYCL_F16 + + // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it + return sumi*d5d8 + m5s8 / (QI5_1 / vdr); +} + +#define VDR_Q8_0_Q8_1_MMVQ 2 +#define VDR_Q8_0_Q8_1_MMQ 8 + +template +static __dpct_inline__ float vec_dot_q8_0_q8_1_impl(const int *v, const int *u, + const float &d8_0, + const float &d8_1) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + // SIMD dot product of quantized values + sumi = dpct::dp4a(v[i], u[i], sumi); + } + + return d8_0*d8_1 * sumi; +} + +template +static __dpct_inline__ float vec_dot_q8_1_q8_1_impl(const int *v, const int *u, + const sycl::half2 &dm8, + const sycl::half2 &ds8) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + // SIMD dot product of quantized values + sumi = dpct::dp4a(v[i], u[i], sumi); + } + +#ifdef GGML_SYCL_F16 + const sycl::float2 tmp = + (dm8 * ds8).convert(); + const float d8d8 = tmp.x(); + const float m8s8 = tmp.y(); +#else + const sycl::float2 dm8f = + dm8.convert(); + const sycl::float2 ds8f = + ds8.convert(); + const float d8d8 = dm8f.x() * ds8f.x(); + const float m8s8 = dm8f.y() * ds8f.y(); +#endif // GGML_SYCL_F16 + + // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it + return sumi*d8d8 + m8s8 / (QI8_1 / vdr); +} + +static __dpct_inline__ float +vec_dot_q4_0_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs) { + + const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; + + int v[VDR_Q4_0_Q8_1_MMVQ]; + int u[2 * VDR_Q4_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_uint8(bq4_0->qs, iqs + i); + u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0); + } + + return vec_dot_q4_0_q8_1_impl(v, u, bq4_0->d, bq8_1->ds); +} + +static __dpct_inline__ float +vec_dot_q4_1_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs) { + + const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq; + + int v[VDR_Q4_1_Q8_1_MMVQ]; + int u[2*VDR_Q4_1_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_uint8_aligned(bq4_1->qs, iqs + i); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1); + } + + return vec_dot_q4_1_q8_1_impl(v, u, bq4_1->dm, bq8_1->ds); +} + +static __dpct_inline__ float +vec_dot_q5_0_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs) { + + const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq; + + int vl[VDR_Q5_0_Q8_1_MMVQ]; + int vh[VDR_Q5_0_Q8_1_MMVQ]; + int u[2*VDR_Q5_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) { + vl[i] = get_int_from_uint8(bq5_0->qs, iqs + i); + vh[i] = get_int_from_uint8(bq5_0->qh, 0) >> (4 * (iqs + i)); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_0); + } + + return vec_dot_q5_0_q8_1_impl(vl, vh, u, bq5_0->d, bq8_1->ds); +} + +static __dpct_inline__ float +vec_dot_q5_1_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs) { + + const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq; + + int vl[VDR_Q5_1_Q8_1_MMVQ]; + int vh[VDR_Q5_1_Q8_1_MMVQ]; + int u[2*VDR_Q5_1_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) { + vl[i] = get_int_from_uint8_aligned(bq5_1->qs, iqs + i); + vh[i] = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i)); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1); + } + + return vec_dot_q5_1_q8_1_impl(vl, vh, u, bq5_1->dm, bq8_1->ds); +} + +static __dpct_inline__ float +vec_dot_q8_0_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs) { + + const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq; + + int v[VDR_Q8_0_Q8_1_MMVQ]; + int u[VDR_Q8_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_int8(bq8_0->qs, iqs + i); + u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + } + + return vec_dot_q8_0_q8_1_impl(v, u, bq8_0->d, + bq8_1->ds[0]); +} + +static __dpct_inline__ float +vec_dot_q2_K_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs) { + + const block_q2_K * bq2_K = (const block_q2_K *) vbq; + + const int bq8_offset = QR2_K * (iqs / QI8_1); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + + const uint8_t * scales = bq2_K->scales + scale_offset; + + const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs); + int u[QR2_K]; + float d8[QR2_K]; + +#pragma unroll + for (int i = 0; i < QR2_K; ++ i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + d8[i] = bq8_1[bq8_offset + i].ds[0]; + } + + return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8); +} + +static __dpct_inline__ float +vec_dot_q3_K_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs) { + + const block_q3_K * bq3_K = (const block_q3_K *) vbq; + + const int bq8_offset = QR3_K * (iqs / (QI3_K/2)); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + + const float d = bq3_K->d; + + const int vl = get_int_from_uint8(bq3_K->qs, iqs); + + // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted + const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset; + + int u[QR3_K]; + float d8[QR3_K]; + +#pragma unroll + for (int i = 0; i < QR3_K; ++i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + d8[i] = bq8_1[bq8_offset + i].ds[0]; + } + + return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8); +} + +static __dpct_inline__ float vec_dot_q4_K_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, + const int & iqs) { +#ifndef GGML_QKK_64 + + const block_q4_K * bq4_K = (const block_q4_K *) vbq; + + const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2)); + const int * q4 = (const int *) (bq4_K->qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4)); + const uint16_t * scales = (const uint16_t *) bq4_K->scales; + + return vec_dot_q4_K_q8_1_common(q4, scales, bq4_K->dm, bq8_1, iqs); + +#else + +#if __SYCL_ARCH__ >= VER_4VEC // lowest compute capability for integer intrinsics + const block_q4_K * bq4_K = (const block_q4_K *) vbq; + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + + uint16_t aux16[2]; + const uint8_t * s = (const uint8_t *)aux16; + + const uint16_t * a = (const uint16_t *)bq4_K->scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + const float dall = bq4_K->dm[0]; + const float dmin = bq4_K->dm[1]; + + const float d8_1 = bq8_1[0].ds[0]; + const float d8_2 = bq8_1[1].ds[1]; + + const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); + const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); + const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2)); + const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4); + + const int * q4 = (const int *)bq4_K->qs + (iqs/2); + const int v1 = q4[0]; + const int v2 = q4[4]; + + const int dot1 = dpct::dp4a(ui2, v2 & 0x0f0f0f0f, dpct::dp4a(ui1, v1 & 0x0f0f0f0f, 0)); + const int dot2 = dpct::dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, dpct::dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0)); + const int dot3 = dpct::dp4a(0x01010101, ui2, dpct::dp4a(0x01010101, ui1, 0)); + const int dot4 = dpct::dp4a(0x01010101, ui4, dpct::dp4a(0x01010101, ui3, 0)); + + sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]); + sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]); + + return dall * sumf_d - dmin * sumf_m; + +#else + bad_arch(); +#endif // __SYCL_ARCH__ >= VER_4VEC + +#endif +} + +static __dpct_inline__ float +vec_dot_q5_K_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs) { + +#ifndef GGML_QKK_64 + const block_q5_K * bq5_K = (const block_q5_K *) vbq; + + int vl[2]; + int vh[2]; + int u[2*QR5_K]; + float d8[QR5_K]; + + const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2)); + const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); + const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4)); + + vl[0] = ql[0]; + vl[1] = ql[4]; + + vh[0] = qh[0] >> bq8_offset; + vh[1] = qh[4] >> bq8_offset; + + const uint16_t * scales = (const uint16_t *)bq5_K->scales; + uint16_t aux[2]; + const int j = bq8_offset/2; + if (j < 2) { + aux[0] = scales[j+0] & 0x3f3f; + aux[1] = scales[j+2] & 0x3f3f; + } else { + aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); + aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); + } + const uint8_t * sc = (const uint8_t *)aux; + const uint8_t * m = sc + 2; + +#pragma unroll + for (int i = 0; i < QR5_K; ++i) { + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + d8[i] = bq8i->ds[0]; + + const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); + u[2*i+0] = q8[0]; + u[2*i+1] = q8[4]; + } + + return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8); + +#else + +#if __SYCL_ARCH__ >= VER_4VEC // lowest compute capability for integer intrinsics + const block_q5_K * bq5_K = (const block_q5_K *) vbq; + + const int8_t * s = bq5_K->scales; + + const float d = bq5_K->d; + + const float d8_1 = bq8_1[0].ds[0]; + const float d8_2 = bq8_1[1].ds[1]; + + const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); + const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); + const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2)); + const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4); + + const int * ql = (const int *)bq5_K->qs + (iqs/2); + const int vl1 = ql[0]; + const int vl2 = ql[4]; + + const int step = 4 * (iqs/2); // 0, 4, 8, 12 + const int im = step/8; // = 0 for iqs = 0, 2, = 1 for iqs = 4, 6 + const int in = step%8; // 0, 4, 0, 4 + const int vh = (*((const int *)(bq5_K->qh + in))) >> im; + + const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f); + const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f); + const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f); + const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f); + + const float sumf_d = d8_1 * (dpct::dp4a(ui1, v1, 0) * s[0] + dpct::dp4a(ui2, v2, 0) * s[1]) + + d8_2 * (dpct::dp4a(ui3, v3, 0) * s[2] + dpct::dp4a(ui4, v4, 0) * s[3]); + + return d * sumf_d; + +#else + bad_arch(); +#endif // __SYCL_ARCH__ >= VER_4VEC + +#endif +} + +static __dpct_inline__ float +vec_dot_q6_K_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs) { + + const block_q6_K * bq6_K = (const block_q6_K *) vbq; + + const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4); + const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8); + const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4)); + + const int vl = get_int_from_uint8(bq6_K->ql, iqs); + const int vh = get_int_from_uint8(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift; + + const int8_t * scales = bq6_K->scales + scale_offset; + + int u[QR6_K]; + float d8[QR6_K]; + +#pragma unroll + for (int i = 0; i < QR6_K; ++i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1); + d8[i] = bq8_1[bq8_offset + 2 * i].ds[0]; + } + + return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8); +} + + +static __dpct_inline__ float +vec_dot_iq2_xxs_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs, + const uint64_t *iq2xxs_grid, const uint8_t *ksigns_iq2xs, + const uint8_t *kmask_iq2xs) { +#if QK_K == 256 + const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq; + + const int ib32 = iqs; + const uint16_t * q2 = bq2->qs + 4*ib32; + const uint8_t * aux8 = (const uint8_t *)q2; + const int8_t * q8 = bq8_1[ib32].qs; + uint32_t aux32 = q2[2] | (q2[3] << 16); + int sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); + const uint8_t signs = ksigns_iq2xs[aux32 & 127]; + for (int j = 0; j < 8; ++j) { + sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + aux32 >>= 7; + } + const float d = (float)bq2->d * (0.5f + aux32) * bq8_1[ib32].ds[0] * 0.25f; + return d * sumi; +#else + assert(false); + return 0.f; +#endif +} + +static __dpct_inline__ float +vec_dot_iq2_xs_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs, + const uint64_t *iq2xs_grid, const uint64_t *ksigns64) { +#if DPCT_COMPATIBILITY_TEMP >= \ + MIN_CC_DP4A // lowest compute capability for integer intrinsics +#if QK_K == 256 + const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq; + + const int ib32 = iqs; + const uint16_t * q2 = bq2->qs + 4*ib32; + const int8_t * q8 = bq8_1[ib32].qs; + const uint8_t ls1 = bq2->scales[ib32] & 0xf; + const uint8_t ls2 = bq2->scales[ib32] >> 4; + int sumi1 = 0; + for (int l = 0; l < 2; ++l) { + const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511)); + const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); + const int grid_l = dpct::vectorized_binary( + grid[0] ^ signs[0], signs[0], std::minus<>()); + const int grid_h = dpct::vectorized_binary( + grid[1] ^ signs[1], signs[1], std::minus<>()); + sumi1 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi1); + sumi1 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi1); + q8 += 8; + } + int sumi2 = 0; + for (int l = 2; l < 4; ++l) { + const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511)); + const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); + const int grid_l = dpct::vectorized_binary( + grid[0] ^ signs[0], signs[0], std::minus<>()); + const int grid_h = dpct::vectorized_binary( + grid[1] ^ signs[1], signs[1], std::minus<>()); + sumi2 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi2); + sumi2 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi2); + q8 += 8; + } + const float d = (float)bq2->d * bq8_1[ib32].ds[0] * 0.25f; + return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2); +#else + assert(false); + return 0.f; +#endif +#else + assert(false); + return 0.f; +#endif +} + +static __dpct_inline__ float +vec_dot_iq2_s_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs) { +#if QK_K == 256 + const block_iq2_s * bq2 = (const block_iq2_s *) vbq; + + const int ib32 = iqs; + const int8_t * q8 = bq8_1[ib32].qs; + const uint8_t * signs = bq2->qs + QK_K/8 + 4*ib32; + const uint8_t ls1 = bq2->scales[ib32] & 0xf; + const uint8_t ls2 = bq2->scales[ib32] >> 4; + int sumi1 = 0; + for (int l = 0; l < 2; ++l) { + const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300))); + const uint32_t signs0 = dpct::vectorized_binary( + ((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201, + std::equal_to<>()); + const uint32_t signs1 = dpct::vectorized_binary( + ((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201, + std::equal_to<>()); + const int grid_l = dpct::vectorized_binary( + grid[0] ^ signs0, signs0, std::minus<>()); + const int grid_h = dpct::vectorized_binary( + grid[1] ^ signs1, signs1, std::minus<>()); + sumi1 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi1); + sumi1 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi1); + q8 += 8; + } + int sumi2 = 0; + for (int l = 2; l < 4; ++l) { + const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300))); + const uint32_t signs0 = dpct::vectorized_binary( + ((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201, + std::equal_to<>()); + const uint32_t signs1 = dpct::vectorized_binary( + ((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201, + std::equal_to<>()); + const int grid_l = dpct::vectorized_binary( + grid[0] ^ signs0, signs0, std::minus<>()); + const int grid_h = dpct::vectorized_binary( + grid[1] ^ signs1, signs1, std::minus<>()); + sumi2 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi2); + sumi2 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi2); + q8 += 8; + } + const float d = (float)bq2->d * bq8_1[ib32].ds[0] * 0.25f; + return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2); +#else + assert(false); +#endif +} + +static __dpct_inline__ float +vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs, + const uint32_t *iq3xxs_grid, const uint64_t *ksigns64) { +#if DPCT_COMPATIBILITY_TEMP >= \ + MIN_CC_DP4A // lowest compute capability for integer intrinsics +#if QK_K == 256 + const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq; + + const int ib32 = iqs; + const uint8_t * q3 = bq2->qs + 8*ib32; + const uint16_t * gas = (const uint16_t *)(bq2->qs + QK_K/4) + 2*ib32; + const int8_t * q8 = bq8_1[ib32].qs; + uint32_t aux32 = gas[0] | (gas[1] << 16); + int sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint32_t * grid1 = iq3xxs_grid + q3[2*l+0]; + const uint32_t * grid2 = iq3xxs_grid + q3[2*l+1]; + const uint32_t * signs = (const uint32_t *)(ksigns64 + (aux32 & 127)); + const int grid_l = dpct::vectorized_binary( + grid1[0] ^ signs[0], signs[0], std::minus<>()); + const int grid_h = dpct::vectorized_binary( + grid2[0] ^ signs[1], signs[1], std::minus<>()); + sumi = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi); + sumi = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi); + q8 += 8; + aux32 >>= 7; + } + const float d = (float)bq2->d * (0.5f + aux32) * bq8_1[ib32].ds[0] * 0.5f; + return d * sumi; +#else + assert(false); + return 0.f; +#endif +#else + assert(false); + return 0.f; +#endif +} + +static __dpct_inline__ float +vec_dot_iq3_s_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs, + const uint32_t *iq3s_grid) { +#if QK_K == 256 + const block_iq3_s * bq2 = (const block_iq3_s *) vbq; + + const int ib32 = iqs; + const uint8_t * qs = bq2->qs + 8*ib32; + const int8_t * q8 = bq8_1[ib32].qs; + int sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint32_t * grid1 = iq3s_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)); + const uint32_t * grid2 = iq3s_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)); + uint32_t signs0 = dpct::vectorized_binary( + ((bq2->signs[4 * ib32 + l] & 0xf) * 0x01010101) & 0x08040201, + 0x08040201, std::equal_to<>()); + uint32_t signs1 = dpct::vectorized_binary( + ((bq2->signs[4 * ib32 + l] >> 4) * 0x01010101) & 0x08040201, + 0x08040201, std::equal_to<>()); + const int grid_l = dpct::vectorized_binary( + grid1[0] ^ signs0, signs0, std::minus<>()); + const int grid_h = dpct::vectorized_binary( + grid2[0] ^ signs1, signs1, std::minus<>()); + sumi = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi); + sumi = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi); + q8 += 8; + } + const float d = + (float)bq2->d * + (1 + 2 * ((bq2->scales[ib32 / 2] >> 4 * (ib32 % 2)) & 0xf)) * + bq8_1[ib32].ds[0]; + return d * sumi; +#else + assert(false); +#endif +} + +static __dpct_inline__ float +vec_dot_iq1_s_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs, + const uint32_t *iq1s_grid_gpu) { +#if QK_K == 256 + const block_iq1_s * bq1 = (const block_iq1_s *) vbq; + + const int ib32 = iqs; + int sumi = 0; + const int * q8 = (const int *)bq8_1[ib32].qs; + for (int l = 0; l < 4; ++l) { + const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8))); + int grid0 = grid[0] & 0x0f0f0f0f; + int grid1 = (grid[0] >> 4) & 0x0f0f0f0f; + sumi = dpct::dp4a(q8[2 * l + 1], grid1, + dpct::dp4a(q8[2 * l + 0], grid0, sumi)); + } + + const float delta = bq1->qh[ib32] & 0x8000 ? -1-IQ1S_DELTA : -1+IQ1S_DELTA; + const float d1q = (float)bq1->d * (2*((bq1->qh[ib32] >> 12) & 7) + 1); + const float d = d1q * bq8_1[ib32].ds[0]; + const float m = d1q * bq8_1[ib32].ds[1]; + return d * sumi + m * delta; +#else + assert(false); +#endif +} + +static __dpct_inline__ float +vec_dot_iq1_m_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs) { +#if QK_K == 256 + const block_iq1_m * bq1 = (const block_iq1_m *) vbq; + + const int ib32 = iqs; + int sumi[2] = {0, 0}; + float sumf[2] = {0.f, 0.f}; + + const int * q8 = (const int *)bq8_1[ib32].qs; + for (int l = 0; l < 4; ++l) { + const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 7) << 8))); + int grid0 = grid[0] & 0x0f0f0f0f; + int grid1 = (grid[0] >> 4) & 0x0f0f0f0f; + sumi[l / 2] = dpct::dp4a(q8[2 * l + 1], grid1, + dpct::dp4a(q8[2 * l + 0], grid0, sumi[l / 2])); + const float delta = (bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 0x08 ? -1-IQ1M_DELTA : -1+IQ1M_DELTA; + const int sumy = dpct::dp4a(q8[2 * l + 1], 0x01010101, + dpct::dp4a(q8[2 * l + 0], 0x01010101, 0)); + sumf[l/2] += delta*sumy; + } + + iq1m_scale_t scale; + const uint16_t * sc = (const uint16_t *)bq1->scales; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + const float d = (float)scale.f16 * bq8_1[ib32].ds[0]; + return d * ((sumi[0] + sumf[0]) * (2*((sc[ib32/2] >> 6*(ib32%2)) & 0x7) + 1) + (sumi[1] + sumf[1]) * (2*((sc[ib32/2] >> (6*(ib32%2)+3)) & 0x7) + 1)); +#else + assert(false); +#endif +} + + +static __dpct_inline__ float +vec_dot_iq4_nl_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs) { + + const block_iq4_nl * bq = (const block_iq4_nl *) vbq; + + const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs; + const int32_t * q8 = (const int32_t *)bq8_1->qs + iqs; + + const uint8_t * values = (const uint8_t *)kvalues_iq4nl; + + int v1, v2; + int sumi1 = 0, sumi2 = 0; + for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) { + const uint32_t aux = q4[2*l] | (q4[2*l+1] << 16); + get_int_from_table_16(aux, values, v1, v2); + sumi1 = dpct::dp4a(v1, q8[l + 0], sumi1); + sumi2 = dpct::dp4a(v2, q8[l + 4], sumi2); + } + + const float d = (float)bq->d * bq8_1->ds[0]; + return d * (sumi1 + sumi2); +} + + +static __dpct_inline__ float +vec_dot_iq4_xs_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs) { + +#if QK_K == 256 + const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq; + const uint8_t * values = (const uint8_t *)kvalues_iq4nl; + + // iqs is 0...7 + const int ib32 = iqs; + const int32_t * q8 = (const int *)bq8_1[ib32].qs; + const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32; + const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4); + const float d = (float)bq4->d * (ls - 32) * bq8_1[ib32].ds[0]; + int v1, v2; + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 4; ++j) { + get_int_from_table_16(q4[j], values, v1, v2); + sumi1 = dpct::dp4a(v1, q8[j + 0], sumi1); + sumi2 = dpct::dp4a(v2, q8[j + 4], sumi2); + } + return d * (sumi1 + sumi2); +#else + assert(false); +#endif +} + +#endif // GGML_SYCL_VECDOTQ_HPP diff --git a/ggml/src/ggml-sycl/wkv.cpp b/ggml/src/ggml-sycl/wkv.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c10e2f7645e89e045ca25e86a8598e734179ed26 --- /dev/null +++ b/ggml/src/ggml-sycl/wkv.cpp @@ -0,0 +1,293 @@ +#include +#include "wkv.hpp" + +constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE + +// Helper function for the main kernel +template +static void rwkv_wkv6_f32_kernel( + const int B, const int T, const int C, const int H, + const float* k, const float* v, const float* r, + const float* tf, const float* td, const float* s, + float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) { + + const int tid = item_ct1.get_local_id(2); + const int bid = item_ct1.get_group(2); + + const int head_size = block_size; + const int batch_i = bid / H; + const int head_i = bid % H; + const int state_size = C * head_size; + const int n_seq_tokens = T / B; + + // Set up shared memory pointers + float* _k = shared_mem; + float* _r = _k + head_size; + float* _tf = _r + head_size; + float* _td = _tf + head_size; + + // Local state array + float state[block_size]; + + // Load initial state + #pragma unroll + for (int i = 0; i < head_size; i++) { + state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid]; + } + + // Sync threads before shared memory operations + item_ct1.barrier(sycl::access::fence_space::local_space); + + // Load time-mixing parameters + _tf[tid] = tf[head_i * head_size + tid]; + item_ct1.barrier(sycl::access::fence_space::local_space); + + // Main sequence processing loop + for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; + t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; + t += C) { + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // Load current timestep data to shared memory + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + + item_ct1.barrier(sycl::access::fence_space::local_space); + + const float _v = v[t]; + float y = 0; + + // Process in chunks of 4 for better vectorization + sycl::float4 k4, r4, tf4, td4, s4; + #pragma unroll + for (int j = 0; j < head_size; j += 4) { + // Load data in vec4 chunks + k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); + td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]); + s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]); + + // Compute key-value product + sycl::float4 kv4 = k4 * _v; + + // Accumulate weighted sum + y += sycl::dot(r4, tf4 * kv4 + s4); + + // Update state + s4 = s4 * td4 + kv4; + + // Store updated state + state[j] = s4.x(); + state[j+1] = s4.y(); + state[j+2] = s4.z(); + state[j+3] = s4.w(); + } + + dst[t] = y; + } + + // Save final state + #pragma unroll + for (int i = 0; i < head_size; i++) { + dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i]; + } +} + +template +static void rwkv_wkv7_f32_kernel( + const int B, const int T, const int C, const int H, + const float* r, const float* w, const float* k, const float* v, + const float* a, const float* b, const float* s, + float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) { + + const int tid = item_ct1.get_local_id(2); + const int bid = item_ct1.get_group(2); + + const int head_size = block_size; + const int batch_i = bid / H; + const int head_i = bid % H; + const int state_size = C * head_size; + const int n_seq_tokens = T / B; + + float* _r = shared_mem; + float* _w = _r + head_size; + float* _k = _w + head_size; + float* _a = _k + head_size; + float* _b = _a + head_size; + + float state[block_size]; + + #pragma unroll + for (int i = 0; i < head_size; i++) { + state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i]; + } + + for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; + t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; + t += C) { + + item_ct1.barrier(sycl::access::fence_space::local_space); + + _r[tid] = r[t]; + _w[tid] = w[t]; + _k[tid] = k[t]; + _a[tid] = a[t]; + _b[tid] = b[t]; + + item_ct1.barrier(sycl::access::fence_space::local_space); + + const float _v = v[t]; + float y = 0, sa = 0; + sycl::float4 a4, s4; + + #pragma unroll + for (int j = 0; j < head_size; j += 4) { + a4 = sycl::float4(_a[j], _a[j+1], _a[j+2], _a[j+3]); + s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]); + sa += sycl::dot(a4, s4); + } + + sycl::float4 r4, w4, k4, b4; + #pragma unroll + for (int j = 0; j < head_size; j += 4) { + r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + w4 = sycl::float4(_w[j], _w[j+1], _w[j+2], _w[j+3]); + k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + b4 = sycl::float4(_b[j], _b[j+1], _b[j+2], _b[j+3]); + s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]); + + sycl::float4 kv4 = k4 * _v; + + s4 = s4 * w4 + kv4 + sa * b4; + y += sycl::dot(r4, s4); + + state[j] = s4.x(); + state[j+1] = s4.y(); + state[j+2] = s4.z(); + state[j+3] = s4.w(); + } + + dst[t] = y; + } + + #pragma unroll + for (int i = 0; i < head_size; i++) { + dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i]; + } +} + +void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/6); + const float* k_d = (const float*)dst->src[0]->data; + const float* v_d = (const float*)dst->src[1]->data; + const float* r_d = (const float*)dst->src[2]->data; + const float* tf_d = (const float*)dst->src[3]->data; + const float* td_d = (const float*)dst->src[4]->data; + const float* s_d = (const float*)dst->src[5]->data; + float* dst_d = (float*)dst->data; + + const int64_t B = dst->src[5]->ne[1]; + const int64_t T = dst->src[0]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[1]; + + GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64 + + dpct::queue_ptr stream = ctx.stream(); + + // Calculate execution configuration + const size_t shared_mem_size = C / H * 4 * sizeof(float); // For k, r, tf, td + sycl::range<3> block_dims(1, 1, C / H); + sycl::range<3> grid_dims(1, 1, B * H); + + // Submit kernel + if (C / H == WKV_BLOCK_SIZE) { + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rwkv_wkv6_f32_kernel( + B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d, + item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() + ); + }); + }); + } else { + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rwkv_wkv6_f32_kernel( + B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d, + item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() + ); + }); + }); + } +} + +void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/7); + const float* r_d = (const float*)dst->src[0]->data; + const float* w_d = (const float*)dst->src[1]->data; + const float* k_d = (const float*)dst->src[2]->data; + const float* v_d = (const float*)dst->src[3]->data; + const float* a_d = (const float*)dst->src[4]->data; + const float* b_d = (const float*)dst->src[5]->data; + const float* s_d = (const float*)dst->src[6]->data; + float* dst_d = (float*)dst->data; + + const int64_t B = dst->src[6]->ne[1]; + const int64_t T = dst->src[0]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[1]; + + GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2); + + dpct::queue_ptr stream = ctx.stream(); + + // Calculate execution configuration + const size_t shared_mem_size = C / H * 5 * sizeof(float); // For r, w, k, a, b + sycl::range<3> block_dims(1, 1, C / H); + sycl::range<3> grid_dims(1, 1, B * H); + + // Submit kernel + if (C / H == WKV_BLOCK_SIZE) { + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rwkv_wkv7_f32_kernel( + B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, + item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() + ); + }); + }); + } else { + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rwkv_wkv7_f32_kernel( + B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, + item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() + ); + }); + }); + } +} diff --git a/ggml/src/ggml-sycl/wkv.hpp b/ggml/src/ggml-sycl/wkv.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9f34a1001fd68a12bf01fa4b8654ea363711d927 --- /dev/null +++ b/ggml/src/ggml-sycl/wkv.hpp @@ -0,0 +1,10 @@ +#ifndef GGML_SYCL_WKV_HPP +#define GGML_SYCL_WKV_HPP + +#include "common.hpp" + +void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif // GGML_SYCL_WKV_HPP diff --git a/ggml/src/ggml-threading.cpp b/ggml/src/ggml-threading.cpp new file mode 100644 index 0000000000000000000000000000000000000000..25a19eedb90535daccd640056175dbbbc58c1ca7 --- /dev/null +++ b/ggml/src/ggml-threading.cpp @@ -0,0 +1,12 @@ +#include "ggml-threading.h" +#include + +std::mutex ggml_critical_section_mutex; + +void ggml_critical_section_start() { + ggml_critical_section_mutex.lock(); +} + +void ggml_critical_section_end(void) { + ggml_critical_section_mutex.unlock(); +} diff --git a/ggml/src/ggml-threading.h b/ggml/src/ggml-threading.h new file mode 100644 index 0000000000000000000000000000000000000000..dec2c8840aa3669c2f813ba940d192f803e1da5e --- /dev/null +++ b/ggml/src/ggml-threading.h @@ -0,0 +1,14 @@ +#pragma once + +#include "ggml.h" + +#ifdef __cplusplus +extern "C" { +#endif + +GGML_API void ggml_critical_section_start(void); +GGML_API void ggml_critical_section_end(void); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..95e2ebe643732c3ee086a320317158b8464a04d8 --- /dev/null +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -0,0 +1,185 @@ +cmake_minimum_required(VERSION 3.19) +cmake_policy(SET CMP0114 NEW) + +find_package(Vulkan COMPONENTS glslc REQUIRED) + +function(detect_host_compiler) + if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") + find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH) + find_program(HOST_CXX_COMPILER NAMES cl g++ clang++ NO_CMAKE_FIND_ROOT_PATH) + else() + find_program(HOST_C_COMPILER NAMES gcc clang NO_CMAKE_FIND_ROOT_PATH) + find_program(HOST_CXX_COMPILER NAMES g++ clang++ NO_CMAKE_FIND_ROOT_PATH) + endif() + set(HOST_C_COMPILER "${HOST_C_COMPILER}" PARENT_SCOPE) + set(HOST_CXX_COMPILER "${HOST_CXX_COMPILER}" PARENT_SCOPE) +endfunction() + +# Function to test shader extension support +# Parameters: +# EXTENSION_NAME - Name of the extension to test (e.g., "GL_EXT_integer_dot_product") +# TEST_SHADER_FILE - Path to the test shader file +# RESULT_VARIABLE - Name of the variable to set (ON/OFF) based on test result +function(test_shader_extension_support EXTENSION_NAME TEST_SHADER_FILE RESULT_VARIABLE) + execute_process( + COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${TEST_SHADER_FILE}" + OUTPUT_VARIABLE glslc_output + ERROR_VARIABLE glslc_error + ) + + if (${glslc_error} MATCHES ".*extension not supported: ${EXTENSION_NAME}.*") + message(STATUS "${EXTENSION_NAME} not supported by glslc") + set(${RESULT_VARIABLE} OFF PARENT_SCOPE) + else() + message(STATUS "${EXTENSION_NAME} supported by glslc") + set(${RESULT_VARIABLE} ON PARENT_SCOPE) + add_compile_definitions(${RESULT_VARIABLE}) + + # Ensure the extension support is forwarded to vulkan-shaders-gen + list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -D${RESULT_VARIABLE}=ON) + set(VULKAN_SHADER_GEN_CMAKE_ARGS "${VULKAN_SHADER_GEN_CMAKE_ARGS}" PARENT_SCOPE) + endif() +endfunction() + +if (Vulkan_FOUND) + message(STATUS "Vulkan found") + + ggml_add_backend_library(ggml-vulkan + ggml-vulkan.cpp + ../../include/ggml-vulkan.h + ) + + set(VULKAN_SHADER_GEN_CMAKE_ARGS "") + + # Test all shader extensions + test_shader_extension_support( + "GL_KHR_cooperative_matrix" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp" + "GGML_VULKAN_COOPMAT_GLSLC_SUPPORT" + ) + + test_shader_extension_support( + "GL_NV_cooperative_matrix2" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp" + "GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT" + ) + + test_shader_extension_support( + "GL_EXT_integer_dot_product" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp" + "GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT" + ) + + test_shader_extension_support( + "GL_EXT_bfloat16" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp" + "GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT" + ) + + target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan) + target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) + + # Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build + # Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector + if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0) + endif() + + if (GGML_VULKAN_CHECK_RESULTS) + add_compile_definitions(GGML_VULKAN_CHECK_RESULTS) + endif() + + if (GGML_VULKAN_DEBUG) + add_compile_definitions(GGML_VULKAN_DEBUG) + endif() + + if (GGML_VULKAN_MEMORY_DEBUG) + add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG) + endif() + + if (GGML_VULKAN_SHADER_DEBUG_INFO) + add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO) + endif() + + if (GGML_VULKAN_VALIDATE) + add_compile_definitions(GGML_VULKAN_VALIDATE) + endif() + + if (GGML_VULKAN_RUN_TESTS) + add_compile_definitions(GGML_VULKAN_RUN_TESTS) + endif() + + # Set up toolchain for host compilation whether cross-compiling or not + if (CMAKE_CROSSCOMPILING) + if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN) + set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN}) + else() + detect_host_compiler() + if (NOT HOST_C_COMPILER OR NOT HOST_CXX_COMPILER) + message(FATAL_ERROR "Host compiler not found") + else() + message(STATUS "Host compiler: ${HOST_C_COMPILER} ${HOST_CXX_COMPILER}") + endif() + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/host-toolchain.cmake.in ${CMAKE_BINARY_DIR}/host-toolchain.cmake @ONLY) + set(HOST_CMAKE_TOOLCHAIN_FILE ${CMAKE_BINARY_DIR}/host-toolchain.cmake) + endif() + else() + # For non-cross-compiling, use empty toolchain (use host compiler) + set(HOST_CMAKE_TOOLCHAIN_FILE "") + endif() + + include(ExternalProject) + + if (CMAKE_CROSSCOMPILING) + list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE}) + message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}") + endif() + + ExternalProject_Add( + vulkan-shaders-gen + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}/$ + -DCMAKE_INSTALL_BINDIR=. + -DCMAKE_BUILD_TYPE=$ + ${VULKAN_SHADER_GEN_CMAKE_ARGS} + + BUILD_COMMAND ${CMAKE_COMMAND} --build . --config $ + INSTALL_COMMAND ${CMAKE_COMMAND} --install . --config $ + ) + ExternalProject_Add_StepTargets(vulkan-shaders-gen build install) + + set (_ggml_vk_host_suffix $,.exe,>) + set (_ggml_vk_genshaders_dir "${CMAKE_BINARY_DIR}/$") + set (_ggml_vk_genshaders_cmd "${_ggml_vk_genshaders_dir}/vulkan-shaders-gen${_ggml_vk_host_suffix}") + set (_ggml_vk_header "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp") + set (_ggml_vk_source "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp") + set (_ggml_vk_input_dir "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders") + set (_ggml_vk_output_dir "${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv") + + file(GLOB _ggml_vk_shader_files CONFIGURE_DEPENDS "${_ggml_vk_input_dir}/*.comp") + + add_custom_command( + OUTPUT ${_ggml_vk_header} + ${_ggml_vk_source} + + COMMAND ${_ggml_vk_genshaders_cmd} + --glslc ${Vulkan_GLSLC_EXECUTABLE} + --input-dir ${_ggml_vk_input_dir} + --output-dir ${_ggml_vk_output_dir} + --target-hpp ${_ggml_vk_header} + --target-cpp ${_ggml_vk_source} + --no-clean + + DEPENDS ${_ggml_vk_shader_files} + vulkan-shaders-gen + vulkan-shaders-gen-build + vulkan-shaders-gen-install + + COMMENT "Generate vulkan shaders" + ) + + target_sources(ggml-vulkan PRIVATE ${_ggml_vk_source} ${_ggml_vk_header}) + +else() + message(WARNING "Vulkan not found") +endif() diff --git a/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in b/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in new file mode 100644 index 0000000000000000000000000000000000000000..2d8a85696d374e68822f854d6e59247c90a934a8 --- /dev/null +++ b/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in @@ -0,0 +1,15 @@ +set(CMAKE_BUILD_TYPE Release) +set(CMAKE_C_FLAGS -O2) +set(CMAKE_CXX_FLAGS -O2) +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE NEVER) +set(CMAKE_C_COMPILER "@HOST_C_COMPILER@") +set(CMAKE_CXX_COMPILER "@HOST_CXX_COMPILER@") +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY @CMAKE_RUNTIME_OUTPUT_DIRECTORY@) + +if("@CMAKE_C_COMPILER_ID@" STREQUAL "MSVC") + foreach(CONFIG IN ITEMS DEBUG RELEASE MINSIZEREL RELWITHDEBINFO) + set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) + endforeach() +endif() diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8d62303aabd7ff4c90926e9aea683704b25f0931 --- /dev/null +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -0,0 +1,10931 @@ +#include "ggml-vulkan.h" +#include +#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_CHECK_RESULTS) +#include +#include "ggml-cpu.h" +#endif + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +# define NOMINMAX 1 +# include +# define YIELD() YieldProcessor() +#elif defined(__clang__) || defined(__GNUC__) +# if defined(__x86_64__) ||defined(__i386__) +# include +# define YIELD() _mm_pause() +# elif defined(__arm__) || defined(__aarch64__) +# if defined(__clang__) +# include +# define YIELD() __yield() +# else +# define YIELD() asm volatile("yield") +# endif +# endif +#endif + +#if !defined(YIELD) +#define YIELD() +#endif + +#include "ggml-impl.h" +#include "ggml-backend-impl.h" + +#include "ggml-vulkan-shaders.hpp" + +// remove this once it's more widely available in the SDK +#if !defined(VK_KHR_shader_bfloat16) + +#define VK_KHR_shader_bfloat16 1 +#define VK_KHR_SHADER_BFLOAT16_SPEC_VERSION 1 +#define VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME "VK_KHR_shader_bfloat16" +#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ((VkStructureType)1000141000) +#define VK_COMPONENT_TYPE_BFLOAT16_KHR ((VkComponentTypeKHR)1000141000) + +typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR { + VkStructureType sType; + void* pNext; + VkBool32 shaderBFloat16Type; + VkBool32 shaderBFloat16DotProduct; + VkBool32 shaderBFloat16CooperativeMatrix; +} VkPhysicalDeviceShaderBfloat16FeaturesKHR; +#endif + +#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1)) +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) +static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } + +#define VK_VENDOR_ID_AMD 0x1002 +#define VK_VENDOR_ID_APPLE 0x106b +#define VK_VENDOR_ID_INTEL 0x8086 +#define VK_VENDOR_ID_NVIDIA 0x10de + +#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256 + +#define GGML_VK_MAX_NODES 8192 + +#define MAX_VK_BUFFERS 256 + +#define VK_CHECK(err, msg) \ + do { \ + vk::Result err_ = (err); \ + if (err_ != vk::Result::eSuccess) { \ + fprintf(stderr, "ggml_vulkan: %s error %s at %s:%d\n", \ + #err, to_string(err_).c_str(), __FILE__, __LINE__); \ + exit(1); \ + } \ + } while (0) + +#ifdef GGML_VULKAN_DEBUG +#define VK_LOG_DEBUG(msg) std::cerr << msg << std::endl +#else +#define VK_LOG_DEBUG(msg) ((void) 0) +#endif // GGML_VULKAN_DEBUG + +struct ggml_backend_vk_context; + +#define MAX_PARAMETER_COUNT 8 + +struct vk_pipeline_struct { + std::string name; + vk::ShaderModule shader_module; + vk::PipelineLayout layout; + vk::Pipeline pipeline; + uint32_t push_constant_size; + uint32_t parameter_count; + std::array wg_denoms; + uint32_t align; + // set to true to request the pipeline is compiled after the dryrun + bool needed {}; + // set to true when the shader has been compiled + bool compiled {}; +}; + +typedef std::shared_ptr vk_pipeline; +typedef std::weak_ptr vk_pipeline_ref; + +static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline); + +struct vk_matmul_pipeline_struct { + vk_pipeline l, m, s; + vk_pipeline a_l, a_m, a_s; +}; + +typedef std::shared_ptr vk_matmul_pipeline; + +struct vk_matmul_pipeline2 { + vk_matmul_pipeline2() { + f16acc = std::make_shared(); + f32acc = std::make_shared(); + } + vk_matmul_pipeline f32acc; + vk_matmul_pipeline f16acc; +}; + +struct vk_device_struct; +typedef std::shared_ptr vk_device; +typedef std::weak_ptr vk_device_ref; + +struct vk_buffer_struct; +typedef std::shared_ptr vk_buffer; +typedef std::weak_ptr vk_buffer_ref; + +struct ggml_backend_vk_buffer_type_context { + std::string name; + vk_device device; +}; + +struct vk_queue; + +// Stores command pool/buffers. There's an instance of this +// for each (context,queue) pair and for each (device,queue) pair. +struct vk_command_pool { + void init(vk_device& device, vk_queue *q_); + void destroy(vk::Device& device); + + vk::CommandPool pool; + uint32_t cmd_buffer_idx; + std::vector cmd_buffers; + + vk_queue *q; +}; + +// Prevent simultaneous submissions to the same queue. +// This could be per vk_queue if we stopped having two vk_queue structures +// sharing the same vk::Queue. +static std::mutex queue_mutex; + +struct vk_queue { + uint32_t queue_family_index; + vk::Queue queue; + + vk_command_pool cmd_pool; + + vk::PipelineStageFlags stage_flags; + + bool transfer_only; + + // copy everything except the cmd_pool + void copyFrom(vk_queue &other) { + queue_family_index = other.queue_family_index; + queue = other.queue; + stage_flags = other.stage_flags; + transfer_only = other.transfer_only; + } +}; + +static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft); +static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size); +static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft); +static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft); +static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor); +static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = { + /* .get_name = */ ggml_backend_vk_buffer_type_name, + /* .alloc_buffer = */ ggml_backend_vk_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_vk_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_vk_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_vk_buffer_type_get_alloc_size, + /* .is_host = */ NULL, +}; + +#ifdef GGML_VULKAN_MEMORY_DEBUG +class vk_memory_logger; +#endif +class vk_perf_logger; +static void ggml_vk_destroy_buffer(vk_buffer& buf); + +static constexpr uint32_t mul_mat_vec_max_cols = 8; +static constexpr uint32_t p021_max_gqa_ratio = 8; + +enum vk_device_architecture { + OTHER, + AMD_GCN, + AMD_RDNA1, + AMD_RDNA2, + AMD_RDNA3, + INTEL_XE2, +}; + +static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) { + vk::PhysicalDeviceProperties props = device.getProperties(); + + if (props.vendorID == VK_VENDOR_ID_AMD) { + const std::vector ext_props = device.enumerateDeviceExtensionProperties(); + + bool amd_shader_core_properties = false; + bool integer_dot_product = false; + bool subgroup_size_control = false; + + for (const auto& properties : ext_props) { + if (strcmp("VK_AMD_shader_core_properties", properties.extensionName) == 0) { + amd_shader_core_properties = true; + } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0) { + integer_dot_product = true; + } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) { + subgroup_size_control = true; + } + } + + if (!amd_shader_core_properties || !integer_dot_product || !subgroup_size_control) { + return vk_device_architecture::OTHER; + } + + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceShaderCorePropertiesAMD shader_core_props_amd; + vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props; + vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; + + props2.pNext = &shader_core_props_amd; + shader_core_props_amd.pNext = &integer_dot_props; + integer_dot_props.pNext = &subgroup_size_control_props; + + device.getProperties2(&props2); + + if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 64) { + return vk_device_architecture::AMD_GCN; + } + if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 32) { + // RDNA + if (shader_core_props_amd.wavefrontsPerSimd == 20) { + return vk_device_architecture::AMD_RDNA1; + } + if (integer_dot_props.integerDotProduct4x8BitPackedMixedSignednessAccelerated) { + return vk_device_architecture::AMD_RDNA3; + } + return vk_device_architecture::AMD_RDNA2; + } + } else if (props.vendorID == VK_VENDOR_ID_INTEL) { + const std::vector ext_props = device.enumerateDeviceExtensionProperties(); + + bool subgroup_size_control = false; + + for (const auto& properties : ext_props) { + if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) { + subgroup_size_control = true; + } + } + + if (!subgroup_size_control) { + return vk_device_architecture::OTHER; + } + + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; + + props2.pNext = &subgroup_size_control_props; + device.getProperties2(&props2); + + if (subgroup_size_control_props.minSubgroupSize == 16) { + // Xe2 architecture uses SIMD16 while previous Xe and Gen architecture uses SIMD8. + // Minimum subgroup size matches the SIMD width so we distinguish architecture by checking this value. + // https://www.intel.com/content/www/us/en/content-details/824434/2024-intel-tech-tour-xe2-and-lunar-lake-s-gpu.html + // https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html + return vk_device_architecture::INTEL_XE2; + } + } + return vk_device_architecture::OTHER; +} + +struct vk_device_struct { + std::mutex mutex; + + vk::PhysicalDevice physical_device; + vk::PhysicalDeviceProperties properties; + std::string name; + uint64_t max_memory_allocation_size; + uint64_t suballocation_block_size; + bool fp16; + bool pipeline_robustness; + vk::Device device; + uint32_t vendor_id; + vk::DriverId driver_id; + vk_device_architecture architecture; + vk_queue compute_queue; + vk_queue transfer_queue; + bool single_queue; + uint32_t subgroup_size; + uint32_t shader_core_count; + bool uma; + bool prefer_host_memory; + bool float_controls_rte_fp16; + bool subgroup_add; + bool subgroup_shuffle; + + bool integer_dot_product; + + bool subgroup_size_control; + uint32_t subgroup_min_size; + uint32_t subgroup_max_size; + bool subgroup_require_full_support; + + bool coopmat_support; + bool coopmat_acc_f32_support {}; + bool coopmat_acc_f16_support {}; + bool coopmat_bf16_support {}; + bool coopmat_support_16x16x16_f16acc {}; + bool coopmat_support_16x16x16_f32acc {}; + bool coopmat1_fa_support {}; + uint32_t coopmat_m; + uint32_t coopmat_n; + uint32_t coopmat_k; + + bool coopmat_int_support; + uint32_t coopmat_int_m; + uint32_t coopmat_int_n; + uint32_t coopmat_int_k; + + bool coopmat2; + + size_t idx; + + bool mul_mat_l[GGML_TYPE_COUNT]; + bool mul_mat_m[GGML_TYPE_COUNT]; + bool mul_mat_s[GGML_TYPE_COUNT]; + bool mul_mat_id_l[GGML_TYPE_COUNT]; + bool mul_mat_id_m[GGML_TYPE_COUNT]; + bool mul_mat_id_s[GGML_TYPE_COUNT]; + + // set to true to indicate that some shaders need to be compiled after the dryrun + bool need_compiles {}; + + vk::DescriptorSetLayout dsl; + + vk_matmul_pipeline pipeline_matmul_f32 {}; + vk_matmul_pipeline pipeline_matmul_f32_f16 {}; + vk_matmul_pipeline pipeline_matmul_bf16 {}; + vk_matmul_pipeline2 pipeline_matmul_f16; + vk_matmul_pipeline2 pipeline_matmul_f16_f32; + + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT]; + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT]; + + vk_matmul_pipeline pipeline_matmul_id_f32 {}; + vk_matmul_pipeline pipeline_matmul_id_bf16 {}; + vk_matmul_pipeline2 pipeline_matmul_id_f16; + vk_matmul_pipeline2 pipeline_matmul_id_f16_f32; + + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT]; + + vk_pipeline pipeline_matmul_split_k_reduce; + vk_pipeline pipeline_quantize_q8_1; + + vk_pipeline pipeline_dequant[GGML_TYPE_COUNT]; + vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT]; + + vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio]; + vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; + vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; + vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_acc_f32; + + // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16] + vk_pipeline pipeline_add[2][2][2]; + vk_pipeline pipeline_add_norepeat[2][2][2]; + vk_pipeline pipeline_sub[2][2][2]; + vk_pipeline pipeline_sub_norepeat[2][2][2]; + vk_pipeline pipeline_mul[2][2][2]; + vk_pipeline pipeline_mul_norepeat[2][2][2]; + vk_pipeline pipeline_div[2][2][2]; + vk_pipeline pipeline_div_norepeat[2][2][2]; + + vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; + vk_pipeline pipeline_upscale_f32; + vk_pipeline pipeline_scale_f32; + vk_pipeline pipeline_sqr_f32; + vk_pipeline pipeline_sin_f32; + vk_pipeline pipeline_cos_f32; + vk_pipeline pipeline_clamp_f32; + vk_pipeline pipeline_pad_f32; + vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; + vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16; + vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16; + vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; + vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_norm_f32; + vk_pipeline pipeline_group_norm_f32; + vk_pipeline pipeline_rms_norm_f32; + vk_pipeline pipeline_rms_norm_back_f32; + vk_pipeline pipeline_l2_norm_f32; + + // [src/dst 0=fp32,1=fp16] + vk_pipeline pipeline_gelu[2]; + vk_pipeline pipeline_gelu_quick[2]; + vk_pipeline pipeline_silu[2]; + vk_pipeline pipeline_relu[2]; + vk_pipeline pipeline_tanh[2]; + vk_pipeline pipeline_sigmoid[2]; + + vk_pipeline pipeline_leaky_relu_f32; + vk_pipeline pipeline_silu_back_f32; + vk_pipeline pipeline_diag_mask_inf_f32; + vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; + vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; + vk_pipeline pipeline_soft_max_back_f32; + vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16; + vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16; + vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; + vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; + vk_pipeline pipeline_argsort_f32; + vk_pipeline pipeline_sum_rows_f32; + vk_pipeline pipeline_argmax_f32; + vk_pipeline pipeline_count_equal_i32; + vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; + vk_pipeline pipeline_timestep_embedding_f32; + vk_pipeline pipeline_conv_transpose_1d_f32; + vk_pipeline pipeline_pool2d_f32; + vk_pipeline pipeline_rwkv_wkv6_f32; + vk_pipeline pipeline_rwkv_wkv7_f32; + vk_pipeline pipeline_opt_step_adamw_f32; + vk_pipeline pipeline_conv2d_dw_whcn_f32; + vk_pipeline pipeline_conv2d_dw_cwhn_f32; + + // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} + vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2]; + + vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2]; + + vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2]; + + vk_pipeline pipeline_flash_attn_split_k_reduce; + + std::unordered_map pipelines; + + std::vector> pinned_memory; + + vk::Fence fence; + vk_buffer sync_staging; + + ggml_backend_buffer_type buffer_type; + +#ifdef GGML_VULKAN_MEMORY_DEBUG + std::unique_ptr memory_logger; +#endif + + // for GGML_VK_PERF_LOGGER + std::unique_ptr perf_logger; + vk::QueryPool query_pool; + int32_t num_queries; + + ~vk_device_struct() { + VK_LOG_DEBUG("destroy device " << name); + + device.destroyFence(fence); + + ggml_vk_destroy_buffer(sync_staging); + + compute_queue.cmd_pool.destroy(device); + transfer_queue.cmd_pool.destroy(device); + + for (auto& pipeline : pipelines) { + if (pipeline.second.expired()) { + continue; + } + + vk_pipeline pl = pipeline.second.lock(); + ggml_vk_destroy_pipeline(device, pl); + } + pipelines.clear(); + + device.destroyDescriptorSetLayout(dsl); + + device.destroy(); + } +}; + +void vk_command_pool::init(vk_device& device, vk_queue *q_) { + cmd_buffer_idx = 0; + q = q_; + + vk::CommandPoolCreateInfo command_pool_create_info(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), q->queue_family_index); + pool = device->device.createCommandPool(command_pool_create_info); +} + +void vk_command_pool::destroy(vk::Device& device) { + device.destroyCommandPool(pool); + pool = nullptr; + cmd_buffers.clear(); +} + +struct vk_buffer_struct { + vk::Buffer buffer = VK_NULL_HANDLE; + vk::DeviceMemory device_memory = VK_NULL_HANDLE; + vk::MemoryPropertyFlags memory_property_flags; + void * ptr; + size_t size = 0; + + vk_device device; + + ~vk_buffer_struct() { + if (size == 0) { + return; + } + VK_LOG_DEBUG("~vk_buffer_struct(" << buffer << ", " << size << ")"); + + device->device.freeMemory(device_memory); + device->device.destroyBuffer(buffer); + } +}; + +struct vk_subbuffer { + vk_buffer buffer; + uint64_t offset; + uint64_t size; + + operator vk::DescriptorBufferInfo() const { + return { buffer->buffer, offset, size }; + } +}; + +struct vk_semaphore { + vk::Semaphore s; + uint64_t value; +}; + +struct vk_submission { + vk::CommandBuffer buffer; + std::vector wait_semaphores; + std::vector signal_semaphores; +}; + +typedef std::vector vk_sequence; + +struct vk_mat_mat_push_constants { + uint32_t M; uint32_t N; uint32_t K; + uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; + uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t k_split; + uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; + uint32_t padded_N; +}; +struct vk_mat_vec_push_constants { + uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; + uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; +}; + +struct vk_mat_mat_id_push_constants { + uint32_t M; uint32_t N; uint32_t K; + uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; + uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11; + uint32_t padded_N; +}; +struct vk_mat_vec_id_push_constants { + uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; + uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t nei0; uint32_t ne11; +}; + +struct vk_flash_attn_push_constants { + uint32_t N; + uint32_t KV; + + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + + uint32_t neq2; + uint32_t neq3; + uint32_t nek2; + uint32_t nek3; + uint32_t nev2; + uint32_t nev3; + uint32_t nem1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t nb21; + uint32_t nb22; + uint32_t nb23; + uint32_t nb31; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t mask; + uint32_t n_head_log2; + float m0; + float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; +}; + +struct vk_op_push_constants { + uint32_t KX; + uint32_t KY; + float param1; + float param2; +}; + +struct vk_op_unary_push_constants { + uint32_t ne; + uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t misalign_offsets; + float param1; float param2; + uint32_t ne0_012mp; uint32_t ne0_012L; + uint32_t ne0_01mp; uint32_t ne0_01L; + uint32_t ne0_0mp; uint32_t ne0_0L; + uint32_t ne1_012mp; uint32_t ne1_012L; + uint32_t ne1_01mp; uint32_t ne1_01L; + uint32_t ne1_0mp; uint32_t ne1_0L; +}; +static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128"); + +// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. +// Precompute mp (m' in the paper) and L such that division +// can be computed using a multiply (high 32b of 64b result) +// and a shift: +// +// n/d = (mulhi(n, mp) + n) >> L; +static void init_fastdiv_values(uint32_t d, uint32_t &mp, uint32_t &L) +{ + // compute L = ceil(log2(d)); + L = 0; + while (L < 32 && (uint32_t{1} << L) < d) { + L++; + } + + mp = (uint32_t)((uint64_t{1} << 32) * ((uint64_t{1} << L) - d) / d + 1); +} + +template void init_pushconst_fastdiv(T &p) { + GGML_UNUSED(p); + static_assert(!std::is_const::value, "unexpected type"); +} + +template <> void init_pushconst_fastdiv(vk_op_unary_push_constants &p) { + // Compute magic values to divide by these six numbers. + init_fastdiv_values(p.ne02*p.ne01*p.ne00, p.ne0_012mp, p.ne0_012L); + init_fastdiv_values(p.ne01*p.ne00, p.ne0_01mp, p.ne0_01L); + init_fastdiv_values(p.ne00, p.ne0_0mp, p.ne0_0L); + init_fastdiv_values(p.ne12*p.ne11*p.ne10, p.ne1_012mp, p.ne1_012L); + init_fastdiv_values(p.ne11*p.ne10, p.ne1_01mp, p.ne1_01L); + init_fastdiv_values(p.ne10, p.ne1_0mp, p.ne1_0L); +} + +struct vk_op_binary_push_constants { + uint32_t ne; + uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23; + uint32_t misalign_offsets; + float param1; float param2; int32_t param3; +}; + +struct vk_op_diag_mask_push_constants { + uint32_t ncols; + uint32_t rows_per_channel; + int32_t n_past; +}; + +struct vk_op_rope_push_constants { + uint32_t ncols; + uint32_t n_dims; + float freq_scale; + uint32_t p_delta_rows; + float freq_base; + float ext_factor; + float attn_factor; + float corr_dims[2]; + float theta_scale; + uint32_t has_ff; + uint32_t ne02; + uint32_t s1; + uint32_t s2; + int32_t sections[4]; + uint32_t is_back; +}; + +struct vk_op_soft_max_push_constants { + uint32_t KX; + uint32_t KY; + float scale; + float max_bias; + float m0; + float m1; + uint32_t n_head_log2; + uint32_t nrows_x; +}; + +struct vk_op_argsort_push_constants { + uint32_t ncols; + uint32_t ncols_pad; + int32_t order; +}; + +struct vk_op_im2col_push_constants { + uint32_t batch_offset; uint32_t offset_delta; + uint32_t IC; + uint32_t IW; uint32_t IH; + uint32_t OW; uint32_t OH; + uint32_t KW; uint32_t KH; + uint32_t pelements; + uint32_t CHW; + int32_t s0; int32_t s1; + int32_t p0; int32_t p1; + int32_t d0; int32_t d1; +}; + +struct vk_op_timestep_embedding_push_constants { + uint32_t nb1; + uint32_t dim; + uint32_t max_period; +}; + +struct vk_op_conv_transpose_1d_push_constants { + uint32_t Cout; + uint32_t Cin; + uint32_t K; + uint32_t L; + uint32_t KL; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb11; + uint32_t nb1; + + int32_t s0; +}; + +struct vk_op_pool2d_push_constants { + uint32_t IW; uint32_t IH; + uint32_t OW; uint32_t OH; + uint32_t OC; + uint32_t pelements; + uint32_t op; + int32_t k0; int32_t k1; + int32_t s0; int32_t s1; + int32_t p0; int32_t p1; +}; + +struct vk_op_rwkv_wkv6_push_constants { + uint32_t B; + uint32_t T; + uint32_t C; + uint32_t H; +}; + +struct vk_op_rwkv_wkv7_push_constants { + uint32_t B; + uint32_t T; + uint32_t C; + uint32_t H; +}; + +struct vk_op_conv2d_dw_push_constants { + uint32_t ne; + uint32_t batches; + uint32_t channels; + uint32_t dst_w; + uint32_t dst_h; + uint32_t src_w; + uint32_t src_h; + uint32_t knl_w; + uint32_t knl_h; + int32_t stride_x; + int32_t stride_y; + int32_t pad_x; + int32_t pad_y; + int32_t dilation_x; + int32_t dilation_y; +}; + +struct vk_op_upscale_push_constants { + uint32_t ne; uint32_t a_offset; uint32_t d_offset; + uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; + float sf0; float sf1; float sf2; float sf3; +}; + +// Allow pre-recording command buffers +struct vk_staging_memcpy { + vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} + + void * dst; + const void * src; + size_t n; +}; + +struct vk_context_struct { + vk_submission * s; + std::vector seqs; + + int exit_tensor_idx; + + std::vector in_memcpys; + std::vector out_memcpys; + + vk_command_pool * p {}; +}; +typedef std::shared_ptr vk_context; +typedef std::weak_ptr vk_context_ref; + +struct ggml_vk_garbage_collector { + std::vector tl_semaphores; + std::vector semaphores; + std::vector events; + std::vector temp_buffers; + std::vector contexts; +}; + +#if defined(GGML_VULKAN_MEMORY_DEBUG) || defined(GGML_VULKAN_DEBUG) +#define VK_LOG_MEMORY(msg) std::cerr << "ggml_vulkan memory: " << msg << std::endl + +static std::string format_size(size_t size) { + const size_t kib = 1024; + const size_t mib = kib * 1024; + const size_t gib = mib * 1024; + + std::ostringstream oss; + oss << std::fixed << std::setprecision(2); + + if (size >= gib) { + oss << static_cast(size) / gib << " GiB"; + } else if (size >= mib) { + oss << static_cast(size) / mib << " MiB"; + } else if (size >= kib) { + oss << static_cast(size) / kib << " KiB"; + } else { + oss << size << " B"; + } + + return oss.str(); +} + +static std::mutex log_mutex; + +class vk_memory_logger { +public: + vk_memory_logger(): total_device(0), total_host(0) {} + void log_allocation(vk_buffer_ref buf_ref, size_t size); + void log_deallocation(vk_buffer_ref buf_ref); + +private: + std::map allocations; // Track allocations + size_t total_device; + size_t total_host; +}; +#else +#define VK_LOG_MEMORY(msg) ((void) 0) +#endif // GGML_VULKAN_MEMORY_DEBUG + +class vk_perf_logger { +public: + void print_timings() { + std::cerr << "----------------\nVulkan Timings:" << std::endl; + for (const auto& t : timings) { + uint64_t total = 0; + for (const auto& time : t.second) { + total += time; + } + std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " us" << std::endl; + } + + timings.clear(); + } + + void log_timing(const ggml_tensor * node, uint64_t time) { + if (node->op == GGML_OP_UNARY) { + timings[ggml_unary_op_name(ggml_get_unary_op(node))].push_back(time); + return; + } + if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { + const uint64_t m = node->src[0]->ne[1]; + const uint64_t n = node->src[1]->ne[1]; + const uint64_t k = node->src[1]->ne[0]; + std::string name = ggml_op_name(node->op); + if (n == 1) { + name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k); + } else { + name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k); + } + timings[name].push_back(time); + return; + } + timings[ggml_op_name(node->op)].push_back(time); + } +private: + std::map> timings; +}; + +struct ggml_backend_vk_context { + std::string name; + + vk_device device; + + size_t semaphore_idx, event_idx; + ggml_vk_garbage_collector gc; + size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k; + vk_buffer prealloc_x, prealloc_y, prealloc_split_k; + vk::Fence fence, almost_ready_fence; + bool almost_ready_fence_pending {}; + + vk_buffer buffer_pool[MAX_VK_BUFFERS]; + + vk_context_ref compute_ctx; + vk_context_ref transfer_ctx; + + std::vector tensor_ctxs; + + std::vector descriptor_pools; + std::vector descriptor_sets; + uint32_t descriptor_set_idx {}; + uint32_t pipeline_descriptor_set_requirements {}; + + vk_command_pool compute_cmd_pool; + vk_command_pool transfer_cmd_pool; +}; + +static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT + +static uint64_t vk_tensor_offset(const ggml_tensor * tensor) { + if (tensor->view_src) { + return (uint8_t *) tensor->view_src->data - (uint8_t *) vk_ptr_base; + } + return (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base; +} + +struct ggml_backend_vk_buffer_context { + vk_device_ref device; + vk_buffer dev_buffer; + std::string name; + + ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) : + device(device), + dev_buffer(dev_buffer), + name(name) { + } + + ~ggml_backend_vk_buffer_context() { + ggml_vk_destroy_buffer(dev_buffer); + } +}; + +#ifdef GGML_VULKAN_MEMORY_DEBUG +void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) { + std::lock_guard guard(log_mutex); + vk_buffer buf = buf_ref.lock(); + const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal); + const std::string type = device ? "device" : "host"; + allocations[buf->buffer] = size; + total_device += device ? size : 0; + total_host += device ? 0 : size; + VK_LOG_MEMORY(buf->device->name << ": +" << format_size(size) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host)); +} + +void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) { + if (buf_ref.expired() || buf_ref.lock()->size == 0) { + return; + } + + std::lock_guard guard(log_mutex); + vk_buffer buf = buf_ref.lock(); + const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal); + std::string type = device ? "device" : "host"; + auto it = allocations.find(buf->buffer); + total_device -= device ? it->second : 0; + total_host -= device ? 0 : it->second; + if (it != allocations.end()) { + VK_LOG_MEMORY(buf->device->name << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host)); + allocations.erase(it); + } else { + VK_LOG_MEMORY("ERROR " << buf->device->name << ": Attempted to deallocate unknown " << type << " memory at " << buf->buffer); + } +} +#endif // GGML_VULKAN_MEMORY_DEBUG + +struct vk_instance_t { + vk::Instance instance; + + std::vector device_indices; + vk_device devices[GGML_VK_MAX_DEVICES]; +}; + +static bool vk_instance_initialized = false; +static vk_instance_t vk_instance; + +static bool vk_perf_logger_enabled = false; + +#ifdef GGML_VULKAN_CHECK_RESULTS +static size_t vk_skip_checks; +static size_t vk_output_tensor; + +static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name); +static void ggml_vk_check_results_0(ggml_tensor * tensor); +static void ggml_vk_check_results_1(ggml_tensor * tensor); +#endif + +typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +static void ggml_backend_vk_free(ggml_backend_t backend); + +// Wait for ctx->fence to be signaled. +static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) { + // Use waitForFences while most of the graph executes. Hopefully the CPU can sleep + // during this wait. + if (ctx->almost_ready_fence_pending) { + VK_CHECK(ctx->device->device.waitForFences({ ctx->almost_ready_fence }, true, UINT64_MAX), "almost_ready_fence"); + ctx->device->device.resetFences({ ctx->almost_ready_fence }); + ctx->almost_ready_fence_pending = false; + } + + // Spin (w/pause) waiting for the graph to finish executing. + vk::Result result; + while ((result = ctx->device->device.getFenceStatus(ctx->fence)) != vk::Result::eSuccess) { + if (result != vk::Result::eNotReady) { + fprintf(stderr, "ggml_vulkan: error %s at %s:%d\n", to_string(result).c_str(), __FILE__, __LINE__); + exit(1); + } + for (uint32_t i = 0; i < 100; ++i) { + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + } + } + ctx->device->device.resetFences({ ctx->fence }); +} + +// variables to track number of compiles in progress +static uint32_t compile_count = 0; +static std::mutex compile_count_mutex; +static std::condition_variable compile_count_cond; + +static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint, + uint32_t parameter_count, std::array wg_denoms, std::vector specialization_constants, + bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) { + VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << pipeline->name << ", " << entrypoint << ", " << parameter_count << + ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << + disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")"); + GGML_ASSERT(parameter_count > 0); + GGML_ASSERT(parameter_count <= MAX_PARAMETER_COUNT); + GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT + + vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast(spv_data)); + pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); + + vk::PushConstantRange pcr( + vk::ShaderStageFlagBits::eCompute, + 0, + pipeline->push_constant_size + ); + + vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), device->dsl, pcr); + pipeline->layout = device->device.createPipelineLayout(pipeline_layout_create_info); + + std::vector specialization_entries(specialization_constants.size()); + + for (size_t i = 0; i < specialization_constants.size(); i++) { + specialization_entries[i].constantID = i; + specialization_entries[i].offset = i * sizeof(uint32_t); + specialization_entries[i].size = sizeof(uint32_t); + } + + vk::SpecializationInfo specialization_info( + specialization_entries.size(), + specialization_entries.data(), + specialization_constants.size() * sizeof(uint32_t), + specialization_constants.data() + ); + + vk::PipelineShaderStageCreateFlags pipeline_shader_stage_create_flags{}; + + if (device->subgroup_require_full_support && require_full_subgroups) { + pipeline_shader_stage_create_flags |= vk::PipelineShaderStageCreateFlagBits::eRequireFullSubgroupsEXT; + } + + vk::PipelineShaderStageCreateInfo pipeline_shader_create_info( + pipeline_shader_stage_create_flags, + vk::ShaderStageFlagBits::eCompute, + pipeline->shader_module, + entrypoint.c_str(), + &specialization_info); + + vk::PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT pipeline_shader_stage_required_subgroup_size_create_info; + pipeline_shader_stage_required_subgroup_size_create_info.requiredSubgroupSize = required_subgroup_size; + if (device->subgroup_size_control && required_subgroup_size > 0) { + GGML_ASSERT(device->subgroup_min_size <= required_subgroup_size && required_subgroup_size <= device->subgroup_max_size); + pipeline_shader_create_info.setPNext(&pipeline_shader_stage_required_subgroup_size_create_info); + } + + vk::ComputePipelineCreateInfo compute_pipeline_create_info( + vk::PipelineCreateFlags{}, + pipeline_shader_create_info, + pipeline->layout); + + vk::PipelineRobustnessCreateInfoEXT rci; + + if (device->pipeline_robustness && disable_robustness) { + rci.storageBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled; + rci.uniformBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled; + compute_pipeline_create_info.setPNext(&rci); + } + + try { + pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value; + } catch (const vk::SystemError& e) { + std::cerr << "ggml_vulkan: Compute pipeline creation failed for " << pipeline->name << std::endl; + std::cerr << "ggml_vulkan: " << e.what() << std::endl; + throw e; + } + pipeline->compiled = true; + + { + std::lock_guard guard(device->mutex); + device->pipelines.insert({ pipeline->name, pipeline }); + } + + { + std::lock_guard guard(compile_count_mutex); + assert(compile_count > 0); + compile_count--; + } + compile_count_cond.notify_all(); +} + +static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) { + VK_LOG_DEBUG("ggml_pipeline_destroy_pipeline(" << pipeline->name << ")"); + device.destroyPipelineLayout(pipeline->layout); + + device.destroyShaderModule(pipeline->shader_module); + + device.destroyPipeline(pipeline->pipeline); +} + +static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx, vk_pipeline& pipeline, uint32_t n) { + VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")"); + ctx->pipeline_descriptor_set_requirements += n; + if (!pipeline->compiled) { + pipeline->needed = true; + ctx->device->need_compiles = true; + } +} + +static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx) { + + if (ctx->descriptor_sets.size() >= ctx->pipeline_descriptor_set_requirements) { + // Enough descriptors are available + return; + } + + vk_device& device = ctx->device; + + uint32_t to_alloc = ctx->pipeline_descriptor_set_requirements - ctx->descriptor_sets.size(); + uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - ctx->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE; + uint32_t pool_idx = ctx->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE; + + while (to_alloc > 0) { + const uint32_t alloc_count = std::min(pool_remaining, to_alloc); + to_alloc -= alloc_count; + pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE; + + if (pool_idx >= ctx->descriptor_pools.size()) { + vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, MAX_PARAMETER_COUNT * VK_DEVICE_DESCRIPTOR_POOL_SIZE); + vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size); + ctx->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); + } + + std::vector layouts(alloc_count); + for (uint32_t i = 0; i < alloc_count; i++) { + layouts[i] = device->dsl; + } + vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(ctx->descriptor_pools[pool_idx], alloc_count, layouts.data()); + std::vector sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info); + ctx->descriptor_sets.insert(ctx->descriptor_sets.end(), sets.begin(), sets.end()); + + pool_idx++; + } +} + +static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) { + VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()"); + + if (p.cmd_buffers.size() > p.cmd_buffer_idx) { + // Reuse command buffer + return p.cmd_buffers[p.cmd_buffer_idx++]; + } + + vk::CommandBufferAllocateInfo command_buffer_alloc_info( + p.pool, + vk::CommandBufferLevel::ePrimary, + 1); + const std::vector cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info); + auto buf = cmd_buffers.front(); + + p.cmd_buffers.push_back(buf); + p.cmd_buffer_idx++; + + return buf; +} + +static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) { + if (ctx->seqs.empty()) { + if (fence) { + std::lock_guard guard(queue_mutex); + ctx->p->q->queue.submit({}, fence); + } + return; + } + VK_LOG_DEBUG("ggml_vk_submit(" << ctx << ", " << fence << ")"); + + std::vector> tl_wait_vals; + std::vector> tl_signal_vals; + std::vector> tl_wait_semaphores; + std::vector> tl_signal_semaphores; + std::vector tl_submit_infos; + std::vector submit_infos; + int idx = -1; + std::vector> stage_flags; + + size_t reserve = 0; + + for (const auto& sequence : ctx->seqs) { + reserve += sequence.size(); + } + + // Pre-reserve vectors to prevent reallocation, which invalidates pointers + tl_wait_semaphores.reserve(reserve); + tl_wait_vals.reserve(reserve); + tl_signal_semaphores.reserve(reserve); + tl_signal_vals.reserve(reserve); + tl_submit_infos.reserve(reserve); + submit_infos.reserve(reserve); + stage_flags.reserve(reserve); + + for (const auto& sequence : ctx->seqs) { + for (const auto& submission : sequence) { + stage_flags.push_back({}); + idx++; + tl_wait_vals.push_back({}); + tl_wait_semaphores.push_back({}); + tl_signal_vals.push_back({}); + tl_signal_semaphores.push_back({}); + for (size_t i = 0; i < submission.wait_semaphores.size(); i++) { + stage_flags[idx].push_back(ctx->p->q->stage_flags); + tl_wait_vals[idx].push_back(submission.wait_semaphores[i].value); + tl_wait_semaphores[idx].push_back(submission.wait_semaphores[i].s); + } + for (size_t i = 0; i < submission.signal_semaphores.size(); i++) { + tl_signal_vals[idx].push_back(submission.signal_semaphores[i].value); + tl_signal_semaphores[idx].push_back(submission.signal_semaphores[i].s); + } + tl_submit_infos.push_back({ + (uint32_t) submission.wait_semaphores.size(), + tl_wait_vals[idx].data(), + (uint32_t) submission.signal_semaphores.size(), + tl_signal_vals[idx].data(), + }); + tl_submit_infos[idx].sType = vk::StructureType::eTimelineSemaphoreSubmitInfo; + tl_submit_infos[idx].pNext = nullptr; + vk::SubmitInfo si{ + (uint32_t) submission.wait_semaphores.size(), + tl_wait_semaphores[idx].data(), + stage_flags[idx].data(), + 1, + &submission.buffer, + (uint32_t) submission.signal_semaphores.size(), + tl_signal_semaphores[idx].data(), + }; + si.setPNext(&tl_submit_infos[idx]); + submit_infos.push_back(si); + } + } + + std::lock_guard guard(queue_mutex); + ctx->p->q->queue.submit(submit_infos, fence); + + ctx->seqs.clear(); +} + +static uint32_t ggml_vk_find_queue_family_index(std::vector& queue_family_props, const vk::QueueFlags& required, const vk::QueueFlags& avoid, int32_t compute_index, uint32_t min_num_queues) { + VK_LOG_DEBUG("ggml_vk_find_queue_family_index()"); + const uint32_t qfsize = queue_family_props.size(); + + // Try with avoid preferences first + for (uint32_t i = 0; i < qfsize; i++) { + if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required && !(queue_family_props[i].queueFlags & avoid)) { + return i; + } + } + + // Fall back to only required + for (size_t i = 0; i < qfsize; i++) { + if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required) { + return i; + } + } + + // Fall back to reusing compute queue + for (size_t i = 0; i < qfsize; i++) { + if (queue_family_props[i].queueCount >= min_num_queues && queue_family_props[i].queueFlags & required) { + return i; + } + } + + // Fall back to ignoring min_num_queries + for (size_t i = 0; i < qfsize; i++) { + if (queue_family_props[i].queueFlags & required) { + return i; + } + } + + // All commands that are allowed on a queue that supports transfer operations are also allowed on a queue that supports either graphics or compute operations. + // Thus, if the capabilities of a queue family include VK_QUEUE_GRAPHICS_BIT or VK_QUEUE_COMPUTE_BIT, then reporting the VK_QUEUE_TRANSFER_BIT capability separately for that queue family is optional. + if (compute_index >= 0) { + return compute_index; + } + + std::cerr << "ggml_vulkan: No suitable queue family index found." << std::endl; + + for(auto &q_family : queue_family_props) { + std::cerr << "Queue number: " + std::to_string(q_family.queueCount) << " flags: " + to_string(q_family.queueFlags) << std::endl; + } + abort(); +} + +static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) { + VK_LOG_DEBUG("ggml_vk_create_queue()"); + std::lock_guard guard(device->mutex); + + q.queue_family_index = queue_family_index; + q.transfer_only = transfer_only; + + q.cmd_pool.init(device, &q); + + q.queue = device->device.getQueue(queue_family_index, queue_index); + + q.stage_flags = stage_flags; +} + +static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_command_pool& p) { + vk_context result = std::make_shared(); + VK_LOG_DEBUG("ggml_vk_create_context(" << result << ")"); + ctx->gc.contexts.emplace_back(result); + result->p = &p; + return result; +} + +static vk_context ggml_vk_create_temporary_context(vk_command_pool& p) { + vk_context result = std::make_shared(); + VK_LOG_DEBUG("ggml_vk_create_temporary_context(" << result << ")"); + result->p = &p; + return result; +} + +static vk_semaphore * ggml_vk_create_binary_semaphore(ggml_backend_vk_context * ctx) { + VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()"); + vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eBinary, 0 }; + vk::SemaphoreCreateInfo ci{}; + ci.setPNext(&tci); + vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci); + ctx->gc.semaphores.push_back({ semaphore, 0 }); + return &ctx->gc.semaphores[ctx->gc.semaphores.size() - 1]; +} + +static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context * ctx) { + VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()"); + if (ctx->semaphore_idx >= ctx->gc.tl_semaphores.size()) { + vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 }; + vk::SemaphoreCreateInfo ci{}; + ci.setPNext(&tci); + vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci); + ctx->gc.tl_semaphores.push_back({ semaphore, 0 }); + } + return &ctx->gc.tl_semaphores[ctx->semaphore_idx++]; +} + +static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) { + if (ctx->event_idx >= ctx->gc.events.size()) { + ctx->gc.events.push_back(ctx->device->device.createEvent({})); + } + return ctx->gc.events[ctx->event_idx++]; +} + +static void ggml_vk_command_pool_cleanup(vk_device& device, vk_command_pool& p) { + VK_LOG_DEBUG("ggml_vk_command_pool_cleanup()"); + + // Requires command buffers to be done + device->device.resetCommandPool(p.pool); + p.cmd_buffer_idx = 0; +} + +static void ggml_vk_queue_command_pools_cleanup(vk_device& device) { + VK_LOG_DEBUG("ggml_vk_queue_command_pools_cleanup()"); + + // Arbitrary frequency to cleanup/reuse command buffers + static constexpr uint32_t cleanup_frequency = 10; + + if (device->compute_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) { + ggml_vk_command_pool_cleanup(device, device->compute_queue.cmd_pool); + } + if (device->transfer_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) { + ggml_vk_command_pool_cleanup(device, device->transfer_queue.cmd_pool); + } +} + + +static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) { + for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) { + vk::MemoryType memory_type = mem_props->memoryTypes[i]; + if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) && + (flags & memory_type.propertyFlags) == flags && + mem_props->memoryHeaps[memory_type.heapIndex].size >= mem_req->size) { + return static_cast(i); + } + } + return UINT32_MAX; +} + +static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) { + VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags) << ", " << to_string(fallback_flags) << ")"); + if (size > device->max_memory_allocation_size) { + throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device memory allocation limit"); + } + + vk_buffer buf = std::make_shared(); + + if (size == 0) { + buf->size = 0; + return buf; + } + + vk::BufferCreateInfo buffer_create_info{ + vk::BufferCreateFlags(), + size, + vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst, + vk::SharingMode::eExclusive, + 0, + nullptr, + }; + + buf->buffer = device->device.createBuffer(buffer_create_info); + + vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer); + + vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties(); + + uint32_t memory_type_index = UINT32_MAX; + + memory_type_index = find_properties(&mem_props, &mem_req, req_flags); + buf->memory_property_flags = req_flags; + + if (memory_type_index == UINT32_MAX && fallback_flags) { + memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags); + buf->memory_property_flags = fallback_flags; + } + + if (memory_type_index == UINT32_MAX) { + device->device.destroyBuffer(buf->buffer); + throw vk::OutOfDeviceMemoryError("No suitable memory type found"); + } + + try { + buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index }); + } catch (const vk::SystemError& e) { + if (buf->memory_property_flags != fallback_flags) { + // Try again with fallback flags + memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags); + buf->memory_property_flags = fallback_flags; + + try { + buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index }); + } + catch (const vk::SystemError& e) { + device->device.destroyBuffer(buf->buffer); + throw e; + } + } else { + // Out of Host/Device memory, clean up buffer + device->device.destroyBuffer(buf->buffer); + throw e; + } + } + buf->ptr = nullptr; + + if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE); + } + + device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0); + + buf->device = device; + buf->size = size; + +#ifdef GGML_VULKAN_MEMORY_DEBUG + device->memory_logger->log_allocation(buf, size); +#endif + + return buf; +} + +static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) { + try { + return ggml_vk_create_buffer(device, size, req_flags, fallback_flags); + } catch (const vk::SystemError& e) { + std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl; + std::cerr << "ggml_vulkan: " << e.what() << std::endl; + throw e; + } +} + +static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) { + vk_buffer buf; + try { + if (device->prefer_host_memory) { + buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal); + } else if (device->uma) { + // Fall back to host memory type + buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); + } else { + // use rebar if available, otherwise fallback to device only visible memory + buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal); + } + } catch (const vk::SystemError& e) { + std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl; + std::cerr << "ggml_vulkan: " << e.what() << std::endl; + throw e; + } + + return buf; +} + +static void ggml_vk_destroy_buffer(vk_buffer& buf) { + if (buf == nullptr) { + return; + } + +#ifdef GGML_VULKAN_MEMORY_DEBUG + if (buf->device != nullptr) { + buf->device->memory_logger->log_deallocation(buf); + } +#endif + + buf.reset(); +} + +static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) { + return { buf, 0, VK_WHOLE_SIZE }; +} + +static void ggml_vk_sync_buffers(vk_context& ctx) { + VK_LOG_DEBUG("ggml_vk_sync_buffers()"); + + const bool transfer_queue = ctx->p->q->transfer_only; + + ctx->s->buffer.pipelineBarrier( + ctx->p->q->stage_flags, + ctx->p->q->stage_flags, + {}, + { { + { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }, + { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) } + } }, + {}, + {} + ); +} + +static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events) { + VK_LOG_DEBUG("ggml_vk_wait_events()"); + if (events.empty()) { + return; + } + + ctx->s->buffer.waitEvents( + events, + ctx->p->q->stage_flags, + ctx->p->q->stage_flags, + {}, + {}, + {} + ); +} + +enum FaCodePath { + FA_SCALAR, + FA_COOPMAT1, + FA_COOPMAT2, +}; + +// number of rows/cols for flash attention shader +static constexpr uint32_t flash_attention_num_small_rows = 32; +static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; +static constexpr uint32_t scalar_flash_attention_num_large_rows = 8; + +// The FA coopmat1 shader assumes 16x16x16 matrix multiply support. +// 128 threads split into four subgroups, each subgroup does 1/4 +// of the Bc dimension. +static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16; +static constexpr uint32_t scalar_flash_attention_Bc = 64; +static constexpr uint32_t scalar_flash_attention_workgroup_size = 128; + +static uint32_t get_fa_num_small_rows(FaCodePath path) { + if (path == FA_COOPMAT2) { + return flash_attention_num_small_rows; + } else { + return scalar_flash_attention_num_small_rows; + } +} + +static std::array fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { + GGML_UNUSED(clamp); + + if (path == FA_SCALAR) { + if (small_rows) { + return {scalar_flash_attention_num_small_rows, 64}; + } else { + return {scalar_flash_attention_num_large_rows, 32}; + } + } + + if (path == FA_COOPMAT1) { + if (small_rows) { + return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc}; + } else { + return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc}; + } + } + + // small rows, large cols + if (small_rows) { + return {get_fa_num_small_rows(FA_COOPMAT2), 32}; + } + + // small cols to reduce register count + if (ggml_is_quantized(type) || D == 256) { + return {64, 32}; + } + return {64, 64}; +} + +static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { + + uint32_t lut_size = 0; + switch (src0_type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + lut_size = 2*2048; + break; + case GGML_TYPE_IQ2_XXS: + lut_size = 8*256; + break; + case GGML_TYPE_IQ2_XS: + lut_size = 8*512; + break; + case GGML_TYPE_IQ2_S: + lut_size = 8*1024; + break; + case GGML_TYPE_IQ3_XXS: + lut_size = 4*256; + break; + case GGML_TYPE_IQ3_S: + lut_size = 4*512; + break; + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + lut_size = 4*16; + break; + default: + break; + } + + // Needs to be kept up to date on shader changes + const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1; + const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); + const uint32_t warps = warptile[0] / warptile[10]; + + const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size; + const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0; + const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; + + const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), " + "mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", supported=" << supported); + + return supported; +} + +struct GpuPipelineConfig { + // GPU architecture identifier. + // Example: vk_device_architecture::AMD_GCN + vk_device_architecture arch; + + // Mapping of pipeline names to their specific subgroup sizes. + // Example: {"soft_max_f32", 64} + std::unordered_map pipelines; + + // Default subgroup size for this GPU. + // Defaults to 0 if not explicitly provided. + uint32_t default_subgroup_size = 0; +}; + +// Pipeline configuration for RDNA1 GPUs. +static const std::unordered_map rdna1_pipelines = { + {"soft_max", 64}, {"im2col", 64}, + {"argmax", 64}, {"mul_mat_vec", 64}, + {"mul_mat_vec_f16", 32}, {"mul_mat_vec_f32_f16", 32} +}; + +// Pipeline configuration for RDNA2 GPUs. +static const std::unordered_map rdna2_pipelines = { + {"soft_max", 64}, {"im2col", 64}, +}; + +static constexpr uint32_t RDNA_DEFAULT_SUBGROUP_SIZE = 32; + +// Define configurations for different GPUs. +static std::vector gpu_pipeline_configs = { + { + vk_device_architecture::AMD_RDNA1, + { + rdna1_pipelines, + }, + RDNA_DEFAULT_SUBGROUP_SIZE + }, + { + vk_device_architecture::AMD_RDNA2, + { + rdna2_pipelines, + }, + RDNA_DEFAULT_SUBGROUP_SIZE + }, +}; + +static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_device_architecture &arch) { + for (const auto &config : gpu_pipeline_configs) { + if (config.arch == arch) { + auto pipIt = config.pipelines.find(pipeline_name); + if (pipIt != config.pipelines.end()) { + return pipIt->second; + } + std::vector> sorted_pipelines(config.pipelines.begin(), config.pipelines.end()); + std::sort(sorted_pipelines.begin(), sorted_pipelines.end(), + [](const auto &a, const auto &b) { return a.first.size() > b.first.size(); }); + for (const auto &entry : sorted_pipelines) { + if (pipeline_name.find(entry.first) != std::string::npos) { + return entry.second; + } + } + return config.default_subgroup_size; + } + } + return 0; // If no matching configuration is found +} + +static void ggml_vk_load_shaders(vk_device& device) { + VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); + + // some shaders have a minimum subgroup size + const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u); + const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); + const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u); + + // mulmat + std::vector l_warptile, m_warptile, s_warptile, + l_warptile_mmq, m_warptile_mmq, s_warptile_mmq, + l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int, + l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k, + l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid; + std::array l_wg_denoms, m_wg_denoms, s_wg_denoms, + l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms, + l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k, + l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms; + + uint32_t l_align, m_align, s_align; + if (device->coopmat2) { + // spec constants and tile sizes for non-quant matmul/matmul_id + l_warptile = { 256, 128, 256, 64, 1 }; + m_warptile = { 256, 128, 128, 64, 0 }; + s_warptile = { 128, 64, 64, 64, 0 }; + l_wg_denoms = {128, 256, 1 }; + m_wg_denoms = {128, 128, 1 }; + s_wg_denoms = { 64, 64, 1 }; + + // spec constants and tile sizes for quant matmul (non-Qi_K) + l_warptile_mmq = { 256, 128, 256, 64, 1 }; + m_warptile_mmq = { 256, 128, 128, 64, 1 }; + s_warptile_mmq = { 256, 32, 64, 128, 0 }; + l_mmq_wg_denoms = { 128, 256, 1 }; + m_mmq_wg_denoms = { 128, 128, 1 }; + s_mmq_wg_denoms = { 32, 64, 1 }; + + // spec constants and tile sizes for quant matmul (Qi_K) + l_warptile_mmq_k = { 256, 64, 128, 64, 1 }; + m_warptile_mmq_k = { 256, 32, 64, 64, 0 }; + s_warptile_mmq_k = { 256, 32, 32, 128, 0 }; + l_mmq_wg_denoms_k = { 64, 128, 1 }; + m_mmq_wg_denoms_k = { 32, 64, 1 }; + s_mmq_wg_denoms_k = { 32, 32, 1 }; + + // spec constants and tile sizes for quant matmul_id + l_warptile_mmqid = { 256, 128, 64, 16, 0 }; + m_warptile_mmqid = { 256, 128, 64, 16, 0 }; + s_warptile_mmqid = { 256, 128, 64, 16, 0 }; + l_mmqid_wg_denoms = { 128, 64, 1 }; + m_mmqid_wg_denoms = { 128, 64, 1 }; + s_mmqid_wg_denoms = { 128, 64, 1 }; + + l_align = 128; + m_align = 64; + s_align = 32; + } else { + // Matrix cores require different warp group sizes + const uint32_t tm_l = device->coopmat_support ? device->coopmat_m : 4; + const uint32_t tm_m = device->coopmat_support ? device->coopmat_m : 4; + const uint32_t tm_s = device->coopmat_support ? device->coopmat_m : 2; + const uint32_t tn_l = device->coopmat_support ? device->coopmat_n : 4; + const uint32_t tn_m = device->coopmat_support ? device->coopmat_n : 2; + const uint32_t tn_s = device->coopmat_support ? device->coopmat_n : 2; + const uint32_t tk_l = device->coopmat_support ? device->coopmat_k : 1; + const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1; + const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1; + + l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; + m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; + + l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; + m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; + + l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; + m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 }; + s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 }; + + // chip specific tuning + if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) { + m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; + } + + l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; + m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 }; + s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 }; + l_align = 128; + m_align = 64; + s_align = 32; + + for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) { + ggml_type t = (ggml_type)i; + // Disable medium and large matrix multiplication if not enough shared memory is available + // Check mmq warptiles as the largest configuration + // Throw an error if not enough for any matrix multiplication is available + if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t)) { + std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl; + throw std::runtime_error("Shared memory size too small for matrix multiplication."); + } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t)) { + device->mul_mat_m[i] = false; + device->mul_mat_l[i] = false; + } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t)) { + device->mul_mat_l[i] = false; + } + + // Disable mul_mat_id if not enough shared memory is available + if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true, t)) { + device->mul_mat_id_s[i] = false; + device->mul_mat_id_m[i] = false; + device->mul_mat_id_l[i] = false; + } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true, t)) { + device->mul_mat_id_m[i] = false; + device->mul_mat_id_l[i] = false; + } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true, t)) { + device->mul_mat_id_l[i] = false; + } + } + } + + if (!device->pipeline_matmul_f32) { + device->pipeline_matmul_f32 = std::make_shared(); + } + if (!device->pipeline_matmul_f32_f16) { + device->pipeline_matmul_f32_f16 = std::make_shared(); + } + if (!device->pipeline_matmul_id_f32) { + device->pipeline_matmul_id_f32 = std::make_shared(); + } + if (!device->pipeline_matmul_bf16) { + device->pipeline_matmul_bf16 = std::make_shared(); + } + if (!device->pipeline_matmul_id_bf16) { + device->pipeline_matmul_id_bf16 = std::make_shared(); + } + + std::vector> compiles; + auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, + uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants, + uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { + + if (!require_full_subgroups && required_subgroup_size == 0) { + required_subgroup_size = get_subgroup_size(name, device->architecture); + } + + if (!pipeline) { + pipeline = std::make_shared(); + pipeline->name = name; + pipeline->parameter_count = parameter_count; + pipeline->push_constant_size = push_constant_size; + pipeline->wg_denoms = wg_denoms; + pipeline->align = align; + } + + if (!pipeline->needed || pipeline->compiled) { + return; + } + { + // wait until fewer than N compiles are in progress + uint32_t N = std::max(1u, std::thread::hardware_concurrency()); + std::unique_lock guard(compile_count_mutex); + while (compile_count >= N) { + compile_count_cond.wait(guard); + } + compile_count++; + } + compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint, + parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); + }; + + auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { + return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1}; + }; + + auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { + // For large number of rows, 128 invocations seems to work best. + // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we + // can't use 256 for D==80. + // For scalar, use 128 (arbitrary) + uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1) + ? scalar_flash_attention_workgroup_size + : ((small_rows && (D % 32) == 0) ? 256 : 128); + auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows); + + // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. + // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. + const uint32_t D_lsb = D ^ (D & (D-1)); + uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4); + + // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads + GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0); + return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split}; + }; + +#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + +#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256) + + CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (device->coopmat1_fa_support) { + CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1) + } +#endif +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2) + } +#endif +#undef CREATE_FA2 +#undef CREATE_FA + +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ + + // Create 2 variants, {f16,f32} accumulator +#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ + CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ + CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ + + CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3) +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) + } +#endif + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0], matmul_q4_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1], matmul_q4_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0], matmul_q5_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1], matmul_q5_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0], matmul_q8_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K], matmul_q2_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K], matmul_q3_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K], matmul_q4_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K], matmul_q5_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K], matmul_q6_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S], matmul_iq1_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M], matmul_iq1_m_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S], matmul_iq2_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + + CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) + } +#endif + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) +#undef CREATE_MM +#undef CREATE_MM2 + } else +#endif // defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (device->coopmat_support) { + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ + + // Create 2 variants, {f16,f32} accumulator +#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->coopmat_acc_f16_support) { \ + CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + } \ + if (device->coopmat_acc_f32_support) { \ + CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + } \ + + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ) + } +#endif + + if (device->coopmat_acc_f16_support) { + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + } else { + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + } + + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + } +#endif + + if (device->coopmat_acc_f16_support) { + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + } else { + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + } +#undef CREATE_MM2 +#undef CREATE_MM + } else +#endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (device->fp16) { + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ + +#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) { \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->l, #NAMELC "_f16acc_l", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + } \ + if (device->mul_mat ## ID ## _m[TYPE]) { \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->m, #NAMELC "_f16acc_m", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + } \ + if (device->mul_mat ## ID ## _s[TYPE]) { \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->s, #NAMELC "_f16acc_s", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + } \ + + // Create 2 variants, {f16,f32} accumulator +#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + } +#endif + + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); +#undef CREATE_MM2 +#undef CREATE_MMQ +#undef CREATE_MM + } else { + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ + +#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + } +#endif + + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + } + // reusing CREATE_MM from the fp32 path + if ((device->coopmat2 || device->coopmat_support) +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + && !device->coopmat_bf16_support +#endif + ) { + // use scalar tile sizes + l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; + m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 }; + s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 }; + + l_wg_denoms = {128, 128, 1 }; + m_wg_denoms = { 64, 64, 1 }; + s_wg_denoms = { 32, 32, 1 }; + + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id); + } +#undef CREATE_MM + + // mul mat vec + + // the number of rows computed per shader depends on GPU model and quant + uint32_t rm_stdq = 1; + uint32_t rm_kq = 2; + if (device->vendor_id == VK_VENDOR_ID_AMD) { + if (device->architecture == AMD_GCN) { + rm_stdq = 2; + rm_kq = 4; + } + } else if (device->vendor_id == VK_VENDOR_ID_INTEL) + rm_stdq = 2; + uint32_t rm_iq = 2 * rm_kq; + + for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) { + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f32_f32_len, mul_mat_vec_bf16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f32_f32_len, mul_mat_vec_iq1_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f32_f32_len, mul_mat_vec_iq1_m_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f16_f32_len, mul_mat_vec_bf16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f16_f32_len, mul_mat_vec_iq1_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f16_f32_len, mul_mat_vec_iq1_m_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + } + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + + // dequant shaders + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k", dequant_q2_k_len, dequant_q2_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k", dequant_q3_k_len, dequant_q3_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_S], "dequant_iq1_s", dequant_iq1_s_len, dequant_iq1_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_M], "dequant_iq1_m", dequant_iq1_m_len, dequant_iq1_m_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XXS], "dequant_iq2_xxs", dequant_iq2_xxs_len, dequant_iq2_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XS], "dequant_iq2_xs", dequant_iq2_xs_len, dequant_iq2_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_S], "dequant_iq2_s", dequant_iq2_s_len, dequant_iq2_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_XXS], "dequant_iq3_xxs", dequant_iq3_xxs_len, dequant_iq3_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + + // get_rows + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_S], "get_rows_iq1_s", get_rows_iq1_s_len, get_rows_iq1_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_M], "get_rows_iq1_m", get_rows_iq1_m_len, get_rows_iq1_m_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs", get_rows_iq2_xxs_len, get_rows_iq2_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs", get_rows_iq2_xs_len, get_rows_iq2_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_S], "get_rows_iq2_s", get_rows_iq2_s_len, get_rows_iq2_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs", get_rows_iq3_xxs_len, get_rows_iq3_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_S], "get_rows_iq1_s_f32", get_rows_iq1_s_f32_len, get_rows_iq1_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_M], "get_rows_iq1_m_f32", get_rows_iq1_m_f32_len, get_rows_iq1_m_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs_f32", get_rows_iq2_xxs_f32_len, get_rows_iq2_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs_f32", get_rows_iq2_xs_f32_len, get_rows_iq2_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_S], "get_rows_iq2_s_f32", get_rows_iq2_s_f32_len, get_rows_iq2_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs_f32", get_rows_iq3_xxs_f32_len, get_rows_iq3_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); + + for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) { + if (device->subgroup_add && device->subgroup_require_full_support) { + ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true); + } else { + ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true); + } + } + ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 9 * sizeof(uint32_t), {1, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + if (device->float_controls_rte_fp16) { + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + } else { + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + } + + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_0], "cpy_q5_0_f32", cpy_q5_0_f32_len, cpy_q5_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_1], "cpy_q5_1_f32", cpy_q5_1_f32_len, cpy_q5_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + + auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) { + std::string s; + s += std::string(src0_f16 ? "_f16" : "_f32"); + s += std::string(src1_f16 ? "_f16" : "_f32"); + s += std::string(dst_f16 ? "_f16" : "_f32"); + return s; + }; + +#define CREATE_BINARY(name, namemod, spec) \ + for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \ + #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \ + "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1); + + CREATE_BINARY(add, , {0}) + CREATE_BINARY(add, _norepeat, {1}) + CREATE_BINARY(sub, , {0}) + CREATE_BINARY(sub, _norepeat, {1}) + CREATE_BINARY(mul, , {0}) + CREATE_BINARY(mul, _norepeat, {1}) + CREATE_BINARY(div, , {0}) + CREATE_BINARY(div, _norepeat, {1}) +#undef CREATE_BINARY + + ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + +#define CREATE_UNARY(name) \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + + CREATE_UNARY(gelu) + CREATE_UNARY(gelu_quick) + CREATE_UNARY(silu) + CREATE_UNARY(relu) + CREATE_UNARY(tanh) + CREATE_UNARY(sigmoid) +#undef CREATE_UNARY + + ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); + + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + + if (device->float_controls_rte_fp16) { + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + } else { + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + } + + ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + + ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + + ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); + + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); + if (device->float_controls_rte_fp16) { + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); + } else { + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); + } + + ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + + for (auto &c : compiles) { + c.wait(); + } + device->need_compiles = false; +} + +static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch); + +static vk_device ggml_vk_get_device(size_t idx) { + VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")"); + + if (vk_instance.devices[idx] == nullptr) { + VK_LOG_DEBUG("Initializing new vk_device"); + vk_device device = std::make_shared(); + vk_instance.devices[idx] = device; + +#ifdef GGML_VULKAN_MEMORY_DEBUG + device->memory_logger = std::unique_ptr(new vk_memory_logger()); +#endif + if (vk_perf_logger_enabled) { + device->perf_logger = std::unique_ptr(new vk_perf_logger()); + } + + size_t dev_num = vk_instance.device_indices[idx]; + + std::vector physical_devices = vk_instance.instance.enumeratePhysicalDevices(); + + if (dev_num >= physical_devices.size()) { + std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl; + throw std::runtime_error("Device not found"); + } + + device->physical_device = physical_devices[dev_num]; + const std::vector ext_props = device->physical_device.enumerateDeviceExtensionProperties(); + + device->architecture = get_device_architecture(device->physical_device); + + const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY"); + device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr; + + bool fp16_storage = false; + bool fp16_compute = false; + bool maintenance4_support = false; + bool sm_builtins = false; + bool amd_shader_core_properties2 = false; + bool pipeline_robustness = false; + bool coopmat2_support = false; + device->coopmat_support = false; + device->integer_dot_product = false; + bool bfloat16_support = false; + + for (const auto& properties : ext_props) { + if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { + maintenance4_support = true; + } else if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { + fp16_storage = true; + } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { + fp16_compute = true; + } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) { + sm_builtins = true; + } else if (strcmp("VK_AMD_shader_core_properties2", properties.extensionName) == 0) { + amd_shader_core_properties2 = true; + } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) { + pipeline_robustness = true; + } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) { + device->subgroup_size_control = true; +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT")) { + device->coopmat_support = true; + device->coopmat_m = 0; + device->coopmat_n = 0; + device->coopmat_k = 0; +#endif +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT2")) { + coopmat2_support = true; +#endif +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { + device->integer_dot_product = true; +#endif +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_BFLOAT16")) { + bfloat16_support = true; +#endif + } + } + + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceMaintenance3Properties props3; + vk::PhysicalDeviceMaintenance4Properties props4; + vk::PhysicalDeviceSubgroupProperties subgroup_props; + vk::PhysicalDeviceDriverProperties driver_props; + vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props; + vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props; + vk::PhysicalDeviceVulkan11Properties vk11_props; + vk::PhysicalDeviceVulkan12Properties vk12_props; + vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; + vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props; + + props2.pNext = &props3; + props3.pNext = &subgroup_props; + subgroup_props.pNext = &driver_props; + driver_props.pNext = &vk11_props; + vk11_props.pNext = &vk12_props; + + VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props; + + if (maintenance4_support) { + last_struct->pNext = (VkBaseOutStructure *)&props4; + last_struct = (VkBaseOutStructure *)&props4; + } + if (sm_builtins) { + last_struct->pNext = (VkBaseOutStructure *)&sm_props; + last_struct = (VkBaseOutStructure *)&sm_props; + } + if (amd_shader_core_properties2) { + last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props; + last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props; + } + if (device->subgroup_size_control) { + last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_props; + last_struct = (VkBaseOutStructure *)&subgroup_size_control_props; + } + +#if defined(VK_NV_cooperative_matrix2) + vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props; + if (coopmat2_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_props; + last_struct = (VkBaseOutStructure *)&coopmat2_props; + } +#endif + + if (device->integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props; + } + + device->physical_device.getProperties2(&props2); + device->properties = props2.properties; + device->vendor_id = device->properties.vendorID; + device->driver_id = driver_props.driverID; + + const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE"); + + if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) { + device->max_memory_allocation_size = std::stoul(GGML_VK_FORCE_MAX_ALLOCATION_SIZE); + } else if (maintenance4_support) { + device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize); + } else { + device->max_memory_allocation_size = props3.maxMemoryAllocationSize; + } + + const char* GGML_VK_SUBALLOCATION_BLOCK_SIZE = getenv("GGML_VK_SUBALLOCATION_BLOCK_SIZE"); + + if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) { + device->suballocation_block_size = std::stoul(GGML_VK_SUBALLOCATION_BLOCK_SIZE); + } else { + // Limit batching of allocations to 1GB by default to avoid fragmentation issues + device->suballocation_block_size = 1024*1024*1024; + } + device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size); + + device->subgroup_size = subgroup_props.subgroupSize; + device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; + if (sm_builtins) { + device->shader_core_count = sm_props.shaderSMCount; + } else if (amd_shader_core_properties2) { + device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount; + } else { + device->shader_core_count = 0; + } + device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16; + + device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic); + + device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle); + + const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr; + + device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; + + if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props, device->architecture)) { + device->coopmat_support = false; + } + + device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated; + + std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); + + // Try to find a non-graphics compute queue and transfer-focused queues + const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1); + const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1); + + const float priorities[] = { 1.0f, 1.0f }; + device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1; + + std::vector device_queue_create_infos; + if (compute_queue_family_index != transfer_queue_family_index) { + device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); + device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1}); + } else if(!device->single_queue) { + device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities}); + } else { + device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); + } + vk::DeviceCreateInfo device_create_info; + std::vector device_extensions; + vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures(); + + VkPhysicalDeviceFeatures2 device_features2; + device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; + device_features2.pNext = nullptr; + device_features2.features = (VkPhysicalDeviceFeatures)device_features; + + VkPhysicalDeviceVulkan11Features vk11_features; + vk11_features.pNext = nullptr; + vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; + device_features2.pNext = &vk11_features; + + VkPhysicalDeviceVulkan12Features vk12_features; + vk12_features.pNext = nullptr; + vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES; + vk11_features.pNext = &vk12_features; + + last_struct = (VkBaseOutStructure *)&vk12_features; + + VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features; + pl_robustness_features.pNext = nullptr; + pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT; + pl_robustness_features.pipelineRobustness = VK_FALSE; + + if (pipeline_robustness) { + last_struct->pNext = (VkBaseOutStructure *)&pl_robustness_features; + last_struct = (VkBaseOutStructure *)&pl_robustness_features; + device_extensions.push_back("VK_EXT_pipeline_robustness"); + } + + VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features; + subgroup_size_control_features.pNext = nullptr; + subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT; + subgroup_size_control_features.computeFullSubgroups = false; + subgroup_size_control_features.subgroupSizeControl = false; + + if (device->subgroup_size_control) { + last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_features; + last_struct = (VkBaseOutStructure *)&subgroup_size_control_features; + } + +#if defined(VK_KHR_cooperative_matrix) + VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; + coopmat_features.pNext = nullptr; + coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; + coopmat_features.cooperativeMatrix = VK_FALSE; + + if (device->coopmat_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; + last_struct = (VkBaseOutStructure *)&coopmat_features; + } +#endif + +#if defined(VK_NV_cooperative_matrix2) + VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {}; + coopmat2_features.pNext = nullptr; + coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV; + if (coopmat2_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features; + last_struct = (VkBaseOutStructure *)&coopmat2_features; + device_extensions.push_back("VK_NV_cooperative_matrix2"); + } +#endif + +#if defined(VK_KHR_shader_bfloat16) + VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {}; + bfloat16_features.pNext = nullptr; + bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR; + if (bfloat16_support) { + last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features; + last_struct = (VkBaseOutStructure *)&bfloat16_features; + device_extensions.push_back("VK_KHR_shader_bfloat16"); + } +#endif + + VkPhysicalDeviceMaintenance4Features maint4_features {}; + maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES; + if (maintenance4_support) { + last_struct->pNext = (VkBaseOutStructure *)&maint4_features; + last_struct = (VkBaseOutStructure *)&maint4_features; + device_extensions.push_back("VK_KHR_maintenance4"); + } + + VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {}; + shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR; + if (device->integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features; + device_extensions.push_back("VK_KHR_shader_integer_dot_product"); + } + + vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); + + device->fp16 = device->fp16 && vk12_features.shaderFloat16; + + device->pipeline_robustness = pl_robustness_features.pipelineRobustness; + + if (device->subgroup_size_control) { + device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize; + device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize; + device_extensions.push_back("VK_EXT_subgroup_size_control"); + } + + device->subgroup_size_control = device->subgroup_size_control && + (subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) && + subgroup_size_control_features.subgroupSizeControl; + + if (device->subgroup_size_control) { + device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups; + } + +#if defined(VK_KHR_cooperative_matrix) + device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; + + // coopmat1 fa shader currently assumes 32 invocations per subgroup + device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support && + device->subgroup_size_control && device->subgroup_min_size <= 32 && + device->subgroup_max_size >= 32; +#endif + + if (coopmat2_support) { +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (coopmat2_features.cooperativeMatrixWorkgroupScope && + coopmat2_features.cooperativeMatrixFlexibleDimensions && + coopmat2_features.cooperativeMatrixReductions && + coopmat2_features.cooperativeMatrixConversions && + coopmat2_features.cooperativeMatrixPerElementOperations && + coopmat2_features.cooperativeMatrixTensorAddressing && + coopmat2_features.cooperativeMatrixBlockLoads && + vk12_features.bufferDeviceAddress) { + + std::vector flexible_dimensions; + uint32_t count = 0; + + PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV + _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV = + (PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV) + vk_instance.instance.getProcAddr("vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV"); + + _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, nullptr); + + VkCooperativeMatrixFlexibleDimensionsPropertiesNV empty_prop {}; + empty_prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV; + flexible_dimensions.resize(count, empty_prop); + + _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, flexible_dimensions.data()); + + bool found_fp16_128 = false, + found_fp16_256 = false, + found_fp32_128 = false, + found_fp32_256 = false; + // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128 + // with 32x16x16 and 256 with 32x32x16. + for (auto &prop : flexible_dimensions) { + if (prop.saturatingAccumulation == VK_FALSE && + prop.scope == VK_SCOPE_WORKGROUP_KHR && + prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + + if (prop.workgroupInvocations == 128 && + prop.MGranularity <= 32 && + prop.NGranularity <= 16 && + prop.KGranularity <= 16) { + if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + found_fp16_128 = true; + } + if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + found_fp32_128 = true; + } + } + if (prop.workgroupInvocations == 256 && + prop.MGranularity <= 32 && + prop.NGranularity <= 32 && + prop.KGranularity <= 16) { + if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + found_fp16_256 = true; + } + if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + found_fp32_256 = true; + } + } + } + } + if (found_fp16_128 && found_fp16_256 && + found_fp32_128 && found_fp32_256 && + coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) { + device->coopmat2 = true; + } + } +#endif + } + + if (!vk11_features.storageBuffer16BitAccess) { + std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl; + throw std::runtime_error("Unsupported device"); + } + + device_extensions.push_back("VK_KHR_16bit_storage"); + +#ifdef GGML_VULKAN_VALIDATE + device_extensions.push_back("VK_KHR_shader_non_semantic_info"); +#endif + + if (device->fp16) { + device_extensions.push_back("VK_KHR_shader_float16_int8"); + } + +#if defined(VK_KHR_cooperative_matrix) + if (device->coopmat_support) { + // Query supported shapes + std::vector cm_props; + + PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR = + (PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR)vkGetInstanceProcAddr(vk_instance.instance, "vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR"); + + uint32_t cm_props_num; + + pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, nullptr); + + cm_props.resize(cm_props_num); + + for (auto& prop : cm_props) { + prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR; + } + + pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, cm_props.data()); + + VK_LOG_DEBUG("ggml_vulkan: Cooperative Matrix Shapes: " << cm_props.size()); + + for (auto& prop : cm_props) { + VK_LOG_DEBUG("ggml_vulkan: M: " << prop.MSize << " N: " << prop.NSize << " K: " << prop.KSize << " A: " << vk::to_string((vk::ComponentTypeKHR)prop.AType) << " B: " << vk::to_string((vk::ComponentTypeKHR)prop.BType) << " C: " << vk::to_string((vk::ComponentTypeKHR)prop.CType) << " Result: " << vk::to_string((vk::ComponentTypeKHR)prop.ResultType) << " saturatingAccumulation: " << prop.saturatingAccumulation << " scope: " << vk::to_string((vk::ScopeKHR)prop.scope)); + + if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 && + (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 && + (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup + ) { + if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 && + (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32) { + // coopmat sizes not set yet + if (device->coopmat_m == 0) { + device->coopmat_acc_f32_support = true; + device->coopmat_m = prop.MSize; + device->coopmat_n = prop.NSize; + device->coopmat_k = prop.KSize; + } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { + // Only enable if shape is identical + device->coopmat_acc_f32_support = true; + } + if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) { + device->coopmat_support_16x16x16_f32acc = true; + } + } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 && + (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) { + // coopmat sizes not set yet + if (device->coopmat_m == 0) { + device->coopmat_acc_f16_support = true; + device->coopmat_m = prop.MSize; + device->coopmat_n = prop.NSize; + device->coopmat_k = prop.KSize; + } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { + // Only enable if shape is identical + device->coopmat_acc_f16_support = true; + } + if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) { + device->coopmat_support_16x16x16_f16acc = true; + } + } + } else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 && + (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 && + (vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eSint32 && + (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 && + (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup && + device->coopmat_int_m == 0 + ) { + device->coopmat_int_support = true; + device->coopmat_int_m = prop.MSize; + device->coopmat_int_n = prop.NSize; + device->coopmat_int_k = prop.KSize; + } +#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR && + prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR && + prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR && + (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup + ) { + // coopmat sizes not set yet + if (device->coopmat_m == 0) { + device->coopmat_bf16_support = true; + device->coopmat_m = prop.MSize; + device->coopmat_n = prop.NSize; + device->coopmat_k = prop.KSize; + } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { + // Only enable if shape is identical + device->coopmat_bf16_support = true; + } + } +#endif + } + + if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) { + // No suitable matmul mode found + GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n"); + device->coopmat_support = false; + } + if (getenv("GGML_VK_DISABLE_BFLOAT16")) { + device->coopmat_bf16_support = false; + } + } + + if (device->coopmat_support) { + device_extensions.push_back("VK_KHR_cooperative_matrix"); + } +#if defined(VK_KHR_shader_bfloat16) + if (device->coopmat_bf16_support) { + device_extensions.push_back("VK_KHR_shader_bfloat16"); + } +#endif +#endif + device->name = GGML_VK_NAME + std::to_string(idx); + + device_create_info = { + vk::DeviceCreateFlags(), + device_queue_create_infos, + {}, + device_extensions + }; + device_create_info.setPNext(&device_features2); + device->device = device->physical_device.createDevice(device_create_info); + + // Queues + ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false); + + // Shaders + // Disable matmul tile sizes early if performance low or not supported + for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) { + switch (device->vendor_id) { +#ifndef GGML_VULKAN_RUN_TESTS + case VK_VENDOR_ID_AMD: + case VK_VENDOR_ID_INTEL: + device->mul_mat_l[i] = false; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = true; + device->mul_mat_id_l[i] = false; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = true; + break; + case VK_VENDOR_ID_APPLE: + device->mul_mat_l[i] = false; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = false; + device->mul_mat_id_l[i] = false; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = false; + break; +#endif + default: + device->mul_mat_l[i] = true; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = true; + device->mul_mat_id_l[i] = true; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = true; + break; + } + } + + + std::vector dsl_binding; + std::vector dsl_binding_flags; + for (uint32_t i = 0; i < MAX_PARAMETER_COUNT; i++) { + dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute}); + dsl_binding_flags.push_back({}); + } + + vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags }; + + vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info( + {}, + dsl_binding); + descriptor_set_layout_create_info.setPNext(&dslbfci); + device->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info); + + ggml_vk_load_shaders(device); + + if (!device->single_queue) { + const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0; + ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true); + } else { + // TODO: Use pointer or reference to avoid copy + device->transfer_queue.copyFrom(device->compute_queue); + device->transfer_queue.cmd_pool.init(device, &device->transfer_queue); + } + + device->buffer_type = { + /* .iface = */ ggml_backend_vk_buffer_type_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), idx), + /* .context = */ new ggml_backend_vk_buffer_type_context{ device->name, device }, + }; + + device->fence = device->device.createFence({}); + + device->idx = idx; + + return device; + } + + return vk_instance.devices[idx]; +} + +static void ggml_vk_print_gpu_info(size_t idx) { + GGML_ASSERT(idx < vk_instance.device_indices.size()); + size_t dev_num = vk_instance.device_indices[idx]; + VK_LOG_DEBUG("ggml_vk_print_gpu_info(" << dev_num << ")"); + GGML_ASSERT(vk_instance_initialized); + + std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); + + if (dev_num >= devices.size()) { + std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl; + throw std::runtime_error("Device not found"); + } + + vk::PhysicalDevice physical_device = devices[dev_num]; + std::vector ext_props = physical_device.enumerateDeviceExtensionProperties(); + + bool fp16_storage = false; + bool fp16_compute = false; + bool coopmat_support = false; + bool coopmat2_support = false; + bool integer_dot_product = false; + + for (auto properties : ext_props) { + if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { + fp16_storage = true; + } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { + fp16_compute = true; +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT")) { + coopmat_support = true; +#endif +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT2")) { + coopmat2_support = true; +#endif +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { + integer_dot_product = true; +#endif + } + } + + const vk_device_architecture device_architecture = get_device_architecture(physical_device); + + const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16"); + bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr; + + bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute; + + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceMaintenance3Properties props3; + vk::PhysicalDeviceSubgroupProperties subgroup_props; + vk::PhysicalDeviceDriverProperties driver_props; + vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props; + props2.pNext = &props3; + props3.pNext = &subgroup_props; + subgroup_props.pNext = &driver_props; + + // Pointer to the last chain element + VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props; + + if (integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props; + } + + physical_device.getProperties2(&props2); + + VkPhysicalDeviceFeatures2 device_features2; + device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; + device_features2.pNext = nullptr; + + VkPhysicalDeviceVulkan11Features vk11_features; + vk11_features.pNext = nullptr; + vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; + device_features2.pNext = &vk11_features; + + VkPhysicalDeviceVulkan12Features vk12_features; + vk12_features.pNext = nullptr; + vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES; + vk11_features.pNext = &vk12_features; + + // Pointer to the last chain element + last_struct = (VkBaseOutStructure *)&vk12_features; + +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; + coopmat_features.pNext = nullptr; + coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; + coopmat_features.cooperativeMatrix = VK_FALSE; + + if (coopmat_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; + last_struct = (VkBaseOutStructure *)&coopmat_features; + } +#endif + + VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {}; + shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR; + if (integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features; + } + + vkGetPhysicalDeviceFeatures2(physical_device, &device_features2); + + fp16 = fp16 && vk12_features.shaderFloat16; + + uint32_t default_subgroup_size = get_subgroup_size("", device_architecture); + const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize; + const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; + + integer_dot_product = integer_dot_product + && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated + && shader_integer_dot_product_features.shaderIntegerDotProduct; + + coopmat_support = coopmat_support +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + && coopmat_features.cooperativeMatrix +#endif + && ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture); + + std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none"; + + std::string device_name = props2.properties.deviceName.data(); + GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n", + idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, + props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str()); + + if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { + GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n"); + } +} + +static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions); +static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions); + +static void ggml_vk_instance_init() { + if (vk_instance_initialized) { + return; + } + VK_LOG_DEBUG("ggml_vk_instance_init()"); + + uint32_t api_version = vk::enumerateInstanceVersion(); + + if (api_version < VK_API_VERSION_1_2) { + std::cerr << "ggml_vulkan: Error: Vulkan 1.2 required." << std::endl; + GGML_ABORT("fatal error"); + } + + vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version }; + + const std::vector instance_extensions = vk::enumerateInstanceExtensionProperties(); + const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions); +#ifdef __APPLE__ + const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions); +#endif + + std::vector layers; + + if (validation_ext) { + layers.push_back("VK_LAYER_KHRONOS_validation"); + } + std::vector extensions; + if (validation_ext) { + extensions.push_back("VK_EXT_validation_features"); + } +#ifdef __APPLE__ + if (portability_enumeration_ext) { + extensions.push_back("VK_KHR_portability_enumeration"); + } +#endif + vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions); +#ifdef __APPLE__ + if (portability_enumeration_ext) { + instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR; + } +#endif + + std::vector features_enable; + vk::ValidationFeaturesEXT validation_features; + + if (validation_ext) { + features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices }; + validation_features = { + features_enable, + {}, + }; + validation_features.setPNext(nullptr); + instance_create_info.setPNext(&validation_features); + GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n"); + } + vk_instance.instance = vk::createInstance(instance_create_info); + vk_instance_initialized = true; + + vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr; + + // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan + char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES"); + if (devices_env != nullptr) { + size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size(); + + std::string devices(devices_env); + std::replace(devices.begin(), devices.end(), ',', ' '); + + std::stringstream ss(devices); + size_t tmp; + while (ss >> tmp) { + if(tmp >= num_available_devices) { + std::cerr << "ggml_vulkan: Invalid device index " << tmp << " in GGML_VK_VISIBLE_DEVICES." << std::endl; + throw std::runtime_error("Invalid Vulkan device index"); + } + vk_instance.device_indices.push_back(tmp); + } + } else { + std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); + + // If no vulkan devices are found, return early + if (devices.empty()) { + GGML_LOG_INFO("ggml_vulkan: No devices found.\n"); + return; + } + + // Default to using all dedicated GPUs + for (size_t i = 0; i < devices.size(); i++) { + vk::PhysicalDeviceProperties2 new_props; + vk::PhysicalDeviceDriverProperties new_driver; + vk::PhysicalDeviceIDProperties new_id; + new_props.pNext = &new_driver; + new_driver.pNext = &new_id; + devices[i].getProperties2(&new_props); + + if (new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu) { + // Check if there are two physical devices corresponding to the same GPU + auto old_device = std::find_if( + vk_instance.device_indices.begin(), + vk_instance.device_indices.end(), + [&devices, &new_id](const size_t k){ + vk::PhysicalDeviceProperties2 old_props; + vk::PhysicalDeviceIDProperties old_id; + old_props.pNext = &old_id; + devices[k].getProperties2(&old_props); + return std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID)); + } + ); + if (old_device == vk_instance.device_indices.end()) { + vk_instance.device_indices.push_back(i); + } else { + // There can be two physical devices corresponding to the same GPU if there are 2 different drivers + // This can cause error when splitting layers aross the devices, need to keep only 1 + VK_LOG_DEBUG("Device " << i << " and device " << *old_device << " have the same deviceUUID"); + + vk::PhysicalDeviceProperties2 old_props; + vk::PhysicalDeviceDriverProperties old_driver; + old_props.pNext = &old_driver; + devices[*old_device].getProperties2(&old_props); + + std::map driver_priorities {}; + int old_priority = std::numeric_limits::max(); + int new_priority = std::numeric_limits::max(); + + // Check https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkDriverId.html for the list of driver id + // Smaller number -> higher priority + switch (old_props.properties.vendorID) { + case VK_VENDOR_ID_AMD: + driver_priorities[vk::DriverId::eMesaRadv] = 1; + driver_priorities[vk::DriverId::eAmdOpenSource] = 2; + driver_priorities[vk::DriverId::eAmdProprietary] = 3; + break; + case VK_VENDOR_ID_INTEL: + driver_priorities[vk::DriverId::eIntelOpenSourceMESA] = 1; + driver_priorities[vk::DriverId::eIntelProprietaryWindows] = 2; + break; + case VK_VENDOR_ID_NVIDIA: + driver_priorities[vk::DriverId::eNvidiaProprietary] = 1; +#if defined(VK_API_VERSION_1_3) && VK_HEADER_VERSION >= 235 + driver_priorities[vk::DriverId::eMesaNvk] = 2; +#endif + break; + } + + if (driver_priorities.count(old_driver.driverID)) { + old_priority = driver_priorities[old_driver.driverID]; + } + if (driver_priorities.count(new_driver.driverID)) { + new_priority = driver_priorities[new_driver.driverID]; + } + + if (new_priority < old_priority) { + auto r = std::remove(vk_instance.device_indices.begin(), vk_instance.device_indices.end(), *old_device); + vk_instance.device_indices.erase(r, vk_instance.device_indices.end()); + vk_instance.device_indices.push_back(i); + + VK_LOG_DEBUG("Prioritize device " << i << " driver " << new_driver.driverName << " over device " << *old_device << " driver " << old_driver.driverName); + } + else { + VK_LOG_DEBUG("Prioritize device " << *old_device << " driver " << old_driver.driverName << " over device " << i << " driver " << new_driver.driverName << std::endl); + } + } + } + } + + // If no dedicated GPUs found, fall back to the first non-CPU device. + // If only CPU devices are available, return without devices. + if (vk_instance.device_indices.empty()) { + for (size_t i = 0; i < devices.size(); i++) { + if (devices[i].getProperties().deviceType != vk::PhysicalDeviceType::eCpu) { + vk_instance.device_indices.push_back(i); + break; + } + } + } + + if (vk_instance.device_indices.empty()) { + GGML_LOG_INFO("ggml_vulkan: No devices found.\n"); + return; + } + } + GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size()); + + for (size_t i = 0; i < vk_instance.device_indices.size(); i++) { + ggml_vk_print_gpu_info(i); + } +} + +static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { + VK_LOG_DEBUG("ggml_vk_init(" << ctx->name << ", " << idx << ")"); + ggml_vk_instance_init(); + GGML_ASSERT(idx < vk_instance.device_indices.size()); + + ctx->name = GGML_VK_NAME + std::to_string(idx); + + ctx->device = ggml_vk_get_device(idx); + + ctx->semaphore_idx = 0; + ctx->event_idx = 0; + + ctx->prealloc_size_x = 0; + ctx->prealloc_size_y = 0; + ctx->prealloc_size_split_k = 0; + + ctx->fence = ctx->device->device.createFence({}); + ctx->almost_ready_fence = ctx->device->device.createFence({}); + + ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue); + ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue); + +#ifdef GGML_VULKAN_CHECK_RESULTS + const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS"); + vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks)); + const char* output_tensor = getenv("GGML_VULKAN_OUTPUT_TENSOR"); + vk_output_tensor = (output_tensor == NULL ? 0 : atoi(output_tensor)); +#endif +} + +static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) { + VK_LOG_DEBUG("ggml_vk_get_to_fp16()"); + switch (type) { + case GGML_TYPE_F32: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + break; + default: + return nullptr; + } + + return ctx->device->pipeline_dequant[type]; +} + +static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { + VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ", " << prec << ")"); + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_f32; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_f32_f16; + } + if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) { + return ctx->device->pipeline_matmul_bf16; + } + if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_f16_f32.f16acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_f16.f16acc; + } + } else { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_f16_f32.f32acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_f16.f32acc; + } + } + + // MMQ + if (src1_type == GGML_TYPE_Q8_1) { + vk_matmul_pipeline pipelines = (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc; + + if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) { + return nullptr; + } + + return pipelines; + } + + if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) { + return nullptr; + } + + switch (src0_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + break; + default: + return nullptr; + } + + if (ctx->device->coopmat2) { + assert(src1_type == GGML_TYPE_F16); + return prec == GGML_PREC_DEFAULT ? ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f32acc; + } + if (ctx->device->coopmat_support) { + return (ctx->device->fp16 && ctx->device->coopmat_acc_f16_support && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc; + } + return (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc; +} + +static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols) { + VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()"); + GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16); + GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols); + + switch (a_type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + break; + default: + return nullptr; + } + + return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type][num_cols-1]; +} + +static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { + VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()"); + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_id_f32; + } + if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) { + return ctx->device->pipeline_matmul_id_bf16; + } + if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_id_f16_f32.f16acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_id_f16.f16acc; + } + } else { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_id_f16_f32.f32acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_id_f16.f32acc; + } + } + + GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16)); + + switch (src0_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + break; + default: + return nullptr; + } + + return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc; +} + +static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) { + VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()"); + GGML_ASSERT(b_type == GGML_TYPE_F32); + + switch (a_type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + break; + default: + return nullptr; + } + + return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type]; +} + +static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) { + VK_LOG_DEBUG("ggml_vk_pool_malloc(" << size << ")"); + VK_LOG_MEMORY("ggml_vk_pool_malloc"); + + int best_i = -1; + size_t best_size = std::numeric_limits::max(); //smallest unused buffer that fits our needs + int worst_i = -1; + size_t worst_size = 0; //largest unused buffer seen so far + for (int i = 0; i < MAX_VK_BUFFERS; ++i) { + vk_buffer &b = ctx->buffer_pool[i]; + if (b != nullptr && b->size >= size && b->size < best_size) { + best_i = i; + best_size = b->size; + } + if (b != nullptr && b->size > worst_size) { + worst_i = i; + worst_size = b->size; + } + } + if(best_i != -1) { + //found the smallest buffer that fits our needs + vk_buffer b = ctx->buffer_pool[best_i]; + ctx->buffer_pool[best_i].reset(); + return b; + } + if(worst_i != -1) { + //no buffer that fits our needs, resize largest one to save memory + vk_buffer& b = ctx->buffer_pool[worst_i]; + ggml_vk_destroy_buffer(b); + } + + return ggml_vk_create_buffer_device(ctx->device, size); +} + +static void ggml_vk_pool_free(ggml_backend_vk_context * ctx, vk_buffer& buffer) { + VK_LOG_DEBUG("ggml_vk_pool_free(" << buffer->size << ")"); + for (int i = 0; i < MAX_VK_BUFFERS; ++i) { + vk_buffer& b = ctx->buffer_pool[i]; + if (b == nullptr) { + b = buffer; + return; + } + } + std::cerr << "ggml_vulkan: WARNING: vk buffer pool full, increase MAX_VK_BUFFERS" << std::endl; + ggml_vk_destroy_buffer(buffer); +} + +// Returns an available temporary buffer that may only be used temporarily, it will be reused +static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_t size) { + // Try to find existing temp buffer with enough capacity + for (auto& buffer : ctx->gc.temp_buffers) { + if (buffer->size >= size) { + return buffer; + } + } + + VK_LOG_MEMORY("ggml_vk_create_buffer_temp(" << size << ")"); + + // Otherwise create new buffer + vk_buffer buf = ggml_vk_pool_malloc(ctx, size); + ctx->gc.temp_buffers.push_back(buf); + + return buf; +} + +static void * ggml_vk_host_malloc(vk_device& device, size_t size) { + VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")"); + vk_buffer buf = ggml_vk_create_buffer(device, size, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); + + if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) { + fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n", + size/1024.0/1024.0); + device->device.freeMemory(buf->device_memory); + device->device.destroyBuffer(buf->buffer); + return nullptr; + } + + device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf)); + + return buf->ptr; +} + +static void ggml_vk_host_free(vk_device& device, void* ptr) { + if (ptr == nullptr) { + return; + } + VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")"); + vk_buffer buf; + size_t index; + for (size_t i = 0; i < device->pinned_memory.size(); i++) { + const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]); + const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]); + if (ptr >= addr && ptr < endr) { + buf = std::get<2>(device->pinned_memory[i]); + index = i; + break; + } + } + if (buf == nullptr) { + fprintf(stderr, "WARNING: failed to free pinned memory: memory not in map\n"); + return; + } + + ggml_vk_destroy_buffer(buf); + + device->pinned_memory.erase(device->pinned_memory.begin() + index); +} + +static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) { + buf = nullptr; + buf_offset = 0; + for (size_t i = 0; i < device->pinned_memory.size(); i++) { + const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]); + const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]); + if (ptr >= addr && ptr < endr) { + buf = std::get<2>(device->pinned_memory[i]); + buf_offset = ((const uint8_t *)ptr) - addr; + break; + } + } +} + +static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) { + vk_submission s; + s.buffer = ggml_vk_create_cmd_buffer(device, p); + if (one_time) { + s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit }); + } else { + s.buffer.begin({ vk::CommandBufferUsageFlags{} }); + } + + return s; +} + +template size_t push_constant_size(const T &t) { + static_assert(std::is_class::value, "T must be a struct/class"); + GGML_UNUSED(t); + return sizeof(T); +} +template size_t push_constant_size(const std::vector &t) { + GGML_UNUSED(t); + return sizeof(T) * t.size(); +} +template size_t push_constant_size(const std::array &t) { + GGML_UNUSED(t); + return sizeof(T) * N; +} + +template const T *push_constant_data(const T &t) { + static_assert(std::is_class::value, "T must be a struct/class"); + return &t; +} +template const T *push_constant_data(const std::vector &t) { + return t.data(); +} +template const T *push_constant_data(const std::array &t) { + return t.data(); +} + +template +static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list const& descriptor_buffer_infos, const T &push_constants, std::array elements) { + const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]); + const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]); + const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]); + VK_LOG_DEBUG("ggml_vk_dispatch_pipeline(" << pipeline->name << ", {"; + for (auto& buffer : descriptor_buffer_infos) { + std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), "; + } + std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))"); + GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size()); + GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT); + + vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++]; + vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() }; + ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {}); + + subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants)); + subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline); + subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute, + pipeline->layout, + 0, + { descriptor_set }, + {}); + subctx->s->buffer.dispatch(wg0, wg1, wg2); +} + +static void ggml_vk_end_submission(vk_submission& s, std::vector wait_semaphores, std::vector signal_semaphores) { + s.buffer.end(); + + s.wait_semaphores = std::move(wait_semaphores); + s.signal_semaphores = std::move(signal_semaphores); +} + +static void ggml_vk_ctx_end(vk_context& ctx) { + VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")"); + if (ctx->s == nullptr) { + return; + } + + ctx->s->buffer.end(); + ctx->s = nullptr; +} + +static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) { + VK_LOG_DEBUG("ggml_vk_ctx_begin(" << device->name << ")"); + if (subctx->s != nullptr) { + ggml_vk_ctx_end(subctx); + } + + subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->p) }); + subctx->s = subctx->seqs[subctx->seqs.size() - 1].data(); +} + +static size_t ggml_vk_align_size(size_t width, size_t align) { + VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")"); + return CEIL_DIV(width, align) * align; +} + +static void deferred_memcpy(void * dst, const void * src, size_t size, std::vector* memcpys = nullptr) { + if (memcpys == nullptr) { + memcpy(dst, src, size); + } else { + memcpys->emplace_back(dst, src, size); + } +} + +static void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) { + if (device->sync_staging == nullptr || device->sync_staging->size < size) { + VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")"); + ggml_vk_destroy_buffer(device->sync_staging); + device->sync_staging = ggml_vk_create_buffer_check(device, size, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); + } +} + +static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_context& subctx, vk_buffer& dst, size_t offset, const ggml_tensor * tensor, bool sync_staging = false) { + VK_LOG_DEBUG("ggml_vk_buffer_write_nc_async(" << tensor << ")"); + GGML_ASSERT(!ggml_is_contiguous(tensor)); + // Buffer is already mapped + if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + std::cerr << "ggml_vulkan: buffer_write_nc_async dst buffer is host_visible. Use synchronous write." << std::endl; + GGML_ABORT("fatal error"); + } + // Check if src is pinned memory + vk_buffer buf = nullptr; + size_t buf_offset = 0; + ggml_vk_host_get(ctx->device, tensor->data, buf, buf_offset); + + const uint64_t ne0 = tensor->ne[0]; + const uint64_t ne1 = tensor->ne[1]; + const uint64_t ne2 = tensor->ne[2]; + const uint64_t ne3 = tensor->ne[3]; + const uint64_t nb0 = tensor->nb[0]; + const uint64_t nb1 = tensor->nb[1]; + const uint64_t nb2 = tensor->nb[2]; + const uint64_t nb3 = tensor->nb[3]; + const ggml_type type = tensor->type; + const uint64_t ts = ggml_type_size(type); + const uint64_t bs = ggml_blck_size(type); + + const uint64_t dstnb0 = ts; + const uint64_t dstnb1 = dstnb0*(ne0/bs); + const uint64_t dstnb2 = dstnb1*ne1; + const uint64_t dstnb3 = dstnb2*ne2; + + const uint64_t ne = ggml_nelements(tensor); + + if (buf != nullptr) { + // Memory is pinned, use as staging buffer + std::vector slices; + + for (uint64_t i3 = 0; i3 < ne3; i3++) { + for (uint64_t i2 = 0; i2 < ne2; i2++) { + // Find longest contiguous slice + if (ne1*nb1 == dstnb2) { + slices.push_back({ buf_offset + i3*nb3 + i2*nb2, offset + i3*dstnb3 + i2*dstnb2, dstnb2 }); + } else { + for (uint64_t i1 = 0; i1 < ne1; i1++) { + if (ne0*nb0/bs == dstnb1) { + slices.push_back({ buf_offset + i3*nb3 + i2*nb2 + i1*nb1, offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, dstnb1 }); + } else { + const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1; + const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1; + for (uint64_t i0 = 0; i0 < ne0; i0++) { + slices.push_back({ s_off + i1*nb0, d_off + i0*dstnb0, dstnb0 }); + } + } + } + } + } + } + + ggml_vk_sync_buffers(subctx); + subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); + return; + } + + if (!sync_staging) { + GGML_ABORT("Asynchronous write to non-pinned memory not supported"); + } + + // Staging buffer required + vk_buffer& staging = ctx->device->sync_staging; + const uint64_t copy_size = ts*ne/bs; + ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size); + VkBufferCopy buf_copy{ 0, offset, copy_size }; + + ggml_vk_sync_buffers(subctx); + vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); + + for (uint64_t i3 = 0; i3 < ne3; i3++) { + for (uint64_t i2 = 0; i2 < ne2; i2++) { + // Find longest contiguous slice + if (ne1*nb1 == dstnb2) { + deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2, dstnb2, &subctx->in_memcpys); + } else { + for (uint64_t i1 = 0; i1 < ne1; i1++) { + if (ne0*nb0/bs == dstnb1) { + deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2 + i1*nb1, dstnb1, &subctx->in_memcpys); + } else { + const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1; + const uint64_t d_off = i3*dstnb3 + i2*dstnb2 + i1*dstnb1; + for (uint64_t i0 = 0; i0 < ne0; i0++) { + deferred_memcpy((uint8_t *)staging->ptr + d_off + i0*dstnb0, (const uint8_t *) tensor->data + s_off + i0*nb0, dstnb0, &subctx->in_memcpys); + } + } + } + } + } + } +} + +static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) { + VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")"); + // Buffer is already mapped + if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + std::cerr << "ggml_vulkan: buffer_write_async dst buffer is host_visible. Use synchronous write." << std::endl; + GGML_ABORT("fatal error"); + } + // Check if src is pinned memory + vk_buffer buf = nullptr; + size_t buf_offset = 0; + ggml_vk_host_get(dst->device, src, buf, buf_offset); + + if (buf != nullptr) { + // Memory is pinned, use as staging buffer + std::vector slices(1); + if (width == spitch) { + // Only do single write if stride is equal + slices[0].srcOffset = buf_offset; + slices[0].dstOffset = offset; + slices[0].size = width * height; + } else { + slices.resize(height); + for (size_t i = 0; i < height; i++) { + slices[i].srcOffset = buf_offset + i * spitch; + slices[i].dstOffset = offset + i * width; + slices[i].size = width; + } + } + + ggml_vk_sync_buffers(subctx); + subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); + return; + } + VK_LOG_DEBUG("STAGING"); + + if (!sync_staging) { + GGML_ABORT("Asynchronous write to non-pinned memory not supported"); + } + + // Staging buffer required + const size_t copy_size = width*height; + ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size); + + vk_buffer& staging_buffer = dst->device->sync_staging; + + VkBufferCopy buf_copy = { + 0, + offset, + copy_size}; + + ggml_vk_sync_buffers(subctx); + vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); + + if (width == spitch) { + deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys); + } else { + for (size_t i = 0; i < height; i++) { + deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys); + } + } +} + +static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) { + VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")"); + return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging); +} + +static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) { + VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")"); + // Buffer is already mapped + if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); + + for (size_t i = 0; i < height; i++) { + memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width); + } + } else { + std::lock_guard guard(dst->device->mutex); + + vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool); + ggml_vk_ctx_begin(dst->device, subctx); + ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true); + ggml_vk_ctx_end(subctx); + + for (auto& cpy : subctx->in_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + + ggml_vk_submit(subctx, dst->device->fence); + VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences"); + dst->device->device.resetFences({ dst->device->fence }); + ggml_vk_queue_command_pools_cleanup(dst->device); + } +} + +static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")"); + ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1); +} + +static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) { + VK_LOG_DEBUG("ggml_vk_buffer_read_2d_async(offset=" << offset << ", width=" << width << ", height=" << height << ")"); + GGML_ASSERT(width > 0); + GGML_ASSERT(height > 0); + GGML_ASSERT(src != nullptr); + + // TODO: staging_offset is not used + + // Check if dst is pinned memory + vk_buffer buf = nullptr; + size_t buf_offset = 0; + ggml_vk_host_get(src->device, dst, buf, buf_offset); + + std::vector slices(1); + if (width == spitch && width == dpitch) { + // Only do single write if stride is equal + slices[0].srcOffset = offset; + slices[0].dstOffset = buf_offset; + slices[0].size = width * height; + } else { + slices.resize(height); + for (size_t i = 0; i < height; i++) { + slices[i].srcOffset = offset + i * spitch; + slices[i].dstOffset = buf_offset + i * dpitch; + slices[i].size = width; + } + } + + if (buf != nullptr) { + // Memory is pinned, use as staging buffer + ggml_vk_sync_buffers(subctx); + subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices); + + return; + } + VK_LOG_DEBUG("STAGING"); + + if (!sync_staging) { + GGML_ABORT("Asynchronous read from non-pinned memory not supported"); + } + + // Fall back to staging buffer + const size_t copy_size = dpitch * height; + ggml_vk_ensure_sync_staging_buffer(src->device, copy_size); + + vk_buffer& staging_buffer = src->device->sync_staging; + + ggml_vk_sync_buffers(subctx); + subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices); + + deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys); +} + +static void ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, bool sync_staging = false) { + return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, sync_staging); +} + +static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")"); + + // If the device is not an UMA device the memory is host-accessible through rebar. While writing + // through PCIe is sufficient fast reading back data from PCIe is slower than going through + // the HW device to host copy path. + if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) { + GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); + + memcpy(dst, (uint8_t *) src->ptr + offset, size); + } else { + std::lock_guard guard(src->device->mutex); + + vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool); + ggml_vk_ctx_begin(src->device, subctx); + ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true); + ggml_vk_ctx_end(subctx); + + ggml_vk_submit(subctx, src->device->fence); + VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences"); + src->device->device.resetFences({ src->device->fence }); + ggml_vk_queue_command_pools_cleanup(src->device); + + for (auto& cpy : subctx->out_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + } +} + +static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")"); + // Make sure both buffers are on same device + GGML_ASSERT(src->device == dst->device); + + VkBufferCopy bc{ src_offset, dst_offset, size }; + + vkCmdCopyBuffer(ctx->s->buffer, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc); +} + +static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { + if (src->device == dst->device) { + std::lock_guard guard(src->device->mutex); + VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")"); + // Copy within the device + vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool); + ggml_vk_ctx_begin(src->device, subctx); + ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size); + ggml_vk_ctx_end(subctx); + ggml_vk_submit(subctx, src->device->fence); + VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences"); + src->device->device.resetFences({ src->device->fence }); + ggml_vk_queue_command_pools_cleanup(src->device); + } else { + VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")"); + // Copy device to device + ggml_vk_ensure_sync_staging_buffer(src->device, size); + ggml_vk_ensure_sync_staging_buffer(dst->device, size); + + // Copy to src staging buffer + ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size); + // memcpy to dst staging buffer + memcpy(dst->device->sync_staging->ptr, src->device->sync_staging->ptr, size); + // Copy to dst buffer + ggml_vk_buffer_copy(dst, dst_offset, dst->device->sync_staging, 0, size); + } +} + +static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")"); + + ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); +} + +static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")"); + + std::lock_guard guard(dst->device->mutex); + vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool); + ggml_vk_ctx_begin(dst->device, subctx); + subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); + ggml_vk_ctx_end(subctx); + + ggml_vk_submit(subctx, dst->device->fence); + VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences"); + dst->device->device.resetFences({ dst->device->fence }); + ggml_vk_queue_command_pools_cleanup(dst->device); +} + +static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) { + VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")"); + + uint32_t split_k = 1; + if (ctx->device->shader_core_count != 0 && m >= (int)pipeline->wg_denoms[0] && n >= (int)pipeline->wg_denoms[1]) { + // If k is 'large' and the SMs will fill less than halfway, use split_k. + uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]); + uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]); + if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) { + split_k = ctx->device->shader_core_count / (m_tiles * n_tiles); + // Clamp to 2 or 4 + split_k = std::min(split_k, 4u); + if (split_k == 3) { + split_k = 2; + } + if (ctx->device->coopmat2) { + // coopmat2 shader expects splits to be aligned to 256 + while (split_k > 1 && ((k / split_k) % 256) != 0) { + split_k /= 2; + } + } + } + } + + return split_k; +} + +static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + + if (ctx->device->coopmat2) { + // Use large shader when the N dimension is greater than the medium shader's tile size + uint32_t crossover_large = mmp->m->wg_denoms[1]; + if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { + return aligned ? mmp->a_l : mmp->l; + } + // Use medium shader when the N dimension is greater than the small shader's tile size + uint32_t crossover_medium = mmp->s->wg_denoms[1]; + if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) { + return aligned ? mmp->a_m : mmp->m; + } + return aligned ? mmp->a_s : mmp->s; + } + + if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) { + return aligned ? mmp->a_s : mmp->s; + } + if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) { + return aligned ? mmp->a_m : mmp->m; + } + return aligned ? mmp->a_l : mmp->l; + + GGML_UNUSED(src1_type); +} + +static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align; +} + +static void ggml_vk_matmul( + ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, + vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer, + uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, + uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, + uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3, + uint32_t padded_n) { + VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")"); + ggml_vk_sync_buffers(subctx); + if (split_k == 1) { + const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch }); + return; + } + + GGML_ASSERT(batch_stride_d == m * n); + + const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n }; + // Make sure enough workgroups get assigned for split k to work + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); + ggml_vk_sync_buffers(subctx); + const std::array pc2 = { (uint32_t)(m * n * batch), split_k }; + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 }); +} + +static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); + + if (ctx->device->coopmat2) { + // Use large shader when the N dimension is greater than the medium shader's tile size + uint32_t crossover_large = mmp->m->wg_denoms[1]; + if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) { + return aligned ? mmp->a_l : mmp->l; + } + // Use medium shader when the N dimension is greater than the small shader's tile size + uint32_t crossover_medium = mmp->s->wg_denoms[1]; + if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) { + return aligned ? mmp->a_m : mmp->m; + } + return aligned ? mmp->a_s : mmp->s; + } + + if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) { + return aligned ? mmp->a_s : mmp->s; + } + if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) { + return aligned ? mmp->a_m : mmp->m; + } + return aligned ? mmp->a_l : mmp->l; +} + +static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")"); + return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align; +} + +static void ggml_vk_matmul_id( + ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, + vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, + uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, + uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, + uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11, + uint32_t padded_n) { + VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " << + "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " << + "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " << + "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")"); + ggml_vk_sync_buffers(subctx); + const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, + nei0, nei1, nbi1, ne11, padded_n }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, pc, { m, nei1, n_as }); +} + +static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) { + return + tensor->nb[0] == ggml_type_size(tensor->type) && + tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + +static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) { + + // Choose "contiguous copy" shader if src/dst are contiguous + bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst)); + + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_f32; + } else { + return ctx->device->pipeline_cpy_f32_f32; + } + } + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F16) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_f16; + } else { + return ctx->device->pipeline_cpy_f32_f16; + } + } + if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F16) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f16_f16; + } else { + return ctx->device->pipeline_cpy_f16_f16; + } + } + if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f16_f32; + } else { + return ctx->device->pipeline_cpy_f16_f32; + } + } + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_bf16; + } else { + return ctx->device->pipeline_cpy_f32_bf16; + } + } + if (src->type == GGML_TYPE_F32) { + switch (to) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return ctx->device->pipeline_cpy_f32_quant[to]; + default: + break; + } + } + + if (to == GGML_TYPE_F32) { + switch (src->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return ctx->device->pipeline_cpy_quant_f32[src->type]; + default: + break; + } + } + + if (src->type == to) { + // Copy two or four bytes at a time, depending on block size. + // For quantized types, we scale by block size/type size. But + // this path is also used for bf16->bf16 for example, where the + // type size must be exactly 2 or 4. + GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4); + if ((ggml_type_size(src->type) % 4) == 0) { + return ctx->device->pipeline_contig_cpy_f32_f32; + } else { + return ctx->device->pipeline_contig_cpy_f16_f16; + } + } + + std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl; + GGML_ABORT("fatal error"); +} + +static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out) { + VK_LOG_DEBUG("ggml_vk_cpy_to_contiguous((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), "; + std::cerr << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")"); + const int tensor_type_size = ggml_type_size(tensor->type); + + const uint32_t ne = ggml_nelements(tensor); + std::array elements; + + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; + } else { + elements = { ne, 1, 1 }; + } + + vk_op_unary_push_constants pc = { + (uint32_t)ne, + (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size, + (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], 1 , (uint32_t)tensor->ne[0] , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]), + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }; + init_pushconst_fastdiv(pc); + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements); +} + +static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) { + switch(type) { + case GGML_TYPE_Q8_1: + return ctx->device->pipeline_quantize_q8_1; + default: + std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl; + GGML_ABORT("fatal error"); + } +} + +static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne) { + VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")"); + + vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); + + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array{ne}, { ne, 1, 1 }); +} + +static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + const uint64_t ne20 = dst->ne[0]; + const uint64_t ne21 = dst->ne[1]; + + const uint64_t r2 = ne12 / ne02; + const uint64_t r3 = ne13 / ne03; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_Qx = nullptr; + size_t qx_buf_offset = 0; + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + } + + // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf + const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) || + !ggml_vk_dim01_contiguous(src1); + + // If src0 is BF16, try to use a BF16 x BF16 multiply + ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16; + + const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; + + bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0; + + // Check for mmq first + vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr; + + if (mmp == nullptr) { + // Fall back to f16 dequant mul mat + mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]); + quantize_y = false; + } + + const bool qx_needs_dequant = mmp == nullptr || x_non_contig; + const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig); + + if (qx_needs_dequant) { + // Fall back to dequant + f16 mulmat + mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]); + } + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + + const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type))); + const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8; + + vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)); + + // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking + uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11; + const int x_ne = ne01 * ne00; + const int y_ne = padded_n * ne10; + const int d_ne = ne11 * ne01; + + const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline); + + const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; + const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); + const uint64_t d_sz = sizeof(float) * d_ne; + + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + vk_pipeline to_q8_1 = nullptr; + + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type); + } else { + to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + + if (quantize_y) { + to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); + } + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + const uint64_t y_sz_upd = y_sz * ne12 * ne13; + const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0; + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || + (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) || + (split_k > 1 && split_k_size > ctx->device->max_memory_allocation_size)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) { + ctx->prealloc_size_split_k = split_k_size; + } + + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); + } + if (quantize_y) { + ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); + } + if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, 1); + } + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + GGML_ASSERT(d_D->size >= d_buf_offset + d_sz * ne02 * ne03); + vk_buffer d_X; + uint64_t x_buf_offset = 0; + vk_buffer d_Y; + uint64_t y_buf_offset = 0; + if (!src0_uma) { + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + if (!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qy != nullptr); + } + if (qx_needs_dequant) { + d_X = ctx->prealloc_x; + GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03); + } else { + d_X = d_Qx; + x_buf_offset = qx_buf_offset; + GGML_ASSERT(qx_sz == x_sz); + } + if (qy_needs_dequant) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13); + } else if (quantize_y) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)); + } else { + d_Y = d_Qy; + y_buf_offset = qy_buf_offset; + GGML_ASSERT(qy_sz == y_sz); + } + + if (x_non_contig) { + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); + } else if (qx_needs_dequant) { + const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); + } + if (y_non_contig) { + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + } + if (quantize_y) { + ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13); + } + + uint32_t stride_batch_x = ne00*ne01; + uint32_t stride_batch_y = ne10*ne11; + + if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { + stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); + } + + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + } + + // compute + ggml_vk_matmul( + ctx, subctx, pipeline, + { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, + { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, + ne01, ne11, ne10, + ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21, + split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n + ); // NOLINT +} + +static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_vec_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << "),)"); + GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + const uint64_t ne20 = dst->ne[0]; + const uint64_t ne21 = dst->ne[1]; + const uint64_t ne22 = dst->ne[2]; + const uint64_t ne23 = dst->ne[3]; + + const uint64_t r2 = ne12 / ne02; + const uint64_t r3 = ne13 / ne03; + + // batch_n indicates that we need to compute a few vector results, and this assumes + // ne12 and ne13 are 1. It overloads the batch_strides to hold the row strides. + GGML_ASSERT(ne11 == 1 || ne12 * ne13 == 1); + bool batch_n = ne11 > 1; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_Qx = nullptr; + size_t qx_buf_offset = 0; + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + } + + const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); + + const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; + + const bool qx_needs_dequant = x_non_contig; + const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig; + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + + const uint64_t x_ne = ne01 * ne00; + const uint64_t y_ne = ne11 * ne10; + const uint64_t d_ne = ne11 * ne01; + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; + const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + const uint64_t d_sz = sizeof(float) * d_ne; + + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11); + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + GGML_ASSERT(dmmv != nullptr); + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + const uint64_t y_sz_upd = y_sz * ne12 * ne13; + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || + (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + + // Request descriptor sets + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); + } + ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_X; + uint64_t x_buf_offset = 0; + vk_buffer d_Y; + uint64_t y_buf_offset = 0; + if(!src0_uma) { + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + if(!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qy != nullptr); + } + if (qx_needs_dequant) { + d_X = ctx->prealloc_x; + } else { + d_X = d_Qx; + x_buf_offset = qx_buf_offset; + GGML_ASSERT(qx_sz == x_sz); + } + if (qy_needs_dequant) { + d_Y = ctx->prealloc_y; + } else { + d_Y = d_Qy; + y_buf_offset = qy_buf_offset; + GGML_ASSERT(qy_sz == y_sz); + } + + if (x_non_contig) { + GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); + } + if (y_non_contig) { + GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + } + + // For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride + uint32_t stride_batch_x = batch_n ? 0 : ne00*ne01; + uint32_t stride_batch_y = batch_n ? ne10 : (ne10*ne11); + uint32_t stride_batch_d = batch_n ? ne20 : (ne20*ne21); + + if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { + stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); + } + + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + } + + const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; + + uint32_t groups_x = ne01; + uint32_t groups_z = 1; + + if (ne01 > max_groups_x) { + groups_z = 64; + groups_x = CEIL_DIV(groups_x, groups_z); + } + + // compute + const vk_mat_vec_push_constants pc = { + (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, + stride_batch_x, stride_batch_y, stride_batch_d, + (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, + }; + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, + { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} }, + pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z }); +} + +static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); + GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // NOLINT + GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // NOLINT + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + // const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + // const uint64_t ne13 = src1->ne[3]; + + GGML_ASSERT(ne11 == 1); + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + src1_uma = d_Qy != nullptr; + } + + const uint64_t x_ne = ne00 * ne01 * ne02; + const uint64_t y_ne = ne10 * ne11 * ne12; + const uint64_t d_ne = ne01 * ne11 * ne12; + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t d_sz = sizeof(float) * d_ne; + + // With grouped query attention there are > 1 Q matrices per K, V matrix. + uint32_t gqa_ratio = (uint32_t)ne12 / (uint32_t)ne02; + if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) { + gqa_ratio = 1; + } + + if (dryrun) { + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1); + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_Qx = src0_buf_ctx->dev_buffer; + const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + if (!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + + const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset; + + const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; + + // compute + const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; + + uint32_t workgroups_z = (uint32_t)ne12; + // When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups + if (gqa_ratio > 1) { + workgroups_z /= gqa_ratio; + } + + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, workgroups_z }); +} + +static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_nc_f16_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(!ggml_is_transposed(src0)); + GGML_ASSERT(!ggml_is_transposed(src1)); + GGML_ASSERT(!ggml_is_permuted(src0)); + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + // const uint64_t ne03 = src0->ne[3]; + + const uint64_t nb01 = src0->nb[1]; + const uint64_t nb02 = src0->nb[2]; + + const uint64_t nb12 = src1->nb[2]; + + // const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + // const uint64_t ne13 = src1->ne[3]; + + GGML_ASSERT(ne11 == 1); + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + src1_uma = d_Qy != nullptr; + } + + const uint64_t d_ne = ne01 * ne11 * ne12; + + const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t); + const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t); + const uint32_t channel_stride_y = nb12 / sizeof(float); + + const uint64_t qx_sz = ggml_nbytes(src0); + const uint64_t qy_sz = ggml_nbytes(src1); + const uint64_t d_sz = sizeof(float) * d_ne; + + if (dryrun) { + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1); + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_Qx = src0_buf_ctx->dev_buffer; + const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + if (!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + + const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset; + + const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; + + // compute + const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, + { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); +} + +static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")"); + if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 && + // detect 0213 permutation, and batch size of 1 + src0->nb[0] <= src0->nb[2] && + src0->nb[2] <= src0->nb[1] && + src0->nb[1] <= src0->nb[3] && + src1->nb[0] <= src1->nb[2] && + src1->nb[2] <= src1->nb[1] && + src1->nb[1] <= src1->nb[3] && + src0->ne[3] == 1 && + src1->ne[3] == 1) { + ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst, dryrun); + } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 && + !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) { + ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun); + // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four) + // when ne12 and ne13 are one. + } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) { + ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun); + } else { + ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun); + } +} + +static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)"); + GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + const uint64_t nei0 = ids->ne[0]; + const uint64_t nei1 = ids->ne[1]; + GGML_ASSERT(nei0 * nei1 <= 4096); + + const uint32_t nbi1 = ids->nb[1]; + const uint32_t nbi2 = ids->nb[2]; + + const uint64_t ne20 = dst->ne[0]; + const uint64_t ne21 = dst->ne[1]; + const uint64_t ne22 = dst->ne[2]; + const uint64_t ne23 = dst->ne[3]; + + const uint64_t n_as = ne02; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context; + + vk_buffer d_Qx = nullptr; + size_t qx_buf_offset = 0; + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + vk_buffer d_ids = nullptr; + size_t ids_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + bool ids_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + ids_uma = d_ids != nullptr; + } + + // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf + const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) || + !ggml_vk_dim01_contiguous(src1); + + // If src0 is BF16, try to use a BF16 x BF16 multiply + ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16; + + const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; + + vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]); + + const bool qx_needs_dequant = mmp == nullptr || x_non_contig; + const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig; + + if (qx_needs_dequant) { + // Fall back to dequant + f16 mulmat + mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]); + } + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + + const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type)); + const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8; + + vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type); + + // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking + uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11; + const uint64_t x_ne = ne01 * ne00; + const uint64_t y_ne = padded_n * ne10; + const uint64_t d_ne = ne21 * ne20; + + const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; + const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + const uint64_t ids_sz = nbi2; + const uint64_t d_sz = sizeof(float) * d_ne; + + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type); + } else { + to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + const uint64_t y_sz_upd = y_sz * ne12 * ne13; + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || + (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); + } + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_X; + uint64_t x_buf_offset = 0; + vk_buffer d_Y; + uint64_t y_buf_offset = 0; + if (!src0_uma) { + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + if (!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qy != nullptr); + } + if (!ids_uma) { + d_ids = ids_buf_ctx->dev_buffer; + ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs; + GGML_ASSERT(d_ids != nullptr); + } + if (qx_needs_dequant) { + d_X = ctx->prealloc_x; + GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03); + } else { + d_X = d_Qx; + x_buf_offset = qx_buf_offset; + GGML_ASSERT(qx_sz == x_sz); + } + if (qy_needs_dequant) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13); + } else { + d_Y = d_Qy; + y_buf_offset = qy_buf_offset; + GGML_ASSERT(qy_sz == y_sz); + } + + if (x_non_contig) { + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); + } else if (qx_needs_dequant) { + const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, + { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); + } + if (y_non_contig) { + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + } + + uint32_t stride_batch_x = ne00*ne01; + uint32_t stride_batch_y = ne10*ne11; + + if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { + stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); + } + + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + } + + // compute + ggml_vk_matmul_id( + ctx, subctx, pipeline, + { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, + { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz }, + ne01, ne21, ne10, ne10, ne10, ne01, + stride_batch_x, stride_batch_y, ne20*ne21, + n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n + ); // NOLINT +} + +static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + const uint64_t nei0 = ids->ne[0]; + const uint64_t nei1 = ids->ne[1]; + + const uint64_t nbi2 = ids->nb[2]; + + GGML_ASSERT(nei1 == 1); + + const uint64_t ne20 = dst->ne[0]; + const uint64_t ne21 = dst->ne[1]; + const uint64_t ne22 = dst->ne[2]; + const uint64_t ne23 = dst->ne[3]; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context; + + vk_buffer d_Qx = nullptr; + size_t qx_buf_offset = 0; + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + vk_buffer d_ids = nullptr; + size_t ids_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + bool ids_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + ids_uma = d_ids != nullptr; + } + + const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); + + const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; + + const bool qx_needs_dequant = x_non_contig; + const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig; + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + + const uint64_t x_ne = ne01 * ne00; + const uint64_t y_ne = ne11 * ne10; + const uint64_t d_ne = ne21 * ne20; + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; + const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + const uint64_t ids_sz = nbi2; + const uint64_t d_sz = sizeof(float) * d_ne; + + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type); + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + GGML_ASSERT(dmmv != nullptr); + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + const uint64_t y_sz_upd = y_sz * ne12 * ne13; + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || + (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + + // Request descriptor sets + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); + } + ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_X; + uint64_t x_buf_offset = 0; + vk_buffer d_Y; + uint64_t y_buf_offset = 0; + if(!src0_uma) { + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + if(!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qy != nullptr); + } + if(!ids_uma) { + d_ids = ids_buf_ctx->dev_buffer; + ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs; + GGML_ASSERT(d_ids != nullptr); + } + if (qx_needs_dequant) { + d_X = ctx->prealloc_x; + } else { + d_X = d_Qx; + x_buf_offset = qx_buf_offset; + GGML_ASSERT(qx_sz == x_sz); + } + if (qy_needs_dequant) { + d_Y = ctx->prealloc_y; + } else { + d_Y = d_Qy; + y_buf_offset = qy_buf_offset; + GGML_ASSERT(qy_sz == y_sz); + } + + if (x_non_contig) { + GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); + } + if (y_non_contig) { + GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + } + + uint32_t stride_batch_y = ne10*ne11; + + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + } + + const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; + + uint32_t groups_x = ne01; + uint32_t groups_z = 1; + + if (ne01 > max_groups_x) { + groups_z = 64; + groups_x = CEIL_DIV(groups_x, groups_z); + } + + // compute + const vk_mat_vec_id_push_constants pc = { + (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, + (uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21), + (uint32_t)nei0, (uint32_t)ne11, + }; + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, + { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, + vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } }, + pc, { groups_x, (uint32_t)nei0, groups_z }); +} + +static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_id(" << src0 << ", " << src1 << ", " << src2 << ", " << dst << ")"); + if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { + ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); + } else { + ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); + } +} + +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) { + // Needs to be kept up to date on shader changes + const uint32_t wg_size = scalar_flash_attention_workgroup_size; + const uint32_t Br = scalar_flash_attention_num_large_rows; + const uint32_t Bc = scalar_flash_attention_Bc; + + const uint32_t acctype = f32acc ? 4 : 2; + const uint32_t f16vec4 = 8; + + const uint32_t tmpsh = wg_size * sizeof(float); + const uint32_t tmpshv4 = wg_size * 4 * acctype; + + const uint32_t Qf = Br * (D / 4 + 2) * f16vec4; + + const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br; + const uint32_t sfsh = Bc * sfshstride * acctype; + + const uint32_t kshstride = D / 4 + 2; + const uint32_t ksh = Bc * kshstride * f16vec4; + + const uint32_t slope = Br * sizeof(float); + + const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported); + + return supported; +} + +static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3]; + std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3]; + std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const uint32_t nem1 = mask ? mask->ne[1] : 0; + const uint32_t nbm1 = mask ? mask->nb[1] : 0; + + const uint32_t D = neq0; + uint32_t N = neq1; + const uint32_t KV = nek1; + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne2 == N); + + // input tensor rows must be contiguous + GGML_ASSERT(nbq0 == ggml_type_size(q->type)); + GGML_ASSERT(nbk0 == ggml_type_size(k->type)); + GGML_ASSERT(nbv0 == ggml_type_size(v->type)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev0 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nev0 == D); + + GGML_ASSERT(nev1 == nek1); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + assert(dst->type == GGML_TYPE_F32); + assert(q->type == GGML_TYPE_F32); + assert(k->type == v->type); + + FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 : + ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; + + if (path == FA_COOPMAT1) { + const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) || + (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc); + + const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32); + + if (!coopmat_shape_supported || !coopmat_shmem_supported) { + path = FA_SCALAR; + } + } + + uint32_t gqa_ratio = 1; + uint32_t qk_ratio = neq2 / nek2; + uint32_t workgroups_x = (uint32_t)neq1; + uint32_t workgroups_y = (uint32_t)neq2; + uint32_t workgroups_z = (uint32_t)neq3; + + // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. + // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). + uint32_t max_gqa; + switch (path) { + case FA_SCALAR: + case FA_COOPMAT1: + // We may switch from coopmat1 to scalar, so use the scalar limit for both + max_gqa = scalar_flash_attention_num_large_rows; + break; + case FA_COOPMAT2: + max_gqa = get_fa_num_small_rows(FA_COOPMAT2); + break; + default: + GGML_ASSERT(0); + } + + if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && + qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { + // grouped query attention - make the N dimension equal to gqa_ratio, reduce + // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 + // and change addressing calculations to index Q's dimension 2. + gqa_ratio = qk_ratio; + N = gqa_ratio; + workgroups_y /= N; + } + + vk_pipeline *pipelines; + bool small_rows = N <= get_fa_num_small_rows(path); + + // coopmat1 does not actually support "small rows" (it needs 16 rows). + // So use scalar instead. + if (small_rows && path == FA_COOPMAT1) { + path = FA_SCALAR; + } + + // scalar is faster than coopmat2 when N==1 + if (N == 1 && path == FA_COOPMAT2) { + path = FA_SCALAR; + } + + bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; + + switch (path) { + case FA_SCALAR: + switch (D) { + case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; + case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; + case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break; + case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break; + case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break; + case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break; + default: + GGML_ASSERT(!"unsupported D value"); + return; + } + break; + case FA_COOPMAT1: + switch (D) { + case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break; + case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break; + case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break; + case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break; + case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break; + case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break; + default: + GGML_ASSERT(!"unsupported D value"); + return; + } + break; + case FA_COOPMAT2: + switch (D) { + case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break; + case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break; + case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break; + case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break; + case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break; + case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break; + default: + GGML_ASSERT(!"unsupported D value"); + return; + } + break; + default: + GGML_ASSERT(0); + } + assert(pipelines); + + const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); + const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); + const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); + + bool aligned = (KV % pipelines[1]->align) == 0 && + // the "aligned" shader variant will forcibly align strides, for performance + (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; + + // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads + GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0); + + vk_pipeline pipeline = pipelines[aligned]; + assert(pipeline); + + uint32_t split_kv = KV; + uint32_t split_k = 1; + + // Use a placeholder core count if one isn't available. split_k is a big help for perf. + const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; + + // Try to use split_k when KV is large enough to be worth the overhead + if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) { + // Try to run two workgroups per SM. + split_k = ctx->device->shader_core_count * 2 / workgroups_y; + if (split_k > 1) { + // Try to evenly split KV into split_k chunks, but it needs to be a multiple + // of "align", so recompute split_k based on that. + split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align); + split_k = CEIL_DIV(KV, split_kv); + workgroups_x = split_k; + } + } + + // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1) + // and the per-row m and L values (ne1 rows). + const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0; + if (split_k_size > ctx->device->max_memory_allocation_size) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (ctx->prealloc_size_split_k < split_k_size) { + ctx->prealloc_size_split_k = split_k_size; + } + + if (dryrun) { + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1); + } + return; + } + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head_kv = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr; + size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0; + + bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset); + ggml_vk_host_get(ctx->device, k->data, d_K, k_buf_offset); + ggml_vk_host_get(ctx->device, v->data, d_V, v_buf_offset); + ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset); + Q_uma = d_Q != nullptr; + K_uma = d_K != nullptr; + V_uma = d_V != nullptr; + D_uma = d_D != nullptr; + if (mask) { + ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset); + M_uma = d_M != nullptr; + } + } + + + ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * q_buf_ctx = (ggml_backend_vk_buffer_context *)q->buffer->context; + ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; + ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; + + if (!Q_uma) { + d_Q = q_buf_ctx->dev_buffer; + q_buf_offset = vk_tensor_offset(q) + q->view_offs; + } + if (!K_uma) { + d_K = k_buf_ctx->dev_buffer; + k_buf_offset = vk_tensor_offset(k) + k->view_offs; + } + if (!V_uma) { + d_V = v_buf_ctx->dev_buffer; + v_buf_offset = vk_tensor_offset(v) + v->view_offs; + } + if (!D_uma) { + d_D = d_buf_ctx->dev_buffer; + d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + } + + if (!M_uma) { + d_M = d_Q; + m_buf_offset = q_buf_offset; + if (mask) { + ggml_backend_vk_buffer_context * m_buf_ctx = (ggml_backend_vk_buffer_context*)mask->buffer->context; + d_M = m_buf_ctx->dev_buffer; + m_buf_offset = vk_tensor_offset(mask) + mask->view_offs; + } + } + + const vk_flash_attn_push_constants pc = { N, KV, + (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, + (uint32_t)neq2, (uint32_t)neq3, + (uint32_t)nek2, (uint32_t)nek3, + (uint32_t)nev2, (uint32_t)nev3, + nem1, + q_stride, (uint32_t)nbq2, (uint32_t)nbq3, + k_stride, (uint32_t)nbk2, (uint32_t)nbk3, + v_stride, (uint32_t)nbv2, (uint32_t)nbv3, + nbm1, + scale, max_bias, logit_softcap, + mask != nullptr, n_head_log2, m0, m1, + gqa_ratio, split_kv, split_k }; + + ggml_vk_sync_buffers(subctx); + + if (split_k > 1) { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, + }, + // We only use split_k when group query attention is enabled, which means + // there's no more than one tile of rows (i.e. workgroups_x would have been + // one). We reuse workgroups_x to mean the number of splits, so we need to + // cancel out the divide by wg_denoms[0]. + pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); + + ggml_vk_sync_buffers(subctx); + const std::array pc2 = { D, (uint32_t)ne1, split_k }; + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, + { + vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, + vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, + }, + pc2, { (uint32_t)ne1, 1, 1 }); + } else { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, + }, + pc, { workgroups_x, workgroups_y, workgroups_z }); + } +} + +static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) { + switch (op) { + case GGML_OP_GET_ROWS: + GGML_ASSERT(src1->type == GGML_TYPE_I32); + if (dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_get_rows[src0->type]; + } + if (dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_get_rows_f32[src0->type]; + } + return nullptr; + case GGML_OP_ACC: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_acc_f32; + } + return nullptr; + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) || + (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) || + (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16)) { + return nullptr; + } + switch (op) { + case GGML_OP_ADD: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + case GGML_OP_SUB: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_norepeat : ctx->device->pipeline_sub; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + case GGML_OP_MUL: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_norepeat : ctx->device->pipeline_mul; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + case GGML_OP_DIV: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_norepeat : ctx->device->pipeline_div; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + default: + break; + } + return nullptr; + case GGML_OP_CONCAT: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_concat_f32; + } + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_concat_f16; + } + if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { + return ctx->device->pipeline_concat_i32; + } + return nullptr; + case GGML_OP_UPSCALE: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) { + return ctx->device->pipeline_upscale_f32; + } + return nullptr; + case GGML_OP_SCALE: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_scale_f32; + } + return nullptr; + case GGML_OP_SQR: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_sqr_f32; + } + return nullptr; + case GGML_OP_SIN: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_sin_f32; + } + return nullptr; + case GGML_OP_COS: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_cos_f32; + } + return nullptr; + case GGML_OP_CLAMP: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_clamp_f32; + } + return nullptr; + case GGML_OP_PAD: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_pad_f32; + } + return nullptr; + case GGML_OP_REPEAT: + if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) { + return ctx->device->pipeline_repeat_f32; + } + return nullptr; + case GGML_OP_REPEAT_BACK: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_repeat_back_f32; + } + return nullptr; + case GGML_OP_CPY: + case GGML_OP_CONT: + case GGML_OP_DUP: + return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type); + case GGML_OP_SILU_BACK: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_silu_back_f32; + } + return nullptr; + case GGML_OP_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_norm_f32; + } + return nullptr; + case GGML_OP_GROUP_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_group_norm_f32; + } + return nullptr; + case GGML_OP_RMS_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rms_norm_f32; + } + return nullptr; + case GGML_OP_RMS_NORM_BACK: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rms_norm_back_f32; + } + return nullptr; + case GGML_OP_L2_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_l2_norm_f32; + } + return nullptr; + case GGML_OP_UNARY: + if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) || + (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) || + (src0->type != dst->type)) { + return nullptr; + } + + switch (ggml_get_unary_op(dst)) { + case GGML_UNARY_OP_SILU: + return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_GELU: + return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_GELU_QUICK: + return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_RELU: + return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_TANH: + return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_SIGMOID: + return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16]; + default: + break; + } + return nullptr; + case GGML_OP_DIAG_MASK_INF: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_diag_mask_inf_f32; + } + return nullptr; + case GGML_OP_SOFT_MAX: + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); + + if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { + return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32; + } + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16; + } + return nullptr; + case GGML_OP_SOFT_MAX_BACK: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_soft_max_back_f32; + } + return nullptr; + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + { + const int mode = ((const int32_t *) dst->op_params)[2]; + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (is_neox) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_neox_f32; + } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_neox_f16; + } + } else if (is_mrope && !is_vision) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_multi_f32; + } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_multi_f16; + } + } else if (is_vision) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_vision_f32; + } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_vision_f16; + } + } else { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_norm_f32; + } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_norm_f16; + } + } + return nullptr; + } + case GGML_OP_ARGSORT: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { + return ctx->device->pipeline_argsort_f32; + } + return nullptr; + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_sum_rows_f32; + } + return nullptr; + case GGML_OP_ARGMAX: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { + return ctx->device->pipeline_argmax_f32; + } + return nullptr; + case GGML_OP_COUNT_EQUAL: + if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I64) { + return ctx->device->pipeline_count_equal_i32; + } + return nullptr; + case GGML_OP_IM2COL: + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_im2col_f32; + } + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_im2col_f32_f16; + } + return nullptr; + case GGML_OP_TIMESTEP_EMBEDDING: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_timestep_embedding_f32; + } + return nullptr; + case GGML_OP_CONV_TRANSPOSE_1D: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_conv_transpose_1d_f32; + } + return nullptr; + case GGML_OP_POOL_2D: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_pool2d_f32; + } + return nullptr; + case GGML_OP_RWKV_WKV6: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rwkv_wkv6_f32; + } + return nullptr; + case GGML_OP_RWKV_WKV7: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rwkv_wkv7_f32; + } + return nullptr; + case GGML_OP_OPT_STEP_ADAMW: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_opt_step_adamw_f32; + } + return nullptr; + case GGML_OP_LEAKY_RELU: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_leaky_relu_f32; + } + return nullptr; + case GGML_OP_CONV_2D_DW: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + if (ggml_is_contiguous(src1)) { + return ctx->device->pipeline_conv2d_dw_whcn_f32; + } else if (ggml_is_contiguous_channels(src1)) { + return ctx->device->pipeline_conv2d_dw_cwhn_f32; + } + } + return nullptr; + default: + return nullptr; + } + + GGML_UNUSED(src2); +} + +static bool ggml_vk_op_supports_incontiguous(ggml_op op) { + switch (op) { + case GGML_OP_CPY: + case GGML_OP_GET_ROWS: + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_ROPE: + case GGML_OP_RMS_NORM: + case GGML_OP_CONV_2D_DW: + case GGML_OP_IM2COL: + return true; + default: + return false; + } +} + +static uint32_t get_misalign_bytes(ggml_backend_vk_context * ctx, const ggml_tensor * t) +{ + return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));; +} + +template void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + GGML_UNUSED(p); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(src2); + GGML_UNUSED(dst); + static_assert(!std::is_const::value, "unexpected type"); + GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0); + GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0); + GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0); + GGML_ASSERT(!dst || get_misalign_bytes(ctx, dst) == 0); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.misalign_offsets = (a_offset << 16) | d_offset; + + GGML_UNUSED(src1); + GGML_UNUSED(src2); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + GGML_ASSERT(dst->op != GGML_OP_GET_ROWS || (a_offset == 0 && b_offset == 0 && d_offset == 0)); + + p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset; + + GGML_UNUSED(src2); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.a_offset = a_offset; + p.d_offset = d_offset; + + GGML_UNUSED(src1); + GGML_UNUSED(src2); +} + +template +static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + if (src1 != nullptr) { + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + } + if (src2 != nullptr) { + std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3]; + } + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT + GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT + GGML_ASSERT(dst->buffer != nullptr); + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + const uint64_t ne0 = ne00 * ne01; + + const bool use_src1 = src1 != nullptr; + const uint64_t ne10 = use_src1 ? src1->ne[0] : 0; + const uint64_t ne11 = use_src1 ? src1->ne[1] : 0; + const uint64_t ne12 = use_src1 ? src1->ne[2] : 0; + const uint64_t ne13 = use_src1 ? src1->ne[3] : 0; + const uint64_t ne1 = ne10 * ne11; + // const uint64_t nb10 = use_src1 ? src1->nb[0] : 0; + + const bool use_src2 = src2 != nullptr; + const uint64_t ne20 = use_src2 ? src2->ne[0] : 0; + const uint64_t ne21 = use_src2 ? src2->ne[1] : 0; + const uint64_t ne22 = use_src2 ? src2->ne[2] : 0; + const uint64_t ne23 = use_src2 ? src2->ne[3] : 0; + const uint64_t ne2 = ne20 * ne21; + + const uint64_t ned0 = dst->ne[0]; + const uint64_t ned1 = dst->ne[1]; + const uint64_t ned2 = dst->ne[2]; + const uint64_t ned3 = dst->ne[3]; + const uint64_t ned = ned0 * ned1; + + init_pushconst_fastdiv(pc); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op); + + if (pipeline == nullptr) { + std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(op) << " for " << ggml_type_name(src0->type); + if (src1 != nullptr) { + std::cerr << " and " << ggml_type_name(src1->type); + } + std::cerr << " to " << ggml_type_name(dst->type) << std::endl; + GGML_ABORT("fatal error"); + } + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + return; + } + + const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op); + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = use_src1 ? (ggml_backend_vk_buffer_context *)src1->buffer->context : nullptr; + ggml_backend_vk_buffer_context * src2_buf_ctx = use_src2 ? (ggml_backend_vk_buffer_context *)src2->buffer->context : nullptr; + + vk_buffer d_X = nullptr; + size_t x_buf_offset = 0; + vk_buffer d_Y = nullptr; + size_t y_buf_offset = 0; + vk_buffer d_Z = nullptr; + size_t z_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + bool src2_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_X, x_buf_offset); + src0_uma = d_X != nullptr; + if (use_src1) { + ggml_vk_host_get(ctx->device, src1->data, d_Y, y_buf_offset); + src1_uma = d_Y != nullptr; + } + if (use_src2) { + ggml_vk_host_get(ctx->device, src2->data, d_Z, z_buf_offset); + src2_uma = d_Z != nullptr; + } + } + + uint64_t x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0; + uint64_t y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 : 0; + uint64_t z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 : 0; + uint64_t d_sz = ggml_type_size(dst->type) * ned; + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + + // Workaround for tiny tensor inputs on ROPE + if (op == GGML_OP_ROPE && use_src1 && y_sz > d_D->size) { + y_sz = VK_WHOLE_SIZE; + } + + GGML_ASSERT(d_D != nullptr); + uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + if(!src0_uma) { + d_X = src0_buf_ctx->dev_buffer; + x_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_X != nullptr); + } + if (use_src1 && !src1_uma) { + d_Y = src1_buf_ctx->dev_buffer; + y_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Y != nullptr); + } + if (use_src2 && !src2_uma) { + d_Z = src2_buf_ctx->dev_buffer; + z_buf_offset = vk_tensor_offset(src2) + src2->view_offs; + GGML_ASSERT(d_Z != nullptr); + } + // Compute misalignment offset for descriptors and store it in in push constants, then align the descriptor offsets. + init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, dst); + x_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + y_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + + if (op_supports_incontiguous) { + x_sz = ggml_nbytes(src0); + y_sz = use_src1 ? ggml_nbytes(src1) : 0; + z_sz = use_src2 ? ggml_nbytes(src2) : 0; + d_sz = ggml_nbytes(dst); + + if (x_buf_offset + x_sz >= d_X->size) { + x_sz = VK_WHOLE_SIZE; + } + if (use_src1 && y_buf_offset + y_sz >= d_Y->size) { + y_sz = VK_WHOLE_SIZE; + } + if (use_src2 && z_buf_offset + z_sz >= d_Z->size) { + z_sz = VK_WHOLE_SIZE; + } + if (d_buf_offset + d_sz >= d_D->size) { + d_sz = VK_WHOLE_SIZE; + } + } + + std::array elements; + + // Single call if dimension 2 is contiguous + GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))); + + switch (op) { + case GGML_OP_NORM: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: + case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: + { + const uint32_t nr = ggml_nrows(src0); + if (nr > 262144) { + elements = { 512, 512, CEIL_DIV(nr, 262144) }; + } else if (nr > 512) { + elements = { 512, CEIL_DIV(nr, 512), 1 }; + } else { + elements = { nr, 1, 1 }; + } + } break; + case GGML_OP_RMS_NORM: + elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 }; + break; + + case GGML_OP_SUM: + // We use GGML_OP_SUM_ROWS with 1 row. + elements = { 1, 1, 1 }; + break; + case GGML_OP_GROUP_NORM: + { + const uint32_t num_groups = dst->op_params[0]; + elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 }; + } break; + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 }; + break; + case GGML_OP_GET_ROWS: + elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) }; + break; + case GGML_OP_ARGSORT: + elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 }; + break; + case GGML_OP_IM2COL: + { + const bool is_2D = dst->op_params[6] == 1; + + const uint32_t IC = src1->ne[is_2D ? 2 : 1]; + + const uint32_t KH = is_2D ? src0->ne[1] : 1; + const uint32_t KW = src0->ne[0]; + + const uint32_t OH = is_2D ? dst->ne[2] : 1; + const uint32_t OW = dst->ne[1]; + + const uint32_t batch = src1->ne[is_2D ? 3 : 2]; + + elements = { OW * KW * KH, OH, batch * IC }; + } break; + case GGML_OP_TIMESTEP_EMBEDDING: + { + const uint32_t dim = dst->op_params[0]; + uint32_t half_ceil = (dim + 1) / 2; + elements = { half_ceil, (uint32_t)src0->ne[0], 1 }; + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1} + } break; + case GGML_OP_POOL_2D: + { + const uint32_t N = dst->ne[3]; + const uint32_t OC = dst->ne[2]; + const uint32_t OH = dst->ne[1]; + const uint32_t OW = dst->ne[0]; + elements = { N * OC * OH * OW, 1, 1}; + } break; + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_DIV: + case GGML_OP_MUL: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_CPY: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_UNARY: + case GGML_OP_CONV_2D_DW: + { + uint32_t ne = ggml_nelements(dst); + if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) { + // Convert from number of logical elements to 2- or 4-byte units. + ne /= ggml_blck_size(src0->type); + if ((ggml_type_size(src0->type) % 4) == 0) { + ne *= ggml_type_size(src0->type) / 4; + } else { + ne *= ggml_type_size(src0->type) / 2; + } + } + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; + } else { + elements = { ne, 1, 1 }; + } + } break; + default: + elements = { (uint32_t)ggml_nelements(src0), 1, 1 }; + break; + } + + if (!op_supports_incontiguous) { + if (x_sz != VK_WHOLE_SIZE) { + x_sz *= ne02 * ne03; + } + if (use_src1 && y_sz != VK_WHOLE_SIZE) { + y_sz *= ne12 * ne13; + } + if (use_src2 && z_sz != VK_WHOLE_SIZE) { + z_sz *= ne22 * ne23; + } + if (d_sz != VK_WHOLE_SIZE) { + d_sz *= ned2 * ned3; + } + } + + if (op == GGML_OP_SOFT_MAX) { + // Empty src1 is possible in soft_max, but the shader needs a buffer + vk_subbuffer subbuf_y; + if (use_src1) { + subbuf_y = { d_Y, y_buf_offset, y_sz }; + } else { + subbuf_y = { d_X, 0, x_sz }; + } + + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) { + // Empty src2 is possible in rope, but the shader needs a buffer + vk_subbuffer subbuf_z; + if (use_src2) { + subbuf_z = { d_Z, z_buf_offset, z_sz }; + } else { + subbuf_z = { d_X, 0, x_sz }; + } + + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (op == GGML_OP_IM2COL) { + // im2col uses only src1 and dst buffers + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (op == GGML_OP_COUNT_EQUAL) { + ggml_vk_sync_buffers(subctx); + // count_equal assumes that destination buffer is initialized with zeroes + ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz); + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (use_src2) { + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (use_src1) { + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else { + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } +} + +static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GET_ROWS, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + +static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 + int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 + // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused + int offset = dst->op_params[3] / 4; // offset in bytes + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ACC, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, offset, + }, dryrun); +} + +static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ADD, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + +static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SUB, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + +static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_MUL, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + +static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_DIV, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + +static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) { + GGML_ASSERT(version == 6 || version == 7); + int num_srcs = version == 6 ? 6 : 7; + + for (int i = 0; i < num_srcs; i++) { + GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type)); + } + + GGML_ASSERT(dst->buffer != nullptr); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); + GGML_ASSERT(pipeline != nullptr); + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + return; + } + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr }; + for (int i = 0; i < num_srcs; i++) { + src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context; + } + + ggml_vk_sync_buffers(subctx); + + vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr }; + size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 }; + bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false }; + + if (ctx->device->uma) { + for (int i = 0; i < num_srcs; i++) { + ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]); + srcs_uma[i] = d_srcs[i] != nullptr; + } + + ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset); + dst_uma = d_D != nullptr; + } + + uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 }; + for (int i = 0; i < num_srcs; i++) { + src_sizes[i] = ggml_nbytes(dst->src[i]); + if (!srcs_uma[i]) { + d_srcs[i] = src_buf_ctxs[i]->dev_buffer; + src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs; + } + } + + const uint64_t dst_size = ggml_nbytes(dst); + if (!dst_uma) { + d_D = dst_buf_ctx->dev_buffer; + dst_offset = vk_tensor_offset(dst) + dst->view_offs; + } + + std::array elements = { + (uint32_t)(pc.B * pc.H), + 1, + 1 + }; + + if (version == 6) { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] }, + vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] }, + vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] }, + vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] }, + vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] }, + vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, + vk_subbuffer{ d_D, dst_offset, dst_size } + }, pc, elements); + } else if (version == 7) { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] }, + vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] }, + vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] }, + vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] }, + vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] }, + vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, + vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] }, + vk_subbuffer{ d_D, dst_offset, dst_size } + }, pc, elements); + } else { + // shouldn't happen + GGML_ASSERT(false); + } +} + +static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const size_t seq_length = dst->src[0]->ne[2]; + const size_t n_embed = dst->ne[0]; + const size_t n_heads = dst->src[0]->ne[1]; + const size_t n_seqs = dst->src[5]->ne[1]; + + ggml_vk_op_f32_wkv( + ctx, subctx, dst, + { + (uint32_t)n_seqs, + (uint32_t)seq_length, + (uint32_t)n_embed, + (uint32_t)n_heads, + }, + 6, + dryrun + ); +} + +static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const size_t seq_length = dst->src[0]->ne[2]; + const size_t n_embed = dst->ne[0]; + const size_t n_heads = dst->src[0]->ne[1]; + const size_t n_seqs = dst->src[6]->ne[1]; + + ggml_vk_op_f32_wkv( + ctx, subctx, dst, + { + (uint32_t)n_seqs, + (uint32_t)seq_length, + (uint32_t)n_embed, + (uint32_t)n_heads, + }, + 7, + dryrun + ); +} + +static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) { + const ggml_tensor * x = dst->src[0]; + const ggml_tensor * g = dst->src[1]; + const ggml_tensor * gm = dst->src[2]; + const ggml_tensor * gv = dst->src[3]; + const ggml_tensor * p = dst->src[4]; + + GGML_ASSERT(x->type == GGML_TYPE_F32); + GGML_ASSERT(g->type == GGML_TYPE_F32); + GGML_ASSERT(gm->type == GGML_TYPE_F32); + GGML_ASSERT(gv->type == GGML_TYPE_F32); + GGML_ASSERT(p->type == GGML_TYPE_F32); + GGML_ASSERT(dst->buffer != nullptr); + GGML_ASSERT(ggml_is_contiguous(x)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(gm)); + GGML_ASSERT(ggml_is_contiguous(gv)); + GGML_ASSERT(ggml_is_contiguous(p)); + GGML_ASSERT(ggml_are_same_shape(x, g)); + GGML_ASSERT(ggml_are_same_shape(x, gm)); + GGML_ASSERT(ggml_are_same_shape(x, gv)); + GGML_ASSERT(ggml_nelements(p) == 7); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, g, gm, gv, dst, GGML_OP_OPT_STEP_ADAMW); + GGML_ASSERT(pipeline != nullptr); + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + return; + } + + ggml_backend_vk_buffer_context * x_buf_ctx = (ggml_backend_vk_buffer_context *)x->buffer->context; + ggml_backend_vk_buffer_context * g_buf_ctx = (ggml_backend_vk_buffer_context *)g->buffer->context; + ggml_backend_vk_buffer_context * gm_buf_ctx = (ggml_backend_vk_buffer_context *)gm->buffer->context; + ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context; + ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context; + + ggml_vk_sync_buffers(subctx); + + vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr; + size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0; + bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, x->data, d_X, x_offset); + ggml_vk_host_get(ctx->device, g->data, d_G, g_offset); + ggml_vk_host_get(ctx->device, gm->data, d_GM, gm_offset); + ggml_vk_host_get(ctx->device, gv->data, d_GV, gv_offset); + ggml_vk_host_get(ctx->device, p->data, d_P, p_offset); + + X_uma = d_X != nullptr; + G_uma = d_G != nullptr; + GM_uma = d_GM != nullptr; + GV_uma = d_GV != nullptr; + P_uma = d_P != nullptr; + } + + if (!X_uma) { + d_X = x_buf_ctx->dev_buffer; + x_offset = vk_tensor_offset(x) + x->view_offs; + } + if (!G_uma) { + d_G = g_buf_ctx->dev_buffer; + g_offset = vk_tensor_offset(g) + g->view_offs; + } + if (!GM_uma) { + d_GM = gm_buf_ctx->dev_buffer; + gm_offset = vk_tensor_offset(gm) + gm->view_offs; + } + if (!GV_uma) { + d_GV = gv_buf_ctx->dev_buffer; + gv_offset = vk_tensor_offset(gv) + gv->view_offs; + } + if (!P_uma) { + d_P = p_buf_ctx->dev_buffer; + p_offset = vk_tensor_offset(p) + p->view_offs; + } + + const uint64_t x_size = ggml_nbytes(x); + const uint64_t g_size = ggml_nbytes(g); + const uint64_t gm_size = ggml_nbytes(gm); + const uint64_t gv_size = ggml_nbytes(gv); + const uint64_t p_size = ggml_nbytes(p); + + std::array elements = { (uint32_t)ggml_nelements(x), 1, 1 }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_X, x_offset, x_size }, + vk_subbuffer{ d_G, g_offset, g_size }, + vk_subbuffer{ d_GM, gm_offset, gm_size }, + vk_subbuffer{ d_GV, gv_offset, gv_size }, + vk_subbuffer{ d_P, p_offset, p_size }, + }, pc, elements); +} + +static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const size_t n = ggml_nelements(dst->src[0]); + + ggml_vk_op_f32_opt_step_adamw( + ctx, subctx, dst, + { (uint32_t)n, 0, 0.0f, 0.0f }, + dryrun + ); +} + +static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + int * op_params = (int *)dst->op_params; + + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONCAT, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, op_params[0], + }, dryrun); +} + +static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + + const float sf0 = (float)dst->ne[0] / src0->ne[0]; + const float sf1 = (float)dst->ne[1] / src0->ne[1]; + const float sf2 = (float)dst->ne[2] / src0->ne[2]; + const float sf3 = (float)dst->ne[3] / src0->ne[3]; + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, { + (uint32_t)ggml_nelements(dst), 0, 0, + (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3], + sf0, sf1, sf2, sf3, + }, dryrun); +} + +static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + op_params[0], 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + op_params[0], op_params[1], + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + uint32_t ne = (uint32_t)ggml_nelements(src0); + if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) { + // Convert from number of logical elements to 2- or 4-byte units. + ne /= ggml_blck_size(src0->type); + if ((ggml_type_size(src0->type) % 4) == 0) { + ne *= ggml_type_size(src0->type) / 4; + } else { + ne *= ggml_type_size(src0->type) / 2; + } + } + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, { + ne, + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); +} + +static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); +} + +static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const int * int_op_params = (const int *)dst->op_params; + const float * float_op_params = (const float *)dst->op_params; + + const uint32_t num_groups = int_op_params[0]; + const float eps = float_op_params[1]; + const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun); +} + +static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + op_params[0], 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); +} + +static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); +} + +static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); +} + +static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + int32_t * op_params = (int32_t *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun); +} + +static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + + float scale = op_params[0]; + float max_bias = op_params[1]; + + const uint32_t ncols = (uint32_t)src0->ne[0]; + const uint32_t nrows_x = (uint32_t)ggml_nrows(src0); + const uint32_t nrows_y = (uint32_t)src0->ne[1]; + + const uint32_t n_head_kv = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, { + ncols, + src1 != nullptr ? nrows_y : (uint32_t)0, + scale, max_bias, + m0, m1, + n_head_log2, + nrows_x, + }, dryrun); +} + +static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], op_params[1] }, dryrun); +} + +static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) { + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + // const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + const float freq_base = ((float *) dst->op_params)[5]; + const float freq_scale = ((float *) dst->op_params)[6]; + const float ext_factor = ((float *) dst->op_params)[7]; + const float attn_factor = ((float *) dst->op_params)[8]; + const float beta_fast = ((float *) dst->op_params)[9]; + const float beta_slow = ((float *) dst->op_params)[10]; + int sections[4] {}; + if (mode & GGML_ROPE_TYPE_MROPE) { + memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4); + } + + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type); + uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, { + (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], + freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, + src2 != nullptr, (uint32_t)src0->ne[2], s1, s2, + sections[0], sections[1], sections[2], sections[3], backprop + }, dryrun); +} + +static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + int32_t * op_params = (int32_t *)dst->op_params; + + uint32_t ncols = src0->ne[0]; + + uint32_t ncols_pad = 1; + while (ncols_pad < ncols) { + ncols_pad *= 2; + } + + GGML_ASSERT(ncols_pad <= 1024); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, { + ncols, + ncols_pad, + op_params[0], + }, dryrun); +} + +static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); +} + +static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); +} + +static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); +} + +static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); +} + +static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const int32_t s0 = dst->op_params[0]; + const int32_t s1 = dst->op_params[1]; + const int32_t p0 = dst->op_params[2]; + const int32_t p1 = dst->op_params[3]; + const int32_t d0 = dst->op_params[4]; + const int32_t d1 = dst->op_params[5]; + + const bool is_2D = dst->op_params[6] == 1; + + const uint32_t IC = src1->ne[is_2D ? 2 : 1]; + const uint32_t IH = is_2D ? src1->ne[1] : 1; + const uint32_t IW = src1->ne[0]; + + const uint32_t KH = is_2D ? src0->ne[1] : 1; + const uint32_t KW = src0->ne[0]; + + const uint32_t OH = is_2D ? dst->ne[2] : 1; + const uint32_t OW = dst->ne[1]; + + const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 + const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 + + const uint32_t pelements = OW * KW * KH; + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, { + batch_offset, offset_delta, + IC, IW, IH, OW, OH, KW, KH, + pelements, + IC * KH * KW, + s0, s1, p0, p1, d0, d1, + }, dryrun); +} + +static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t dim = dst->op_params[0]; + const uint32_t max_period = dst->op_params[1]; + const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, { + nb1, dim, max_period, + }, dryrun); +} + +static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + // src0: (K, Cout, Cin, 1) -- kernel + // src1: (L, Cin, 1, 1) -- input + // dst: (*, Cout, 1, 1) + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + + const int32_t s0 = dst->op_params[0]; + + vk_op_conv_transpose_1d_push_constants p{}; + p.Cout = static_cast(ne01); + p.Cin = static_cast(ne02); + p.K = static_cast(ne00); + p.L = static_cast(ne10); + p.KL = static_cast(ne0); + p.nb01 = static_cast(nb01 / nb00); + p.nb02 = static_cast(nb02 / nb00); + p.nb11 = static_cast(nb11 / nb10); + p.nb1 = static_cast(nb1 / nb0); + p.s0 = static_cast(s0); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun); +} + +static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + uint32_t op = static_cast(dst->op_params[0]); + const int32_t k1 = dst->op_params[1]; + const int32_t k0 = dst->op_params[2]; + const int32_t s1 = dst->op_params[3]; + const int32_t s0 = dst->op_params[4]; + const int32_t p1 = dst->op_params[5]; + const int32_t p0 = dst->op_params[6]; + + const uint32_t IH = src0->ne[1]; + const uint32_t IW = src0->ne[0]; + + const uint32_t N = dst->ne[3]; + + const uint32_t OC = dst->ne[2]; + const uint32_t OH = dst->ne[1]; + const uint32_t OW = dst->ne[0]; + + const uint32_t parallel_elements = N * OC * OH * OW; + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_POOL_2D, { + IW, IH, OW, OH, OC, + parallel_elements, + op, + k0, k1, s0, s1, p0, p1, + }, dryrun); +} + +static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + vk_op_conv2d_dw_push_constants p{}; + p.ne = ggml_nelements(dst); + p.channels = dst->ne[2]; + p.batches = dst->ne[3]; + p.dst_w = dst->ne[0]; + p.dst_h = dst->ne[1]; + p.src_w = src1->ne[0]; + p.src_h = src1->ne[1]; + p.knl_w = src0->ne[0]; + p.knl_h = src0->ne[1]; + p.stride_x = dst->op_params[0]; + p.stride_y = dst->op_params[1]; + p.pad_x = dst->op_params[2]; + p.pad_y = dst->op_params[3]; + p.dilation_x = dst->op_params[4]; + p.dilation_y = dst->op_params[5]; + + GGML_ASSERT(src0->ne[3] == p.channels); + GGML_ASSERT(src1->ne[3] == p.batches); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun); +} + +static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const float * op_params = (const float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun); +} + +#ifdef GGML_VULKAN_RUN_TESTS +static void ggml_vk_print_matrix_area(const void * data, ggml_type type, int ne0, int ne1, int i0, int i1, int i2) { + if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16) { + return; + } + i0 = std::max(i0, 5); + i1 = std::max(i1, 5); + i2 = std::max(i2, 0); + fprintf(stderr, " "); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + fprintf(stderr, "%7d ", idx1); + } + fprintf(stderr, "\n"); + for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { + fprintf(stderr, "%7d: ", idx0); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + if (idx0 >= 0 && idx0 < ne0 && idx1 >= 0 && idx1 < ne1) { + float val; + if (type == GGML_TYPE_F32) { + val = *((const float *) data + i2*ne1*ne0 + idx1*ne0 + idx0); + } else if (type == GGML_TYPE_F16) { + val = ggml_fp16_to_fp32(*((const ggml_fp16_t *) data + i2*ne1*ne0 + idx1*ne0 + idx0)); + } else { + GGML_ABORT("fatal error"); + } + fprintf(stderr, "% 7.2f ", val); + } else { + fprintf(stderr, " "); + } + } + fprintf(stderr, "\n"); + } +} + +template +static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, int split_k, int shader_size) { + VK_LOG_DEBUG("ggml_vk_test_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << shader_size << ")"); + const size_t x_ne = m * k * batch; + const size_t y_ne = k * n * batch; + const size_t d_ne = m * n * batch; + + vk_pipeline p; + std::string shname; + if (shader_size == 0) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->a_s; + shname = "F32_ALIGNED_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->a_s; + shname = "F32_F16_ALIGNED_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_s; + shname = "F16_F32_ALIGNED_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->a_s; + shname = "F16_ALIGNED_S"; + } else { + GGML_ABORT("fatal error"); + } + } else if (shader_size == 1) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->a_m; + shname = "F32_ALIGNED_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->a_m; + shname = "F32_F16_ALIGNED_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_m; + shname = "F16_F32_ALIGNED_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->a_m; + shname = "F16_ALIGNED_M"; + } else { + GGML_ABORT("fatal error"); + } + } else if (shader_size == 2) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->a_l; + shname = "F32_ALIGNED_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->a_l; + shname = "F32_F16_ALIGNED_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_l; + shname = "F16_F32_ALIGNED_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->a_l; + shname = "F16_ALIGNED_L"; + } else { + GGML_ABORT("fatal error"); + } + } else { + GGML_ASSERT(0); + } + + const size_t kpad = ggml_vk_align_size(k, p->align); + + if (k != kpad) { + if (shader_size == 0) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->s; + shname = "F32_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->s; + shname = "F32_F16_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->s; + shname = "F16_F32_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->s; + shname = "F16_S"; + } + } else if (shader_size == 1) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->m; + shname = "F32_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->m; + shname = "F32_F16_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->m; + shname = "F16_F32_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->m; + shname = "F16_M"; + } + } else if (shader_size == 2) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->l; + shname = "F32_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->l; + shname = "F32_F16_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->l; + shname = "F16_F32_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->l; + shname = "F16_L"; + } + } + } + + ggml_pipeline_request_descriptor_sets(ctx, p, num_it); + if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it); + + if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { + // Resize buffer + if (ctx->prealloc_split_k != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_split_k); + } + ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal); + } + } + + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } + + ggml_pipeline_allocate_descriptor_sets(ctx); + + vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); + + X_TYPE* x = (X_TYPE *) malloc(sizeof(X_TYPE) * x_ne); + Y_TYPE* y = (Y_TYPE *) malloc(sizeof(Y_TYPE) * y_ne); + float* d = (float *) malloc(sizeof(float) * d_ne); + + for (size_t i = 0; i < x_ne; i++) { + if (std::is_same()) { + x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // x[i] = 1.0f; + // x[i] = i + 1; + // x[i] = (i % k == i / k) ? 1.0f : 0.0f; + } else if (std::is_same()) { + x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f); + // x[i] = ggml_fp32_to_fp16(1.0f); + // x[i] = ggml_fp32_to_fp16(i + 1); + // x[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f); + } else { + GGML_ABORT("fatal error"); + } + } + for (size_t i = 0; i < y_ne; i++) { + if (std::is_same()) { + y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // y[i] = (i % k == i / k) ? 1.0f : 0.0f; + // y[i] = i + 1; + } else if (std::is_same()) { + y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f); + // y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f); + // y[i] = ggml_fp32_to_fp16(i + 1); + } else { + GGML_ABORT("fatal error"); + } + } + + ggml_vk_buffer_write(d_X, 0, x, sizeof(X_TYPE) * k * m * batch); + ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch); + + vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ggml_vk_ctx_begin(ctx->device, subctx); + for (size_t i = 0; i < num_it; i++) { + ggml_vk_matmul( + ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), + m, n, k, + k, k, m, k*m, k*n, m*n, + split_k, batch, batch, batch, 1, 1, n + ); + } + ggml_vk_ctx_end(subctx); + + auto begin = std::chrono::high_resolution_clock::now(); + ggml_vk_submit(subctx, ctx->fence); + VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences"); + ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_queue_command_pools_cleanup(ctx->device); + + auto end = std::chrono::high_resolution_clock::now(); + double time = std::chrono::duration_cast(end-begin).count() / 1000.0; + + // copy dst to host + ggml_vk_buffer_read(d_D, 0, d, sizeof(float) * d_ne); + + float * d_chk = (float *) malloc(sizeof(float) * d_ne); + + ggml_init_params iparams = { + /*.mem_size =*/ 1024*1024*1024, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ggml_ctx = ggml_init(iparams); + + ggml_type src0_type; + ggml_type src1_type; + + if (std::is_same()) { + src0_type = GGML_TYPE_F32; + } else if (std::is_same()) { + src0_type = GGML_TYPE_F16; + } else { + GGML_ABORT("fatal error"); + } + if (std::is_same()) { + src1_type = GGML_TYPE_F32; + } else if (std::is_same()) { + src1_type = GGML_TYPE_F16; + } else { + GGML_ABORT("fatal error"); + } + + ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, src0_type, k, m, batch); + ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, src1_type, k, n, batch); + ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml); + + src0_ggml->data = x; + src1_ggml->data = y; + tensor_ggml->data = d_chk; + + ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); + ggml_build_forward_expand(cgraph, tensor_ggml); + + ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1); + + ggml_free(ggml_ctx); + + double avg_err = 0.0; + int first_err_n = -1; + int first_err_m = -1; + int first_err_b = -1; + + for (size_t i = 0; i < m*n*batch; i++) { + double err = std::fabs(d[i] - d_chk[i]); + avg_err += err; + + if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) { + first_err_b = i / (m * n); + first_err_n = (i % (m * n)) / m; + first_err_m = (i % (m * n)) % m; + } + } + + avg_err /= m * n; + + double tflops = 2.0*m*n*k*batch*num_it / (time / 1000.0) / (1000.0*1000.0*1000.0*1000.0); + + std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; + + if (avg_err > 0.1 || std::isnan(avg_err)) { + std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; + std::cerr << "Actual result: " << std::endl << std::endl; + ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + std::cerr << "Expected result: " << std::endl << std::endl; + ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + if (split_k > 1) { + float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k); + ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k); + + std::cerr << "d_buf0: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf1: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf2: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf3: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + free(split_k_buf); + } + } + + free(d_chk); + + ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); + ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); + + ggml_vk_destroy_buffer(d_X); + ggml_vk_destroy_buffer(d_Y); + ggml_vk_destroy_buffer(d_D); + + free(x); + free(y); + free(d); +} + +static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, int i0, int i1, int i2, int i3) { + if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16) { + return; + } + i0 = std::max(i0, 5); + i1 = std::max(i1, 5); + i2 = std::max(i2, 0); + i3 = std::max(i3, 0); + fprintf(stderr, " "); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + fprintf(stderr, "%7d ", idx1); + } + fprintf(stderr, "\n"); + for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { + fprintf(stderr, "%7d: ", idx0); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) { + float val; + if (tensor->type == GGML_TYPE_F32) { + val = *(float *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); + } else if (tensor->type == GGML_TYPE_F16) { + val = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0])); + } else { + GGML_ABORT("fatal error"); + } + fprintf(stderr, "% 7.2f ", val); + } else { + fprintf(stderr, " "); + } + } + fprintf(stderr, "\n"); + } +} + +static void ggml_vk_quantize_data(const float * from, void * to, size_t ne, ggml_type quant) { + ggml_quantize_chunk(quant, from, to, 0, 1, ne, nullptr); +} + +static void ggml_vk_dequantize_data(const void * from, float * to, size_t ne, ggml_type quant) { + if (quant == GGML_TYPE_F32) { + memcpy(to, from, sizeof(float) * ne); + return; + } + + const auto * tt = ggml_get_type_traits(quant); + + ggml_to_float_t dequant_fn = tt->to_float; + + dequant_fn(from, to, ne); +} + +static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) { + VK_LOG_DEBUG("ggml_vk_test_dequant(" << ne << ")"); + const size_t x_sz = sizeof(float) * ne; + const size_t x_sz_f16 = sizeof(ggml_fp16_t) * ne; + const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant); + float * x = (float *) malloc(x_sz); + void * qx = malloc(qx_sz); + vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, vk::MemoryPropertyFlagBits::eDeviceLocal); + float * x_ref = (float *) malloc(x_sz); + ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16); + + for (size_t i = 0; i < ne; i++) { + x[i] = rand() / (float)RAND_MAX; + } + + vk_pipeline p = ggml_vk_get_to_fp16(ctx, quant); + + ggml_vk_quantize_data(x, qx, ne, quant); + ggml_vk_dequantize_data(qx, x_ref, ne, quant); + + ggml_pipeline_request_descriptor_sets(ctx, p, 1); + + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } + + ggml_pipeline_allocate_descriptor_sets(ctx); + + ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); + + vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ggml_vk_ctx_begin(ctx->device, subctx); + const std::vector pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne }; + ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc, { (uint32_t)ne, 1, 1}); + ggml_vk_ctx_end(subctx); + + auto begin = std::chrono::high_resolution_clock::now(); + + ggml_vk_submit(subctx, ctx->fence); + VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences"); + ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_queue_command_pools_cleanup(ctx->device); + + auto end = std::chrono::high_resolution_clock::now(); + + double ms_dequant = std::chrono::duration_cast(end-begin).count() / 1000.0; + ggml_vk_buffer_read(x_buf, 0, x_chk, x_sz_f16); + + int first_err = -1; + + double avg_err = 0.0; + for (size_t i = 0; i < ne; i++) { + double error = std::fabs(x_ref[i] - ggml_fp16_to_fp32(x_chk[i])); + avg_err += error; + + if (first_err < 0 && error > 0.05) { + first_err = i; + } + } + + avg_err /= ne; + + std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err << std::endl; + + if (avg_err > 0.1) { + std::cerr << "first_error = " << first_err << std::endl; + std::cerr << "Actual result: " << std::endl << std::endl; + for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) { + std::cerr << ggml_fp16_to_fp32(x_chk[i]) << ", "; + } + std::cerr << std::endl << "Expected result: " << std::endl << std::endl; + for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) { + std::cerr << x_ref[i] << ", "; + } + std::cerr << std::endl; + } + + ggml_vk_destroy_buffer(x_buf); + ggml_vk_destroy_buffer(qx_buf); + + free(x); + free(qx); + free(x_ref); + free(x_chk); +} + +// This does not work without ggml q8_1 quantization support +// +// typedef uint16_t ggml_half; +// typedef uint32_t ggml_half2; +// +// #define QK8_1 32 +// typedef struct { +// union { +// struct { +// ggml_half d; // delta +// ggml_half s; // d * sum(qs[i]) +// } GGML_COMMON_AGGR_S; +// ggml_half2 ds; +// } GGML_COMMON_AGGR_U; +// int8_t qs[QK8_1]; // quants +// } block_q8_1; +// +// static void ggml_vk_test_quantize(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) { +// VK_LOG_DEBUG("ggml_vk_test_quantize(" << ne << ")"); +// GGML_ASSERT(quant == GGML_TYPE_Q8_1); +// +// const size_t x_sz = sizeof(float) * ne; +// const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant); +// float * x = (float *) malloc(x_sz); +// block_q8_1 * qx = (block_q8_1 *)malloc(qx_sz); +// block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz); +// vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); +// vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); +// +// for (size_t i = 0; i < ne; i++) { +// x[i] = rand() / (float)RAND_MAX; +// } +// +// vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant); +// +// ggml_pipeline_request_descriptor_sets(ctx, p, 1); +// +// if (ctx->device->need_compiles) { +// ggml_vk_load_shaders(ctx->device); +// } +// +// ggml_pipeline_allocate_descriptor_sets(ctx); +// +// ggml_vk_buffer_write(x_buf, 0, x, x_sz); +// +// vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); +// ggml_vk_ctx_begin(ctx->device, subctx); +// ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(x_buf), ggml_vk_subbuffer(qx_buf), ne); +// ggml_vk_ctx_end(subctx); +// +// auto begin = std::chrono::high_resolution_clock::now(); +// +// ggml_vk_submit(subctx, ctx->fence); +// VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences"); +// ctx->device->device.resetFences({ ctx->fence }); +// ggml_vk_queue_command_pools_cleanup(ctx->device); +// +// auto end = std::chrono::high_resolution_clock::now(); +// +// double ms_quant = std::chrono::duration_cast(end-begin).count() / 1000.0; +// ggml_vk_buffer_read(qx_buf, 0, qx, qx_sz); +// +// ggml_vk_quantize_data(x, qx_res, ne, quant); +// +// int first_err = -1; +// +// for (size_t i = 0; i < ne / 32; i++) { +// double error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d)); +// +// if (first_err < 0 && error > 0.1) { +// first_err = i; +// } +// +// error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s)); +// +// if (first_err < 0 && error > 0.1) { +// first_err = i; +// } +// +// for (size_t j = 0; j < 32; j++) { +// uint64_t error = std::abs(qx_res[i].qs[j] - qx[i].qs[j]); +// +// if (first_err < 0 && error > 1) { +// first_err = i; +// } +// } +// } +// +// std::cerr << "TEST QUANTIZE " << ggml_type_name(quant) << " time=" << ms_quant << "ms " << (first_err == -1 ? "CORRECT" : "INCORRECT") << std::endl; +// +// if (first_err != -1) { +// std::cerr << "first_error = " << first_err << std::endl; +// std::cerr << "Actual result: " << std::endl << std::endl; +// std::cout << "d=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " "; +// for (size_t j = 0; j < 32; j++) { +// std::cout << " qs" << j << "=" << (uint32_t)qx[first_err].qs[j] << " "; +// } +// std::cerr << std::endl << std::endl << "Expected result: " << std::endl << std::endl; +// std::cout << "d=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " "; +// for (size_t j = 0; j < 32; j++) { +// std::cout << " qs" << j << "=" << (uint32_t)qx_res[first_err].qs[j] << " "; +// } +// std::cerr << std::endl; +// } +// +// ggml_vk_destroy_buffer(x_buf); +// ggml_vk_destroy_buffer(qx_buf); +// +// free(x); +// free(qx); +// free(qx_res); +// } + +static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant, bool mmq = false) { + VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")"); + const size_t x_ne = m * k * batch; + const size_t y_ne = k * n * batch; + const size_t d_ne = m * n * batch; + + vk_matmul_pipeline2 * pipelines; + + if (mmq) { + pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1; + } else { + pipelines = ctx->device->pipeline_dequant_mul_mat_mat; + } + + const bool fp16acc = ctx->device->fp16; + + vk_pipeline p; + std::string shname; + if (shader_size == 0) { + p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s; + shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S"; + } else if (shader_size == 1) { + p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m; + shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M"; + } else if (shader_size == 2) { + p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l; + shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L"; + } else { + GGML_ASSERT(0); + } + + const size_t kpad = mmq ? 0 : ggml_vk_align_size(k, p->align); + + if (mmq || k != kpad) { + if (shader_size == 0) { + p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s; + shname = std::string(ggml_type_name(quant)) + "_S"; + } else if (shader_size == 1) { + p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m; + shname = std::string(ggml_type_name(quant)) + "_M"; + } else if (shader_size == 2) { + p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l; + shname = std::string(ggml_type_name(quant)) + "_L"; + } else { + GGML_ASSERT(0); + } + } + + if (p == nullptr) { + std::cerr << "error: no pipeline for ggml_vk_test_dequant_matmul " << ggml_type_name(quant) << std::endl; + return; + } + + const size_t x_sz = sizeof(float) * x_ne; + const size_t y_sz = sizeof(float) * y_ne; + const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant); + const size_t qy_sz = mmq ? y_ne * ggml_type_size(GGML_TYPE_Q8_1)/ggml_blck_size(GGML_TYPE_Q8_1) : y_sz; + const size_t d_sz = sizeof(float) * d_ne; + float * x = (float *) malloc(x_sz); + float * y = (float *) malloc(y_sz); + void * qx = malloc(qx_sz); + vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); + float * d = (float *) malloc(d_sz); + float * d_chk = (float *) malloc(d_sz); + + for (size_t i = 0; i < x_ne; i++) { + x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // x[i] = (i % k == i / k) ? 1.0f : 0.0f; + // x[i] = i % k; + } + + ggml_vk_quantize_data(x, qx, x_ne, quant); + + for (size_t i = 0; i < y_ne; i++) { + y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // y[i] = (i % k == i / k) ? 1.0f : 0.0f; + // y[i] = i % k; + } + + ggml_pipeline_request_descriptor_sets(ctx, p, num_it); + if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it); + + if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { + // Resize buffer + if (ctx->prealloc_split_k != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_split_k); + } + ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal); + } + } + if (mmq) { + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_quantize_q8_1, num_it); + } + + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } + + ggml_pipeline_allocate_descriptor_sets(ctx); + + ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); + ggml_vk_buffer_write(y_buf, 0, y, y_sz); + + vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ggml_vk_ctx_begin(ctx->device, subctx); + if (mmq) { + for (size_t i = 0; i < num_it; i++) { + ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne); + ggml_vk_matmul( + ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k }, + m, n, k, + k, k, m, k*m, k*n, m*n, + split_k, batch, batch, batch, 1, 1, n + ); + } + } else { + for (size_t i = 0; i < num_it; i++) { + ggml_vk_matmul( + ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k }, + m, n, k, + k, k, m, k*m, k*n, m*n, + split_k, batch, batch, batch, 1, 1, n + ); + } + } + ggml_vk_ctx_end(subctx); + + auto begin = std::chrono::high_resolution_clock::now(); + + ggml_vk_submit(subctx, ctx->fence); + VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences"); + ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_queue_command_pools_cleanup(ctx->device); + + auto end = std::chrono::high_resolution_clock::now(); + + double time_ms = std::chrono::duration_cast(end-begin).count() / 1000.0; + ggml_vk_buffer_read(d_buf, 0, d, d_sz); + + ggml_init_params iparams = { + /*.mem_size =*/ 1024*1024*1024, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ggml_ctx = ggml_init(iparams); + + ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, quant, k, m, batch); + ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, GGML_TYPE_F32, k, n, batch); + ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml); + + src0_ggml->data = qx; + src1_ggml->data = y; + tensor_ggml->data = d_chk; + + ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); + ggml_build_forward_expand(cgraph, tensor_ggml); + + ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1); + + ggml_free(ggml_ctx); + + double avg_err = 0.0; + int first_err_n = -1; + int first_err_m = -1; + int first_err_b = -1; + + for (size_t i = 0; i < m*n*batch; i++) { + double err = std::fabs(d[i] - d_chk[i]); + avg_err += err; + + if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) { + first_err_b = i / (m * n); + first_err_n = (i % (m * n)) / m; + first_err_m = (i % (m * n)) % m; + } + } + + avg_err /= m * n; + + double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0); + + std::cerr << "TEST dequant matmul " << shname; + if (mmq) { + std::cerr << " mmq"; + } + std::cerr << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; + + if (avg_err > 0.01 || std::isnan(avg_err)) { + std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; + std::cerr << "Actual result: " << std::endl << std::endl; + ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + std::cerr << std::endl; + std::cerr << "Expected result: " << std::endl << std::endl; + ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "src0: " << std::endl << std::endl; + ggml_vk_print_matrix_area(x, GGML_TYPE_F32, k, m, first_err_m, first_err_n, first_err_b); + std::cerr << std::endl; + std::cerr << "src1: " << std::endl << std::endl; + ggml_vk_print_matrix_area(y, GGML_TYPE_F32, k, n, first_err_m, first_err_n, first_err_b); + + if (split_k > 1) { + float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k); + ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k); + + std::cerr << "d_buf0: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf1: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf2: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf3: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + free(split_k_buf); + } + } + + ggml_vk_destroy_buffer(qx_buf); + ggml_vk_destroy_buffer(y_buf); + ggml_vk_destroy_buffer(qy_buf); + ggml_vk_destroy_buffer(d_buf); + + free(x); + free(qx); + free(y); + free(d); + free(d_chk); +} +#endif + +static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { +#if defined(GGML_VULKAN_RUN_TESTS) + const std::vector vals { + 512, 512, 128, + 128, 512, 512, + 4096, 512, 4096, + 11008, 512, 4096, + 4096, 512, 11008, + 32000, 512, 4096, + 8, 8, 8, + 100, 46, 576, + 623, 111, 128, + 100, 46, 558, + 512, 1, 256, + 128, 110, 622, + 511, 511, 127, + 511, 511, 7, + 511, 511, 17, + 49, 49, 128, + 128, 49, 49, + 4096, 49, 4096, + }; + const size_t num_it = 100; + + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0); + + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true); + + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0); + + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true); + + abort(); + + for (size_t i = 0; i < vals.size(); i += 3) { + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2); + std::cerr << '\n'; + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2); + std::cerr << '\n'; + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2); + std::cerr << '\n' << std::endl; + + if (vals[i + 2] % 32 == 0) { + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_0); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_0); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_0); + std::cerr << '\n' << std::endl; + } + + if (vals[i + 2] % 256 == 0) { + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_K); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_K); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_K); + std::cerr << '\n' << std::endl; + } + } + + GGML_ABORT("fatal error"); +#endif + + if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) { + VK_LOG_MEMORY("ggml_vk_preallocate_buffers(x_size: " << ctx->prealloc_size_x << ")"); + // Resize buffer + if (ctx->prealloc_x != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_x); + } + ctx->prealloc_x = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_x); + } + if (ctx->prealloc_y == nullptr || (ctx->prealloc_size_y > 0 && ctx->prealloc_y->size < ctx->prealloc_size_y)) { + VK_LOG_MEMORY("ggml_vk_preallocate_buffers(y_size: " << ctx->prealloc_size_y << ")"); + // Resize buffer + if (ctx->prealloc_y != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_y); + } + ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y); + } + if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) { + VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")"); + // Resize buffer + if (ctx->prealloc_split_k != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_split_k); + } + ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k); + } +} + +static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready); + +// Returns true if node has enqueued work into the queue, false otherwise +// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution. +static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){ + if (ggml_is_empty(node) || !node->buffer) { + return false; + } + + VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")"); + ctx->semaphore_idx = 0; + + const ggml_tensor * src0 = node->src[0]; + const ggml_tensor * src1 = node->src[1]; + const ggml_tensor * src2 = node->src[2]; + const ggml_tensor * src3 = node->src[3]; + + switch (node->op) { + // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_NONE: + return false; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_SIGMOID: + break; + default: + return false; + } + break; + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_GET_ROWS: + case GGML_OP_ADD: + case GGML_OP_ACC: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_CPY: + case GGML_OP_CONT: + case GGML_OP_DUP: + case GGML_OP_SILU_BACK: + case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + case GGML_OP_ARGSORT: + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: + case GGML_OP_IM2COL: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_TRANSPOSE_1D: + case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D_DW: + case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: + case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: + case GGML_OP_OPT_STEP_ADAMW: + break; + default: + std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl; + GGML_ABORT("fatal error"); + return false; + } + + vk_context compute_ctx; + + if (!dryrun) { + if (ctx->compute_ctx.expired()) { + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + } else { + compute_ctx = ctx->compute_ctx.lock(); + } + } else { + switch (node->op) { + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_ACC: + case GGML_OP_GET_ROWS: + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_CPY: + case GGML_OP_CONT: + case GGML_OP_DUP: + case GGML_OP_SILU_BACK: + case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: + case GGML_OP_UNARY: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + case GGML_OP_ARGSORT: + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: + case GGML_OP_IM2COL: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_TRANSPOSE_1D: + case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D_DW: + case GGML_OP_LEAKY_RELU: + { + // These operations all go through ggml_vk_op_f32, so short-circuit and + // do the only thing needed for the dryrun. + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + return false; + } + default: + break; + } + } + + switch (node->op) { + case GGML_OP_REPEAT: + ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_REPEAT_BACK: + ggml_vk_repeat_back(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_ACC: + ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_GET_ROWS: + ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_ADD: + ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_SUB: + ggml_vk_sub(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_MUL: + ggml_vk_mul(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_DIV: + ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_CONCAT: + ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_UPSCALE: + ggml_vk_upscale(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SCALE: + ggml_vk_scale(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SQR: + ggml_vk_sqr(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SIN: + ggml_vk_sin(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_COS: + ggml_vk_cos(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_CLAMP: + ggml_vk_clamp(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_PAD: + ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_CPY: + case GGML_OP_CONT: + case GGML_OP_DUP: + ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SILU_BACK: + ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_NORM: + ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_GROUP_NORM: + ggml_vk_group_norm(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_RMS_NORM: + ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_RMS_NORM_BACK: + ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_L2_NORM: + ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_SIGMOID: + ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun); + break; + default: + return false; + } + break; + case GGML_OP_DIAG_MASK_INF: + ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SOFT_MAX: + ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_SOFT_MAX_BACK: + ggml_vk_soft_max_back(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_ROPE: + ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, false, dryrun); + + break; + case GGML_OP_ROPE_BACK: + ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, true, dryrun); + + break; + case GGML_OP_ARGSORT: + ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SUM: + ggml_vk_sum(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SUM_ROWS: + ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_ARGMAX: + ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_COUNT_EQUAL: + ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_IM2COL: + ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_TIMESTEP_EMBEDDING: + ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_CONV_TRANSPOSE_1D: + ggml_vk_conv_transpose_1d(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_POOL_2D: + ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_CONV_2D_DW: + ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_LEAKY_RELU: + ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_MUL_MAT: + ggml_vk_mul_mat(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_MUL_MAT_ID: + ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun); + + break; + + case GGML_OP_FLASH_ATTN_EXT: + ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun); + + break; + + case GGML_OP_RWKV_WKV6: + ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun); + + break; + + case GGML_OP_RWKV_WKV7: + ggml_vk_rwkv_wkv7(ctx, compute_ctx, node, dryrun); + + break; + + case GGML_OP_OPT_STEP_ADAMW: + ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun); + + break; + default: + return false; + } + + if (dryrun) { + return false; + } + + ctx->tensor_ctxs[node_idx] = compute_ctx; + +#if defined(GGML_VULKAN_CHECK_RESULTS) + // Force context reset on each node so that each tensor ends up in its own context + // and can be run and compared to its CPU equivalent separately + last_node = true; +#endif + + if (submit || last_node) { + ggml_vk_ctx_end(compute_ctx); + + // TODO probably it'd be better to pass a exit_node flag to ggml_vk_compute_forward + if (last_node) { + compute_ctx->exit_tensor_idx = node_idx_begin; + } + else { + compute_ctx->exit_tensor_idx = -1; + } + + ctx->compute_ctx.reset(); + + bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready); + if (!ok) { + if (node->op == GGML_OP_UNARY) { + std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast(node->op_params[0])) << ")" << std::endl; + } + else { + std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl; + } + } + + } + return true; +} + +static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) { + ggml_backend_buffer * buf = nullptr; + + switch (tensor->op) { + case GGML_OP_ADD: + case GGML_OP_ACC: + case GGML_OP_GET_ROWS: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_CPY: + case GGML_OP_CONT: + case GGML_OP_DUP: + case GGML_OP_SILU_BACK: + case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_NONE: + case GGML_OP_ARGSORT: + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: + case GGML_OP_IM2COL: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_TRANSPOSE_1D: + case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D_DW: + case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: + case GGML_OP_LEAKY_RELU: + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_OPT_STEP_ADAMW: + buf = tensor->buffer; + + break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_SIGMOID: + buf = tensor->buffer; + break; + default: + return false; + } + break; + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + case GGML_OP_FLASH_ATTN_EXT: + buf = tensor->buffer; + + break; + default: + return false; + } + + if (buf == nullptr) { + return false; + } + + VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")"); + + vk_context subctx = ctx->tensor_ctxs[tensor_idx].lock(); + + // always wait for the GPU work to be done for the last submit + if (tensor_idx == subctx->exit_tensor_idx) { + use_fence = true; + } + + // Only run if ctx hasn't been submitted yet + if (!subctx->seqs.empty()) { +#ifdef GGML_VULKAN_CHECK_RESULTS + ggml_vk_check_results_0(tensor); + use_fence = true; +#endif + + // Do staging buffer copies + for (auto& cpy : subctx->in_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + + if (almost_ready && !ctx->almost_ready_fence_pending && !use_fence) { + ggml_vk_submit(subctx, ctx->almost_ready_fence); + ctx->almost_ready_fence_pending = true; + } else { + ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{}); + } + + if (use_fence) { + ggml_vk_wait_for_fence(ctx); + } +#ifdef GGML_VULKAN_CHECK_RESULTS + ggml_vk_check_results_1(tensor); +#endif + } + + if (tensor_idx == subctx->exit_tensor_idx) { + // Do staging buffer copies + for (auto& cpy : subctx->out_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + subctx->in_memcpys.clear(); + subctx->out_memcpys.clear(); + } + + return true; +} + +// Clean up after graph processing is done +static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { + VK_LOG_DEBUG("ggml_vk_graph_cleanup()"); + for (auto& buffer : ctx->gc.temp_buffers) { + ggml_vk_pool_free(ctx, buffer); + } + ctx->gc.temp_buffers.clear(); + + ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); + ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); + + for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) { + ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s }); + } + ctx->gc.semaphores.clear(); + + for (size_t i = 0; i < ctx->gc.tl_semaphores.size(); i++) { + ctx->device->device.destroySemaphore({ ctx->gc.tl_semaphores[i].s }); + } + ctx->gc.tl_semaphores.clear(); + ctx->semaphore_idx = 0; + + ctx->event_idx = 0; + + for (auto& event : ctx->gc.events) { + ctx->device->device.resetEvent(event); + } + + ctx->tensor_ctxs.clear(); + ctx->gc.contexts.clear(); + ctx->pipeline_descriptor_set_requirements = 0; + ctx->descriptor_set_idx = 0; +} + +// Clean up on backend free +static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { + VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")"); + ggml_vk_graph_cleanup(ctx); + + ggml_vk_destroy_buffer(ctx->prealloc_x); + ggml_vk_destroy_buffer(ctx->prealloc_y); + ggml_vk_destroy_buffer(ctx->prealloc_split_k); + + for (auto& buffer : ctx->buffer_pool) { + ggml_vk_destroy_buffer(buffer); + } + + ctx->prealloc_size_x = 0; + ctx->prealloc_size_y = 0; + ctx->prealloc_size_split_k = 0; + + for (auto& event : ctx->gc.events) { + ctx->device->device.destroyEvent(event); + } + ctx->gc.events.clear(); + + ctx->device->device.destroyFence(ctx->fence); + ctx->device->device.destroyFence(ctx->almost_ready_fence); + + for (auto& pool : ctx->descriptor_pools) { + ctx->device->device.destroyDescriptorPool(pool); + } + ctx->descriptor_pools.clear(); + ctx->descriptor_sets.clear(); + + ctx->compute_cmd_pool.destroy(ctx->device->device); + ctx->transfer_cmd_pool.destroy(ctx->device->device); +} + +static int ggml_vk_get_device_count() { + ggml_vk_instance_init(); + + return vk_instance.device_indices.size(); +} + +static void ggml_vk_get_device_description(int device, char * description, size_t description_size) { + ggml_vk_instance_init(); + + std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); + + vk::PhysicalDeviceProperties props; + devices[device].getProperties(&props); + + snprintf(description, description_size, "%s", props.deviceName.data()); +} + +// backend interface + +#define UNUSED GGML_UNUSED + +// device backend + +static bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer) { + return buffer->buft->iface.get_name == ggml_backend_vk_buffer_type_name; +} + +static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) { + VK_LOG_MEMORY("ggml_backend_vk_buffer_free_buffer()"); + ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; + ggml_vk_destroy_buffer(ctx->dev_buffer); + delete ctx; +} + +static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) { + return vk_ptr_base; + + UNUSED(buffer); +} + +static enum ggml_status ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")"); + if (tensor->view_src != nullptr) { + GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); + } + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + vk_buffer buf = buf_ctx->dev_buffer; + + uint32_t val32 = (uint32_t)value * 0x01010101; + ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size); +} + +static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); +} + +static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); +} + +static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { + if (ggml_backend_buffer_is_vk(src->buffer)) { + ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + + vk_buffer src_buf = src_buf_ctx->dev_buffer; + vk_buffer dst_buf = dst_buf_ctx->dev_buffer; + + ggml_vk_buffer_copy(dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); + + return true; + } + return false; + + UNUSED(buffer); +} + +static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; + + ggml_vk_buffer_memset(ctx->dev_buffer, 0, value, buffer->size); +} + +static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = { + /* .free_buffer = */ ggml_backend_vk_buffer_free_buffer, + /* .get_base = */ ggml_backend_vk_buffer_get_base, + /* .init_tensor = */ ggml_backend_vk_buffer_init_tensor, + /* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor, + /* .clear = */ ggml_backend_vk_buffer_clear, + /* .reset = */ NULL, +}; + +// vk buffer type +static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft) { + ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *)buft->context; + + return ctx->name.c_str(); +} + +static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + VK_LOG_MEMORY("ggml_backend_vk_buffer_type_alloc_buffer(" << size << ")"); + ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; + + vk_buffer dev_buffer = nullptr; + try { + dev_buffer = ggml_vk_create_buffer_device(ctx->device, size); + } catch (const vk::SystemError& e) { + return nullptr; + } + + ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(ctx->device, std::move(dev_buffer), ctx->name); + + return ggml_backend_buffer_init(buft, ggml_backend_vk_buffer_interface, bufctx, size); +} + +static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; + return ctx->device->properties.limits.minStorageBufferOffsetAlignment; +} + +static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; + return ctx->device->suballocation_block_size; +} + +static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + return ggml_nbytes(tensor); + + UNUSED(buft); +} + +ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) { + ggml_vk_instance_init(); + + VK_LOG_DEBUG("ggml_backend_vk_buffer_type(" << dev_num << ")"); + + vk_device dev = ggml_vk_get_device(dev_num); + + return &dev->buffer_type; +} + +// host buffer type + +static const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_type_t buft) { + return GGML_VK_NAME "_Host"; + + UNUSED(buft); +} + +static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) { + return GGML_VK_NAME "_Host"; + + UNUSED(buffer); +} + +static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { + VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()"); + ggml_vk_host_free(vk_instance.devices[0], buffer->context); +} + +static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + VK_LOG_MEMORY("ggml_backend_vk_host_buffer_type_alloc_buffer(" << size << ")"); + + size += 32; // Behave like the CPU buffer type + void * ptr = nullptr; + try { + ptr = ggml_vk_host_malloc(vk_instance.devices[0], size); + } catch (vk::SystemError& e) { + GGML_LOG_WARN("ggml_vulkan: Failed to allocate pinned memory (%s)\n", e.what()); + // fallback to cpu buffer + return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); + } + + ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); + buffer->buft = buft; + buffer->iface.free_buffer = ggml_backend_vk_host_buffer_free_buffer; + + return buffer; + + UNUSED(buft); +} + +static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return vk_instance.devices[0]->properties.limits.minMemoryMapAlignment; + + UNUSED(buft); +} + +// Should be changed to return device-specific host buffer type +// but that probably requires changes in llama.cpp +ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() { + static struct ggml_backend_buffer_type ggml_backend_vk_buffer_type_host = { + /* .iface = */ { + /* .get_name = */ ggml_backend_vk_host_buffer_type_name, + /* .alloc_buffer = */ ggml_backend_vk_host_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_vk_host_buffer_type_get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, + /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), 0), + /* .context = */ nullptr, + }; + + // Make sure device 0 is initialized + ggml_vk_instance_init(); + ggml_vk_get_device(0); + + return &ggml_backend_vk_buffer_type_host; +} + + +// backend + +static const char * ggml_backend_vk_name(ggml_backend_t backend) { + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + + return ctx->name.c_str(); +} + +static void ggml_backend_vk_free(ggml_backend_t backend) { + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + VK_LOG_DEBUG("ggml_backend_vk_free(" << ctx->name << ")"); + + ggml_vk_cleanup(ctx); + + delete ctx; + delete backend; +} + +static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_backend_t backend) { + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + + return &ctx->device->buffer_type; +} + +static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); + + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + + vk_context transfer_ctx; + + if (ctx->transfer_ctx.expired()) { + // Initialize new transfer context + transfer_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); + ctx->transfer_ctx = transfer_ctx; + ggml_vk_ctx_begin(ctx->device, transfer_ctx); + } else { + transfer_ctx = ctx->transfer_ctx.lock(); + } + + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_write_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); +} + +static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); + + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + + vk_context transfer_ctx; + + if (ctx->transfer_ctx.expired()) { + // Initialize new transfer context + transfer_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); + ctx->transfer_ctx = transfer_ctx; + ggml_vk_ctx_begin(ctx->device, transfer_ctx); + } else { + transfer_ctx = ctx->transfer_ctx.lock(); + } + + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_read_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); +} + +static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) { + VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) { + ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + + vk_context transfer_ctx; + + if (ctx->transfer_ctx.expired()) { + // Initialize new transfer context + transfer_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); + ctx->transfer_ctx = transfer_ctx; + ggml_vk_ctx_begin(ctx->device, transfer_ctx); + } else { + transfer_ctx = ctx->transfer_ctx.lock(); + } + + vk_buffer src_buf = src_buf_ctx->dev_buffer; + vk_buffer dst_buf = dst_buf_ctx->dev_buffer; + + ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); + return true; + } + + return false; +} + +static void ggml_backend_vk_synchronize(ggml_backend_t backend) { + VK_LOG_DEBUG("ggml_backend_vk_synchronize()"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + if(ctx->transfer_ctx.expired()) { + return; + } + + vk_context transfer_ctx = ctx->transfer_ctx.lock(); + + ggml_vk_ctx_end(transfer_ctx); + + for (auto& cpy : transfer_ctx->in_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + + ggml_vk_submit(transfer_ctx, ctx->fence); + ggml_vk_wait_for_fence(ctx); + + for (auto& cpy : transfer_ctx->out_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + + ctx->transfer_ctx.reset(); +} + +static bool ggml_vk_is_empty(ggml_tensor * node) { + return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; +} + +static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + + uint64_t total_mat_mul_bytes = 0; + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false); + if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) { + total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); + } + } + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } + ggml_vk_preallocate_buffers(ctx); + ggml_pipeline_allocate_descriptor_sets(ctx); + + int last_node = cgraph->n_nodes - 1; + + // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly + while (last_node > 0 && ggml_vk_is_empty(cgraph->nodes[last_node])) { + last_node -= 1; + } + + // Reserve tensor context space for all nodes + ctx->tensor_ctxs.resize(cgraph->n_nodes); + + bool first_node_in_batch = true; // true if next node will be first node in a batch + int submit_node_idx = 0; // index to first node in a batch + + vk_context compute_ctx; + if (vk_perf_logger_enabled) { + // allocate/resize the query pool + if (ctx->device->num_queries < cgraph->n_nodes + 1) { + if (ctx->device->query_pool) { + ctx->device->device.destroyQueryPool(ctx->device->query_pool); + } + vk::QueryPoolCreateInfo query_create_info; + query_create_info.queryType = vk::QueryType::eTimestamp; + query_create_info.queryCount = cgraph->n_nodes + 100; + ctx->device->query_pool = ctx->device->device.createQueryPool(query_create_info); + ctx->device->num_queries = query_create_info.queryCount; + } + + ctx->device->device.resetQueryPool(ctx->device->query_pool, 0, cgraph->n_nodes+1); + + GGML_ASSERT(ctx->compute_ctx.expired()); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0); + } + + // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution. + // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB + // (and scaled down based on model size, so smaller models submit earlier). + // Also submit at least every 100 nodes, in case there are workloads without as much matmul. + int nodes_per_submit = 100; + int submitted_nodes = 0; + int submit_count = 0; + uint64_t mul_mat_bytes = 0; + uint64_t mul_mat_bytes_per_submit = std::min(uint64_t(100*1000*1000), total_mat_mul_bytes / 40u); + for (int i = 0; i < cgraph->n_nodes; i++) { + if (first_node_in_batch) { + submit_node_idx = i; + } + + if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) { + mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); + } + + // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining) + bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5; + bool submit = (submitted_nodes >= nodes_per_submit) || + (mul_mat_bytes >= mul_mat_bytes_per_submit) || + (i == last_node) || + (almost_ready && !ctx->almost_ready_fence_pending); + + bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit); + + if (vk_perf_logger_enabled) { + if (ctx->compute_ctx.expired()) { + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + } else { + compute_ctx = ctx->compute_ctx.lock(); + } + compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+1); + } + + if (enqueued) { + ++submitted_nodes; + +#ifndef GGML_VULKAN_CHECK_RESULTS + if (first_node_in_batch) { + first_node_in_batch = false; + } +#endif + } + + if (submit && enqueued) { + first_node_in_batch = true; + submitted_nodes = 0; + mul_mat_bytes = 0; + if (submit_count < 3) { + mul_mat_bytes_per_submit *= 2; + } + submit_count++; + } + } + + if (vk_perf_logger_enabled) { + // End the command buffer and submit/wait + GGML_ASSERT(!ctx->compute_ctx.expired()); + compute_ctx = ctx->compute_ctx.lock(); + ggml_vk_ctx_end(compute_ctx); + + ggml_vk_submit(compute_ctx, ctx->device->fence); + VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences"); + ctx->device->device.resetFences({ ctx->device->fence }); + + // Get the results and pass them to the logger + std::vector timestamps(cgraph->n_nodes + 1); + VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->device->query_pool, 0, cgraph->n_nodes + 1, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results"); + for (int i = 0; i < cgraph->n_nodes; i++) { + if (!ggml_vk_is_empty(cgraph->nodes[i])) { + ctx->device->perf_logger->log_timing(cgraph->nodes[i], uint64_t((timestamps[i+1] - timestamps[i]) * ctx->device->properties.limits.timestampPeriod)); + } + } + + ctx->device->perf_logger->print_timings(); + } + + ggml_vk_graph_cleanup(ctx); + + return GGML_STATUS_SUCCESS; + + UNUSED(backend); +} + +// TODO: enable async and synchronize +static ggml_backend_i ggml_backend_vk_interface = { + /* .get_name = */ ggml_backend_vk_name, + /* .free = */ ggml_backend_vk_free, + /* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async, + /* .get_tensor_async = */ NULL, // ggml_backend_vk_get_tensor_async, + /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async, + /* .synchronize = */ NULL, // ggml_backend_vk_synchronize, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_vk_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, +}; + +static ggml_guid_t ggml_backend_vk_guid() { + static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b }; + return &guid; +} + +ggml_backend_t ggml_backend_vk_init(size_t dev_num) { + VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")"); + + ggml_backend_vk_context * ctx = new ggml_backend_vk_context; + ggml_vk_init(ctx, dev_num); + + ggml_backend_t vk_backend = new ggml_backend { + /* .guid = */ ggml_backend_vk_guid(), + /* .interface = */ ggml_backend_vk_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num), + /* .context = */ ctx, + }; + + return vk_backend; +} + +bool ggml_backend_is_vk(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid()); +} + +int ggml_backend_vk_get_device_count() { + return ggml_vk_get_device_count(); +} + +void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) { + GGML_ASSERT(device < (int) vk_instance.device_indices.size()); + int dev_idx = vk_instance.device_indices[device]; + ggml_vk_get_device_description(dev_idx, description, description_size); +} + +void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { + GGML_ASSERT(device < (int) vk_instance.device_indices.size()); + + vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]]; + + vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties(); + + for (const vk::MemoryHeap& heap : memprops.memoryHeaps) { + if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) { + *total = heap.size; + *free = heap.size; + break; + } + } +} + +////////////////////////// + +struct ggml_backend_vk_device_context { + size_t device; + std::string name; + std::string description; +}; + +static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ctx->name.c_str(); +} + +static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ctx->description.c_str(); +} + +static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; + ggml_backend_vk_get_device_memory(ctx->device, free, total); +} + +static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ggml_backend_vk_buffer_type(ctx->device); +} + +static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(ggml_backend_dev_t dev) { + UNUSED(dev); + return ggml_backend_vk_host_buffer_type(); +} + +static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) { + UNUSED(dev); + return GGML_BACKEND_DEVICE_TYPE_GPU; +} + +static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_vk_device_get_name(dev); + props->description = ggml_backend_vk_device_get_description(dev); + props->type = ggml_backend_vk_device_get_type(dev); + ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ false, + /* .host_buffer = */ true, + /* .buffer_from_host_ptr = */ false, + /* .events = */ false, + }; +} + +static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) { + UNUSED(params); + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ggml_backend_vk_init(ctx->device); +} + +static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + switch (op->op) { + case GGML_OP_UNARY: + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_SIGMOID: + return ggml_is_contiguous(op->src[0]) && + (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (op->src[0]->type == op->type); + default: + return false; + } + break; + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + { + ggml_type src0_type = op->src[0]->type; + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) { + // If there's not enough shared memory for row_ids and the result tile, fallback to CPU + return false; + } + switch (src0_type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + break; + default: + return false; + } + struct ggml_tensor * a; + struct ggml_tensor * b; + if (op->op == GGML_OP_MUL_MAT) { + a = op->src[0]; + b = op->src[1]; + } else { + a = op->src[2]; + b = op->src[1]; + } + if (a->ne[3] != b->ne[3]) { + return false; + } + if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16) || + !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) { + return false; + } + if (op->src[0]->type == GGML_TYPE_BF16 && op->src[1]->type == GGML_TYPE_F16) { + // We currently don't have a bf16 x f16 shader, or an fp16->bf16 copy shader. + // So don't support this combination for now. + return false; + } + + return true; + } break; + case GGML_OP_FLASH_ATTN_EXT: + { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + auto device = ggml_vk_get_device(ctx->device); + bool coopmat2 = device->coopmat2; + switch (op->src[0]->ne[0]) { + case 64: + case 80: + case 96: + case 112: + case 128: + case 256: + break; + default: + return false; + } + if (op->src[1]->ne[0] != op->src[2]->ne[0]) { + // different head sizes of K and V are not supported yet + return false; + } + if (op->src[0]->type != GGML_TYPE_F32) { + return false; + } + if (op->type != GGML_TYPE_F32) { + return false; + } + if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { + return false; + } + // It's straightforward to support different K/V dequant, but would + // significantly increase the number of pipelines + if (op->src[1]->type != op->src[2]->type) { + return false; + } + switch (op->src[1]->type) { + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + // supported in scalar and coopmat2 paths + break; + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently + //case GGML_TYPE_Q2_K: + //case GGML_TYPE_Q3_K: + //case GGML_TYPE_Q4_K: + //case GGML_TYPE_Q5_K: + //case GGML_TYPE_Q6_K: + //case GGML_TYPE_IQ1_S: + //case GGML_TYPE_IQ1_M: + //case GGML_TYPE_IQ2_XXS: + //case GGML_TYPE_IQ2_XS: + //case GGML_TYPE_IQ2_S: + //case GGML_TYPE_IQ3_XXS: + //case GGML_TYPE_IQ3_S: + //case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + // currently supported only in coopmat2 path + if (!coopmat2) { + return false; + } + break; + default: + return false; + } + if (!coopmat2 && !device->subgroup_shuffle) { + // scalar FA uses subgroupShuffle + return false; + } + return true; + } + case GGML_OP_GET_ROWS: + { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + return true; + default: + return false; + } + } break; + case GGML_OP_CONT: + case GGML_OP_CPY: + case GGML_OP_DUP: + { + ggml_type src0_type = op->src[0]->type; + ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type; + + if (src0_type == GGML_TYPE_F32) { + switch (src1_type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return true; + default: + break; + } + } + if (src1_type == GGML_TYPE_F32) { + switch (src0_type) { + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return true; + default: + break; + } + } + + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return true; + } + + // We can handle copying from a type to the same type if it's + // contiguous (memcpy). We use f16 or f32 shaders to do the copy, + // so the type/block size must be a multiple of 4. + if (src0_type == src1_type && + ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op) && + (ggml_type_size(src0_type) % 2) == 0) { + return true; + } + return false; + } break; + case GGML_OP_REPEAT: + return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float); + case GGML_OP_REPEAT_BACK: + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_RMS_NORM: + return true; + case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_L2_NORM: + return ggml_is_contiguous(op->src[0]); + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); + case GGML_OP_SILU_BACK: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_UPSCALE: + return op->op_params[0] == GGML_SCALE_MODE_NEAREST; + case GGML_OP_ACC: + case GGML_OP_CONCAT: + case GGML_OP_SCALE: + case GGML_OP_PAD: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: + case GGML_OP_ARGSORT: + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: + case GGML_OP_IM2COL: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_2D_DW: + case GGML_OP_POOL_2D: + case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: + case GGML_OP_LEAKY_RELU: + case GGML_OP_OPT_STEP_ADAMW: + return true; + case GGML_OP_CONV_TRANSPOSE_1D: + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; + default: + return false; + } + + UNUSED(dev); +} + +static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) { + return false; + } + + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context; + + return buft_ctx->device->idx == ctx->device; +} + +static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + const int min_batch_size = 32; + + return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) || + (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID); + + UNUSED(dev); +} + +static const struct ggml_backend_device_i ggml_backend_vk_device_i = { + /* .get_name = */ ggml_backend_vk_device_get_name, + /* .get_description = */ ggml_backend_vk_device_get_description, + /* .get_memory = */ ggml_backend_vk_device_get_memory, + /* .get_type = */ ggml_backend_vk_device_get_type, + /* .get_props = */ ggml_backend_vk_device_get_props, + /* .init_backend = */ ggml_backend_vk_device_init, + /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type, + /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type, + /* .buffer_from_host_ptr = */ NULL, + /* .supports_op = */ ggml_backend_vk_device_supports_op, + /* .supports_buft = */ ggml_backend_vk_device_supports_buft, + /* .offload_op = */ ggml_backend_vk_device_offload_op, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) { + UNUSED(reg); + return GGML_VK_NAME; +} + +static size_t ggml_backend_vk_reg_get_device_count(ggml_backend_reg_t reg) { + UNUSED(reg); + return ggml_backend_vk_get_device_count(); +} + +static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, size_t device) { + static std::vector devices; + + static bool initialized = false; + + { + static std::mutex mutex; + std::lock_guard lock(mutex); + if (!initialized) { + for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { + ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; + char desc[256]; + ggml_backend_vk_get_device_description(i, desc, sizeof(desc)); + ctx->device = i; + ctx->name = GGML_VK_NAME + std::to_string(i); + ctx->description = desc; + devices.push_back(new ggml_backend_device { + /* .iface = */ ggml_backend_vk_device_i, + /* .reg = */ reg, + /* .context = */ ctx, + }); + } + initialized = true; + } + } + + GGML_ASSERT(device < devices.size()); + return devices[device]; +} + +static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = { + /* .get_name = */ ggml_backend_vk_reg_get_name, + /* .get_device_count = */ ggml_backend_vk_reg_get_device_count, + /* .get_device = */ ggml_backend_vk_reg_get_device, + /* .get_proc_address = */ NULL, +}; + +ggml_backend_reg_t ggml_backend_vk_reg() { + static ggml_backend_reg reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_vk_reg_i, + /* .context = */ nullptr, + }; + try { + ggml_vk_instance_init(); + return ® + } catch (const vk::SystemError& e) { + VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: System error: " << e.what()); + return nullptr; + } +} + +// Extension availability +static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions) { +#ifdef GGML_VULKAN_VALIDATE + bool portability_enumeration_ext = false; + // Check for portability enumeration extension for MoltenVK support + for (const auto& properties : instance_extensions) { + if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) { + return true; + } + } + if (!portability_enumeration_ext) { + std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl; + } +#endif + return false; + + UNUSED(instance_extensions); +} +static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions) { +#ifdef __APPLE__ + bool portability_enumeration_ext = false; + // Check for portability enumeration extension for MoltenVK support + for (const auto& properties : instance_extensions) { + if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) { + return true; + } + } + if (!portability_enumeration_ext) { + std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl; + } +#endif + return false; + + UNUSED(instance_extensions); +} + +static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) { + switch (props.vendorID) { + case VK_VENDOR_ID_INTEL: + // Only allowing Xe2 GPU at the moment since Xe2 GPU can gain significant performance boost, + // while some older hardware (ex. Arc A770) has performance regressions + return arch == vk_device_architecture::INTEL_XE2; + case VK_VENDOR_ID_AMD: + if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) { + // Workaround for AMD proprietary driver reporting support on all GPUs + return arch == vk_device_architecture::AMD_RDNA3; + } + return true; + default: + return true; + } +} + +// checks + +#ifdef GGML_VULKAN_CHECK_RESULTS +static void ggml_vk_print_graph_origin(const ggml_tensor * tensor, std::vector& done, int level = 0) { + if (std::find(done.begin(), done.end(), tensor) != done.end() || level > 10) { + return; + } + for (int j = 0; j < level; j++) { + std::cerr << " "; + } + std::cerr << ggml_op_name(tensor->op) << " gpu=" << (tensor->extra != nullptr) << std::endl; + + done.push_back(tensor); + + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] != nullptr) { + ggml_vk_print_graph_origin(tensor->src[i], done, level + 1); + } + } +} + +static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * data, int i0, int i1, int i2, int i3) { + if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16 && tensor->type != GGML_TYPE_I32) { + return; + } + i0 = std::max(i0, 5); + i1 = std::max(i1, 5); + i2 = std::max(i2, 0); + i3 = std::max(i3, 0); + fprintf(stderr, " "); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + fprintf(stderr, "%7d ", idx1); + } + fprintf(stderr, "\n"); + for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { + fprintf(stderr, "%7d: ", idx0); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) { + float val; + if (tensor->type == GGML_TYPE_F32) { + val = *(const float *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); + } else if (tensor->type == GGML_TYPE_F16) { + val = ggml_fp16_to_fp32(*(const ggml_fp16_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0])); + } else if (tensor->type == GGML_TYPE_I32) { + val = *(const int32_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); + } else { + GGML_ABORT("fatal error"); + } + fprintf(stderr, "% 7.2f ", val); + } else { + fprintf(stderr, " "); + } + } + fprintf(stderr, "\n"); + } +} + +static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name) { + void * tensor_data = tensor->data; + + const bool is_gpu = tensor->buffer != nullptr && ggml_backend_buffer_is_vk(tensor->buffer); + + if (is_gpu) { + const size_t tensor_size = ggml_nbytes(tensor); + tensor_data = malloc(tensor_size); + + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + + vk_buffer buffer_gpu = buf_ctx->dev_buffer; + ggml_vk_buffer_read(buffer_gpu, vk_tensor_offset(tensor) + tensor->view_offs, tensor_data, tensor_size); + } + + std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl; + std::cerr << "tensor=" << tensor << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << std::endl; + if (tensor->src[0] != nullptr) { + std::cerr << "tensor->src[0]=" << tensor->src[0] << " name=" << tensor->src[0]->name << " op=" << ggml_op_name(tensor->src[0]->op) << " type=" << ggml_type_name(tensor->src[0]->type) << " ne0=" << tensor->src[0]->ne[0] << " nb0=" << tensor->src[0]->nb[0] << " ne1=" << tensor->src[0]->ne[1] << " nb1=" << tensor->src[0]->nb[1] << " ne2=" << tensor->src[0]->ne[2] << " nb2=" << tensor->src[0]->nb[2] << " ne3=" << tensor->src[0]->ne[3] << " nb3=" << tensor->src[0]->nb[3] << std::endl; + } + if (tensor->src[1] != nullptr) { + std::cerr << "tensor->src[1]=" << tensor->src[1] << " name=" << tensor->src[1]->name << " op=" << ggml_op_name(tensor->src[1]->op) << " type=" << ggml_type_name(tensor->src[1]->type) << " ne0=" << tensor->src[1]->ne[0] << " nb0=" << tensor->src[1]->nb[0] << " ne1=" << tensor->src[1]->ne[1] << " nb1=" << tensor->src[1]->nb[1] << " ne2=" << tensor->src[1]->ne[2] << " nb2=" << tensor->src[1]->nb[2] << " ne3=" << tensor->src[1]->ne[3] << " nb3=" << tensor->src[1]->nb[3] << std::endl; + } + std::cerr << std::endl << "Result:" << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); + + if (is_gpu) { + free(tensor_data); + } +} + +void * comp_result; +size_t comp_size; +size_t comp_nb[GGML_MAX_DIMS]; +size_t check_counter = 0; +static void ggml_vk_check_results_0(ggml_tensor * tensor) { + if (tensor->op == GGML_OP_TRANSPOSE) { + return; + } + + check_counter++; + if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { + return; + } + + VK_LOG_DEBUG("ggml_vk_check_results_0(" << tensor->name << ")"); + + ggml_tensor * src0 = tensor->src[0]; + ggml_tensor * src1 = tensor->src[1]; + + struct ggml_init_params iparams = { + /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ false, + }; + + struct ggml_context * ggml_ctx = ggml_init(iparams); + + std::array src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; + std::array src_size = {0, 0, 0, 0, 0, 0}; + std::array src_buffer = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; + const char * srci_name[6] = {"src0", "src1", "src2", "src3", "src4", "src5"}; + + struct ggml_tensor * tensor_clone = nullptr; + + for (int i = 0; i < 6; i++) { + ggml_tensor * srci = tensor->src[i]; + if (srci == nullptr) { + continue; + } + ggml_tensor * srci_clone = ggml_dup_tensor(ggml_ctx, srci); + size_t srci_size = ggml_nbytes(srci); + + src_clone[i] = srci_clone; + src_size[i] = ggml_nbytes(srci); + src_buffer[i] = malloc(srci_size); + + srci_clone->data = src_buffer[i]; + if (ggml_backend_buffer_is_host(srci->buffer)) { + memcpy(srci_clone->data, srci->data, srci_size); + memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS); + } else if (ggml_backend_buffer_is_vk(srci->buffer)) { + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)srci->buffer->context; + vk_buffer& buffer_gpu = buf_ctx->dev_buffer; + uint64_t offset = vk_tensor_offset(srci) + srci->view_offs; + if (!ggml_is_contiguous(srci) && ggml_vk_dim01_contiguous(srci)) { + for (int i3 = 0; i3 < srci->ne[3]; i3++) { + for (int i2 = 0; i2 < srci->ne[2]; i2++) { + const int idx = i3*srci->ne[2] + i2; + ggml_vk_buffer_read(buffer_gpu, offset + idx * srci->nb[2], ((char *)srci_clone->data + idx * srci_clone->nb[2]), srci->ne[1] * srci->nb[1]); + } + } + + srci_clone->nb[0] = srci->nb[0]; + srci_clone->nb[1] = srci->nb[1]; + for (int i = 2; i < GGML_MAX_DIMS; i++) { + srci_clone->nb[i] = srci_clone->nb[i - 1]*srci_clone->ne[i - 1]; + } + } else { + if (offset + srci_size >= buffer_gpu->size) { + srci_size = buffer_gpu->size - offset; + } + ggml_vk_buffer_read(buffer_gpu, offset, srci_clone->data, srci_size); + memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS); + } + } else { + GGML_ABORT("fatal error"); + } + + if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { + ggml_vk_print_tensor(srci, srci_name[i]); + } + } + + if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]); + } else if (tensor->op == GGML_OP_MUL_MAT) { + tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_MUL_MAT_ID) { + tensor_clone = ggml_mul_mat_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]); + } else if (tensor->op == GGML_OP_SUB) { + tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_MUL) { + tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_DIV) { + tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_CONCAT) { + tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_UPSCALE) { + tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]); + } else if (tensor->op == GGML_OP_SCALE) { + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]); + } else if (tensor->op == GGML_OP_SQR) { + tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_SIN) { + tensor_clone = ggml_sin(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_COS) { + tensor_clone = ggml_cos(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_CLAMP) { + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]); + } else if (tensor->op == GGML_OP_PAD) { + tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]); + } else if (tensor->op == GGML_OP_REPEAT) { + tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor); + } else if (tensor->op == GGML_OP_REPEAT_BACK) { + tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor); + } else if (tensor->op == GGML_OP_ADD) { + tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_ACC) { + tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); + } else if (tensor->op == GGML_OP_NORM) { + tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); + } else if (tensor->op == GGML_OP_GROUP_NORM) { + const float * float_params = (const float *)tensor->op_params; + tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]); + } else if (tensor->op == GGML_OP_RMS_NORM) { + tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); + } else if (tensor->op == GGML_OP_RMS_NORM_BACK) { + const float eps = ((float *) tensor->op_params)[0]; + tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps); + } else if (tensor->op == GGML_OP_SILU_BACK) { + tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_L2_NORM) { + const float eps = ((float *) tensor->op_params)[0]; + tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps); + } else if (tensor->op == GGML_OP_SOFT_MAX) { + if (src1 != nullptr) { + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]); + } else { + tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]); + } + } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) { + tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); + } else if (tensor->op == GGML_OP_DIAG_MASK_INF) { + tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]); + } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) { + const int n_dims = ((int32_t *) tensor->op_params)[1]; + const int mode = ((int32_t *) tensor->op_params)[2]; + //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3]; + const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4]; + const float freq_base = ((float *) tensor->op_params)[5]; + const float freq_scale = ((float *) tensor->op_params)[6]; + const float ext_factor = ((float *) tensor->op_params)[7]; + const float attn_factor = ((float *) tensor->op_params)[8]; + const float beta_fast = ((float *) tensor->op_params)[9]; + const float beta_slow = ((float *) tensor->op_params)[10]; + if (mode & GGML_ROPE_TYPE_MROPE) { + int32_t *sections = ((int32_t *) tensor->op_params) + 11; + if (tensor->op == GGML_OP_ROPE) { + tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } else { + tensor_clone = ggml_rope_multi_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } + } else { + if (tensor->op == GGML_OP_ROPE) { + tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } else { + tensor_clone = ggml_rope_ext_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } + } + } else if (tensor->op == GGML_OP_UNARY) { + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_SILU: + tensor_clone = ggml_silu(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_GELU: + tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_GELU_QUICK: + tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_RELU: + tensor_clone = ggml_relu(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_TANH: + tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_SIGMOID: + tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]); + break; + default: + std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; + GGML_ABORT("fatal error"); + } + } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) { + if (src1 == nullptr) { + tensor_clone = ggml_dup(ggml_ctx, src_clone[0]); + tensor_clone->type = tensor->type; + } else { + tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]); + } + } else if (tensor->op == GGML_OP_CONT) { + tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + } else if (tensor->op == GGML_OP_RESHAPE) { + tensor_clone = ggml_reshape_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + } else if (tensor->op == GGML_OP_VIEW) { + tensor_clone = ggml_view_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]); + } else if (tensor->op == GGML_OP_PERMUTE) { + int32_t * params = (int32_t *)tensor->op_params; + tensor_clone = ggml_permute(ggml_ctx, src_clone[0], params[0], params[1], params[2], params[3]); + } else if (tensor->op == GGML_OP_TRANSPOSE) { + tensor_clone = ggml_transpose(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_GET_ROWS) { + tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_ARGSORT) { + tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_SUM) { + tensor_clone = ggml_sum(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_SUM_ROWS) { + tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_ARGMAX) { + tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_COUNT_EQUAL) { + tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_IM2COL) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t p0 = tensor->op_params[2]; + const int32_t p1 = tensor->op_params[3]; + const int32_t d0 = tensor->op_params[4]; + const int32_t d1 = tensor->op_params[5]; + + const bool is_2D = tensor->op_params[6] == 1; + tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type); + } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) { + const int32_t dim = tensor->op_params[0]; + const int32_t max_period = tensor->op_params[1]; + tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period); + } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_1D){ + const int32_t s0 = tensor->op_params[0]; + const int32_t p0 = tensor->op_params[1]; + const int32_t d0 = tensor->op_params[2]; + tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0); + } else if (tensor->op == GGML_OP_POOL_2D) { + enum ggml_op_pool op = static_cast(tensor->op_params[0]); + const int32_t k0 = tensor->op_params[1]; + const int32_t k1 = tensor->op_params[2]; + const int32_t s0 = tensor->op_params[3]; + const int32_t s1 = tensor->op_params[4]; + const int32_t p0 = tensor->op_params[5]; + const int32_t p1 = tensor->op_params[6]; + + tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1); + } else if (tensor->op == GGML_OP_LEAKY_RELU) { + const float * op_params = (const float *)tensor->op_params; + tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false); + } else if (tensor->op == GGML_OP_RWKV_WKV6) { + tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1], + src_clone[2], src_clone[3], src_clone[4], src_clone[5]); + } else if (tensor->op == GGML_OP_RWKV_WKV7) { + tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], + src_clone[4], src_clone[5], src_clone[6]); + } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) { + src_clone[0]->flags = src0->flags; + tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], + src_clone[2], src_clone[3], src_clone[4]); + } + else { + std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; + GGML_ABORT("fatal error"); + } + + ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); + ggml_build_forward_expand(cgraph, tensor_clone); + + ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8); + + if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { + ggml_vk_print_tensor(tensor_clone, "tensor_clone"); + } + + comp_size = ggml_nbytes(tensor_clone); + + comp_result = malloc(comp_size); + memcpy(comp_result, tensor_clone->data, comp_size); + memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS); + + for (int i = 0; i < 6; i++) { + if (src_buffer[i] != nullptr) { + free(src_buffer[i]); + } + } + + ggml_free(ggml_ctx); + + VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")"); +} + +static void ggml_vk_check_results_1(ggml_tensor * tensor) { + if (tensor->op == GGML_OP_TRANSPOSE) { + return; + } + if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { + return; + } + + VK_LOG_DEBUG("ggml_vk_check_results_1(" << tensor->name << ")"); + + ggml_tensor * src0 = tensor->src[0]; + ggml_tensor * src1 = tensor->src[1]; + ggml_tensor * src2 = tensor->src[2]; + ggml_tensor * src3 = tensor->src[3]; + + void * tensor_data = tensor->data; + + if (ggml_backend_buffer_is_vk(tensor->buffer)) { + size_t tensor_size = ggml_nbytes(tensor); + tensor_data = malloc(tensor_size); + + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + + vk_buffer& buffer_gpu = buf_ctx->dev_buffer; + uint64_t offset = vk_tensor_offset(tensor) + tensor->view_offs; + if (offset + tensor_size >= buffer_gpu->size) { + tensor_size = buffer_gpu->size - offset; + } + + ggml_vk_buffer_read(buffer_gpu, offset, tensor_data, tensor_size); + } + + float first_error_result = -1.0f; + float first_error_correct = -1.0f; + std::array first_error = { -1, -1, -1, -1 }; + double avg_err = 0.0; + size_t counter = 0; + + for (int i3 = 0; i3 < tensor->ne[3]; i3++) { + for (int i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + const bool buffer_size_fit = i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0] < comp_size; + float correct = 0.0f; + float result = 0.0f; + + if (buffer_size_fit) { + if (tensor->type == GGML_TYPE_F32) { + correct = *(float *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); + result = *(float *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); + } else if (tensor->type == GGML_TYPE_F16) { + correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0])); + result = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0])); + } else if (tensor->type == GGML_TYPE_I32) { + correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); + result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); + } else if (tensor->type == GGML_TYPE_I64) { + correct = *(int64_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); + result = *(int64_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); + } else { + std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl; + } + } else { + std::cerr << "Missing debug code for type " << ggml_type_name(tensor->type) << std::endl; + GGML_ABORT("fatal error"); + } + + if ((std::isnan(correct) != std::isnan(result)) || (std::isinf(correct) != std::isinf(result)) || !buffer_size_fit) { + std::cerr << "ERROR: Invalid value in " << ggml_op_name(tensor->op) << " i3=" << i3 << " i2=" << i2 << " i1=" << i1 << " i0=" << i0 << " result=" << result << " correct=" << correct << " avg_err=" << (avg_err / counter) << std::endl; + std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; + if (src0 != nullptr) { + std::cerr << "src0=" << src0 << " src0->name=" << src0->name << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; + } + if (src1 != nullptr) { + std::cerr << "src1=" << src1 << " src1->name=" << src1->name << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; + } + if (src2 != nullptr) { + std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; + } + if (src3 != nullptr) { + std::cerr << "src3=" << src3 << " src3->name=" << src3->name << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; + } + std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; + std::cerr << std::endl << "Result:" << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3); + std::cerr << std::endl << "Correct:" << std::endl; + ggml_vk_print_tensor_area(tensor, comp_result, i0, i1, i2, i3); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); + GGML_ABORT("fatal error"); + } + const double denom = std::fabs(correct) > 1.0f ? (std::fabs(correct) > 1e-8 ? std::fabs(correct) : 1e-8) : 1.0f; + if (first_error[0] == -1 && std::fabs(correct - result) / denom > 0.5) { + first_error[0] = i0; + first_error[1] = i1; + first_error[2] = i2; + first_error[3] = i3; + first_error_result = result; + first_error_correct = correct; + } + + // Special case, value is infinite, avoid NaN result in avg_err + // NaN also appears in results, if both are nan error is 0 + if (!std::isinf(correct) && !std::isinf(result) && !std::isnan(correct) && !std::isnan(result)) { + avg_err += std::fabs(correct - result) / denom; + } + counter++; + } + } + } + } + + avg_err /= counter; + + if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { + std::cerr << "TENSOR CHECK: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl; + std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; + if (src0 != nullptr) { + std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; + } + if (src1 != nullptr) { + std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; + } + if (src2 != nullptr) { + std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; + } + if (src3 != nullptr) { + std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; + } + std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; + std::cerr << std::endl << "Result:" << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0); + std::cerr << std::endl << "Correct:" << std::endl; + ggml_vk_print_tensor_area(tensor, comp_result, 5, 5, 0, 0); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); + } + + if (avg_err > 0.5 || std::isnan(avg_err)) { + std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl; + std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; + if (src0 != nullptr) { + std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; + } + if (src1 != nullptr) { + std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; + } + if (src2 != nullptr) { + std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; + } + if (src3 != nullptr) { + std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; + } + std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; + std::cerr << std::endl << "Result:" << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]); + std::cerr << std::endl << "Correct:" << std::endl; + ggml_vk_print_tensor_area(tensor, comp_result, first_error[0], first_error[1], first_error[2], first_error[3]); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); + GGML_ABORT("fatal error"); + } else { + std::cerr << check_counter << " " << tensor->name << " op=" << ggml_op_name(tensor->op) << " avg_err=" << avg_err << std::endl; + } + + free(comp_result); + comp_result = nullptr; + comp_size = 0; + + if (ggml_backend_buffer_is_vk(tensor->buffer)) { + free(tensor_data); + } + + VK_LOG_DEBUG("END ggml_vk_check_results_1(" << tensor->name << ")"); +} +#endif + +GGML_BACKEND_DL_IMPL(ggml_backend_vk_reg) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..14e9daaa01a25f980c00576cdb70674a64dc9c13 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -0,0 +1,27 @@ +cmake_minimum_required(VERSION 3.19) +project("vulkan-shaders-gen" C CXX) + +find_package (Threads REQUIRED) + +if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + message(STATUS "Enabling coopmat glslc support") +endif() +if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + message(STATUS "Enabling coopmat2 glslc support") +endif() +if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + message(STATUS "Enabling dot glslc support") +endif() +if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + message(STATUS "Enabling bfloat16 glslc support") +endif() + +set(TARGET vulkan-shaders-gen) +add_executable(${TARGET} vulkan-shaders-gen.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_compile_features(${TARGET} PRIVATE cxx_std_17) +target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp new file mode 100644 index 0000000000000000000000000000000000000000..d896f1ef0beee5ec499f7d092205ffa2df45cd8a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp @@ -0,0 +1,29 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = gl_GlobalInvocationID.x; + if (idx >= p.ne) { + return; + } + + const uint offset = p.param3; + const uint src1_i = idx - offset; + const uint oz = src1_i / p.nb02; + const uint oy = (src1_i - (oz * p.nb02)) / p.nb01; + const uint ox = src1_i % p.nb01; + + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) { + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11])); + } else { + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)])); + } +} + diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/add.comp b/ggml/src/ggml-vulkan/vulkan-shaders/add.comp new file mode 100644 index 0000000000000000000000000000000000000000..2b4085c4f82d5d49c5fb25c13ffe77ca490a5a90 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/add.comp @@ -0,0 +1,29 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +#include "types.comp" +#include "generic_binary_head.comp" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + + idx += num_threads; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp new file mode 100644 index 0000000000000000000000000000000000000000..eaf4da341e348904c2fb2ed5fcf8a2dc4cdefac4 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp @@ -0,0 +1,51 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; + +shared FLOAT_TYPE tmpmax[BLOCK_SIZE]; +shared uint tmp[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint col = gl_LocalInvocationID.x; + + if (col >= p.KX) { + return; + } + A_TYPE amax = data_a[row*p.KX + col]; + tmp[col] = col; + + for (uint i = col + BLOCK_SIZE; i < p.KX; i += BLOCK_SIZE) { + A_TYPE val = data_a[row*p.KX + i]; + if (val > amax) { + amax = val; + tmp[col] = i; + } + } + tmpmax[col] = amax; + + barrier(); + [[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) { + if (col < s && col + s < p.KX) { + if (tmpmax[col] < tmpmax[col + s]) { + tmpmax[col] = tmpmax[col + s]; + tmp[col] = tmp[col + s]; + } + } + barrier(); + } + + if (col == 0) { + data_d[row] = D_TYPE(tmp[0]); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp new file mode 100644 index 0000000000000000000000000000000000000000..d4fa45b1e106fdd0172b6de8b4de1b23b4cb24ea --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp @@ -0,0 +1,69 @@ +#version 450 + +#include "types.comp" + +#define BLOCK_SIZE 1024 +#define ASC 0 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) buffer D {int data_d[];}; + +layout (push_constant) uniform parameter { + uint ncols; + uint ncols_pad; + uint order; +} p; + +shared int dst_row[BLOCK_SIZE]; + +void swap(uint idx0, uint idx1) { + int tmp = dst_row[idx0]; + dst_row[idx0] = dst_row[idx1]; + dst_row[idx1] = tmp; +} + +void main() { + // bitonic sort + const int col = int(gl_LocalInvocationID.x); + const uint row = gl_WorkGroupID.y; + + const uint row_offset = row * p.ncols; + + // initialize indices + if (col < p.ncols_pad) { + dst_row[col] = col; + } + barrier(); + + for (uint k = 2; k <= p.ncols_pad; k *= 2) { + for (uint j = k / 2; j > 0; j /= 2) { + const uint ixj = col ^ j; + if (col < p.ncols_pad && ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= p.ncols || + (dst_row[ixj] < p.ncols && (p.order == ASC ? + data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]] : + data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]])) + ) { + swap(col, ixj); + } + } else { + if (dst_row[ixj] >= p.ncols || + (dst_row[col] < p.ncols && (p.order == ASC ? + data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]] : + data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]])) + ) { + swap(col, ixj); + } + } + } + barrier(); + } + } + + if (col < p.ncols) { + data_d[row_offset + col] = dst_row[col]; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp b/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp new file mode 100644 index 0000000000000000000000000000000000000000..1e5cb8dae4e10d088bf22eb4ab7cf9c2732734a2 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val)); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp b/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp new file mode 100644 index 0000000000000000000000000000000000000000..9ee2f1fae2074aab562c863475f05b706590f7c1 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp @@ -0,0 +1,41 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + const int dim = p.param3; + + if (idx >= p.ne) { + return; + } + + const uint i3 = idx / (p.ne22*p.ne21*p.ne20); + const uint i3_offset = i3 * p.ne22*p.ne21*p.ne20; + const uint i2 = (idx - i3_offset) / (p.ne21*p.ne20); + const uint i2_offset = i2*p.ne21*p.ne20; + const uint i1 = (idx - i3_offset - i2_offset) / p.ne20; + const uint i0 = idx - i3_offset - i2_offset - i1*p.ne20; + + uint o[4] = {0, 0, 0, 0}; + o[dim] = dim == 0 ? p.ne00 : (dim == 1 ? p.ne01 : (dim == 2 ? p.ne02 : p.ne03)); + + const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00; + const uint src1_idx = (i3 - o[3])*p.nb13 + (i2 - o[2])*p.nb12 + (i1 - o[1])*p.nb11 + (i0 - o[0])*p.nb10; + const uint dst_idx = i3*p.nb23 + i2*p.nb22 + i1*p.nb21 + i0*p.nb20; + + const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03; + +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : data_b[get_boffset() + src1_idx]); +#else + if (is_src0) { + data_d[get_doffset() + dst_idx] = data_a[get_aoffset() + src0_idx]; + } else { + data_d[get_doffset() + dst_idx] = data_b[get_boffset() + src1_idx]; + } +#endif +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp new file mode 100644 index 0000000000000000000000000000000000000000..6567a8c54cf493bec8214b9f45324dacd2a74ad9 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp @@ -0,0 +1,49 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +#extension GL_EXT_control_flow_attributes : require + +const uint num_threads = 128; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 4; + + // fast path for when all four iterations are in-bounds + if (idx + (num_iter-1)*num_threads < p.ne) { + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + +#if defined(DATA_D_BF16) + float f = float(data_a[get_aoffset() + idx]); + data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); +#elif !defined(OPTIMIZATION_ERROR_WORKAROUND) + data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]); +#else + data_d[get_doffset() + idx] = data_a[get_aoffset() + idx]; +#endif + idx += num_threads; + } + } else { + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + +#if defined(DATA_D_BF16) + float f = float(data_a[get_aoffset() + idx]); + data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); +#elif !defined(OPTIMIZATION_ERROR_WORKAROUND) + data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]); +#else + data_d[get_doffset() + idx] = data_a[get_aoffset() + idx]; +#endif + idx += num_threads; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp new file mode 100644 index 0000000000000000000000000000000000000000..938c74da5007476446fef614377a43504a502b40 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp @@ -0,0 +1,105 @@ +#version 450 + +#include "types.comp" + +layout (push_constant) uniform parameter +{ + uint ne; + uint batches; + uint channels; + uint dst_w; + uint dst_h; + uint src_w; + uint src_h; + uint knl_w; + uint knl_h; + int stride_x; + int stride_y; + int pad_x; + int pad_y; + int dilation_x; + int dilation_y; +} p; + +layout (binding = 0) readonly buffer A {A_TYPE knl_data[];}; +layout (binding = 1) readonly buffer B {B_TYPE src_data[];}; +layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];}; + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE conv_2d_dw_whcn(uint idx) { + uint i0 = idx / p.dst_w; + uint dst_x = idx - i0 * p.dst_w; + uint i1 = i0 / p.dst_h; + uint dst_y = i0 - i1 * p.dst_h; + uint n = i1 / p.channels; + uint c = i1 - n * p.channels; + + uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w; + uint knl_i = c * p.knl_h * p.knl_w; + + FLOAT_TYPE sum = 0.0; + for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) { + uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y; + if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int + continue; + } + for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) { + uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x; + if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int + continue; + } + FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]); + FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]); + sum = fma(v, k, sum); + } + } + return sum; +} + +FLOAT_TYPE conv_2d_dw_cwhn(uint idx) { + uint i0 = idx / p.channels; + uint c = idx - i0 * p.channels; + uint i1 = i0 / p.dst_w; + uint dst_x = i0 - i1 * p.dst_w; + uint n = i1 / p.dst_h; + uint dst_y = i1 - n * p.dst_h; + + uint src_i = n * p.channels * p.src_h * p.src_w; + uint src_row = p.src_w * p.channels; + uint knl_row = p.knl_w * p.channels; + + FLOAT_TYPE sum = 0.0; + for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) { + uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y; + if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int + continue; + } + for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) { + uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x; + if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int + continue; + } + FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]); + FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]); + sum = fma(v, k, sum); + } + } + return sum; +} + +void main() { + uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + if (idx >= p.ne) { + return; + } + + FLOAT_TYPE result = +#ifdef WHCN + conv_2d_dw_whcn(idx); +#else + conv_2d_dw_cwhn(idx); +#endif + dst_data[idx] = D_TYPE(result); +} + diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp new file mode 100644 index 0000000000000000000000000000000000000000..b17b4e83eec4b6202bdf34a3ffd9adc6324f9323 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp @@ -0,0 +1,98 @@ +#version 450 + +#include "types.comp" + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin] +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin] +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; // dst - result [KL, Cout] + +layout(local_size_x = 128 , local_size_y = 1, local_size_z = 1) in; + +layout (push_constant) uniform parameter { + uint32_t Cout; + uint32_t Cin; + uint32_t K; + uint32_t L; + uint32_t KL; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb11; + uint32_t nb1; + + int32_t s0; +} p; + + +uint32_t Cout_idx = gl_WorkGroupID.x; +const uint32_t bs = gl_WorkGroupSize.x; +uint32_t tid = gl_LocalInvocationID.x; +// Code is more straightforward if we assume it is bs*s0+K instead of (bs-1)*s0+K. +uint32_t tmp_len = bs*p.s0+p.K; +shared D_TYPE tmp[4096]; + +uint splitWork(uint workSize){ + return (bs + workSize -1) / bs; +} + +void main(){ + for(uint32_t i = 0; i < splitWork(tmp_len); i++){ + uint32_t idx = i*bs+tid; + if(idx < tmp_len){ + tmp[idx] = 0.0; + } + } + + uint32_t L_blocks = splitWork(p.L); + for(uint32_t L_block_id = 0; L_block_id < L_blocks; L_block_id++){ + if(L_block_id > 0){ + barrier(); + // Shift values in tmp to the current processing window + for(int i = 0; i < splitWork(tmp_len); i++){ + uint32_t idx = i*bs+tid; + if(idx >= bs*p.s0 && idx < tmp_len){ + tmp[idx-bs*p.s0] = tmp[idx]; + tmp[idx] = 0.0; + }else if(idx >= p.K && idx < bs*p.s0){ + tmp[idx] = 0.0; + } + } + } + barrier(); + + // Save contributions of the block to tmp + uint32_t L_idx = L_block_id*bs + tid; + for(uint32_t K_idx = 0; K_idx < p.K; K_idx++){ + D_TYPE dp = 0.0; + for(uint32_t Cin_idx = 0; Cin_idx < p.Cin; Cin_idx++){ + A_TYPE elemKrn = data_a[K_idx + Cout_idx * p.nb01 + Cin_idx * p.nb02]; + if(L_idx < p.L){ + B_TYPE elemInp = data_b[L_idx + Cin_idx*p.nb11]; + dp = fma(elemKrn, elemInp, dp); + } + } + tmp[tid*p.s0 + K_idx] += dp; + barrier(); + } + + // Save the computed values except the last block that can have different size + uint32_t KLb_idx = L_block_id*bs*p.s0; + if(L_block_id < L_blocks-1){ + for(uint32_t s0_idx = 0; s0_idx < p.s0; s0_idx++){ + uint32_t sh_idx = p.s0*tid+s0_idx; + uint32_t KL_idx = KLb_idx+sh_idx; + if(KL_idx < p.KL){ + data_d[KL_idx + Cout_idx*p.nb1] = tmp[sh_idx]; + } + } + } + } + + for(uint32_t i = 0; i < splitWork(tmp_len); i++){ + uint32_t idx = i*bs+tid; + uint32_t KL_idx = (L_blocks-1)*bs*p.s0+idx; + if(KL_idx < p.KL){ + data_d[KL_idx + Cout_idx*p.nb1] = tmp[idx]; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp new file mode 100644 index 0000000000000000000000000000000000000000..f476a2e3dd83e816af6d4fb819f5cdcdb062cf53 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp @@ -0,0 +1,23 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + +#if defined(DATA_D_BF16) + float f = float(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f)); +#elif !defined(OPTIMIZATION_ERROR_WORKAROUND) + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]); +#else + data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)]; +#endif +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp new file mode 100644 index 0000000000000000000000000000000000000000..dbc7daa3328f6e322f2bfbcd312d28bb7bce991a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp @@ -0,0 +1,51 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" +#include "dequant_funcs.comp" + +#if defined(DATA_A_IQ4_NL) +// 16 invocations needed for init_iq4nl_shmem +layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; +#else +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; +#endif + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); + if (gl_LocalInvocationIndex.x != 0) { + return; + } +#endif + + const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K; + + if (idx >= p.ne) { + return; + } + + uint dst_idx = get_doffset() + dst_idx(idx); + uint src_idx = src0_idx_quant(idx, QUANT_K); + + const uint a_offset = 0; + const uint ib = src_idx; + const vec2 dm = get_dm(ib, a_offset); + + [[unroll]] for (int j = 0; j < QUANT_K; j += 4) { + vec4 v = dequantize4(ib, j / QUANT_R, a_offset); + v = v * dm.x + vec4(dm.y); + +#if QUANT_R == 2 + data_d[dst_idx + j/2 + 0] = v[0]; + data_d[dst_idx + j/2 + QUANT_K/2 + 0] = v[1]; + data_d[dst_idx + j/2 + 1] = v[2]; + data_d[dst_idx + j/2 + QUANT_K/2 + 1] = v[3]; +#else + data_d[dst_idx + j + 0] = v[0]; + data_d[dst_idx + j + 1] = v[1]; + data_d[dst_idx + j + 2] = v[2]; + data_d[dst_idx + j + 3] = v[3]; +#endif + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp new file mode 100644 index 0000000000000000000000000000000000000000..9c76437d9b0b986405ffc8ac11abdfe417a0e2f8 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -0,0 +1,242 @@ +#version 450 + +#if RTE16 +#extension GL_EXT_spirv_intrinsics : enable +spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits +#endif // RTE16 + +#include "types.comp" +#include "generic_unary_head.comp" + +#if defined(DATA_A_IQ4_NL) +// 16 invocations needed for init_iq4nl_shmem +layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; +#else +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; +#endif + +layout (binding = 0) readonly buffer S {float data_s[];}; +layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];}; + +#if defined(DATA_A_Q4_0) +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; + float vmax = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_0; ++j) { + const float v = data_s[src_idx + j]; + if (amax < abs(v)) { + amax = abs(v); + vmax = v; + } + } + + const float d = vmax / -8; + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_0/2; ++j) { + const float x0 = data_s[src_idx + 0 + j]*id; + const float x1 = data_s[src_idx + QUANT_K_Q4_0/2 + j]*id; + + const uint xi0 = min(15, int(x0 + 8.5)); + const uint xi1 = min(15, int(x1 + 8.5)); + + data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4)); + } +} +#endif + +#if defined(DATA_A_Q4_1) +void quantize(uint dst_idx, uint src_idx) +{ + float vmin = 1.0/0.0; + float vmax = -vmin; + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_1; ++j) { + const float v = data_s[src_idx + j]; + + if (v < vmin) vmin = v; + if (v > vmax) vmax = v; + } + + const float d = (vmax - vmin) / ((1 << 4) - 1); + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + data_q[dst_idx].m = float16_t(vmin); + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_1/2; ++j) { + const float x0 = (data_s[src_idx + 0 + j] - vmin)*id; + const float x1 = (data_s[src_idx + QUANT_K_Q4_1/2 + j] - vmin)*id; + + const uint xi0 = min(15, int(x0 + 0.5)); + const uint xi1 = min(15, int(x1 + 0.5)); + + data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4)); + } +} +#endif + +#if defined(DATA_A_Q5_0) +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; + float vmax = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_Q5_0; ++j) { + const float v = data_s[src_idx + j]; + if (amax < abs(v)) { + amax = abs(v); + vmax = v; + } + } + + const float d = vmax / -16; + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + + uint32_t qh = 0; + [[unroll]] for (int j = 0; j < QUANT_K_Q5_0/2; ++j) { + const float x0 = data_s[src_idx + 0 + j]*id; + const float x1 = data_s[src_idx + QUANT_K_Q5_0/2 + j]*id; + + const uint xi0 = min(31, int(x0 + 16.5)); + const uint xi1 = min(31, int(x1 + 16.5)); + + data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4)); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_0/2); + } + data_q[dst_idx].qh[0] = uint16_t(qh & 0xFFFF); + data_q[dst_idx].qh[1] = uint16_t(qh >> 16); +} +#endif + +#if defined(DATA_A_Q5_1) +void quantize(uint dst_idx, uint src_idx) +{ + float min = data_s[src_idx + 0]; + float max = min; + + [[unroll]] for (int j = 1; j < QUANT_K_Q5_1; ++j) { + const float v = data_s[src_idx + j]; + min = v < min ? v : min; + max = v > max ? v : max; + } + + const float d = (max - min) / 31; + const float id = (d != 0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + data_q[dst_idx].m = float16_t(min); + + uint32_t qh = 0; + [[unroll]] for (int j = 0; j < QUANT_K_Q5_1/2; ++j) { + const float x0 = (data_s[src_idx + 0 + j] - min)*id; + const float x1 = (data_s[src_idx + QUANT_K_Q5_1/2 + j] - min)*id; + + const uint xi0 = uint(x0 + 0.5); + const uint xi1 = uint(x1 + 0.5); + + data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4)); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_1/2); + } + data_q[dst_idx].qh = qh; +} +#endif + +#if defined(DATA_A_Q8_0) +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; // absolute max + + [[unroll]] for (int j = 0; j < QUANT_K_Q8_0; j++) { + const float v = data_s[src_idx + j]; + amax = max(amax, abs(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + + [[unroll]] for (int j = 0; j < QUANT_K_Q8_0; ++j) { + const float x0 = data_s[src_idx + j]*id; + + data_q[dst_idx].qs[j] = int8_t(round(x0)); + } +} +#endif + +#if defined(DATA_A_IQ4_NL) +uint best_index(float x) { + if (x <= kvalues_iq4nl[0]) return 0; + if (x >= kvalues_iq4nl[15]) return 15; + int ml = 0, mu = 15; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < kvalues_iq4nl[mav]) mu = mav; else ml = mav; + } + return x - kvalues_iq4nl[mu-1] < kvalues_iq4nl[mu] - x ? mu-1 : mu; +} + +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; + float vmax = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL; ++j) { + const float v = data_s[src_idx + j]; + if (amax < abs(v)) { + amax = abs(v); + vmax = v; + } + } + + float d = vmax / kvalues_iq4nl[0]; + const float id = (d != 0.0) ? 1.0/d : 0.0; + + float sumqx = 0, sumq2 = 0; + [[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL/2; ++j) { + const float x0 = data_s[src_idx + 0 + j]*id; + const float x1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*id; + const uint xi0 = best_index(x0); + const uint xi1 = best_index(x1); + data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4)); + const float v0 = kvalues_iq4nl[xi0]; + const float v1 = kvalues_iq4nl[xi1]; + const float w0 = data_s[src_idx + 0 + j]*data_s[src_idx + 0 + j]; + const float w1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*data_s[src_idx + QUANT_K_IQ4_NL/2 + j]; + sumqx += w0*v0*data_s[src_idx + j] + w1*v1*data_s[src_idx + QUANT_K_IQ4_NL/2 + j]; + sumq2 += w0*v0*v0 + w1*v1*v1; + } + + data_q[dst_idx].d = float16_t(sumq2 > 0 ? sumqx/sumq2 : d); + +} +#endif + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); + if (gl_LocalInvocationIndex.x != 0) { + return; + } +#endif + + const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K; + + if (idx >= p.ne) { + return; + } + + uint dst_idx = dst_idx_quant(idx, QUANT_K); + uint src_idx = get_aoffset() + src0_idx(idx); + + quantize(dst_idx, src_idx); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp b/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp new file mode 100644 index 0000000000000000000000000000000000000000..0b8d02f58fc316a242d2f9b93bfdcf79eb977198 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val)); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp b/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp new file mode 100644 index 0000000000000000000000000000000000000000..d9345497c73fd30a3aa0c636384ff50ceb2ddf33 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp @@ -0,0 +1,31 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#include "types.comp" +#include "generic_head.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; +layout (binding = 2) buffer D {D_TYPE data_d[];}; + +const uint CHUNK_SIZE = 512; + +void main() { + const uint base = gl_WorkGroupID.x * CHUNK_SIZE; + const uint col = gl_LocalInvocationID.x; + + uint count = 0; + [[unroll]] + for (uint i = 0; i < CHUNK_SIZE; i += gl_WorkGroupSize.x) { + const uint idx = base + i + col; + if (idx >= p.KX) { + break; + } + count += uint(data_a[idx] == data_b[idx]); + } + + atomicAdd(data_d[0], D_TYPE(count)); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp new file mode 100644 index 0000000000000000000000000000000000000000..a4d3fca556208e339a08170194cd54198b0c3125 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp @@ -0,0 +1,20 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {float data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_GlobalInvocationID.x * 16; + + if (i >= p.nel) { + return; + } + + [[unroll]] for (uint l = 0; l < 16; l++) { + data_b[i + l] = D_TYPE(data_a[i + l]); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp new file mode 100644 index 0000000000000000000000000000000000000000..0d9739d40609af24b9dbc6d4b8124dae815f02b4 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp @@ -0,0 +1,462 @@ +#if !defined(DATA_A_F32) && !defined(DATA_A_F16) +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#endif + +#include "types.comp" + +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif + +#if defined(DATA_A_F32) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); +} +#endif + +#if defined(DATA_A_F16) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); +} +#endif + +#if defined(DATA_A_BF16) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1])); +} +#endif + +#if defined(DATA_A_Q4_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return (vec2(vui & 0xF, vui >> 4) - 8.0f); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12) - 8.0f); +} +#endif + +#if defined(DATA_A_Q4_1) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2(vui & 0xF, vui >> 4); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12); +} +#endif + +#if defined(DATA_A_Q5_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = uint(data_a[a_offset + ib].qh[1]) << 16 | data_a[a_offset + ib].qh[0]; + const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = uint(data_a_packed16[a_offset + ib].qh[1]) << 16 | data_a_packed16[a_offset + ib].qh[0]; + const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10); + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f); +} +#endif + +#if defined(DATA_A_Q5_1) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = data_a[a_offset + ib].qh; + const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = data_a_packed16[a_offset + ib].qh; + const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10); + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y); +} +#endif + +#if defined(DATA_A_Q8_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const i8vec2 v0 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2 + 1])).xy; + return vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + +#if defined(DATA_A_IQ1_S) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + const int i8 = int(iqs % 8); + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const float dl = float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint idxhi = bitfieldExtract(qh, 3 * int(ib8 & 3), 3); + const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]); + // Signed bitfield extract. + const ivec2 gvec = ivec2( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2) + ); + return dl * (vec2(gvec) + delta); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + const int i8 = int(iqs % 8); + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const float dl = 2 * bitfieldExtract(qh, 12, 3) + 1; + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); + // Signed bitfield extract. + const ivec4 gvec = ivec4( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2), + bitfieldExtract(grid, 2 * (i8 + 2), 2), + bitfieldExtract(grid, 2 * (i8 + 3), 2) + ); + return dl * (vec4(gvec) + delta); +} +#endif + +#if defined(DATA_A_IQ1_M) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib8 = iqs / 8; + const uint ib16 = iqs / 16; + const int i8 = int(iqs % 8); + const uint sc = data_a[a_offset + ib].scales[iqs / 64]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1)); + const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + // Signed bitfield extract. + const ivec2 gvec = ivec2( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2) + ); + return dl * (vec2(gvec) + delta); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib8 = iqs / 8; + const uint ib16 = iqs / 16; + const int i8 = int(iqs % 8); + const uint sc = data_a[a_offset + ib].scales[iqs / 64]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1)); + const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + // Signed bitfield extract. + const ivec4 gvec = ivec4( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2), + bitfieldExtract(grid, 2 * (i8 + 2), 2), + bitfieldExtract(grid, 2 * (i8 + 3), 2) + ); + return dl * (vec4(gvec) + delta); +} +#endif + +#if defined(DATA_A_IQ2_XXS) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = (iqs / 8) % 4; + const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8]; + // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale) + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2], + data_a_packed16[a_offset + ib].qs[4 * ib32 + 3])); + const float db = 0.25 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = (iqs / 8) % 4; + const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8]; + // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale) + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2], + data_a_packed16[a_offset + ib].qs[4 * ib32 + 3])); + const float db = 0.25 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ2_XS) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[iqs / 8]; + const float db = 0.25 * (0.5 + scale); + const uint sign7 = qs >> 9; + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[iqs / 8]; + const float db = 0.25 * (0.5 + scale); + const uint sign7 = qs >> 9; + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ2_S) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + + const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qhshift = 2 * (ib8 % 4); + const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8); + + const float db = 0.25 * (0.5 + scale); + const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid[iqs % 4] * (sign0 ? -1.0 : 1.0), + grid[(iqs % 4) + 1] * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + + const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qhshift = 2 * (ib8 % 4); + const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8); + + const float db = 0.25 * (0.5 + scale); + const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ3_XXS) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib4 = iqs / 4; + const uint ib32 = iqs / 32; + const uint is = QUANT_K / 4 + 4 * ib32; + const uint qs = data_a[a_offset + ib].qs[ib4]; + // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale) + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2], + data_a_packed16[a_offset + ib].qs[is / 2 + 1])); + const float db = 0.5 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq3xxs_grid[qs] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib4 = iqs / 4; + const uint ib32 = iqs / 32; + const uint is = QUANT_K / 4 + 4 * ib32; + const uint qs = data_a[a_offset + ib].qs[ib4]; + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2], + data_a_packed16[a_offset + ib].qs[is / 2 + 1])); + const float db = 0.5 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq3xxs_grid[qs]); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ3_S) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint qs = data_a[a_offset + ib].qs[iqs / 4]; + const uint qh = data_a[a_offset + ib].qh[iqs / 32]; + const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8); + const uint scale = data_a[a_offset + ib].scales[iqs / 64]; + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + const float db = 1 + 2 * ((scale >> (4 * ((iqs / 32) & 1))) & 0xf); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ((iqs / 4) % 8))) & 256)] >> (8 * (iqs % 4)); + return db * vec2( + int(grid & 0xFF) * (sign0 ? -1.0 : 1.0), + int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib4 = iqs / 4; + const uint ib32 = iqs / 32; + const uint qs = data_a[a_offset + ib].qs[ib4]; + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8); + const uint scale = data_a[a_offset + ib].scales[ib32 / 2]; + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + const float db = 1 + 2 * ((scale >> (4 * (ib32 & 1))) & 0xf); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ib4 % 8)) & 256)] >> (8 * (iqs % 4)); + return db * vec4( + int(grid & 0xFF) * (sign0 ? -1.0 : 1.0), + int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0), + int((grid >> 16) & 0xFF) * (sign2 ? -1.0 : 1.0), + int((grid >> 24) & 0xFF) * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ4_XS) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint iq = 16 * ib32 + (iqs % 16); + + const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3; + const uint qshift = (iqs & 16) >> 2; + u8vec2 qs = u8vec2(data_a[a_offset + ib].qs[iq], data_a[a_offset + ib].qs[iq + 1]); + qs = (qs >> qshift) & uint8_t(0xF); + + const float dl = float(int(sl | (sh << 4)) - 32); + return dl * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint iq = 16 * ib32 + (iqs % 16); + + const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3; + const uint qshift = (iqs & 16) >> 2; + u8vec4 qs = u8vec4( + data_a[a_offset + ib].qs[iq + 0], + data_a[a_offset + ib].qs[iq + 1], + data_a[a_offset + ib].qs[iq + 2], + data_a[a_offset + ib].qs[iq + 3] + ); + qs = (qs >> qshift) & uint8_t(0xF); + + const float dl = float(int(sl | (sh << 4)) - 32); + return dl * vec4( + kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y], + kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]); +} +#endif + +#if defined(DATA_A_IQ4_NL) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[vui >> 12]); +} +#endif + +#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(0, 0); +} +#endif + +#if defined(DATA_A_IQ1_M) +vec2 get_dm(uint ib, uint a_offset) { + const uint16_t[4] scales = data_a[a_offset + ib].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + return vec2(d, 0); +} +#endif + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(float(data_a[a_offset + ib].d), 0); +} +#endif + +#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m)); +} +#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp new file mode 100644 index 0000000000000000000000000000000000000000..9cb7da2daab5da60f63dc01580db0689fd42bbe2 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -0,0 +1,699 @@ + +#include "types.comp" + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { + block_q4_0_packed16 block; +}; + +float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = uint32_t(bl.block.qs[(idx & 0xE) >> 1]); + qs >>= shift; + qs &= 0x0F0F; + qs = unpack8(qs)[idx & 1]; + float16_t ret = (float16_t(qs) - float16_t(8)) * d; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 { + block_q4_1 block; +}; + +float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const float16_t m = bl.block.m; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + float16_t ret = float16_t(qs) * d + m; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 { + block_q5_0 block; +}; + +float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + + const uint uint_qh = uint(bl.block.qh[1]) << 16 | bl.block.qh[0]; + const uint qh = ((uint_qh >> idx) << 4) & 0x10; + + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + + float16_t ret = (float16_t(qs | qh) - float16_t(16)) * d; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 { + block_q5_1 block; +}; + +float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const float16_t m = bl.block.m; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + + const uint uint_qh = bl.block.qh; + const uint qh = ((uint_qh >> idx) << 4) & 0x10; + + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + + float16_t ret = float16_t(qs | qh) * d + m; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 { + block_q8_0_packed16 block; +}; + +float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint iqs = idx; + + // Load 16b and select the byte for this element + int32_t qs = unpack8(bl.block.qs[(iqs & 0x1E) >> 1])[iqs & 1]; + float16_t ret = float16_t(qs) * d; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K { + block_q2_K block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2_K_packed16 { + block_q2_K_packed16 block; +}; + +float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl); + const f16vec2 d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint scalesi = (idx & 0xF0) >> 4; // 0..15 + const uint qsshift = (idx & 0x60) >> 4; // 0,2,4,6 + + uint qs = uint32_t(bl16.block.qs[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]); + qs = (qs >> qsshift) & 0x0303; + qs = unpack8(qs)[idx & 1]; + + const uint scales = bl.block.scales[scalesi]; + float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4); + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K { + block_q3_K block; +}; + +float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + const uint iqs = idx; + + const uint n = iqs / 128; // 0,1 + const uint qsi = n * 32 + (iqs % 32); // 0..63 + const uint hmi = (iqs % 32); // 0..31 + const uint j = (iqs % 128) / 8; // 0..15 + const uint is = iqs / 16; // 0..15 + const uint halfsplit = ((iqs % 128) / 32); // 0,1,2,3 + const uint qsshift = halfsplit * 2; // 0,2,4,6 + const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 + + uint32_t scaleidx0 = (is < 8) ? is : (is-8); + uint32_t scaleidx0shift = (is < 8) ? 0 : 4; + uint32_t scaleidx1 = is + 8 - (is/4)*4; + uint32_t scaleidx1shift = (is/4)*2; + + const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4)); + + const float16_t dl = bl.block.d * float16_t(us - 32); + + float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi ] >> qsshift) & 3) - (((bl.block.hmask[hmi ] & m) != 0) ? 0 : 4)); + + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K { + block_q4_K block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed16 { + block_q4_K_packed16 block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 { + block_q4_K_packed128 block; +}; + +#if defined(IS_MUL_MM2) + +// For Q4_K and Q5_K in the mat-mul shader, we decode a tile's worth of scales +// into shared memory and then process the whole tile using those scales. +// There is a fetch function that loads into private variables and then a store +// function that stores into shared memory. +// Q4_K and Q5_K have the same encoding of scales, so everything is shared except +// the part that fetches from the structure (which has a different block layout). +#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) +const uint shAscales_stride = (BM + 2); +// 1 scale per 32 elements -> 8 scales per block, per row +shared vec2 shAscales[8 * shAscales_stride]; +uvec4 row_v; +#endif + +#if defined(DATA_A_Q4_K) +layout (binding = 0) readonly buffer A_Q4_K_128 {block_q4_K_packed128 data_a_q4_k_packed128[];}; + +void fetch_scalesQ4_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds) +{ + uint tids_per_row = BLOCK_SIZE / BM; + uint is_per_tid = 8 / tids_per_row; + uint is_start = is_per_tid * (tid % tids_per_row); + uint tid_row = tid / tids_per_row; + + uint row = ir_BM + tid_row; + uint block_index = pos_a + row * stride_a + (block_k / QUANT_K); + if (in_bounds || row < p.M) { + row_v = data_a_q4_k_packed128[block_index].q4k[0]; + } +} +#endif +#if defined(DATA_A_Q5_K) +layout (binding = 0) readonly buffer A_Q5_K_128 {block_q5_K_packed128 data_a_q5_k_packed128[];}; + +void fetch_scalesQ5_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds) +{ + uint tids_per_row = BLOCK_SIZE / BM; + uint is_per_tid = 8 / tids_per_row; + uint is_start = is_per_tid * (tid % tids_per_row); + uint tid_row = tid / tids_per_row; + + uint row = ir_BM + tid_row; + uint block_index = pos_a + row * stride_a + (block_k / QUANT_K); + if (in_bounds || row < p.M) { + row_v = data_a_q5_k_packed128[block_index].q5k[0]; + } +} +#endif + +#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) +void store_scalesQ4_K(uint tid) +{ + barrier(); + + uint tids_per_row = BLOCK_SIZE / BM; + uint is_per_tid = 8 / tids_per_row; + uint is_start = is_per_tid * (tid % tids_per_row); + uint tid_row = tid / tids_per_row; + + [[unroll]] for (uint idx = 0; idx < is_per_tid; ++idx) { + uint is = idx + is_start; + uvec4 v = row_v; + const vec2 loadd = vec2(unpackFloat2x16(v.x)); + + uint32_t sc; + uint32_t mbyte; + + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; + + const float d = loadd.x * float(sc); + const float m = loadd.y * float(mbyte); + shAscales[is * shAscales_stride + tid_row] = vec2(d,m); + } + + barrier(); +} +#endif + +#endif + +float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl); + decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl); + const uint idx = coordInBlock[1]; + + const uint b = (idx & 0x20) >> 5; // 0,1 + const uint is = (idx & 0xE0) >> 5; // 0..7 + +#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K) + vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)]; + float d = v.x; + float m = v.y; +#else + uvec4 v = bl128.block.q4k[0]; + const vec2 loadd = vec2(unpackFloat2x16(v.x)); + + uint32_t sc; + uint32_t mbyte; + + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; + + const float d = loadd.x * float(sc); + const float m = loadd.y * float(mbyte); +#endif + + uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); + qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF; + + float ret = d * float(qs) - m; + + return float16_t(ret); +} + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K { + block_q5_K block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed16 { + block_q5_K_packed16 block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed128 { + block_q5_K_packed128 block; +}; + +float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl); + decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl); + const uint idx = coordInBlock[1]; + + const uint b = (idx & 0x20) >> 5; // 0,1 + const uint is = (idx & 0xE0) >> 5; // 0..7 + +#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K) + vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)]; + float d = v.x; + float m = v.y; +#else + uvec4 v = bl128.block.q5k[0]; + + const f16vec2 loadd = unpackFloat2x16(v.x); + + uint32_t sc; + uint32_t mbyte; + + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; + + const float16_t d = loadd.x * float16_t(sc); + const float16_t m = loadd.y * float16_t(mbyte); +#endif + + uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]); + qh = ((qh >> is) & 0x101) << 4; + + uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); + qs = (qs >> (b * 4)) & 0x0F0F; + qs = unpack8(qs | qh)[idx & 1]; + + float ret = d * float(qs) - m; + + return float16_t(ret); +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K { + block_q6_K block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ6_K_packed16 { + block_q6_K_packed16 block; +}; + +float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint b = (idx & 0x40) >> 6; // 0,1 + const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6 + const uint is = (idx & 0xF0) >> 4; // 0..15 + + const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]); + + uint ql = uint32_t(bl16.block.ql[((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1)]); + ql = (ql >> (b * 4)) & 0x0F0F; + + uint qh = uint32_t(bl16.block.qh[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]); + qh = ((qh >> qhshift) & 0x0303) << 4; + + int q = unpack8(ql | qh)[idx & 1]; + + float16_t ret = dscale * float16_t(q - 32); + + return ret; +} + +#if defined(DATA_A_IQ1_S) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_S { + block_iq1_s block; +}; + +float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = (idx & 0xE0) >> 5; + const uint ib8 = (idx & 0xF8) >> 3; + + const uint qh = bl.block.qh[ib32]; + const uint qs = bl.block.qs[ib8]; + const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint grid = iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]; + + float16_t ret = float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * int(idx % 8), 2)) + float16_t(delta)); + return ret; +} +#endif + +#if defined(DATA_A_IQ1_M) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_M { + block_iq1_m block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufIQ1_M_packed64 { + block_iq1_m_packed64 block; +}; + +float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl); + const uint idx = coordInBlock[1]; + + uvec2 scales = unpack32(bl64.block.scales); + const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16))); + + const uint ib8 = (idx & 0xF8) >> 3; + const uint ib16 = (idx & 0xF0) >> 4; + const int i8 = int(idx % 8); + const uint sc = bl.block.scales[ib8 / 8]; + const uint qs = bl.block.qs[ib8]; + const uint qh = bl.block.qh[ib16] >> (4 * (ib8 & 1)); + const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1; + const float delta = ((qh & 8) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint grid = iq1s_grid[qs | ((qh & 7) << 8)]; + + float16_t ret = d * float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * i8, 2)) + float16_t(delta)); + return ret; +} +#endif + +#if defined(DATA_A_IQ2_XXS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS { + block_iq2_xxs block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS_packed16 { + block_iq2_xxs_packed16 block; +}; + +float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl); + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = (idx & 0xE0) >> 5; // 0..7 + const uint ib8 = (idx & 0x18) >> 3; // 0..3 + const uint iqs = 8 * ib32 + ib8; + + const uint qs = bl.block.qs[iqs]; + const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3])); + + const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28)); + uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7); + sign |= bitCount(sign) << 7; + + uint g2 = iq2xxs_grid[qs][(idx & 4) >> 2]; + g2 >>= (idx & 2) * 8; + const vec2 g = vec2(unpack8(g2)); + + vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); + return float16_t(ret[idx & 1]); +} +#endif + +#if defined(DATA_A_IQ2_XS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XS { + block_iq2_xs block; +}; + +float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint is = (idx & 0xE0) >> 5; // 0..8 + const uint sshift = (idx & 0x10) >> 2; // 0,4 + const uint iqs = (idx & 0xF8) >> 3; // 0..63 + + const uint16_t qs = bl.block.qs[iqs]; + const float dscale = float(bl.block.d) * 0.25 * (0.5 + float((bl.block.scales[is] >> sshift) & 0xF)); + + uint sign = uint(qs >> 9); + sign |= bitCount(sign) << 7; + uint g2 = iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2]; + g2 >>= (idx & 2) * 8; + const vec2 g = vec2(unpack8(g2)); + + vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); + return float16_t(ret[idx & 1]); +} +#endif + +#if defined(DATA_A_IQ2_S) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_S { + block_iq2_s block; +}; + +float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + uint idx = coordInBlock[1]; + + const uint ib32 = (idx & 0xE0) >> 5; // 0..7 + const uint ib8 = (idx & 0xF8) >> 3; // 0..31 + const uint qhshift = 2 * (ib8 % 4); + + const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 2)) & 0xf; + const uint qs = bl.block.qs[ib8]; + const uint qh = bl.block.qh[ib32]; + const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (idx & 0x6); + + const float d = float(bl.block.d); + const float db = d * 0.25 * (0.5 + scale); + const ivec2 sign01 = 1 - (2 & ivec2(sign << 1, sign)); + uint g2 = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2]; + g2 >>= (idx & 2) * 8; + const vec2 v = db * vec2(sign01) * vec2(unpack8(g2)); + return float16_t(v[idx & 1]); +} +#endif + +#if defined(DATA_A_IQ3_XXS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS { + block_iq3_xxs block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS_packed16 { + block_iq3_xxs_packed16 block; +}; + +float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl); + uint idx = coordInBlock[1]; + + const uint iqs = (idx & 0xFC) >> 2; // 0..63 + const uint is = QUANT_K / 4 + ((idx & 0xE0) >> 3);// 8 values + + const float d = float(bl.block.d); + const uint qs = bl.block.qs[iqs]; + const uint signs = pack32(u16vec2( + bl16.block.qs[is/2+0], + bl16.block.qs[is/2+1] + )); + const float db = d * 0.5 * (0.5 + (signs >> 28)); + const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (idx & 0x6); + const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign))); + const uint grid = iq3xxs_grid[qs] >> (16 * ((idx & 2) >> 1)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + return float16_t(v[idx & 1]); +} +#endif + +#if defined(DATA_A_IQ3_S) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_S { + block_iq3_s block; +}; + +float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + uint idx = coordInBlock[1]; + + const uint iqs = (idx & 0xFC) >> 2; // 0..63 + const uint iqh = (idx & 0xE0) >> 5; + + const float d = float(bl.block.d); + const uint qs = bl.block.qs[iqs]; + const uint qh = bl.block.qh[iqh]; + const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (idx & 0x6)); + const uint scale = bl.block.scales[iqs / 16]; + const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign))); + const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> ((idx & 2) << 3); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + + return float16_t(v[idx & 1]); +} +#endif + +#if defined(DATA_A_IQ4_XS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_XS { + block_iq4_xs block; +}; + +float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = (idx & 0xE0) >> 5; // 0..7 + + const uint sl = (bl.block.scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const uint sh = ((bl.block.scales_h) >> (2 * ib32)) & 3; + const uint qshift = (idx & 16) >> 2; + const uint q = (bl.block.qs[16 * ib32 + (idx % 16)] >> qshift) & 0xF; + + float16_t ret = d * float16_t(int(sl | (sh << 4)) - 32) * float16_t(kvalues_iq4nl[q]); + return ret; +} +#endif + +#if defined(DATA_A_IQ4_NL) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL { + block_iq4_nl block; +}; + +float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + float16_t ret = float16_t(kvalues_iq4nl[qs]) * d; + return ret; +} +#endif + +#if defined(DATA_A_Q4_0) +#define dequantFuncA dequantFuncQ4_0 +#elif defined(DATA_A_Q4_1) +#define dequantFuncA dequantFuncQ4_1 +#elif defined(DATA_A_Q5_0) +#define dequantFuncA dequantFuncQ5_0 +#elif defined(DATA_A_Q5_1) +#define dequantFuncA dequantFuncQ5_1 +#elif defined(DATA_A_Q8_0) +#define dequantFuncA dequantFuncQ8_0 +#elif defined(DATA_A_Q2_K) +#define dequantFuncA dequantFuncQ2_K +#elif defined(DATA_A_Q3_K) +#define dequantFuncA dequantFuncQ3_K +#elif defined(DATA_A_Q4_K) +#define dequantFuncA dequantFuncQ4_K +#define fetch_scales fetch_scalesQ4_K +#define store_scales store_scalesQ4_K +#elif defined(DATA_A_Q5_K) +#define dequantFuncA dequantFuncQ5_K +#define fetch_scales fetch_scalesQ5_K +#define store_scales store_scalesQ4_K +#elif defined(DATA_A_Q6_K) +#define dequantFuncA dequantFuncQ6_K +#elif defined(DATA_A_IQ1_S) +#define dequantFuncA dequantFuncIQ1_S +#elif defined(DATA_A_IQ1_M) +#define dequantFuncA dequantFuncIQ1_M +#elif defined(DATA_A_IQ2_XXS) +#define dequantFuncA dequantFuncIQ2_XXS +#elif defined(DATA_A_IQ2_XS) +#define dequantFuncA dequantFuncIQ2_XS +#elif defined(DATA_A_IQ2_S) +#define dequantFuncA dequantFuncIQ2_S +#elif defined(DATA_A_IQ3_XXS) +#define dequantFuncA dequantFuncIQ3_XXS +#elif defined(DATA_A_IQ3_S) +#define dequantFuncA dequantFuncIQ3_S +#elif defined(DATA_A_IQ4_XS) +#define dequantFuncA dequantFuncIQ4_XS +#elif defined(DATA_A_IQ4_NL) +#define dequantFuncA dequantFuncIQ4_NL +#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp new file mode 100644 index 0000000000000000000000000000000000000000..8d806435b7163f60b1e71e7e5d1c1a4c7b626bc6 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp @@ -0,0 +1,13 @@ +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint M; + uint K; + uint stride_a; + uint stride_b; + uint nel; +} p; + +#include "types.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp new file mode 100644 index 0000000000000000000000000000000000000000..b604c1881a5eae903e0b86cb38ed591468f3f6b0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp @@ -0,0 +1,42 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq1_m data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint ib64 = ib32 / 2; + const uint b_idx = 256 * ib + 32 * ib32; + + const uint16_t[4] scales = data_a[ib].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + + const uint sc = data_a[ib].scales[ib64]; + [[unroll]] for (int l = 0; l < 4; ++l) { + const uint ib16 = 2 * ib32 + l / 2; + const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); + const uint qh = data_a[ib].qh[ib16] >> (4 * (l & 1)); + const uint qs = data_a[ib].qs[4 * ib32 + l]; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + [[unroll]] for (int j = 0; j < 8; ++j) { + data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta)); + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp new file mode 100644 index 0000000000000000000000000000000000000000..fd1e4e30d252bff8c28061fa56203c719d84af32 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp @@ -0,0 +1,35 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq1_s data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * ib32; + + uint qh = data_a[ib].qh[ib32]; + const float d = float(data_a[ib].d); + const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint qs = data_a[ib].qs[4 * ib32 + l]; + const uint hi = bitfieldExtract(qh, 3 * int(l), 3); + const int16_t grid = int16_t(iq1s_grid[qs | (hi << 8)]); + [[unroll]] for (int j = 0; j < 8; ++j) { + data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta)); + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp new file mode 100644 index 0000000000000000000000000000000000000000..48f6b65bc40ce143912ed9a8f7dc80ca25b52f6f --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp @@ -0,0 +1,44 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq2_s data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * ib32; + + const float d = float(data_a[ib].d); + const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4); + const vec2 db = d * (0.5 + scale) * 0.25; + + uint qh = data_a[ib].qh[ib32]; + [[unroll]] for (uint l = 0; l < 4; ++l) { + uint qs = data_a[ib].qs[4 * ib32 + l]; + const uint8_t sign = data_a[ib].qs[QUANT_K / 8 + 4 * ib32 + l]; + qs |= (qh << (8 - 2 * l)) & 0x300; + const uvec2 grid = iq2s_grid[qs & 511]; + const u8vec4 grid0 = unpack8(grid.x); + const u8vec4 grid1 = unpack8(grid.y); + data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp new file mode 100644 index 0000000000000000000000000000000000000000..a08331c40de322ba4253037d3b77eb43df2138f0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp @@ -0,0 +1,43 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq2_xs data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * ib32; + + const float d = float(data_a[ib].d); + const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4); + const vec2 db = d * (0.5 + scale) * 0.25; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + uint16_t qs = data_a[ib].qs[4 * ib32 + l]; + const uint sign7 = qs >> 9; + const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit + const uvec2 grid = iq2xs_grid[qs & 511]; + const u8vec4 grid0 = unpack8(grid.x); + const u8vec4 grid1 = unpack8(grid.y); + data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp new file mode 100644 index 0000000000000000000000000000000000000000..e370690bcb089c8caeeb8c91a241095f4bcd1297 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp @@ -0,0 +1,48 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq2_xxs data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 scale block (32 values) + // Each block is described by 4 lattice indices, 4x7 sign bits and 4 scale bits + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint is = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * is; + + const float d = float(data_a[ib].d); + uint signscale = pack32(u8vec4( + data_a[ib].qs[8*is + 4], + data_a[ib].qs[8*is + 5], + data_a[ib].qs[8*is + 6], + data_a[ib].qs[8*is + 7] + )); + const float db = d * (0.5 + (signscale >> 28)) * 0.25; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7); + const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit + const uvec2 grid = iq2xxs_grid[data_a[ib].qs[8 * is + l]]; + const u8vec4 grid0 = unpack8(grid.x); + const u8vec4 grid1 = unpack8(grid.y); + data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp new file mode 100644 index 0000000000000000000000000000000000000000..c3f4bca5d95e2a4a1355c635e5e0cafa38bccec1 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp @@ -0,0 +1,39 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq3_s data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 scale nibble. + // Each block contains 4 scale bytes (8 scales) for 256 output values. + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint is = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * is; + + const float d = float(data_a[ib].d); + const float db = d * (1 + 2 * ((data_a[ib].scales[is] >> (4 * (is % 2))) & 0xf)); + + // We must produce 32 values using 4 sign bytes, 1 qh byte, 8 qs bytes. + uint qh = data_a[ib].qh[is]; + [[unroll]] for (uint l = 0; l < 8; ++l) { + uint qs = data_a[ib].qs[8 * is + l]; + uint gidx = qs | ((qh << (8 - l)) & 256); + uint8_t signs = data_a[ib].signs[8 * is + l / 2] >> (4 * (l & 1)); + u8vec4 grid = unpack8(iq3s_grid[gidx]); + data_b[b_idx + 4 * l + 0] = D_TYPE(db * grid.x * ((signs & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 4 * l + 1] = D_TYPE(db * grid.y * ((signs & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 4 * l + 2] = D_TYPE(db * grid.z * ((signs & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 4 * l + 3] = D_TYPE(db * grid.w * ((signs & 8) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp new file mode 100644 index 0000000000000000000000000000000000000000..a92b82961afdae747cfc5fbaea6d405d46357ea0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp @@ -0,0 +1,49 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq3_xxs data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 scale block (32 values) + // 8 threads handle 1 superblock + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint is = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * is; + const uint s_idx = QUANT_K / 4 + 4 * is; + + const float d = float(data_a[ib].d); + uint signscale = pack32(u8vec4( + data_a[ib].qs[s_idx + 0], + data_a[ib].qs[s_idx + 1], + data_a[ib].qs[s_idx + 2], + data_a[ib].qs[s_idx + 3] + )); + const float db = d * (0.5 + (signscale >> 28)) * 0.5; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7); + // Restore parity bit. + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const u8vec4 grid0 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l]]); + const u8vec4 grid1 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l + 1]]); + data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp new file mode 100644 index 0000000000000000000000000000000000000000..46d9ad15ebafc5f4c045feb7b0094e6eb58bb1e6 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp @@ -0,0 +1,32 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq4_nl data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + init_iq_shmem(gl_WorkGroupSize); + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint q_idx = 8*il; + const uint b_idx = 1024*i + 32*ir + q_idx; + + const float d = float(data_a[ib].d); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]); + data_b[b_idx + l + 16] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp new file mode 100644 index 0000000000000000000000000000000000000000..f930852a48a74d0fa5538c8a20ff50522e9b5bc0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp @@ -0,0 +1,34 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq4_xs data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (1 scale and 32 quantized values) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + + const float d = float(data_a[ib].d); + // Scales are 6 bits + const uint scale = ((data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF) + | (((data_a[ib].scales_h >> (2 * ib32)) & 3) << 4); + const float dl = d * (int(scale) - 32); + + const uint b_idx = 256 * ib + 32 * ib32; + const uint q_idx = 16 * ib32; + [[unroll]] for (uint l = 0; l < 16; ++l) { + data_b[b_idx + l + 0] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]); + data_b[b_idx + l + 16] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp new file mode 100644 index 0000000000000000000000000000000000000000..157154af3a328d5df56e8dcbc52b5cde9f385937 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp @@ -0,0 +1,34 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint i = gl_WorkGroupID.x * 256 + wgy; + if (i >= p.M * p.K / QUANT_K) { + return; + } + + const uint tid = gl_LocalInvocationID.x; + const uint ip = tid / 32; + const uint il = tid - 32 * ip; + const uint is = 8 * ip + il / 16; + + const uint y_idx = i * QUANT_K + 128 * ip + il; + + const uint ql_idx = 32 * ip + il; + const uint8_t qs = data_a[i].qs[32 * ip + il]; + + FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); + FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y); + data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4)); + data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4)); + data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4)); + data_b[y_idx + 96] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+6] & 0xF) * ((qs >> 6) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+6] >> 4)); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp new file mode 100644 index 0000000000000000000000000000000000000000..c17dd0d99911668af5e243833df66e12acd506b1 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp @@ -0,0 +1,42 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint i = uint(gl_WorkGroupID.x * 256 + wgy); + if (i >= p.M * p.K / QUANT_K) { + return; + } + + const uint r = gl_LocalInvocationID.x / 4; + const uint tid = r / 2; + const uint is0 = r % 2; + const uint l0 = 16 * is0 + 4 * (gl_LocalInvocationID.x % 4); + const uint n = tid / 4; + const uint j = tid - 4*n; + + const uint8_t m = uint8_t(1 << (4*n + j)); + const uint is = 8*n + 2*j + is0; + const uint shift = 2*j; + + const int8_t us = int8_t(is < 4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is+0] >> 4) & 3) << 4) : + (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is-4] >> 6) & 3) << 4)); + const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d); + const FLOAT_TYPE dl = d_all * FLOAT_TYPE(us - 32); + + const uint y_idx = i * QUANT_K + 128 * n + 32 * j; + const uint qs_idx = 32*n; + + for (uint l = l0; l < l0 + 4; ++l) { + data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4))); + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp new file mode 100644 index 0000000000000000000000000000000000000000..408185327255b4ce769d4d0636cf3a7b3e684b84 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp @@ -0,0 +1,30 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q4_0 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint q_idx = 8*il; + const uint b_idx = 1024*i + 32*ir + q_idx; + + const float d = float(data_a[ib].d); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] & 0xF) - 8.0f)); + data_b[b_idx + l + 16] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] >> 4) - 8.0f)); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp new file mode 100644 index 0000000000000000000000000000000000000000..2f27eee686eb9cc2787163626dad39cb03023910 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp @@ -0,0 +1,32 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q4_1 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint b_idx = 1024*i + 32*ir + 8*il; + + const float d = float(data_a[ib].d); + const float m = float(data_a[ib].m); + + const uint q_idx = 8*il; + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + m); + data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + m); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp new file mode 100644 index 0000000000000000000000000000000000000000..987f113a35ad030f60d837045199ed4cfd107fb1 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp @@ -0,0 +1,68 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint ib = gl_WorkGroupID.x * 256 + wgy; + if (ib >= p.M * p.K / QUANT_K) { + return; + } + + const uint tid = gl_LocalInvocationID.x; + const uint il = tid / 8; + const uint ir = tid % 8; + const uint is = 2 * il; + const uint n = 4; + + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); + + const uint y_idx = ib * QUANT_K + 64 * il + n * ir; + const uint qs_idx = 32*il + n * ir; + + uint scidx0 = (is < 4) ? is : (is + 4); + uint scidx1 = (is < 4) ? is : (is - 4); + uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint scidxshift1 = (is < 4) ? 0 : 2; + uint mbidx0 = is + 4; + uint mbidx1 = (is < 4) ? is + 4 : is; + uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + uint mbidxshift0 = (is < 4) ? 0 : 4; + uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint mbidxshift1 = (is < 4) ? 0 : 2; + + uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d1 = dall * sc; + const FLOAT_TYPE m1 = dmin * mbyte; + + scidx0 = (is < 4) ? is + 1 : (is + 5); + scidx1 = (is < 4) ? is + 1 : (is - 3); + scidxmask1 = (is < 4) ? 0x30 : 0xC0; + scidxshift1 = (is < 4) ? 0 : 2; + mbidx0 = is + 5; + mbidx1 = (is < 4) ? is + 5 : is + 1; + mbidxmask0 = (is < 4) ? 0xF : 0xF0; + mbidxshift0 = (is < 4) ? 0 : 4; + mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + mbidxshift1 = (is < 4) ? 0 : 2; + + sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d2 = dall * sc; + const FLOAT_TYPE m2 = dmin * mbyte; + + [[unroll]] for (uint l = 0; l < n; ++l) { + data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] & 0xF) - m1); + data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] >> 4) - m2); + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp new file mode 100644 index 0000000000000000000000000000000000000000..b20b80529217412aa8932bd976ac7abdf57d3e53 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp @@ -0,0 +1,34 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q5_0 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint b_idx = 1024*i + 32*ir + 8*il; + + const float d = float(data_a[ib].d); + const uint qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; + + const uint q_idx = 8*il; + + [[unroll]] for (uint l = 0; l < 8; ++l) { + const uint iqs = q_idx + l; + const uint vui = uint(data_a[ib].qs[iqs]); + data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10)) - 16.0f)); + data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10)) - 16.0f)); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp new file mode 100644 index 0000000000000000000000000000000000000000..dc59fe3b77ee3b5e822cb9dd8b76032989478b97 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp @@ -0,0 +1,35 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q5_1 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint b_idx = 1024*i + 32*ir + 8*il; + + const float d = float(data_a[ib].d); + const float m = float(data_a[ib].m); + const uint qh = data_a[ib].qh; + + const uint q_idx = 8*il; + + [[unroll]] for (uint l = 0; l < 8; ++l) { + const uint iqs = q_idx + l; + const uint vui = uint(data_a[ib].qs[iqs]); + data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10))) + m); + data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10))) + m); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp new file mode 100644 index 0000000000000000000000000000000000000000..6db5403b6613ec86233a0c2c927687ea8c56fb09 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp @@ -0,0 +1,70 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint ib = gl_WorkGroupID.x * 256 + wgy; + if (ib >= p.M * p.K / QUANT_K) { + return; + } + + const uint tid = gl_LocalInvocationID.x; + const uint il = tid / 16; + const uint ir = tid % 16; + const uint is = 2 * il; + + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); + + const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir; + const uint qs_idx = 32*il + 2 * ir; + const uint qh_idx = 2 * ir; + + uint scidx0 = (is < 4) ? is : (is + 4); + uint scidx1 = (is < 4) ? is : (is - 4); + uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint scidxshift1 = (is < 4) ? 0 : 2; + uint mbidx0 = is + 4; + uint mbidx1 = (is < 4) ? is + 4 : is; + uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + uint mbidxshift0 = (is < 4) ? 0 : 4; + uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint mbidxshift1 = (is < 4) ? 0 : 2; + + uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d1 = dall * sc; + const FLOAT_TYPE m1 = dmin * mbyte; + + scidx0 = (is < 4) ? is + 1 : (is + 5); + scidx1 = (is < 4) ? is + 1 : (is - 3); + scidxmask1 = (is < 4) ? 0x30 : 0xC0; + scidxshift1 = (is < 4) ? 0 : 2; + mbidx0 = is + 5; + mbidx1 = (is < 4) ? is + 5 : is + 1; + mbidxmask0 = (is < 4) ? 0xF : 0xF0; + mbidxshift0 = (is < 4) ? 0 : 4; + mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + mbidxshift1 = (is < 4) ? 0 : 2; + + sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d2 = dall * sc; + const FLOAT_TYPE m2 = dmin * mbyte; + + const uint8_t hm1 = uint8_t(1 << (2 * il )); + const uint8_t hm2 = uint8_t(1 << (2 * il + 1)); + data_b[y_idx ] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] & 0xF) + (((data_a[ib].qh[qh_idx ] & hm1) != 0) ? 16 : 0)) - m1); + data_b[y_idx + 1] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] & 0xF) + (((data_a[ib].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1); + data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] >> 4) + (((data_a[ib].qh[qh_idx ] & hm2) != 0) ? 16 : 0)) - m2); + data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] >> 4) + (((data_a[ib].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp new file mode 100644 index 0000000000000000000000000000000000000000..0b91317550f97989cbfb707bf2624c030eeacaec --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp @@ -0,0 +1,33 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint i = gl_WorkGroupID.x * 256 + wgy; + if (i >= p.M * p.K / QUANT_K) { + return; + } + const uint tid = gl_LocalInvocationID.x; + const uint ip = tid / 32; + const uint il = tid - 32 * ip; + const uint is = 8 * ip + il / 16; + + const uint y_idx = i * QUANT_K + 128 * ip + il; + + const uint ql_idx = 64 * ip + il; + const uint8_t qh = data_a[i].qh[32 * ip + il]; + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d); + + data_b[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 0] * (int8_t((data_a[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32))); + data_b[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 2] * (int8_t((data_a[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32))); + data_b[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 4] * (int8_t((data_a[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32))); + data_b[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 6] * (int8_t((data_a[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32))); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp new file mode 100644 index 0000000000000000000000000000000000000000..bd1344a88d129555fea5562ced72805c3fca52a3 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp @@ -0,0 +1,31 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q8_0 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint b_idx = 1024*i + 32*ir + 16*il; + + const float d = float(data_a[ib].d); + + const uint q_idx = 16*il; + + [[unroll]] for (uint l = 0; l < 16; l += 2) { + data_b[b_idx + l ] = D_TYPE(d * data_a[ib].qs[q_idx + l ]); + data_b[b_idx + l + 1] = D_TYPE(d * data_a[ib].qs[q_idx + l + 1]); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp b/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp new file mode 100644 index 0000000000000000000000000000000000000000..26d8bc22ad7fd2a2fb62d44ba9778f096b17f90b --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp @@ -0,0 +1,34 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : enable + +layout (push_constant) uniform parameter +{ + uint ncols; + uint rows_per_channel; + uint n_past; +} p; + +#include "types.comp" + +layout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint col = gl_GlobalInvocationID.y; + const uint row = gl_GlobalInvocationID.x; + + if (col >= p.ncols) { + return; + } + + const uint i = row*p.ncols + col; + if (col > p.n_past + row % p.rows_per_channel) { + data_d[i] = D_TYPE(uintBitsToFloat(0xFF800000)); + } else { + data_d[i] = D_TYPE(data_a[i]); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/div.comp b/ggml/src/ggml-vulkan/vulkan-shaders/div.comp new file mode 100644 index 0000000000000000000000000000000000000000..9fb69c6c15b69f19b6a41d8fcfc165c0038ab517 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/div.comp @@ -0,0 +1,27 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + + idx += num_threads; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp new file mode 100644 index 0000000000000000000000000000000000000000..ce230a8f7d91038ed84f2ba4841052f110e5fbf7 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -0,0 +1,337 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#extension GL_KHR_shader_subgroup_shuffle : enable + +#include "types.comp" +#include "flash_attn_base.comp" + +const uint32_t D_per_thread = D / D_split; + +const uint32_t cols_per_iter = WorkGroupSize / D_split; +const uint32_t cols_per_thread = Bc / cols_per_iter; + + +layout (binding = 0) readonly buffer Q {float data_q[];}; +layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; +layout (binding = 1) readonly buffer K {float16_t data_k[];}; +layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; +layout (binding = 2) readonly buffer V {float16_t data_v[];}; +layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; +layout (binding = 3) readonly buffer M {float16_t data_m[];}; + +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + uint32_t offset = (iq2 + r) * D + c; + data_o[o_offset + offset] = D_TYPE(elem); + return elem; +} + +shared FLOAT_TYPE tmpsh[WorkGroupSize]; +shared vec4 tmpshv4[WorkGroupSize]; + +shared float masksh[Bc][Br]; +shared vec4 Qf[Br][D / 4]; + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + init_indices(); + + const uint32_t tid = gl_LocalInvocationIndex; + const uint32_t d_tid = gl_LocalInvocationIndex % D_split; + const uint32_t col_tid = gl_LocalInvocationIndex / D_split; + + uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + + [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (D / 4); + uint32_t r = (idx + tid) / (D / 4); + if (r < Br && d < D / 4 && + i * Br + r < N) { + Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale; + } + } + barrier(); + + vec4 Of[Br][D_per_thread / 4]; + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] = vec4(0.0); + } + } + + float Lf[Br], Mf[Br]; + + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Lf[r] = 0; + Mf[r] = NEG_FLT_MAX_OVER_2; + } + + float slope[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + slope[r] = 1.0; + } + + // ALiBi + if (p.max_bias > 0.0f) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); + } + } + +#if BLOCK_SIZE > 1 + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; +#else + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; +#endif + + [[dont_unroll]] + for (uint32_t j = start_j; j < end_j; ++j) { + + float Sf[Br][cols_per_thread]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Sf[r][c] = 0.0; + } + } + + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); +#else + vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf); + } + } + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + // Compute sum across the D_split + [[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Sf[r][c] += subgroupShuffleXor(Sf[r][c], s); + } + } + } + + if (p.logit_softcap != 0.0f) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]); + } + } + } + + if (p.mask != 0) { + + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) % Bc; + uint32_t r = (idx + tid) / Bc; + if (idx + tid < Bc * Br) { + masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]); + } + } + barrier(); + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float mvf = masksh[c * cols_per_iter + col_tid][r]; + + Sf[r][c] += slope[r]*mvf; + } + } + barrier(); + } + + float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + rowmaxf[r] = Sf[r][0]; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + rowmaxf[r] = max(rowmaxf[r], Sf[r][c]); + } + Moldf[r] = Mf[r]; + + // M = max(rowmax, Mold) + // P = e^(S - M) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf[r], Moldf[r]); + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Pf[r][c] = exp(Sf[r][c] - Mf[r]); + } + eMf[r] = exp(Moldf[r] - Mf[r]); + + // Compute sum across row of P + rowsumf[r] = 0.0; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + rowsumf[r] += Pf[r][c]; + } + + Lf[r] = eMf[r]*Lf[r] + rowsumf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] = eMf[r] * Of[r][d]; + } + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); +#else + vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] += Pf[r][c] * Vf; + } + } + } + + barrier(); + } + + // reduce across threads + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float rowmaxf, eMf; + + tmpsh[tid] = Mf[r]; + // Compute max across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]); + } + barrier(); + } + rowmaxf = tmpsh[d_tid]; + barrier(); + + float Moldf = Mf[r]; + + // M = max(rowmax, Mold) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf, Moldf); + eMf = exp(Moldf - Mf[r]); + + Lf[r] = eMf*Lf[r]; + + tmpsh[tid] = Lf[r]; + + // Compute sum across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s]; + } + barrier(); + } + Lf[r] = tmpsh[d_tid]; + barrier(); + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + + Of[r][d] = eMf * Of[r][d]; + tmpshv4[tid] = Of[r][d]; + + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + Of[r][d] += tmpshv4[tid + s]; + tmpshv4[tid] = Of[r][d]; + } + barrier(); + } + Of[r][d] = tmpshv4[d_tid]; + barrier(); + } + } + + + // If there is split_k, then the split_k resolve shader does the final + // division by L. Store the intermediate O value and per-row m and L values. + if (p.k_num > 1) { + uint32_t o_offset = D * p.ne1 * split_k_index; + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); + } + } + } + } + + o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } + } + + return; + } + + float Lfrcp[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Lfrcp[r] = 1.0 / Lf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] *= Lfrcp[r]; + } + } + + uint32_t o_offset = iq3*p.ne2*p.ne1; + + if (p.gqa_ratio > 1) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); + } + } + } + } + } else { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (i * Br + r < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + } + } + } + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp new file mode 100644 index 0000000000000000000000000000000000000000..61d90e2d8ed219448c2d9cd710fe6165b399b4b0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp @@ -0,0 +1,162 @@ + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (constant_id = 0) const uint32_t WorkGroupSize = 128; +layout (constant_id = 1) const uint32_t Br = 1; +layout (constant_id = 2) const uint32_t Bc = 32; +layout (constant_id = 3) const uint32_t D = 32; +layout (constant_id = 4) const uint32_t Clamp = 0; +layout (constant_id = 5) const uint32_t D_split = 16; + + +layout (push_constant) uniform parameter { + uint32_t N; + uint32_t KV; + + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + + uint32_t neq2; + uint32_t neq3; + uint32_t nek2; + uint32_t nek3; + uint32_t nev2; + uint32_t nev3; + uint32_t nem1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t nb21; + uint32_t nb22; + uint32_t nb23; + uint32_t nb31; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t mask; + uint32_t n_head_log2; + float m0; + float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; +} p; + +layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; + +#if defined(A_TYPE_PACKED16) +#define BINDING_IDX_K 0 +#define BINDING_IDX_V 1 +layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; +#endif + +#if defined(DATA_A_Q4_0) +#define BLOCK_BYTE_SIZE 18 + +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); +} +#endif + +#if defined(DATA_A_Q8_0) +#define BLOCK_BYTE_SIZE 34 +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + + +// Store column zero. This is used to save per-row m and L values for split_k. +ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c == 0) { + uint32_t offset = iq2 + r; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +// Load the slope matrix, indexed by Q's dimension 2. +ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) +{ + const uint32_t h = iq2 + (r % p.gqa_ratio); + + const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); + + return ACC_TYPE(pow(base, ACC_TYPE(exph))); +} + +uint32_t i, N, KV, split_k_index, Tr, start_j, end_j, + iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, + q_stride, k_stride, v_stride, m_stride; + +void init_indices() +{ + N = p.N; + KV = p.KV; + + i = gl_WorkGroupID.x; + split_k_index = 0; + + if (p.k_num > 1) { + i = 0; + split_k_index = gl_WorkGroupID.x; + } + + Tr = CEIL_DIV(N, Br); + + start_j = split_k_index * p.split_kv / Bc; + end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); + + // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. + // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. + iq2 = gl_WorkGroupID.y * p.gqa_ratio; + iq3 = gl_WorkGroupID.z; + + // broadcast factors + rk2 = p.neq2/p.nek2; + rk3 = p.neq3/p.nek3; + + rv2 = p.neq2/p.nev2; + rv3 = p.neq3/p.nev3; + + // k indices + ik3 = iq3 / rk3; + ik2 = iq2 / rk2; + + // v indices + iv3 = iq3 / rv3; + iv2 = iq2 / rv2; + + // nb?1 are already divided by the type size and are in units of elements. + // When using grouped query attention, Q is indexed by iq2, so the stride + // should be nb02 (which is in bytes). + q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; + k_stride = p.nb11; + v_stride = p.nb21; + // When using grouped query attention, all rows use the same mask (stride 0). + // "p.gqa_ratio >> 16" is just a roundabout way of writing zero + // that prevents the compiler from folding the "&" through the select + // and breaking the alignment detection. + m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp new file mode 100644 index 0000000000000000000000000000000000000000..da478be24fb6e7298ddff54438bf00313075a9c2 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -0,0 +1,360 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_cooperative_matrix : enable + +#include "types.comp" +#include "flash_attn_base.comp" + +const uint32_t D_per_thread = D / D_split; +const uint32_t row_split = 4; +const uint32_t rows_per_thread = Br / row_split; +const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split; +const uint32_t cols_per_thread = Bc / cols_per_iter; + + +layout (binding = 0) readonly buffer Q {float data_q[];}; +layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; +layout (binding = 1) readonly buffer K {float16_t data_k[];}; +layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; +layout (binding = 2) readonly buffer V {float16_t data_v[];}; +layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; +layout (binding = 3) readonly buffer M {float16_t data_m[];}; + +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + uint32_t offset = (iq2 + r) * D + c; + data_o[o_offset + offset] = D_TYPE(elem); + return elem; +} + +// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd +const uint32_t MatBr = 16; +const uint32_t MatBc = 16; + +shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; +shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x]; + +const uint32_t qstride = D / 4 + 2; // in units of f16vec4 +shared f16vec4 Qf[Br * qstride]; + +// Avoid padding for D==256 to make it fit in 48KB shmem. +const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br; +shared ACC_TYPE sfsh[Bc * sfshstride]; + +const uint32_t kshstride = D / 4 + 2; // in units of f16vec4 +shared f16vec4 ksh[Bc * kshstride]; + +shared float slope[Br]; + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + init_indices(); + + const uint32_t tid = gl_LocalInvocationIndex; + + const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; + const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; + const uint32_t d_tid = gl_LocalInvocationIndex % D_split; + const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split; + +#define tile_row(r) (row_tid * rows_per_thread + (r)) + + uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + + [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (D / 4); + uint32_t r = (idx + tid) / (D / 4); + if (r < Br && d < D / 4 && + i * Br + r < N) { + Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); + } + } + barrier(); + + ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4]; + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] = ACC_TYPEV4(0.0); + } + } + + float Lf[rows_per_thread], Mf[rows_per_thread]; + + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lf[r] = 0; + Mf[r] = NEG_FLT_MAX_OVER_2; + } + + // ALiBi + if (p.max_bias > 0.0f) { + if (tid < Br) { + uint r = tid; + slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); + } + barrier(); + } else { + if (tid < Br) { + uint r = tid; + slope[r] = 1.0; + } + barrier(); + } + +#if BLOCK_SIZE > 1 + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; +#else + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; +#endif + + [[dont_unroll]] + for (uint32_t j = start_j; j < end_j; ++j) { + + [[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (D / 4); + uint32_t c = (idx + tid) / (D / 4); + if (c < Bc && d < D / 4) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); +#else + f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); +#endif + + ksh[c * kshstride + d] = K_Tf; + } + } + barrier(); + + // K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br + // Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 + // This is written transposed in order to allow for N being 8 if implementations need it + coopmat SfMat = coopmat(0); + coopmat KMat; + coopmat QMat; + + for (uint32_t d = 0; d < D / 16; ++d) { + coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); + + uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; + coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + + SfMat = coopMatMulAdd(KMat, QMat, SfMat); + } + + uint coord = gl_SubgroupID * MatBc * sfshstride; + coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor); + barrier(); + + if (p.logit_softcap != 0.0f) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) / Br; + uint32_t r = (idx + tid) % Br; + if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { + sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r])); + } + } + barrier(); + } + + if (p.mask != 0) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) % Bc; + uint32_t r = (idx + tid) / Bc; + if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { + sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)])); + } + } + barrier(); + } + + float eMf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride]; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride])); + } + float Moldf = Mf[r]; + + // M = max(rowmax, Mold) + // P = e^(S - M) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf, Moldf); + eMf[r] = exp(Moldf - Mf[r]); + } + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] = float16_t(eMf[r]) * Of[r][d]; + } + } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lf[r] = eMf[r]*Lf[r]; + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + float Pf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]); + Lf[r] += Pf[r]; + } + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); +#else + vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] += float16_t(Pf[r]) * ACC_TYPEV4(Vf); + } + } + } + + barrier(); + } + + // reduce across threads + + float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + FLOAT_TYPE M = Mf[r]; + tmpsh[tid] = M; + // Compute max across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + M = max(M, tmpsh[tid ^ s]); + barrier(); + tmpsh[tid] = M; + barrier(); + } + rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Moldf[r] = Mf[r]; + + // M = max(rowmax, Mold) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf[r], Moldf[r]); + eMf[r] = exp(Moldf[r] - Mf[r]); + + Lf[r] = eMf[r]*Lf[r]; + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + FLOAT_TYPE L = Lf[r]; + tmpsh[tid] = L; + // Compute sum across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + L += tmpsh[tid ^ s]; + barrier(); + tmpsh[tid] = L; + barrier(); + } + Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + + Of[r][d] = float16_t(eMf[r]) * Of[r][d]; + tmpshv4[tid] = Of[r][d]; + + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + Of[r][d] += tmpshv4[tid ^ s]; + barrier(); + tmpshv4[tid] = Of[r][d]; + barrier(); + } + Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + } + + // If there is split_k, then the split_k resolve shader does the final + // division by L. Store the intermediate O value and per-row m and L values. + if (p.k_num > 1) { + uint32_t o_offset = D * p.ne1 * split_k_index; + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); + } + } + } + } + + o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } + } + + return; + } + + float Lfrcp[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lfrcp[r] = 1.0 / Lf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] *= float16_t(Lfrcp[r]); + } + } + + uint32_t o_offset = iq3*p.ne2*p.ne1; + + if (p.gqa_ratio > 1) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); + } + } + } + } + } else { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (i * Br + tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + } + } + } + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp new file mode 100644 index 0000000000000000000000000000000000000000..6acf67a03a46351fbfd1d28540f951f02ef5648e --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -0,0 +1,267 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require + +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_cooperative_matrix : enable +#extension GL_NV_cooperative_matrix2 : enable +#extension GL_EXT_buffer_reference : enable +#extension GL_KHR_shader_subgroup_ballot : enable +#extension GL_KHR_shader_subgroup_vote : enable +#extension GL_EXT_null_initializer : enable + +#include "types.comp" +#include "dequant_funcs_cm2.comp" +#include "flash_attn_base.comp" + +layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; +layout (binding = 1) readonly buffer K {uint8_t data_k[];}; +layout (binding = 2) readonly buffer V {uint8_t data_v[];}; +layout (binding = 3) readonly buffer M {uint8_t data_m[];}; + +ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) { + return max(x, y); +} + +ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) { + return x; +} + +// Replace matrix elements >= numRows or numCols with 'replace' +ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) { + if (row >= numRows || col >= numCols) { + return replace; + } + return elem; +} + +ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem) +{ + return exp(elem); +} + +ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1) +{ + return max(elem0, elem1); +} + +#if defined(BLOCK_SIZE) +#define DECODEFUNC , DEQUANTFUNC +#else +#define DECODEFUNC +#endif + +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c < D) { + uint32_t offset = (iq2 + r) * D + c; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + init_indices(); + + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp); + tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp); + + tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); + +#if defined(BLOCK_SIZE) + tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE); + tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); +#endif + + tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D); + tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); + tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); + + // hint to the compiler that strides are aligned for the aligned variant of the shader + if (Clamp != gl_CooperativeMatrixClampModeConstantNV) + { + q_stride &= ~7; +#if !defined(BLOCK_SIZE) + k_stride &= ~7; + v_stride &= ~7; +#endif + m_stride &= ~7; + } + tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1); + tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); + tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1); + + coopmat Q; + coopmat Qf16; + + uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; + coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D)); + + Qf16 = coopmat(Q); + Qf16 *= float16_t(p.scale); + + coopmat O = coopmat(0); + + coopmat L, M; + + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); + + L = coopmat(0); + M = coopmat(NEG_FLT_MAX_OVER_2); + + coopmat slopeMat = coopmat(1.0); + + // ALiBi + if (p.max_bias > 0.0f) { + coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2); + } + + [[dont_unroll]] + for (uint32_t j = start_j; j < end_j; ++j) { + + coopmat S = coopmat(0); + + coopmat K_T; + + uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC); + S = coopMatMulAdd(Qf16, K_T, S); + + if (p.logit_softcap != 0.0f) { + [[unroll]] + for (int k = 0; k < S.length(); ++k) { + S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]); + } + } + + if (p.mask != 0) { + tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); + tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); + tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); + + coopmat mv; + + coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + + S += slopeMat*coopmat(mv); + } + + // Clear padding elements to -inf, so they don't contribute to rowmax + if (Clamp != 0 && + ((j + 1) * Bc > KV || + (i + 1) * Br > N)) { + + uint R = ((i + 1) * Br > N) ? (N % Br) : Br; + uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; + + coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C); + } + + coopmat rowmax, P, rowsum, eM; + + coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce); + + coopmat Mold = M; + + // M = max(rowmax, Mold) + // P = e^(S - M) + // eM = e^(Mold - M) + coopMatPerElementNV(M, rowmax, Max, Mold); + coopMatPerElementNV(P, S - M, Exp); + coopMatPerElementNV(eM, Mold - M, Exp); + + // Clear padding elements to 0, so they don't contribute to rowsum + if (Clamp != 0 && + ((j + 1) * Bc > KV || + (i + 1) * Br > N)) { + + uint R = ((i + 1) * Br > N) ? (N % Br) : Br; + uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; + + coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C); + } + + coopmat P_A = coopmat(P); + + // compute rowsum by multiplying by matrix of all ones. + coopmat One = coopmat(1.0); + + rowsum = coopmat(0.0); + rowsum = coopMatMulAdd(P_A, One, rowsum); + + coopmat V; + uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC); + + L = eM*L + rowsum; + + // This is the "diagonal" matrix in the paper, but since we do componentwise + // multiply rather than matrix multiply it has the diagonal element smeared + // across the row + coopmat eMdiag; + + // resize eM by using smear/reduce + coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); + + // multiply with fp16 accumulation, then add to O. + coopmat PV = coopmat(0); + PV = coopMatMulAdd(P_A, V, PV); + + O = eMdiag * O + coopmat(PV); + } + + // If there is split_k, then the split_k resolve shader does the final + // division by L. Store the intermediate O value and per-row m and L values. + if (p.k_num > 1) { + coopmat O_D = coopmat(O); + + uint32_t o_offset = D * p.ne1 * split_k_index; + coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); + + o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); + coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); + return; + } + + coopmat Ldiag; + + // resize L by using smear/reduce + coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce); + + [[unroll]] + for (int k = 0; k < Ldiag.length(); ++k) { + Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k]; + } + + O = Ldiag*O; + + uint32_t o_offset = iq3*p.ne2*p.ne1; + + coopmat O_D = coopmat(O); + if (p.gqa_ratio > 1) { + coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); + } else { + tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); + + // permute dimensions + tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); + + coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp new file mode 100644 index 0000000000000000000000000000000000000000..a7e3956854c442dc56b53779932c55c6d4eba610 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -0,0 +1,59 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#define BLOCK_SIZE 32 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {float data_a[];}; +layout (binding = 1) writeonly buffer D {float data_d[];}; + +layout (push_constant) uniform parameter { + uint D; + uint N; + uint k_num; +} p; + +void main() { + // Each workgroup handles a row + const uint n = gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + uint D = p.D; + uint N = p.N; + uint k_num = p.k_num; + + uint l_offset = D * N * k_num + n; + uint m_offset = D * N * k_num + N + n; + uint lm_stride = N * 2; + + // Compute the max m value for the row + float m_max = -1.0/0.0; + [[unroll]] for (uint k = 0; k < k_num; ++k) { + float m = data_a[m_offset + k * lm_stride]; + m_max = max(m_max, m); + } + + // Compute L based on m_max + float L = 0; + [[unroll]] for (uint k = 0; k < k_num; ++k) { + float l = data_a[l_offset + k * lm_stride]; + float m = data_a[m_offset + k * lm_stride]; + L += exp(m - m_max) * l; + } + + L = 1.0 / L; + + // Scale and sum the O contributions based on m_max and store the result to memory + for (uint d = tid; d < D; d += BLOCK_SIZE) { + float O = 0.0; + [[unroll]] for (uint k = 0; k < k_num; ++k) { + uint o_offset = D * N * k + D * n + d; + float m = data_a[m_offset + k * lm_stride]; + O += exp(m - m_max) * data_a[o_offset]; + } + O *= L; + data_d[D * n + d] = O; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp new file mode 100644 index 0000000000000000000000000000000000000000..4cc7a68ca18c51ec0c9970a5c0eb0fde113ca788 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp @@ -0,0 +1,25 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const float GELU_COEF_A = 0.044715f; + const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float xi = float(data_a[i]); + const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi); + data_d[i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1))); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp new file mode 100644 index 0000000000000000000000000000000000000000..e6e6fcfd20e260db832cf63ec3a579dbd3f01df9 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp @@ -0,0 +1,23 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const float GELU_QUICK_COEF = -1.702f; + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(x * (1.0f / (1.0f + exp(GELU_QUICK_COEF * x)))); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp new file mode 100644 index 0000000000000000000000000000000000000000..062e2a4cdf2d89eaa4e0bef52f011df304e9af73 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp @@ -0,0 +1,64 @@ +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : require + +layout (push_constant) uniform parameter +{ + uint ne; + uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; + uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23; + uint misalign_offsets; + float param1; float param2; int param3; +} p; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +// true if src0/src1 are the same shape and the indices can be reused without additional modulus +layout(constant_id = 0) const bool norepeat = false; + +uint get_idx() { + return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; +} + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; } +uint get_doffset() { return p.misalign_offsets & 0xFF; } + +// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1 +uint fastmod(uint a, uint b) { + if ((b & (b-1)) == 0) { + return a & (b-1); + } + return a % b; +} + +uint fastdiv(uint a, uint b) { + return (a < b) ? 0 : (a / b); +} + +void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) { + i03 = fastdiv(idx, (p.ne02*p.ne01*p.ne00)); + const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; + i02 = fastdiv((idx - i03_offset), (p.ne01*p.ne00)); + const uint i02_offset = i02*p.ne01*p.ne00; + i01 = (idx - i03_offset - i02_offset) / p.ne00; + i00 = idx - i03_offset - i02_offset - i01*p.ne00; +} + +uint src0_idx(uint i00, uint i01, uint i02, uint i03) { + return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; +} + +uint src1_idx(uint i00, uint i01, uint i02, uint i03) { + if (norepeat) { + return i03*p.nb13 + i02*p.nb12 + i01*p.nb11 + i00*p.nb10; + } else { + return fastmod(i03, p.ne13)*p.nb13 + fastmod(i02, p.ne12)*p.nb12 + fastmod(i01, p.ne11)*p.nb11 + fastmod(i00, p.ne10)*p.nb10; + } +} + +uint dst_idx(uint i00, uint i01, uint i02, uint i03) { + return i03*p.nb23 + i02*p.nb22 + i01*p.nb21 + i00*p.nb20; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp new file mode 100644 index 0000000000000000000000000000000000000000..66e46ae6796b80803ab5f4031b121109ef802793 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp @@ -0,0 +1,9 @@ +#extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint KX; + uint KY; + float param1; + float param2; +} p; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp new file mode 100644 index 0000000000000000000000000000000000000000..8dc9d360d52b4fafd8575d966f36ab0ea9158856 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp @@ -0,0 +1,76 @@ +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : require + +layout (push_constant) uniform parameter +{ + uint ne; + uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; + uint misalign_offsets; + float param1; float param2; + + uint ne0_012mp; uint ne0_012L; + uint ne0_01mp; uint ne0_01L; + uint ne0_0mp; uint ne0_0L; + uint ne1_012mp; uint ne1_012L; + uint ne1_01mp; uint ne1_01L; + uint ne1_0mp; uint ne1_0L; +} p; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +uint get_idx() { + return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; +} + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +// see init_fastdiv_values in ggml-vulkan.cpp +uint fastdiv(uint n, uint mp, uint L) { + uint msbs, lsbs; + // msbs = mulhi(n, mp) + umulExtended(n, mp, msbs, lsbs); + return (msbs + n) >> L; +} + +uint src0_idx(uint idx) { + const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L); + const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; + const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L); + const uint i02_offset = i02*p.ne01*p.ne00; + const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L); + const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; + return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; +} + +uint dst_idx(uint idx) { + const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10; +} + +uint src0_idx_quant(uint idx, uint qk) { + const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L); + const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; + const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L); + const uint i02_offset = i02*p.ne01*p.ne00; + const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L); + const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; + return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + (i00/qk)*p.nb00; +} + +uint dst_idx_quant(uint idx, uint qk) { + const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + (i10/qk)*p.nb10; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp new file mode 100644 index 0000000000000000000000000000000000000000..ee6b86a18ddf2e4a6ad108967e6d83f05b23cca7 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp @@ -0,0 +1,33 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint i00 = gl_GlobalInvocationID.x; + const uint i10 = gl_GlobalInvocationID.y; + const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; + const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; + + if (i00 >= p.ne00) { + return; + } + + const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; + + const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03; + const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23; + +#if defined(DATA_A_BF16) + FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00])); +#else + FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]); +#endif +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[d_offset + i00] = D_TYPE(v); +#else + data_d[d_offset + i00] = D_TYPE(v); +#endif +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp new file mode 100644 index 0000000000000000000000000000000000000000..cfd645a38a8babe4431bf2003799523876e4c4a4 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp @@ -0,0 +1,41 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#include "types.comp" +#include "generic_binary_head.comp" +#include "dequant_funcs.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint i00 = (gl_GlobalInvocationID.x)*2; + const uint i10 = gl_GlobalInvocationID.y; + const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; + const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; + +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + if (i00 >= p.ne00) { + return; + } + + const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; + + const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03; + const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23; + + const uint ib = a_offset + i00/QUANT_K; // block index + const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index + const uint iybs = i00 - i00%QUANT_K; // dst block start index + const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; + + vec2 v = dequantize(ib, iqs, 0); + const vec2 dm = get_dm(ib, 0); + v = v * dm.x + dm.y; + + data_d[d_offset + iybs + iqs ] = D_TYPE(v.x); + data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp new file mode 100644 index 0000000000000000000000000000000000000000..b6a0d56454951ff5f293019bb235835bbd222377 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp @@ -0,0 +1,66 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +shared float tmp[BLOCK_SIZE]; + +void main() { + const uint group_size = p.KX; + const float eps = p.param1; + + const uint tid = gl_LocalInvocationID.x; + const uint start = gl_WorkGroupID.x * group_size + tid; + const uint end = (gl_WorkGroupID.x + 1) * group_size; + + tmp[tid] = 0.0f; + + // Calculate mean + [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { + tmp[tid] += float(data_a[col]); + } + + // tmp up partial tmps and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + + const float mean = tmp[0] / group_size; + barrier(); + tmp[tid] = 0.0f; + + // Calculate variance + [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { + const float xi = float(data_a[col]) - mean; + data_d[col] = D_TYPE(xi); + tmp[tid] += xi * xi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + + const float variance = tmp[0] / group_size; + const float scale = inversesqrt(variance + eps); + + [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { + data_d[col] *= D_TYPE(scale); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp new file mode 100644 index 0000000000000000000000000000000000000000..09aa849e8815c13fc2adcf7d437a7aeb0d28f3ae --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -0,0 +1,100 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_spirv_intrinsics: enable +#extension GL_EXT_control_flow_attributes : require + +#if RTE16 +spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits +#endif + +layout (push_constant) uniform parameter +{ + uint batch_offset; uint offset_delta; + uint IC; + uint IW; uint IH; + uint OW; uint OH; + uint KW; uint KH; + uint pelements; + uint CHW; + int s0; int s1; + int p0; int p1; + int d0; int d1; +} p; + +#include "types.comp" + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; + +const uint NUM_ITER = 512 / BLOCK_SIZE; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint gidx = gl_GlobalInvocationID.x; + + const uint oh = gl_GlobalInvocationID.y; + const uint batch = gl_GlobalInvocationID.z / p.IC; + const uint ic = gl_GlobalInvocationID.z % p.IC; + + const uint src_base = ic * p.offset_delta + batch * p.batch_offset; + const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH); + const int oh_s1 = int(oh) * p.s1; + const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1); + + const uint base_linear_idx = gidx * NUM_ITER; + + const uint max_ky = ksize / p.OW; + + uint current_kx = base_linear_idx / ksize; + const uint rem = base_linear_idx - (current_kx * ksize); + uint current_ky = rem / p.OW; + uint current_ix = rem % p.OW; + + A_TYPE values[NUM_ITER]; + uint offset_dst[NUM_ITER]; + [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + values[idx] = A_TYPE(0); + } + + [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + + const uint linear_idx = base_linear_idx + idx; + + if (linear_idx >= p.pelements) { + continue; + } + + const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0; + const uint iih = oh_s1 + current_ky * p.d1 - p.p1; + + offset_dst[idx] = dst_base + current_ix * p.CHW + current_ky * p.KW + current_kx; + + if ((iih < p.IH) && (iiw < p.IW)) { + values[idx] = data_a[src_base + iih * p.IW + iiw]; + } + + if (++current_ix == p.OW) { + current_ix = 0; + if (++current_ky == max_ky) { + current_ky = 0; + current_kx++; + } + } + } + + [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + + const uint linear_idx = base_linear_idx + idx; + + if (linear_idx >= p.pelements) { + continue; + } + + data_d[offset_dst[idx]] = D_TYPE(values[idx]); + } + +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp new file mode 100644 index 0000000000000000000000000000000000000000..deba8c39856293383de9e13465660e7544b7d336 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp @@ -0,0 +1,41 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE sum[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]); + sum[tid] += xi * xi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum[tid] += sum[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1))); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col])); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp new file mode 100644 index 0000000000000000000000000000000000000000..d90a99aea55d33e7fdac112faed1fe7d632cc533 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float val = float(data_a[i]); + data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp new file mode 100644 index 0000000000000000000000000000000000000000..43de19df8eb0c8338f9acd5365b5d925f0e96cfd --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp @@ -0,0 +1,27 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + + idx += num_threads; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp new file mode 100644 index 0000000000000000000000000000000000000000..4c64fd47af718d5e43224648e2f5a7b527739723 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp @@ -0,0 +1,48 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {float data_a[];}; +layout (binding = 0) readonly buffer A4 {vec4 data_a4[];}; +layout (binding = 1) writeonly buffer D {float data_d[];}; +layout (binding = 1) writeonly buffer D4 {vec4 data_d4[];}; + +layout (push_constant) uniform parameter { + uint ne; + uint k_num; +} p; + +void main() { + // Each invocation handles four consecutive components + const uint idx = gl_GlobalInvocationID.x * 4; + + if (idx >= p.ne) { + return; + } + + // Check if all four components are in bounds and aligned, + // then use vector loads + if (idx + 3 < p.ne && (p.ne % 4) == 0) { + vec4 result = vec4(0.0f); + + [[unroll]] for (uint i = 0; i < p.k_num; i++) { + result += data_a4[(i * p.ne + idx) / 4]; + } + + data_d4[idx / 4] = result; + } else { + [[unroll]] for (uint j = 0; j < 4; ++j) { + if (idx + j < p.ne) { + float result = 0.0f; + + [[unroll]] for (uint i = 0; i < p.k_num; i++) { + result += data_a[i * p.ne + idx + j]; + } + + data_d[idx + j] = result; + } + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp new file mode 100644 index 0000000000000000000000000000000000000000..bb429dd594588a62c36c6b0d9a758aedd85d86aa --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -0,0 +1,169 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16) +#define K_PER_ITER 8 +#else +#define K_PER_ITER 2 +#endif + + +uint a_offset, b_offset, d_offset, y_offset; + +void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter) +{ + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const uint col = i*BLOCK_SIZE + K_PER_ITER*tid; + const uint iqs = (col%QUANT_K)/QUANT_R; // quant index + const uint iybs = col - col%QUANT_K; // y block start index + +#if K_PER_ITER == 8 +#if QUANT_R == 2 + const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]); + const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]); + const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y); + const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w); +#else + const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]); + const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]); +#endif +#else + // Check if the second of the pair of elements is OOB, and don't fetch B or + // accumulate it. We still fetch a pair of elements for A, which is fine for + // quantized formats since they'll be within the same block. We should + // probably skip fetching the second element for F16/F32, but as of now we + // still do. + const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols); + + FLOAT_TYPE b0 = 0, b1 = 0; + b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]); + if (!OOB) { + b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]); + } +#endif + uint ibi = first_row*p.ncols; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib = (ibi + col)/QUANT_K; // block index + ibi += p.ncols; + +#if K_PER_ITER == 8 + vec4 v = dequantize4(ib, iqs, a_offset); + vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset); + + const vec2 dm = get_dm(ib, a_offset); + if (dm.y != 0) { // quant has min component + v = v * dm.x + dm.y; + v2 = v2 * dm.x + dm.y; + } + + // matrix multiplication + FLOAT_TYPE rowtmp = dot(bv0, v); + rowtmp += dot(bv1, v2); + + if (dm.y == 0) + rowtmp *= dm.x; + + temp[j][n] += rowtmp; +#else + const vec2 v = dequantize(ib, iqs, a_offset); + + // matrix multiplication + temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]); + if (!OOB) { + temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]); + } +#endif + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + const uint tid = gl_LocalInvocationID.x; + + get_offsets(a_offset, b_offset, d_offset); + a_offset /= QUANT_K; + + y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; + + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE); + if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) { + num_iters++; + } + int unroll_count = 4; + uint unrolled_iters = num_iters & ~(unroll_count - 1); + +#if K_PER_ITER == 2 + // If the K dimension is odd, we need lastiter==true on the last iteration + // so OOB is computed correctly. Skip some unrolling to make that happen. + if ((p.ncols & 1) != 0 && + unrolled_iters == num_iters && + unrolled_iters > 0) { + unrolled_iters -= unroll_count; + } +#endif + + uint i = 0; + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false); + i++; + } + } + + unroll_count = 2; + unrolled_iters = num_iters & ~(unroll_count - 1); + +#if K_PER_ITER == 2 + if ((p.ncols & 1) != 0 && + unrolled_iters == num_iters && + unrolled_iters > 0) { + unrolled_iters -= unroll_count; + } +#endif + + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false); + i++; + } + } + while (i < num_iters) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true); + i++; + } + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp new file mode 100644 index 0000000000000000000000000000000000000000..903753c7e2ec5fc3fed4ab5fc8f9809a9a73678b --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp @@ -0,0 +1,118 @@ +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_shader_8bit_storage : require + +#ifdef MUL_MAT_ID +#define EXPERT_COUNT 8 +#endif + +#include "types.comp" + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];}; +layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; + +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; +#endif + +#include "dequant_funcs.comp" + +layout (push_constant) uniform parameter +{ + uint ncols; + uint stride_a; + uint stride_b; + uint stride_d; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint nei0; + uint ne11; +#else + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; +#endif +} p; + +void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.y; +#else + const uint batch_idx = gl_GlobalInvocationID.y; +#endif + +#ifndef MUL_MAT_ID + uint batch_idx_a = 0; + if (batch_idx != 0) { + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + batch_idx_a = i03 * p.ne02 + i02; + } +#else + const uint expert_id = data_ids[expert_idx]; +#endif + + a_offset = +#ifdef MUL_MAT_ID + expert_id * p.batch_stride_a; +#else + batch_idx_a * p.batch_stride_a; +#endif + b_offset = +#ifdef MUL_MAT_ID + (expert_idx % p.ne11) * p.stride_b; +#else + batch_idx * p.batch_stride_b; +#endif + d_offset = +#ifdef MUL_MAT_ID + expert_idx * p.stride_d; +#else + batch_idx * p.batch_stride_d; +#endif +} + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; +layout (constant_id = 1) const uint NUM_ROWS = 1; +layout (constant_id = 2) const uint NUM_COLS = 1; + +shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE]; + +void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) { + // sum up partial sums and write back result + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + tmpsh[j][n][tid] = temp[j][n]; + } + } + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { + if (tid < s) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + tmpsh[j][n][tid] += tmpsh[j][n][tid + s]; + } + } + } + barrier(); + } + if (tid == 0) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]); + } + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp new file mode 100644 index 0000000000000000000000000000000000000000..e4acbd4f9626118d3e932c5fd0baee254e60c564 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp @@ -0,0 +1,82 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 32 * ib32; + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint16_t[4] scales = data_a[ibi].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + + const uint sc = data_a[ibi].scales[ib32 / 2] >> (6 * (ib32 & 1)); + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint qh = data_a[ibi].qh[2 * ib32 + l / 2] >> (4 * (l&1)); + const uint qs = data_a[ibi].qs[4 * ib32 + l]; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const float dl = d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1); + + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int k = 0; k < 4; ++k) { + sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta, + fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum)); + } + temp[j][n] = fma(dl, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 8 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/8; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 8; // 0...7 + const uint ix = tid / 8; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp new file mode 100644 index 0000000000000000000000000000000000000000..309da0991ae63f9eee5ff284e85b56385138964e --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp @@ -0,0 +1,79 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 32 * ib32; + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint qh = data_a[ibi].qh[ib32]; + const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint qs = data_a[ibi].qs[4 * ib32 + l]; + const uint idxhi = bitfieldExtract(qh, 3 * int(l), 3); + const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int k = 0; k < 4; ++k) { + sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta, + fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum)); + } + temp[j][n] = fma(dl, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 8 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/8; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 8; // 0...7 + const uint ix = tid / 8; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp new file mode 100644 index 0000000000000000000000000000000000000000..8d01536fa69c0f6fd4953e4615685e877a89289f --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp @@ -0,0 +1,90 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint nibble_shift = 4 * (itid & 1); + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF; + const float db = d * (0.5 + scale) * 0.25; + + const uint qh = data_a[ibi].qh[ib32]; + const u8vec2 qs16 = unpack8(uint32_t(data_a_packed16[ibi].qs[itid])).xy; // vec4 used due to #12147 + const u8vec2 sign16 = unpack8(uint32_t(data_a_packed16[ibi].qs[QUANT_K / 16 + itid])).xy; + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint8_t sign = sign16[l]; + const uint qs = qs16[l] | ((qh << (8 - nibble_shift - 2 * l)) & 0x300); + const uvec2 grid = iq2s_grid[qs]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp new file mode 100644 index 0000000000000000000000000000000000000000..c496043241072d30923fde064bd2b68bf1d1d889 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp @@ -0,0 +1,87 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint nibble_shift = 4 * (itid & 1); + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF; + const float db = d * (0.5 + scale) * 0.25; + + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint qs = data_a[ibi].qs[2 * itid + l]; + const uint sign = qs >> 9; + const uint sign7 = bitCount(sign); + const vec4 grid0 = vec4(unpack8(iq2xs_grid[qs & 511].x)); + const vec4 grid1 = vec4(unpack8(iq2xs_grid[qs & 511].y)); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp new file mode 100644 index 0000000000000000000000000000000000000000..94d4b92e1ee694aa74abc083efd3a41a08e8e8e5 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp @@ -0,0 +1,87 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint signscale = pack32(u16vec2( + data_a_packed16[ibi].qs[4 * ib32 + 2], + data_a_packed16[ibi].qs[4 * ib32 + 3])); + const float db = d * 0.25 * (0.5 + (signscale >> 28)); + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint qs = data_a[ibi].qs[8 * ib32 + 2 * (itid & 1) + l]; + const uint sign = bitfieldExtract(signscale, 7 * int(2 * (itid & 1) + l), 7); + const uint sign7 = bitCount(sign); + const vec4 grid0 = vec4(unpack8(iq2xxs_grid[qs].x)); + const vec4 grid1 = vec4(unpack8(iq2xxs_grid[qs].y)); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp new file mode 100644 index 0000000000000000000000000000000000000000..f021e404761997bd2c6dfcdc91881b6ce8fa4219 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp @@ -0,0 +1,90 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 32 * ib32; + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint scale = (data_a[ibi].scales[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const float dscale = d * (1 + 2 * scale); + const uint qh = data_a[ibi].qh[ib32]; + FLOAT_TYPE sum[NUM_COLS]; + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + sum[j] = 0.0; + } + [[unroll]] for (uint l = 0; l < 4; ++l) { + const u8vec2 qs = unpack8(uint32_t(data_a_packed16[ibi].qs[4 * ib32 + l])).xy; // vec4 used due to #12147 + const uint sign = data_a[ibi].signs[4 * ib32 + l]; + const vec4 grid0 = vec4(unpack8(iq3s_grid[qs.x | ((qh << (8 - 2*l)) & 0x100)])); + const vec4 grid1 = vec4(unpack8(iq3s_grid[qs.y | ((qh << (7 - 2*l)) & 0x100)])); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + sum[j] = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w), + sum[j])))))))); + } + } + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + temp[j][n] = fma(dscale, sum[j], temp[j][n]); + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 8 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/8; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 8; // 0...7 + const uint ix = tid / 8; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp new file mode 100644 index 0000000000000000000000000000000000000000..3fe9dc3a4113a9f8f53c9291b855fa0fb3a96d86 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp @@ -0,0 +1,88 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint signscale = pack32(u16vec2( + data_a_packed16[ibi].qs[QUANT_K / 8 + 2 * ib32], + data_a_packed16[ibi].qs[QUANT_K / 8 + 2 * ib32 + 1])); + const float db = d * 0.5 * (0.5 + (signscale >> 28)); + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint qs0 = data_a[ibi].qs[8 * ib32 + 4 * (itid & 1) + 2 * l]; + const uint qs1 = data_a[ibi].qs[8 * ib32 + 4 * (itid & 1) + 2 * l + 1]; + const uint sign = bitfieldExtract(signscale, 7 * int(2 * (itid & 1) + l), 7); + const uint sign7 = bitCount(sign); + const vec4 grid0 = vec4(unpack8(iq3xxs_grid[qs0])); + const vec4 grid1 = vec4(unpack8(iq3xxs_grid[qs1])); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp new file mode 100644 index 0000000000000000000000000000000000000000..bc633369f9bb58dc4445cfddb8503706cb26babb --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp @@ -0,0 +1,118 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#define BLOCK_SIZE 32 +#define FLOAT_TYPE float + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; + +layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; +layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; + +layout (push_constant) uniform parameter +{ + uint ncols_x; + uint nrows_x; + uint row_stride_x; + uint channel_stride_x; + uint channel_stride_y; + uint channel_x_divisor; + uint ne12; + uint b_offset; + uint d_offset; +} p; + +shared FLOAT_TYPE tmp[BLOCK_SIZE]; + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint row_x = gl_GlobalInvocationID.y; + const uint channel = gl_GlobalInvocationID.z; + const uint channel_x = channel / p.channel_x_divisor; + const uint channel_y = channel % p.ne12; + + const uint nrows_y = p.ncols_x; + const uint nrows_dst = p.nrows_x; + const uint row_dst = row_x; + + const uint idst = channel*nrows_dst + row_dst; + + FLOAT_TYPE temp = 0.0f; + + // Detect alignment for vector loads + bool is_aligned = (p.ncols_x % 4) == 0 && (p.row_stride_x % 4) == 0 && (p.channel_stride_x % 4) == 0; + + for (uint col_x0 = 0; col_x0 < p.ncols_x;) { + + // Unroll 2x and do vec4 loads if aligned + const uint unroll_count = 2; + if (col_x0 + unroll_count * 4 * BLOCK_SIZE <= p.ncols_x && is_aligned) { + [[unroll]] for (uint i = 0; i < unroll_count; ++i) { + const uint col_x = col_x0 + 4*tid; + + const uint row_y = col_x; + + const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = channel_y*p.channel_stride_y + row_y; + + const vec4 av4 = vec4(data_a_v4[ix / 4]); + const vec4 bv4 = vec4(data_b_v4[iy / 4]); + + temp += dot(av4, bv4); + + col_x0 += 4*BLOCK_SIZE; + } + // do vec4 loads if aligned + } else if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) { + const uint col_x = col_x0 + 4*tid; + + const uint row_y = col_x; + + const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = channel_y*p.channel_stride_y + row_y; + + const vec4 av4 = vec4(data_a_v4[ix / 4]); + const vec4 bv4 = vec4(data_b_v4[iy / 4]); + + temp += dot(av4, bv4); + + col_x0 += 4*BLOCK_SIZE; + } else { + const uint col_x = col_x0 + tid; + if (col_x >= p.ncols_x) { + break; + } + + const uint row_y = col_x; + + const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = channel_y*p.channel_stride_y + row_y; + + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); + + temp = fma(xi, FLOAT_TYPE(data_b[iy]), temp); + col_x0 += BLOCK_SIZE; + } + } + + tmp[tid] = temp; + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + + if (tid == 0) { + dst[idst] = tmp[0]; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp new file mode 100644 index 0000000000000000000000000000000000000000..7aa070eebdf72b4849468098de1a3a30238cf218 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp @@ -0,0 +1,154 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require +#if USE_SUBGROUP_ADD +#extension GL_KHR_shader_subgroup_arithmetic : enable +#endif + +#define FLOAT_TYPE float + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; + +layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; +layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; + +layout(constant_id = 0) const int BLOCK_SIZE = 32; +// gqa_ratio is in the range [1,8] +layout(constant_id = 1) const uint gqa_ratio = 1; + +layout (push_constant) uniform parameter +{ + uint ncols_x; + uint nrows_x; + uint nchannels_x; + uint nchannels_y; + uint b_offset; + uint d_offset; +} p; + +#if !USE_SUBGROUP_ADD +shared FLOAT_TYPE tmp[8][BLOCK_SIZE]; +#endif + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint row_x = gl_GlobalInvocationID.y; + + uint channel, channel_x; + + // When gqa_ratio > 1, each invocation does multiple rows. + // The row in the A matrix is starting from channel / gqa_ratio and the + // rows in the B matrix are [channel, channel+gqa_ratio). + // When gpa_ratio is 1, each invocation does one row. + if (gqa_ratio > 1) { + channel_x = gl_GlobalInvocationID.z; + channel = channel_x * gqa_ratio; + } else { + channel = gl_GlobalInvocationID.z; + channel_x = channel / (p.nchannels_y / p.nchannels_x);; + } + + const uint nrows_y = p.ncols_x; + const uint nrows_dst = p.nrows_x; + const uint row_dst = row_x; + + FLOAT_TYPE temp[8]; + [[unroll]] for (uint i = 0; i < 8; ++i) { + temp[i] = FLOAT_TYPE(0.0f); + } + + // Detect alignment for vector loads + bool is_aligned = (p.ncols_x % 4) == 0 && (p.nchannels_x % 4) == 0 && (nrows_y % 4) == 0; + + for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { + + // Use vec4 loads if aligned + if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) { + + uint col_x = col_x0 + 4*tid; + const uint row_y = col_x; + + // x is transposed and permuted + const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; + const vec4 av4 = vec4(data_a_v4[ix / 4]); + + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // y is not transposed but permuted + const uint iy = (channel + c)*nrows_y + row_y; + + vec4 bv4 = data_b_v4[iy / 4]; + temp[c] += dot(av4, bv4); + } + + col_x0 += 3*BLOCK_SIZE; + } else { + const uint col_x = col_x0 + tid; + + if (col_x >= p.ncols_x) { + break; + } + + // x is transposed and permuted + const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); + + const uint row_y = col_x; + + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // y is not transposed but permuted + const uint iy = (channel + c)*nrows_y + row_y; + + temp[c] = fma(xi, FLOAT_TYPE(data_b[iy]), temp[c]); + } + } + } + +#if USE_SUBGROUP_ADD + // reduce vec4 at a time + vec4 t = vec4(temp[0], temp[1], temp[2], temp[3]); + t = subgroupAdd(t); + temp[0] = t[0]; + temp[1] = t[1]; + temp[2] = t[2]; + temp[3] = t[3]; + if (gqa_ratio > 4) { + t = vec4(temp[4], temp[5], temp[6], temp[7]); + t = subgroupAdd(t); + temp[4] = t[0]; + temp[5] = t[1]; + temp[6] = t[2]; + temp[7] = t[3]; + } +#else + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + tmp[c][tid] = temp[c]; + } + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + temp[c] += tmp[c][tid + s]; + tmp[c][tid] = temp[c]; + } + } + barrier(); + } + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + temp[c] = tmp[c][tid]; + } +#endif + + if (tid == 0) { + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // dst is not transposed and not permuted + const uint idst = (channel + c)*nrows_dst + row_dst; + dst[idst] = temp[c]; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp new file mode 100644 index 0000000000000000000000000000000000000000..423ceb8a3df463b917e3f828e3d163cfcfab0c16 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -0,0 +1,130 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +shared FLOAT_TYPE sccache1[2][BLOCK_SIZE/16][16]; +shared FLOAT_TYPE sccache2[2][BLOCK_SIZE/16][16]; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; +uint csel = 0; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + csel ^= 1; + + if (!all_threads) { // when we don't have enough blocks to use all threads + if (i < num_blocks_per_row) { + const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); + sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF); + sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); + } + barrier(); + + if (i >= num_blocks_per_row) + continue; + } else { + const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); + sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF); + sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); + barrier(); + } + + const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); + const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); + const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); + const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); + const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); + + vec2 d = vec2(data_a[ib0 + i].d); + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); + vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]); + vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]); + vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]); + vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]); + vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]); + vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]); + vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]); + + FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); + FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[csel][ix][ 8*v_im] * qs_u32_0[l ], + fma(FLOAT_TYPE(b16[l]), sccache1[csel][ix][1 + 8*v_im] * qs_u32_0[l+2], + fma(FLOAT_TYPE(b32[l]), sccache1[csel][ix][2 + 8*v_im] * qs_u32_2[l ], + fma(FLOAT_TYPE(b48[l]), sccache1[csel][ix][3 + 8*v_im] * qs_u32_2[l+2], + fma(FLOAT_TYPE(b64[l]), sccache1[csel][ix][4 + 8*v_im] * qs_u32_4[l ], + fma(FLOAT_TYPE(b80[l]), sccache1[csel][ix][5 + 8*v_im] * qs_u32_4[l+2], + fma(FLOAT_TYPE(b96[l]), sccache1[csel][ix][6 + 8*v_im] * qs_u32_6[l ], + fma(FLOAT_TYPE(b112[l]), sccache1[csel][ix][7 + 8*v_im] * qs_u32_6[l+2], sum1)))))))); + sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[csel][ix][ 8*v_im], + fma(FLOAT_TYPE(b16[l]), sccache2[csel][ix][1 + 8*v_im], + fma(FLOAT_TYPE(b32[l]), sccache2[csel][ix][2 + 8*v_im], + fma(FLOAT_TYPE(b48[l]), sccache2[csel][ix][3 + 8*v_im], + fma(FLOAT_TYPE(b64[l]), sccache2[csel][ix][4 + 8*v_im], + fma(FLOAT_TYPE(b80[l]), sccache2[csel][ix][5 + 8*v_im], + fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im], + fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2)))))))); + } + temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; + + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - 8*v_im; // 0...7 + + const uint l0 = 2*v_in; // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 128*v_im + l0; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp new file mode 100644 index 0000000000000000000000000000000000000000..e91724a28db2289e8e4705b7201da367b5c590c3 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -0,0 +1,132 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][2][8]; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; +uint csel = 0; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + csel ^= 1; + + if (!all_threads) { // when we don't have enough blocks to use all threads + if (i < num_blocks_per_row) + sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); + barrier(); + + if (i >= num_blocks_per_row) + continue; + } + + const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16)); + const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> ( v_im4)) << 2)); + const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2)); + const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2)); + const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2)); + + // 0, 1, 16, 17 + uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8); + qs_u32 |= (uint32_t(data_a[ib0 + i].qs[q_offset + 16]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 17]) << 8)) << 16; + const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); + const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); + const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); + const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); + + if (all_threads) { + sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); + barrier(); + } + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); + vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]); + vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]); + vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]); + vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]); + vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]); + vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]); + vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum = fma(FLOAT_TYPE( b0[l]) * sccache[csel][ix][v_im][0], qs_u32_0[l ] - hmk_0[l ], + fma(FLOAT_TYPE( b16[l]) * sccache[csel][ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2], + fma(FLOAT_TYPE( b32[l]) * sccache[csel][ix][v_im][2], qs_u32_2[l ] - hmk_1[l ], + fma(FLOAT_TYPE( b48[l]) * sccache[csel][ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2], + fma(FLOAT_TYPE( b64[l]) * sccache[csel][ix][v_im][4], qs_u32_4[l ] - hmk_2[l ], + fma(FLOAT_TYPE( b80[l]) * sccache[csel][ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2], + fma(FLOAT_TYPE( b96[l]) * sccache[csel][ix][v_im][6], qs_u32_6[l ] - hmk_3[l ], + fma(FLOAT_TYPE(b112[l]) * sccache[csel][ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum)))))))); + } + temp[j][n] = fma(d, sum, temp[j][n]); + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; + const uint itid8 = itid%8; + + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_im4 = v_im*4; + const uint v_in = itid - 8*v_im; // 0...7 + + const uint32_t m = 0x01010101 << (4 * v_im); + uint32_t hm_m[4]; + [[unroll]] for (uint j = 0; j < 4; ++j) + hm_m[j] = m << j; + + const uint l0 = 2*v_in; // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 128*v_im + l0; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + const uint s_shift = v_im4 + 2*(itid8/4); + + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, false); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp new file mode 100644 index 0000000000000000000000000000000000000000..f9cde064887a8c1a0680431f2e93b8d363ab950a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -0,0 +1,136 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y1_idx = i * QUANT_K + y_offset; + const uint y2_idx = y1_idx + 128; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + vec2 d = vec2(data_a[ib0 + i].d); + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + + const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; + const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; + const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); + const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); + + const FLOAT_TYPE sc0 = scale_0_4_l_f.x; + const FLOAT_TYPE sc1 = scale_0_4_l_f.y; + const FLOAT_TYPE sc2 = scale_0_4_l_f.z; + const FLOAT_TYPE sc3 = scale_0_4_l_f.w; + const FLOAT_TYPE sc4 = scale8_f.x; + const FLOAT_TYPE sc5 = scale8_f.y; + const FLOAT_TYPE sc6 = scale8_f.z; + const FLOAT_TYPE sc7 = scale8_f.w; + + const uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4]; + const uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16]; + + const uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F; + const uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F; + const uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F; + const uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F; + + const vec4 qs0_lo4 = vec4(unpack8(qs0_u32_lo4)); + const vec4 qs64_lo4 = vec4(unpack8(qs64_u32_lo4)); + const vec4 qs0_hi4 = vec4(unpack8(qs0_u32_hi4)); + const vec4 qs64_hi4 = vec4(unpack8(qs64_u32_hi4)); + + const FLOAT_TYPE q4_0 = qs0_lo4.x; + const FLOAT_TYPE q4_1 = qs0_lo4.y; + const FLOAT_TYPE q4_2 = qs0_lo4.z; + const FLOAT_TYPE q4_3 = qs0_lo4.w; + const FLOAT_TYPE q4_4 = qs0_hi4.x; + const FLOAT_TYPE q4_5 = qs0_hi4.y; + const FLOAT_TYPE q4_6 = qs0_hi4.z; + const FLOAT_TYPE q4_7 = qs0_hi4.w; + const FLOAT_TYPE q4_8 = qs64_lo4.x; + const FLOAT_TYPE q4_9 = qs64_lo4.y; + const FLOAT_TYPE q4_10 = qs64_lo4.z; + const FLOAT_TYPE q4_11 = qs64_lo4.w; + const FLOAT_TYPE q4_12 = qs64_hi4.x; + const FLOAT_TYPE q4_13 = qs64_hi4.y; + const FLOAT_TYPE q4_14 = qs64_hi4.z; + const FLOAT_TYPE q4_15 = qs64_hi4.w; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 by10 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 ]); + vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]); + vec4 by20 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 ]); + vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]); + + const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3))); + const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7))); + const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11))); + const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7, + fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, + fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, + fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); + temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; + + const uint il = itid/4; // 0...3 + const uint ir = itid - 4*il; // 0...3 + const uint n = 4; + + const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const uint v_in = il % 2; + + const uint l0 = n * (2 * ir + v_in); // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 64*v_im + l0; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) + calc_superblock(a_offset, b_offset, v_im, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp new file mode 100644 index 0000000000000000000000000000000000000000..6c84ef3cde3ff5e8bb6f9d57bb9ae8f7d3655a96 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -0,0 +1,167 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint l0, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y1_idx = i * QUANT_K + y_offset; + const uint y2_idx = y1_idx + 128; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + vec2 d = vec2(data_a[ib0 + i].d); + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + + const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; + const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; + const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); + const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); + + const FLOAT_TYPE sc0 = scale_0_4_l_f.x; + const FLOAT_TYPE sc1 = scale_0_4_l_f.y; + const FLOAT_TYPE sc2 = scale_0_4_l_f.z; + const FLOAT_TYPE sc3 = scale_0_4_l_f.w; + const FLOAT_TYPE sc4 = scale8_f.x; + const FLOAT_TYPE sc5 = scale8_f.y; + const FLOAT_TYPE sc6 = scale8_f.z; + const FLOAT_TYPE sc7 = scale8_f.w; + + const uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); + const uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); + + uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F; + uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F; + uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F; + uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F; + + const uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8])); + + const uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; + const uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; + const uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010); + const uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; + + qs0_16_u32_lo4 += qs0_16_lo4_offset16; + qs0_16_u32_hi4 += qs0_16_hi4_offset16; + qs64_80_u32_lo4 += qs64_80_lo4_offset16; + qs64_80_u32_hi4 += qs64_80_hi4_offset16; + + const vec4 qs0_16_lo4 = vec4(unpack8(qs0_16_u32_lo4)); + const vec4 qs64_80_lo4 = vec4(unpack8(qs64_80_u32_lo4)); + const vec4 qs0_16_hi4 = vec4(unpack8(qs0_16_u32_hi4)); + const vec4 qs64_80_hi4 = vec4(unpack8(qs64_80_u32_hi4)); + + const FLOAT_TYPE q4_0 = qs0_16_lo4.x; + const FLOAT_TYPE q4_1 = qs0_16_lo4.y; + const FLOAT_TYPE q4_2 = qs0_16_lo4.z; + const FLOAT_TYPE q4_3 = qs0_16_lo4.w; + const FLOAT_TYPE q4_4 = qs0_16_hi4.x; + const FLOAT_TYPE q4_5 = qs0_16_hi4.y; + const FLOAT_TYPE q4_6 = qs0_16_hi4.z; + const FLOAT_TYPE q4_7 = qs0_16_hi4.w; + const FLOAT_TYPE q4_8 = qs64_80_lo4.x; + const FLOAT_TYPE q4_9 = qs64_80_lo4.y; + const FLOAT_TYPE q4_10 = qs64_80_lo4.z; + const FLOAT_TYPE q4_11 = qs64_80_lo4.w; + const FLOAT_TYPE q4_12 = qs64_80_hi4.x; + const FLOAT_TYPE q4_13 = qs64_80_hi4.y; + const FLOAT_TYPE q4_14 = qs64_80_hi4.z; + const FLOAT_TYPE q4_15 = qs64_80_hi4.w; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 by10 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 ]); + vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]); + vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]); + vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]); + vec2 by20 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 ]); + vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]); + vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]); + vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]); + + const FLOAT_TYPE sx = + fma(FLOAT_TYPE(by10.x), q4_0, + fma(FLOAT_TYPE(by10.y), q4_1, + fma(FLOAT_TYPE(by116.x), q4_2, + FLOAT_TYPE(by116.y) * q4_3))); + const FLOAT_TYPE sy = + fma(FLOAT_TYPE(by132.x), q4_4, + fma(FLOAT_TYPE(by132.y), q4_5, + fma(FLOAT_TYPE(by148.x), q4_6, + FLOAT_TYPE(by148.y) * q4_7))); + const FLOAT_TYPE sz = + fma(FLOAT_TYPE(by20.x), q4_8, + fma(FLOAT_TYPE(by20.y), q4_9, + fma(FLOAT_TYPE(by216.x), q4_10, + FLOAT_TYPE(by216.y) * q4_11))); + const FLOAT_TYPE sw = + fma(FLOAT_TYPE(by232.x), q4_12, + fma(FLOAT_TYPE(by232.y), q4_13, + fma(FLOAT_TYPE(by248.x), q4_14, + FLOAT_TYPE(by248.y) * q4_15))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2, + fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3, + fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6, + (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7))); + temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; + + const uint il = itid/4; // 0...3 + const uint ir = itid - 4*il; // 0...3 + + const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const uint v_in = il % 2; + + const uint l0 = 4*ir + 2*v_in; // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 64*v_im + l0; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) + calc_superblock(a_offset, b_offset, v_im, l0, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp new file mode 100644 index 0000000000000000000000000000000000000000..d53d9ee0a2723cb20cede02b4f2d3252513d2178 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -0,0 +1,130 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][16]; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; +uint csel = 0; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + csel ^= 1; + + if (!all_threads) { // when we don't have enough blocks to use all threads + if (i < num_blocks_per_row) + sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); + barrier(); + + if (i >= num_blocks_per_row) + continue; + } + + const uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); + const uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); + + const uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F; + const uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F; + const uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F; + const uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F; + + const uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); + const uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; + const uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; + const uint32_t qh4_u32 = (qh_u32 & 0x30303030); + const uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2; + + const uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32; + const uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32; + const uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32; + const uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32; + + const vec4 q0 = vec4(unpack8(q0_u32)) - 32; + const vec4 q1 = vec4(unpack8(q1_u32)) - 32; + const vec4 q2 = vec4(unpack8(q2_u32)) - 32; + const vec4 q3 = vec4(unpack8(q3_u32)) - 32; + + if (all_threads) { + sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); + barrier(); + } + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 by0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 ]); + vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]); + vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]); + vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]); + + FLOAT_TYPE sum[4] = {0, 0, 0, 0}; + [[unroll]] for (uint l = 0; l < 4; ++l) { + sum[0] = fma(FLOAT_TYPE(by0[l]), q0[l], sum[0]); + sum[1] = fma(FLOAT_TYPE(by32[l]), q1[l], sum[1]); + sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]); + sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]); + } + temp[j][n] = fma(fma(sum[0], sccache[csel][ix][s_offset], fma(sum[1], sccache[csel][ix][s_offset + 2], fma(sum[2], sccache[csel][ix][s_offset + 4], sum[3] * sccache[csel][ix][s_offset + 6]))), d, temp[j][n]); + } + } +} + +void compute_outputs(const uint first_row, const uint num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; + + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - 8*v_im; // 0...7 + + const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28 + const uint is = v_in / 4; + + const uint ql_offset = 64*v_im + l0; + const uint qh_offset = 32*v_im + l0; + const uint s_offset = 8*v_im + is; + const uint y_offset = 128*v_im + l0; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp new file mode 100644 index 0000000000000000000000000000000000000000..26163b167c7ed712f98439ee260e4056e999ff2e --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -0,0 +1,868 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#ifdef FLOAT16 +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#endif +#if defined(DATA_A_IQ1_M) +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#endif + +#if defined(DATA_A_BF16) && defined(COOPMAT) +#extension GL_EXT_bfloat16 : enable +#endif + +#ifdef COOPMAT +#extension GL_KHR_cooperative_matrix : enable +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_shader_subgroup_basic : enable +#endif + +#ifdef MUL_MAT_ID +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#endif + +#include "types.comp" + +#ifndef LOAD_VEC_A +#define LOAD_VEC_A 1 +#endif +#ifndef LOAD_VEC_B +#define LOAD_VEC_B 1 +#endif + +#if !defined(TO_FLOAT_TYPE) +#define TO_FLOAT_TYPE FLOAT_TYPE +#endif + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif + +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; +#endif + +layout (push_constant) uniform parameter +{ + uint M; + uint N; + uint K; + uint stride_a; + uint stride_b; + uint stride_d; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint nei0; + uint nei1; + uint nbi1; + uint ne11; +#else + uint k_split; + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; +#endif +} p; + +layout (constant_id = 0) const uint BLOCK_SIZE = 64; +layout (constant_id = 1) const uint BM = 64; +layout (constant_id = 2) const uint BN = 64; +layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant +layout (constant_id = 4) const uint WM = 32; +layout (constant_id = 5) const uint WN = 32; +layout (constant_id = 6) const uint WMITER = 2; +layout (constant_id = 7) const uint TM = 4; +layout (constant_id = 8) const uint TN = 2; +layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat +layout (constant_id = 10) const uint WARP = 32; + +#ifdef COOPMAT +#define SHMEM_STRIDE (BK + 8) +#else +#define SHMEM_STRIDE (BK + 1) +#endif + +shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE]; +shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE]; + +#ifdef MUL_MAT_ID +shared u16vec2 row_ids[4096]; +#endif // MUL_MAT_ID + +#define NUM_WARPS (BLOCK_SIZE / WARP) + +#ifdef COOPMAT +shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; +#endif + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.z; +#else + const uint batch_idx = gl_GlobalInvocationID.z; + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + const uint batch_idx_a = i03 * p.ne02 + i02; +#endif + + const uint blocks_m = (p.M + BM - 1) / BM; + const uint ir = gl_WorkGroupID.x % blocks_m; + const uint ik = gl_WorkGroupID.x / blocks_m; + const uint ic = gl_WorkGroupID.y; + + const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); + const uint WSUBM = WM / WMITER; + const uint WSUBN = WN / WNITER; + +#ifdef COOPMAT + const uint warp_i = gl_SubgroupID; + + const uint tiw = gl_SubgroupInvocationID; + + const uint cms_per_row = WM / TM; + const uint cms_per_col = WN / TN; + + const uint storestride = WARP / TM; + const uint store_r = tiw % TM; + const uint store_c = tiw / TM; +#else + const uint warp_i = gl_LocalInvocationID.x / WARP; + + const uint tiw = gl_LocalInvocationID.x % WARP; + + const uint tiwr = tiw % (WSUBM / TM); + const uint tiwc = tiw / (WSUBM / TM); +#endif + + const uint warp_r = warp_i % (BM / WM); + const uint warp_c = warp_i / (BM / WM); + + const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A); + const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A); + const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B); + const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B); + + const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK; + const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK; + +#ifdef MUL_MAT_ID + uint _ne1 = 0; + for (uint ii1 = 0; ii1 < p.nei1; ii1++) { + for (uint ii0 = 0; ii0 < p.nei0; ii0++) { + if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { + row_ids[_ne1] = u16vec2(ii0, ii1); + _ne1++; + } + } + } + + barrier(); + + // Workgroup has no work + if (ic * BN >= _ne1) return; +#endif + +#ifdef MUL_MAT_ID + const uint start_k = 0; + const uint end_k = p.K; +#else + const uint start_k = ik * p.k_split; + const uint end_k = min(p.K, (ik + 1) * p.k_split); +#endif + + uint pos_a = ( +#ifdef MUL_MAT_ID + expert_idx * p.batch_stride_a + +#else + batch_idx_a * p.batch_stride_a + +#endif + ir * BM * p.stride_a + start_k) / LOAD_VEC_A; +#ifdef MUL_MAT_ID + uint pos_b = 0; +#else + uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; +#endif + +#ifdef COOPMAT + coopmat cache_a; + coopmat cache_b; + coopmat sums[cms_per_row * cms_per_col]; + + [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { + sums[i] = coopmat(0.0f); + } +#else + ACC_TYPE sums[WMITER * TM * WNITER * TN]; + FLOAT_TYPE cache_a[WMITER * TM]; + FLOAT_TYPE cache_b[TN]; + + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { + sums[i] = ACC_TYPE(0.0f); + } +#endif + + for (uint block = start_k; block < end_k; block += BK) { + [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) { + +#if defined(DATA_A_F32) || defined(DATA_A_F16) +#if LOAD_VEC_A == 8 + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x); + buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y); + buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z); + buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w); + buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x); + buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y); + buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z); + buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w); +#elif LOAD_VEC_A == 4 + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x); + buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y); + buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z); + buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w); +#else + if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { + buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); + } else { + buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f); + } +#endif +#elif defined(DATA_A_BF16) +#if LOAD_VEC_A == 4 + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + buf_a[buf_idx ] = TO_FLOAT_TYPE(data_a[idx].x); + buf_a[buf_idx + 1] = TO_FLOAT_TYPE(data_a[idx].y); + buf_a[buf_idx + 2] = TO_FLOAT_TYPE(data_a[idx].z); + buf_a[buf_idx + 3] = TO_FLOAT_TYPE(data_a[idx].w); +#else + if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { + buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); + } else { + buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(uint16_t(0)); + } +#endif +#elif defined(DATA_A_Q4_0) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a; + + const uint ib = idx / 4; + const uint iqs = idx & 0x03; + + const float d = float(data_a_packed16[ib].d); + const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); + const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d; + const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d; + + buf_a[buf_idx ] = FLOAT_TYPE(v0.x); + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y); + buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z); + buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w); + buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x); + buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); + buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); + buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); +#elif defined(DATA_A_Q4_1) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a; + + const uint ib = idx / 4; + const uint iqs = idx & 0x03; + + const float d = float(data_a_packed16[ib].d); + const float m = float(data_a_packed16[ib].m); + const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); + const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m; + const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m; + + buf_a[buf_idx ] = FLOAT_TYPE(v0.x); + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y); + buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z); + buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w); + buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x); + buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); + buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); + buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); +#elif defined(DATA_A_Q5_0) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; + + const uint ib = idx / 8; + const uint iqs = idx & 0x07; + + const float d = float(data_a_packed16[ib].d); + const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]); + const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); + + const uint vui = uint(data_a_packed16[ib].qs[iqs]); + const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d; + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z); + buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); + buf_a[buf_idx + 17] = FLOAT_TYPE(v.w); +#elif defined(DATA_A_Q5_1) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; + + const uint ib = idx / 8; + const uint iqs = idx & 0x07; + + const float d = float(data_a_packed16[ib].d); + const float m = float(data_a_packed16[ib].m); + const uint uint_qh = data_a_packed16[ib].qh; + const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); + + const uint vui = uint(data_a_packed16[ib].qs[iqs]); + const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m; + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z); + buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); + buf_a[buf_idx + 17] = FLOAT_TYPE(v.w); +#elif defined(DATA_A_Q8_0) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 8; + const uint iqs = idx & 0x07; + + const float d = float(data_a_packed16[ib].d); + const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy; + const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx + 2] = FLOAT_TYPE(v.z); + buf_a[buf_idx + 3] = FLOAT_TYPE(v.w); +#elif defined(DATA_A_Q2_K) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30 + const uint scalesi = iqs / 8; // 0..15 + const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + + const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]); + const uint scales = data_a[ib].scales[scalesi]; + const vec2 d = vec2(data_a[ib].d); + + const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_Q3_K) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 64; // 0,1 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + const uint hmi = (iqs % 16) * 2; // 0,2,4..30 + const uint j = (iqs % 64) / 4; // 0..3 + const uint is = iqs / 8; // 0..15 + const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3 + const uint qsshift = halfsplit * 2; // 0,2,4,6 + const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 + + const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF) + | (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4)); + const float dl = float(data_a[ib].d) * float(us - 32); + + buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4))); + buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); +#elif defined(DATA_A_Q4_K) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + + const vec2 loadd = vec2(data_a[ib].d); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float d = loadd.x * sc; + const float m = -loadd.y * mbyte; + + buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m)); + buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); +#elif defined(DATA_A_Q5_K) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + const uint qhi = (iqs % 16) * 2; // 0,2,4..30 + + const uint8_t hm = uint8_t(1 << (iqs / 16)); + + const vec2 loadd = vec2(data_a[ib].d); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float d = loadd.x * sc; + const float m = -loadd.y * mbyte; + + buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m)); + buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); +#elif defined(DATA_A_Q6_K) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 64; // 0,1 + const uint b = (iqs % 64) / 32; // 0,1 + const uint is_b = (iqs % 16) / 8; // 0,1 + const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + const uint is = 8 * n + qhshift + is_b; // 0..15 + const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 + const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + + const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]); + + buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32)); + buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); +#elif defined(DATA_A_IQ1_S) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint ib32 = (idx % 128) / 16; // 0..7 + const uint ib8 = (idx % 128) / 4; + const int i8 = 2 * int(idx % 4); + + const float d = float(data_a[ib].d); + const uint qh = data_a[ib].qh[ib32]; + const uint qs = data_a[ib].qs[ib8]; + const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); + + const ivec2 gvec = ivec2( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2) + ); + const vec2 v = dl * (vec2(gvec) + delta); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ1_M) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint ib8 = (idx % 128) / 4; + const uint ib16 = ib8 / 2; + const int i8 = 2 * int(idx % 4); + + const uint16_t[4] scales = data_a[ib].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + const uint sc = scales[ib8 / 8]; + const uint qs = data_a[ib].qs[ib8]; + const uint qh = data_a[ib].qh[ib16] >> (4 * (ib8 & 1)); + const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + const ivec2 gvec = ivec2( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2) + ); + const vec2 v = dl * (vec2(gvec) + delta); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ2_XXS) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint ib32 = (idx % 128) / 16; // 0..7 + const uint ib8 = (idx / 4) % 4; + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[8 * ib32 + ib8]; + const uint signs = pack32(u8vec4( + data_a[ib].qs[8*ib32 + 4], + data_a[ib].qs[8*ib32 + 5], + data_a[ib].qs[8*ib32 + 6], + data_a[ib].qs[8*ib32 + 7] + )); + const float db = d * 0.25 * (0.5 + (signs >> 28)); + const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); + const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ2_XS) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint ib32 = (idx % 128) / 16; // 0..7 + const uint ib8 = (idx / 4) % 4; // 0..3 + + const float d = float(data_a[ib].d); + const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; + const float db = d * 0.25 * (0.5 + scale); + const uint qs = data_a[ib].qs[4 * ib32 + ib8]; + const uint sign7 = qs >> 9; + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); + const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ2_S) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint ib8 = (idx % 128) / 4; // 0..31 + const uint ib32 = ib8 / 4; // 0..7 + + const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; + const uint qs = data_a[ib].qs[ib8]; + const uint qh = data_a[ib].qh[ib32]; + const uint qhshift = 2 * (ib8 % 4); + const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4)); + + const float d = float(data_a[ib].d); + const float db = d * 0.25 * (0.5 + scale); + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); + const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1]; + const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147 + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ3_XXS) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = (idx % 128) / 2; // 0..63 + const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[iqs]; + const uint signs = pack32(u8vec4( + data_a[ib].qs[is+0], + data_a[ib].qs[is+1], + data_a[ib].qs[is+2], + data_a[ib].qs[is+3] + )); + const float db = d * 0.5 * (0.5 + (signs >> 28)); + const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); + const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ3_S) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = (idx % 128) / 2; // 0..63 + const uint iqh = iqs / 8; + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[iqs]; + const uint qh = data_a[ib].qh[iqh]; + const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4))); + const uint scale = data_a[ib].scales[iqs / 16]; + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); + const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ4_XS) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint ib32 = (idx % 128) / 16; // 0..7 + const uint iq = 16 * ib32 + 2 * (idx % 8); + + const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3; + const uint qshift = (idx & 8) >> 1; + u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]); + qs = (qs >> qshift) & uint8_t(0xF); + + const float d = float(data_a[ib].d); + const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ4_NL) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; + + const uint ib = idx / 8; + const uint iqs = idx & 0x07; + + const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d); + const uint vui = uint(data_a_packed16[ib].qs[iqs]); + + buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d; + buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d; + buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d; + buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d; +#endif + } + [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { +#if LOAD_VEC_B == 8 +#ifdef MUL_MAT_ID + const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; +#else + const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; +#endif + const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; + buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x); + buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y); + buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z); + buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w); + buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x); + buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y); + buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z); + buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w); +#elif LOAD_VEC_B == 4 +#ifdef MUL_MAT_ID + const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; +#else + const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; +#endif + const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; + buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x); + buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y); + buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z); + buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w); +#elif !MUL_MAT_ID + if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); + } else { + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); + } +#else + const uint row_i = ic * BN + loadc_b + l; + if (row_i < _ne1) { + const u16vec2 row_idx = row_ids[row_i]; + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); + } else { + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); + } +#endif + } + + barrier(); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + +#ifdef COOPMAT + [[unroll]] for (uint i = 0; i < BK; i += TK) { + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + // Load from shared into cache + coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); + + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); + + sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]); + } + } + } +#else + [[unroll]] for (uint i = 0; i < BK; i++) { + // Load from shared into cache + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint j = 0; j < TM; j++) { + cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i]; + } + } + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint j = 0; j < TN; j++) { + cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; + } + + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; + sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]); + } + } + } + } + } +#endif + + barrier(); + } + + const uint dr = ir * BM + warp_r * WM; + const uint dc = ic * BN + warp_c * WN; + +#ifndef MUL_MAT_ID + const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; +#endif + +#ifdef COOPMAT +#ifdef MUL_MAT_ID + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + const uint row_i = dc + cm_col * TN + col + store_c; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i]; + + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } +#else + const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float + + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N; + + if (is_aligned && is_in_bounds) { + // Full coopMat is within bounds and stride_d is aligned with 16B + coopmat cm_dtype = coopmat(sums[cm_col * cms_per_row + cm_row]); + coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor); + } else if (is_in_bounds) { + // Full coopMat is within bounds, but stride_d is not aligned + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { + // Partial coopMat is within bounds + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } + } + } +#endif // MUL_MAT_ID +#else + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + + const uint dr_warp = dr + wsir * WSUBM + tiwr * TM; + const uint dc_warp = dc + wsic * WSUBN + tiwc * TN; + [[unroll]] for (uint cc = 0; cc < TN; cc++) { +#ifdef MUL_MAT_ID + const uint row_i = dc_warp + cc; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i]; +#endif // MUL_MAT_ID + [[unroll]] for (uint cr = 0; cr < TM; cr++) { +#ifdef MUL_MAT_ID + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); +#else + if (dr_warp + cr < p.M && dc_warp + cc < p.N) { + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + } +#endif // MUL_MAT_ID + } + } + } + } +#endif // COOPMAT +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp new file mode 100644 index 0000000000000000000000000000000000000000..9184657573281455975fc70eaae438e601327822 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -0,0 +1,441 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require + +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_cooperative_matrix : enable +#extension GL_NV_cooperative_matrix2 : enable +#extension GL_EXT_buffer_reference : enable +#extension GL_KHR_shader_subgroup_ballot : enable +#extension GL_KHR_shader_subgroup_vote : enable +#ifdef DATA_A_BF16 +#extension GL_EXT_bfloat16 : enable +#endif + +#include "types.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +#define IS_MUL_MM2 1 + +layout (constant_id = 0) const uint BLOCK_SIZE = 256; +layout (constant_id = 1) const uint BM = 64; +layout (constant_id = 2) const uint BN = 64; +layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant + +layout (constant_id = 4) const bool enable_smaller_matrices = false; +const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN; +const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN; + +layout (push_constant) uniform parameter +{ + uint M; + uint N; + uint K; + uint stride_a; + uint stride_b; + uint stride_d; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint nei0; + uint nei1; + uint nbi1; + uint ne11; +#else + uint k_split; + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; +#endif + // N dimension for the B matrix can be >= p.N + uint padded_N; +} p; + + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +#if QUANT_K > 1 +#define DECODEFUNCA , dequantFuncA + +#include "dequant_funcs_cm2.comp" + +#else +#define DECODEFUNCA +#endif + +#if !defined(fetch_scales) +#define fetch_scales(a, b, c, d, e, f) +#endif +#if !defined(store_scales) +#define store_scales(a) +#endif + +#if defined(DATA_A_BF16) +#define MAT_TYPE bfloat16_t +#else +#define MAT_TYPE FLOAT_TYPE +#endif + +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; + +shared u16vec4 row_ids[4096]; + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB { + B_TYPE b[]; +}; + +uint _ne1; +shared uint _ne1_sh; + +B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint row_i = blockCoords[0]; + + if (row_i >= _ne1) { + return B_TYPE(0.0); + } + + const u16vec4 row_idx = row_ids[row_i]; + B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]]; + + return ret; +} + +D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic) +{ + uint dr = ir * BM + r; + uint dc = ic * BN + c; + + if (dr < p.M && dc < _ne1) { + uint row_i = dc; + const u16vec4 row_idx = row_ids[row_i]; + data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem; + } + return elem; +} + +#endif + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + const uint tid = gl_LocalInvocationIndex; + +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.z; +#else + const uint batch_idx = gl_GlobalInvocationID.z; + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + const uint batch_idx_a = i03 * p.ne02 + i02; +#endif + + const uint blocks_m = (p.M + BM - 1) / BM; + const uint ir = gl_WorkGroupID.x % blocks_m; + const uint ik = gl_WorkGroupID.x / blocks_m; + const uint ic = gl_WorkGroupID.y; + +#ifdef MUL_MAT_ID + // Spread the search across all elements in the first subgroup + if (gl_SubgroupID == 0) { + _ne1 = 0; + uint num_elements = p.nei1 * p.nei0; + + for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) { + bool in_range = i < num_elements; + uint ii0 = i % p.nei0; + uint ii1 = i / p.nei0; + uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + uvec4 ballot = subgroupBallot(in_range && id == expert_idx); + uint idx = subgroupBallotExclusiveBitCount(ballot); + if (in_range && id == expert_idx) { + row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0); + } + _ne1 += subgroupBallotBitCount(ballot); + } + _ne1_sh = _ne1; + } + + barrier(); + + _ne1 = _ne1_sh; + + // Workgroup has no work + if (ic * BN >= _ne1) return; +#endif + +#ifdef MUL_MAT_ID + uint start_k = 0; + const uint end_k = p.K; +#else + uint start_k = ik * p.k_split; + const uint end_k = min(p.K, (ik + 1) * p.k_split); +#endif + +#ifdef MUL_MAT_ID + uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K; + uint pos_b = 0; +#else + uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K; + uint pos_b = batch_idx * p.batch_stride_b; + uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; +#endif + + uint stride_a = p.stride_a / QUANT_K; + uint stride_b = p.stride_b; + + // Hint to the compiler that values are aligned (want 16B alignment). + // Quants are always block-aligned, no alignment needed. +#if ALIGNED +#if QUANT_K == 1 + stride_a &= ~7; +#endif + stride_b &= ~7; +#endif + + // Create layouts for both clamped and unclamped accesses + tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2); + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutAClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2); + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1); + +#if QUANT_K > 1 + tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K); + tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K); +#endif + + // Use end_k rather than p.K as the dimension because that's what + // we need to bound check against when using split_k. + // Bounds check B against padded_N, but bounds check D against N. + tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k); + tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.padded_N, end_k); + tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M); + tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k); + tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k); + + tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); + +#if !defined(MUL_MAT_ID) + + const uint START_ALIGN_K = 256; + // For Qi_K (block size 256), unroll whole 256 element tiles. + // For legacy quants (block size 32), unroll 8x. + const uint UNROLL_K = (QUANT_K == 256) ? 256 : (BK * 8); + const uint unroll_count = UNROLL_K / BK; + + // Detect a fast path where all loads are entirely in bounds and no clamping is required + if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % START_ALIGN_K) == 0 && (end_k % BK) == 0 && +#if QUANT_K == 1 + (stride_a % 8) == 0 && +#endif + (stride_b % 8) == 0) { + // Hint to the compiler that values are aligned (want 16B alignment) + start_k &= ~(START_ALIGN_K-1); + stride_b &= ~7; +#if QUANT_K == 1 + stride_a &= ~7; +#endif + + tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); + tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); + + uint k_iters = (end_k - start_k) / UNROLL_K; + uint block_k = start_k; + + // fetch scale values for a tile of quants. These will be copied into shared memory. + // The fetches and stores are pipelined to hide the latency. + fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, true); + + if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) { + coopmat sum = coopmat(0.0); + for (uint i = 0; i < k_iters; ++i) { + + store_scales(tid); + if (block_k + UNROLL_K < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true); + } + + // Manually partial unroll + [[unroll]] for (uint j = 0; j < unroll_count; ++j) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + } + // Do any remaining iterations that were not unrolled + if (block_k < end_k) { + store_scales(tid); + } + while (block_k < end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + coopmat mat_d = coopmat(sum); + + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose); + return; + } else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) { + coopmat sum = coopmat(0.0); + for (uint i = 0; i < k_iters; ++i) { + + store_scales(tid); + if (block_k + UNROLL_K < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true); + } + + // Manually partial unroll + [[unroll]] for (uint j = 0; j < unroll_count; ++j) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + } + // Do any remaining iterations that were not unrolled + if (block_k < end_k) { + store_scales(tid); + } + while (block_k < end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + coopmat mat_d = coopmat(sum); + + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose); + return; + } else { + coopmat sum = coopmat(0.0); + + for (uint i = 0; i < k_iters; ++i) { + + store_scales(tid); + if (block_k + UNROLL_K < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true); + } + + // Manually partial unroll + [[unroll]] for (uint j = 0; j < unroll_count; ++j) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + } + // Do any remaining iterations that were not unrolled + if (block_k < end_k) { + store_scales(tid); + } + while (block_k < end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + coopmat mat_d = coopmat(sum); + + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); + return; + } + } else +#endif // !defined(MUL_MAT_ID) + { + tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); + + tensorLayoutAClamp = setTensorLayoutStrideNV(tensorLayoutAClamp, stride_a, 1); + + tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); + + tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1); + + coopmat sum; + sum = coopmat(0.0); + + uint k_iters = (end_k - start_k + BK - 1) / BK; + + fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false); + + [[dont_unroll]] + for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + + store_scales(tid); + if (block_k + BK < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false); + } + + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); +#ifdef MUL_MAT_ID + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); +#else + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); +#endif + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } + + // Convert from ACC_TYPE to D_TYPE + coopmat mat_d; + mat_d = coopmat(sum); + +#ifdef MUL_MAT_ID + // Call callback to store each element, remapping row through shared memory + coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); +#else + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); +#endif + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp new file mode 100644 index 0000000000000000000000000000000000000000..83de90eb7e0f23304ccece780b00857a79c521ed --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -0,0 +1,442 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require + +#extension GL_EXT_integer_dot_product : require + +#ifdef FLOAT16 +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#endif + +#ifdef COOPMAT +#extension GL_KHR_cooperative_matrix : enable +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_shader_subgroup_basic : enable +#endif + +#ifdef MUL_MAT_ID +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#endif + +#include "types.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif +layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; +#endif + +layout (push_constant) uniform parameter +{ + uint M; + uint N; + uint K; + uint stride_a; + uint stride_b; + uint stride_d; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint nei0; + uint nei1; + uint nbi1; + uint ne11; +#else + uint k_split; + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; +#endif +} p; + +layout (constant_id = 0) const uint BLOCK_SIZE = 64; +layout (constant_id = 1) const uint BM = 64; +layout (constant_id = 2) const uint BN = 64; +// layout (constant_id = 3) const uint BK = 32; +layout (constant_id = 4) const uint WM = 32; +layout (constant_id = 5) const uint WN = 32; +layout (constant_id = 6) const uint WMITER = 2; +layout (constant_id = 7) const uint TM = 4; +layout (constant_id = 8) const uint TN = 2; +layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat +layout (constant_id = 10) const uint WARP = 32; + +#define BK 32 + +#ifdef COOPMAT +#define SHMEM_STRIDE (BK / 4 + 4) +#else +#define SHMEM_STRIDE (BK / 4 + 1) +#endif + +shared int32_t buf_a_qs[BM * SHMEM_STRIDE]; + +#ifndef COOPMAT +#if QUANT_AUXF == 1 +shared FLOAT_TYPE buf_a_dm[BM]; +#else +shared FLOAT_TYPE_VEC2 buf_a_dm[BM]; +#endif +#endif + +shared int32_t buf_b_qs[BN * SHMEM_STRIDE]; +#ifndef COOPMAT +shared FLOAT_TYPE_VEC2 buf_b_ds[BN]; +#endif + +#define LOAD_VEC_A (4 * QUANT_R) +#define LOAD_VEC_B 4 + +#ifdef MUL_MAT_ID +shared u16vec2 row_ids[4096]; +#endif // MUL_MAT_ID + +#define NUM_WARPS (BLOCK_SIZE / WARP) + +#ifdef COOPMAT +shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; +#endif + +#include "mul_mmq_funcs.comp" + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.z; +#else + const uint batch_idx = gl_GlobalInvocationID.z; + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + const uint batch_idx_a = i03 * p.ne02 + i02; +#endif + + const uint blocks_m = (p.M + BM - 1) / BM; + const uint ir = gl_WorkGroupID.x % blocks_m; + const uint ik = gl_WorkGroupID.x / blocks_m; + const uint ic = gl_WorkGroupID.y; + + const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); + const uint WSUBM = WM / WMITER; + const uint WSUBN = WN / WNITER; + +#ifdef COOPMAT + const uint warp_i = gl_SubgroupID; + + const uint tiw = gl_SubgroupInvocationID; + + const uint cms_per_row = WM / TM; + const uint cms_per_col = WN / TN; + + const uint storestride = WARP / TM; + const uint store_r = tiw % TM; + const uint store_c = tiw / TM; +#else + const uint warp_i = gl_LocalInvocationID.x / WARP; + + const uint tiw = gl_LocalInvocationID.x % WARP; + + const uint tiwr = tiw % (WSUBM / TM); + const uint tiwc = tiw / (WSUBM / TM); +#endif + + const uint warp_r = warp_i % (BM / WM); + const uint warp_c = warp_i / (BM / WM); + + const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A); + const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A); + const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B); + const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B); + + const uint loadstride_a = BLOCK_SIZE * LOAD_VEC_A / BK; + const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK; + +#ifdef MUL_MAT_ID + uint _ne1 = 0; + for (uint ii1 = 0; ii1 < p.nei1; ii1++) { + for (uint ii0 = 0; ii0 < p.nei0; ii0++) { + if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { + row_ids[_ne1] = u16vec2(ii0, ii1); + _ne1++; + } + } + } + + barrier(); + + // Workgroup has no work + if (ic * BN >= _ne1) return; +#endif + +#ifdef MUL_MAT_ID + const uint start_k = 0; + const uint end_k = p.K; +#else + const uint start_k = ik * p.k_split; + const uint end_k = min(p.K, (ik + 1) * p.k_split); +#endif + + uint pos_a_ib = ( +#ifdef MUL_MAT_ID + expert_idx * p.batch_stride_a + +#else + batch_idx_a * p.batch_stride_a + +#endif + ir * BM * p.stride_a + start_k) / BK; +#ifdef MUL_MAT_ID + uint pos_b_ib = 0; +#else + uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK; +#endif + +#ifdef COOPMAT + coopmat cache_a; + coopmat cache_b; + coopmat cm_result; + + coopmat factors[cms_per_row * cms_per_col]; + + coopmat sums[cms_per_row * cms_per_col]; + + [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { + sums[i] = coopmat(0.0f); + } +#else + int32_t cache_a_qs[WMITER * TM * BK / 4]; + + int32_t cache_b_qs[TN * BK / 4]; + + ACC_TYPE sums[WMITER * TM * WNITER * TN]; + + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { + sums[i] = ACC_TYPE(0.0f); + } +#endif + +#if QUANT_AUXF == 1 + FLOAT_TYPE cache_a_dm[WMITER * TM]; +#else + FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM]; +#endif + + FLOAT_TYPE_VEC2 cache_b_ds[TN]; + + for (uint block = start_k; block < end_k; block += BK) { + [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) { + const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK; + const uint iqs = loadr_a; + const uint buf_ib = loadc_a + l; + + if (iqs == 0) { +#if QUANT_AUXF == 1 + buf_a_dm[buf_ib] = get_d(ib); +#else + buf_a_dm[buf_ib] = get_dm(ib); +#endif + } +#if QUANT_R == 1 + buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs); +#else + const i32vec2 vals = repack(ib, iqs); + buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x; + buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y; +#endif + } + [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) { +#ifdef MUL_MAT_ID + const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; + const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; + const uint ib = idx / 8; + const uint iqs = idx & 0x7; +#else + const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK; + const uint iqs = loadr_b; +#endif + + const uint buf_ib = loadc_b + l; + + if (iqs == 0) { + buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds); + } + buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs]; + } + + barrier(); + + pos_a_ib += 1; + pos_b_ib += 1; + +#ifdef COOPMAT + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + const uint ib_a = warp_r * WM + cm_row * TM; + // Load from shared into cache + coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); + + // TODO: only cache values that are actually needed + [[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) { + cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx]; + } + + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const uint ib_b = warp_c * WN + cm_col * TN; + coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); + + // TODO: only cache values that are actually needed + [[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) { + cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx]; + } + + cm_result = coopmat(0); + cm_result = coopMatMulAdd(cache_a, cache_b, cm_result); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col])); + } + + coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + sums[cm_col * cms_per_row + cm_row] += factors * coopmat(cm_result); + } + } +#else + // Load from shared into cache + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr; + cache_a_dm[wsir * TM + cr] = buf_a_dm[ib]; + [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { + cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k]; + } + } + } + + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc; + cache_b_ds[cc] = buf_b_ds[ib]; + [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { + cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k]; + } + } + + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint cache_a_idx = wsir * TM + cr; + const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; + int32_t q_sum = 0; + [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { + q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k], + cache_b_qs[cc * (BK / 4) + idx_k]); + } + + sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc]); + } + } + } + } +#endif + + barrier(); + } + + const uint dr = ir * BM + warp_r * WM; + const uint dc = ic * BN + warp_c * WN; + +#ifndef MUL_MAT_ID + const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; +#endif + +#ifdef COOPMAT +#ifdef MUL_MAT_ID + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < BN; col += storestride) { + const uint row_i = dc + cm_col * TN + col + store_c; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i]; + + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } +#else + const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float + + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N; + + if (is_aligned && is_in_bounds) { + // Full coopMat is within bounds and stride_d is aligned with 16B + coopmat cm_dtype = coopmat(sums[cm_col * cms_per_row + cm_row]); + coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor); + } else if (is_in_bounds) { + // Full coopMat is within bounds, but stride_d is not aligned + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { + // Partial coopMat is within bounds + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } + } + } +#endif // MUL_MAT_ID +#else + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + + const uint dr_warp = dr + wsir * WSUBM + tiwr * TM; + const uint dc_warp = dc + wsic * WSUBN + tiwc * TN; + [[unroll]] for (uint cc = 0; cc < TN; cc++) { +#ifdef MUL_MAT_ID + const uint row_i = dc_warp + cc; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i]; +#endif // MUL_MAT_ID + [[unroll]] for (uint cr = 0; cr < TM; cr++) { +#ifdef MUL_MAT_ID + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); +#else + if (dr_warp + cr < p.M && dc_warp + cc < p.N) { + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + } +#endif // MUL_MAT_ID + } + } + } + } +#endif // COOPMAT +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp new file mode 100644 index 0000000000000000000000000000000000000000..63b15471bd3aaa00db343e5bd7925dc5986bcfa9 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp @@ -0,0 +1,99 @@ +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require + +#include "types.comp" + +// Each iqs value maps to a 32-bit integer + +#if defined(DATA_A_Q4_0) +i32vec2 repack(uint ib, uint iqs) { + // Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4 + const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ], + data_a[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +} + +ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { + return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0f * dsb.y)); +} +#endif + +#if defined(DATA_A_Q4_1) +i32vec2 repack(uint ib, uint iqs) { + // Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4 + const uint32_t vui = data_a_packed32[ib].qs[iqs]; + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +} + +ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) { + return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y); +} +#endif + +#if defined(DATA_A_Q5_0) +i32vec2 repack(uint ib, uint iqs) { + // Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4 + const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ], + data_a[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs)); + const int32_t v0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + + const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + return i32vec2(v0, v1); +} + +ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { + return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0f * dsb.y)); +} +#endif + +#if defined(DATA_A_Q5_1) +i32vec2 repack(uint ib, uint iqs) { + // Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4 + const uint32_t vui = data_a_packed32[ib].qs[iqs]; + const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); + const int32_t v0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + + const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + return i32vec2(v0, v1); +} + +ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) { + return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y); +} +#endif + +#if defined(DATA_A_Q8_0) +int32_t repack(uint ib, uint iqs) { + // Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4 + return pack32(i16vec2(data_a[ib].qs[iqs * 2 ], + data_a[ib].qs[iqs * 2 + 1])); +} + +ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { + return ACC_TYPE(float(q_sum) * da * dsb.x); +} +#endif + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +FLOAT_TYPE get_d(uint ib) { + return FLOAT_TYPE(data_a[ib].d); +} +#endif + +#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) +FLOAT_TYPE_VEC2 get_dm(uint ib) { + return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); +} +#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp new file mode 100644 index 0000000000000000000000000000000000000000..6627a50bd949a07fb621c9147c845425d10f3408 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp @@ -0,0 +1,44 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +shared vec2 sum[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + sum[tid] = vec2(0.0f, 0.0f); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const float xi = float(data_a[row*p.KX + col]); + sum[tid].x += xi; + sum[tid].y += xi * xi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum[tid] += sum[tid + s]; + } + barrier(); + } + + const float mean = sum[0].x / p.KX; + const float var = sum[0].y / p.KX - mean * mean; + const float inv_std = inversesqrt(var + p.param1); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp new file mode 100644 index 0000000000000000000000000000000000000000..e0214fe7645c217e00ec77558b87beb5e6407be0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp @@ -0,0 +1,42 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) buffer X {A_TYPE x[];}; +layout (binding = 1) readonly buffer G {A_TYPE grad[];}; +layout (binding = 2) buffer GM {A_TYPE gradm[];}; +layout (binding = 3) buffer GV {A_TYPE gradv[];}; +layout (binding = 4) readonly buffer P {float params[7];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float alpha = params[0]; + const float beta1 = params[1]; + const float beta2 = params[2]; + const float eps = params[3]; + const float wd = params[4]; + const float beta1h = params[5]; + const float beta2h = params[6]; + + const float gi = grad[i]; + const float gmi = gradm[i]*beta1 + gi*(1.0f - beta1); + const float gvi = gradv[i]*beta2 + gi*gi*(1.0f - beta2); + + gradm[i] = gmi; + gradv[i] = gvi; + + const float mh = gmi*beta1h; + const float vh = sqrt(gvi*beta2h) + eps; + + x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp b/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp new file mode 100644 index 0000000000000000000000000000000000000000..450b67fc55d3760eb8c36bae55fbd6d9e41da86d --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp @@ -0,0 +1,28 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (idx >= p.ne) { + return; + } + + const uint i3 = idx / (p.ne12*p.ne11*p.ne10); + const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10; + const uint i2 = (idx - i3_offset) / (p.ne11*p.ne10); + const uint i2_offset = i2*p.ne11*p.ne10; + const uint i1 = (idx - i3_offset - i2_offset) / p.ne10; + const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10; + + const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00; + const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10; + + const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03; + + data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp new file mode 100644 index 0000000000000000000000000000000000000000..b6124411a054c7a168c01c85b097c22d3d763000 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp @@ -0,0 +1,74 @@ +#version 450 + +#include "types.comp" + +#extension GL_EXT_shader_16bit_storage : require + +layout(push_constant) uniform parameter { + uint IW; uint IH; + uint OW; uint OH; + uint OC; + uint pelements; + uint op; + int k0; int k1; + int s0; int s1; + int p0; int p1; +} p; + +#define BLOCK_SIZE 512 +#define FLT_MAX 3.402823466e+38F +#define OP_POOL_MAX 0u +#define OP_POOL_AVG 1u + +layout (local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout(binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout(binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint idx = gl_GlobalInvocationID.x; + if (idx >= p.pelements) { + return; + } + + const uint O_HW = p.OW * p.OH; + + const uint nc = idx / O_HW; + const uint cur_oh = (idx % O_HW) / p.OW; + const uint cur_ow = (idx % O_HW) % p.OW; + + const int start_h = int(cur_oh) * p.s0 - p.p0; + const uint bh = max(start_h, 0); + const uint eh = min(start_h + p.k0, p.IH); + + const int start_w = int(cur_ow) * p.s1 - p.p1; + const uint bw = max(start_w, 0); + const uint ew = min(start_w + p.k1, p.IW); + + const float scale = 1.0 / float(p.k0 * p.k1); + float res; + + if (p.op == OP_POOL_AVG) { + res = 0.0; + } else if (p.op == OP_POOL_MAX) { + res = -FLT_MAX; + } else { + return; + } + + #pragma unroll + for (uint i = bh; i < eh; i++) { + #pragma unroll + for (uint j = bw; j < ew; j++) { + const float cur = D_TYPE(data_a[nc * p.IH * p.IW + i * p.IW + j]); + + if (p.op == OP_POOL_AVG) { + res += cur * scale; + } else if (p.op == OP_POOL_MAX) { + res = max(res, cur); + } + } + } + + data_d[nc * O_HW + cur_oh * p.OW + cur_ow] = res; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp new file mode 100644 index 0000000000000000000000000000000000000000..e2e020fec2c6a9e247a50461b27e335327cb0a7a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp @@ -0,0 +1,77 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint ne; +} p; + +#include "types.comp" + +layout(constant_id = 0) const uint GROUP_SIZE = 32; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {vec4 data_a[];}; +layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];}; + +shared float shmem[GROUP_SIZE]; + +void quantize() { + const uint wgid = gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + // Each thread handles a vec4, so 8 threads handle a block + const uint blocks_per_group = GROUP_SIZE / 8; + + const uint block_in_wg = tid / 8; + + const uint ib = wgid * blocks_per_group + block_in_wg; + const uint iqs = tid % 8; + + if (ib >= gl_NumWorkGroups.x * blocks_per_group) { + return; + } + + const uint a_idx = ib * 8 + iqs; + + vec4 vals = a_idx < p.ne ? data_a[a_idx] : vec4(0.0f); + const vec4 abs_vals = abs(vals); + + // Find absolute max for each block + shmem[tid] = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w)); + barrier(); + [[unroll]] for (uint s = 4; s > 0; s >>= 1) { + if (iqs < s) { + shmem[tid] = max(shmem[tid], shmem[tid + s]); + } + barrier(); + } + + const float amax = shmem[block_in_wg * 8]; + const float d = amax / 127.0; + const float d_inv = d != 0.0 ? 1.0 / d : 0.0; + vals = round(vals * d_inv); + data_b[ib].qs[iqs] = pack32(i8vec4(round(vals))); + barrier(); + + // Calculate the sum for each block + shmem[tid] = vals.x + vals.y + vals.z + vals.w; + barrier(); + [[unroll]] for (uint s = 4; s > 0; s >>= 1) { + if (iqs < s) { + shmem[tid] += shmem[tid + s]; + } + barrier(); + } + if (iqs == 0) { + const float sum = shmem[tid]; + + data_b[ib].ds = f16vec2(vec2(d, sum * d)); + } +} + +void main() { + quantize(); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp new file mode 100644 index 0000000000000000000000000000000000000000..4f806270c7799a5e62f28db338f8752ecb0beac3 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp @@ -0,0 +1,21 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + data_d[i] = D_TYPE(max(float(data_a[i]), 0)); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp b/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp new file mode 100644 index 0000000000000000000000000000000000000000..1568b141de59eb8e51becee7b9360ca5410cd0a4 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp @@ -0,0 +1,26 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +uint src0_idx_mod(uint idx) { + const uint i13 = idx / (p.ne12*p.ne11*p.ne10); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = (idx - i13_offset - i12_offset) / p.ne10; + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + return (i13 % p.ne03)*p.nb03 + (i12 % p.ne02)*p.nb02 + (i11 % p.ne01)*p.nb01 + (i10 % p.ne00)*p.nb00; +} + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx_mod(idx)]); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp new file mode 100644 index 0000000000000000000000000000000000000000..d86279934f17655efaa48ccce674895a43623aa0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp @@ -0,0 +1,37 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + // Destination multi-index (inlined dst_idx) + const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + const uint d_idx = i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10; + + // Accumulate from sources + A_TYPE acc = A_TYPE(0); + for (uint i3 = i13; i3 < p.ne03; i3 += p.ne13) { + for (uint i2 = i12; i2 < p.ne02; i2 += p.ne12) { + for (uint i1 = i11; i1 < p.ne01; i1 += p.ne11) { + for (uint i0 = i10; i0 < p.ne00; i0 += p.ne10) { + acc += data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00]; + } + } + } + } + + data_d[get_doffset() + d_idx] = D_TYPE(acc); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp new file mode 100644 index 0000000000000000000000000000000000000000..deb8ee9960f58240fb5ea33b06809ffb2f4632e9 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -0,0 +1,52 @@ +#version 450 + +#include "generic_unary_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +shared FLOAT_TYPE sum[BLOCK_SIZE]; + +void main() { + const uint ncols = p.ne00; + const uint nrows = gl_NumWorkGroups.x; + const uint nchannels = gl_NumWorkGroups.y; + + const uint row = gl_WorkGroupID.x; + const uint channel = gl_WorkGroupID.y; + const uint samp = gl_WorkGroupID.z; + const uint tid = gl_LocalInvocationID.x; + + const uint stride_row = p.nb01; + const uint stride_channel = p.nb02; + const uint stride_sample = p.nb03; + + uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset(); + uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); + + sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp + + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]); + sum[tid] += xi * xi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum[tid] += sum[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols); + const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); + + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp new file mode 100644 index 0000000000000000000000000000000000000000..76009f3df6783a5a6d1f7b6550c7c4d528935768 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp @@ -0,0 +1,55 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer G {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer X {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE sum_xx[BLOCK_SIZE]; +shared FLOAT_TYPE sum_xg[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + // Compute derivative of x[i]/norm(x) = g[i]/norm(x) - x[i] dot(x,g)/KX / norm(x)^1.5 + + // partial sums for thread in warp + sum_xx[tid] = FLOAT_TYPE(0.0f); + sum_xg[tid] = FLOAT_TYPE(0.0f); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const FLOAT_TYPE gi = FLOAT_TYPE(data_a[row*p.KX + col]); + const FLOAT_TYPE xi = FLOAT_TYPE(data_b[row*p.KX + col]); + sum_xx[tid] += xi * xi; + sum_xg[tid] += xi * gi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum_xx[tid] += sum_xx[tid + s]; + sum_xg[tid] += sum_xg[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE eps = FLOAT_TYPE(p.param1); + const FLOAT_TYPE mean = sum_xx[0] / FLOAT_TYPE(p.KX); + const FLOAT_TYPE scale_g = inversesqrt(mean + eps); + const FLOAT_TYPE scale_x = -scale_g * sum_xg[0] / (sum_xx[0] + FLOAT_TYPE(p.KX) * eps); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE( + scale_g * FLOAT_TYPE(data_a[row*p.KX + col]) + + scale_x * FLOAT_TYPE(data_b[row*p.KX + col])); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp new file mode 100644 index 0000000000000000000000000000000000000000..96c9c4cbd307ced519430adfeee6218e5869db63 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp @@ -0,0 +1,58 @@ +#include "types.comp" + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_spirv_intrinsics: enable + +#if RTE16 +spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits +#endif + +layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {int data_pos[];}; +layout (binding = 2) readonly buffer Z {float data_ff[];}; +layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; + +layout (push_constant) uniform parameter { + uint ncols; + uint n_dims; + float freq_scale; + uint p_delta_rows; + float freq_base; + float ext_factor; + float attn_factor; + float corr_dims[2]; + float theta_scale; + uint has_ff; + uint ne02; + uint s1; + uint s2; + int sections[4]; + uint is_back; +} p; + +float rope_yarn_ramp(const float low, const float high, const uint i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) { + float mscale = p.attn_factor; + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = p.freq_scale * theta_extrap; + float theta = theta_interp; + if (p.ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale); + } + // Backprogagation uses inverted rotation + if (p.is_back != 0) { + theta = -theta; + } + cos_theta = cos(theta) * mscale; + sin_theta = sin(theta) * mscale; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp new file mode 100644 index 0000000000000000000000000000000000000000..4f5b1a0ecaf5dce002598b6519c0dd053444d2a4 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -0,0 +1,60 @@ +#version 450 + +#include "rope_head.comp" + +void main() { + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + uint ne2 = p.ne02; + + if (i0 >= ne0) { + return; + } + + const uint row_dst = gl_GlobalInvocationID.x; + + if (i0 >= p.n_dims) { + const uint i = row_dst*ne0 + i0; + + data_d[i + 0] = data_a[i + 0]; + data_d[i + 1] = data_a[i + 1]; + + return; + } + + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + + const uint idst = row_dst*ne0 + i0/2; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + + const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3]; + const int sec_w = p.sections[1] + p.sections[0]; + const uint sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (sector < p.sections[0]) { + theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= p.sections[0] && sector < sec_w) { + theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= sec_w && sector < sec_w + p.sections[2]) { + theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= sec_w + p.sections[2]) { + theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + } + + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); + + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + p.n_dims/2]); + + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp new file mode 100644 index 0000000000000000000000000000000000000000..db775c456cae82f2767b159f9101b11eee4abc12 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -0,0 +1,43 @@ +#version 450 + +#include "rope_head.comp" + +void main() { + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + + if (i0 >= ne0) { + return; + } + + const uint row_dst = gl_GlobalInvocationID.x; + + if (i0 >= p.n_dims) { + const uint i = row_dst*ne0 + i0; + + data_d[i + 0] = data_a[i + 0]; + data_d[i + 1] = data_a[i + 1]; + + return; + } + + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + + const uint idst = row_dst*ne0 + i0/2; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + + const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); + + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); + + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + p.n_dims/2]); + + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp new file mode 100644 index 0000000000000000000000000000000000000000..4ad35e549d77fa9db0317f6caa4b4427089336ed --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -0,0 +1,43 @@ +#version 450 + +#include "rope_head.comp" + +void main() { + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + + if (i0 >= ne0) { + return; + } + + const uint row_dst = gl_GlobalInvocationID.x; + + if (i0 >= p.n_dims) { + const uint i = row_dst*ne0 + i0; + + data_d[i + 0] = data_a[i + 0]; + data_d[i + 1] = data_a[i + 1]; + + return; + } + + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + + const uint idst = row_dst*ne0 + i0; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0; + + const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); + + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); + + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + 1]); + + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + 1] = D_TYPE(x0*sin_theta + x1*cos_theta); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp new file mode 100644 index 0000000000000000000000000000000000000000..cedacc4d14439184e58c8bdf5f61845c037ee28c --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp @@ -0,0 +1,47 @@ +#version 450 + +#include "rope_head.comp" + +void main() { + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + uint ne2 = p.ne02; + + if (i0 >= ne0) { + return; + } + + const uint row_dst = gl_GlobalInvocationID.x; + + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + + const uint idst = row_dst*ne0 + i0/2; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + + const int sect_dims = p.sections[0] + p.sections[1]; + const int sec_w = p.sections[1] + p.sections[0]; + const uint sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (sector < p.sections[0]) { + const uint p0 = sector; + theta_base = data_pos[channel_x]*pow(p.theta_scale, p0); + } + else if (sector >= p.sections[0] && sector < sec_w) { + const uint p0 = sector - p.sections[0]; + theta_base = data_pos[channel_x + ne2]*pow(p.theta_scale, p0); + } + + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); + + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + p.n_dims]); + + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + p.n_dims] = D_TYPE(x0*sin_theta + x1*cos_theta); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp b/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp new file mode 100644 index 0000000000000000000000000000000000000000..4663428dee0a2878987a2eda364ebec1c66b5169 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp @@ -0,0 +1,24 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +const uint num_threads = 128; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 4; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1)); + idx += num_threads; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp new file mode 100644 index 0000000000000000000000000000000000000000..5c9e5c350323b2e306005bcefeed16d5c79fc703 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp @@ -0,0 +1,20 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + data_d[i] = D_TYPE(1. / (1 + exp(-1. * float(data_a[i])))); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp new file mode 100644 index 0000000000000000000000000000000000000000..4d36f88e089bcf61d137541886c608a8286416c9 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float xi = float(data_a[i]); + data_d[i] = D_TYPE(xi / (1.0f + exp(-xi))); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp new file mode 100644 index 0000000000000000000000000000000000000000..f9afa9b13c1f224d07126bd74e1d5e9fe4029227 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp @@ -0,0 +1,26 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer G {A_TYPE data_g[];}; +layout (binding = 1) readonly buffer X {B_TYPE data_x[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + // Compute derivative of SiLU(x): 1/(1+exp(-x)) - x*exp(-x)/(1+exp(-x))^2 + + const float xi = float(data_x[i]); + const float s = 1.0f / (1.0f + exp(-xi)); + data_d[i] = D_TYPE(data_g[i] * (s + xi * s * (1 - s))); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp new file mode 100644 index 0000000000000000000000000000000000000000..d7c15a16959535a112772c5956db8a555d0852cd --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val)); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp new file mode 100644 index 0000000000000000000000000000000000000000..51fc2dc7ed406cf7a87401ed3913f8ae2241a95f --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp @@ -0,0 +1,173 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +layout (push_constant) uniform parameter +{ + uint KX; + uint KY; + float scale; + float max_bias; + float m0; + float m1; + uint n_head_log2; + uint nrows_x; +} p; + +#include "types.comp" + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; +layout (binding = 2) buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE vals[BLOCK_SIZE]; + +// num_iters is the number of BLOCK_SIZE loop iterations we need to iterate +// over all the columns. The main function tries to pass a constant here, +// as if it were a template function, to allow unrolling. +void soft_max(uint num_iters) { + const uint tid = gl_LocalInvocationID.x; + const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint rowy = (p.KY > 0) ? (rowx % p.KY) : 0; + + if (rowx >= p.nrows_x) { + return; + } + + float slope = 1.0f; + + // ALiBi + if (p.max_bias > 0.0f) { + const uint h = rowx/p.KY; // head index + + const float base = h < p.n_head_log2 ? p.m0 : p.m1; + const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; + + slope = pow(base, exp); + } + + // Find max + FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000); + + // Cache values while we compute the max, so we don't need to read them + // again when we're ready to compute exp(x-max). + const uint DATA_CACHE_SIZE = 16; + FLOAT_TYPE data_cache[DATA_CACHE_SIZE]; + + [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + FLOAT_TYPE a = FLOAT_TYPE(0); + if (col < p.KX) { + a = data_a[rowx * p.KX + col]; + } + + FLOAT_TYPE b = FLOAT_TYPE(0); + if (p.KY > 0 && col < p.KX) { + b = data_b[rowy * p.KX + col]; + } + + FLOAT_TYPE v = a * p.scale + slope * b; + + if (col < p.KX) { + max_val = max(max_val, v); + } + + if (idx < DATA_CACHE_SIZE) { + data_cache[idx] = v; + } + } + + // reduce across the workgroup + vals[tid] = max_val; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(vals[tid], vals[tid + s]); + } + barrier(); + } + + max_val = vals[0]; + barrier(); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); + + // Compute sum{exp(x - max)} + [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + if (col >= p.KX) { + break; + } + + // compute exp(a*scale+b*slope), add it to sum, and cache the new value + // in data_cache if possible. + const uint i = rowx * p.KX + col; + FLOAT_TYPE val; + if (idx < DATA_CACHE_SIZE) { + val = exp(data_cache[idx] - max_val); + } else { + val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val); + } + sum += val; + if (idx < DATA_CACHE_SIZE) { + data_cache[idx] = val; + } else { + data_d[i] = D_TYPE(val); + } + } + + // reduce across the workgroup + vals[tid] = sum; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] += vals[tid + s]; + } + barrier(); + } + sum = vals[0]; + + FLOAT_TYPE rcpdivisor = 1.0/sum; + + [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + if (col >= p.KX) { + continue; + } + + if (idx < DATA_CACHE_SIZE) { + data_d[rowx*p.KX + col] = D_TYPE(data_cache[idx] * rcpdivisor); + } else { + data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor); + } + } +} + +void main() { + // instantiate the soft_max function for several different + // dimensions, to allow loop unrolling + uint num_blocks = (p.KX + BLOCK_SIZE - 1) / BLOCK_SIZE; + if (num_blocks > 32) { + soft_max(num_blocks); + } else if (num_blocks > 16) { + soft_max(32); + } else if (num_blocks > 8) { + soft_max(16); + } else if (num_blocks > 4) { + soft_max(8); + } else if (num_blocks == 4) { + soft_max(4); + } else if (num_blocks == 3) { + soft_max(3); + } else if (num_blocks == 2) { + soft_max(2); + } else if (num_blocks == 1) { + soft_max(1); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp new file mode 100644 index 0000000000000000000000000000000000000000..29bd77d7e1c886f6d806ffef23fc07015bdba99c --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp @@ -0,0 +1,50 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#include "generic_head.comp" +#include "types.comp" + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +// In this shader Y = softmax(X) and X is not provided as input. + +layout (binding = 0) readonly buffer G {A_TYPE data_g[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_y[];}; +layout (binding = 2) buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE sum_yg[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + FLOAT_TYPE scale = p.param1; + + // partial sums for thread in warp + sum_yg[tid] = FLOAT_TYPE(0.0f); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const FLOAT_TYPE gi = FLOAT_TYPE(data_g[row*p.KX + col]); + const FLOAT_TYPE yi = FLOAT_TYPE(data_y[row*p.KX + col]); + sum_yg[tid] += yi * gi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum_yg[tid] += sum_yg[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE dot_yg = sum_yg[0]; + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE(scale + * (FLOAT_TYPE(data_g[row*p.KX + col]) - dot_yg) + * FLOAT_TYPE(data_y[row*p.KX + col])); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/square.comp b/ggml/src/ggml-vulkan/vulkan-shaders/square.comp new file mode 100644 index 0000000000000000000000000000000000000000..ef43598baf3a562a9b23f93278ee9b4a7ef09987 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/square.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp new file mode 100644 index 0000000000000000000000000000000000000000..72353cc3296ed66d0ee10287e5a44c83115d735f --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp @@ -0,0 +1,29 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +#include "types.comp" +#include "generic_binary_head.comp" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) - FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + + idx += num_threads; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp new file mode 100644 index 0000000000000000000000000000000000000000..961e5ffa1f56f42b6b35a855ecc582d47ac7d298 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp @@ -0,0 +1,37 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; + +shared FLOAT_TYPE tmp[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint col = gl_LocalInvocationID.x; + + tmp[col] = FLOAT_TYPE(0.0f); + + for (uint i = col; i < p.KX; i += BLOCK_SIZE) { + tmp[col] += FLOAT_TYPE(data_a[row*p.KX + i]); + } + + barrier(); + [[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) { + if (col < s) { + tmp[col] += tmp[col + s]; + } + barrier(); + } + + if (col == 0) { + data_d[row] = D_TYPE(tmp[0]); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp b/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp new file mode 100644 index 0000000000000000000000000000000000000000..8a6f868f58a7c82234780abe78fb8d30a9525d9f --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp @@ -0,0 +1,20 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + data_d[i] = D_TYPE(1. - 2. / (exp(2.*float(data_a[i])) + 1.)); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp new file mode 100644 index 0000000000000000000000000000000000000000..fd0ba401feb0cbf2ef08daa5e686f3d93620a81a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp @@ -0,0 +1,7 @@ +#version 460 + +#extension GL_EXT_bfloat16 : require + +void main() +{ +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp new file mode 100644 index 0000000000000000000000000000000000000000..28eb24e11f871cdb7e1413757181ce3cdb3ed481 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp @@ -0,0 +1,7 @@ +#version 460 + +#extension GL_NV_cooperative_matrix2 : require + +void main() +{ +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp new file mode 100644 index 0000000000000000000000000000000000000000..8c5dd1bd1679c8476ae9815aad549a5e0dfa980f --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp @@ -0,0 +1,7 @@ +#version 460 + +#extension GL_KHR_cooperative_matrix : require + +void main() +{ +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp new file mode 100644 index 0000000000000000000000000000000000000000..470e3074d938adb636135cd91eeda75766c20474 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp @@ -0,0 +1,7 @@ +#version 460 + +#extension GL_EXT_integer_dot_product : require + +void main() +{ +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp b/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp new file mode 100644 index 0000000000000000000000000000000000000000..79e065a9313aa1552ca55cec3d588fcefbcf5266 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp @@ -0,0 +1,41 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint nb1; + uint dim; + uint max_period; +} p; + +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 256 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_WorkGroupID.y; + const uint j = gl_GlobalInvocationID.x; + const uint d_offset = i * p.nb1; + + if (p.dim % 2 != 0 && j == ((p.dim + 1) / 2)) { + data_d[d_offset + p.dim] = 0.f; + } + + const uint half_dim = p.dim / 2; + if (j >= half_dim) { + return; + } + + const float timestep = float(data_a[i]); + const float freq = float(exp(-log(p.max_period) * j / half_dim)); + const float arg = timestep * freq; + data_d[d_offset + j] = D_TYPE(cos(arg)); + data_d[d_offset + j + half_dim] = D_TYPE(sin(arg)); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp new file mode 100644 index 0000000000000000000000000000000000000000..3bde717832b45115d4e538a5e2d44c086e1d0e2a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -0,0 +1,1373 @@ +#if !defined(GGML_TYPES_COMP) +#define GGML_TYPES_COMP + +#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#extension GL_EXT_shader_16bit_storage : require + +#if defined(DATA_A_F32) +#define QUANT_K 1 +#define QUANT_R 1 + +#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 +#define A_TYPE float +#elif LOAD_VEC_A == 4 +#define A_TYPE vec4 +#elif LOAD_VEC_A == 8 +#define A_TYPE mat2x4 +#endif +#endif + +#if defined(DATA_A_F16) +#define QUANT_K 1 +#define QUANT_R 1 + +#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 +#define A_TYPE float16_t +#elif LOAD_VEC_A == 4 +#define A_TYPE f16vec4 +#elif LOAD_VEC_A == 8 +#define A_TYPE f16mat2x4 +#endif +#endif + +#if defined(DATA_A_BF16) +#define QUANT_K 1 +#define QUANT_R 1 + +#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 +#define A_TYPE uint16_t +#elif LOAD_VEC_A == 4 +#define A_TYPE u16vec4 +#elif LOAD_VEC_A == 8 +#error unsupported +#endif +#endif + +#define QUANT_K_Q4_0 32 +#define QUANT_R_Q4_0 2 + +struct block_q4_0 +{ + float16_t d; + uint8_t qs[16]; +}; +struct block_q4_0_packed16 +{ + float16_t d; + uint16_t qs[16/2]; +}; + +#if defined(DATA_A_Q4_0) +#define QUANT_K QUANT_K_Q4_0 +#define QUANT_R QUANT_R_Q4_0 +#define QUANT_AUXF 1 +#define A_TYPE block_q4_0 +#define A_TYPE_PACKED16 block_q4_0_packed16 +#endif + +#define QUANT_K_Q4_1 32 +#define QUANT_R_Q4_1 2 + +struct block_q4_1 +{ + float16_t d; + float16_t m; + uint8_t qs[16]; +}; + +struct block_q4_1_packed16 +{ + float16_t d; + float16_t m; + uint16_t qs[16/2]; +}; + +struct block_q4_1_packed32 +{ + f16vec2 dm; + uint32_t qs[16/4]; +}; + +#if defined(DATA_A_Q4_1) +#define QUANT_K QUANT_K_Q4_1 +#define QUANT_R QUANT_R_Q4_1 +#define QUANT_AUXF 2 +#define A_TYPE block_q4_1 +#define A_TYPE_PACKED16 block_q4_1_packed16 +#define A_TYPE_PACKED32 block_q4_1_packed32 +#endif + +#define QUANT_K_Q5_0 32 +#define QUANT_R_Q5_0 2 + +struct block_q5_0 +{ + float16_t d; + uint16_t qh[2]; + uint8_t qs[16]; +}; + +struct block_q5_0_packed16 +{ + float16_t d; + uint16_t qh[2]; + uint16_t qs[16/2]; +}; + +#if defined(DATA_A_Q5_0) +#define QUANT_K QUANT_K_Q5_0 +#define QUANT_R QUANT_R_Q5_0 +#define QUANT_AUXF 1 +#define A_TYPE block_q5_0 +#define A_TYPE_PACKED16 block_q5_0_packed16 +#endif + +#define QUANT_K_Q5_1 32 +#define QUANT_R_Q5_1 2 + +struct block_q5_1 +{ + float16_t d; + float16_t m; + uint qh; + uint8_t qs[16]; +}; + +struct block_q5_1_packed16 +{ + float16_t d; + float16_t m; + uint qh; + uint16_t qs[16/2]; +}; + +struct block_q5_1_packed32 +{ + f16vec2 dm; + uint qh; + uint32_t qs[16/4]; +}; + +#if defined(DATA_A_Q5_1) +#define QUANT_K QUANT_K_Q5_1 +#define QUANT_R QUANT_R_Q5_1 +#define QUANT_AUXF 2 +#define A_TYPE block_q5_1 +#define A_TYPE_PACKED16 block_q5_1_packed16 +#define A_TYPE_PACKED32 block_q5_1_packed32 +#endif + +#define QUANT_K_Q8_0 32 +#define QUANT_R_Q8_0 1 + +struct block_q8_0 +{ + float16_t d; + int8_t qs[32]; +}; +struct block_q8_0_packed16 +{ + float16_t d; + int16_t qs[32/2]; +}; +struct block_q8_0_packed32 +{ + float16_t d; + int32_t qs[32/4]; +}; + +#if defined(DATA_A_Q8_0) +#define QUANT_K QUANT_K_Q8_0 +#define QUANT_R QUANT_R_Q8_0 +#define QUANT_AUXF 1 +#define A_TYPE block_q8_0 +#define A_TYPE_PACKED16 block_q8_0_packed16 +#define A_TYPE_PACKED32 block_q8_0_packed32 +#endif + +#define QUANT_K_Q8_1 32 +#define QUANT_R_Q8_1 1 + +struct block_q8_1 +{ + f16vec2 ds; + int8_t qs[32]; +}; +struct block_q8_1_packed16 +{ + f16vec2 ds; + int16_t qs[16]; +}; +struct block_q8_1_packed32 +{ + f16vec2 ds; + int32_t qs[8]; +}; + +// K-quants +#define QUANT_K_Q2_K 256 + +struct block_q2_K +{ + uint8_t scales[QUANT_K_Q2_K/16]; + uint8_t qs[QUANT_K_Q2_K/4]; + f16vec2 d; +}; + +struct block_q2_K_packed16 +{ + uint16_t scales[QUANT_K_Q2_K/16/2]; + uint16_t qs[QUANT_K_Q2_K/4/2]; + f16vec2 d; +}; + +struct block_q2_K_packed32 +{ + uint32_t scales[QUANT_K_Q2_K/16/4]; + uint32_t qs[QUANT_K_Q2_K/4/4]; + f16vec2 d; +}; + +#if defined(DATA_A_Q2_K) +#define QUANT_K QUANT_K_Q2_K +#define A_TYPE block_q2_K +#define A_TYPE_PACKED16 block_q2_K_packed16 +#define A_TYPE_PACKED32 block_q2_K_packed32 +#endif + +#define QUANT_K_Q3_K 256 + +struct block_q3_K +{ + uint8_t hmask[QUANT_K_Q3_K/8]; + uint8_t qs[QUANT_K_Q3_K/4]; + uint8_t scales[12]; + float16_t d; +}; + +struct block_q3_K_packed16 +{ + uint16_t hmask[QUANT_K_Q3_K/8/2]; + uint16_t qs[QUANT_K_Q3_K/4/2]; + uint16_t scales[12/2]; + float16_t d; +}; + +#if defined(DATA_A_Q3_K) +#define QUANT_K QUANT_K_Q3_K +#define A_TYPE block_q3_K +#define A_TYPE_PACKED16 block_q3_K_packed16 +#endif + +#define QUANT_K_Q4_K 256 + +struct block_q4_K +{ + f16vec2 d; + uint8_t scales[3*QUANT_K_Q4_K/64]; + uint8_t qs[QUANT_K_Q4_K/2]; +}; + +struct block_q4_K_packed16 +{ + f16vec2 d; + uint16_t scales[3*QUANT_K_Q4_K/64/2]; + uint16_t qs[QUANT_K_Q4_K/2/2]; +}; + +struct block_q4_K_packed32 +{ + f16vec2 d; + uint32_t scales[3*QUANT_K_Q4_K/64/4]; + uint32_t qs[QUANT_K_Q4_K/2/4]; +}; + +struct block_q4_K_packed128 +{ + uvec4 q4k[9]; +}; + +#if defined(DATA_A_Q4_K) +#define QUANT_K QUANT_K_Q4_K +#define A_TYPE block_q4_K +#define A_TYPE_PACKED16 block_q4_K_packed16 +#define A_TYPE_PACKED32 block_q4_K_packed32 +#endif + +#define QUANT_K_Q5_K 256 + +struct block_q5_K +{ + f16vec2 d; + uint8_t scales[12]; + uint8_t qh[QUANT_K_Q5_K/8]; + uint8_t qs[QUANT_K_Q5_K/2]; +}; + +struct block_q5_K_packed16 +{ + f16vec2 d; + uint16_t scales[12/2]; + uint16_t qh[QUANT_K_Q5_K/8/2]; + uint16_t qs[QUANT_K_Q5_K/2/2]; +}; + +struct block_q5_K_packed128 +{ + uvec4 q5k[11]; +}; + +#if defined(DATA_A_Q5_K) +#define QUANT_K QUANT_K_Q5_K +#define A_TYPE block_q5_K +#define A_TYPE_PACKED16 block_q5_K_packed16 +#endif + +#define QUANT_K_Q6_K 256 + +struct block_q6_K +{ + uint8_t ql[QUANT_K_Q6_K/2]; + uint8_t qh[QUANT_K_Q6_K/4]; + int8_t scales[QUANT_K_Q6_K/16]; + float16_t d; +}; + +struct block_q6_K_packed16 +{ + uint16_t ql[QUANT_K_Q6_K/2/2]; + uint16_t qh[QUANT_K_Q6_K/4/2]; + int8_t scales[QUANT_K_Q6_K/16]; + float16_t d; +}; + +#if defined(DATA_A_Q6_K) +#define QUANT_K QUANT_K_Q6_K +#define A_TYPE block_q6_K +#define A_TYPE_PACKED16 block_q6_K_packed16 +#endif + +// IQuants + +#define QUANT_K_IQ1_S 256 +#define QUANT_R_IQ1_S 1 + +struct block_iq1_s { + float16_t d; + uint8_t qs[QUANT_K_IQ1_S/8]; + uint16_t qh[QUANT_K_IQ1_S/32]; +}; + +#define QUANT_K_IQ1_M 256 +#define QUANT_R_IQ1_M 1 + +struct block_iq1_m { + uint8_t qs[QUANT_K_IQ1_M/8]; + uint8_t qh[QUANT_K_IQ1_M/16]; + uint16_t scales[QUANT_K_IQ1_M/64]; +}; + +struct block_iq1_m_packed64 { + uint64_t qs[QUANT_K_IQ1_M/8/8]; + uint64_t qh[QUANT_K_IQ1_M/16/8]; + uint64_t scales; +}; + +#if defined(DATA_A_IQ1_S) +#define QUANT_K QUANT_K_IQ1_S +#define QUANT_R QUANT_R_IQ1_S +#define A_TYPE block_iq1_s +#endif + +#if defined(DATA_A_IQ1_M) +#define QUANT_K QUANT_K_IQ1_M +#define QUANT_R QUANT_R_IQ1_M +#define A_TYPE block_iq1_m +#endif + +#if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M) +#define IQ1S_DELTA 0.125f +#define IQ1M_DELTA 0.125f + +// Packed IQ1S grid where every 2 vec8 are encoded on 32 bits (2 bits per coordinate). +const uint[1024] iq1s_grid_const = { + 0xfffdffff, 0xfff7fff0, 0xffccfff5, 0xffdfffc0, 0xffd7ffdd, 0xff30ffd5, 0xff03ff0c, 0xff10ff01, + 0xff7dff7f, 0xff75ff77, 0xff5fff40, 0xff57ff5d, 0xfcf3ff55, 0xfcccfcf0, 0xfcc1fcc3, 0xfcc5fcc4, + 0xfc3cfcd0, 0xfc34fc31, 0xfc00fc0d, 0xfc1cfc05, 0xfc11fc13, 0xfc70fc17, 0xfc43fc4c, 0xfc50fc41, + 0xfdfdfdff, 0xfdf5fdf7, 0xfddffdc0, 0xfdd7fddd, 0xfd30fdd5, 0xfd04fd0c, 0xfd14fd13, 0xfd7dfd7f, + 0xfd75fd77, 0xfd40fd4c, 0xfd5ffd44, 0xfd57fd5d, 0xf3ccfd55, 0xf3c1f3c3, 0xf33cf3d0, 0xf300f334, + 0xf313f305, 0xf34cf310, 0xf350f344, 0xf0f3f0fc, 0xf0f1f0f0, 0xf0c7f0c0, 0xf0d4f0c5, 0xf030f03f, + 0xf00ff035, 0xf003f00c, 0xf001f000, 0xf01ff004, 0xf010f01d, 0xf015f017, 0xf04cf07c, 0xf047f040, + 0xf05cf045, 0xf050f053, 0xf054f051, 0xf1c4f1c3, 0xf133f13c, 0xf10df10f, 0xf107f100, 0xf11cf11f, + 0xf114f111, 0xf14cf170, 0xf144f143, 0xf7fdf7ff, 0xf7f5f7f7, 0xf7dff7c0, 0xf7d7f7dd, 0xf730f7d5, + 0xf701f70c, 0xf77ff710, 0xf777f77d, 0xf740f775, 0xf75df75f, 0xf755f757, 0xf4ccf4f0, 0xf4c4f4c3, + 0xf4d0f4d3, 0xf40ff43c, 0xf400f40c, 0xf413f41c, 0xf44cf414, 0xf441f443, 0xf450f444, 0xf5fdf5ff, + 0xf5f5f5f7, 0xf5dff5c0, 0xf5d7f5dd, 0xf530f5d5, 0xf504f50c, 0xf510f51c, 0xf57df57f, 0xf577f570, + 0xf540f575, 0xf55df55f, 0xf555f557, 0xcfcccfcf, 0xcfc4cfc3, 0xcfd0cfd3, 0xcf33cf3c, 0xcf00cf0f, + 0xcf1ccf07, 0xcf10cf13, 0xcf4ccf14, 0xcf41cf43, 0xcf50cf5c, 0xccf3ccfc, 0xccf4ccf1, 0xcccdcccf, + 0xccc7ccc0, 0xccd3ccdc, 0xcc30ccd4, 0xcc0fcc35, 0xcc0dcc0c, 0xcc00cc03, 0xcc04cc01, 0xcc10cc1f, + 0xcc4dcc73, 0xcc5ccc40, 0xcdcccc53, 0xcdc1cdc3, 0xcd3fcdd0, 0xcd34cd31, 0xcd00cd0d, 0xcd05cd07, + 0xcd11cd13, 0xcd4ccd70, 0xcd41cd43, 0xc3fccd50, 0xc3f4c3f1, 0xc3c0c3c3, 0xc3c4c3c7, 0xc3d1c3dc, + 0xc330c33c, 0xc337c331, 0xc30cc335, 0xc300c303, 0xc304c301, 0xc310c31d, 0xc373c317, 0xc34fc374, + 0xc340c343, 0xc344c347, 0xc35cc345, 0xc350c353, 0xc0fdc354, 0xc0f5c0f0, 0xc0c3c0cc, 0xc0c1c0c0, + 0xc0dfc0c4, 0xc0d0c0dd, 0xc0d5c0d7, 0xc033c03c, 0xc031c030, 0xc00dc00c, 0xc000c003, 0xc004c001, + 0xc01cc005, 0xc010c013, 0xc014c011, 0xc07dc07f, 0xc070c073, 0xc075c077, 0xc04cc04f, 0xc040c043, + 0xc044c041, 0xc05fc045, 0xc050c05d, 0xc1f3c1fc, 0xc1f1c1f0, 0xc1c1c1c0, 0xc1c5c1c7, 0xc1d1c1dc, + 0xc13dc13f, 0xc130c133, 0xc135c137, 0xc100c10c, 0xc107c101, 0xc11cc104, 0xc110c113, 0xc114c117, + 0xc171c115, 0xc14dc175, 0xc153c140, 0xc7ccc154, 0xc7d0c7c1, 0xc733c73c, 0xc734c731, 0xc700c70f, + 0xc705c707, 0xc71cc71f, 0xc711c713, 0xc770c714, 0xc743c74c, 0xc4cfc750, 0xc4c0c4cd, 0xc4dcc4c5, + 0xc43dc4d0, 0xc430c433, 0xc40cc437, 0xc400c403, 0xc404c401, 0xc41fc405, 0xc415c410, 0xc44cc474, + 0xc440c44d, 0xc45cc447, 0xc454c451, 0xc5c1c5f4, 0xc5d1c5d3, 0xc531c533, 0xc50fc534, 0xc500c50d, + 0xc51cc507, 0xc514c511, 0xc54cc570, 0xc545c541, 0xdffddfff, 0xdff5dff7, 0xdfdfdfc0, 0xdfd0dfdd, + 0xdfd5dfd7, 0xdf0cdf30, 0xdf1cdf04, 0xdf7fdf10, 0xdf77df7d, 0xdf40df75, 0xdf5ddf5f, 0xdf57df50, + 0xdcf0df55, 0xdcc3dccc, 0xdcd0dcc4, 0xdc33dc3d, 0xdc00dc34, 0xdc05dc07, 0xdc13dc1c, 0xdc11dc10, + 0xdc4fdc70, 0xdc44dc41, 0xddfcdc50, 0xddf5ddf7, 0xddc0ddcc, 0xdddddddf, 0xddd5ddd7, 0xdd0cdd30, + 0xdd04dd01, 0xdd7cdd10, 0xdd75dd77, 0xdd40dd4c, 0xdd5ddd5f, 0xdd55dd57, 0xd3c3d3f0, 0xd3c4d3c1, + 0xd333d3d0, 0xd331d330, 0xd30dd334, 0xd307d300, 0xd311d305, 0xd34cd370, 0xd344d343, 0xd350d35c, + 0xd0c0d0f4, 0xd0d4d0dc, 0xd030d03f, 0xd00cd037, 0xd000d003, 0xd01dd004, 0xd017d010, 0xd04fd074, + 0xd040d043, 0xd045d047, 0xd053d05c, 0xd054d051, 0xd1cfd1f0, 0xd1c4d1cd, 0xd13cd1d0, 0xd100d134, + 0xd11cd11f, 0xd173d114, 0xd14fd171, 0xd7ffd145, 0xd7f7d7fd, 0xd7c0d7f5, 0xd7ddd7df, 0xd7d5d7d7, + 0xd70cd730, 0xd710d703, 0xd77dd77f, 0xd775d777, 0xd75dd75f, 0xd755d757, 0xd4ccd4f4, 0xd4c4d4c3, + 0xd431d4d0, 0xd40dd434, 0xd41cd400, 0xd411d413, 0xd470d414, 0xd441d44f, 0xd453d444, 0xd5ffd450, + 0xd5f7d5fd, 0xd5dfd5f5, 0xd5d7d5dd, 0xd530d5d5, 0xd501d50c, 0xd510d504, 0xd57dd57f, 0xd575d577, + 0xd55fd540, 0xd557d55d, 0x3ff0d555, 0x3fc13fcc, 0x3f343fd0, 0x3f003f0d, 0x3f053f07, 0x3f133f1c, + 0x3f433f11, 0x3f5c3f44, 0x3cff3f51, 0x3cf33cfc, 0x3cf43cf1, 0x3cc03ccd, 0x3cc73cc1, 0x3cdc3cc5, + 0x3cd43cd1, 0x3c373c30, 0x3c0c3c35, 0x3c003c03, 0x3c043c01, 0x3c103c05, 0x3c153c17, 0x3c733c7c, + 0x3c4f3c71, 0x3c403c4d, 0x3c5c3c5f, 0x3df03c5d, 0x3dc33dcc, 0x3dd03dc1, 0x3d0d3d3c, 0x3d053d00, + 0x3d143d13, 0x3d433d74, 0x33fc3d50, 0x33c433c0, 0x333033d4, 0x33353337, 0x3303330c, 0x33013300, + 0x331d331c, 0x33173310, 0x337c3315, 0x33743371, 0x334d334f, 0x335f3340, 0x3354335c, 0x30fd30fc, + 0x30f530f0, 0x30c330cc, 0x30c130c0, 0x30df30c4, 0x30d530d0, 0x3033303c, 0x30313030, 0x300f3034, + 0x3003300c, 0x30013000, 0x30043007, 0x3013301c, 0x30113010, 0x307d3014, 0x30703073, 0x304c3077, + 0x30403043, 0x30443041, 0x30503045, 0x30553057, 0x31f031fc, 0x31c331f4, 0x31c731c0, 0x31dc31c5, + 0x31d431d3, 0x313d313f, 0x31373130, 0x310c310f, 0x3100310d, 0x31043101, 0x3110311d, 0x317c3117, + 0x31753170, 0x31403143, 0x3153315c, 0x37f03151, 0x37c037cc, 0x37d037c5, 0x3734373d, 0x3700370f, + 0x371c3707, 0x37113713, 0x37703714, 0x3743374c, 0x37443741, 0x34fc3750, 0x34f134f0, 0x34cf34f5, + 0x34c034c3, 0x34dc34c7, 0x34d134d3, 0x3430343f, 0x340c3435, 0x3403340d, 0x34013400, 0x341f3404, + 0x3410341d, 0x34153411, 0x34743471, 0x3440344d, 0x34473441, 0x3453345c, 0x34543451, 0x353335c1, + 0x35343531, 0x35073500, 0x35133505, 0x35433514, 0x0ffc3550, 0x0ff00ff3, 0x0ff40ff1, 0x0fc00fcd, + 0x0fdc0fc5, 0x0fd40fd3, 0x0f300f3f, 0x0f0c0f37, 0x0f000f03, 0x0f040f01, 0x0f170f10, 0x0f740f71, + 0x0f470f40, 0x0f5c0f5f, 0x0f540f51, 0x0cf70cf0, 0x0cf50cf4, 0x0cc30ccc, 0x0cc10cc0, 0x0cc40cc7, + 0x0cd00cdf, 0x0cd70cd1, 0x0c3c0cd5, 0x0c300c33, 0x0c340c31, 0x0c0c0c0f, 0x0c030c0d, 0x0c010c00, + 0x0c040c07, 0x0c1c0c05, 0x0c100c13, 0x0c140c11, 0x0c700c7d, 0x0c430c4c, 0x0c410c40, 0x0c5f0c44, + 0x0c550c50, 0x0df10dfc, 0x0dc00dcd, 0x0ddc0dc5, 0x0d3d0dd3, 0x0d350d30, 0x0d030d0c, 0x0d010d00, + 0x0d1d0d04, 0x0d700d10, 0x0d4d0d4f, 0x0d440d40, 0x0d530d45, 0x03f003f3, 0x03c303cc, 0x03c103c0, + 0x03c403c7, 0x03d003dc, 0x03d503d7, 0x0333033c, 0x03310330, 0x03350334, 0x030c030f, 0x03000303, + 0x03070301, 0x03050304, 0x031d031c, 0x03100313, 0x03140311, 0x0377037f, 0x034c0375, 0x03400343, + 0x03440341, 0x0353035c, 0x03550350, 0x00fd00fc, 0x00f000f3, 0x00f400f1, 0x00cc00cf, 0x00c300cd, + 0x00c100c0, 0x00c500c4, 0x00d300dc, 0x00d100d0, 0x003f00d4, 0x003d003c, 0x00300033, 0x00370031, + 0x000f0034, 0x000d000c, 0x00000003, 0x00070001, 0x00050004, 0x001c001f, 0x00100013, 0x00170011, + 0x00150014, 0x0073007c, 0x00740070, 0x004f0075, 0x0043004c, 0x00410040, 0x00440047, 0x0053005c, + 0x00510050, 0x01ff0054, 0x01fd01fc, 0x01f101f3, 0x01f401f7, 0x01c301cc, 0x01c701c0, 0x01df01c4, + 0x01dd01dc, 0x01d001d3, 0x01d701d1, 0x013c01d4, 0x01310130, 0x01340137, 0x010f0135, 0x010d010c, + 0x01000103, 0x01070101, 0x01050104, 0x0113011c, 0x01140110, 0x0170017d, 0x01770171, 0x01750174, + 0x0140014c, 0x015d0145, 0x01510150, 0x01540157, 0x07f007f3, 0x07f407f1, 0x07c007cf, 0x07dc07c7, + 0x073007d5, 0x07350737, 0x0703070c, 0x07010700, 0x07040707, 0x071d071f, 0x07100713, 0x0774077d, + 0x074d074f, 0x07470740, 0x0754075c, 0x04fd04fc, 0x04f504f0, 0x04c304cc, 0x04c104c0, 0x04d004c4, + 0x0433043c, 0x04310430, 0x040f0434, 0x040d040c, 0x04000403, 0x04070401, 0x04050404, 0x0413041c, + 0x04110410, 0x047c0414, 0x04740470, 0x0443044c, 0x04410440, 0x04440447, 0x05f30450, 0x05c005f7, + 0x05df05c5, 0x05d105d0, 0x053005d4, 0x05340537, 0x0500050c, 0x05070501, 0x051d0504, 0x05170510, + 0x057c0515, 0x054d0575, 0x05410540, 0x05450547, 0x1ff0055c, 0x1fc11fc3, 0x1fd01fc4, 0x1f0f1f33, + 0x1f011f00, 0x1f051f07, 0x1f131f1c, 0x1f141f11, 0x1f411f7c, 0x1cfc1f50, 0x1cf11cf3, 0x1ccd1cf4, + 0x1cdc1cc0, 0x1cd11cdd, 0x1c301cd4, 0x1c0c1c34, 0x1c011c00, 0x1c101c04, 0x1c151c11, 0x1c751c73, + 0x1c401c4d, 0x1c511c5c, 0x1dcc1c54, 0x1dc41dc1, 0x1d3c1d3f, 0x1d001d31, 0x1d071d01, 0x1d701d1f, + 0x1d411d4c, 0x13cc1d50, 0x13c013cd, 0x13c513c1, 0x13d113dc, 0x133f13d4, 0x1330133d, 0x13351337, + 0x1303130c, 0x13011300, 0x13051304, 0x131d131f, 0x13731310, 0x13741370, 0x134d134f, 0x13401343, + 0x13471341, 0x135c1345, 0x13541353, 0x10f710f0, 0x10cc10f5, 0x10c110c0, 0x103310c4, 0x10311030, + 0x100f1034, 0x1003100c, 0x10011000, 0x101c1004, 0x10101013, 0x10141011, 0x10741071, 0x104c1075, + 0x10411040, 0x10451044, 0x1050105d, 0x10571051, 0x11f411fd, 0x11df11c0, 0x11d711d1, 0x113f11d4, + 0x11371130, 0x110c1135, 0x11001103, 0x11071101, 0x111f1105, 0x11171110, 0x117d117f, 0x11751170, + 0x11411143, 0x11441147, 0x1153115f, 0x11551151, 0x17c417c1, 0x173c17d0, 0x1700170d, 0x171c1705, + 0x17701714, 0x1747174c, 0x14fc1751, 0x14cf14f3, 0x14dc14c0, 0x14d114d3, 0x143f14d4, 0x1430143c, + 0x14371431, 0x1403140c, 0x14011400, 0x141f1404, 0x14151410, 0x1473147d, 0x14401475, 0x1453145c, + 0x14541450, 0x15c115cc, 0x153c15c7, 0x15341533, 0x1500150f, 0x15051507, 0x15101513, 0x15711514, + 0x15471543, 0x15511545, 0x7ffd7fff, 0x7ff57ff7, 0x7fdd7fdf, 0x7fd57fd7, 0x7f0f7f30, 0x7f037f0c, + 0x7f047f01, 0x7f7f7f10, 0x7f777f7d, 0x7f407f75, 0x7f5d7f5f, 0x7f557f57, 0x7ccc7cf0, 0x7cc17cc3, + 0x7cd07cc4, 0x7c337c3c, 0x7c0f7c34, 0x7c007c0d, 0x7c077c01, 0x7c137c04, 0x7c147c11, 0x7c747c70, + 0x7c417c43, 0x7c507c44, 0x7dfd7dff, 0x7df57df7, 0x7ddf7dc0, 0x7dd77ddd, 0x7d0c7dd5, 0x7d047d03, + 0x7d7f7d10, 0x7d777d7d, 0x7d407d75, 0x7d5d7d5f, 0x7d557d57, 0x73c473c3, 0x7333733c, 0x7300730c, + 0x731c7305, 0x73147313, 0x73447343, 0x70f470fc, 0x70c070cd, 0x70d170c5, 0x703f70d4, 0x7030703c, + 0x700c7037, 0x70007003, 0x70047001, 0x70107005, 0x70177011, 0x707c7015, 0x70717073, 0x704f7074, + 0x7040704d, 0x70517047, 0x71c171cc, 0x71d071c4, 0x7133713c, 0x71357134, 0x7100710f, 0x71057104, + 0x7111711c, 0x71707115, 0x7145714c, 0x77ff7153, 0x77f777fd, 0x77c077f5, 0x77dd77df, 0x77d577d7, + 0x7730773c, 0x7703770c, 0x77107704, 0x777f7714, 0x7777777d, 0x77407775, 0x775d775f, 0x77557757, + 0x74f174f0, 0x74c374cc, 0x74d074c1, 0x7433743c, 0x74347431, 0x740d740f, 0x74057400, 0x7413741c, + 0x74417470, 0x74507444, 0x75fd75ff, 0x75f575f7, 0x75df75c0, 0x75d775dd, 0x753075d5, 0x7503750c, + 0x757f7501, 0x7577757d, 0x75407575, 0x755d755f, 0x75557557, 0x4fcc4ff0, 0x4fc74fc1, 0x4fd04fc4, + 0x4f314f3c, 0x4f004f34, 0x4f054f07, 0x4f154f14, 0x4f4c4f70, 0x4f414f43, 0x4f504f44, 0x4cf34cfc, + 0x4cf44cf1, 0x4cc04ccf, 0x4cc54cc7, 0x4cd34cdc, 0x4cd44cd1, 0x4c304c3f, 0x4c0c4c0f, 0x4c004c03, + 0x4c044c01, 0x4c104c1d, 0x4c714c73, 0x4c404c4d, 0x4c5c4c47, 0x4c514c53, 0x4df04c54, 0x4dc34dcc, + 0x4dd04dc4, 0x4d314d33, 0x4d0f4d34, 0x4d004d0d, 0x4d114d07, 0x4d704d14, 0x4d414d43, 0x43fc4d54, + 0x43f143f3, 0x43c043cf, 0x43d143c7, 0x4335433f, 0x4303430c, 0x43014300, 0x43044307, 0x431c431f, + 0x4310431d, 0x43714373, 0x4343434d, 0x43474340, 0x4354435c, 0x40f040ff, 0x40f540f7, 0x40cc40cf, + 0x40c040c3, 0x40c440c1, 0x40d040dc, 0x40d540d4, 0x4033403c, 0x40314030, 0x400f4034, 0x400d400c, + 0x40004003, 0x40074001, 0x40054004, 0x4013401c, 0x40114010, 0x407c4014, 0x40774070, 0x404d404c, + 0x40404043, 0x40444041, 0x405f4045, 0x4050405d, 0x40554057, 0x41f341fc, 0x41c041cf, 0x41df41c4, + 0x41d441d1, 0x41374130, 0x410c4134, 0x4100410d, 0x41044101, 0x41174110, 0x4173417d, 0x41754174, + 0x4143414d, 0x41534140, 0x41544151, 0x47c147f0, 0x47d047c4, 0x4731473c, 0x470d470f, 0x47014700, + 0x47134705, 0x47704710, 0x4741474c, 0x47504744, 0x44f144f3, 0x44cf44f4, 0x44c044cd, 0x44c544c7, + 0x44dc44df, 0x44d144d3, 0x443d443f, 0x44374430, 0x440c4435, 0x44004403, 0x44044401, 0x4410441d, + 0x44154411, 0x4473447c, 0x444d444f, 0x44454440, 0x4451445c, 0x45c045f0, 0x453345d0, 0x45344531, + 0x4500450f, 0x451c4507, 0x454c4570, 0x45404543, 0x5fff4541, 0x5ff75ffd, 0x5fc05ff5, 0x5fdd5fdf, + 0x5fd55fd7, 0x5f0c5f30, 0x5f015f03, 0x5f7f5f04, 0x5f775f7d, 0x5f405f75, 0x5f5d5f5f, 0x5f555f57, + 0x5cf45cf0, 0x5cc35ccc, 0x5cc45cc1, 0x5c315cc5, 0x5c0c5c34, 0x5c075c00, 0x5c1c5c05, 0x5c705c13, + 0x5c4d5c4f, 0x5c445c41, 0x5df75dfd, 0x5dcf5df5, 0x5ddd5dc4, 0x5dd55dd7, 0x5d0c5d30, 0x5d045d01, + 0x5d7f5d10, 0x5d775d7d, 0x5d405d75, 0x5d5d5d5f, 0x5d555d57, 0x53d053c4, 0x5333533c, 0x5303530f, + 0x53075300, 0x531c5305, 0x53115310, 0x53145317, 0x50f15370, 0x50cf50f4, 0x50c050cd, 0x50d150c7, + 0x503d50d4, 0x500c5030, 0x50005003, 0x50045001, 0x50155010, 0x5073507c, 0x50715070, 0x504d5074, + 0x50475040, 0x51cc51f0, 0x51c551c1, 0x51d051dc, 0x51315133, 0x510d5135, 0x51015100, 0x511f5107, + 0x5171511d, 0x5140514f, 0x51445141, 0x5153515c, 0x57ff5151, 0x57f757fd, 0x57df57f5, 0x57d757dd, + 0x570c57d5, 0x57015703, 0x577f5704, 0x5777577d, 0x57405775, 0x575d575f, 0x57555757, 0x54c354f0, + 0x54dc54c4, 0x543c54d0, 0x5400540f, 0x541c5405, 0x54145411, 0x5441544f, 0x55fd55ff, 0x55f555f7, + 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557 +}; + +shared uint16_t iq1s_grid[2048]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + [[unroll]] for (uint i = 0; i < iq1s_grid_const.length(); i += wgsize.x) { + uint idx = i + gl_LocalInvocationIndex.x; + if (iq1s_grid_const.length() % wgsize.x == 0 || idx < iq1s_grid_const.length()) { + u16vec2 g = unpack16(iq1s_grid_const[idx]); + iq1s_grid[2*idx+0] = g.x; + iq1s_grid[2*idx+1] = g.y; + } + } + barrier(); +} +#endif + +#define QUANT_K_IQ2_XXS 256 +#define QUANT_R_IQ2_XXS 1 + +struct block_iq2_xxs +{ + float16_t d; + uint8_t qs[QUANT_K_IQ2_XXS/4]; +}; + +struct block_iq2_xxs_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_XXS/8]; +}; + +#if defined(DATA_A_IQ2_XXS) + +const uvec2[256] iq2xxs_grid_const = { + uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808), + uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x082b0808, 0x08080808), + uvec2(0x082b082b, 0x08080808), uvec2(0x082b2b08, 0x08080808), uvec2(0x082b2b2b, 0x08080808), uvec2(0x19080819, 0x08080808), + uvec2(0x19081908, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), + uvec2(0x192b1908, 0x08080808), uvec2(0x2b080808, 0x08080808), uvec2(0x2b08082b, 0x08080808), uvec2(0x2b082b2b, 0x08080808), + uvec2(0x2b2b082b, 0x08080808), uvec2(0x08080819, 0x08080819), uvec2(0x08081908, 0x08080819), uvec2(0x08190808, 0x08080819), + uvec2(0x08191919, 0x08080819), uvec2(0x19080808, 0x08080819), uvec2(0x2b081908, 0x08080819), uvec2(0x2b192b08, 0x08080819), + uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), uvec2(0x082b082b, 0x0808082b), uvec2(0x2b08082b, 0x0808082b), + uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x082b0819, 0x08081908), + uvec2(0x082b1908, 0x08081908), uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19082b08, 0x08081908), + uvec2(0x192b0808, 0x08081908), uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b190808, 0x08081908), + uvec2(0x2b2b1908, 0x08081908), uvec2(0x08080808, 0x08081919), uvec2(0x0808082b, 0x08081919), uvec2(0x08082b08, 0x08081919), + uvec2(0x082b0808, 0x08081919), uvec2(0x1908192b, 0x08081919), uvec2(0x192b2b19, 0x08081919), uvec2(0x2b080808, 0x08081919), + uvec2(0x2b190819, 0x08081919), uvec2(0x08082b19, 0x0808192b), uvec2(0x08190808, 0x0808192b), uvec2(0x19080808, 0x0808192b), + uvec2(0x2b081908, 0x0808192b), uvec2(0x2b2b1908, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x08081919, 0x08082b08), + uvec2(0x08082b08, 0x08082b08), uvec2(0x08191908, 0x08082b08), uvec2(0x082b2b08, 0x08082b08), uvec2(0x19080819, 0x08082b08), + uvec2(0x19081908, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x1919082b, 0x08082b08), uvec2(0x2b082b08, 0x08082b08), + uvec2(0x08081908, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x0808082b, 0x08082b2b), uvec2(0x08191908, 0x08082b2b), + uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x082b0819, 0x08190808), + uvec2(0x19080808, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x2b081908, 0x08190808), uvec2(0x2b190808, 0x08190808), + uvec2(0x2b191919, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x08082b08, 0x08190819), uvec2(0x082b0808, 0x08190819), + uvec2(0x19190808, 0x08190819), uvec2(0x19192b2b, 0x08190819), uvec2(0x2b080808, 0x08190819), uvec2(0x082b1908, 0x0819082b), + uvec2(0x19081919, 0x0819082b), uvec2(0x08080808, 0x08191908), uvec2(0x08082b08, 0x08191908), uvec2(0x082b0808, 0x08191908), + uvec2(0x082b1919, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x08192b08, 0x08191919), + uvec2(0x192b082b, 0x08191919), uvec2(0x08080808, 0x0819192b), uvec2(0x0819192b, 0x0819192b), uvec2(0x08080819, 0x08192b08), + uvec2(0x08081908, 0x08192b08), uvec2(0x08190808, 0x08192b08), uvec2(0x19080808, 0x08192b08), uvec2(0x2b080819, 0x08192b08), + uvec2(0x08080808, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x2b2b0808, 0x08192b19), uvec2(0x19190819, 0x08192b2b), + uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08082b2b, 0x082b0808), uvec2(0x19081908, 0x082b0808), + uvec2(0x192b0819, 0x082b0808), uvec2(0x2b080808, 0x082b0808), uvec2(0x2b08082b, 0x082b0808), uvec2(0x082b2b19, 0x082b0819), + uvec2(0x19082b08, 0x082b0819), uvec2(0x08080808, 0x082b082b), uvec2(0x0808082b, 0x082b082b), uvec2(0x08080819, 0x082b1908), + uvec2(0x08081908, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x19080808, 0x082b1908), uvec2(0x1919192b, 0x082b1908), + uvec2(0x08080808, 0x082b1919), uvec2(0x19080819, 0x082b1919), uvec2(0x192b1908, 0x082b1919), uvec2(0x2b190808, 0x082b192b), + uvec2(0x08082b08, 0x082b2b08), uvec2(0x082b0808, 0x082b2b08), uvec2(0x2b191908, 0x082b2b08), uvec2(0x19081908, 0x082b2b2b), + uvec2(0x08080819, 0x19080808), uvec2(0x08081908, 0x19080808), uvec2(0x08190808, 0x19080808), uvec2(0x08192b08, 0x19080808), + uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), uvec2(0x19080808, 0x19080808), uvec2(0x19082b08, 0x19080808), + uvec2(0x1919192b, 0x19080808), uvec2(0x192b0808, 0x19080808), uvec2(0x2b080819, 0x19080808), uvec2(0x2b081908, 0x19080808), + uvec2(0x2b190808, 0x19080808), uvec2(0x08080808, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x192b0819, 0x19080819), + uvec2(0x2b080808, 0x19080819), uvec2(0x2b081919, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08190808, 0x1908082b), + uvec2(0x19082b08, 0x1908082b), uvec2(0x1919192b, 0x1908082b), uvec2(0x192b2b08, 0x1908082b), uvec2(0x08080808, 0x19081908), + uvec2(0x08082b08, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x2b080808, 0x19081908), uvec2(0x2b192b19, 0x19081908), + uvec2(0x0819082b, 0x19081919), uvec2(0x082b1908, 0x19081919), uvec2(0x08080808, 0x1908192b), uvec2(0x08080819, 0x19082b08), + uvec2(0x08081908, 0x19082b08), uvec2(0x08190808, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x19081919, 0x19082b08), + uvec2(0x08080808, 0x19082b19), uvec2(0x19192b08, 0x19082b19), uvec2(0x192b0819, 0x19082b19), uvec2(0x2b08082b, 0x19082b19), + uvec2(0x19081919, 0x19082b2b), uvec2(0x2b190808, 0x19082b2b), uvec2(0x08080808, 0x19190808), uvec2(0x08082b08, 0x19190808), + uvec2(0x08190819, 0x19190808), uvec2(0x08192b19, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x2b080808, 0x19190808), + uvec2(0x2b082b08, 0x19190808), uvec2(0x08081908, 0x19190819), uvec2(0x1908082b, 0x19190819), uvec2(0x2b2b1908, 0x19190819), + uvec2(0x2b190819, 0x1919082b), uvec2(0x2b190808, 0x19191908), uvec2(0x2b19082b, 0x19191908), uvec2(0x08082b2b, 0x19191919), + uvec2(0x08080819, 0x1919192b), uvec2(0x19191908, 0x1919192b), uvec2(0x08080808, 0x19192b08), uvec2(0x08190819, 0x19192b08), + uvec2(0x08192b19, 0x19192b08), uvec2(0x192b1908, 0x19192b08), uvec2(0x19080808, 0x19192b19), uvec2(0x08082b08, 0x19192b2b), + uvec2(0x08081908, 0x192b0808), uvec2(0x08190808, 0x192b0808), uvec2(0x19080808, 0x192b0808), uvec2(0x192b2b08, 0x192b0808), + uvec2(0x08080808, 0x192b0819), uvec2(0x19191919, 0x192b0819), uvec2(0x08192b08, 0x192b082b), uvec2(0x192b0808, 0x192b082b), + uvec2(0x08080808, 0x192b1908), uvec2(0x08081919, 0x192b1908), uvec2(0x08190808, 0x192b1919), uvec2(0x0819082b, 0x192b1919), + uvec2(0x2b081908, 0x192b1919), uvec2(0x1908082b, 0x192b2b08), uvec2(0x08080808, 0x2b080808), uvec2(0x0808082b, 0x2b080808), + uvec2(0x08082b2b, 0x2b080808), uvec2(0x19080819, 0x2b080808), uvec2(0x2b08082b, 0x2b080808), uvec2(0x08081908, 0x2b080819), + uvec2(0x08192b08, 0x2b080819), uvec2(0x19080808, 0x2b080819), uvec2(0x08190819, 0x2b08082b), uvec2(0x08080819, 0x2b081908), + uvec2(0x08081908, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x08191919, 0x2b081908), uvec2(0x19080808, 0x2b081908), + uvec2(0x192b0808, 0x2b081908), uvec2(0x08080808, 0x2b081919), uvec2(0x1908192b, 0x2b081919), uvec2(0x2b191908, 0x2b081919), + uvec2(0x08082b19, 0x2b08192b), uvec2(0x19080808, 0x2b08192b), uvec2(0x192b0808, 0x2b08192b), uvec2(0x0808082b, 0x2b082b08), + uvec2(0x08081908, 0x2b082b19), uvec2(0x08190819, 0x2b082b2b), uvec2(0x08081908, 0x2b190808), uvec2(0x08190808, 0x2b190808), + uvec2(0x082b1908, 0x2b190808), uvec2(0x19080808, 0x2b190808), uvec2(0x2b2b0819, 0x2b190808), uvec2(0x0819192b, 0x2b190819), + uvec2(0x2b080808, 0x2b190819), uvec2(0x19081919, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x082b082b, 0x2b191908), + uvec2(0x19081908, 0x2b191908), uvec2(0x19190819, 0x2b191919), uvec2(0x2b080819, 0x2b192b08), uvec2(0x082b0808, 0x2b192b19), + uvec2(0x0808082b, 0x2b2b0808), uvec2(0x19190808, 0x2b2b0808), uvec2(0x2b081919, 0x2b2b0808), uvec2(0x08082b19, 0x2b2b0819), + uvec2(0x08080808, 0x2b2b082b), uvec2(0x08192b08, 0x2b2b1908), uvec2(0x19190808, 0x2b2b2b08), uvec2(0x08081908, 0x2b2b2b19) +}; + +shared uvec2 iq2xxs_grid[256]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + [[unroll]] for (uint i = 0; i < iq2xxs_grid.length(); i += wgsize.x) { + if (iq2xxs_grid_const.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2xxs_grid_const.length()) { + iq2xxs_grid[i + gl_LocalInvocationIndex.x] = iq2xxs_grid_const[i + gl_LocalInvocationIndex.x]; + } + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ2_XXS +#define QUANT_R QUANT_R_IQ2_XXS +#define A_TYPE block_iq2_xxs +#define A_TYPE_PACKED16 block_iq2_xxs_packed16 +#endif + +#define QUANT_K_IQ2_XS 256 +#define QUANT_R_IQ2_XS 1 + +struct block_iq2_xs +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_XS/8]; + uint8_t scales[QUANT_K_IQ2_XS/32]; +}; + +struct block_iq2_xs_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_XS/8]; + uint16_t scales[QUANT_K_IQ2_XS/64]; +}; + +#if defined(DATA_A_IQ2_XS) + +const uvec2 iq2xs_grid_const[512] = { + uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808), + uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x0819192b, 0x08080808), + uvec2(0x08192b19, 0x08080808), uvec2(0x082b0808, 0x08080808), uvec2(0x082b082b, 0x08080808), uvec2(0x082b1919, 0x08080808), + uvec2(0x082b2b08, 0x08080808), uvec2(0x19080819, 0x08080808), uvec2(0x19081908, 0x08080808), uvec2(0x1908192b, 0x08080808), + uvec2(0x19082b19, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x1919082b, 0x08080808), uvec2(0x19191919, 0x08080808), + uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), uvec2(0x192b1908, 0x08080808), uvec2(0x2b080808, 0x08080808), + uvec2(0x2b08082b, 0x08080808), uvec2(0x2b081919, 0x08080808), uvec2(0x2b082b08, 0x08080808), uvec2(0x2b190819, 0x08080808), + uvec2(0x2b191908, 0x08080808), uvec2(0x2b192b19, 0x08080808), uvec2(0x2b2b0808, 0x08080808), uvec2(0x08080819, 0x08080819), + uvec2(0x08081908, 0x08080819), uvec2(0x0808192b, 0x08080819), uvec2(0x08082b19, 0x08080819), uvec2(0x08190808, 0x08080819), + uvec2(0x0819082b, 0x08080819), uvec2(0x08191919, 0x08080819), uvec2(0x08192b08, 0x08080819), uvec2(0x08192b2b, 0x08080819), + uvec2(0x082b0819, 0x08080819), uvec2(0x082b1908, 0x08080819), uvec2(0x19080808, 0x08080819), uvec2(0x1908082b, 0x08080819), + uvec2(0x19081919, 0x08080819), uvec2(0x19082b08, 0x08080819), uvec2(0x19190819, 0x08080819), uvec2(0x19191908, 0x08080819), + uvec2(0x192b0808, 0x08080819), uvec2(0x192b2b08, 0x08080819), uvec2(0x2b080819, 0x08080819), uvec2(0x2b081908, 0x08080819), + uvec2(0x2b190808, 0x08080819), uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), uvec2(0x08081919, 0x0808082b), + uvec2(0x08082b08, 0x0808082b), uvec2(0x08190819, 0x0808082b), uvec2(0x08191908, 0x0808082b), uvec2(0x082b0808, 0x0808082b), + uvec2(0x19080819, 0x0808082b), uvec2(0x19081908, 0x0808082b), uvec2(0x19190808, 0x0808082b), uvec2(0x19191919, 0x0808082b), + uvec2(0x2b080808, 0x0808082b), uvec2(0x2b082b2b, 0x0808082b), uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), + uvec2(0x0808192b, 0x08081908), uvec2(0x08082b19, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x0819082b, 0x08081908), + uvec2(0x08191919, 0x08081908), uvec2(0x08192b08, 0x08081908), uvec2(0x082b0819, 0x08081908), uvec2(0x082b1908, 0x08081908), + uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19081919, 0x08081908), uvec2(0x19082b08, 0x08081908), + uvec2(0x19190819, 0x08081908), uvec2(0x19191908, 0x08081908), uvec2(0x1919192b, 0x08081908), uvec2(0x192b0808, 0x08081908), + uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b190808, 0x08081908), uvec2(0x08080808, 0x08081919), + uvec2(0x0808082b, 0x08081919), uvec2(0x08081919, 0x08081919), uvec2(0x08082b08, 0x08081919), uvec2(0x08190819, 0x08081919), + uvec2(0x08191908, 0x08081919), uvec2(0x082b0808, 0x08081919), uvec2(0x19080819, 0x08081919), uvec2(0x19081908, 0x08081919), + uvec2(0x19190808, 0x08081919), uvec2(0x192b0819, 0x08081919), uvec2(0x2b080808, 0x08081919), uvec2(0x08080819, 0x0808192b), + uvec2(0x08081908, 0x0808192b), uvec2(0x08190808, 0x0808192b), uvec2(0x082b192b, 0x0808192b), uvec2(0x19080808, 0x0808192b), + uvec2(0x1908082b, 0x0808192b), uvec2(0x2b081908, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x0808082b, 0x08082b08), + uvec2(0x08081919, 0x08082b08), uvec2(0x08082b08, 0x08082b08), uvec2(0x08082b2b, 0x08082b08), uvec2(0x08190819, 0x08082b08), + uvec2(0x08191908, 0x08082b08), uvec2(0x082b0808, 0x08082b08), uvec2(0x082b1919, 0x08082b08), uvec2(0x19080819, 0x08082b08), + uvec2(0x19081908, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x19192b08, 0x08082b08), uvec2(0x2b080808, 0x08082b08), + uvec2(0x2b2b0808, 0x08082b08), uvec2(0x2b2b2b2b, 0x08082b08), uvec2(0x08080819, 0x08082b19), uvec2(0x08081908, 0x08082b19), + uvec2(0x08190808, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x2b080819, 0x08082b19), uvec2(0x2b082b19, 0x08082b19), + uvec2(0x08080808, 0x08082b2b), uvec2(0x082b0808, 0x08082b2b), uvec2(0x082b2b08, 0x08082b2b), uvec2(0x2b19192b, 0x08082b2b), + uvec2(0x2b2b0808, 0x08082b2b), uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), uvec2(0x0808192b, 0x08190808), + uvec2(0x08082b19, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x0819082b, 0x08190808), uvec2(0x08191919, 0x08190808), + uvec2(0x08192b08, 0x08190808), uvec2(0x082b0819, 0x08190808), uvec2(0x082b1908, 0x08190808), uvec2(0x19080808, 0x08190808), + uvec2(0x1908082b, 0x08190808), uvec2(0x19081919, 0x08190808), uvec2(0x19082b08, 0x08190808), uvec2(0x19190819, 0x08190808), + uvec2(0x19191908, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x192b2b2b, 0x08190808), uvec2(0x2b080819, 0x08190808), + uvec2(0x2b081908, 0x08190808), uvec2(0x2b190808, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x0808082b, 0x08190819), + uvec2(0x08081919, 0x08190819), uvec2(0x08082b08, 0x08190819), uvec2(0x08190819, 0x08190819), uvec2(0x08191908, 0x08190819), + uvec2(0x082b0808, 0x08190819), uvec2(0x19080819, 0x08190819), uvec2(0x19081908, 0x08190819), uvec2(0x19190808, 0x08190819), + uvec2(0x2b080808, 0x08190819), uvec2(0x2b191908, 0x08190819), uvec2(0x2b19192b, 0x08190819), uvec2(0x08080819, 0x0819082b), + uvec2(0x08081908, 0x0819082b), uvec2(0x0808192b, 0x0819082b), uvec2(0x08190808, 0x0819082b), uvec2(0x19080808, 0x0819082b), + uvec2(0x192b0808, 0x0819082b), uvec2(0x08080808, 0x08191908), uvec2(0x0808082b, 0x08191908), uvec2(0x08081919, 0x08191908), + uvec2(0x08082b08, 0x08191908), uvec2(0x08190819, 0x08191908), uvec2(0x08191908, 0x08191908), uvec2(0x082b0808, 0x08191908), + uvec2(0x19080819, 0x08191908), uvec2(0x19081908, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x19190808, 0x08191908), + uvec2(0x192b1908, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x08080819, 0x08191919), uvec2(0x08081908, 0x08191919), + uvec2(0x08190808, 0x08191919), uvec2(0x19080808, 0x08191919), uvec2(0x08080808, 0x0819192b), uvec2(0x08191908, 0x0819192b), + uvec2(0x19082b19, 0x0819192b), uvec2(0x08080819, 0x08192b08), uvec2(0x08081908, 0x08192b08), uvec2(0x08190808, 0x08192b08), + uvec2(0x0819082b, 0x08192b08), uvec2(0x19080808, 0x08192b08), uvec2(0x19191908, 0x08192b08), uvec2(0x2b08192b, 0x08192b08), + uvec2(0x08080808, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x192b192b, 0x08192b19), uvec2(0x19190819, 0x08192b2b), + uvec2(0x2b2b2b19, 0x08192b2b), uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08081919, 0x082b0808), + uvec2(0x08082b08, 0x082b0808), uvec2(0x08082b2b, 0x082b0808), uvec2(0x08190819, 0x082b0808), uvec2(0x08191908, 0x082b0808), + uvec2(0x082b0808, 0x082b0808), uvec2(0x19080819, 0x082b0808), uvec2(0x19081908, 0x082b0808), uvec2(0x19190808, 0x082b0808), + uvec2(0x2b080808, 0x082b0808), uvec2(0x2b2b0808, 0x082b0808), uvec2(0x08080819, 0x082b0819), uvec2(0x08081908, 0x082b0819), + uvec2(0x08190808, 0x082b0819), uvec2(0x19080808, 0x082b0819), uvec2(0x19082b08, 0x082b0819), uvec2(0x192b1919, 0x082b0819), + uvec2(0x08080808, 0x082b082b), uvec2(0x082b082b, 0x082b082b), uvec2(0x2b080808, 0x082b082b), uvec2(0x2b2b2b08, 0x082b082b), + uvec2(0x08080819, 0x082b1908), uvec2(0x08081908, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x082b2b19, 0x082b1908), + uvec2(0x19080808, 0x082b1908), uvec2(0x08080808, 0x082b1919), uvec2(0x19080819, 0x082b1919), uvec2(0x1919082b, 0x082b1919), + uvec2(0x2b192b19, 0x082b1919), uvec2(0x08080819, 0x082b192b), uvec2(0x08192b2b, 0x082b192b), uvec2(0x2b2b192b, 0x082b192b), + uvec2(0x08080808, 0x082b2b08), uvec2(0x08082b08, 0x082b2b08), uvec2(0x08082b2b, 0x082b2b08), uvec2(0x082b0808, 0x082b2b08), + uvec2(0x19191919, 0x082b2b08), uvec2(0x2b082b08, 0x082b2b08), uvec2(0x2b2b082b, 0x082b2b08), uvec2(0x192b2b08, 0x082b2b19), + uvec2(0x2b190808, 0x082b2b19), uvec2(0x08082b08, 0x082b2b2b), uvec2(0x082b0808, 0x082b2b2b), uvec2(0x2b08082b, 0x082b2b2b), + uvec2(0x2b082b08, 0x082b2b2b), uvec2(0x2b082b2b, 0x082b2b2b), uvec2(0x08080819, 0x19080808), uvec2(0x08081908, 0x19080808), + uvec2(0x0808192b, 0x19080808), uvec2(0x08082b19, 0x19080808), uvec2(0x08190808, 0x19080808), uvec2(0x0819082b, 0x19080808), + uvec2(0x08191919, 0x19080808), uvec2(0x08192b08, 0x19080808), uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), + uvec2(0x19080808, 0x19080808), uvec2(0x1908082b, 0x19080808), uvec2(0x19081919, 0x19080808), uvec2(0x19082b08, 0x19080808), + uvec2(0x19082b2b, 0x19080808), uvec2(0x19190819, 0x19080808), uvec2(0x19191908, 0x19080808), uvec2(0x192b0808, 0x19080808), + uvec2(0x192b1919, 0x19080808), uvec2(0x2b080819, 0x19080808), uvec2(0x2b081908, 0x19080808), uvec2(0x2b190808, 0x19080808), + uvec2(0x08080808, 0x19080819), uvec2(0x0808082b, 0x19080819), uvec2(0x08081919, 0x19080819), uvec2(0x08082b08, 0x19080819), + uvec2(0x08190819, 0x19080819), uvec2(0x08191908, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x19080819, 0x19080819), + uvec2(0x19081908, 0x19080819), uvec2(0x19190808, 0x19080819), uvec2(0x2b080808, 0x19080819), uvec2(0x2b081919, 0x19080819), + uvec2(0x2b2b082b, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08081908, 0x1908082b), uvec2(0x08190808, 0x1908082b), + uvec2(0x0819082b, 0x1908082b), uvec2(0x082b2b19, 0x1908082b), uvec2(0x19080808, 0x1908082b), uvec2(0x08080808, 0x19081908), + uvec2(0x0808082b, 0x19081908), uvec2(0x08081919, 0x19081908), uvec2(0x08082b08, 0x19081908), uvec2(0x08190819, 0x19081908), + uvec2(0x08191908, 0x19081908), uvec2(0x08192b19, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x19080819, 0x19081908), + uvec2(0x19081908, 0x19081908), uvec2(0x19190808, 0x19081908), uvec2(0x2b080808, 0x19081908), uvec2(0x2b191908, 0x19081908), + uvec2(0x08080819, 0x19081919), uvec2(0x08081908, 0x19081919), uvec2(0x08190808, 0x19081919), uvec2(0x082b1908, 0x19081919), + uvec2(0x19080808, 0x19081919), uvec2(0x2b192b2b, 0x19081919), uvec2(0x08080808, 0x1908192b), uvec2(0x08082b2b, 0x1908192b), + uvec2(0x19081908, 0x1908192b), uvec2(0x19190808, 0x1908192b), uvec2(0x08080819, 0x19082b08), uvec2(0x08081908, 0x19082b08), + uvec2(0x08190808, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x19081919, 0x19082b08), uvec2(0x19191908, 0x19082b08), + uvec2(0x192b082b, 0x19082b08), uvec2(0x08080808, 0x19082b19), uvec2(0x08190819, 0x19082b19), uvec2(0x19081908, 0x19082b19), + uvec2(0x19190808, 0x19082b19), uvec2(0x192b2b19, 0x19082b19), uvec2(0x08081908, 0x19082b2b), uvec2(0x08080808, 0x19190808), + uvec2(0x0808082b, 0x19190808), uvec2(0x08081919, 0x19190808), uvec2(0x08082b08, 0x19190808), uvec2(0x08190819, 0x19190808), + uvec2(0x08191908, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x082b2b08, 0x19190808), uvec2(0x19080819, 0x19190808), + uvec2(0x19081908, 0x19190808), uvec2(0x19190808, 0x19190808), uvec2(0x2b080808, 0x19190808), uvec2(0x08080819, 0x19190819), + uvec2(0x08081908, 0x19190819), uvec2(0x08190808, 0x19190819), uvec2(0x08191919, 0x19190819), uvec2(0x19080808, 0x19190819), + uvec2(0x1908082b, 0x19190819), uvec2(0x08080808, 0x1919082b), uvec2(0x19081908, 0x1919082b), uvec2(0x2b2b2b2b, 0x1919082b), + uvec2(0x08080819, 0x19191908), uvec2(0x08081908, 0x19191908), uvec2(0x08190808, 0x19191908), uvec2(0x082b0819, 0x19191908), + uvec2(0x19080808, 0x19191908), uvec2(0x192b0808, 0x19191908), uvec2(0x2b080819, 0x19191908), uvec2(0x2b2b0819, 0x19191908), + uvec2(0x08080808, 0x19191919), uvec2(0x08082b08, 0x19191919), uvec2(0x2b080808, 0x19191919), uvec2(0x2b082b08, 0x19191919), + uvec2(0x082b0819, 0x1919192b), uvec2(0x192b2b08, 0x1919192b), uvec2(0x2b2b0819, 0x1919192b), uvec2(0x08080808, 0x19192b08), + uvec2(0x08191908, 0x19192b08), uvec2(0x19080819, 0x19192b08), uvec2(0x19190808, 0x19192b08), uvec2(0x2b192b19, 0x19192b08), + uvec2(0x08192b2b, 0x19192b19), uvec2(0x19080808, 0x19192b19), uvec2(0x1908082b, 0x19192b19), uvec2(0x2b081919, 0x19192b2b), + uvec2(0x08080819, 0x192b0808), uvec2(0x08081908, 0x192b0808), uvec2(0x08190808, 0x192b0808), uvec2(0x19080808, 0x192b0808), + uvec2(0x19191908, 0x192b0808), uvec2(0x192b082b, 0x192b0808), uvec2(0x2b08192b, 0x192b0808), uvec2(0x2b2b2b19, 0x192b0808), + uvec2(0x08080808, 0x192b0819), uvec2(0x082b1908, 0x192b082b), uvec2(0x19082b2b, 0x192b082b), uvec2(0x2b19082b, 0x192b082b), + uvec2(0x08080808, 0x192b1908), uvec2(0x0819192b, 0x192b1908), uvec2(0x08190808, 0x192b1919), uvec2(0x19080808, 0x192b1919), + uvec2(0x19081919, 0x192b1919), uvec2(0x2b2b1908, 0x192b1919), uvec2(0x08080819, 0x192b2b08), uvec2(0x192b2b2b, 0x192b2b08), + uvec2(0x082b1919, 0x192b2b19), uvec2(0x0808192b, 0x192b2b2b), uvec2(0x19191908, 0x192b2b2b), uvec2(0x192b082b, 0x192b2b2b), + uvec2(0x08080808, 0x2b080808), uvec2(0x0808082b, 0x2b080808), uvec2(0x08081919, 0x2b080808), uvec2(0x08082b08, 0x2b080808), + uvec2(0x08190819, 0x2b080808), uvec2(0x08191908, 0x2b080808), uvec2(0x082b0808, 0x2b080808), uvec2(0x082b2b2b, 0x2b080808), + uvec2(0x19080819, 0x2b080808), uvec2(0x19081908, 0x2b080808), uvec2(0x19190808, 0x2b080808), uvec2(0x2b080808, 0x2b080808), + uvec2(0x2b08082b, 0x2b080808), uvec2(0x2b2b2b08, 0x2b080808), uvec2(0x2b2b2b2b, 0x2b080808), uvec2(0x08080819, 0x2b080819), + uvec2(0x08081908, 0x2b080819), uvec2(0x0808192b, 0x2b080819), uvec2(0x08190808, 0x2b080819), uvec2(0x19080808, 0x2b080819), + uvec2(0x19190819, 0x2b080819), uvec2(0x19192b19, 0x2b080819), uvec2(0x08080808, 0x2b08082b), uvec2(0x082b0808, 0x2b08082b), + uvec2(0x2b080808, 0x2b08082b), uvec2(0x2b08082b, 0x2b08082b), uvec2(0x2b2b0808, 0x2b08082b), uvec2(0x2b2b2b08, 0x2b08082b), + uvec2(0x08080819, 0x2b081908), uvec2(0x08081908, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x0819082b, 0x2b081908), + uvec2(0x08191919, 0x2b081908), uvec2(0x19080808, 0x2b081908), uvec2(0x192b0808, 0x2b081908), uvec2(0x2b082b19, 0x2b081908), + uvec2(0x08080808, 0x2b081919), uvec2(0x19081908, 0x2b081919), uvec2(0x2b2b1919, 0x2b081919), uvec2(0x08192b08, 0x2b08192b), + uvec2(0x192b2b2b, 0x2b08192b), uvec2(0x08080808, 0x2b082b08), uvec2(0x08082b08, 0x2b082b08), uvec2(0x082b1919, 0x2b082b08), + uvec2(0x19192b2b, 0x2b082b08), uvec2(0x2b080808, 0x2b082b08), uvec2(0x2b08082b, 0x2b082b08), uvec2(0x2b2b2b08, 0x2b082b08), + uvec2(0x0808192b, 0x2b082b19), uvec2(0x082b082b, 0x2b082b2b), uvec2(0x2b080808, 0x2b082b2b), uvec2(0x2b082b08, 0x2b082b2b), + uvec2(0x2b19192b, 0x2b082b2b), uvec2(0x2b2b2b08, 0x2b082b2b), uvec2(0x08080819, 0x2b190808), uvec2(0x08081908, 0x2b190808), + uvec2(0x08190808, 0x2b190808), uvec2(0x19080808, 0x2b190808), uvec2(0x1919192b, 0x2b190808), uvec2(0x2b081908, 0x2b190808), + uvec2(0x08080808, 0x2b190819), uvec2(0x082b082b, 0x2b190819), uvec2(0x192b1908, 0x2b190819), uvec2(0x1919192b, 0x2b19082b), + uvec2(0x2b082b19, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x08081919, 0x2b191908), uvec2(0x19081908, 0x2b191908), + uvec2(0x19190808, 0x2b191908), uvec2(0x19192b08, 0x2b191908), uvec2(0x082b2b19, 0x2b191919), uvec2(0x2b190808, 0x2b191919), + uvec2(0x2b19082b, 0x2b191919), uvec2(0x19080819, 0x2b19192b), uvec2(0x19190819, 0x2b192b08), uvec2(0x2b2b192b, 0x2b192b08), + uvec2(0x19082b19, 0x2b192b19), uvec2(0x08191919, 0x2b192b2b), uvec2(0x192b0808, 0x2b192b2b), uvec2(0x08080808, 0x2b2b0808), + uvec2(0x0808082b, 0x2b2b0808), uvec2(0x08082b08, 0x2b2b0808), uvec2(0x08082b2b, 0x2b2b0808), uvec2(0x082b0808, 0x2b2b0808), + uvec2(0x082b2b2b, 0x2b2b0808), uvec2(0x2b2b0808, 0x2b2b0808), uvec2(0x19190819, 0x2b2b0819), uvec2(0x19192b19, 0x2b2b0819), + uvec2(0x2b2b192b, 0x2b2b0819), uvec2(0x08080808, 0x2b2b082b), uvec2(0x0808082b, 0x2b2b082b), uvec2(0x08082b08, 0x2b2b082b), + uvec2(0x082b2b2b, 0x2b2b082b), uvec2(0x2b080808, 0x2b2b082b), uvec2(0x2b2b0808, 0x2b2b082b), uvec2(0x19080808, 0x2b2b1908), + uvec2(0x2b191919, 0x2b2b1908), uvec2(0x192b1919, 0x2b2b192b), uvec2(0x2b192b08, 0x2b2b192b), uvec2(0x08082b2b, 0x2b2b2b08), + uvec2(0x082b0808, 0x2b2b2b08), uvec2(0x082b082b, 0x2b2b2b08), uvec2(0x082b2b08, 0x2b2b2b08), uvec2(0x2b2b0808, 0x2b2b2b08), + uvec2(0x2b2b2b08, 0x2b2b2b08), uvec2(0x08081908, 0x2b2b2b19), uvec2(0x2b081908, 0x2b2b2b19), uvec2(0x2b08192b, 0x2b2b2b19), + uvec2(0x082b2b08, 0x2b2b2b2b), uvec2(0x082b2b2b, 0x2b2b2b2b), uvec2(0x2b190819, 0x2b2b2b2b), uvec2(0x2b2b2b2b, 0x2b2b2b2b), +}; + +shared uvec2 iq2xs_grid[512]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + [[unroll]] for (uint i = 0; i < iq2xs_grid.length(); i += wgsize.x) { + if (iq2xs_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2xs_grid_const.length()) { + iq2xs_grid[i + gl_LocalInvocationIndex.x] = iq2xs_grid_const[i + gl_LocalInvocationIndex.x]; + } + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ2_XS +#define QUANT_R QUANT_R_IQ2_XS +#define A_TYPE block_iq2_xs +#define A_TYPE_PACKED16 block_iq2_xs_packed16 +#endif + +#define QUANT_K_IQ2_S 256 +#define QUANT_R_IQ2_S 1 + +struct block_iq2_s +{ + float16_t d; + uint8_t qs[QUANT_K_IQ2_S/4]; + uint8_t qh[QUANT_K_IQ2_S/32]; + uint8_t scales[QUANT_K_IQ2_S/32]; +}; + +struct block_iq2_s_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_S/8]; + uint16_t qh[QUANT_K_IQ2_S/64]; + uint16_t scales[QUANT_K_IQ2_S/64]; +}; + +#if defined(DATA_A_IQ2_S) + +const uvec2 iq2s_grid_const[1024] = { + uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808), + uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x0819192b, 0x08080808), + uvec2(0x08192b19, 0x08080808), uvec2(0x082b0808, 0x08080808), uvec2(0x082b082b, 0x08080808), uvec2(0x082b1919, 0x08080808), + uvec2(0x082b2b08, 0x08080808), uvec2(0x19080819, 0x08080808), uvec2(0x19081908, 0x08080808), uvec2(0x1908192b, 0x08080808), + uvec2(0x19082b19, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x1919082b, 0x08080808), uvec2(0x19191919, 0x08080808), + uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), uvec2(0x192b1908, 0x08080808), uvec2(0x192b192b, 0x08080808), + uvec2(0x192b2b19, 0x08080808), uvec2(0x2b080808, 0x08080808), uvec2(0x2b08082b, 0x08080808), uvec2(0x2b081919, 0x08080808), + uvec2(0x2b082b08, 0x08080808), uvec2(0x2b190819, 0x08080808), uvec2(0x2b191908, 0x08080808), uvec2(0x2b2b0808, 0x08080808), + uvec2(0x2b2b1919, 0x08080808), uvec2(0x2b2b2b2b, 0x08080808), uvec2(0x08080819, 0x08080819), uvec2(0x08081908, 0x08080819), + uvec2(0x0808192b, 0x08080819), uvec2(0x08082b19, 0x08080819), uvec2(0x08190808, 0x08080819), uvec2(0x0819082b, 0x08080819), + uvec2(0x08191919, 0x08080819), uvec2(0x08192b08, 0x08080819), uvec2(0x082b0819, 0x08080819), uvec2(0x082b1908, 0x08080819), + uvec2(0x19080808, 0x08080819), uvec2(0x1908082b, 0x08080819), uvec2(0x19081919, 0x08080819), uvec2(0x19082b08, 0x08080819), + uvec2(0x19190819, 0x08080819), uvec2(0x19191908, 0x08080819), uvec2(0x1919192b, 0x08080819), uvec2(0x19192b19, 0x08080819), + uvec2(0x192b0808, 0x08080819), uvec2(0x192b1919, 0x08080819), uvec2(0x192b2b08, 0x08080819), uvec2(0x2b080819, 0x08080819), + uvec2(0x2b081908, 0x08080819), uvec2(0x2b190808, 0x08080819), uvec2(0x2b19082b, 0x08080819), uvec2(0x2b191919, 0x08080819), + uvec2(0x2b2b0819, 0x08080819), uvec2(0x2b2b1908, 0x08080819), uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), + uvec2(0x08081919, 0x0808082b), uvec2(0x08082b08, 0x0808082b), uvec2(0x08190819, 0x0808082b), uvec2(0x08191908, 0x0808082b), + uvec2(0x082b0808, 0x0808082b), uvec2(0x082b2b2b, 0x0808082b), uvec2(0x19080819, 0x0808082b), uvec2(0x19081908, 0x0808082b), + uvec2(0x1908192b, 0x0808082b), uvec2(0x19082b19, 0x0808082b), uvec2(0x19190808, 0x0808082b), uvec2(0x19191919, 0x0808082b), + uvec2(0x2b080808, 0x0808082b), uvec2(0x2b081919, 0x0808082b), uvec2(0x2b082b2b, 0x0808082b), uvec2(0x2b191908, 0x0808082b), + uvec2(0x2b2b082b, 0x0808082b), uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), uvec2(0x0808192b, 0x08081908), + uvec2(0x08082b19, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x0819082b, 0x08081908), uvec2(0x08191919, 0x08081908), + uvec2(0x08192b08, 0x08081908), uvec2(0x082b0819, 0x08081908), uvec2(0x082b1908, 0x08081908), uvec2(0x082b192b, 0x08081908), + uvec2(0x082b2b19, 0x08081908), uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19081919, 0x08081908), + uvec2(0x19082b08, 0x08081908), uvec2(0x19082b2b, 0x08081908), uvec2(0x19190819, 0x08081908), uvec2(0x19191908, 0x08081908), + uvec2(0x1919192b, 0x08081908), uvec2(0x19192b19, 0x08081908), uvec2(0x192b0808, 0x08081908), uvec2(0x192b082b, 0x08081908), + uvec2(0x192b1919, 0x08081908), uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b08192b, 0x08081908), + uvec2(0x2b082b19, 0x08081908), uvec2(0x2b190808, 0x08081908), uvec2(0x2b191919, 0x08081908), uvec2(0x2b192b08, 0x08081908), + uvec2(0x2b2b0819, 0x08081908), uvec2(0x2b2b1908, 0x08081908), uvec2(0x08080808, 0x08081919), uvec2(0x0808082b, 0x08081919), + uvec2(0x08081919, 0x08081919), uvec2(0x08082b08, 0x08081919), uvec2(0x08082b2b, 0x08081919), uvec2(0x08190819, 0x08081919), + uvec2(0x08191908, 0x08081919), uvec2(0x0819192b, 0x08081919), uvec2(0x08192b19, 0x08081919), uvec2(0x082b0808, 0x08081919), + uvec2(0x082b1919, 0x08081919), uvec2(0x082b2b08, 0x08081919), uvec2(0x19080819, 0x08081919), uvec2(0x19081908, 0x08081919), + uvec2(0x1908192b, 0x08081919), uvec2(0x19082b19, 0x08081919), uvec2(0x19190808, 0x08081919), uvec2(0x1919082b, 0x08081919), + uvec2(0x19191919, 0x08081919), uvec2(0x19192b08, 0x08081919), uvec2(0x192b0819, 0x08081919), uvec2(0x192b1908, 0x08081919), + uvec2(0x2b080808, 0x08081919), uvec2(0x2b08082b, 0x08081919), uvec2(0x2b081919, 0x08081919), uvec2(0x2b082b08, 0x08081919), + uvec2(0x2b190819, 0x08081919), uvec2(0x2b191908, 0x08081919), uvec2(0x2b2b0808, 0x08081919), uvec2(0x08080819, 0x0808192b), + uvec2(0x08081908, 0x0808192b), uvec2(0x0808192b, 0x0808192b), uvec2(0x08082b19, 0x0808192b), uvec2(0x08190808, 0x0808192b), + uvec2(0x08191919, 0x0808192b), uvec2(0x19080808, 0x0808192b), uvec2(0x19081919, 0x0808192b), uvec2(0x19082b08, 0x0808192b), + uvec2(0x19190819, 0x0808192b), uvec2(0x19191908, 0x0808192b), uvec2(0x192b0808, 0x0808192b), uvec2(0x2b080819, 0x0808192b), + uvec2(0x2b081908, 0x0808192b), uvec2(0x2b190808, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x0808082b, 0x08082b08), + uvec2(0x08081919, 0x08082b08), uvec2(0x08082b08, 0x08082b08), uvec2(0x08190819, 0x08082b08), uvec2(0x08191908, 0x08082b08), + uvec2(0x0819192b, 0x08082b08), uvec2(0x08192b19, 0x08082b08), uvec2(0x082b0808, 0x08082b08), uvec2(0x082b1919, 0x08082b08), + uvec2(0x082b2b2b, 0x08082b08), uvec2(0x19080819, 0x08082b08), uvec2(0x19081908, 0x08082b08), uvec2(0x1908192b, 0x08082b08), + uvec2(0x19082b19, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x1919082b, 0x08082b08), uvec2(0x19191919, 0x08082b08), + uvec2(0x19192b08, 0x08082b08), uvec2(0x192b0819, 0x08082b08), uvec2(0x192b1908, 0x08082b08), uvec2(0x2b080808, 0x08082b08), + uvec2(0x2b081919, 0x08082b08), uvec2(0x2b191908, 0x08082b08), uvec2(0x2b2b2b2b, 0x08082b08), uvec2(0x08080819, 0x08082b19), + uvec2(0x08081908, 0x08082b19), uvec2(0x08190808, 0x08082b19), uvec2(0x0819082b, 0x08082b19), uvec2(0x08191919, 0x08082b19), + uvec2(0x08192b08, 0x08082b19), uvec2(0x082b0819, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x19081919, 0x08082b19), + uvec2(0x19082b08, 0x08082b19), uvec2(0x19190819, 0x08082b19), uvec2(0x19191908, 0x08082b19), uvec2(0x192b0808, 0x08082b19), + uvec2(0x2b080819, 0x08082b19), uvec2(0x2b190808, 0x08082b19), uvec2(0x08080808, 0x08082b2b), uvec2(0x08190819, 0x08082b2b), + uvec2(0x08191908, 0x08082b2b), uvec2(0x082b082b, 0x08082b2b), uvec2(0x082b2b08, 0x08082b2b), uvec2(0x082b2b2b, 0x08082b2b), + uvec2(0x19190808, 0x08082b2b), uvec2(0x2b192b19, 0x08082b2b), uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), + uvec2(0x0808192b, 0x08190808), uvec2(0x08082b19, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x0819082b, 0x08190808), + uvec2(0x08191919, 0x08190808), uvec2(0x08192b08, 0x08190808), uvec2(0x082b0819, 0x08190808), uvec2(0x082b1908, 0x08190808), + uvec2(0x082b192b, 0x08190808), uvec2(0x19080808, 0x08190808), uvec2(0x1908082b, 0x08190808), uvec2(0x19081919, 0x08190808), + uvec2(0x19082b08, 0x08190808), uvec2(0x19190819, 0x08190808), uvec2(0x19191908, 0x08190808), uvec2(0x1919192b, 0x08190808), + uvec2(0x19192b19, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x192b082b, 0x08190808), uvec2(0x192b1919, 0x08190808), + uvec2(0x192b2b08, 0x08190808), uvec2(0x2b080819, 0x08190808), uvec2(0x2b081908, 0x08190808), uvec2(0x2b08192b, 0x08190808), + uvec2(0x2b190808, 0x08190808), uvec2(0x2b191919, 0x08190808), uvec2(0x2b192b08, 0x08190808), uvec2(0x2b2b0819, 0x08190808), + uvec2(0x2b2b1908, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x0808082b, 0x08190819), uvec2(0x08081919, 0x08190819), + uvec2(0x08082b08, 0x08190819), uvec2(0x08082b2b, 0x08190819), uvec2(0x08190819, 0x08190819), uvec2(0x08191908, 0x08190819), + uvec2(0x0819192b, 0x08190819), uvec2(0x08192b19, 0x08190819), uvec2(0x082b0808, 0x08190819), uvec2(0x082b082b, 0x08190819), + uvec2(0x082b1919, 0x08190819), uvec2(0x082b2b08, 0x08190819), uvec2(0x19080819, 0x08190819), uvec2(0x19081908, 0x08190819), + uvec2(0x1908192b, 0x08190819), uvec2(0x19082b19, 0x08190819), uvec2(0x19190808, 0x08190819), uvec2(0x1919082b, 0x08190819), + uvec2(0x19191919, 0x08190819), uvec2(0x19192b08, 0x08190819), uvec2(0x192b0819, 0x08190819), uvec2(0x192b1908, 0x08190819), + uvec2(0x2b080808, 0x08190819), uvec2(0x2b08082b, 0x08190819), uvec2(0x2b081919, 0x08190819), uvec2(0x2b082b08, 0x08190819), + uvec2(0x2b190819, 0x08190819), uvec2(0x2b191908, 0x08190819), uvec2(0x08080819, 0x0819082b), uvec2(0x08081908, 0x0819082b), + uvec2(0x08082b19, 0x0819082b), uvec2(0x08190808, 0x0819082b), uvec2(0x08191919, 0x0819082b), uvec2(0x082b0819, 0x0819082b), + uvec2(0x082b1908, 0x0819082b), uvec2(0x19080808, 0x0819082b), uvec2(0x19081919, 0x0819082b), uvec2(0x19190819, 0x0819082b), + uvec2(0x19191908, 0x0819082b), uvec2(0x2b080819, 0x0819082b), uvec2(0x2b081908, 0x0819082b), uvec2(0x2b190808, 0x0819082b), + uvec2(0x08080808, 0x08191908), uvec2(0x0808082b, 0x08191908), uvec2(0x08081919, 0x08191908), uvec2(0x08082b08, 0x08191908), + uvec2(0x08190819, 0x08191908), uvec2(0x08191908, 0x08191908), uvec2(0x0819192b, 0x08191908), uvec2(0x08192b19, 0x08191908), + uvec2(0x082b0808, 0x08191908), uvec2(0x082b1919, 0x08191908), uvec2(0x082b2b08, 0x08191908), uvec2(0x19080819, 0x08191908), + uvec2(0x19081908, 0x08191908), uvec2(0x1908192b, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x19190808, 0x08191908), + uvec2(0x1919082b, 0x08191908), uvec2(0x19191919, 0x08191908), uvec2(0x19192b08, 0x08191908), uvec2(0x192b0819, 0x08191908), + uvec2(0x192b1908, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x2b08082b, 0x08191908), uvec2(0x2b081919, 0x08191908), + uvec2(0x2b082b08, 0x08191908), uvec2(0x2b190819, 0x08191908), uvec2(0x2b191908, 0x08191908), uvec2(0x2b2b0808, 0x08191908), + uvec2(0x08080819, 0x08191919), uvec2(0x08081908, 0x08191919), uvec2(0x0808192b, 0x08191919), uvec2(0x08082b19, 0x08191919), + uvec2(0x08190808, 0x08191919), uvec2(0x0819082b, 0x08191919), uvec2(0x08191919, 0x08191919), uvec2(0x08192b08, 0x08191919), + uvec2(0x082b0819, 0x08191919), uvec2(0x082b1908, 0x08191919), uvec2(0x19080808, 0x08191919), uvec2(0x1908082b, 0x08191919), + uvec2(0x19081919, 0x08191919), uvec2(0x19082b08, 0x08191919), uvec2(0x19190819, 0x08191919), uvec2(0x19191908, 0x08191919), + uvec2(0x192b0808, 0x08191919), uvec2(0x2b080819, 0x08191919), uvec2(0x2b081908, 0x08191919), uvec2(0x2b190808, 0x08191919), + uvec2(0x08080808, 0x0819192b), uvec2(0x08081919, 0x0819192b), uvec2(0x08082b08, 0x0819192b), uvec2(0x08190819, 0x0819192b), + uvec2(0x08191908, 0x0819192b), uvec2(0x082b0808, 0x0819192b), uvec2(0x19080819, 0x0819192b), uvec2(0x19081908, 0x0819192b), + uvec2(0x19190808, 0x0819192b), uvec2(0x2b080808, 0x0819192b), uvec2(0x2b2b2b2b, 0x0819192b), uvec2(0x08080819, 0x08192b08), + uvec2(0x08081908, 0x08192b08), uvec2(0x0808192b, 0x08192b08), uvec2(0x08082b19, 0x08192b08), uvec2(0x08190808, 0x08192b08), + uvec2(0x08191919, 0x08192b08), uvec2(0x08192b08, 0x08192b08), uvec2(0x082b0819, 0x08192b08), uvec2(0x19080808, 0x08192b08), + uvec2(0x1908082b, 0x08192b08), uvec2(0x19081919, 0x08192b08), uvec2(0x19082b08, 0x08192b08), uvec2(0x19190819, 0x08192b08), + uvec2(0x19191908, 0x08192b08), uvec2(0x192b0808, 0x08192b08), uvec2(0x2b080819, 0x08192b08), uvec2(0x2b081908, 0x08192b08), + uvec2(0x08080808, 0x08192b19), uvec2(0x0808082b, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x08082b08, 0x08192b19), + uvec2(0x08190819, 0x08192b19), uvec2(0x08191908, 0x08192b19), uvec2(0x082b0808, 0x08192b19), uvec2(0x19080819, 0x08192b19), + uvec2(0x19081908, 0x08192b19), uvec2(0x19190808, 0x08192b19), uvec2(0x192b2b19, 0x08192b19), uvec2(0x2b2b082b, 0x08192b19), + uvec2(0x08081908, 0x08192b2b), uvec2(0x08190808, 0x08192b2b), uvec2(0x19080808, 0x08192b2b), uvec2(0x1919192b, 0x08192b2b), + uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08081919, 0x082b0808), uvec2(0x08082b08, 0x082b0808), + uvec2(0x08190819, 0x082b0808), uvec2(0x08191908, 0x082b0808), uvec2(0x0819192b, 0x082b0808), uvec2(0x08192b19, 0x082b0808), + uvec2(0x082b0808, 0x082b0808), uvec2(0x082b1919, 0x082b0808), uvec2(0x082b2b2b, 0x082b0808), uvec2(0x19080819, 0x082b0808), + uvec2(0x19081908, 0x082b0808), uvec2(0x19190808, 0x082b0808), uvec2(0x1919082b, 0x082b0808), uvec2(0x19191919, 0x082b0808), + uvec2(0x192b1908, 0x082b0808), uvec2(0x2b080808, 0x082b0808), uvec2(0x2b082b2b, 0x082b0808), uvec2(0x2b191908, 0x082b0808), + uvec2(0x2b2b2b2b, 0x082b0808), uvec2(0x08080819, 0x082b0819), uvec2(0x08081908, 0x082b0819), uvec2(0x08190808, 0x082b0819), + uvec2(0x0819082b, 0x082b0819), uvec2(0x08191919, 0x082b0819), uvec2(0x082b0819, 0x082b0819), uvec2(0x19080808, 0x082b0819), + uvec2(0x1908082b, 0x082b0819), uvec2(0x19081919, 0x082b0819), uvec2(0x19190819, 0x082b0819), uvec2(0x19191908, 0x082b0819), + uvec2(0x192b0808, 0x082b0819), uvec2(0x2b080819, 0x082b0819), uvec2(0x2b081908, 0x082b0819), uvec2(0x2b190808, 0x082b0819), + uvec2(0x08080808, 0x082b082b), uvec2(0x08082b2b, 0x082b082b), uvec2(0x082b082b, 0x082b082b), uvec2(0x082b2b08, 0x082b082b), + uvec2(0x082b2b2b, 0x082b082b), uvec2(0x19081908, 0x082b082b), uvec2(0x19190808, 0x082b082b), uvec2(0x2b082b08, 0x082b082b), + uvec2(0x2b082b2b, 0x082b082b), uvec2(0x2b2b2b08, 0x082b082b), uvec2(0x08080819, 0x082b1908), uvec2(0x08081908, 0x082b1908), + uvec2(0x0808192b, 0x082b1908), uvec2(0x08082b19, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x08191919, 0x082b1908), + uvec2(0x08192b08, 0x082b1908), uvec2(0x082b0819, 0x082b1908), uvec2(0x082b1908, 0x082b1908), uvec2(0x19080808, 0x082b1908), + uvec2(0x1908082b, 0x082b1908), uvec2(0x19081919, 0x082b1908), uvec2(0x19082b08, 0x082b1908), uvec2(0x19190819, 0x082b1908), + uvec2(0x19191908, 0x082b1908), uvec2(0x192b0808, 0x082b1908), uvec2(0x2b080819, 0x082b1908), uvec2(0x2b081908, 0x082b1908), + uvec2(0x2b190808, 0x082b1908), uvec2(0x08080808, 0x082b1919), uvec2(0x08081919, 0x082b1919), uvec2(0x08082b08, 0x082b1919), + uvec2(0x08190819, 0x082b1919), uvec2(0x08191908, 0x082b1919), uvec2(0x082b0808, 0x082b1919), uvec2(0x19080819, 0x082b1919), + uvec2(0x19081908, 0x082b1919), uvec2(0x19190808, 0x082b1919), uvec2(0x192b192b, 0x082b1919), uvec2(0x2b080808, 0x082b1919), + uvec2(0x08080819, 0x082b192b), uvec2(0x08081908, 0x082b192b), uvec2(0x08190808, 0x082b192b), uvec2(0x19080808, 0x082b192b), + uvec2(0x19192b19, 0x082b192b), uvec2(0x08080808, 0x082b2b08), uvec2(0x08081919, 0x082b2b08), uvec2(0x08190819, 0x082b2b08), + uvec2(0x08191908, 0x082b2b08), uvec2(0x19080819, 0x082b2b08), uvec2(0x19081908, 0x082b2b08), uvec2(0x19190808, 0x082b2b08), + uvec2(0x2b082b2b, 0x082b2b08), uvec2(0x2b2b2b2b, 0x082b2b08), uvec2(0x08080819, 0x082b2b19), uvec2(0x08081908, 0x082b2b19), + uvec2(0x08190808, 0x082b2b19), uvec2(0x2b191919, 0x082b2b19), uvec2(0x08082b2b, 0x082b2b2b), uvec2(0x082b082b, 0x082b2b2b), + uvec2(0x192b1908, 0x082b2b2b), uvec2(0x2b082b08, 0x082b2b2b), uvec2(0x2b082b2b, 0x082b2b2b), uvec2(0x08080819, 0x19080808), + uvec2(0x08081908, 0x19080808), uvec2(0x0808192b, 0x19080808), uvec2(0x08082b19, 0x19080808), uvec2(0x08190808, 0x19080808), + uvec2(0x0819082b, 0x19080808), uvec2(0x08191919, 0x19080808), uvec2(0x08192b08, 0x19080808), uvec2(0x08192b2b, 0x19080808), + uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), uvec2(0x082b192b, 0x19080808), uvec2(0x19080808, 0x19080808), + uvec2(0x1908082b, 0x19080808), uvec2(0x19081919, 0x19080808), uvec2(0x19082b08, 0x19080808), uvec2(0x19082b2b, 0x19080808), + uvec2(0x19190819, 0x19080808), uvec2(0x19191908, 0x19080808), uvec2(0x1919192b, 0x19080808), uvec2(0x19192b19, 0x19080808), + uvec2(0x192b0808, 0x19080808), uvec2(0x192b082b, 0x19080808), uvec2(0x192b1919, 0x19080808), uvec2(0x2b080819, 0x19080808), + uvec2(0x2b081908, 0x19080808), uvec2(0x2b190808, 0x19080808), uvec2(0x2b191919, 0x19080808), uvec2(0x2b192b08, 0x19080808), + uvec2(0x2b2b0819, 0x19080808), uvec2(0x2b2b1908, 0x19080808), uvec2(0x08080808, 0x19080819), uvec2(0x0808082b, 0x19080819), + uvec2(0x08081919, 0x19080819), uvec2(0x08082b08, 0x19080819), uvec2(0x08190819, 0x19080819), uvec2(0x08191908, 0x19080819), + uvec2(0x0819192b, 0x19080819), uvec2(0x08192b19, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x082b082b, 0x19080819), + uvec2(0x082b1919, 0x19080819), uvec2(0x19080819, 0x19080819), uvec2(0x19081908, 0x19080819), uvec2(0x1908192b, 0x19080819), + uvec2(0x19082b19, 0x19080819), uvec2(0x19190808, 0x19080819), uvec2(0x1919082b, 0x19080819), uvec2(0x19191919, 0x19080819), + uvec2(0x19192b08, 0x19080819), uvec2(0x192b0819, 0x19080819), uvec2(0x192b1908, 0x19080819), uvec2(0x2b080808, 0x19080819), + uvec2(0x2b08082b, 0x19080819), uvec2(0x2b081919, 0x19080819), uvec2(0x2b082b08, 0x19080819), uvec2(0x2b190819, 0x19080819), + uvec2(0x2b191908, 0x19080819), uvec2(0x2b2b0808, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08081908, 0x1908082b), + uvec2(0x08190808, 0x1908082b), uvec2(0x0819082b, 0x1908082b), uvec2(0x08191919, 0x1908082b), uvec2(0x08192b08, 0x1908082b), + uvec2(0x082b1908, 0x1908082b), uvec2(0x19080808, 0x1908082b), uvec2(0x19081919, 0x1908082b), uvec2(0x19082b08, 0x1908082b), + uvec2(0x19190819, 0x1908082b), uvec2(0x19191908, 0x1908082b), uvec2(0x192b0808, 0x1908082b), uvec2(0x2b080819, 0x1908082b), + uvec2(0x2b081908, 0x1908082b), uvec2(0x08080808, 0x19081908), uvec2(0x0808082b, 0x19081908), uvec2(0x08081919, 0x19081908), + uvec2(0x08082b08, 0x19081908), uvec2(0x08082b2b, 0x19081908), uvec2(0x08190819, 0x19081908), uvec2(0x08191908, 0x19081908), + uvec2(0x0819192b, 0x19081908), uvec2(0x08192b19, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x082b082b, 0x19081908), + uvec2(0x082b1919, 0x19081908), uvec2(0x082b2b08, 0x19081908), uvec2(0x19080819, 0x19081908), uvec2(0x19081908, 0x19081908), + uvec2(0x1908192b, 0x19081908), uvec2(0x19082b19, 0x19081908), uvec2(0x19190808, 0x19081908), uvec2(0x1919082b, 0x19081908), + uvec2(0x19191919, 0x19081908), uvec2(0x19192b08, 0x19081908), uvec2(0x192b0819, 0x19081908), uvec2(0x192b1908, 0x19081908), + uvec2(0x2b080808, 0x19081908), uvec2(0x2b08082b, 0x19081908), uvec2(0x2b081919, 0x19081908), uvec2(0x2b082b08, 0x19081908), + uvec2(0x2b190819, 0x19081908), uvec2(0x2b191908, 0x19081908), uvec2(0x2b2b0808, 0x19081908), uvec2(0x08080819, 0x19081919), + uvec2(0x08081908, 0x19081919), uvec2(0x0808192b, 0x19081919), uvec2(0x08082b19, 0x19081919), uvec2(0x08190808, 0x19081919), + uvec2(0x0819082b, 0x19081919), uvec2(0x08191919, 0x19081919), uvec2(0x08192b08, 0x19081919), uvec2(0x082b0819, 0x19081919), + uvec2(0x082b1908, 0x19081919), uvec2(0x19080808, 0x19081919), uvec2(0x1908082b, 0x19081919), uvec2(0x19081919, 0x19081919), + uvec2(0x19082b08, 0x19081919), uvec2(0x19190819, 0x19081919), uvec2(0x19191908, 0x19081919), uvec2(0x192b0808, 0x19081919), + uvec2(0x192b2b2b, 0x19081919), uvec2(0x2b080819, 0x19081919), uvec2(0x2b081908, 0x19081919), uvec2(0x2b190808, 0x19081919), + uvec2(0x08080808, 0x1908192b), uvec2(0x0808082b, 0x1908192b), uvec2(0x08081919, 0x1908192b), uvec2(0x08082b08, 0x1908192b), + uvec2(0x08190819, 0x1908192b), uvec2(0x08191908, 0x1908192b), uvec2(0x082b0808, 0x1908192b), uvec2(0x19080819, 0x1908192b), + uvec2(0x19081908, 0x1908192b), uvec2(0x19190808, 0x1908192b), uvec2(0x2b080808, 0x1908192b), uvec2(0x2b2b1919, 0x1908192b), + uvec2(0x08080819, 0x19082b08), uvec2(0x08081908, 0x19082b08), uvec2(0x08082b19, 0x19082b08), uvec2(0x08190808, 0x19082b08), + uvec2(0x0819082b, 0x19082b08), uvec2(0x08191919, 0x19082b08), uvec2(0x08192b08, 0x19082b08), uvec2(0x082b0819, 0x19082b08), + uvec2(0x082b1908, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x1908082b, 0x19082b08), uvec2(0x19081919, 0x19082b08), + uvec2(0x19082b08, 0x19082b08), uvec2(0x19190819, 0x19082b08), uvec2(0x19191908, 0x19082b08), uvec2(0x192b0808, 0x19082b08), + uvec2(0x2b081908, 0x19082b08), uvec2(0x2b190808, 0x19082b08), uvec2(0x08080808, 0x19082b19), uvec2(0x0808082b, 0x19082b19), + uvec2(0x08081919, 0x19082b19), uvec2(0x08082b08, 0x19082b19), uvec2(0x08190819, 0x19082b19), uvec2(0x08191908, 0x19082b19), + uvec2(0x082b0808, 0x19082b19), uvec2(0x19080819, 0x19082b19), uvec2(0x19081908, 0x19082b19), uvec2(0x19190808, 0x19082b19), + uvec2(0x2b080808, 0x19082b19), uvec2(0x2b19192b, 0x19082b19), uvec2(0x08080819, 0x19082b2b), uvec2(0x08081908, 0x19082b2b), + uvec2(0x08190808, 0x19082b2b), uvec2(0x19080808, 0x19082b2b), uvec2(0x08080808, 0x19190808), uvec2(0x0808082b, 0x19190808), + uvec2(0x08081919, 0x19190808), uvec2(0x08082b08, 0x19190808), uvec2(0x08190819, 0x19190808), uvec2(0x08191908, 0x19190808), + uvec2(0x0819192b, 0x19190808), uvec2(0x08192b19, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x082b082b, 0x19190808), + uvec2(0x082b1919, 0x19190808), uvec2(0x082b2b08, 0x19190808), uvec2(0x19080819, 0x19190808), uvec2(0x19081908, 0x19190808), + uvec2(0x1908192b, 0x19190808), uvec2(0x19082b19, 0x19190808), uvec2(0x19190808, 0x19190808), uvec2(0x1919082b, 0x19190808), + uvec2(0x19191919, 0x19190808), uvec2(0x19192b08, 0x19190808), uvec2(0x192b0819, 0x19190808), uvec2(0x192b1908, 0x19190808), + uvec2(0x2b080808, 0x19190808), uvec2(0x2b08082b, 0x19190808), uvec2(0x2b081919, 0x19190808), uvec2(0x2b082b08, 0x19190808), + uvec2(0x2b190819, 0x19190808), uvec2(0x2b191908, 0x19190808), uvec2(0x08080819, 0x19190819), uvec2(0x08081908, 0x19190819), + uvec2(0x0808192b, 0x19190819), uvec2(0x08082b19, 0x19190819), uvec2(0x08190808, 0x19190819), uvec2(0x0819082b, 0x19190819), + uvec2(0x08191919, 0x19190819), uvec2(0x08192b08, 0x19190819), uvec2(0x082b0819, 0x19190819), uvec2(0x082b1908, 0x19190819), + uvec2(0x19080808, 0x19190819), uvec2(0x1908082b, 0x19190819), uvec2(0x19081919, 0x19190819), uvec2(0x19082b08, 0x19190819), + uvec2(0x19190819, 0x19190819), uvec2(0x19191908, 0x19190819), uvec2(0x192b0808, 0x19190819), uvec2(0x2b080819, 0x19190819), + uvec2(0x2b081908, 0x19190819), uvec2(0x2b190808, 0x19190819), uvec2(0x08080808, 0x1919082b), uvec2(0x08081919, 0x1919082b), + uvec2(0x08082b08, 0x1919082b), uvec2(0x08190819, 0x1919082b), uvec2(0x08191908, 0x1919082b), uvec2(0x082b0808, 0x1919082b), + uvec2(0x19080819, 0x1919082b), uvec2(0x19081908, 0x1919082b), uvec2(0x19190808, 0x1919082b), uvec2(0x192b2b19, 0x1919082b), + uvec2(0x2b080808, 0x1919082b), uvec2(0x08080819, 0x19191908), uvec2(0x08081908, 0x19191908), uvec2(0x0808192b, 0x19191908), + uvec2(0x08082b19, 0x19191908), uvec2(0x08190808, 0x19191908), uvec2(0x0819082b, 0x19191908), uvec2(0x08191919, 0x19191908), + uvec2(0x08192b08, 0x19191908), uvec2(0x082b0819, 0x19191908), uvec2(0x082b1908, 0x19191908), uvec2(0x19080808, 0x19191908), + uvec2(0x1908082b, 0x19191908), uvec2(0x19081919, 0x19191908), uvec2(0x19082b08, 0x19191908), uvec2(0x19190819, 0x19191908), + uvec2(0x19191908, 0x19191908), uvec2(0x192b0808, 0x19191908), uvec2(0x2b080819, 0x19191908), uvec2(0x2b081908, 0x19191908), + uvec2(0x2b190808, 0x19191908), uvec2(0x08080808, 0x19191919), uvec2(0x0808082b, 0x19191919), uvec2(0x08081919, 0x19191919), + uvec2(0x08082b08, 0x19191919), uvec2(0x08190819, 0x19191919), uvec2(0x08191908, 0x19191919), uvec2(0x082b0808, 0x19191919), + uvec2(0x19080819, 0x19191919), uvec2(0x19081908, 0x19191919), uvec2(0x19190808, 0x19191919), uvec2(0x2b080808, 0x19191919), + uvec2(0x08080819, 0x1919192b), uvec2(0x08081908, 0x1919192b), uvec2(0x08190808, 0x1919192b), uvec2(0x082b192b, 0x1919192b), + uvec2(0x19080808, 0x1919192b), uvec2(0x08080808, 0x19192b08), uvec2(0x0808082b, 0x19192b08), uvec2(0x08081919, 0x19192b08), + uvec2(0x08082b08, 0x19192b08), uvec2(0x08190819, 0x19192b08), uvec2(0x08191908, 0x19192b08), uvec2(0x082b0808, 0x19192b08), + uvec2(0x19080819, 0x19192b08), uvec2(0x19081908, 0x19192b08), uvec2(0x19190808, 0x19192b08), uvec2(0x19192b2b, 0x19192b08), + uvec2(0x2b080808, 0x19192b08), uvec2(0x08080819, 0x19192b19), uvec2(0x08081908, 0x19192b19), uvec2(0x08190808, 0x19192b19), + uvec2(0x19080808, 0x19192b19), uvec2(0x08080808, 0x19192b2b), uvec2(0x08192b19, 0x19192b2b), uvec2(0x2b081919, 0x19192b2b), + uvec2(0x2b2b2b08, 0x19192b2b), uvec2(0x08080819, 0x192b0808), uvec2(0x08081908, 0x192b0808), uvec2(0x0808192b, 0x192b0808), + uvec2(0x08190808, 0x192b0808), uvec2(0x0819082b, 0x192b0808), uvec2(0x08191919, 0x192b0808), uvec2(0x08192b08, 0x192b0808), + uvec2(0x082b0819, 0x192b0808), uvec2(0x082b1908, 0x192b0808), uvec2(0x19080808, 0x192b0808), uvec2(0x19081919, 0x192b0808), + uvec2(0x19082b08, 0x192b0808), uvec2(0x19190819, 0x192b0808), uvec2(0x19191908, 0x192b0808), uvec2(0x192b0808, 0x192b0808), + uvec2(0x2b081908, 0x192b0808), uvec2(0x2b190808, 0x192b0808), uvec2(0x08080808, 0x192b0819), uvec2(0x0808082b, 0x192b0819), + uvec2(0x08081919, 0x192b0819), uvec2(0x08082b08, 0x192b0819), uvec2(0x08190819, 0x192b0819), uvec2(0x08191908, 0x192b0819), + uvec2(0x082b0808, 0x192b0819), uvec2(0x19080819, 0x192b0819), uvec2(0x19081908, 0x192b0819), uvec2(0x19190808, 0x192b0819), + uvec2(0x2b080808, 0x192b0819), uvec2(0x2b192b19, 0x192b0819), uvec2(0x08081908, 0x192b082b), uvec2(0x08190808, 0x192b082b), + uvec2(0x19080808, 0x192b082b), uvec2(0x1919192b, 0x192b082b), uvec2(0x2b2b0819, 0x192b082b), uvec2(0x08080808, 0x192b1908), + uvec2(0x08081919, 0x192b1908), uvec2(0x08082b08, 0x192b1908), uvec2(0x08190819, 0x192b1908), uvec2(0x08191908, 0x192b1908), + uvec2(0x082b0808, 0x192b1908), uvec2(0x19080819, 0x192b1908), uvec2(0x19081908, 0x192b1908), uvec2(0x19190808, 0x192b1908), + uvec2(0x2b080808, 0x192b1908), uvec2(0x08080819, 0x192b1919), uvec2(0x08081908, 0x192b1919), uvec2(0x08190808, 0x192b1919), + uvec2(0x19080808, 0x192b1919), uvec2(0x19082b2b, 0x192b1919), uvec2(0x192b2b08, 0x192b1919), uvec2(0x2b19082b, 0x192b1919), + uvec2(0x08080808, 0x192b192b), uvec2(0x2b191908, 0x192b192b), uvec2(0x08080819, 0x192b2b08), uvec2(0x08081908, 0x192b2b08), + uvec2(0x08190808, 0x192b2b08), uvec2(0x192b1919, 0x192b2b08), uvec2(0x2b192b08, 0x192b2b08), uvec2(0x08080808, 0x192b2b19), + uvec2(0x082b2b2b, 0x192b2b19), uvec2(0x1908082b, 0x192b2b2b), uvec2(0x2b2b0819, 0x192b2b2b), uvec2(0x08080808, 0x2b080808), + uvec2(0x0808082b, 0x2b080808), uvec2(0x08081919, 0x2b080808), uvec2(0x08082b08, 0x2b080808), uvec2(0x08190819, 0x2b080808), + uvec2(0x08191908, 0x2b080808), uvec2(0x08192b19, 0x2b080808), uvec2(0x082b0808, 0x2b080808), uvec2(0x082b1919, 0x2b080808), + uvec2(0x19080819, 0x2b080808), uvec2(0x19081908, 0x2b080808), uvec2(0x19190808, 0x2b080808), uvec2(0x1919082b, 0x2b080808), + uvec2(0x19191919, 0x2b080808), uvec2(0x19192b08, 0x2b080808), uvec2(0x192b0819, 0x2b080808), uvec2(0x2b080808, 0x2b080808), + uvec2(0x2b081919, 0x2b080808), uvec2(0x2b190819, 0x2b080808), uvec2(0x2b191908, 0x2b080808), uvec2(0x08080819, 0x2b080819), + uvec2(0x08081908, 0x2b080819), uvec2(0x08082b19, 0x2b080819), uvec2(0x08190808, 0x2b080819), uvec2(0x0819082b, 0x2b080819), + uvec2(0x08191919, 0x2b080819), uvec2(0x08192b08, 0x2b080819), uvec2(0x082b0819, 0x2b080819), uvec2(0x082b1908, 0x2b080819), + uvec2(0x19080808, 0x2b080819), uvec2(0x1908082b, 0x2b080819), uvec2(0x19081919, 0x2b080819), uvec2(0x19082b08, 0x2b080819), + uvec2(0x19190819, 0x2b080819), uvec2(0x19191908, 0x2b080819), uvec2(0x2b080819, 0x2b080819), uvec2(0x2b081908, 0x2b080819), + uvec2(0x2b190808, 0x2b080819), uvec2(0x2b2b2b19, 0x2b080819), uvec2(0x08080808, 0x2b08082b), uvec2(0x08081919, 0x2b08082b), + uvec2(0x08082b2b, 0x2b08082b), uvec2(0x08190819, 0x2b08082b), uvec2(0x08191908, 0x2b08082b), uvec2(0x19080819, 0x2b08082b), + uvec2(0x19081908, 0x2b08082b), uvec2(0x19190808, 0x2b08082b), uvec2(0x08080819, 0x2b081908), uvec2(0x08081908, 0x2b081908), + uvec2(0x0808192b, 0x2b081908), uvec2(0x08082b19, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x0819082b, 0x2b081908), + uvec2(0x08191919, 0x2b081908), uvec2(0x08192b08, 0x2b081908), uvec2(0x082b0819, 0x2b081908), uvec2(0x19080808, 0x2b081908), + uvec2(0x1908082b, 0x2b081908), uvec2(0x19081919, 0x2b081908), uvec2(0x19082b08, 0x2b081908), uvec2(0x19190819, 0x2b081908), + uvec2(0x19191908, 0x2b081908), uvec2(0x192b0808, 0x2b081908), uvec2(0x2b080819, 0x2b081908), uvec2(0x2b081908, 0x2b081908), + uvec2(0x2b190808, 0x2b081908), uvec2(0x08080808, 0x2b081919), uvec2(0x0808082b, 0x2b081919), uvec2(0x08081919, 0x2b081919), + uvec2(0x08082b08, 0x2b081919), uvec2(0x08190819, 0x2b081919), uvec2(0x08191908, 0x2b081919), uvec2(0x082b0808, 0x2b081919), + uvec2(0x19080819, 0x2b081919), uvec2(0x19081908, 0x2b081919), uvec2(0x19190808, 0x2b081919), uvec2(0x2b080808, 0x2b081919), + uvec2(0x2b082b2b, 0x2b081919), uvec2(0x08080819, 0x2b08192b), uvec2(0x08081908, 0x2b08192b), uvec2(0x08190808, 0x2b08192b), + uvec2(0x082b2b19, 0x2b08192b), uvec2(0x19080808, 0x2b08192b), uvec2(0x08080808, 0x2b082b08), uvec2(0x08081919, 0x2b082b08), + uvec2(0x08190819, 0x2b082b08), uvec2(0x08191908, 0x2b082b08), uvec2(0x19080819, 0x2b082b08), uvec2(0x19081908, 0x2b082b08), + uvec2(0x19190808, 0x2b082b08), uvec2(0x2b2b082b, 0x2b082b08), uvec2(0x08080819, 0x2b082b19), uvec2(0x08081908, 0x2b082b19), + uvec2(0x19080808, 0x2b082b19), uvec2(0x192b1919, 0x2b082b19), uvec2(0x082b082b, 0x2b082b2b), uvec2(0x19192b08, 0x2b082b2b), + uvec2(0x19192b2b, 0x2b082b2b), uvec2(0x2b08082b, 0x2b082b2b), uvec2(0x2b2b082b, 0x2b082b2b), uvec2(0x08080819, 0x2b190808), + uvec2(0x08081908, 0x2b190808), uvec2(0x08082b19, 0x2b190808), uvec2(0x08190808, 0x2b190808), uvec2(0x0819082b, 0x2b190808), + uvec2(0x08191919, 0x2b190808), uvec2(0x08192b08, 0x2b190808), uvec2(0x082b1908, 0x2b190808), uvec2(0x19080808, 0x2b190808), + uvec2(0x1908082b, 0x2b190808), uvec2(0x19081919, 0x2b190808), uvec2(0x19082b08, 0x2b190808), uvec2(0x19190819, 0x2b190808), + uvec2(0x19191908, 0x2b190808), uvec2(0x192b0808, 0x2b190808), uvec2(0x2b080819, 0x2b190808), uvec2(0x2b081908, 0x2b190808), + uvec2(0x2b190808, 0x2b190808), uvec2(0x08080808, 0x2b190819), uvec2(0x08081919, 0x2b190819), uvec2(0x08190819, 0x2b190819), + uvec2(0x08191908, 0x2b190819), uvec2(0x19080819, 0x2b190819), uvec2(0x19081908, 0x2b190819), uvec2(0x19190808, 0x2b190819), + uvec2(0x19192b2b, 0x2b190819), uvec2(0x08080819, 0x2b19082b), uvec2(0x08081908, 0x2b19082b), uvec2(0x08190808, 0x2b19082b), + uvec2(0x19080808, 0x2b19082b), uvec2(0x2b2b192b, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x0808082b, 0x2b191908), + uvec2(0x08081919, 0x2b191908), uvec2(0x08082b08, 0x2b191908), uvec2(0x08190819, 0x2b191908), uvec2(0x08191908, 0x2b191908), + uvec2(0x082b0808, 0x2b191908), uvec2(0x19080819, 0x2b191908), uvec2(0x19081908, 0x2b191908), uvec2(0x19190808, 0x2b191908), + uvec2(0x2b080808, 0x2b191908), uvec2(0x2b19192b, 0x2b191908), uvec2(0x08080819, 0x2b191919), uvec2(0x08081908, 0x2b191919), + uvec2(0x08190808, 0x2b191919), uvec2(0x19080808, 0x2b191919), uvec2(0x2b192b08, 0x2b191919), uvec2(0x2b2b0819, 0x2b191919), + uvec2(0x08080808, 0x2b19192b), uvec2(0x1908192b, 0x2b19192b), uvec2(0x192b1908, 0x2b19192b), uvec2(0x08080819, 0x2b192b08), + uvec2(0x08081908, 0x2b192b08), uvec2(0x08190808, 0x2b192b08), uvec2(0x082b192b, 0x2b192b08), uvec2(0x19080808, 0x2b192b08), + uvec2(0x2b2b2b19, 0x2b192b08), uvec2(0x08080808, 0x2b192b19), uvec2(0x19082b19, 0x2b192b19), uvec2(0x1919082b, 0x2b192b19), + uvec2(0x2b190808, 0x2b192b2b), uvec2(0x08080808, 0x2b2b0808), uvec2(0x08081919, 0x2b2b0808), uvec2(0x08082b2b, 0x2b2b0808), + uvec2(0x08191908, 0x2b2b0808), uvec2(0x082b082b, 0x2b2b0808), uvec2(0x082b2b2b, 0x2b2b0808), uvec2(0x19080819, 0x2b2b0808), + uvec2(0x19081908, 0x2b2b0808), uvec2(0x19190808, 0x2b2b0808), uvec2(0x2b2b082b, 0x2b2b0808), uvec2(0x2b2b2b2b, 0x2b2b0808), + uvec2(0x19080808, 0x2b2b0819), uvec2(0x192b1919, 0x2b2b0819), uvec2(0x0808082b, 0x2b2b082b), uvec2(0x08082b2b, 0x2b2b082b), + uvec2(0x082b082b, 0x2b2b082b), uvec2(0x082b2b08, 0x2b2b082b), uvec2(0x082b2b2b, 0x2b2b082b), uvec2(0x2b08082b, 0x2b2b082b), + uvec2(0x2b082b08, 0x2b2b082b), uvec2(0x2b082b2b, 0x2b2b082b), uvec2(0x2b2b2b08, 0x2b2b082b), uvec2(0x08080819, 0x2b2b1908), + uvec2(0x08081908, 0x2b2b1908), uvec2(0x08190808, 0x2b2b1908), uvec2(0x19080808, 0x2b2b1908), uvec2(0x2b082b19, 0x2b2b1908), + uvec2(0x2b2b1908, 0x2b2b1908), uvec2(0x08080808, 0x2b2b1919), uvec2(0x08192b19, 0x2b2b1919), uvec2(0x19190819, 0x2b2b192b), + uvec2(0x08082b2b, 0x2b2b2b08), uvec2(0x082b2b08, 0x2b2b2b08), uvec2(0x2b2b082b, 0x2b2b2b08), uvec2(0x19191908, 0x2b2b2b19), + uvec2(0x2b08192b, 0x2b2b2b19), uvec2(0x08082b08, 0x2b2b2b2b), uvec2(0x08082b2b, 0x2b2b2b2b), uvec2(0x082b0808, 0x2b2b2b2b), + uvec2(0x082b082b, 0x2b2b2b2b), uvec2(0x082b2b08, 0x2b2b2b2b), uvec2(0x2b082b08, 0x2b2b2b2b), uvec2(0x2b2b2b2b, 0x2b2b2b2b) +}; + +shared uvec2 iq2s_grid[1024]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + [[unroll]] for (uint i = 0; i < iq2s_grid.length(); i += wgsize.x) { + if (iq2s_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2s_grid_const.length()) { + iq2s_grid[i + gl_LocalInvocationIndex.x] = iq2s_grid_const[i + gl_LocalInvocationIndex.x]; + } + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ2_S +#define QUANT_R QUANT_R_IQ2_S +#define A_TYPE block_iq2_s +#define A_TYPE_PACKED16 block_iq2_s_packed16 +#endif + +#define QUANT_K_IQ3_XXS 256 +#define QUANT_R_IQ3_XXS 1 + +struct block_iq3_xxs +{ + float16_t d; + uint8_t qs[QUANT_K_IQ3_XXS/4 + QUANT_K_IQ3_XXS/8]; +}; + +struct block_iq3_xxs_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ3_XXS/8 + QUANT_K_IQ3_XXS/16]; +}; + +#if defined(DATA_A_IQ3_XXS) + +const uint32_t iq3xxs_grid_const[256] = { + 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, + 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, + 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404, + 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e, + 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c, + 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c, + 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34, + 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c, + 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c, + 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04, + 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c, + 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414, + 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434, + 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c, + 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e, + 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24, + 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24, + 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c, + 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c, + 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14, + 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414, + 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e, + 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404, + 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c, + 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c, + 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14, + 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c, + 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c, + 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14, + 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14, + 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c, + 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, +}; + +shared uint32_t iq3xxs_grid[256]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + [[unroll]] for (uint i = 0; i < iq3xxs_grid.length(); i += wgsize.x) { + if (iq3xxs_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq3xxs_grid.length()) { + iq3xxs_grid[i + gl_LocalInvocationIndex.x] = iq3xxs_grid_const[i + gl_LocalInvocationIndex.x]; + } + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ3_XXS +#define QUANT_R QUANT_R_IQ3_XXS +#define A_TYPE block_iq3_xxs +#define A_TYPE_PACKED16 block_iq3_xxs_packed16 +#endif + +#define QUANT_K_IQ3_S 256 +#define QUANT_R_IQ3_S 1 + +struct block_iq3_s +{ + float16_t d; + uint8_t qs[QUANT_K_IQ3_S/4]; + uint8_t qh[QUANT_K_IQ3_S/32]; + uint8_t signs[QUANT_K_IQ3_S/8]; + uint8_t scales[QUANT_K_IQ3_S/64]; +}; + +struct block_iq3_s_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ3_S/4/2]; + uint16_t qh[QUANT_K_IQ3_S/32/2]; + uint16_t signs[QUANT_K_IQ3_S/8/2]; + uint16_t scales[QUANT_K_IQ3_S/64/2]; +}; + +#if defined(DATA_A_IQ3_S) + +const uint32_t iq3s_grid_const[512] = { + 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305, + 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905, + 0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09, + 0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b, + 0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b, + 0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d, + 0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03, + 0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505, + 0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03, + 0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901, + 0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d, + 0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303, + 0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501, + 0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105, + 0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505, + 0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101, + 0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707, + 0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b, + 0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01, + 0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f, + 0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305, + 0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103, + 0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509, + 0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503, + 0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b, + 0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f, + 0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f, + 0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f, + 0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109, + 0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f, + 0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509, + 0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501, + 0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303, + 0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f, + 0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907, + 0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703, + 0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03, + 0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01, + 0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01, + 0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903, + 0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505, + 0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b, + 0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107, + 0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509, + 0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303, + 0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103, + 0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05, + 0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b, + 0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f, + 0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701, + 0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909, + 0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305, + 0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d, + 0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b, + 0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d, + 0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307, + 0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09, + 0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309, + 0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709, + 0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f, + 0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303, + 0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503, + 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b, + 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101, +}; + +shared uint32_t iq3s_grid[512]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + [[unroll]] for (uint i = 0; i < iq3s_grid.length(); i += wgsize.x) { + if (iq3s_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq3s_grid.length()) { + iq3s_grid[i + gl_LocalInvocationIndex.x] = iq3s_grid_const[i + gl_LocalInvocationIndex.x]; + } + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ3_S +#define QUANT_R QUANT_R_IQ3_S +#define A_TYPE block_iq3_s +#define A_TYPE_PACKED16 block_iq3_s_packed16 +#endif + +#define QUANT_K_IQ4_XS 256 +#define QUANT_R_IQ4_XS 1 + +struct block_iq4_xs +{ + float16_t d; + uint16_t scales_h; + uint8_t scales_l[QUANT_K_IQ4_XS/64]; + uint8_t qs[QUANT_K_IQ4_XS/2]; +}; + +#if defined(DATA_A_IQ4_XS) +#define QUANT_K QUANT_K_IQ4_XS +#define QUANT_R QUANT_R_IQ4_XS +#define A_TYPE block_iq4_xs +#endif + +#define QUANT_K_IQ4_NL 32 +#define QUANT_R_IQ4_NL 2 + +struct block_iq4_nl +{ + float16_t d; + uint8_t qs[QUANT_K_IQ4_NL/2]; +}; + +struct block_iq4_nl_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ4_NL/2/2]; +}; + +#if defined(DATA_A_IQ4_NL) +#define QUANT_K QUANT_K_IQ4_NL +#define QUANT_R QUANT_R_IQ4_NL +#define A_TYPE block_iq4_nl +#define A_TYPE_PACKED16 block_iq4_nl_packed16 +#endif + +#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS) +const int8_t kvalues_iq4nl_const[16] = { + int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10), + int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113) +}; + +shared FLOAT_TYPE kvalues_iq4nl[16]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < kvalues_iq4nl.length(); i += wgsize.x) { + kvalues_iq4nl[i] = FLOAT_TYPE(kvalues_iq4nl_const[i]); + } + barrier(); +} +#endif + +// returns the bfloat value in the low 16b. +// See ggml_compute_fp32_to_bf16 +uint32_t fp32_to_bf16(float f) +{ + uint32_t u = floatBitsToUint(f); + u = (u + (0x7fff + ((u >> 16) & 1))) >> 16; + return u; +} + +float bf16_to_fp32(uint32_t u) +{ + return uintBitsToFloat(u << 16); +} + +#endif // !defined(GGML_TYPES_COMP) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp new file mode 100644 index 0000000000000000000000000000000000000000..6f607380df8bf556def8c1a1cac8c5128df80c48 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp @@ -0,0 +1,36 @@ +#version 450 + +layout (push_constant) uniform parameter +{ + uint ne; uint a_offset; uint d_offset; + uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; + float sf0; float sf1; float sf2; float sf3; +} p; + +#include "types.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (idx >= p.ne) { + return; + } + + const uint i10 = idx % p.ne10; + const uint i11 = (idx / p.ne10) % p.ne11; + const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12; + const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13; + + const uint i00 = uint(i10 / p.sf0); + const uint i01 = uint(i11 / p.sf1); + const uint i02 = uint(i12 / p.sf2); + const uint i03 = uint(i13 / p.sf3); + + data_d[p.d_offset + idx] = D_TYPE(data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c63345ec8b4b673ba581f8ad1c7cfe64900e9df0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -0,0 +1,753 @@ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 + #include + #include // For _mkdir on Windows +#else + #include + #include + #include +#endif + +#define ASYNCIO_CONCURRENCY 64 + +std::mutex lock; +std::vector> shader_fnames; + +std::string GLSLC = "glslc"; +std::string input_dir = "vulkan-shaders"; +std::string output_dir = "/tmp"; +std::string target_hpp = "ggml-vulkan-shaders.hpp"; +std::string target_cpp = "ggml-vulkan-shaders.cpp"; +bool no_clean = false; + +const std::vector type_names = { + "f32", + "f16", + "q4_0", + "q4_1", + "q5_0", + "q5_1", + "q8_0", + "q2_k", + "q3_k", + "q4_k", + "q5_k", + "q6_k", + "iq1_s", + "iq1_m", + "iq2_xxs", + "iq2_xs", + "iq2_s", + "iq3_xxs", + "iq3_s", + "iq4_xs", + "iq4_nl", + "bf16", +}; + +namespace { +void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) { +#ifdef _WIN32 + HANDLE stdout_read, stdout_write; + HANDLE stderr_read, stderr_write; + SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE }; + + if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) || + !SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) { + throw std::runtime_error("Failed to create stdout pipe"); + } + + if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) || + !SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) { + throw std::runtime_error("Failed to create stderr pipe"); + } + + PROCESS_INFORMATION pi; + STARTUPINFOA si = {}; + si.cb = sizeof(STARTUPINFOA); + si.dwFlags = STARTF_USESTDHANDLES; + si.hStdOutput = stdout_write; + si.hStdError = stderr_write; + + std::vector cmd(command.begin(), command.end()); + cmd.push_back('\0'); + + if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) { + throw std::runtime_error("Failed to create process"); + } + + CloseHandle(stdout_write); + CloseHandle(stderr_write); + + std::array buffer; + DWORD bytes_read; + + while (ReadFile(stdout_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) { + stdout_str.append(buffer.data(), bytes_read); + } + + while (ReadFile(stderr_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) { + stderr_str.append(buffer.data(), bytes_read); + } + + CloseHandle(stdout_read); + CloseHandle(stderr_read); + WaitForSingleObject(pi.hProcess, INFINITE); + CloseHandle(pi.hProcess); + CloseHandle(pi.hThread); +#else +int stdout_pipe[2]; + int stderr_pipe[2]; + + if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) { + throw std::runtime_error("Failed to create pipes"); + } + + pid_t pid = fork(); + if (pid < 0) { + throw std::runtime_error("Failed to fork process"); + } + + if (pid == 0) { + close(stdout_pipe[0]); + close(stderr_pipe[0]); + dup2(stdout_pipe[1], STDOUT_FILENO); + dup2(stderr_pipe[1], STDERR_FILENO); + close(stdout_pipe[1]); + close(stderr_pipe[1]); + execl("/bin/sh", "sh", "-c", command.c_str(), (char*) nullptr); + _exit(EXIT_FAILURE); + } else { + close(stdout_pipe[1]); + close(stderr_pipe[1]); + + std::array buffer; + ssize_t bytes_read; + + while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) { + stdout_str.append(buffer.data(), bytes_read); + } + + while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) { + stderr_str.append(buffer.data(), bytes_read); + } + + close(stdout_pipe[0]); + close(stderr_pipe[0]); + waitpid(pid, nullptr, 0); + } +#endif +} + +bool directory_exists(const std::string& path) { + struct stat info; + if (stat(path.c_str(), &info) != 0) { + return false; // Path doesn't exist or can't be accessed + } + return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory +} + +bool create_directory(const std::string& path) { +#ifdef _WIN32 + return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists +#else + return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions +#endif +} + +std::string to_uppercase(const std::string& input) { + std::string result = input; + for (char& c : result) { + c = std::toupper(c); + } + return result; +} + +bool string_starts_with(const std::string& str, const std::string& prefix) { + if (prefix.size() > str.size()) { + return false; + } + return std::equal(prefix.begin(), prefix.end(), str.begin()); +} + +bool string_ends_with(const std::string& str, const std::string& suffix) { + if (suffix.size() > str.size()) { + return false; + } + return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); +} + +static const char path_separator = '/'; + +std::string join_paths(const std::string& path1, const std::string& path2) { + return path1 + path_separator + path2; +} + +std::string basename(const std::string &path) { + return path.substr(path.find_last_of("/\\") + 1); +} + +// variables to track number of compiles in progress +static uint32_t compile_count = 0; +static std::mutex compile_count_mutex; +static std::condition_variable compile_count_cond; + +void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { + std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); + std::string out_fname = join_paths(output_dir, name + ".spv"); + std::string in_path = join_paths(input_dir, in_fname); + + std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2"; + + // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734 + std::string opt_level = coopmat ? "" : "-O"; + + #ifdef _WIN32 + std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""}; + #else + std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_fname}; + #endif + + #ifdef GGML_VULKAN_SHADER_DEBUG_INFO + cmd.push_back("-g"); + #endif + + for (const auto& define : defines) { + cmd.push_back("-D" + define.first + "=" + define.second); + } + + std::string command; + for (const auto& part : cmd) { + command += part + " "; + } + + std::string stdout_str, stderr_str; + try { + // std::cout << "Executing command: "; + // for (const auto& part : cmd) { + // std::cout << part << " "; + // } + // std::cout << std::endl; + + execute_command(command, stdout_str, stderr_str); + if (!stderr_str.empty()) { + std::cerr << "cannot compile " << name << "\n\n" << command << "\n\n" << stderr_str << std::endl; + return; + } + + std::lock_guard guard(lock); + shader_fnames.push_back(std::make_pair(name, out_fname)); + } catch (const std::exception& e) { + std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl; + } + { + std::lock_guard guard(compile_count_mutex); + assert(compile_count > 0); + compile_count--; + } + compile_count_cond.notify_all(); +} + +std::map merge_maps(const std::map& a, const std::map& b) { + std::map result = a; + result.insert(b.begin(), b.end()); + return result; +} + +static std::vector> compiles; +void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { + { + // wait until fewer than N compiles are in progress. + // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors. + uint32_t N = 16; + std::unique_lock guard(compile_count_mutex); + while (compile_count >= N) { + compile_count_cond.wait(guard); + } + compile_count++; + } + compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc)); +} + +void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) { + std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4"; + std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; + std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; + + std::map base_dict = { + {"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"}, + }; + std::string shader_name = "matmul"; + + if (matmul_id) { + base_dict["MUL_MAT_ID"] = "1"; + shader_name = "matmul_id"; + } + + if (fp16) { + base_dict["FLOAT16"] = "1"; + } + + base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; + + if (coopmat) { + base_dict["COOPMAT"] = "1"; + } + + const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; + + auto const &FLOAT_TYPE = [&](const std::string &t) -> std::string { + if (t == "bf16") { + // scalar path promotes to float + if (!coopmat && !coopmat2) { + return "float"; + } + return "bfloat16_t"; + } + if (coopmat2 || fp16) { + return "float16_t"; + } + return "float"; + }; + + // Shaders with f16 B_TYPE + string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + + string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + + // bf16 + { + std::string load_vec_a_unaligned = "1"; + // For aligned matmul loads + std::string load_vec_a = coopmat2 ? "1" : "4"; + + // scalar path promotes to float + std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32"; + + // If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader +#if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (!(coopmat || coopmat2)) +#endif + { + string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); + } + } + + for (const auto& tname : type_names) { + std::string load_vec_quant = "2"; + if ((tname == "q4_0") || (tname == "q4_1")) + load_vec_quant = "8"; + else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl")) + load_vec_quant = "4"; + + if (tname == "bf16") { + continue; + } + + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + // For unaligned, load one at a time for f32/f16, or two at a time for quants + std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant; + // For aligned matmul loads + std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant; + + // don't generate f32 variants for coopmat2 + if (!coopmat2) { + string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + } + + if (tname != "f16" && tname != "f32") { + string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + } + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) { + string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc); + } +#endif + } +} + +void process_shaders() { + std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl; + std::map base_dict = {{"FLOAT_TYPE", "float"}}; + + // matmul + for (const auto& matmul_id : {false, true}) { + // No coopmats + // fp32 + matmul_shaders(false, matmul_id, false, false, false); + + // fp16, fp32acc and fp16acc + matmul_shaders(true, matmul_id, false, false, false); + matmul_shaders(true, matmul_id, false, false, true); + +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + // Coopmat, fp32acc and fp16acc + matmul_shaders(true, matmul_id, true, false, false); + matmul_shaders(true, matmul_id, true, false, true); +#endif + +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + // Coopmat2, fp32acc and fp16acc + matmul_shaders(true, matmul_id, false, true, false); + matmul_shaders(true, matmul_id, false, true, true); +#endif + } + + // flash attention + for (const auto& f16acc : {false, true}) { + std::string acctype = f16acc ? "float16_t" : "float"; + std::string acctypev4 = f16acc ? "f16vec4" : "vec4"; + + for (const auto& tname : type_names) { + if (tname == "f32") { + continue; + } + if (tname == "bf16") continue; + +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", + merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc); + } else { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", + merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); + } +#endif +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", + merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc); + } else if (tname == "q4_0" || tname == "q8_0") { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", + merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); + } +#endif + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc); + } else if (tname == "q4_0" || tname == "q8_0") { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); + } + } + } + + for (const auto& tname : type_names) { + // mul mat vec + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; + + string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); + + string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); + + // Dequant shaders + if (tname != "f16" && tname != "bf16") { + string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}})); + } + + if (!string_ends_with(tname, "_k")) { + shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp"; + + if (tname == "f16") { + string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}})); + } else { + string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}})); + } + string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}})); + } + } + + string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); + string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); + string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); + + // Norms + string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); + string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); + string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); + string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); + + for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { + string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); + string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + } + + auto get_type_str = [](bool f16) { + return f16 ? "float16_t" : "float"; + }; + auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) { + std::string s; + s += std::string(src0_f16 ? "_f16" : "_f32"); + s += std::string(src1_f16 ? "_f16" : "_f32"); + s += std::string(dst_f16 ? "_f16" : "_f32"); + return s; + }; + for (std::string op : {"add", "sub", "mul", "div"}) { + for (auto src0_f16 : {false, true}) { + for (auto src1_f16 : {false, true}) { + for (auto dst_f16 : {false, true}) { + auto name = op + get_suffix(src0_f16, src1_f16, dst_f16); + string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}}); + } + } + } + } + + string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); + string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {}); + string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {}); + + string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}); + + string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("gelu_quick_f16", "gelu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("silu_f16", "silu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("relu_f16", "relu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("tanh_f16", "tanh.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); + + string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); + string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); + + string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}})); + string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}})); + + string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + + string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + + string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + + string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); + string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); + + for (auto &c : compiles) { + c.wait(); + } +} + +void write_output_files() { + FILE* hdr = fopen(target_hpp.c_str(), "w"); + FILE* src = fopen(target_cpp.c_str(), "w"); + + fprintf(hdr, "#include \n\n"); + fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str()); + + std::sort(shader_fnames.begin(), shader_fnames.end()); + for (const auto& pair : shader_fnames) { + const std::string& name = pair.first; + #ifdef _WIN32 + std::string path = pair.second; + std::replace(path.begin(), path.end(), '/', '\\' ); + #else + const std::string& path = pair.second; + #endif + + FILE* spv = fopen(path.c_str(), "rb"); + if (!spv) { + std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n"; + continue; + } + + fseek(spv, 0, SEEK_END); + size_t size = ftell(spv); + fseek(spv, 0, SEEK_SET); + + std::vector data(size); + size_t read_size = fread(data.data(), 1, size, spv); + fclose(spv); + if (read_size != size) { + std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n"; + continue; + } + + fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size); + fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size); + + fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size); + for (size_t i = 0; i < size; ++i) { + fprintf(src, "0x%02x,", data[i]); + if ((i + 1) % 12 == 0) fprintf(src, "\n"); + } + fprintf(src, "\n};\n\n"); + + if (!no_clean) { + std::remove(path.c_str()); + } + } + for (const char *op : {"add", "sub", "mul", "div"}) { + fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op); + fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op); + fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op); + fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op); + } + fclose(hdr); + fclose(src); +} +} + +int main(int argc, char** argv) { + std::map args; + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg.rfind("--", 0) == 0) { + if (i + 1 < argc && argv[i + 1][0] != '-') { + args[arg] = argv[i + 1]; + ++i; + } else { + args[arg] = ""; + } + } + } + + if (args.find("--glslc") != args.end()) { + GLSLC = args["--glslc"]; // Path to glslc + } + if (args.find("--input-dir") != args.end()) { + input_dir = args["--input-dir"]; // Directory containing shader sources + } + if (args.find("--output-dir") != args.end()) { + output_dir = args["--output-dir"]; // Directory for containing SPIR-V output + } + if (args.find("--target-hpp") != args.end()) { + target_hpp = args["--target-hpp"]; // Path to generated header file + } + if (args.find("--target-cpp") != args.end()) { + target_cpp = args["--target-cpp"]; // Path to generated cpp file + } + if (args.find("--no-clean") != args.end()) { + no_clean = true; // Keep temporary SPIR-V files in output-dir after build + } + + if (!directory_exists(input_dir)) { + std::cerr << "\"" << input_dir << "\" must be a valid directory containing shader sources" << std::endl; + return EXIT_FAILURE; + } + + if (!directory_exists(output_dir)) { + if (!create_directory(output_dir)) { + std::cerr << "Error creating output directory: " << output_dir << "\n"; + return EXIT_FAILURE; + } + } + + process_shaders(); + + write_output_files(); + + return EXIT_SUCCESS; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp b/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp new file mode 100644 index 0000000000000000000000000000000000000000..35cc6c45f90a572eece67491e06475d1fa33a91c --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp @@ -0,0 +1,87 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +#define BLOCK_SIZE 64 +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint B; + uint T; + uint C; + uint H; +}; + +layout(binding = 0) readonly buffer KBuf { A_TYPE k[]; }; +layout(binding = 1) readonly buffer VBuf { A_TYPE v[]; }; +layout(binding = 2) readonly buffer RBuf { A_TYPE r[]; }; +layout(binding = 3) readonly buffer TimeFBuf { A_TYPE tf[]; }; +layout(binding = 4) readonly buffer TimeDBuf { A_TYPE td[]; }; +layout(binding = 5) readonly buffer StateBuf { A_TYPE state_in[]; }; +layout(binding = 6) buffer DstBuf { A_TYPE dst[]; }; + +shared A_TYPE _k[BLOCK_SIZE], _r[BLOCK_SIZE], _tf[BLOCK_SIZE], _td[BLOCK_SIZE]; + +void main() { + const uint head_size = BLOCK_SIZE; + const uint batch_id = gl_WorkGroupID.x / H; + const uint head_id = gl_WorkGroupID.x % H; + const uint tid = gl_LocalInvocationID.x; + + const uint state_size = C * head_size; + const uint n_seq_tokens = T / B; + + if (batch_id >= B || head_id >= H) { + return; + } + + A_TYPE state[BLOCK_SIZE]; + [[unroll]] for (uint i = 0; i < head_size; i++) { + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid]; + } + + barrier(); + _tf[tid] = tf[head_id * head_size + tid]; + barrier(); + + const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; + const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; + + for (uint t = start_t; t < end_t; t += C) { + barrier(); + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + barrier(); + + const A_TYPE v_val = v[t]; + A_TYPE y = 0.0; + + [[unroll]] for (uint j = 0; j < head_size; j += 4) { + vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); + vec4 td_vec = vec4(_td[j], _td[j+1], _td[j+2], _td[j+3]); + vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); + + vec4 kv = k_vec * v_val; + + vec4 temp = tf_vec * kv + s_vec; + y += dot(r_vec, temp); + + s_vec = s_vec * td_vec + kv; + state[j] = s_vec.x; + state[j+1] = s_vec.y; + state[j+2] = s_vec.z; + state[j+3] = s_vec.w; + } + + dst[t] = y; + } + + [[unroll]] for (uint i = 0; i < head_size; i++) { + dst[T * C + batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid] = state[i]; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp b/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp new file mode 100644 index 0000000000000000000000000000000000000000..88c1c02b32b8c3fa7a027821a1c9b7857164772e --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp @@ -0,0 +1,91 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +#define BLOCK_SIZE 64 +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint B; + uint T; + uint C; + uint H; +}; + +layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; }; +layout(binding = 1) readonly buffer WBuf { A_TYPE w[]; }; +layout(binding = 2) readonly buffer KBuf { A_TYPE k[]; }; +layout(binding = 3) readonly buffer VBuf { A_TYPE v[]; }; +layout(binding = 4) readonly buffer ABuf { A_TYPE a[]; }; +layout(binding = 5) readonly buffer BBuf { A_TYPE b[]; }; +layout(binding = 6) readonly buffer StateBuf { A_TYPE state_in[]; }; +layout(binding = 7) buffer DstBuf { A_TYPE dst[]; }; + +shared A_TYPE _r[BLOCK_SIZE], _w[BLOCK_SIZE], _k[BLOCK_SIZE], _a[BLOCK_SIZE], _b[BLOCK_SIZE]; + +void main() { + const uint head_size = BLOCK_SIZE; + const uint batch_id = gl_WorkGroupID.x / H; + const uint head_id = gl_WorkGroupID.x % H; + const uint tid = gl_LocalInvocationID.x; + + const uint state_size = C * head_size; + const uint n_seq_tokens = T / B; + + if (batch_id >= B || head_id >= H) { + return; + } + + A_TYPE state[BLOCK_SIZE]; + [[unroll]] for (uint i = 0; i < head_size; i++) { + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + + tid * head_size + i]; + } + + const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; + const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; + + for (uint t = start_t; t < end_t; t += C) { + barrier(); + _r[tid] = r[t]; + _w[tid] = w[t]; + _k[tid] = k[t]; + _a[tid] = a[t]; + _b[tid] = b[t]; + barrier(); + + A_TYPE sa = 0.0; + [[unroll]] for (uint j = 0; j < head_size; j += 4) { + vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); + vec4 a_vec = vec4(_a[j], _a[j+1], _a[j+2], _a[j+3]); + sa += dot(s_vec, a_vec); + } + + const A_TYPE v_val = v[t]; + A_TYPE y = 0.0; + + [[unroll]] for (uint j = 0; j < head_size; j += 4) { + vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + vec4 w_vec = vec4(_w[j], _w[j+1], _w[j+2], _w[j+3]); + vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + vec4 b_vec = vec4(_b[j], _b[j+1], _b[j+2], _b[j+3]); + vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); + + vec4 kv = k_vec * v_val; + s_vec = s_vec * w_vec + kv + sa * b_vec; + y += dot(r_vec, s_vec); + + state[j] = s_vec.x; + state[j+1] = s_vec.y; + state[j+2] = s_vec.z; + state[j+3] = s_vec.w; + } + + dst[t] = y; + } + + [[unroll]] for (uint i = 0; i < head_size; i++) { + dst[T * C + batch_id * state_size + head_id * head_size * head_size + + tid * head_size + i] = state[i]; + } +} diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c new file mode 100644 index 0000000000000000000000000000000000000000..196b7b8f3e2ae6edaf3a88c44f11f7034bf51c44 --- /dev/null +++ b/ggml/src/ggml.c @@ -0,0 +1,6577 @@ +#define _CRT_SECURE_NO_DEPRECATE // Disables "unsafe" warnings on Windows +#define _USE_MATH_DEFINES // For M_PI on MSVC + +#include "ggml-backend.h" +#include "ggml-impl.h" +#include "ggml-threading.h" +#include "ggml-cpu.h" +#include "ggml.h" + +// FIXME: required here for quantization functions +#include "ggml-quants.h" + +#ifdef GGML_USE_CPU_HBM +#include +#endif + +#if defined(_MSC_VER) || defined(__MINGW32__) +#include // using malloc.h with MSC/MINGW +#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if defined(__gnu_linux__) +#include +#endif + +#if defined(__APPLE__) +#include +#include +#include +#endif + +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX + #define NOMINMAX +#endif +#include +#endif + +#define UNUSED GGML_UNUSED + +#if defined(_MSC_VER) +#define m512bh(p) p +#define m512i(p) p +#else +#define m512bh(p) (__m512bh)(p) +#define m512i(p) (__m512i)(p) +#endif + +// precomputed f32 table for f16 (256 KB) (ggml-impl.h) +float ggml_table_f32_f16[1 << 16]; + +#if defined(__linux__) || \ + defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \ + (defined(__APPLE__) && !TARGET_OS_TV && !TARGET_OS_WATCH) + +#include +#include +#include +#include +#if defined(__linux__) +#include +#endif + +#if defined(__ANDROID__) +#include +#include +#include + +struct backtrace_state { + void ** current; + void ** end; +}; + +static _Unwind_Reason_Code unwind_callback(struct _Unwind_Context* context, void* arg) { + struct backtrace_state * state = (struct backtrace_state *)arg; + uintptr_t pc = _Unwind_GetIP(context); + if (pc) { + if (state->current == state->end) { + return _URC_END_OF_STACK; + } else { + *state->current++ = (void*)pc; + } + } + return _URC_NO_REASON; +} + +static void ggml_print_backtrace_symbols(void) { + const int max = 100; + void* buffer[max]; + + struct backtrace_state state = {buffer, buffer + max}; + _Unwind_Backtrace(unwind_callback, &state); + + int count = state.current - buffer; + + for (int idx = 0; idx < count; ++idx) { + const void * addr = buffer[idx]; + const char * symbol = ""; + + Dl_info info; + if (dladdr(addr, &info) && info.dli_sname) { + symbol = info.dli_sname; + } + + fprintf(stderr, "%d: %p %s\n", idx, addr, symbol); + } +} +#elif defined(__linux__) && defined(__GLIBC__) +#include +static void ggml_print_backtrace_symbols(void) { + void * trace[100]; + int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0])); + backtrace_symbols_fd(trace, nptrs, STDERR_FILENO); +} +#else +static void ggml_print_backtrace_symbols(void) { + // platform not supported +} +#endif + +void ggml_print_backtrace(void) { + const char * GGML_NO_BACKTRACE = getenv("GGML_NO_BACKTRACE"); + if (GGML_NO_BACKTRACE) { + return; + } +#if defined(__linux__) + FILE * f = fopen("/proc/self/status", "r"); + size_t size = 0; + char * line = NULL; + ssize_t length = 0; + while ((length = getline(&line, &size, f)) > 0) { + if (!strncmp(line, "TracerPid:", sizeof("TracerPid:") - 1) && + (length != sizeof("TracerPid:\t0\n") - 1 || line[length - 2] != '0')) { + // Already being debugged, and the breakpoint is the later abort() + free(line); + fclose(f); + return; + } + } + free(line); + fclose(f); + int lock[2] = { -1, -1 }; + (void) !pipe(lock); // Don't start gdb until after PR_SET_PTRACER +#endif + const int parent_pid = getpid(); + const int child_pid = fork(); + if (child_pid < 0) { // error +#if defined(__linux__) + close(lock[1]); + close(lock[0]); +#endif + return; + } else if (child_pid == 0) { // child + char attach[32]; + snprintf(attach, sizeof(attach), "attach %d", parent_pid); +#if defined(__linux__) + close(lock[1]); + (void) !read(lock[0], lock, 1); + close(lock[0]); +#endif + // try gdb + execlp("gdb", "gdb", "--batch", + "-ex", "set style enabled on", + "-ex", attach, + "-ex", "bt -frame-info source-and-location", + "-ex", "detach", + "-ex", "quit", + (char *) NULL); + // try lldb + execlp("lldb", "lldb", "--batch", + "-o", "bt", + "-o", "quit", + "-p", &attach[sizeof("attach ") - 1], + (char *) NULL); + // gdb failed, fallback to backtrace_symbols + ggml_print_backtrace_symbols(); + _Exit(0); + } else { // parent +#if defined(__linux__) + prctl(PR_SET_PTRACER, child_pid); + close(lock[1]); + close(lock[0]); +#endif + waitpid(child_pid, NULL, 0); + } +} +#else +void ggml_print_backtrace(void) { + // platform not supported +} +#endif + +void ggml_abort(const char * file, int line, const char * fmt, ...) { + fflush(stdout); + + fprintf(stderr, "%s:%d: ", file, line); + + va_list args; + va_start(args, fmt); + vfprintf(stderr, fmt, args); + va_end(args); + + fprintf(stderr, "\n"); + + ggml_print_backtrace(); + abort(); +} + +// ggml_print_backtrace is registered with std::set_terminate by ggml.cpp + +// +// logging +// + +struct ggml_logger_state { + ggml_log_callback log_callback; + void * log_callback_user_data; +}; +static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL}; + +static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) { + if (format == NULL) { + return; + } + va_list args_copy; + va_copy(args_copy, args); + char buffer[128]; + int len = vsnprintf(buffer, 128, format, args); + if (len < 128) { + g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data); + } else { + char * buffer2 = (char *) calloc(len + 1, sizeof(char)); + vsnprintf(buffer2, len + 1, format, args_copy); + buffer2[len] = 0; + g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data); + free(buffer2); + } + va_end(args_copy); +} + +void ggml_log_internal(enum ggml_log_level level, const char * format, ...) { + va_list args; + va_start(args, format); + ggml_log_internal_v(level, format, args); + va_end(args); +} + +void ggml_log_callback_default(enum ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} + +// +// end of logging block +// + +#ifdef GGML_USE_ACCELERATE +// uncomment to use vDSP for soft max computation +// note: not sure if it is actually faster +//#define GGML_SOFT_MAX_ACCELERATE +#endif + + +void * ggml_aligned_malloc(size_t size) { +#if defined(__s390x__) + const int alignment = 256; +#else + const int alignment = 64; +#endif + +#if defined(_MSC_VER) || defined(__MINGW32__) + return _aligned_malloc(size, alignment); +#else + if (size == 0) { + GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n"); + return NULL; + } + void * aligned_memory = NULL; + #ifdef GGML_USE_CPU_HBM + int result = hbw_posix_memalign(&aligned_memory, alignment, size); + #elif TARGET_OS_OSX + GGML_UNUSED(alignment); + kern_return_t alloc_status = vm_allocate((vm_map_t) mach_task_self(), (vm_address_t *) &aligned_memory, size, VM_FLAGS_ANYWHERE); + int result = EFAULT; + switch (alloc_status) { + case KERN_SUCCESS: + result = 0; + break; + case KERN_INVALID_ADDRESS: + result = EINVAL; + break; + case KERN_NO_SPACE: + result = ENOMEM; + break; + default: + result = EFAULT; + break; + } + #else + int result = posix_memalign(&aligned_memory, alignment, size); + #endif + if (result != 0) { + // Handle allocation failure + const char *error_desc = "unknown allocation error"; + switch (result) { + case EINVAL: + error_desc = "invalid alignment value"; + break; + case ENOMEM: + error_desc = "insufficient memory"; + break; + } + GGML_LOG_ERROR("%s: %s (attempted to allocate %6.2f MB)\n", __func__, error_desc, size/(1024.0*1024.0)); + return NULL; + } + return aligned_memory; +#endif +} + +void ggml_aligned_free(void * ptr, size_t size) { + GGML_UNUSED(size); +#if defined(_MSC_VER) || defined(__MINGW32__) + _aligned_free(ptr); +#elif GGML_USE_CPU_HBM + if (ptr != NULL) { + hbw_free(ptr); + } +#elif TARGET_OS_OSX + if (ptr != NULL) { + vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ptr, size); + } +#else + free(ptr); +#endif +} + + +inline static void * ggml_malloc(size_t size) { + if (size == 0) { + GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_malloc!\n"); + return NULL; + } + void * result = malloc(size); + if (result == NULL) { + GGML_LOG_ERROR("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0)); + GGML_ABORT("fatal error"); + } + return result; +} + +// calloc +inline static void * ggml_calloc(size_t num, size_t size) { + if (num == 0 || size == 0) { + GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_calloc!\n"); + return NULL; + } + void * result = calloc(num, size); + if (result == NULL) { + GGML_LOG_ERROR("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0)); + GGML_ABORT("fatal error"); + } + return result; +} + +#define GGML_MALLOC(size) ggml_malloc(size) +#define GGML_CALLOC(num, size) ggml_calloc(num, size) + +#define GGML_FREE(ptr) free(ptr) + +const char * ggml_status_to_string(enum ggml_status status) { + switch (status) { + case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)"; + case GGML_STATUS_FAILED: return "GGML status: error (operation failed)"; + case GGML_STATUS_SUCCESS: return "GGML status: success"; + case GGML_STATUS_ABORTED: return "GGML status: warning (operation aborted)"; + } + + return "GGML status: unknown"; +} + +float ggml_fp16_to_fp32(ggml_fp16_t x) { +#define ggml_fp16_to_fp32 do_not_use__ggml_fp16_to_fp32__in_ggml + return GGML_FP16_TO_FP32(x); +} + +ggml_fp16_t ggml_fp32_to_fp16(float x) { +#define ggml_fp32_to_fp16 do_not_use__ggml_fp32_to_fp16__in_ggml + return GGML_FP32_TO_FP16(x); +} + +float ggml_bf16_to_fp32(ggml_bf16_t x) { +#define ggml_bf16_to_fp32 do_not_use__ggml_bf16_to_fp32__in_ggml + return GGML_BF16_TO_FP32(x); // it just left shifts +} + +ggml_bf16_t ggml_fp32_to_bf16(float x) { +#define ggml_fp32_to_bf16 do_not_use__ggml_fp32_to_bf16__in_ggml + return GGML_FP32_TO_BF16(x); +} + +void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) { + for (int64_t i = 0; i < n; i++) { + y[i] = GGML_FP16_TO_FP32(x[i]); + } +} + +void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) { + int i = 0; + for (; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(x[i]); + } +} + +void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) { + int i = 0; + for (; i < n; ++i) { + y[i] = GGML_BF16_TO_FP32(x[i]); + } +} + +void ggml_fp32_to_bf16_row_ref(const float * x, ggml_bf16_t * y, int64_t n) { + for (int i = 0; i < n; i++) { + y[i] = ggml_compute_fp32_to_bf16(x[i]); + } +} + +void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) { + int i = 0; +#if defined(__AVX512BF16__) + // subnormals are flushed to zero on this platform + for (; i + 32 <= n; i += 32) { + _mm512_storeu_si512( + (__m512i *)(y + i), + m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16), + _mm512_loadu_ps(x + i)))); + } +#endif + for (; i < n; i++) { + y[i] = GGML_FP32_TO_BF16(x[i]); + } +} + +bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) { + return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0; +} + +// +// timing +// + +#if defined(_MSC_VER) || defined(__MINGW32__) +static int64_t timer_freq, timer_start; +void ggml_time_init(void) { + LARGE_INTEGER t; + QueryPerformanceFrequency(&t); + timer_freq = t.QuadPart; + + // The multiplication by 1000 or 1000000 below can cause an overflow if timer_freq + // and the uptime is high enough. + // We subtract the program start time to reduce the likelihood of that happening. + QueryPerformanceCounter(&t); + timer_start = t.QuadPart; +} +int64_t ggml_time_ms(void) { + LARGE_INTEGER t; + QueryPerformanceCounter(&t); + return ((t.QuadPart-timer_start) * 1000) / timer_freq; +} +int64_t ggml_time_us(void) { + LARGE_INTEGER t; + QueryPerformanceCounter(&t); + return ((t.QuadPart-timer_start) * 1000000) / timer_freq; +} +#else +void ggml_time_init(void) {} +int64_t ggml_time_ms(void) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (int64_t)ts.tv_sec*1000 + (int64_t)ts.tv_nsec/1000000; +} + +int64_t ggml_time_us(void) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (int64_t)ts.tv_sec*1000000 + (int64_t)ts.tv_nsec/1000; +} +#endif + +int64_t ggml_cycles(void) { + return clock(); +} + +int64_t ggml_cycles_per_ms(void) { + return CLOCKS_PER_SEC/1000; +} + +// +// cross-platform UTF-8 file paths +// + +#ifdef _WIN32 +static wchar_t * ggml_mbstowcs(const char * mbs) { + int wlen = MultiByteToWideChar(CP_UTF8, 0, mbs, -1, NULL, 0); + if (!wlen) { + errno = EINVAL; + return NULL; + } + + wchar_t * wbuf = GGML_MALLOC(wlen * sizeof(wchar_t)); + wlen = MultiByteToWideChar(CP_UTF8, 0, mbs, -1, wbuf, wlen); + if (!wlen) { + GGML_FREE(wbuf); + errno = EINVAL; + return NULL; + } + + return wbuf; +} +#endif + +FILE * ggml_fopen(const char * fname, const char * mode) { +#ifdef _WIN32 + FILE * file = NULL; + + // convert fname (UTF-8) + wchar_t * wfname = ggml_mbstowcs(fname); + if (wfname) { + // convert mode (ANSI) + wchar_t * wmode = GGML_MALLOC((strlen(mode) + 1) * sizeof(wchar_t)); + wchar_t * wmode_p = wmode; + do { + *wmode_p++ = (wchar_t)*mode; + } while (*mode++); + + // open file + file = _wfopen(wfname, wmode); + + GGML_FREE(wfname); + GGML_FREE(wmode); + } + + return file; +#else + return fopen(fname, mode); +#endif + +} +static void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc); +static void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc); +static void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc); + +static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { + [GGML_TYPE_I8] = { + .type_name = "i8", + .blck_size = 1, + .type_size = sizeof(int8_t), + .is_quantized = false, + }, + [GGML_TYPE_I16] = { + .type_name = "i16", + .blck_size = 1, + .type_size = sizeof(int16_t), + .is_quantized = false, + }, + [GGML_TYPE_I32] = { + .type_name = "i32", + .blck_size = 1, + .type_size = sizeof(int32_t), + .is_quantized = false, + }, + [GGML_TYPE_I64] = { + .type_name = "i64", + .blck_size = 1, + .type_size = sizeof(int64_t), + .is_quantized = false, + }, + [GGML_TYPE_F64] = { + .type_name = "f64", + .blck_size = 1, + .type_size = sizeof(double), + .is_quantized = false, + }, + [GGML_TYPE_F32] = { + .type_name = "f32", + .blck_size = 1, + .type_size = sizeof(float), + .is_quantized = false, + }, + [GGML_TYPE_F16] = { + .type_name = "f16", + .blck_size = 1, + .type_size = sizeof(ggml_fp16_t), + .is_quantized = false, + .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row, + .from_float_ref = (ggml_from_float_t) ggml_fp32_to_fp16_row, + }, + [GGML_TYPE_Q4_0] = { + .type_name = "q4_0", + .blck_size = QK4_0, + .type_size = sizeof(block_q4_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q4_0, + .from_float_ref = (ggml_from_float_t) quantize_row_q4_0_ref, + }, + [GGML_TYPE_Q4_1] = { + .type_name = "q4_1", + .blck_size = QK4_1, + .type_size = sizeof(block_q4_1), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q4_1, + .from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref, + }, + [4] = { // GGML_TYPE_Q4_2 + .type_name = "DEPRECATED", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, + }, + [5] = { // GGML_TYPE_Q4_3 + .type_name = "DEPRECATED", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, + }, + [GGML_TYPE_Q5_0] = { + .type_name = "q5_0", + .blck_size = QK5_0, + .type_size = sizeof(block_q5_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q5_0, + .from_float_ref = (ggml_from_float_t) quantize_row_q5_0_ref, + }, + [GGML_TYPE_Q5_1] = { + .type_name = "q5_1", + .blck_size = QK5_1, + .type_size = sizeof(block_q5_1), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q5_1, + .from_float_ref = (ggml_from_float_t) quantize_row_q5_1_ref, + }, + [GGML_TYPE_Q8_0] = { + .type_name = "q8_0", + .blck_size = QK8_0, + .type_size = sizeof(block_q8_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q8_0, + .from_float_ref = (ggml_from_float_t) quantize_row_q8_0_ref, + }, + [GGML_TYPE_Q8_1] = { + .type_name = "q8_1", + .blck_size = QK8_1, + .type_size = sizeof(block_q8_1), + .is_quantized = true, + .from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref, + }, + [GGML_TYPE_Q2_K] = { + .type_name = "q2_K", + .blck_size = QK_K, + .type_size = sizeof(block_q2_K), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q2_K, + .from_float_ref = (ggml_from_float_t) quantize_row_q2_K_ref, + }, + [GGML_TYPE_Q3_K] = { + .type_name = "q3_K", + .blck_size = QK_K, + .type_size = sizeof(block_q3_K), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q3_K, + .from_float_ref = (ggml_from_float_t) quantize_row_q3_K_ref, + }, + [GGML_TYPE_Q4_K] = { + .type_name = "q4_K", + .blck_size = QK_K, + .type_size = sizeof(block_q4_K), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q4_K, + .from_float_ref = (ggml_from_float_t) quantize_row_q4_K_ref, + }, + [GGML_TYPE_Q5_K] = { + .type_name = "q5_K", + .blck_size = QK_K, + .type_size = sizeof(block_q5_K), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q5_K, + .from_float_ref = (ggml_from_float_t) quantize_row_q5_K_ref, + }, + [GGML_TYPE_Q6_K] = { + .type_name = "q6_K", + .blck_size = QK_K, + .type_size = sizeof(block_q6_K), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q6_K, + .from_float_ref = (ggml_from_float_t) quantize_row_q6_K_ref, + }, + [GGML_TYPE_IQ2_XXS] = { + .type_name = "iq2_xxs", + .blck_size = QK_K, + .type_size = sizeof(block_iq2_xxs), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq2_xxs, + .from_float_ref = NULL, + }, + [GGML_TYPE_IQ2_XS] = { + .type_name = "iq2_xs", + .blck_size = QK_K, + .type_size = sizeof(block_iq2_xs), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq2_xs, + .from_float_ref = NULL, + }, + [GGML_TYPE_IQ3_XXS] = { + .type_name = "iq3_xxs", + .blck_size = QK_K, + .type_size = sizeof(block_iq3_xxs), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq3_xxs, + .from_float_ref = (ggml_from_float_t)quantize_row_iq3_xxs_ref, + }, + [GGML_TYPE_IQ3_S] = { + .type_name = "iq3_s", + .blck_size = QK_K, + .type_size = sizeof(block_iq3_s), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq3_s, + .from_float_ref = (ggml_from_float_t)quantize_row_iq3_s_ref, + }, + [GGML_TYPE_IQ2_S] = { + .type_name = "iq2_s", + .blck_size = QK_K, + .type_size = sizeof(block_iq2_s), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq2_s, + .from_float_ref = (ggml_from_float_t)quantize_row_iq2_s_ref, + }, + [GGML_TYPE_IQ1_S] = { + .type_name = "iq1_s", + .blck_size = QK_K, + .type_size = sizeof(block_iq1_s), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq1_s, + .from_float_ref = NULL, + }, + [GGML_TYPE_IQ1_M] = { + .type_name = "iq1_m", + .blck_size = QK_K, + .type_size = sizeof(block_iq1_m), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq1_m, + .from_float_ref = NULL, + }, + [GGML_TYPE_IQ4_NL] = { + .type_name = "iq4_nl", + .blck_size = QK4_NL, + .type_size = sizeof(block_iq4_nl), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq4_nl, + .from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref, + }, + [GGML_TYPE_IQ4_XS] = { + .type_name = "iq4_xs", + .blck_size = QK_K, + .type_size = sizeof(block_iq4_xs), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq4_xs, + .from_float_ref = (ggml_from_float_t)quantize_row_iq4_xs_ref, + }, + [GGML_TYPE_Q8_K] = { + .type_name = "q8_K", + .blck_size = QK_K, + .type_size = sizeof(block_q8_K), + .is_quantized = true, + }, + [GGML_TYPE_BF16] = { + .type_name = "bf16", + .blck_size = 1, + .type_size = sizeof(ggml_bf16_t), + .is_quantized = false, + .to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row, + .from_float_ref = (ggml_from_float_t) ggml_fp32_to_bf16_row_ref, + }, + [31] = { // GGML_TYPE_Q4_0_4_4 + .type_name = "TYPE_Q4_0_4_4 REMOVED, use Q4_0 with runtime repacking", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, + }, + [32] = { // GGML_TYPE_Q4_0_4_8 + .type_name = "TYPE_Q4_0_4_8 REMOVED, use Q4_0 with runtime repacking", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, + }, + [33] = { // GGML_TYPE_Q4_0_8_8 + .type_name = "TYPE_Q4_0_8_8 REMOVED, use Q4_0 with runtime repacking", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, + }, + [GGML_TYPE_TQ1_0] = { + .type_name = "tq1_0", + .blck_size = QK_K, + .type_size = sizeof(block_tq1_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_tq1_0, + .from_float_ref = (ggml_from_float_t) quantize_row_tq1_0_ref, + }, + [GGML_TYPE_TQ2_0] = { + .type_name = "tq2_0", + .blck_size = QK_K, + .type_size = sizeof(block_tq2_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_tq2_0, + .from_float_ref = (ggml_from_float_t) quantize_row_tq2_0_ref, + }, + [36] = { // GGML_TYPE_IQ4_NL_4_4 + .type_name = "TYPE_IQ4_NL_4_4 REMOVED, use IQ4_NL with runtime repacking", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, + }, + [37] = { // GGML_TYPE_IQ4_NL_4_8 + .type_name = "TYPE_IQ4_NL_4_8 REMOVED, use IQ4_NL with runtime repacking", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, + }, + [38] = { // GGML_TYPE_IQ4_NL_8_8 + .type_name = "TYPE_IQ4_NL_8_8 REMOVED, use IQ4_NL with runtime repacking", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, + }, +}; + +const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) { + GGML_ASSERT(type < GGML_TYPE_COUNT); + return &type_traits[type]; +} + +// +// ggml object +// + +struct ggml_object { + size_t offs; + size_t size; + + struct ggml_object * next; + + enum ggml_object_type type; + + char padding[4]; +}; + +static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object); + +// +// ggml context +// + +struct ggml_context { + size_t mem_size; + void * mem_buffer; + bool mem_buffer_owned; + bool no_alloc; + + int n_objects; + + struct ggml_object * objects_begin; + struct ggml_object * objects_end; +}; + +struct ggml_context_container { + bool used; + + struct ggml_context context; +}; + +// +// data types +// + +static const char * GGML_OP_NAME[GGML_OP_COUNT] = { + "NONE", + + "DUP", + "ADD", + "ADD1", + "ACC", + "SUB", + "MUL", + "DIV", + "SQR", + "SQRT", + "LOG", + "SIN", + "COS", + "SUM", + "SUM_ROWS", + "MEAN", + "ARGMAX", + "COUNT_EQUAL", + "REPEAT", + "REPEAT_BACK", + "CONCAT", + "SILU_BACK", + "NORM", + "RMS_NORM", + "RMS_NORM_BACK", + "GROUP_NORM", + "L2_NORM", + + "MUL_MAT", + "MUL_MAT_ID", + "OUT_PROD", + + "SCALE", + "SET", + "CPY", + "CONT", + "RESHAPE", + "VIEW", + "PERMUTE", + "TRANSPOSE", + "GET_ROWS", + "GET_ROWS_BACK", + "DIAG", + "DIAG_MASK_INF", + "DIAG_MASK_ZERO", + "SOFT_MAX", + "SOFT_MAX_BACK", + "ROPE", + "ROPE_BACK", + "CLAMP", + "CONV_TRANSPOSE_1D", + "IM2COL", + "IM2COL_BACK", + "CONV_2D_DW", + "CONV_TRANSPOSE_2D", + "POOL_1D", + "POOL_2D", + "POOL_2D_BACK", + "UPSCALE", + "PAD", + "PAD_REFLECT_1D", + "ARANGE", + "TIMESTEP_EMBEDDING", + "ARGSORT", + "LEAKY_RELU", + + "FLASH_ATTN_EXT", + "FLASH_ATTN_BACK", + "SSM_CONV", + "SSM_SCAN", + "WIN_PART", + "WIN_UNPART", + "GET_REL_POS", + "ADD_REL_POS", + "RWKV_WKV6", + "GATED_LINEAR_ATTN", + "RWKV_WKV7", + + "UNARY", + + "MAP_CUSTOM1", + "MAP_CUSTOM2", + "MAP_CUSTOM3", + + "CUSTOM", + + "CROSS_ENTROPY_LOSS", + "CROSS_ENTROPY_LOSS_BACK", + "OPT_STEP_ADAMW", +}; + +static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); + +static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { + "none", + + "x", + "x+y", + "x+y", + "view(x,nb,offset)+=y->x", + "x-y", + "x*y", + "x/y", + "x^2", + "√x", + "log(x)", + "sin(x)", + "cos(x)", + "Σx", + "Σx_k", + "Σx/n", + "argmax(x)", + "count_equal(x)", + "repeat(x)", + "repeat_back(x)", + "concat(x, y)", + "silu_back(x)", + "norm(x)", + "rms_norm(x)", + "rms_norm_back(x)", + "group_norm(x)", + "l2_norm(x)", + + "X*Y", + "X[i]*Y", + "X*Y", + + "x*v", + "y-\\>view(x)", + "x-\\>y", + "cont(x)", + "reshape(x)", + "view(x)", + "permute(x)", + "transpose(x)", + "get_rows(x)", + "get_rows_back(x)", + "diag(x)", + "diag_mask_inf(x)", + "diag_mask_zero(x)", + "soft_max(x)", + "soft_max_back(x)", + "rope(x)", + "rope_back(x)", + "clamp(x)", + "conv_transpose_1d(x)", + "im2col(x)", + "im2col_back(x)", + "conv_2d_dw(x)", + "conv_transpose_2d(x)", + "pool_1d(x)", + "pool_2d(x)", + "pool_2d_back(x)", + "upscale(x)", + "pad(x)", + "pad_reflect_1d(x)", + "arange(start, stop, step)", + "timestep_embedding(timesteps, dim, max_period)", + "argsort(x)", + "leaky_relu(x)", + + "flash_attn_ext(x)", + "flash_attn_back(x)", + "ssm_conv(x)", + "ssm_scan(x)", + "win_part(x)", + "win_unpart(x)", + "get_rel_pos(x)", + "add_rel_pos(x)", + "rwkv_wkv6(k, v, r, tf, td, s)", + "gated_linear_attn(k, v, q, gate, s)", + "rwkv_wkv7(r, w, k, v, a, b, s)", + + "unary(x)", + + "map_custom(x)", + "map_custom(x,y)", + "map_custom(x,y,z)", + + "custom(x)", + + "cross_entropy_loss(x,y)", + "cross_entropy_loss_back(x,y)", + "adamw(x)", +}; + +static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); + +static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); + + +static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { + "ABS", + "SGN", + "NEG", + "STEP", + "TANH", + "ELU", + "RELU", + "SIGMOID", + "GELU", + "GELU_QUICK", + "SILU", + "HARDSWISH", + "HARDSIGMOID", + "EXP", + "GELU_ERF", +}; + +static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15"); + + +static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); +static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); + + +//////////////////////////////////////////////////////////////////////////////// + +void ggml_print_object(const struct ggml_object * obj) { + GGML_LOG_INFO(" - ggml_object: type = %d, offset = %zu, size = %zu, next = %p\n", + obj->type, obj->offs, obj->size, (const void *) obj->next); +} + +void ggml_print_objects(const struct ggml_context * ctx) { + struct ggml_object * obj = ctx->objects_begin; + + GGML_LOG_INFO("%s: objects in context %p:\n", __func__, (const void *) ctx); + + while (obj != NULL) { + ggml_print_object(obj); + obj = obj->next; + } + + GGML_LOG_INFO("%s: --- end ---\n", __func__); +} + +int64_t ggml_nelements(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; +} + +int64_t ggml_nrows(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; +} + +size_t ggml_nbytes(const struct ggml_tensor * tensor) { + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + if (tensor->ne[i] <= 0) { + return 0; + } + } + + size_t nbytes; + const size_t blck_size = ggml_blck_size(tensor->type); + if (blck_size == 1) { + nbytes = ggml_type_size(tensor->type); + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; + } + } + else { + nbytes = tensor->ne[0]*tensor->nb[0]/blck_size; + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; + } + } + + return nbytes; +} + +size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) { + return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN); +} + +int64_t ggml_blck_size(enum ggml_type type) { + return type_traits[type].blck_size; +} + +size_t ggml_type_size(enum ggml_type type) { + return type_traits[type].type_size; +} + +size_t ggml_row_size(enum ggml_type type, int64_t ne) { + assert(ne % ggml_blck_size(type) == 0); + return ggml_type_size(type)*ne/ggml_blck_size(type); +} + +double ggml_type_sizef(enum ggml_type type) { + return ((double)(type_traits[type].type_size))/type_traits[type].blck_size; +} + +const char * ggml_type_name(enum ggml_type type) { + return type < GGML_TYPE_COUNT ? type_traits[type].type_name : "NONE"; +} + +bool ggml_is_quantized(enum ggml_type type) { + return type_traits[type].is_quantized; +} + +const char * ggml_op_name(enum ggml_op op) { + return GGML_OP_NAME[op]; +} + +const char * ggml_op_symbol(enum ggml_op op) { + return GGML_OP_SYMBOL[op]; +} + +const char * ggml_unary_op_name(enum ggml_unary_op op) { + return GGML_UNARY_OP_NAME[op]; +} + +const char * ggml_op_desc(const struct ggml_tensor * t) { + if (t->op == GGML_OP_UNARY) { + enum ggml_unary_op uop = ggml_get_unary_op(t); + return ggml_unary_op_name(uop); + } + return ggml_op_name(t->op); +} + +size_t ggml_element_size(const struct ggml_tensor * tensor) { + return ggml_type_size(tensor->type); +} + +bool ggml_is_scalar(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +bool ggml_is_vector(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +bool ggml_is_matrix(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +bool ggml_is_3d(const struct ggml_tensor * tensor) { + return tensor->ne[3] == 1; +} + +int ggml_n_dims(const struct ggml_tensor * tensor) { + for (int i = GGML_MAX_DIMS - 1; i >= 1; --i) { + if (tensor->ne[i] > 1) { + return i + 1; + } + } + return 1; +} + +enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { + enum ggml_type wtype = GGML_TYPE_COUNT; + + switch (ftype) { + case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break; + case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break; + case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break; + case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break; + case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break; + case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break; + case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break; + case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; + case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break; + case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break; + case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break; + case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break; + case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break; + case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break; + case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break; + case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break; + case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break; + case GGML_FTYPE_MOSTLY_IQ1_M: wtype = GGML_TYPE_IQ1_M; break; + case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break; + case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; + case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break; + case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break; + case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; + case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; + } + + GGML_ASSERT(wtype != GGML_TYPE_COUNT); + + return wtype; +} + +size_t ggml_tensor_overhead(void) { + return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE; +} + +bool ggml_is_transposed(const struct ggml_tensor * tensor) { + return tensor->nb[0] > tensor->nb[1]; +} + +static bool ggml_is_contiguous_n(const struct ggml_tensor * tensor, int n) { + size_t next_nb = ggml_type_size(tensor->type); + if (tensor->ne[0] != ggml_blck_size(tensor->type) && tensor->nb[0] != next_nb) { + return false; + } + next_nb *= tensor->ne[0]/ggml_blck_size(tensor->type); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + if (tensor->ne[i] != 1) { + if (i > n) { + if (tensor->nb[i] != next_nb) { + return false; + } + next_nb *= tensor->ne[i]; + } else { + // this dimension does not need to be contiguous + next_nb = tensor->ne[i]*tensor->nb[i]; + } + } + } + return true; +} + +bool ggml_is_contiguous(const struct ggml_tensor * tensor) { + return ggml_is_contiguous_0(tensor); +} + +bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) { + return ggml_is_contiguous_n(tensor, 0); +} + +bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) { + return ggml_is_contiguous_n(tensor, 1); +} + +bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) { + return ggml_is_contiguous_n(tensor, 2); +} + +bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor) { + return ggml_nbytes(tensor) == ggml_nelements(tensor) * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type); +} + +bool ggml_is_permuted(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3]; +} + +bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) { + return + tensor->nb[0] > tensor->nb[2] && + tensor->nb[1] > tensor->nb[0] && + tensor->nb[2] == ggml_type_size(tensor->type); +} + +static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + tensor->nb[0] == ggml_type_size(tensor->type) && + tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + +bool ggml_is_empty(const struct ggml_tensor * tensor) { + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + if (tensor->ne[i] == 0) { + // empty if any dimension has no elements + return true; + } + } + return false; +} + +bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + (t0->ne[0] == t1->ne[0]) && + (t0->ne[1] == t1->ne[1]) && + (t0->ne[2] == t1->ne[2]) && + (t0->ne[3] == t1->ne[3]); +} + +bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + (t0->nb[0] == t1->nb[0]) && + (t0->nb[1] == t1->nb[1]) && + (t0->nb[2] == t1->nb[2]) && + (t0->nb[3] == t1->nb[3]); +} + +// check if t1 can be represented as a repetition of t0 +bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return ggml_is_empty(t0) ? ggml_is_empty(t1) : + (t1->ne[0]%t0->ne[0] == 0) && + (t1->ne[1]%t0->ne[1] == 0) && + (t1->ne[2]%t0->ne[2] == 0) && + (t1->ne[3]%t0->ne[3] == 0); +} + +static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return (t0->ne[0] == t1->ne[0]) && ggml_can_repeat(t0, t1); +} + +// assert that pointer is aligned to GGML_MEM_ALIGN +#define GGML_ASSERT_ALIGNED(ptr) \ + GGML_ASSERT(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0) + +//////////////////////////////////////////////////////////////////////////////// + +struct ggml_context * ggml_init(struct ggml_init_params params) { + static bool is_first_call = true; + + ggml_critical_section_start(); + + if (is_first_call) { + // initialize time system (required on Windows) + ggml_time_init(); + + for (int i = 0; i < (1 << 16); ++i) { + union { + uint16_t u16; + ggml_fp16_t fp16; + } u = {i}; + ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16); + } + + is_first_call = false; + } + + ggml_critical_section_end(); + + struct ggml_context * ctx = GGML_MALLOC(sizeof(struct ggml_context)); + + // allow to call ggml_init with 0 size + if (params.mem_size == 0) { + params.mem_size = GGML_MEM_ALIGN; + } + + const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_PAD(params.mem_size, GGML_MEM_ALIGN); + + *ctx = (struct ggml_context) { + /*.mem_size =*/ mem_size, + /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : ggml_aligned_malloc(mem_size), + /*.mem_buffer_owned =*/ params.mem_buffer ? false : true, + /*.no_alloc =*/ params.no_alloc, + /*.n_objects =*/ 0, + /*.objects_begin =*/ NULL, + /*.objects_end =*/ NULL, + }; + + GGML_ASSERT(ctx->mem_buffer != NULL); + + GGML_ASSERT_ALIGNED(ctx->mem_buffer); + + GGML_PRINT_DEBUG("%s: context initialized\n", __func__); + + return ctx; +} + +void ggml_reset(struct ggml_context * ctx) { + if (ctx == NULL) { + return; + } + + ctx->n_objects = 0; + ctx->objects_begin = NULL; + ctx->objects_end = NULL; +} + +void ggml_free(struct ggml_context * ctx) { + if (ctx == NULL) { + return; + } + + if (ctx->mem_buffer_owned) { + ggml_aligned_free(ctx->mem_buffer, ctx->mem_size); + } + + GGML_FREE(ctx); +} + +size_t ggml_used_mem(const struct ggml_context * ctx) { + return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size; +} + +bool ggml_get_no_alloc(struct ggml_context * ctx) { + return ctx->no_alloc; +} + +void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc) { + ctx->no_alloc = no_alloc; +} + +void * ggml_get_mem_buffer(const struct ggml_context * ctx) { + return ctx->mem_buffer; +} + +size_t ggml_get_mem_size(const struct ggml_context * ctx) { + return ctx->mem_size; +} + +size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) { + size_t max_size = 0; + + for (struct ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor != NULL; tensor = ggml_get_next_tensor(ctx, tensor)) { + size_t bytes = ggml_nbytes(tensor); + max_size = MAX(max_size, bytes); + } + + return max_size; +} + +//////////////////////////////////////////////////////////////////////////////// + +static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml_object_type type, size_t size) { + // always insert objects at the end of the context's memory pool + struct ggml_object * obj_cur = ctx->objects_end; + + const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs; + const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size; + const size_t cur_end = cur_offs + cur_size; + + // align to GGML_MEM_ALIGN + size_t size_needed = GGML_PAD(size, GGML_MEM_ALIGN); + + char * const mem_buffer = ctx->mem_buffer; + struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end); + + if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) { + GGML_LOG_WARN("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", + __func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size); +#ifndef NDEBUG + GGML_ABORT("not enough space in the context's memory pool"); +#endif + return NULL; + } + + *obj_new = (struct ggml_object) { + .offs = cur_end + GGML_OBJECT_SIZE, + .size = size_needed, + .next = NULL, + .type = type, + }; + + GGML_ASSERT_ALIGNED(mem_buffer + obj_new->offs); + + if (obj_cur != NULL) { + obj_cur->next = obj_new; + } else { + // this is the first object in this context + ctx->objects_begin = obj_new; + } + + ctx->objects_end = obj_new; + + //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size); + + return obj_new; +} + +static struct ggml_tensor * ggml_new_tensor_impl( + struct ggml_context * ctx, + enum ggml_type type, + int n_dims, + const int64_t * ne, + struct ggml_tensor * view_src, + size_t view_offs) { + + GGML_ASSERT(type >= 0 && type < GGML_TYPE_COUNT); + GGML_ASSERT(n_dims >= 1 && n_dims <= GGML_MAX_DIMS); + + // find the base tensor and absolute offset + if (view_src != NULL && view_src->view_src != NULL) { + view_offs += view_src->view_offs; + view_src = view_src->view_src; + } + + size_t data_size = ggml_row_size(type, ne[0]); + for (int i = 1; i < n_dims; i++) { + data_size *= ne[i]; + } + + GGML_ASSERT(view_src == NULL || data_size == 0 || data_size + view_offs <= ggml_nbytes(view_src)); + + void * data = view_src != NULL ? view_src->data : NULL; + if (data != NULL) { + data = (char *) data + view_offs; + } + + size_t obj_alloc_size = 0; + + if (view_src == NULL && !ctx->no_alloc) { + // allocate tensor data in the context's memory pool + obj_alloc_size = data_size; + } + + struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TYPE_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size); + GGML_ASSERT(obj_new); + + struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs); + + *result = (struct ggml_tensor) { + /*.type =*/ type, + /*.buffer =*/ NULL, + /*.ne =*/ { 1, 1, 1, 1 }, + /*.nb =*/ { 0, 0, 0, 0 }, + /*.op =*/ GGML_OP_NONE, + /*.op_params =*/ { 0 }, + /*.flags =*/ 0, + /*.src =*/ { NULL }, + /*.view_src =*/ view_src, + /*.view_offs =*/ view_offs, + /*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data, + /*.name =*/ { 0 }, + /*.extra =*/ NULL, + /*.padding =*/ { 0 }, + }; + + // TODO: this should not be needed as long as we don't rely on aligned SIMD loads + //GGML_ASSERT_ALIGNED(result->data); + + for (int i = 0; i < n_dims; i++) { + result->ne[i] = ne[i]; + } + + result->nb[0] = ggml_type_size(type); + result->nb[1] = result->nb[0]*(result->ne[0]/ggml_blck_size(type)); + for (int i = 2; i < GGML_MAX_DIMS; i++) { + result->nb[i] = result->nb[i - 1]*result->ne[i - 1]; + } + + ctx->n_objects++; + + return result; +} + +struct ggml_tensor * ggml_new_tensor( + struct ggml_context * ctx, + enum ggml_type type, + int n_dims, + const int64_t * ne) { + return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL, 0); +} + +struct ggml_tensor * ggml_new_tensor_1d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0) { + return ggml_new_tensor(ctx, type, 1, &ne0); +} + +struct ggml_tensor * ggml_new_tensor_2d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1) { + const int64_t ne[2] = { ne0, ne1 }; + return ggml_new_tensor(ctx, type, 2, ne); +} + +struct ggml_tensor * ggml_new_tensor_3d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2) { + const int64_t ne[3] = { ne0, ne1, ne2 }; + return ggml_new_tensor(ctx, type, 3, ne); +} + +struct ggml_tensor * ggml_new_tensor_4d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3) { + const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; + return ggml_new_tensor(ctx, type, 4, ne); +} + +void * ggml_new_buffer(struct ggml_context * ctx, size_t nbytes) { + struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_WORK_BUFFER, nbytes); + + return (uint8_t *)ctx->mem_buffer + obj->offs; +} + +struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggml_tensor * src) { + return ggml_new_tensor(ctx, src->type, GGML_MAX_DIMS, src->ne); +} + +void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) { + const int64_t ne2 = tensor->ne[2]; + const int64_t ne1 = tensor->ne[1]; + const int64_t ne0 = tensor->ne[0]; + + const int64_t i3_ = (i/(ne2*ne1*ne0)); + const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0); + const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0; + const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0); + + if (i0) { + * i0 = i0_; + } + if (i1) { + * i1 = i1_; + } + if (i2) { + * i2 = i2_; + } + if (i3) { + * i3 = i3_; + } +} + +void * ggml_get_data(const struct ggml_tensor * tensor) { + return tensor->data; +} + +float * ggml_get_data_f32(const struct ggml_tensor * tensor) { + assert(tensor->type == GGML_TYPE_F32); + return (float *)(tensor->data); +} + +enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) { + GGML_ASSERT(tensor->op == GGML_OP_UNARY); + return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0); +} + +const char * ggml_get_name(const struct ggml_tensor * tensor) { + return tensor->name; +} + +struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name) { + size_t i; + for (i = 0; i < sizeof(tensor->name) - 1 && name[i] != '\0'; i++) { + tensor->name[i] = name[i]; + } + tensor->name[i] = '\0'; + return tensor; +} + +struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char * fmt, ...) { + va_list args; + va_start(args, fmt); + vsnprintf(tensor->name, sizeof(tensor->name), fmt, args); + va_end(args); + return tensor; +} + +struct ggml_tensor * ggml_view_tensor( + struct ggml_context * ctx, + struct ggml_tensor * src) { + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, src->type, GGML_MAX_DIMS, src->ne, src, 0); + ggml_format_name(result, "%s (view)", src->name); + + for (int i = 0; i < GGML_MAX_DIMS; i++) { + result->nb[i] = src->nb[i]; + } + + return result; +} + +struct ggml_tensor * ggml_get_first_tensor(const struct ggml_context * ctx) { + struct ggml_object * obj = ctx->objects_begin; + + char * const mem_buffer = ctx->mem_buffer; + + while (obj != NULL) { + if (obj->type == GGML_OBJECT_TYPE_TENSOR) { + return (struct ggml_tensor *)(mem_buffer + obj->offs); + } + + obj = obj->next; + } + + return NULL; +} + +struct ggml_tensor * ggml_get_next_tensor(const struct ggml_context * ctx, struct ggml_tensor * tensor) { + struct ggml_object * obj = (struct ggml_object *) ((char *)tensor - GGML_OBJECT_SIZE); + obj = obj->next; + + char * const mem_buffer = ctx->mem_buffer; + + while (obj != NULL) { + if (obj->type == GGML_OBJECT_TYPE_TENSOR) { + return (struct ggml_tensor *)(mem_buffer + obj->offs); + } + + obj = obj->next; + } + + return NULL; +} + +struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name) { + struct ggml_object * obj = ctx->objects_begin; + + char * const mem_buffer = ctx->mem_buffer; + + while (obj != NULL) { + if (obj->type == GGML_OBJECT_TYPE_TENSOR) { + struct ggml_tensor * cur = (struct ggml_tensor *)(mem_buffer + obj->offs); + if (strcmp(cur->name, name) == 0) { + return cur; + } + } + + obj = obj->next; + } + + return NULL; +} + +//////////////////////////////////////////////////////////////////////////////// + +// ggml_dup + +static struct ggml_tensor * ggml_dup_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_DUP; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_dup( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_dup_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_dup_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_dup_impl(ctx, a, true); +} + +// ggml_add + +static struct ggml_tensor * ggml_add_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + GGML_ASSERT(ggml_can_repeat(b, a)); + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_ADD; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_add( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_add_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_add_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_add_impl(ctx, a, b, true); +} + +// ggml_add_cast + +static struct ggml_tensor * ggml_add_cast_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_type type) { + // TODO: support less-strict constraint + // GGML_ASSERT(ggml_can_repeat(b, a)); + GGML_ASSERT(ggml_can_repeat_rows(b, a)); + + // currently only supported for quantized input and f16 + GGML_ASSERT(ggml_is_quantized(a->type) || + a->type == GGML_TYPE_F16 || + a->type == GGML_TYPE_BF16); + + struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne); + + result->op = GGML_OP_ADD; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_add_cast( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_type type) { + return ggml_add_cast_impl(ctx, a, b, type); +} + +// ggml_add1 + +static struct ggml_tensor * ggml_add1_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + GGML_ASSERT(ggml_is_scalar(b)); + GGML_ASSERT(ggml_is_padded_1d(a)); + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_ADD1; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_add1( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_add1_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_add1_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_add1_impl(ctx, a, b, true); +} + +// ggml_acc + +static struct ggml_tensor * ggml_acc_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset, + bool inplace) { + GGML_ASSERT(ggml_nelements(b) <= ggml_nelements(a)); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(a->type == GGML_TYPE_F32); + GGML_ASSERT(b->type == GGML_TYPE_F32); + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_ACC; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_acc( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); +} + +struct ggml_tensor * ggml_acc_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true); +} + +// ggml_sub + +static struct ggml_tensor * ggml_sub_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + GGML_ASSERT(ggml_can_repeat(b, a)); + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SUB; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_sub( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_sub_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_sub_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_sub_impl(ctx, a, b, true); +} + +// ggml_mul + +static struct ggml_tensor * ggml_mul_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + GGML_ASSERT(ggml_can_repeat(b, a)); + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_MUL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_mul( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_mul_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_mul_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_mul_impl(ctx, a, b, true); +} + +// ggml_div + +static struct ggml_tensor * ggml_div_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + GGML_ASSERT(ggml_can_repeat(b, a)); + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_DIV; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_div( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_div_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_div_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_div_impl(ctx, a, b, true); +} + +// ggml_sqr + +static struct ggml_tensor * ggml_sqr_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SQR; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_sqr( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqr_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_sqr_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqr_impl(ctx, a, true); +} + +// ggml_sqrt + +static struct ggml_tensor * ggml_sqrt_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SQRT; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_sqrt( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqrt_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_sqrt_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqrt_impl(ctx, a, true); +} + +// ggml_log + +static struct ggml_tensor * ggml_log_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_LOG; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_log( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_log_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_log_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_log_impl(ctx, a, true); +} + +// ggml_sin + +static struct ggml_tensor * ggml_sin_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SIN; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_sin( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sin_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_sin_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sin_impl(ctx, a, true); +} + +// ggml_cos + +static struct ggml_tensor * ggml_cos_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_COS; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_cos( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_cos_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_cos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_cos_impl(ctx, a, true); +} + +// ggml_sum + +struct ggml_tensor * ggml_sum( + struct ggml_context * ctx, + struct ggml_tensor * a) { + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1); + + result->op = GGML_OP_SUM; + result->src[0] = a; + + return result; +} + +// ggml_sum_rows + +struct ggml_tensor * ggml_sum_rows( + struct ggml_context * ctx, + struct ggml_tensor * a) { + int64_t ne[GGML_MAX_DIMS] = { 1 }; + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + ne[i] = a->ne[i]; + } + + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne); + + result->op = GGML_OP_SUM_ROWS; + result->src[0] = a; + + return result; +} + +// ggml_mean + +struct ggml_tensor * ggml_mean( + struct ggml_context * ctx, + struct ggml_tensor * a) { + int64_t ne[4] = { 1, a->ne[1], a->ne[2], a->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_MEAN; + result->src[0] = a; + + return result; +} + +// ggml_argmax + +struct ggml_tensor * ggml_argmax( + struct ggml_context * ctx, + struct ggml_tensor * a) { + GGML_ASSERT(ggml_is_matrix(a)); + GGML_ASSERT(a->ne[0] <= INT32_MAX); + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, a->ne[1]); + + result->op = GGML_OP_ARGMAX; + result->src[0] = a; + + return result; +} + +// ggml_count_equal + +struct ggml_tensor * ggml_count_equal( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_are_same_shape(a, b)); + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, 1); + + result->op = GGML_OP_COUNT_EQUAL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_repeat + +struct ggml_tensor * ggml_repeat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_can_repeat(a, b)); + + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne); + + result->op = GGML_OP_REPEAT; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_repeat_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) { + const bool can_repeat = ggml_is_empty(a) || ( + (ne0 % a->ne[0] == 0) && + (ne1 % a->ne[1] == 0) && + (ne2 % a->ne[2] == 0) && + (ne3 % a->ne[3] == 0) + ); + GGML_ASSERT(can_repeat); + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); + + result->op = GGML_OP_REPEAT; + result->src[0] = a; + + return result; +} + +// ggml_repeat_back + +struct ggml_tensor * ggml_repeat_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_can_repeat(b, a)); + + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne); + + result->op = GGML_OP_REPEAT_BACK; + result->src[0] = a; + + return result; +} + +// ggml_concat + +struct ggml_tensor * ggml_concat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int dim) { + GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS); + GGML_ASSERT(a->type == b->type); + + int64_t ne[GGML_MAX_DIMS]; + for (int d = 0; d < GGML_MAX_DIMS; ++d) { + if (d == dim) { + ne[d] = a->ne[d] + b->ne[d]; + continue; + } + GGML_ASSERT(a->ne[d] == b->ne[d]); + ne[d] = a->ne[d]; + } + + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne); + + ggml_set_op_params_i32(result, 0, dim); + + result->op = GGML_OP_CONCAT; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_abs + +struct ggml_tensor * ggml_abs( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_ABS); +} + +struct ggml_tensor * ggml_abs_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ABS); +} + +// ggml_sgn + +struct ggml_tensor * ggml_sgn( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_SGN); +} + +struct ggml_tensor * ggml_sgn_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SGN); +} + +// ggml_neg + +struct ggml_tensor * ggml_neg( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_NEG); +} + +struct ggml_tensor * ggml_neg_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_NEG); +} + +// ggml_step + +struct ggml_tensor * ggml_step( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_STEP); +} + +struct ggml_tensor * ggml_step_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_STEP); +} + +// ggml_tanh + +struct ggml_tensor * ggml_tanh( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_TANH); +} + +struct ggml_tensor * ggml_tanh_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_TANH); +} + +// ggml_elu + +struct ggml_tensor * ggml_elu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_ELU); +} + +struct ggml_tensor * ggml_elu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ELU); +} + +// ggml_relu + +struct ggml_tensor * ggml_relu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_RELU); +} + +struct ggml_tensor * ggml_relu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU); +} + +// ggml_leaky_relu + +struct ggml_tensor * ggml_leaky_relu( + struct ggml_context * ctx, + struct ggml_tensor * a, + float negative_slope, + bool inplace) { + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params(result, &negative_slope, sizeof(negative_slope)); + + result->op = GGML_OP_LEAKY_RELU; + result->src[0] = a; + + return result; +} + +// ggml_sigmoid + +struct ggml_tensor * ggml_sigmoid( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_SIGMOID); +} + +struct ggml_tensor * ggml_sigmoid_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SIGMOID); +} + +// ggml_gelu + +struct ggml_tensor * ggml_gelu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_GELU); +} + +struct ggml_tensor * ggml_gelu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU); +} + +// ggml_gelu_erf + +struct ggml_tensor * ggml_gelu_erf( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_ERF); +} + +struct ggml_tensor * ggml_gelu_erf_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_ERF); +} + +// ggml_gelu_quick + +struct ggml_tensor * ggml_gelu_quick( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_QUICK); +} + +struct ggml_tensor * ggml_gelu_quick_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_QUICK); +} + +// ggml_silu + +struct ggml_tensor * ggml_silu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_SILU); +} + +struct ggml_tensor * ggml_silu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU); +} + +// ggml_silu_back + +struct ggml_tensor * ggml_silu_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SILU_BACK; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml hardswish + +struct ggml_tensor * ggml_hardswish( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSWISH); +} + +// ggml hardsigmoid + +struct ggml_tensor * ggml_hardsigmoid( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSIGMOID); +} + +// ggml exp + +struct ggml_tensor * ggml_exp( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_EXP); +} + +struct ggml_tensor * ggml_exp_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP); +} + +// ggml_norm + +static struct ggml_tensor * ggml_norm_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps, + bool inplace) { + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params(result, &eps, sizeof(eps)); + + result->op = GGML_OP_NORM; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps) { + return ggml_norm_impl(ctx, a, eps, false); +} + +struct ggml_tensor * ggml_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps) { + return ggml_norm_impl(ctx, a, eps, true); +} + +// ggml_rms_norm + +static struct ggml_tensor * ggml_rms_norm_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps, + bool inplace) { + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params(result, &eps, sizeof(eps)); + + result->op = GGML_OP_RMS_NORM; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_rms_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps) { + return ggml_rms_norm_impl(ctx, a, eps, false); +} + +struct ggml_tensor * ggml_rms_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps) { + return ggml_rms_norm_impl(ctx, a, eps, true); +} + +// ggml_rms_norm_back + +struct ggml_tensor * ggml_rms_norm_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float eps) { + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + ggml_set_op_params(result, &eps, sizeof(eps)); + + result->op = GGML_OP_RMS_NORM_BACK; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_group_norm + +static struct ggml_tensor * ggml_group_norm_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups, + float eps, + bool inplace) { + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params_i32(result, 0, n_groups); + ggml_set_op_params_f32(result, 1, eps); + + result->op = GGML_OP_GROUP_NORM; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_group_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups, + float eps) { + return ggml_group_norm_impl(ctx, a, n_groups, eps, false); +} + +struct ggml_tensor * ggml_group_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups, + float eps) { + return ggml_group_norm_impl(ctx, a, n_groups, eps, true); +} + +// ggml_l2_norm + +static struct ggml_tensor * ggml_l2_norm_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps, + bool inplace) { + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params_f32(result, 0, eps); + + result->op = GGML_OP_L2_NORM; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_l2_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps) { + return ggml_l2_norm_impl(ctx, a, eps, false); +} + +struct ggml_tensor * ggml_l2_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps) { + return ggml_l2_norm_impl(ctx, a, eps, true); +} + +// ggml_mul_mat + +static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return (t0->ne[0] == t1->ne[0]) && + (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable + (t1->ne[3]%t0->ne[3] == 0); +} + +struct ggml_tensor * ggml_mul_mat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_can_mul_mat(a, b)); + GGML_ASSERT(!ggml_is_transposed(a)); + + const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_MUL_MAT; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +void ggml_mul_mat_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_MUL_MAT); + + const int32_t prec_i32 = (int32_t) prec; + + ggml_set_op_params_i32(a, 0, prec_i32); +} + +// ggml_mul_mat_id + +/* + c = ggml_mul_mat_id(ctx, as, b, ids); + + as -> [cols, rows, n_expert] + b -> [cols, n_expert_used, n_tokens] + ids -> [n_expert_used, n_tokens] (i32) + c -> [rows, n_expert_used, n_tokens] + + in b, n_expert_used can be broadcasted to match the n_expert_used of ids + + c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids +*/ +struct ggml_tensor * ggml_mul_mat_id( + struct ggml_context * ctx, + struct ggml_tensor * as, + struct ggml_tensor * b, + struct ggml_tensor * ids) { + GGML_ASSERT(!ggml_is_transposed(as)); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + GGML_ASSERT(as->ne[3] == 1); // as is 3d (one matrix per expert) + GGML_ASSERT(b->ne[3] == 1); // b is 3d + GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d + GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row + GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat + GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast + + const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_MUL_MAT_ID; + result->src[0] = as; + result->src[1] = b; + result->src[2] = ids; + + return result; +} + +// ggml_out_prod + +static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return (t0->ne[1] == t1->ne[1]) && + (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable + (t1->ne[3]%t0->ne[3] == 0); +} + +struct ggml_tensor * ggml_out_prod( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_can_out_prod(a, b)); + GGML_ASSERT(!ggml_is_transposed(a)); + + // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3] + const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_OUT_PROD; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_scale + +static struct ggml_tensor * ggml_scale_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s, + bool inplace) { + GGML_ASSERT(ggml_is_padded_1d(a)); + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params(result, &s, sizeof(s)); + + result->op = GGML_OP_SCALE; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_scale( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s) { + return ggml_scale_impl(ctx, a, s, false); +} + +struct ggml_tensor * ggml_scale_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s) { + return ggml_scale_impl(ctx, a, s, true); +} + +// ggml_set + +static struct ggml_tensor * ggml_set_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset, + bool inplace) { + GGML_ASSERT(ggml_nelements(a) >= ggml_nelements(b)); + + // make a view of the destination + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + GGML_ASSERT(offset < (size_t)(1 << 30)); + int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_SET; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_set( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + return ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, false); +} + +struct ggml_tensor * ggml_set_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + return ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, true); +} + +struct ggml_tensor * ggml_set_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t offset) { + return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, false); +} + +struct ggml_tensor * ggml_set_1d_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t offset) { + return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, true); +} + +struct ggml_tensor * ggml_set_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t offset) { + return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false); +} + +struct ggml_tensor * ggml_set_2d_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t offset) { + return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, true); +} + +// ggml_cpy + +static struct ggml_tensor * ggml_cpy_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); + + // make a view of the destination + struct ggml_tensor * result = ggml_view_tensor(ctx, b); + if (strlen(b->name) > 0) { + ggml_format_name(result, "%s (copy of %s)", b->name, a->name); + } else { + ggml_format_name(result, "%s (copy)", a->name); + } + + result->op = GGML_OP_CPY; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_cpy( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_cpy_impl(ctx, a, b); +} + +struct ggml_tensor * ggml_cast( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_type type) { + struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne); + ggml_format_name(result, "%s (copy)", a->name); + + result->op = GGML_OP_CPY; + result->src[0] = a; + result->src[1] = result; + + return result; +} + +// ggml_cont + +static struct ggml_tensor * ggml_cont_impl( + struct ggml_context * ctx, + struct ggml_tensor * a) { + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + ggml_format_name(result, "%s (cont)", a->name); + + result->op = GGML_OP_CONT; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_cont( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_cont_impl(ctx, a); +} + +// make contiguous, with new shape +GGML_API struct ggml_tensor * ggml_cont_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0) { + return ggml_cont_4d(ctx, a, ne0, 1, 1, 1); +} + +GGML_API struct ggml_tensor * ggml_cont_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1) { + return ggml_cont_4d(ctx, a, ne0, ne1, 1, 1); +} + +GGML_API struct ggml_tensor * ggml_cont_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2) { + return ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1); +} + +struct ggml_tensor * ggml_cont_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3) { + GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3)); + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); + ggml_format_name(result, "%s (cont)", a->name); + + result->op = GGML_OP_CONT; + result->src[0] = a; + + return result; +} + +// ggml_reshape + +struct ggml_tensor * ggml_reshape( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_is_contiguous(a)); + // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous. + GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); + + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b->ne, a, 0); + ggml_format_name(result, "%s (reshaped)", a->name); + + result->op = GGML_OP_RESHAPE; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_reshape_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0) { + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_nelements(a) == ne0); + + const int64_t ne[1] = { ne0 }; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, ne, a, 0); + ggml_format_name(result, "%s (reshaped)", a->name); + + result->op = GGML_OP_RESHAPE; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_reshape_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1) { + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_nelements(a) == ne0*ne1); + + const int64_t ne[2] = { ne0, ne1 }; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a, 0); + ggml_format_name(result, "%s (reshaped)", a->name); + + result->op = GGML_OP_RESHAPE; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_reshape_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2) { + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2); + + const int64_t ne[3] = { ne0, ne1, ne2 }; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a, 0); + ggml_format_name(result, "%s (reshaped)", a->name); + + result->op = GGML_OP_RESHAPE; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_reshape_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3) { + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3); + + const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0); + ggml_format_name(result, "%s (reshaped)", a->name); + + result->op = GGML_OP_RESHAPE; + result->src[0] = a; + + return result; +} + +static struct ggml_tensor * ggml_view_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_dims, + const int64_t * ne, + size_t offset) { + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, n_dims, ne, a, offset); + ggml_format_name(result, "%s (view)", a->name); + + ggml_set_op_params(result, &offset, sizeof(offset)); + + result->op = GGML_OP_VIEW; + result->src[0] = a; + + return result; +} + +// ggml_view_1d + +struct ggml_tensor * ggml_view_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + size_t offset) { + struct ggml_tensor * result = ggml_view_impl(ctx, a, 1, &ne0, offset); + + return result; +} + +// ggml_view_2d + +struct ggml_tensor * ggml_view_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + size_t nb1, + size_t offset) { + const int64_t ne[2] = { ne0, ne1 }; + + struct ggml_tensor * result = ggml_view_impl(ctx, a, 2, ne, offset); + + result->nb[1] = nb1; + result->nb[2] = result->nb[1]*ne1; + result->nb[3] = result->nb[2]; + + return result; +} + +// ggml_view_3d + +struct ggml_tensor * ggml_view_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + size_t nb1, + size_t nb2, + size_t offset) { + const int64_t ne[3] = { ne0, ne1, ne2 }; + + struct ggml_tensor * result = ggml_view_impl(ctx, a, 3, ne, offset); + + result->nb[1] = nb1; + result->nb[2] = nb2; + result->nb[3] = result->nb[2]*ne2; + + return result; +} + +// ggml_view_4d + +struct ggml_tensor * ggml_view_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; + + struct ggml_tensor * result = ggml_view_impl(ctx, a, 4, ne, offset); + + result->nb[1] = nb1; + result->nb[2] = nb2; + result->nb[3] = nb3; + + return result; +} + +// ggml_permute + +struct ggml_tensor * ggml_permute( + struct ggml_context * ctx, + struct ggml_tensor * a, + int axis0, + int axis1, + int axis2, + int axis3) { + GGML_ASSERT(axis0 >= 0 && axis0 < GGML_MAX_DIMS); + GGML_ASSERT(axis1 >= 0 && axis1 < GGML_MAX_DIMS); + GGML_ASSERT(axis2 >= 0 && axis2 < GGML_MAX_DIMS); + GGML_ASSERT(axis3 >= 0 && axis3 < GGML_MAX_DIMS); + + GGML_ASSERT(axis0 != axis1); + GGML_ASSERT(axis0 != axis2); + GGML_ASSERT(axis0 != axis3); + GGML_ASSERT(axis1 != axis2); + GGML_ASSERT(axis1 != axis3); + GGML_ASSERT(axis2 != axis3); + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + ggml_format_name(result, "%s (permuted)", a->name); + + int ne[GGML_MAX_DIMS]; + int nb[GGML_MAX_DIMS]; + + ne[axis0] = a->ne[0]; + ne[axis1] = a->ne[1]; + ne[axis2] = a->ne[2]; + ne[axis3] = a->ne[3]; + + nb[axis0] = a->nb[0]; + nb[axis1] = a->nb[1]; + nb[axis2] = a->nb[2]; + nb[axis3] = a->nb[3]; + + result->ne[0] = ne[0]; + result->ne[1] = ne[1]; + result->ne[2] = ne[2]; + result->ne[3] = ne[3]; + + result->nb[0] = nb[0]; + result->nb[1] = nb[1]; + result->nb[2] = nb[2]; + result->nb[3] = nb[3]; + + result->op = GGML_OP_PERMUTE; + result->src[0] = a; + + int32_t params[] = { axis0, axis1, axis2, axis3 }; + ggml_set_op_params(result, params, sizeof(params)); + + return result; +} + +// ggml_transpose + +struct ggml_tensor * ggml_transpose( + struct ggml_context * ctx, + struct ggml_tensor * a) { + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + ggml_format_name(result, "%s (transposed)", a->name); + + result->ne[0] = a->ne[1]; + result->ne[1] = a->ne[0]; + + result->nb[0] = a->nb[1]; + result->nb[1] = a->nb[0]; + + result->op = GGML_OP_TRANSPOSE; + result->src[0] = a; + + return result; +} + +// ggml_get_rows + +struct ggml_tensor * ggml_get_rows( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(a->ne[2] == b->ne[1]); + GGML_ASSERT(b->ne[3] == 1); + GGML_ASSERT(b->type == GGML_TYPE_I32); + + // TODO: implement non F32 return + enum ggml_type type = GGML_TYPE_F32; + if (a->type == GGML_TYPE_I32) { + type = a->type; + } + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]); + + result->op = GGML_OP_GET_ROWS; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_get_rows_back + +struct ggml_tensor * ggml_get_rows_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c) { + GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_matrix(c) && (a->ne[0] == c->ne[0])); + + // TODO: implement non F32 return + //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); + struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, c->ne[0], c->ne[1]); + + result->op = GGML_OP_GET_ROWS_BACK; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_diag + +struct ggml_tensor * ggml_diag( + struct ggml_context * ctx, + struct ggml_tensor * a) { + GGML_ASSERT(a->ne[1] == 1); + + const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 4, ne); + + result->op = GGML_OP_DIAG; + result->src[0] = a; + + return result; +} + +// ggml_diag_mask_inf + +static struct ggml_tensor * ggml_diag_mask_inf_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + bool inplace) { + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + int32_t params[] = { n_past }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_DIAG_MASK_INF; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_diag_mask_inf( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past) { + return ggml_diag_mask_inf_impl(ctx, a, n_past, false); +} + +struct ggml_tensor * ggml_diag_mask_inf_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past) { + return ggml_diag_mask_inf_impl(ctx, a, n_past, true); +} + +// ggml_diag_mask_zero + +static struct ggml_tensor * ggml_diag_mask_zero_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + bool inplace) { + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + int32_t params[] = { n_past }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_DIAG_MASK_ZERO; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_diag_mask_zero( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past) { + return ggml_diag_mask_zero_impl(ctx, a, n_past, false); +} + +struct ggml_tensor * ggml_diag_mask_zero_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past) { + return ggml_diag_mask_zero_impl(ctx, a, n_past, true); +} + +// ggml_soft_max + +static struct ggml_tensor * ggml_soft_max_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale, + float max_bias, + bool inplace) { + GGML_ASSERT(ggml_is_contiguous(a)); + + if (mask) { + GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(mask)); + GGML_ASSERT(ggml_is_matrix(mask)); + GGML_ASSERT(mask->ne[0] == a->ne[0]); + GGML_ASSERT(mask->ne[1] >= a->ne[1]); + } + + if (max_bias > 0.0f) { + GGML_ASSERT(mask); + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + float params[] = { scale, max_bias }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_SOFT_MAX; + result->src[0] = a; + result->src[1] = mask; + + return result; +} + +struct ggml_tensor * ggml_soft_max( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false); +} + +struct ggml_tensor * ggml_soft_max_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true); +} + +struct ggml_tensor * ggml_soft_max_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale, + float max_bias) { + return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false); +} + +// ggml_soft_max_ext_back + +static struct ggml_tensor * ggml_soft_max_ext_back_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float scale, + float max_bias, + bool inplace) { + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SOFT_MAX_BACK; + result->src[0] = a; + result->src[1] = b; + + memcpy((float *) result->op_params + 0, &scale, sizeof(float)); + memcpy((float *) result->op_params + 1, &max_bias, sizeof(float)); + + return result; +} + +struct ggml_tensor * ggml_soft_max_ext_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float scale, + float max_bias) { + return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, false); +} + +struct ggml_tensor * ggml_soft_max_ext_back_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float scale, + float max_bias) { + return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, true); +} + +// ggml_rope + +static struct ggml_tensor * ggml_rope_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + bool inplace) { + GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported"); + + GGML_ASSERT(ggml_is_vector(b)); + GGML_ASSERT(b->type == GGML_TYPE_I32); + GGML_ASSERT(a->ne[2] == b->ne[0]); + + if (c) { + GGML_ASSERT(c->type == GGML_TYPE_F32); + GGML_ASSERT(c->ne[0] >= n_dims / 2); + } + + int sections[4] = {0, 0, 0, 0}; + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; + memcpy(params + 5, &freq_base, sizeof(float)); + memcpy(params + 6, &freq_scale, sizeof(float)); + memcpy(params + 7, &ext_factor, sizeof(float)); + memcpy(params + 8, &attn_factor, sizeof(float)); + memcpy(params + 9, &beta_fast, sizeof(float)); + memcpy(params + 10, &beta_slow, sizeof(float)); + memcpy(params + 11, §ions, sizeof(int)*4); + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_ROPE; + result->src[0] = a; + result->src[1] = b; + result->src[2] = c; + + return result; +} + +struct ggml_tensor * ggml_rope( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode) { + return ggml_rope_impl( + ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false + ); +} + +struct ggml_tensor * ggml_rope_multi( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int sections[4], + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + // Multimodal Rotary Position Embedding + GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported"); + + GGML_ASSERT(ggml_is_vector(b)); + GGML_ASSERT(b->type == GGML_TYPE_I32); + GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token + + if (c) { + GGML_ASSERT(c->type == GGML_TYPE_F32); + GGML_ASSERT(c->ne[0] >= n_dims / 2); + } + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; + memcpy(params + 5, &freq_base, sizeof(float)); + memcpy(params + 6, &freq_scale, sizeof(float)); + memcpy(params + 7, &ext_factor, sizeof(float)); + memcpy(params + 8, &attn_factor, sizeof(float)); + memcpy(params + 9, &beta_fast, sizeof(float)); + memcpy(params + 10, &beta_slow, sizeof(float)); + memcpy(¶ms[11], sections, sizeof(int)*4); + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_ROPE; + result->src[0] = a; + result->src[1] = b; + result->src[2] = c; + + return result; +} + +struct ggml_tensor * ggml_rope_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode) { + return ggml_rope_impl( + ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true + ); +} + +struct ggml_tensor * ggml_rope_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return ggml_rope_impl( + ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, false + ); +} + +struct ggml_tensor * ggml_rope_ext_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return ggml_rope_impl( + ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, true + ); +} + +struct ggml_tensor * ggml_rope_custom( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return ggml_rope_impl( + ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, false + ); +} + +struct ggml_tensor * ggml_rope_custom_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return ggml_rope_impl( + ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, true + ); +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); +} + +void ggml_rope_yarn_corr_dims( + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2] +) { + // start and end correction dims + float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base)); + float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base)); + dims[0] = MAX(0, start); + dims[1] = MIN(n_dims - 1, end); +} + +// ggml_rope_back + +struct ggml_tensor * ggml_rope_ext_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + struct ggml_tensor * result = ggml_rope_ext( + ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + result->op = GGML_OP_ROPE_BACK; + return result; +} + +struct ggml_tensor * ggml_rope_multi_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int sections[4], + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + struct ggml_tensor * result = ggml_rope_multi( + ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + result->op = GGML_OP_ROPE_BACK; + return result; +} +// ggml_clamp + +struct ggml_tensor * ggml_clamp( + struct ggml_context * ctx, + struct ggml_tensor * a, + float min, + float max) { + // TODO: when implement backward, fix this: + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + float params[] = { min, max }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_CLAMP; + result->src[0] = a; + + return result; +} + +static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) { + return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; +} + +// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] +// a: [OC,IC, KH, KW] +// b: [N, IC, IH, IW] +// result: [N, OH, OW, IC*KH*KW] +struct ggml_tensor * ggml_im2col( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + bool is_2D, + enum ggml_type dst_type) { + if (is_2D) { + GGML_ASSERT(a->ne[2] == b->ne[2]); + } else { + //GGML_ASSERT(b->ne[1] % a->ne[1] == 0); + GGML_ASSERT(b->ne[1] == a->ne[1]); + GGML_ASSERT(b->ne[3] == 1); + } + + const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0; + const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + + GGML_ASSERT((!is_2D || OH > 0) && "b too small compared to a"); + GGML_ASSERT((OW > 0) && "b too small compared to a"); + + const int64_t ne[4] = { + is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0], + OW, + is_2D ? OH : b->ne[2], + is_2D ? b->ne[3] : 1, + }; + + struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne); + int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_IM2COL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_im2col_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int64_t * ne, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + bool is_2D) { + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_IM2COL_BACK; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_conv_1d + +struct ggml_tensor * ggml_conv_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int p0, + int d0) { + struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K] + + struct ggml_tensor * result = + ggml_mul_mat(ctx, + ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K] + ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K] + + result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL] + + return result; +} + +// ggml_conv_1d_ph + +struct ggml_tensor* ggml_conv_1d_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s, + int d) { + return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d); +} + +// ggml_conv_1d_dw + +struct ggml_tensor * ggml_conv_1d_dw( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int p0, + int d0) { + struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], 1, a->ne[1], a->ne[2]); + struct ggml_tensor * new_b = ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]); + + struct ggml_tensor * im2col = ggml_im2col(ctx, new_a, new_b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); + + struct ggml_tensor * result = ggml_mul_mat(ctx, im2col, a); + + result = ggml_reshape_3d(ctx, result, b->ne[0], b->ne[1], 1); + + return result; +} + +// ggml_conv_1d_dw_ph + +struct ggml_tensor * ggml_conv_1d_dw_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int d0) { + return ggml_conv_1d_dw(ctx, a, b, s0, a->ne[0] / 2, d0); +} + +// ggml_conv_transpose_1d + +static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) { + return (ins - 1) * s - 2 * p + d * (ks - 1) + 1; +} + +GGML_API struct ggml_tensor * ggml_conv_transpose_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int p0, + int d0) { + GGML_ASSERT(ggml_is_matrix(b)); + GGML_ASSERT(a->ne[2] == b->ne[1]); + GGML_ASSERT(a->ne[3] == 1); + + GGML_ASSERT(p0 == 0); + GGML_ASSERT(d0 == 1); + + const int64_t ne[4] = { + ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/), + a->ne[1], b->ne[2], 1, + }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + int32_t params[] = { s0, p0, d0 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_CONV_TRANSPOSE_1D; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_conv_2d + +// a: [OC,IC, KH, KW] +// b: [N, IC, IH, IW] +// result: [N, OC, OH, OW] +struct ggml_tensor * ggml_conv_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1) { + struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type); // [N, OH, OW, IC * KH * KW] + + struct ggml_tensor * result = + ggml_mul_mat(ctx, + ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW] + ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OC,IC, KH, KW] => [OC, IC * KH * KW] + + result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], im2col->ne[3], a->ne[3]); // [OC, N, OH, OW] + result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OH, OW] + + + return result; +} + +// ggml_conv_2d_sk_p0 + +struct ggml_tensor * ggml_conv_2d_sk_p0( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_conv_2d(ctx, a, b, a->ne[0], a->ne[1], 0, 0, 1, 1); +} + +// ggml_conv_2d_s1_ph + +struct ggml_tensor * ggml_conv_2d_s1_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1); +} + +// ggml_conv_2d_dw + +struct ggml_tensor * ggml_conv_2d_dw( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1) { + struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]); + struct ggml_tensor * im2col = ggml_im2col(ctx, new_a, + ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]), + s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW] + struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW] + + new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW] + struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b); + result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW] + + return result; +} + +// ggml_conv_2d_dw_direct + +struct ggml_tensor * ggml_conv_2d_dw_direct( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int stride0, + int stride1, + int pad0, + int pad1, + int dilation0, + int dilation1) { + GGML_ASSERT(a->ne[2] == 1); + GGML_ASSERT(a->ne[3] == b->ne[2]); + int64_t ne[4]; + ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], stride0, pad0, dilation0); + ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], stride1, pad1, dilation1); + ne[2] = b->ne[2]; + ne[3] = b->ne[3]; + + struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne); + + if (ggml_is_contiguous_channels(b)) { + // Result will be permuted the same way as input (CWHN order) + const int64_t type_size = ggml_type_size(result->type); + GGML_ASSERT(ggml_blck_size(result->type) == 1); + result->nb[0] = result->ne[2] * type_size; + result->nb[1] = result->ne[0] * result->nb[0]; + result->nb[2] = type_size; + } + + int32_t params[] = { stride0, stride1, pad0, pad1, dilation0, dilation1 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_CONV_2D_DW; + result->src[0] = a; + result->src[1] = b; + return result; +} + +// ggml_conv_transpose_2d_p0 + +static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) { + return (ins - 1) * s - 2 * p + ks; +} + +struct ggml_tensor * ggml_conv_transpose_2d_p0( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int stride) { + GGML_ASSERT(a->ne[3] == b->ne[2]); + + const int64_t ne[4] = { + ggml_calc_conv_transpose_output_size(b->ne[0], a->ne[0], stride, 0 /*p0*/), + ggml_calc_conv_transpose_output_size(b->ne[1], a->ne[1], stride, 0 /*p1*/), + a->ne[2], b->ne[3], + }; + + struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + ggml_set_op_params_i32(result, 0, stride); + + result->op = GGML_OP_CONV_TRANSPOSE_2D; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_pool_* + +static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, float p) { + return (ins + 2 * p - ks) / s + 1; +} + +// ggml_pool_1d + +struct ggml_tensor * ggml_pool_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_op_pool op, + int k0, + int s0, + int p0) { + const int64_t ne[4] = { + ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), + a->ne[1], + a->ne[2], + a->ne[3], + }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + int32_t params[] = { op, k0, s0, p0 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_POOL_1D; + result->src[0] = a; + + return result; +} + +// ggml_pool_2d + +struct ggml_tensor * ggml_pool_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_op_pool op, + int k0, + int k1, + int s0, + int s1, + float p0, + float p1) { + struct ggml_tensor * result; + const int64_t ne[4] = { + ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), + ggml_calc_pool_output_size(a->ne[1], k1, s1, p1), + a->ne[2], + a->ne[3], + }; + result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_POOL_2D; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_pool_2d_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * af, + enum ggml_op_pool op, + int k0, + int k1, + int s0, + int s1, + float p0, + float p1) { + struct ggml_tensor * result; + result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, af->ne); + + int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_POOL_2D_BACK; + result->src[0] = a; + result->src[1] = af; + + return result; +} + +// ggml_upscale + +static struct ggml_tensor * ggml_upscale_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1, + int ne2, + int ne3, + enum ggml_scale_mode mode) { + GGML_ASSERT(a->ne[0] <= ne0); + GGML_ASSERT(a->ne[1] <= ne1); + GGML_ASSERT(a->ne[2] <= ne2); + GGML_ASSERT(a->ne[3] <= ne3); + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); + + ggml_set_op_params_i32(result, 0, mode); + + result->op = GGML_OP_UPSCALE; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_upscale( + struct ggml_context * ctx, + struct ggml_tensor * a, + int scale_factor, + enum ggml_scale_mode mode) { + return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode); +} + +struct ggml_tensor * ggml_upscale_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1, + int ne2, + int ne3, + enum ggml_scale_mode mode) { + return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3, mode); +} + +// ggml_pad + +struct ggml_tensor * ggml_pad( + struct ggml_context * ctx, + struct ggml_tensor * a, + int p0, + int p1, + int p2, + int p3) { + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, + a->ne[0] + p0, + a->ne[1] + p1, + a->ne[2] + p2, + a->ne[3] + p3); + + result->op = GGML_OP_PAD; + result->src[0] = a; + + return result; +} + +// ggml_pad_reflect_1d + +struct ggml_tensor * ggml_pad_reflect_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int p0, + int p1) { + GGML_ASSERT(p0 >= 0); + GGML_ASSERT(p1 >= 0); + + GGML_ASSERT(p0 < a->ne[0]); // padding length on each size must be less than the + GGML_ASSERT(p1 < a->ne[0]); // existing length of the dimension being padded + + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(a->type == GGML_TYPE_F32); + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, + a->ne[0] + p0 + p1, + a->ne[1], + a->ne[2], + a->ne[3]); + + int32_t params[] = { p0, p1 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_PAD_REFLECT_1D; + result->src[0] = a; + + return result; +} + +// ggml_arange + +struct ggml_tensor * ggml_arange( + struct ggml_context * ctx, + float start, + float stop, + float step) { + GGML_ASSERT(stop > start); + + const int64_t steps = (int64_t) ceilf((stop - start) / step); + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps); + + ggml_set_op_params_f32(result, 0, start); + ggml_set_op_params_f32(result, 1, stop); + ggml_set_op_params_f32(result, 2, step); + + result->op = GGML_OP_ARANGE; + + return result; +} + +// ggml_timestep_embedding + +struct ggml_tensor * ggml_timestep_embedding( + struct ggml_context * ctx, + struct ggml_tensor * timesteps, + int dim, + int max_period) { + int actual_dim = dim; + if (dim % 2 != 0) { + actual_dim = dim + 1; + } + + struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, actual_dim, timesteps->ne[0]); + + ggml_set_op_params_i32(result, 0, dim); + ggml_set_op_params_i32(result, 1, max_period); + + result->op = GGML_OP_TIMESTEP_EMBEDDING; + result->src[0] = timesteps; + + return result; +} + +// ggml_argsort + +struct ggml_tensor * ggml_argsort( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_sort_order order) { + GGML_ASSERT(a->ne[0] <= INT32_MAX); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne); + + ggml_set_op_params_i32(result, 0, (int32_t) order); + + result->op = GGML_OP_ARGSORT; + result->src[0] = a; + + return result; +} + +// ggml_top_k + +struct ggml_tensor * ggml_top_k( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k) { + GGML_ASSERT(a->ne[0] >= k); + + struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC); + + result = ggml_view_4d(ctx, result, + k, result->ne[1], result->ne[2], result->ne[3], + result->nb[1], result->nb[2], result->nb[3], + 0); + + return result; +} + +// ggml_flash_attn_ext + +struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale, + float max_bias, + float logit_softcap) { + GGML_ASSERT(ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + + if (mask) { + GGML_ASSERT(ggml_is_contiguous(mask)); + GGML_ASSERT(mask->ne[2] == 1); + GGML_ASSERT(mask->ne[3] == 1); + GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && + "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); + //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); + } + + if (max_bias > 0.0f) { + GGML_ASSERT(mask); + } + + // permute(0, 2, 1, 3) + int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + float params[] = { scale, max_bias, logit_softcap }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_FLASH_ATTN_EXT; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = mask; + + return result; +} + +void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT); + + const int32_t prec_i32 = (int32_t) prec; + + ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second +} + +enum ggml_prec ggml_flash_attn_ext_get_prec( + const struct ggml_tensor * a) { + GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT); + + const int32_t prec_i32 = ggml_get_op_params_i32(a, 3); + + return (enum ggml_prec) prec_i32; +} + +// ggml_flash_attn_back + +struct ggml_tensor * ggml_flash_attn_back( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * d, + bool masked) { + GGML_ABORT("TODO: adapt to ggml_flash_attn_ext() changes"); + + GGML_ASSERT(ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + + // d shape [D,N,ne2,ne3] + // q shape [D,N,ne2,ne3] + // k shape [D,M,kvne2,ne3] + // v shape [M,D,kvne2,ne3] + + const int64_t D = q->ne[0]; + const int64_t N = q->ne[1]; + const int64_t M = k->ne[1]; + const int64_t ne2 = q->ne[2]; + const int64_t ne3 = q->ne[3]; + const int64_t kvne2 = k->ne[2]; + + GGML_ASSERT(k->ne[0] == D); + GGML_ASSERT(v->ne[0] == M); + GGML_ASSERT(v->ne[1] == D); + GGML_ASSERT(d->ne[0] == D); + GGML_ASSERT(d->ne[1] == N); + GGML_ASSERT(k->ne[2] == kvne2); + GGML_ASSERT(k->ne[3] == ne3); + GGML_ASSERT(v->ne[2] == kvne2); + GGML_ASSERT(v->ne[3] == ne3); + GGML_ASSERT(d->ne[2] == ne2); + GGML_ASSERT(d->ne[3] == ne3); + + GGML_ASSERT(ne2 % kvne2 == 0); + + // store gradients of q, k and v as continuous tensors concatenated in result. + // note: v and gradv are actually transposed, i.e. v->ne[0] != D. + const int64_t elem_q = ggml_nelements(q); + const int64_t elem_k = ggml_nelements(k); + const int64_t elem_v = ggml_nelements(v); + + enum ggml_type result_type = GGML_TYPE_F32; + GGML_ASSERT(ggml_blck_size(result_type) == 1); + const size_t tsize = ggml_type_size(result_type); + + const size_t offs_q = 0; + const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN); + const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN); + const size_t end = offs_v + GGML_PAD(elem_v * tsize, GGML_MEM_ALIGN); + + const size_t nelements = (end + tsize - 1)/tsize; + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelements); + + int32_t masked_i = masked ? 1 : 0; + ggml_set_op_params(result, &masked_i, sizeof(masked_i)); + + result->op = GGML_OP_FLASH_ATTN_BACK; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = d; + + return result; +} + +// ggml_ssm_conv + +struct ggml_tensor * ggml_ssm_conv( + struct ggml_context * ctx, + struct ggml_tensor * sx, + struct ggml_tensor * c) { + GGML_ASSERT(ggml_is_3d(sx)); + GGML_ASSERT(ggml_is_matrix(c)); + + const int64_t d_conv = c->ne[0]; + const int64_t d_inner = c->ne[1]; + const int64_t n_t = sx->ne[0] - d_conv + 1; // tokens per sequence + const int64_t n_s = sx->ne[2]; + + // TODO: maybe support other strides than 1? + // FIXME: this is always true? + GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t); + GGML_ASSERT(sx->ne[1] == d_inner); + GGML_ASSERT(n_t >= 0); + + struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s); + + result->op = GGML_OP_SSM_CONV; + result->src[0] = sx; + result->src[1] = c; + + return result; +} + +// ggml_ssm_scan + +struct ggml_tensor * ggml_ssm_scan( + struct ggml_context * ctx, + struct ggml_tensor * s, + struct ggml_tensor * x, + struct ggml_tensor * dt, + struct ggml_tensor * A, + struct ggml_tensor * B, + struct ggml_tensor * C) { + GGML_ASSERT(ggml_is_contiguous(s)); + GGML_ASSERT(ggml_is_contiguous(x)); + GGML_ASSERT(ggml_is_contiguous(dt)); + GGML_ASSERT(ggml_is_contiguous(A)); + GGML_ASSERT(ggml_is_matrix(A)); + GGML_ASSERT(ggml_is_3d(B)); + GGML_ASSERT(ggml_is_3d(s)); + GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); + GGML_ASSERT(C->nb[0] == ggml_type_size(C->type)); + GGML_ASSERT(ggml_are_same_shape(x, dt)); + GGML_ASSERT(ggml_are_same_shape(B, C)); + + { + const int64_t d_state = s->ne[0]; + const int64_t d_inner = s->ne[1]; + const int64_t n_seq_tokens = x->ne[1]; + const int64_t n_seqs = x->ne[2]; + + GGML_ASSERT(s->ne[2] == n_seqs); + GGML_ASSERT(x->ne[0] == d_inner); + GGML_ASSERT(A->ne[0] == d_state); + GGML_ASSERT(A->ne[1] == d_inner); + GGML_ASSERT(B->ne[0] == d_state); + GGML_ASSERT(B->ne[1] == n_seq_tokens); + GGML_ASSERT(B->ne[2] == n_seqs); + } + + // concatenated y + ssm_states + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); + + result->op = GGML_OP_SSM_SCAN; + result->src[0] = s; + result->src[1] = x; + result->src[2] = dt; + result->src[3] = A; + result->src[4] = B; + result->src[5] = C; + + return result; +} + +// ggml_win_part + +struct ggml_tensor * ggml_win_part( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w) { + GGML_ASSERT(a->ne[3] == 1); + GGML_ASSERT(a->type == GGML_TYPE_F32); + + // padding + const int px = (w - a->ne[1]%w)%w; + const int py = (w - a->ne[2]%w)%w; + + const int npx = (px + a->ne[1])/w; + const int npy = (py + a->ne[2])/w; + const int np = npx*npy; + + const int64_t ne[4] = { a->ne[0], w, w, np, }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + int32_t params[] = { npx, npy, w }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_WIN_PART; + result->src[0] = a; + + return result; +} + +// ggml_win_unpart + +struct ggml_tensor * ggml_win_unpart( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w0, + int h0, + int w) { + GGML_ASSERT(a->type == GGML_TYPE_F32); + + const int64_t ne[4] = { a->ne[0], w0, h0, 1, }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne); + + int32_t params[] = { w }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_WIN_UNPART; + result->src[0] = a; + + return result; +} + +// ggml_get_rel_pos + +struct ggml_tensor * ggml_get_rel_pos( + struct ggml_context * ctx, + struct ggml_tensor * a, + int qh, + int kh) { + GGML_ASSERT(qh == kh); + GGML_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]); + + const int64_t ne[4] = { a->ne[0], kh, qh, 1, }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 3, ne); + + result->op = GGML_OP_GET_REL_POS; + result->src[0] = a; + + return result; +} + +// ggml_add_rel_pos + +static struct ggml_tensor * ggml_add_rel_pos_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * pw, + struct ggml_tensor * ph, + bool inplace) { + GGML_ASSERT(ggml_are_same_shape(pw, ph)); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_is_contiguous(pw)); + GGML_ASSERT(ggml_is_contiguous(ph)); + GGML_ASSERT(ph->type == GGML_TYPE_F32); + GGML_ASSERT(pw->type == GGML_TYPE_F32); + GGML_ASSERT(pw->ne[3] == a->ne[2]); + GGML_ASSERT(pw->ne[0]*pw->ne[0] == a->ne[0]); + GGML_ASSERT(pw->ne[1]*pw->ne[2] == a->ne[1]); + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + ggml_set_op_params_i32(result, 0, inplace ? 1 : 0); + + result->op = GGML_OP_ADD_REL_POS; + result->src[0] = a; + result->src[1] = pw; + result->src[2] = ph; + + return result; +} + +struct ggml_tensor * ggml_add_rel_pos( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * pw, + struct ggml_tensor * ph) { + return ggml_add_rel_pos_impl(ctx, a, pw, ph, false); +} + +struct ggml_tensor * ggml_add_rel_pos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * pw, + struct ggml_tensor * ph) { + return ggml_add_rel_pos_impl(ctx, a, pw, ph, true); +} + +// ggml_rwkv_wkv6 + +struct ggml_tensor * ggml_rwkv_wkv6( + struct ggml_context * ctx, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * r, + struct ggml_tensor * tf, + struct ggml_tensor * td, + struct ggml_tensor * state) { + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(r)); + GGML_ASSERT(ggml_is_contiguous(tf)); + GGML_ASSERT(ggml_is_contiguous(td)); + GGML_ASSERT(ggml_is_contiguous(state)); + + const int64_t S = k->ne[0]; + const int64_t H = k->ne[1]; + const int64_t n_tokens = k->ne[2]; + const int64_t n_seqs = state->ne[1]; + { + GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens); + GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens); + GGML_ASSERT(td->ne[0] == S && td->ne[1] == H && td->ne[2] == n_tokens); + GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs); + } + + // concat output and new_state + const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_RWKV_WKV6; + result->src[0] = k; + result->src[1] = v; + result->src[2] = r; + result->src[3] = tf; + result->src[4] = td; + result->src[5] = state; + + return result; +} + +// ggml_gated_linear_attn + +struct ggml_tensor * ggml_gated_linear_attn( + struct ggml_context * ctx, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * q, + struct ggml_tensor * g, + struct ggml_tensor * state, + float scale) { + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(q)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(state)); + + const int64_t S = k->ne[0]; + const int64_t H = k->ne[1]; + const int64_t n_tokens = k->ne[2]; + const int64_t n_seqs = state->ne[1]; + { + GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens); + GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens); + GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens); + GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs); + } + + // concat output and new_state + const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + ggml_set_op_params_f32(result, 0, scale); + + result->op = GGML_OP_GATED_LINEAR_ATTN; + result->src[0] = k; + result->src[1] = v; + result->src[2] = q; + result->src[3] = g; + result->src[4] = state; + + return result; +} + +// ggml_rwkv_wkv7 + +struct ggml_tensor * ggml_rwkv_wkv7( + struct ggml_context * ctx, + struct ggml_tensor * r, + struct ggml_tensor * w, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * state) { + GGML_ASSERT(ggml_is_contiguous(r)); + GGML_ASSERT(ggml_is_contiguous(w)); + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_is_contiguous(b)); + GGML_ASSERT(ggml_is_contiguous(state)); + + const int64_t S = k->ne[0]; + const int64_t H = k->ne[1]; + const int64_t n_tokens = k->ne[2]; + const int64_t n_seqs = state->ne[1]; + { + GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens); + GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens); + GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens); + GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens); + GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens); + GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs); + } + + // concat output and new_state + const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_RWKV_WKV7; + result->src[0] = r; + result->src[1] = w; + result->src[2] = k; + result->src[3] = v; + result->src[4] = a; + result->src[5] = b; + result->src[6] = state; + + return result; +} + +// ggml_unary + +static struct ggml_tensor * ggml_unary_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_unary_op op, + bool inplace) { + GGML_ASSERT(ggml_is_contiguous_1(a)); + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params_i32(result, 0, (int32_t) op); + + result->op = GGML_OP_UNARY; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_unary( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_unary_op op) { + return ggml_unary_impl(ctx, a, op, false); +} + +struct ggml_tensor * ggml_unary_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_unary_op op) { + return ggml_unary_impl(ctx, a, op, true); +} + +// ggml_map_custom1 + +static struct ggml_tensor * ggml_map_custom1_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_custom1_op_t fun, + int n_tasks, + void * userdata, + bool inplace) { + GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0); + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + struct ggml_map_custom1_op_params params = { + /*.fun =*/ fun, + /*.n_tasks =*/ n_tasks, + /*.userdata =*/ userdata + }; + ggml_set_op_params(result, ¶ms, sizeof(params)); + + result->op = GGML_OP_MAP_CUSTOM1; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_map_custom1( + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_custom1_op_t fun, + int n_tasks, + void * userdata) { + return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, false); +} + +struct ggml_tensor * ggml_map_custom1_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_custom1_op_t fun, + int n_tasks, + void * userdata) { + return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, true); +} + +// ggml_map_custom2 + +static struct ggml_tensor * ggml_map_custom2_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_custom2_op_t fun, + int n_tasks, + void * userdata, + bool inplace) { + GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0); + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + struct ggml_map_custom2_op_params params = { + /*.fun =*/ fun, + /*.n_tasks =*/ n_tasks, + /*.userdata =*/ userdata + }; + ggml_set_op_params(result, ¶ms, sizeof(params)); + + result->op = GGML_OP_MAP_CUSTOM2; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_map_custom2( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_custom2_op_t fun, + int n_tasks, + void * userdata) { + return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, false); +} + +struct ggml_tensor * ggml_map_custom2_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_custom2_op_t fun, + int n_tasks, + void * userdata) { + return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, true); +} + +// ggml_map_custom3 + +static struct ggml_tensor * ggml_map_custom3_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_custom3_op_t fun, + int n_tasks, + void * userdata, + bool inplace) { + GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0); + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + struct ggml_map_custom3_op_params params = { + /*.fun =*/ fun, + /*.n_tasks =*/ n_tasks, + /*.userdata =*/ userdata + }; + ggml_set_op_params(result, ¶ms, sizeof(params)); + + result->op = GGML_OP_MAP_CUSTOM3; + result->src[0] = a; + result->src[1] = b; + result->src[2] = c; + + return result; +} + +struct ggml_tensor * ggml_map_custom3( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_custom3_op_t fun, + int n_tasks, + void * userdata) { + return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, false); +} + +struct ggml_tensor * ggml_map_custom3_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_custom3_op_t fun, + int n_tasks, + void * userdata) { + return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true); +} + +struct ggml_tensor * ggml_custom_4d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + struct ggml_tensor ** args, + int n_args, + ggml_custom_op_t fun, + int n_tasks, + void * userdata) { + + GGML_ASSERT(n_args < GGML_MAX_SRC); + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3); + + struct ggml_custom_op_params params = { + /*.fun =*/ fun, + /*.n_tasks =*/ n_tasks, + /*.userdata =*/ userdata + }; + ggml_set_op_params(result, ¶ms, sizeof(params)); + + result->op = GGML_OP_CUSTOM; + for (int i = 0; i < n_args; i++) { + result->src[i] = args[i]; + } + + return result; +} + +struct ggml_tensor * ggml_custom_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor ** args, + int n_args, + ggml_custom_op_t fun, + int n_tasks, + void * userdata) { + + GGML_ASSERT(n_args < GGML_MAX_SRC - 1); + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + struct ggml_custom_op_params params = { + /*.fun =*/ fun, + /*.n_tasks =*/ n_tasks, + /*.userdata =*/ userdata + }; + ggml_set_op_params(result, ¶ms, sizeof(params)); + + result->op = GGML_OP_CUSTOM; + result->src[0] = a; + for (int i = 0; i < n_args; i++) { + result->src[i + 1] = args[i]; + } + + return result; +} +// ggml_cross_entropy_loss + +struct ggml_tensor * ggml_cross_entropy_loss( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_are_same_shape(a, b)); + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1); + + result->op = GGML_OP_CROSS_ENTROPY_LOSS; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_cross_entropy_loss_back + +struct ggml_tensor * ggml_cross_entropy_loss_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c) { + GGML_ASSERT(ggml_is_scalar(a)); + GGML_ASSERT(ggml_are_same_shape(b, c)); + + struct ggml_tensor * result = ggml_dup_tensor(ctx, b); + + result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK; + result->src[0] = a; + result->src[1] = b; + result->src[2] = c; + + return result; +} + +// opt_step_adamw + +struct ggml_tensor * ggml_opt_step_adamw( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * grad, + struct ggml_tensor * m, + struct ggml_tensor * v, + struct ggml_tensor * adamw_params) { + GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM); + GGML_ASSERT(ggml_are_same_shape(a, grad)); + GGML_ASSERT(ggml_are_same_shape(a, m)); + GGML_ASSERT(ggml_are_same_shape(a, v)); + GGML_ASSERT(adamw_params->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_nelements(adamw_params) == 7); + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + result->op = GGML_OP_OPT_STEP_ADAMW; + result->src[0] = a; + result->src[1] = grad; + result->src[2] = m; + result->src[3] = v; + result->src[4] = adamw_params; + + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +struct ggml_hash_set ggml_hash_set_new(size_t size) { + size = ggml_hash_size(size); + struct ggml_hash_set result; + result.size = size; + result.keys = GGML_MALLOC(sizeof(struct ggml_tensor *) * size); + result.used = GGML_CALLOC(ggml_bitset_size(size), sizeof(ggml_bitset_t)); + return result; +} + +void ggml_hash_set_reset(struct ggml_hash_set * hash_set) { + memset(hash_set->used, 0, sizeof(ggml_bitset_t) * ggml_bitset_size(hash_set->size)); +} + +void ggml_hash_set_free(struct ggml_hash_set * hash_set) { + GGML_FREE(hash_set->used); + GGML_FREE(hash_set->keys); +} + +size_t ggml_hash_size(size_t min_sz) { + // next primes after powers of two + static const size_t primes[] = { + 2, 3, 5, 11, 17, 37, 67, 131, 257, 521, 1031, + 2053, 4099, 8209, 16411, 32771, 65537, 131101, + 262147, 524309, 1048583, 2097169, 4194319, 8388617, + 16777259, 33554467, 67108879, 134217757, 268435459, + 536870923, 1073741827, 2147483659 + }; + static const size_t n_primes = sizeof(primes)/sizeof(primes[0]); + + // find the smallest prime that is larger or equal than min_sz + size_t l = 0; + size_t r = n_primes; + while (l < r) { + size_t m = (l + r)/2; + if (primes[m] < min_sz) { + l = m + 1; + } else { + r = m; + } + } + size_t sz = l < n_primes ? primes[l] : min_sz | 1; + return sz; +} + +struct hash_map { + struct ggml_hash_set set; + struct ggml_tensor ** vals; +}; + +static struct hash_map * ggml_new_hash_map(size_t size) { + struct hash_map * result = GGML_MALLOC(sizeof(struct hash_map)); + result->set = ggml_hash_set_new(size); + result->vals = GGML_CALLOC(result->set.size, sizeof(struct ggml_tensor *)); + return result; +} + +static void ggml_hash_map_free(struct hash_map * map) { + ggml_hash_set_free(&map->set); + GGML_FREE(map->vals); + GGML_FREE(map); +} + +// utility functions to change gradients +// isrc is the index of tensor in cgraph->visited_has_set.keys +// the corresponding gradient (accumulators) are also at position isrc +// if tensor has a gradient accumulator, modify that accumulator in-place +// else if there is no gradient for tensor, set the corresponding value +// else, just add/subtract/etc. the gradients + +static void ggml_add_or_set( + struct ggml_context * ctx, + struct ggml_cgraph * cgraph, + size_t isrc, + struct ggml_tensor * tensor) { + struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc]; + GGML_ASSERT(src); + if (cgraph->grads[isrc]) { + cgraph->grads[isrc] = ggml_add_impl(ctx, cgraph->grads[isrc], tensor, /*inplace =*/ cgraph->grad_accs[isrc]); + } else { + cgraph->grads[isrc] = tensor; + } + ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name); + ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); +} + +static void ggml_acc_or_set( + struct ggml_context * ctx, + struct ggml_cgraph * cgraph, + size_t isrc, + struct ggml_tensor * tensor, + const size_t nb1, + const size_t nb2, + const size_t nb3, + const size_t offset) { + struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc]; + GGML_ASSERT(src); + if (cgraph->grads[isrc]) { + cgraph->grads[isrc] = ggml_acc_impl(ctx, cgraph->grads[isrc], tensor, nb1, nb2, nb3, offset, cgraph->grad_accs[isrc]); + } else { + struct ggml_tensor * a_zero = ggml_scale(ctx, src, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN + cgraph->grads[isrc] = ggml_acc_impl(ctx, a_zero, tensor, nb1, nb2, nb3, offset, false); + } + ggml_format_name(cgraph->grads[isrc], "grad for %s", cgraph->visited_hash_set.keys[isrc]->name); + ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); +} + +static void ggml_add1_or_set( + struct ggml_context * ctx, + struct ggml_cgraph * cgraph, + size_t isrc, + struct ggml_tensor * tensor) { + struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc]; + GGML_ASSERT(src); + if (cgraph->grads[isrc]) { + cgraph->grads[isrc] = ggml_add1_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]); + } else { + cgraph->grads[isrc] = ggml_repeat(ctx, tensor, src); + } + ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name); + ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); +} + +static void ggml_sub_or_set( + struct ggml_context * ctx, + struct ggml_cgraph * cgraph, + size_t isrc, + struct ggml_tensor * tensor) { + struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc]; + GGML_ASSERT(src); + if (cgraph->grads[isrc]) { + cgraph->grads[isrc] = ggml_sub_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]); + } else { + cgraph->grads[isrc] = ggml_neg(ctx, tensor); + } + ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name); + ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); +} + +static void ggml_compute_backward( + struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, const bool * grads_needed) { + struct ggml_tensor * tensor = cgraph->nodes[i]; + struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, tensor); + + if (!grad) { + return; + } + + struct ggml_tensor * src0 = tensor->src[0]; + struct ggml_tensor * src1 = tensor->src[1]; + struct ggml_tensor * src2 = tensor->src[2]; + struct ggml_hash_set * hash_set = &cgraph->visited_hash_set; + const size_t isrc0 = src0 ? ggml_hash_find(hash_set, src0) : (size_t) -1; + const size_t isrc1 = src1 ? ggml_hash_find(hash_set, src1) : (size_t) -1; + const size_t isrc2 = src2 ? ggml_hash_find(hash_set, src2) : (size_t) -1; + const bool src0_needs_grads = src0 && isrc0 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0]; + const bool src1_needs_grads = src1 && isrc1 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1]; + const bool src2_needs_grads = src2 && isrc2 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2]; + + switch (tensor->op) { + case GGML_OP_DUP: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, grad); + } + } break; + case GGML_OP_ADD: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, grad); + } + if (src1_needs_grads) { + struct ggml_tensor * tmp = grad; + if (!ggml_are_same_shape(src0, src1)) { + tmp = ggml_repeat_back(ctx, tmp, src1); + } + ggml_add_or_set(ctx, cgraph, isrc1, tmp); + } + } break; + case GGML_OP_ADD1: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, grad); + } + if (src1_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc1, ggml_mean(ctx, grad)); // TODO: should probably be sum instead of mean + } + } break; + case GGML_OP_ACC: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, grad); + } + if (src1_needs_grads) { + const size_t nb1 = ((int32_t *) tensor->op_params)[0]; + const size_t nb2 = ((int32_t *) tensor->op_params)[1]; + const size_t nb3 = ((int32_t *) tensor->op_params)[2]; + const size_t offset = ((int32_t *) tensor->op_params)[3]; + + struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx, + grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + nb1, nb2, nb3, offset); + + ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1)); + } + } break; + case GGML_OP_SUB: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, grad); + } + if (src1_needs_grads) { + ggml_sub_or_set(ctx, cgraph, isrc1, grad); + } + } break; + case GGML_OP_MUL: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1)); + } + if (src1_needs_grads) { + struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad); + if (!ggml_are_same_shape(src0, src1)) { + tmp = ggml_repeat_back(ctx, tmp, src1); + } + ggml_add_or_set(ctx, cgraph, isrc1, tmp); + } + } break; + case GGML_OP_DIV: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src1)); + } + if (src1_needs_grads) { + ggml_sub_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, grad, ggml_div(ctx, tensor, src1))); + } + } break; + case GGML_OP_SQR: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_mul(ctx, src0, grad), 2.0f)); + } + } break; + case GGML_OP_SQRT: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_div(ctx, grad, tensor), 0.5f)); + } + } break; + case GGML_OP_LOG: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src0)); + } + } break; + case GGML_OP_SIN: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_cos(ctx, src0))); + } + } break; + case GGML_OP_COS: { + if (src0_needs_grads) { + ggml_sub_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_sin(ctx, src0))); + } + } break; + case GGML_OP_SUM: { + if (src0_needs_grads) { + ggml_add1_or_set(ctx, cgraph, isrc0, grad); + } + } break; + case GGML_OP_SUM_ROWS: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0)); + } + } break; + case GGML_OP_MEAN: { + if (src0_needs_grads) { + ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false)); + } + } break; + case GGML_OP_REPEAT: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat_back(ctx, grad, src0)); + } + } break; + case GGML_OP_REPEAT_BACK: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0)); + } + } break; + case GGML_OP_RMS_NORM: { + if (src0_needs_grads) { + float eps; + memcpy(&eps, tensor->op_params, sizeof(float)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, grad, src0, eps)); + } + } break; + case GGML_OP_MUL_MAT: { + // https://cs231n.github.io/optimization-2/#staged + // # forward pass + // s0 = np.random.randn(5, 10) + // s1 = np.random.randn(10, 3) + // t = s0.dot(s1) + + // # now suppose we had the gradient on t from above in the circuit + // dt = np.random.randn(*t.shape) # same shape as t + // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix + // ds1 = t.T.dot(dt) + + // tensor.shape [m,p,qq,rr] + // src0.shape [n,m,q1,r1] + // src1.shape [n,p,qq,rr] + + if (src0_needs_grads) { + GGML_ASSERT(grad->ne[2] == src1->ne[2]); + GGML_ASSERT(grad->ne[3] == src1->ne[3]); + struct ggml_tensor * tmp = + ggml_out_prod(ctx, // [n,m,qq,rr] + src1, // [n,p,qq,rr] + grad); // [m,p,qq,rr] + if (!ggml_are_same_shape(tmp, src0)) { + GGML_ASSERT(tmp->ne[0] == src0->ne[0]); + GGML_ASSERT(tmp->ne[1] == src0->ne[1]); + GGML_ASSERT(tmp->ne[3] == 1); + + const int64_t nr2 = tmp->ne[2] / src0->ne[2]; + const size_t nb2 = tmp->nb[2] * nr2; + const size_t nb3 = tmp->nb[2]; + + tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0); + tmp = ggml_repeat_back(ctx, tmp, src0); + } + ggml_add_or_set(ctx, cgraph, isrc0, tmp); + } + if (src1_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc1, + // ggml_mul_mat(ctx, // [n,p,qq,rr] + // ggml_cont(ctx, // [m,n,q1,r1] + // ggml_transpose(ctx, src0)), // [m,n,q1,r1] + // grad), // [m,p,qq,rr] + + // when src0 is bigger than tensor->grad (this is mostly the case in llama), + // avoid transpose of src0, rather transpose smaller tensor->grad + // and then use ggml_out_prod + ggml_out_prod(ctx, // [n,p,qq,rr] + src0, // [n,m,q1,r1] + ggml_transpose(ctx, // [p,m,qq,rr] + grad))); // [m,p,qq,rr] + } + } break; + case GGML_OP_SCALE: { + if (src0_needs_grads) { + float s; + memcpy(&s, tensor->op_params, sizeof(float)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, false)); + } + } break; + case GGML_OP_SET: { + const size_t nb1 = ((const int32_t *) tensor->op_params)[0]; + const size_t nb2 = ((const int32_t *) tensor->op_params)[1]; + const size_t nb3 = ((const int32_t *) tensor->op_params)[2]; + const size_t offset = ((const int32_t *) tensor->op_params)[3]; + + struct ggml_tensor * tensor_grad_view = NULL; + + if (src0_needs_grads || src1_needs_grads) { + GGML_ASSERT(src0->type == tensor->type); + GGML_ASSERT(!cgraph->grads[isrc0] || cgraph->grads[isrc0]->type == grad->type); + GGML_ASSERT(!cgraph->grads[isrc1] || !src1_needs_grads || cgraph->grads[isrc1]->type == grad->type); + + tensor_grad_view = ggml_view_4d(ctx, + grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + nb1, nb2, nb3, offset); + } + + if (src0_needs_grads) { + struct ggml_tensor * tmp = ggml_neg(ctx, tensor_grad_view); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_acc_impl(ctx, grad, tmp, nb1, nb2, nb3, offset, false)); + } + + if (src1_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1)); + } + } break; + case GGML_OP_CPY: { + // cpy overwrites value of src1 by src0 and returns view(src1) + // the overwriting is mathematically equivalent to: + // tensor = src0 * 1 + src1 * 0 + if (src0_needs_grads) { + // dsrc0 = dtensor * 1 + ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad, src0)); + } + if (src1_needs_grads) { + // dsrc1 = dtensor * 0 -> noop + } + } break; + case GGML_OP_CONT: { + // same as cpy + if (src0_needs_grads) { + GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0])); + GGML_ASSERT(ggml_is_contiguous(grad)); + GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0)); + ggml_add_or_set(ctx, cgraph, isrc0, + ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0)); + } + } break; + case GGML_OP_RESHAPE: { + if (src0_needs_grads) { + struct ggml_tensor * grad_cont = ggml_is_contiguous(grad) ? grad : ggml_cont(ctx, grad); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad_cont, src0)); + } + } break; + case GGML_OP_VIEW: { + if (src0_needs_grads) { + size_t offset; + + memcpy(&offset, tensor->op_params, sizeof(offset)); + + size_t nb1 = tensor->nb[1]; + size_t nb2 = tensor->nb[2]; + size_t nb3 = tensor->nb[3]; + + if (cgraph->grads[isrc0] && src0->type != cgraph->grads[isrc0]->type) { + // gradient is typically F32, but src0 could be other type + size_t ng = ggml_element_size(cgraph->grads[isrc0]); + size_t n0 = ggml_element_size(src0); + GGML_ASSERT(offset % n0 == 0); + GGML_ASSERT(nb1 % n0 == 0); + GGML_ASSERT(nb2 % n0 == 0); + GGML_ASSERT(nb3 % n0 == 0); + offset = (offset / n0) * ng; + nb1 = (nb1 / n0) * ng; + nb2 = (nb2 / n0) * ng; + nb3 = (nb3 / n0) * ng; + } + + ggml_acc_or_set(ctx, cgraph, isrc0, grad, nb1, nb2, nb3, offset); + } + } break; + case GGML_OP_PERMUTE: { + if (src0_needs_grads) { + const int32_t * axes = (const int32_t *) tensor->op_params; + const int axis0 = axes[0] & 0x3; + const int axis1 = axes[1] & 0x3; + const int axis2 = axes[2] & 0x3; + const int axis3 = axes[3] & 0x3; + int axb[4] = {0,0,0,0}; // axes backward + axb[axis0] = 0; + axb[axis1] = 1; + axb[axis2] = 2; + axb[axis3] = 3; + ggml_add_or_set(ctx, cgraph, isrc0, ggml_permute(ctx, grad, axb[0], axb[1], axb[2], axb[3])); + } + } break; + case GGML_OP_TRANSPOSE: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_transpose(ctx, grad)); + } + } break; + case GGML_OP_GET_ROWS: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_get_rows_back(ctx, grad, src1, src0)); + } + if (src1_needs_grads) { + // noop + } + } break; + case GGML_OP_DIAG_MASK_INF: { + if (src0_needs_grads) { + /* ggml_diag_mask_inf_impl() shouldn't be here */ + /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */ + const int n_past = ((const int32_t *) tensor->op_params)[0]; + ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false)); + } + } break; + case GGML_OP_DIAG_MASK_ZERO: { + if (src0_needs_grads) { + const int n_past = ((const int32_t *) tensor->op_params)[0]; + ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false)); + } + } break; + case GGML_OP_SOFT_MAX: { + if (src0_needs_grads) { + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) tensor->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) tensor->op_params + 1, sizeof(float)); + + ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_ext_back(ctx, grad, tensor, scale, max_bias)); + } + GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented"); + } break; + case GGML_OP_ROPE: { + if (src0_needs_grads) { + //const int n_past = ((int32_t *) tensor->op_params)[0]; + const int n_dims = ((const int32_t *) tensor->op_params)[1]; + const int mode = ((const int32_t *) tensor->op_params)[2]; + //const int n_ctx = ((int32_t *) tensor->op_params)[3]; + const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4]; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + int sections[4] = {0, 0, 0, 0}; + + memcpy(&freq_base, (const float *) tensor->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (const float *) tensor->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (const float *) tensor->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (const float *) tensor->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (const float *) tensor->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (const float *) tensor->op_params + 10, sizeof(float)); + memcpy(§ions, tensor->op_params + 11, sizeof(sections)); + + struct ggml_tensor * rope_back = grad->ne[2] == src1->ne[0] ? + ggml_rope_ext_back(ctx, grad, src1, src2, n_dims, + mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow) : + ggml_rope_multi_back(ctx, grad, src1, src2, n_dims, sections, + mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + ggml_add_or_set(ctx, cgraph, isrc0, rope_back); + } + GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented"); + } break; + case GGML_OP_IM2COL: { + if (src1_needs_grads) { + const int32_t s0 = ggml_get_op_params_i32(tensor, 0); + const int32_t s1 = ggml_get_op_params_i32(tensor, 1); + const int32_t p0 = ggml_get_op_params_i32(tensor, 2); + const int32_t p1 = ggml_get_op_params_i32(tensor, 3); + const int32_t d0 = ggml_get_op_params_i32(tensor, 4); + const int32_t d1 = ggml_get_op_params_i32(tensor, 5); + const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1; + + ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, grad, src0, src1->ne, s0, s1, p0, p1, d0, d1, is_2D)); + } + } break; + case GGML_OP_POOL_2D: { + if (src0_needs_grads) { + const enum ggml_op_pool op = ggml_get_op_params_i32(tensor, 0); + const int32_t k0 = ggml_get_op_params_i32(tensor, 1); + const int32_t k1 = ggml_get_op_params_i32(tensor, 2); + const int32_t s0 = ggml_get_op_params_i32(tensor, 3); + const int32_t s1 = ggml_get_op_params_i32(tensor, 4); + const int32_t p0 = ggml_get_op_params_i32(tensor, 5); + const int32_t p1 = ggml_get_op_params_i32(tensor, 6); + + ggml_add_or_set(ctx, cgraph, isrc0, ggml_pool_2d_back(ctx, grad, src0, op, k0, k1, s0, s1, p0, p1)); + } + } break; + case GGML_OP_WIN_PART: + case GGML_OP_WIN_UNPART: + case GGML_OP_UNARY: { + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_ABS: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_sgn(ctx, src0), grad)); + } + } break; + case GGML_UNARY_OP_SGN: { + // noop + } break; + case GGML_UNARY_OP_NEG: { + if (src0_needs_grads) { + ggml_sub_or_set(ctx, cgraph, isrc0, grad); + } + } break; + case GGML_UNARY_OP_STEP: { + // noop + } break; + case GGML_UNARY_OP_RELU: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_step(ctx, src0), grad)); + } + } break; + case GGML_UNARY_OP_SILU: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, grad, src0)); + } + } break; + case GGML_UNARY_OP_EXP: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, tensor, grad)); + } + } break; + default: { + fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n", + __func__, ggml_unary_op_name(ggml_get_unary_op(tensor))); + GGML_ABORT("fatal error"); + } //break; + } + } break; + case GGML_OP_CROSS_ENTROPY_LOSS: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, grad, src0, src1)); + } + GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented"); + } break; + case GGML_OP_NONE: { + // noop + } break; + case GGML_OP_COUNT: + default: { + fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op)); + GGML_ABORT("fatal error"); + } //break; + } + + GGML_ASSERT(!src0_needs_grads || ggml_are_same_shape(src0, cgraph->grads[isrc0])); + GGML_ASSERT(!src1_needs_grads || ggml_are_same_shape(src1, cgraph->grads[isrc1])); + GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2])); +} + +static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { + // check if already visited + if (ggml_hash_insert(&cgraph->visited_hash_set, node) == GGML_HASHSET_ALREADY_EXISTS) { + return; + } + + for (int i = 0; i < GGML_MAX_SRC; ++i) { + const int k = + (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i : + (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) : + /* unknown order, just fall back to using i*/ i; + if (node->src[k]) { + ggml_visit_parents(cgraph, node->src[k]); + } + } + + if (node->op == GGML_OP_NONE && !(node->flags & GGML_TENSOR_FLAG_PARAM)) { + // reached a leaf node, not part of the gradient graph (e.g. a constant) + GGML_ASSERT(cgraph->n_leafs < cgraph->size); + + if (strlen(node->name) == 0) { + ggml_format_name(node, "leaf_%d", cgraph->n_leafs); + } + + cgraph->leafs[cgraph->n_leafs] = node; + cgraph->n_leafs++; + } else { + GGML_ASSERT(cgraph->n_nodes < cgraph->size); + + if (strlen(node->name) == 0) { + ggml_format_name(node, "node_%d", cgraph->n_nodes); + } + + cgraph->nodes[cgraph->n_nodes] = node; + cgraph->n_nodes++; + } +} + +static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) { + if (!expand) { + // TODO: this branch isn't accessible anymore, maybe move this to ggml_build_forward_expand + ggml_graph_clear(cgraph); + } + + const int n0 = cgraph->n_nodes; + + ggml_visit_parents(cgraph, tensor); + + const int n_new = cgraph->n_nodes - n0; + GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new); + + if (n_new > 0) { + // the last added node should always be starting point + GGML_ASSERT(cgraph->nodes[cgraph->n_nodes - 1] == tensor); + } +} + +void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) { + ggml_build_forward_impl(cgraph, tensor, true); +} + +void ggml_build_backward_expand( + struct ggml_context * ctx, + struct ggml_cgraph * cgraph, + struct ggml_tensor ** grad_accs) { + GGML_ASSERT(cgraph->n_nodes > 0); + GGML_ASSERT(cgraph->grads); + GGML_ASSERT(cgraph->grad_accs); + + const int n_nodes_f = cgraph->n_nodes; + + memset(cgraph->grads, 0, cgraph->visited_hash_set.size*sizeof(struct ggml_tensor *)); + memset(cgraph->grad_accs, 0, cgraph->visited_hash_set.size*sizeof(struct ggml_tensor *)); + bool * grads_needed = calloc(cgraph->visited_hash_set.size, sizeof(bool)); + + { + bool any_params = false; + bool any_loss = false; + for (int i = 0; i < n_nodes_f; ++i) { + struct ggml_tensor * node = cgraph->nodes[i]; + any_params = any_params || (node->flags & GGML_TENSOR_FLAG_PARAM); + any_loss = any_loss || (node->flags & GGML_TENSOR_FLAG_LOSS); + } + GGML_ASSERT(any_params && "no trainable parameters found, did you forget to call ggml_set_param?"); + GGML_ASSERT(any_loss && "no training loss found, did you forget to call ggml_set_loss?"); + } + + for (int i = 0; i < n_nodes_f; ++i) { + struct ggml_tensor * node = cgraph->nodes[i]; + + if (node->type == GGML_TYPE_I32) { + continue; + } + + bool node_needs_grad = (node->flags & GGML_TENSOR_FLAG_PARAM) || (node->flags & GGML_TENSOR_FLAG_LOSS); + bool ignore_src[GGML_MAX_SRC] = {false}; + switch (node->op) { + // gradients in node->src[0] for one reason or another have no effect on output gradients + case GGML_OP_IM2COL: // only used for its shape + case GGML_OP_IM2COL_BACK: // same as IM2COL + ignore_src[0] = true; + break; + case GGML_OP_UNARY: { + const enum ggml_unary_op uop = ggml_get_unary_op(node); + // SGN and STEP unary ops are piecewise constant + if (uop == GGML_UNARY_OP_SGN || uop == GGML_UNARY_OP_STEP) { + ignore_src[0] = true; + } + } break; + + // gradients in node->src[1] for one reason or another have no effect on output gradients + case GGML_OP_CPY: // gradients in CPY target are irrelevant + case GGML_OP_GET_ROWS: // row indices not differentiable + case GGML_OP_GET_ROWS_BACK: // same as for GET_ROWS + case GGML_OP_ROPE: // positions not differentiable + ignore_src[1] = true; + break; + + default: + break; + } + for (int j = 0; j < GGML_MAX_SRC; ++j) { + if (!node->src[j] || ignore_src[j] || !grads_needed[ggml_hash_find(&cgraph->visited_hash_set, node->src[j])]) { + continue; + } + GGML_ASSERT(node->src[j]->type == GGML_TYPE_F32 || node->src[j]->type == GGML_TYPE_F16); + node_needs_grad = true; + break; + } + if (!node_needs_grad) { + continue; + } + + // inplace operations are currently not supported + GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW || + node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE); + + const size_t ihash = ggml_hash_find(&cgraph->visited_hash_set, node); + GGML_ASSERT(ihash != GGML_HASHSET_FULL); + GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, ihash)); + if (grad_accs && grad_accs[i]) { + cgraph->grad_accs[ihash] = grad_accs[i]; + cgraph->grads[ihash] = cgraph->grad_accs[ihash]; + } else if (node->flags & GGML_TENSOR_FLAG_LOSS) { + // loss tensors always need a gradient accumulator + cgraph->grad_accs[ihash] = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); + cgraph->grads[ihash] = cgraph->grad_accs[ihash]; + } + grads_needed[ihash] = true; + } + + for (int i = n_nodes_f - 1; i >= 0; --i) { + // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation + // use allocator to automatically make inplace operations + ggml_compute_backward(ctx, cgraph, i, grads_needed); + } + + free(grads_needed); +} + +static void * incr_ptr_aligned(void ** p, size_t size, size_t align) { + void * ptr = *p; + ptr = (void *) GGML_PAD((uintptr_t) ptr, align); + *p = (void *) ((char *) ptr + size); + return ptr; +} + +static size_t ggml_graph_nbytes(size_t size, bool grads) { + size_t hash_size = ggml_hash_size(size * 2); + void * p = 0; + incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1); + incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes + incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs + incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys + if (grads) { + incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads + incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grad_accs + } + incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t)); + + size_t nbytes = (size_t) p; + return nbytes; +} + +size_t ggml_graph_overhead_custom(size_t size, bool grads) { + return GGML_OBJECT_SIZE + GGML_PAD(ggml_graph_nbytes(size, grads), GGML_MEM_ALIGN); +} + +size_t ggml_graph_overhead(void) { + return ggml_graph_overhead_custom(GGML_DEFAULT_GRAPH_SIZE, false); +} + +struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads) { + const size_t obj_size = ggml_graph_nbytes(size, grads); + struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_GRAPH, obj_size); + struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs); + + // the size of the hash table is doubled since it needs to hold both nodes and leafs + size_t hash_size = ggml_hash_size(size * 2); + + void * p = cgraph + 1; + + struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); + struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); + struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); + struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; + struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; + + ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t)); + + // check that we allocated the correct amount of memory + assert(obj_size == (size_t)((char *)p - (char *)cgraph)); + + *cgraph = (struct ggml_cgraph) { + /*.size =*/ size, + /*.n_nodes =*/ 0, + /*.n_leafs =*/ 0, + /*.nodes =*/ nodes_ptr, + /*.grads =*/ grads_ptr, + /*.grad_accs =*/ grad_accs_ptr, + /*.leafs =*/ leafs_ptr, + /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr }, + /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, + }; + + ggml_hash_set_reset(&cgraph->visited_hash_set); + if (grads) { + memset(cgraph->grads, 0, hash_size*sizeof(struct ggml_tensor *)); + memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *)); + } + + return cgraph; +} + +struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) { + return ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, false); +} + +struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) { + struct ggml_cgraph cgraph = { + /*.size =*/ 0, + /*.n_nodes =*/ i1 - i0, + /*.n_leafs =*/ 0, + /*.nodes =*/ cgraph0->nodes + i0, + /*.grads =*/ NULL, // gradients would need visited_hash_set + /*.grad_accs =*/ NULL, + /*.leafs =*/ NULL, + /*.visited_hash_set =*/ { 0, NULL, NULL }, + /*.order =*/ cgraph0->order, + }; + + return cgraph; +} + +void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) { + GGML_ASSERT(dst->size >= src->n_leafs); + GGML_ASSERT(dst->size >= src->n_nodes); + GGML_ASSERT(dst->visited_hash_set.size >= src->visited_hash_set.size); + + dst->n_leafs = src->n_leafs; + dst->n_nodes = src->n_nodes; + dst->order = src->order; + + for (int i = 0; i < src->n_leafs; ++i) { + dst->leafs[i] = src->leafs[i]; + } + + for (int i = 0; i < src->n_nodes; ++i) { + dst->nodes[i] = src->nodes[i]; + } + + for (size_t i = 0; i < src->visited_hash_set.size; ++i) { + // copy all hashset keys (tensors) that are in use + if (ggml_bitset_get(src->visited_hash_set.used, i)) { + ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]); + } + } + + if (dst->grads) { + memset(dst->grads, 0, dst->visited_hash_set.size*sizeof(struct ggml_tensor *)); + memset(dst->grad_accs, 0, dst->visited_hash_set.size*sizeof(struct ggml_tensor *)); + } + if (src->grads) { + GGML_ASSERT(dst->grads != NULL); + GGML_ASSERT(dst->grad_accs != NULL); + for (int i = 0; i < src->n_nodes; ++i) { + const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]); + const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]); + + GGML_ASSERT(igrad_src != GGML_HASHSET_FULL); + GGML_ASSERT(ggml_bitset_get(src->visited_hash_set.used, igrad_src)); + GGML_ASSERT(igrad_dst != GGML_HASHSET_FULL); + GGML_ASSERT(ggml_bitset_get(dst->visited_hash_set.used, igrad_dst)); + + dst->grads[igrad_dst] = src->grads[igrad_src]; + dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src]; + } + } +} + +struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads) { + struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads); + ggml_graph_cpy(cgraph, result); + return result; +} + +struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) { + if (ggml_is_empty(tensor)) { + return tensor; + } + if (tensor->buffer) { + ggml_backend_tensor_memset(tensor, 0, 0, ggml_nbytes(tensor)); + } else { + GGML_ASSERT(tensor->data); + memset(tensor->data, 0, ggml_nbytes(tensor)); + } + return tensor; +} + +void ggml_graph_reset(struct ggml_cgraph * cgraph) { + if (!cgraph) { + return; + } + GGML_ASSERT(cgraph->grads != NULL); + + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * node = cgraph->nodes[i]; + struct ggml_tensor * grad_acc = ggml_graph_get_grad_acc(cgraph, node); + + if (node->op == GGML_OP_OPT_STEP_ADAMW) { + // clear momenta + ggml_set_zero(node->src[2]); + ggml_set_zero(node->src[3]); + } + + // initial gradients of loss should be 1, 0 otherwise + if (grad_acc) { + if (node->flags & GGML_TENSOR_FLAG_LOSS) { + GGML_ASSERT(grad_acc->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_scalar(grad_acc)); + + const float onef = 1.0f; + if (grad_acc->buffer) { + ggml_backend_tensor_set(grad_acc, &onef, 0, sizeof(float)); + } else { + GGML_ASSERT(grad_acc->data); + *((float *) grad_acc->data) = onef; + } + } else { + ggml_set_zero(grad_acc); + } + } + } +} + +void ggml_graph_clear(struct ggml_cgraph * cgraph) { + cgraph->n_leafs = 0; + cgraph->n_nodes = 0; + ggml_hash_set_reset(&cgraph->visited_hash_set); +} + +int ggml_graph_size(struct ggml_cgraph * cgraph) { + return cgraph->size; +} + +struct ggml_tensor * ggml_graph_node(struct ggml_cgraph * cgraph, int i) { + if (i < 0) { + GGML_ASSERT(cgraph->n_nodes + i >= 0); + return cgraph->nodes[cgraph->n_nodes + i]; + } + + GGML_ASSERT(i < cgraph->n_nodes); + return cgraph->nodes[i]; +} + +struct ggml_tensor ** ggml_graph_nodes(struct ggml_cgraph * cgraph) { + return cgraph->nodes; +} + +int ggml_graph_n_nodes(struct ggml_cgraph * cgraph) { + return cgraph->n_nodes; +} + +void ggml_graph_add_node(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) { + GGML_ASSERT(cgraph->size > cgraph->n_nodes); + cgraph->nodes[cgraph->n_nodes] = tensor; + cgraph->n_nodes++; +} + +struct ggml_tensor * ggml_graph_get_tensor(const struct ggml_cgraph * cgraph, const char * name) { + for (int i = 0; i < cgraph->n_leafs; i++) { + struct ggml_tensor * leaf = cgraph->leafs[i]; + + if (strcmp(leaf->name, name) == 0) { + return leaf; + } + } + + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * node = cgraph->nodes[i]; + + if (strcmp(node->name, name) == 0) { + return node; + } + } + + return NULL; +} + +struct ggml_tensor * ggml_graph_get_grad(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { + const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node); + return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) && cgraph->grads ? cgraph->grads[igrad] : NULL; +} + +struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { + const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node); + return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) && cgraph->grad_accs ? cgraph->grad_accs[igrad] : NULL; +} + +void ggml_graph_print(const struct ggml_cgraph * cgraph) { + GGML_LOG_INFO("=== GRAPH ===\n"); + + GGML_LOG_INFO("n_nodes = %d\n", cgraph->n_nodes); + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * node = cgraph->nodes[i]; + + GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n", + i, + node->ne[0], node->ne[1], node->ne[2], + ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" : + ggml_graph_get_grad(cgraph, node) ? "g" : " "); + } + + GGML_LOG_INFO("n_leafs = %d\n", cgraph->n_leafs); + for (int i = 0; i < cgraph->n_leafs; i++) { + struct ggml_tensor * node = cgraph->leafs[i]; + + GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n", + i, + node->ne[0], node->ne[1], + ggml_op_name(node->op), + ggml_get_name(node)); + } + + GGML_LOG_INFO("========================================\n"); +} + +// check if node is part of the graph +static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { + if (cgraph == NULL) { + return true; + } + + for (int i = 0; i < cgraph->n_nodes; i++) { + if (cgraph->nodes[i] == node) { + return true; + } + } + + return false; +} + +static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * parent = cgraph->nodes[i]; + struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, parent); + + if (grad == node) { + return parent; + } + } + + return NULL; +} + +static void ggml_graph_dump_dot_node_edge(FILE * fp, const struct ggml_cgraph * gb, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) { + struct ggml_tensor * gparent = ggml_graph_get_parent(gb, node); + struct ggml_tensor * gparent0 = ggml_graph_get_parent(gb, parent); + fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"%s\"; ]\n", + gparent0 ? (void *) gparent0 : (void *) parent, + gparent0 ? "g" : "x", + gparent ? (void *) gparent : (void *) node, + gparent ? "g" : "x", + gparent ? "empty" : "vee", + gparent ? "dashed" : "solid", + label); +} + +static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) { + fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"%s\"; ]\n", + (void *) parent, "x", + (void *) node, "x", + label); +} + +void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) { + char color[16]; + + FILE * fp = ggml_fopen(filename, "w"); + GGML_ASSERT(fp); + + fprintf(fp, "digraph G {\n"); + fprintf(fp, " newrank = true;\n"); + fprintf(fp, " rankdir = TB;\n"); + + for (int i = 0; i < gb->n_nodes; i++) { + struct ggml_tensor * node = gb->nodes[i]; + struct ggml_tensor * grad = ggml_graph_get_grad(gb, node); + + if (ggml_graph_get_parent(gb, node) != NULL) { + continue; + } + + if (node->flags & GGML_TENSOR_FLAG_PARAM) { + snprintf(color, sizeof(color), "yellow"); + } else if (grad) { + if (ggml_graph_find(gf, node)) { + snprintf(color, sizeof(color), "green"); + } else { + snprintf(color, sizeof(color), "lightblue"); + } + } else { + snprintf(color, sizeof(color), "white"); + } + + fprintf(fp, " \"%p\" [ " + "style = filled; fillcolor = %s; shape = record; " + "label=\"", + (void *) node, color); + + if (strlen(node->name) > 0) { + fprintf(fp, "%s (%s)|", node->name, ggml_type_name(node->type)); + } else { + fprintf(fp, "(%s)|", ggml_type_name(node->type)); + } + + if (ggml_is_matrix(node)) { + fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], ggml_op_symbol(node->op)); + } else { + fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], node->ne[2], ggml_op_symbol(node->op)); + } + + if (grad) { + fprintf(fp, " | %s\"; ]\n", ggml_op_symbol(grad->op)); + } else { + fprintf(fp, "\"; ]\n"); + } + } + + for (int i = 0; i < gb->n_leafs; i++) { + struct ggml_tensor * node = gb->leafs[i]; + + snprintf(color, sizeof(color), "pink"); + + fprintf(fp, " \"%p\" [ " + "style = filled; fillcolor = %s; shape = record; " + "label=\"", + (void *) node, color); + + if (strlen(node->name) > 0) { + fprintf(fp, "%s (%s)|", node->name, ggml_type_name(node->type)); + } else { + fprintf(fp, "(%s)|", ggml_type_name(node->type)); + } + + fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]); + if (ggml_nelements(node) < 5 && node->data != NULL) { + fprintf(fp, " | ("); + for (int j = 0; j < ggml_nelements(node); j++) { + // FIXME: use ggml-backend to obtain the tensor data + //if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) { + // fprintf(fp, "%d", ggml_get_i32_1d(node, j)); + //} + //else if (node->type == GGML_TYPE_F32 || + // node->type == GGML_TYPE_F16 || + // node->type == GGML_TYPE_BF16) { + // fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, j)); + //} + //else + { + fprintf(fp, "#"); + } + if (j < ggml_nelements(node) - 1) { + fprintf(fp, ", "); + } + } + fprintf(fp, ")"); + } + fprintf(fp, "\"; ]\n"); + } + + for (int i = 0; i < gb->n_nodes; i++) { + struct ggml_tensor * node = gb->nodes[i]; + + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (node->src[j]) { + char label[16]; + snprintf(label, sizeof(label), "src %d", j); + ggml_graph_dump_dot_node_edge(fp, gb, node, node->src[j], label); + } + } + } + + for (int i = 0; i < gb->n_leafs; i++) { + struct ggml_tensor * node = gb->leafs[i]; + + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (node->src[j]) { + char label[16]; + snprintf(label, sizeof(label), "src %d", j); + ggml_graph_dump_dot_leaf_edge(fp, node, node->src[j], label); + } + } + } + + fprintf(fp, "}\n"); + + fclose(fp); + + GGML_LOG_INFO("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename); +} + +//////////////////////////////////////////////////////////////////////////////// + +void ggml_set_input(struct ggml_tensor * tensor) { + tensor->flags |= GGML_TENSOR_FLAG_INPUT; +} + +void ggml_set_output(struct ggml_tensor * tensor) { + tensor->flags |= GGML_TENSOR_FLAG_OUTPUT; +} + +void ggml_set_param(struct ggml_tensor * tensor) { + GGML_ASSERT(tensor->op == GGML_OP_NONE); + tensor->flags |= GGML_TENSOR_FLAG_PARAM; +} + +void ggml_set_loss(struct ggml_tensor * tensor) { + GGML_ASSERT(ggml_is_scalar(tensor)); + GGML_ASSERT(tensor->type == GGML_TYPE_F32); + tensor->flags |= GGML_TENSOR_FLAG_LOSS; +} + +//////////////////////////////////////////////////////////////////////////////// + +void ggml_quantize_init(enum ggml_type type) { + ggml_critical_section_start(); + + switch (type) { + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: iq2xs_init_impl(type); break; + case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break; + case GGML_TYPE_IQ3_S: iq3xs_init_impl(512); break; + default: // nothing + break; + } + + ggml_critical_section_end(); +} + +void ggml_quantize_free(void) { + ggml_critical_section_start(); + + iq2xs_free_impl(GGML_TYPE_IQ2_XXS); + iq2xs_free_impl(GGML_TYPE_IQ2_XS); + iq2xs_free_impl(GGML_TYPE_IQ1_S); + iq3xs_free_impl(256); + + ggml_critical_section_end(); +} + +bool ggml_quantize_requires_imatrix(enum ggml_type type) { + return + type == GGML_TYPE_IQ2_XXS || + type == GGML_TYPE_IQ2_XS || + type == GGML_TYPE_IQ1_S;// || + //type == GGML_TYPE_IQ1_M; +} + +size_t ggml_quantize_chunk( + enum ggml_type type, + const float * src, + void * dst, + int64_t start, + int64_t nrows, + int64_t n_per_row, + const float * imatrix) { + const int64_t n = (int64_t) nrows * n_per_row; + + if (ggml_quantize_requires_imatrix(type)) { + GGML_ASSERT(imatrix != NULL); + } + + GGML_ASSERT(start % type_traits[type].blck_size == 0); + GGML_ASSERT(start % n_per_row == 0); + + ggml_quantize_init(type); // this is noop if already initialized + + const size_t start_row = start / n_per_row; + const size_t row_size = ggml_row_size(type, n_per_row); + + size_t result = 0; + + switch (type) { + case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q5_K: result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_TQ1_0: result = quantize_tq1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_TQ2_0: result = quantize_tq2_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ3_S: result = quantize_iq3_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ2_S: result = quantize_iq2_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_F16: + { + size_t elemsize = sizeof(ggml_fp16_t); + ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n); + result = n * elemsize; + } break; + case GGML_TYPE_BF16: + { + size_t elemsize = sizeof(ggml_bf16_t); + ggml_fp32_to_bf16_row_ref(src + start, (ggml_bf16_t *)dst + start, n); + result = n * elemsize; + } break; + case GGML_TYPE_F32: + { + size_t elemsize = sizeof(float); + result = n * elemsize; + memcpy((uint8_t *)dst + start * elemsize, src + start, result); + } break; + default: + assert(false); + } + + GGML_ASSERT(result == nrows * row_size); + + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +void ggml_log_set(ggml_log_callback log_callback, void * user_data) { + g_logger_state.log_callback = log_callback ? log_callback : ggml_log_callback_default; + g_logger_state.log_callback_user_data = user_data; +} + +void ggml_threadpool_params_init(struct ggml_threadpool_params * p, int n_threads) { + p->n_threads = n_threads; + p->prio = 0; // default priority (usually means normal or inherited) + p->poll = 50; // hybrid-polling enabled + p->strict_cpu = false; // no strict placement (all threads share same cpumask) + p->paused = false; // threads are ready to go + memset(p->cpumask, 0, GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited) +} + +struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads) { + struct ggml_threadpool_params p; + ggml_threadpool_params_init(&p, n_threads); + return p; +} + +bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1) { + if (p0->n_threads != p1->n_threads ) return false; + if (p0->prio != p1->prio ) return false; + if (p0->poll != p1->poll ) return false; + if (p0->strict_cpu != p1->strict_cpu ) return false; + return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0; +} diff --git a/ggml/src/ggml.cpp b/ggml/src/ggml.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0d388d45536d1fb6fd39ee35a5b8e0583e212b54 --- /dev/null +++ b/ggml/src/ggml.cpp @@ -0,0 +1,26 @@ +#include "ggml-impl.h" + +#include +#include + +static std::terminate_handler previous_terminate_handler; + +GGML_NORETURN static void ggml_uncaught_exception() { + ggml_print_backtrace(); + if (previous_terminate_handler) { + previous_terminate_handler(); + } + abort(); // unreachable unless previous_terminate_handler was nullptr +} + +static bool ggml_uncaught_exception_init = []{ + const char * GGML_NO_BACKTRACE = getenv("GGML_NO_BACKTRACE"); + if (GGML_NO_BACKTRACE) { + return false; + } + const auto prev{std::get_terminate()}; + GGML_ASSERT(prev != ggml_uncaught_exception); + previous_terminate_handler = prev; + std::set_terminate(ggml_uncaught_exception); + return true; +}(); diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a0a318a29f5b9ef97e3ce1da3ebe87b030f54611 --- /dev/null +++ b/ggml/src/gguf.cpp @@ -0,0 +1,1347 @@ +#include "ggml.h" +#include "ggml-backend.h" +#include "ggml-impl.h" +#include "gguf.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +template +struct type_to_gguf_type; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_UINT8; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_INT8; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_UINT16; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_INT16; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_UINT32; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_INT32; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_FLOAT32; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_BOOL; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_STRING; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_UINT64; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_INT64; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_FLOAT64; +}; + +static const std::map GGUF_TYPE_SIZE = { + {GGUF_TYPE_UINT8, sizeof(uint8_t)}, + {GGUF_TYPE_INT8, sizeof(int8_t)}, + {GGUF_TYPE_UINT16, sizeof(uint16_t)}, + {GGUF_TYPE_INT16, sizeof(int16_t)}, + {GGUF_TYPE_UINT32, sizeof(uint32_t)}, + {GGUF_TYPE_INT32, sizeof(int32_t)}, + {GGUF_TYPE_FLOAT32, sizeof(float)}, + {GGUF_TYPE_BOOL, sizeof(int8_t)}, + {GGUF_TYPE_STRING, 0}, // undefined + {GGUF_TYPE_ARRAY, 0}, // undefined + {GGUF_TYPE_UINT64, sizeof(uint64_t)}, + {GGUF_TYPE_INT64, sizeof(int64_t)}, + {GGUF_TYPE_FLOAT64, sizeof(double)}, +}; +static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13"); + +static const std::map GGUF_TYPE_NAME = { + {GGUF_TYPE_UINT8, "u8"}, + {GGUF_TYPE_INT8, "i8"}, + {GGUF_TYPE_UINT16, "u16"}, + {GGUF_TYPE_INT16, "i16"}, + {GGUF_TYPE_UINT32, "u32"}, + {GGUF_TYPE_INT32, "i32"}, + {GGUF_TYPE_FLOAT32, "f32"}, + {GGUF_TYPE_BOOL, "bool"}, + {GGUF_TYPE_STRING, "str"}, + {GGUF_TYPE_ARRAY, "arr"}, + {GGUF_TYPE_UINT64, "u64"}, + {GGUF_TYPE_INT64, "i64"}, + {GGUF_TYPE_FLOAT64, "f64"}, +}; +static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13"); + +size_t gguf_type_size(enum gguf_type type) { + auto it = GGUF_TYPE_SIZE.find(type); + return it == GGUF_TYPE_SIZE.end() ? 0 : it->second; +} + +struct gguf_kv { + std::string key; + + bool is_array; + enum gguf_type type; + + std::vector data; + std::vector data_string; + + template + gguf_kv(const std::string & key, const T value) + : key(key), is_array(false), type(type_to_gguf_type::value) { + GGML_ASSERT(!key.empty()); + data.resize(sizeof(T)); + memcpy(data.data(), &value, sizeof(T)); + } + + template + gguf_kv(const std::string & key, const std::vector & value) + : key(key), is_array(true), type(type_to_gguf_type::value) { + GGML_ASSERT(!key.empty()); + data.resize(value.size()*sizeof(T)); + for (size_t i = 0; i < value.size(); ++i) { + const T tmp = value[i]; + memcpy(data.data() + i*sizeof(T), &tmp, sizeof(T)); + } + } + + gguf_kv(const std::string & key, const std::string & value) + : key(key), is_array(false), type(GGUF_TYPE_STRING) { + GGML_ASSERT(!key.empty()); + data_string.push_back(value); + } + + gguf_kv(const std::string & key, const std::vector & value) + : key(key), is_array(true), type(GGUF_TYPE_STRING) { + GGML_ASSERT(!key.empty()); + data_string = value; + } + + const std::string & get_key() const { + return key; + } + + const enum gguf_type & get_type() const { + return type; + } + + size_t get_ne() const { + if (type == GGUF_TYPE_STRING) { + const size_t ne = data_string.size(); + GGML_ASSERT(is_array || ne == 1); + return ne; + } + const size_t type_size = gguf_type_size(type); + GGML_ASSERT(data.size() % type_size == 0); + const size_t ne = data.size() / type_size; + GGML_ASSERT(is_array || ne == 1); + return ne; + } + + template + const T & get_val(const size_t i = 0) const { + GGML_ASSERT(type_to_gguf_type::value == type); + if constexpr (std::is_same::value) { + GGML_ASSERT(data_string.size() >= i+1); + return data_string[i]; + } + const size_t type_size = gguf_type_size(type); + GGML_ASSERT(data.size() % type_size == 0); + GGML_ASSERT(data.size() >= (i+1)*type_size); + return reinterpret_cast(data.data())[i]; + } + + void cast(const enum gguf_type new_type) { + const size_t new_type_size = gguf_type_size(new_type); + GGML_ASSERT(data.size() % new_type_size == 0); + type = new_type; + } +}; + +struct gguf_tensor_info { + struct ggml_tensor t; // for holding the equivalent info + uint64_t offset; // offset from start of `data`, must be a multiple of `ALIGNMENT` +}; + +struct gguf_context { + uint32_t version = GGUF_VERSION; + + std::vector kv; + std::vector info; + + size_t alignment = GGUF_DEFAULT_ALIGNMENT; + size_t offset = 0; // offset of `data` from beginning of file + size_t size = 0; // size of `data` in bytes + + void * data = nullptr; +}; + +struct gguf_reader { + FILE * file; + + gguf_reader(FILE * file) : file(file) {} + + template + bool read(T & dst) const { + return fread(&dst, 1, sizeof(dst), file) == sizeof(dst); + } + + template + bool read(std::vector & dst, const size_t n) const { + dst.resize(n); + for (size_t i = 0; i < dst.size(); ++i) { + if constexpr (std::is_same::value) { + bool tmp; + if (!read(tmp)) { + return false; + } + dst[i] = tmp; + } else { + if (!read(dst[i])) { + return false; + } + } + } + return true; + } + + bool read(bool & dst) const { + int8_t tmp = -1; + if (!read(tmp)) { + return false; + } + dst = tmp != 0; + return true; + } + + bool read(enum ggml_type & dst) const { + int32_t tmp = -1; + if (!read(tmp)) { + return false; + } + dst = ggml_type(tmp); + return true; + } + + bool read(enum gguf_type & dst) const { + int32_t tmp = -1; + if (!read(tmp)) { + return false; + } + dst = gguf_type(tmp); + return true; + } + + bool read(std::string & dst) const { + uint64_t size = -1; + if (!read(size)) { + return false; + } + dst.resize(size); + return fread(dst.data(), 1, dst.length(), file) == dst.length(); + } + + bool read(void * dst, const size_t size) const { + return fread(dst, 1, size, file) == size; + } +}; + +struct gguf_context * gguf_init_empty(void) { + return new gguf_context; +} + +template +bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vector & kv, const std::string & key, const bool is_array, const size_t n) { + if (is_array) { + std::vector value; + try { + if (!gr.read(value, n)) { + return false; + } + } catch (std::length_error &) { + GGML_LOG_ERROR("%s: encountered length_error while reading value for key '%s'\n", __func__, key.c_str()); + return false; + } catch (std::bad_alloc &) { + GGML_LOG_ERROR("%s: encountered bad_alloc error while reading value for key '%s'\n", __func__, key.c_str()); + return false; + } + kv.emplace_back(key, value); + } else { + T value; + if (!gr.read(value)) { + return false; + } + kv.emplace_back(key, value); + } + return true; +} + +struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params) { + const struct gguf_reader gr(file); + struct gguf_context * ctx = new gguf_context; + + bool ok = true; + + // file magic + { + std::vector magic; + ok = ok && gr.read(magic, 4); + + if (!ok) { + GGML_LOG_ERROR("%s: failed to read magic\n", __func__); + gguf_free(ctx); + return nullptr; + } + + for (uint32_t i = 0; i < magic.size(); i++) { + if (magic[i] != GGUF_MAGIC[i]) { + GGML_LOG_ERROR("%s: invalid magic characters: '%c%c%c%c', expected 'GGUF'\n", __func__, magic[0], magic[1], magic[2], magic[3]); + gguf_free(ctx); + return nullptr; + } + } + } + + // header + int64_t n_kv = 0; + int64_t n_tensors = 0; + + if (ok && gr.read(ctx->version)) { + if (ok && ctx->version == 0) { + GGML_LOG_ERROR("%s: bad GGUF version: %" PRIu32 "\n", __func__, ctx->version); + ok = false; + } + + /* + * bit layout is different when reading non-native endian models. + * assuming that the GGUF version is 3, the non-native endian model + * would read it as 0x30000000. we can use the AND operation against + * the last 4 hexadecimal digits to check if the model is the same + * endianness as the host system. + */ + if (ok && (ctx->version & 0x0000FFFF) == 0x00000000) { + GGML_LOG_ERROR("%s: failed to load model: this GGUF file version %" PRIu32 " is extremely large, is there a mismatch between the host and model endianness?\n", __func__, ctx->version); + ok = false; + } + + if (ok && ctx->version == 1) { + GGML_LOG_ERROR("%s: GGUFv1 is no longer supported, please use a more up-to-date version\n", __func__); + ok = false; + } + if (ok && ctx->version > GGUF_VERSION) { + GGML_LOG_ERROR("%s: this GGUF file is version %" PRIu32 " but this software only supports up to version %d\n", + __func__, ctx->version, GGUF_VERSION); + ok = false; + } + } else { + ok = false; + } + + if (ok && gr.read(n_tensors)) { + static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing"); + if (n_tensors < 0 || n_tensors > int64_t(SIZE_MAX/sizeof(gguf_tensor_info))) { + GGML_LOG_ERROR("%s: number of tensors is %" PRIi64 " but must be in [0, %zu]\n", + __func__, n_tensors, SIZE_MAX/sizeof(gguf_tensor_info)); + ok = false; + } + } else { + ok = false; + } + + if (ok && gr.read(n_kv)) { + static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing"); + if (n_kv < 0 || n_kv > int64_t(SIZE_MAX/sizeof(gguf_kv))) { + GGML_LOG_ERROR("%s: number of key value pairs is %" PRIi64 " but must be in [0, %zu]\n", + __func__, n_kv, SIZE_MAX/sizeof(gguf_kv)); + ok = false; + } + } else { + ok = false; + } + + if (!ok) { + GGML_LOG_ERROR("%s: failed to read header\n", __func__); + gguf_free(ctx); + return nullptr; + } + + // KV pairs + { + for (int64_t i = 0; ok && i < n_kv; ++i) { + std::string key; + gguf_type type = gguf_type(-1); + bool is_array = false; + uint64_t n = 1; + + try { + ok = ok && gr.read(key); + } catch (std::length_error &) { + GGML_LOG_ERROR("%s: encountered length_error while reading key %" PRIi64 "\n", __func__, i); + ok = false; + } catch (std::bad_alloc &) { + GGML_LOG_ERROR("%s: encountered bad_alloc error while reading key %" PRIi64 "\n", __func__, i); + ok = false; + } + for (size_t j = 0; ok && j < ctx->kv.size(); ++j) { + if (key == ctx->kv[j].key) { + GGML_LOG_ERROR("%s: duplicate key '%s' for tensors %zu and %" PRIi64 " \n", __func__, key.c_str(), j, i); + ok = false; + } + } + if (!ok) { + break; + } + + ok = ok && gr.read(type); + if (type == GGUF_TYPE_ARRAY) { + is_array = true; + ok = ok && gr.read(type); + ok = ok && gr.read(n); + } + if (!ok) { + break; + } + + switch (type) { + case GGUF_TYPE_UINT8: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_INT8: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_UINT16: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_INT16: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_UINT32: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_INT32: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_FLOAT32: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_BOOL: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_STRING: ok = ok && gguf_read_emplace_helper(gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_UINT64: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_INT64: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_FLOAT64: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_ARRAY: + default: + { + GGML_LOG_ERROR("%s: key '%s' has invalid GGUF type %d\n", __func__, key.c_str(), type); + ok = false; + } break; + } + } + + if (!ok) { + GGML_LOG_ERROR("%s: failed to read key-value pairs\n", __func__); + gguf_free(ctx); + return nullptr; + } + GGML_ASSERT(int64_t(ctx->kv.size()) == n_kv); + + const int alignment_idx = gguf_find_key(ctx, GGUF_KEY_GENERAL_ALIGNMENT); + ctx->alignment = alignment_idx == -1 ? GGUF_DEFAULT_ALIGNMENT : gguf_get_val_u32(ctx, alignment_idx); + + if (ctx->alignment == 0 || (ctx->alignment & (ctx->alignment - 1)) != 0) { + GGML_LOG_ERROR("%s: alignment %zu is not a power of 2\n", __func__, ctx->alignment); + gguf_free(ctx); + return nullptr; + } + } + + // read the tensor info + for (int64_t i = 0; ok && i < n_tensors; ++i) { + struct gguf_tensor_info info; + + // tensor name + { + std::string name; + try { + ok = ok && gr.read(name); + } catch (std::length_error &) { + GGML_LOG_ERROR("%s: encountered length_error while reading tensor name %" PRIi64 "\n", __func__, i); + ok = false; + } catch (std::bad_alloc &) { + GGML_LOG_ERROR("%s: encountered bad_alloc error while reading tensor name %" PRIi64 "\n", __func__, i); + ok = false; + } + if (name.length() >= GGML_MAX_NAME) { + GGML_LOG_ERROR("%s: tensor name %" PRIi64 " is too long: %zu >= %d\n", __func__, i, name.length(), GGML_MAX_NAME); + ok = false; + break; + } + ggml_set_name(&info.t, name.c_str()); + + // make sure there are no duplicate tensor names + for (int64_t j = 0; ok && j < i; ++j) { + if (strcmp(info.t.name, ctx->info[j].t.name) == 0) { + GGML_LOG_ERROR("%s: duplicate tensor name '%s' for tensors %" PRIi64 " and %" PRIi64 "\n", __func__, info.t.name, j, i); + ok = false; + break; + } + } + } + if (!ok) { + break; + } + + // tensor shape + { + uint32_t n_dims = -1; + ok = ok && gr.read(n_dims); + if (n_dims > GGML_MAX_DIMS) { + GGML_LOG_ERROR("%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n", + __func__, info.t.name, n_dims, GGML_MAX_DIMS); + ok = false; + break; + } + for (uint32_t j = 0; ok && j < GGML_MAX_DIMS; ++j) { + info.t.ne[j] = 1; + if (j < n_dims) { + ok = ok && gr.read(info.t.ne[j]); + } + + // check that all ne are non-negative + if (info.t.ne[j] < 0) { + GGML_LOG_ERROR("%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n", + __func__, info.t.name, j, info.t.ne[j]); + ok = false; + break; + } + } + + // check that the total number of elements is representable + if (ok && ((INT64_MAX/info.t.ne[1] <= info.t.ne[0]) || + (INT64_MAX/info.t.ne[2] <= info.t.ne[0]*info.t.ne[1]) || + (INT64_MAX/info.t.ne[3] <= info.t.ne[0]*info.t.ne[1]*info.t.ne[2]))) { + + GGML_LOG_ERROR("%s: total number of elements in tensor '%s' with shape " + "(%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") is >= %" PRIi64 "\n", + __func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], INT64_MAX); + ok = false; + break; + } + } + if (!ok) { + break; + } + + // tensor type + { + ok = ok && gr.read(info.t.type); + + // check that tensor type is within defined range + if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) { + GGML_LOG_ERROR("%s: tensor '%s' has invalid ggml type %d (%s)\n", + __func__, info.t.name, info.t.type, ggml_type_name(info.t.type)); + ok = false; + break; + } + const size_t type_size = ggml_type_size(info.t.type); + const int64_t blck_size = ggml_blck_size(info.t.type); + + // check that row size is divisible by block size + if (blck_size == 0 || info.t.ne[0] % blck_size != 0) { + GGML_LOG_ERROR("%s: tensor '%s' of type %d (%s) has %" PRId64 " elements per row, " + "not a multiple of block size (%" PRId64 ")\n", + __func__, info.t.name, (int) info.t.type, ggml_type_name(info.t.type), info.t.ne[0], blck_size); + ok = false; + break; + } + + // calculate byte offsets given the tensor shape and type + info.t.nb[0] = type_size; + info.t.nb[1] = info.t.nb[0]*(info.t.ne[0]/blck_size); + for (int j = 2; j < GGML_MAX_DIMS; ++j) { + info.t.nb[j] = info.t.nb[j - 1]*info.t.ne[j - 1]; + } + } + if (!ok) { + break; + } + + // tensor data offset within buffer + ok = ok && gr.read(info.offset); + + ctx->info.push_back(info); + } + + if (!ok) { + GGML_LOG_ERROR("%s: failed to read tensor info\n", __func__); + gguf_free(ctx); + return nullptr; + } + GGML_ASSERT(int64_t(ctx->info.size()) == n_tensors); + + // we require the data section to be aligned, so take into account any padding + if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) { + GGML_LOG_ERROR("%s: failed to seek to beginning of data section\n", __func__); + gguf_free(ctx); + return nullptr; + } + + // store the current file offset - this is where the data section starts + ctx->offset = ftell(file); + + // compute the total size of the data section, taking into account the alignment + { + ctx->size = 0; + for (size_t i = 0; i < ctx->info.size(); ++i) { + const gguf_tensor_info & ti = ctx->info[i]; + if (ti.offset != ctx->size) { + GGML_LOG_ERROR("%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n", + __func__, ti.t.name, ti.offset, ctx->size); + GGML_LOG_ERROR("%s: failed to read tensor data\n", __func__); + gguf_free(ctx); + return nullptr; + } + ctx->size += GGML_PAD(ggml_nbytes(&ti.t), ctx->alignment); + } + } + + // load the tensor data only if requested + if (params.ctx != nullptr) { + // if the provided gguf_context is no_alloc, then we create "empty" tensors and do not read the binary blob + // otherwise, we load the binary blob into the created ggml_context as well, and point the "data" members of + // the ggml_tensor structs to the appropriate locations in the binary blob + + // compute the exact size needed for the new ggml_context + const size_t mem_size = + params.no_alloc ? + (n_tensors )*ggml_tensor_overhead() : + (n_tensors + 1)*ggml_tensor_overhead() + ctx->size; + + struct ggml_init_params pdata = { + /*mem_size =*/ mem_size, + /*mem_buffer =*/ nullptr, + /*no_alloc =*/ params.no_alloc, + }; + + *params.ctx = ggml_init(pdata); + if (*params.ctx == nullptr) { + GGML_LOG_ERROR("%s: failed to initialize ggml context for storing tensors\n", __func__); + gguf_free(ctx); + return nullptr; + } + + struct ggml_context * ctx_data = *params.ctx; + + struct ggml_tensor * data = nullptr; + + if (!params.no_alloc) { + data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size); + + ok = ok && data != nullptr; + + if (ok) { + ggml_set_name(data, "GGUF tensor data binary blob"); + } + + // read the binary blob with the tensor data + ok = ok && gr.read(data->data, ctx->size); + + if (!ok) { + GGML_LOG_ERROR("%s: failed to read tensor data binary blob\n", __func__); + ggml_free(ctx_data); + *params.ctx = nullptr; + gguf_free(ctx); + return nullptr; + } + + ctx->data = data->data; + } + + ggml_set_no_alloc(ctx_data, true); + + // create the tensors + for (size_t i = 0; i < ctx->info.size(); ++i) { + const struct gguf_tensor_info & info = ctx->info[i]; + + struct ggml_tensor * cur = ggml_new_tensor(ctx_data, info.t.type, GGML_MAX_DIMS, info.t.ne); + + ok = ok && cur != nullptr; + + if (!ok) { + break; + } + + ggml_set_name(cur, info.t.name); + + // point the data member to the appropriate location in the binary blob using the tensor info + if (!params.no_alloc) { + cur->data = (char *) data->data + info.offset; + } + } + + if (!ok) { + GGML_LOG_ERROR("%s: failed to create tensors\n", __func__); + ggml_free(ctx_data); + *params.ctx = nullptr; + gguf_free(ctx); + return nullptr; + } + + ggml_set_no_alloc(ctx_data, params.no_alloc); + } + + return ctx; +} + +struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) { + FILE * file = ggml_fopen(fname, "rb"); + + if (!file) { + GGML_LOG_ERROR("%s: failed to open GGUF file '%s'\n", __func__, fname); + return nullptr; + } + + struct gguf_context * result = gguf_init_from_file_impl(file, params); + fclose(file); + return result; +} + +void gguf_free(struct gguf_context * ctx) { + if (ctx == nullptr) { + return; + } + delete ctx; +} + +const char * gguf_type_name(enum gguf_type type) { + auto it = GGUF_TYPE_NAME.find(type); + return it == GGUF_TYPE_NAME.end() ? nullptr : it->second; +} + +uint32_t gguf_get_version(const struct gguf_context * ctx) { + return ctx->version; +} + +size_t gguf_get_alignment(const struct gguf_context * ctx) { + return ctx->alignment; +} + +size_t gguf_get_data_offset(const struct gguf_context * ctx) { + return ctx->offset; +} + +int64_t gguf_get_n_kv(const struct gguf_context * ctx) { + return ctx->kv.size(); +} + +int64_t gguf_find_key(const struct gguf_context * ctx, const char * key) { + // return -1 if key not found + int64_t keyfound = -1; + + const int64_t n_kv = gguf_get_n_kv(ctx); + + for (int64_t i = 0; i < n_kv; ++i) { + if (strcmp(key, gguf_get_key(ctx, i)) == 0) { + keyfound = i; + break; + } + } + + return keyfound; +} + +const char * gguf_get_key(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + return ctx->kv[key_id].get_key().c_str(); +} + +enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + return ctx->kv[key_id].is_array ? GGUF_TYPE_ARRAY : ctx->kv[key_id].get_type(); +} + +enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].is_array); + return ctx->kv[key_id].get_type(); +} + +const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING); + return ctx->kv[key_id].data.data(); +} + +const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING); + return ctx->kv[key_id].data_string[i].c_str(); +} + +size_t gguf_get_arr_n(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + + if (ctx->kv[key_id].type == GGUF_TYPE_STRING) { + return ctx->kv[key_id].data_string.size(); + } + + const size_t type_size = gguf_type_size(ctx->kv[key_id].type); + GGML_ASSERT(ctx->kv[key_id].data.size() % type_size == 0); + return ctx->kv[key_id].data.size() / type_size; +} + +uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +int8_t gguf_get_val_i8(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +int16_t gguf_get_val_i16(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +int32_t gguf_get_val_i32(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +float gguf_get_val_f32(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +int64_t gguf_get_val_i64(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +double gguf_get_val_f64(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +bool gguf_get_val_bool(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val().c_str(); +} + +const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING); + return ctx->kv[key_id].data.data(); +} + +int64_t gguf_get_n_tensors(const struct gguf_context * ctx) { + return ctx->info.size(); +} + +int64_t gguf_find_tensor(const struct gguf_context * ctx, const char * name) { + // return -1 if tensor not found + int64_t tensor_id = -1; + + const int64_t n_tensors = gguf_get_n_tensors(ctx); + + for (int64_t i = 0; i < n_tensors; ++i) { + if (strcmp(name, gguf_get_tensor_name(ctx, i)) == 0) { + tensor_id = i; + break; + } + } + + return tensor_id; +} + +size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int64_t tensor_id) { + GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx)); + return ctx->info[tensor_id].offset; +} + +const char * gguf_get_tensor_name(const struct gguf_context * ctx, int64_t tensor_id) { + GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx)); + return ctx->info[tensor_id].t.name; +} + +enum ggml_type gguf_get_tensor_type(const struct gguf_context * ctx, int64_t tensor_id) { + GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx)); + return ctx->info[tensor_id].t.type; +} + +size_t gguf_get_tensor_size(const struct gguf_context * ctx, int64_t tensor_id) { + GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx)); + return ggml_nbytes(&ctx->info[tensor_id].t); +} + +int64_t gguf_remove_key(struct gguf_context * ctx, const char * key) { + const int64_t key_id = gguf_find_key(ctx, key); + if (key_id >= 0) { + ctx->kv.erase(ctx->kv.begin() + key_id); + } + return key_id; +} + +template +static void gguf_check_reserved_keys(const std::string & key, const T val) { + if (key == GGUF_KEY_GENERAL_ALIGNMENT) { + if constexpr (std::is_same::value) { + GGML_ASSERT(val > 0 && (val & (val - 1)) == 0 && GGUF_KEY_GENERAL_ALIGNMENT " must be power of 2"); + } else { + GGML_UNUSED(val); + GGML_ABORT(GGUF_KEY_GENERAL_ALIGNMENT " must be type u32"); + } + } +} + +void gguf_set_val_u8(struct gguf_context * ctx, const char * key, uint8_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_i8(struct gguf_context * ctx, const char * key, int8_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_u16(struct gguf_context * ctx, const char * key, uint16_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_i16(struct gguf_context * ctx, const char * key, int16_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_u32(struct gguf_context * ctx, const char * key, uint32_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_i32(struct gguf_context * ctx, const char * key, int32_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_u64(struct gguf_context * ctx, const char * key, uint64_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_i64(struct gguf_context * ctx, const char * key, int64_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_f64(struct gguf_context * ctx, const char * key, double val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_str(struct gguf_context * ctx, const char * key, const char * val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, std::string(val)); +} + +void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, size_t n) { + gguf_check_reserved_keys(key, data); + gguf_remove_key(ctx, key); + + const size_t nbytes = n*gguf_type_size(type); + std::vector tmp(nbytes); + if (!tmp.empty()) { + memcpy(tmp.data(), data, nbytes); + } + ctx->kv.emplace_back(key, tmp); + ctx->kv.back().cast(type); +} + +void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** data, size_t n) { + gguf_check_reserved_keys(key, data); + gguf_remove_key(ctx, key); + + std::vector tmp(n); + for (size_t i = 0; i < n; ++i) { + tmp[i] = data[i]; + } + ctx->kv.emplace_back(key, tmp); +} + +// set or add KV pairs from another context +void gguf_set_kv(struct gguf_context * ctx, const struct gguf_context * src) { + const int64_t n_kv = gguf_get_n_kv(src); + for (int64_t i = 0; i < n_kv; ++i) { + const struct gguf_kv & kv = src->kv[i]; + + if (!kv.is_array) { + switch (kv.get_type()) { + case GGUF_TYPE_UINT8: gguf_set_val_u8 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_INT8: gguf_set_val_i8 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_UINT16: gguf_set_val_u16 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_INT16: gguf_set_val_i16 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_UINT32: gguf_set_val_u32 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_INT32: gguf_set_val_i32 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_UINT64: gguf_set_val_u64 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_INT64: gguf_set_val_i64 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_FLOAT64: gguf_set_val_f64 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_BOOL: gguf_set_val_bool(ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_STRING: gguf_set_val_str (ctx, kv.get_key().c_str(), kv.get_val().c_str()); break; + case GGUF_TYPE_ARRAY: + default: GGML_ABORT("invalid type"); + } + continue; + } + + const size_t ne = kv.get_ne(); + + switch (kv.get_type()) { + case GGUF_TYPE_UINT8: + case GGUF_TYPE_INT8: + case GGUF_TYPE_UINT16: + case GGUF_TYPE_INT16: + case GGUF_TYPE_UINT32: + case GGUF_TYPE_INT32: + case GGUF_TYPE_FLOAT32: + case GGUF_TYPE_UINT64: + case GGUF_TYPE_INT64: + case GGUF_TYPE_FLOAT64: + case GGUF_TYPE_BOOL: { + gguf_set_arr_data(ctx, kv.get_key().c_str(), kv.get_type(), kv.data.data(), ne); + } break; + case GGUF_TYPE_STRING: { + std::vector tmp(ne); + for (size_t j = 0; j < ne; ++j) { + tmp[j] = kv.data_string[j].c_str(); + } + gguf_set_arr_str(ctx, kv.get_key().c_str(), tmp.data(), ne); + } break; + case GGUF_TYPE_ARRAY: + default: GGML_ABORT("invalid type"); + } + } +} + +void gguf_add_tensor( + struct gguf_context * ctx, + const struct ggml_tensor * tensor) { + GGML_ASSERT(tensor); + if (gguf_find_tensor(ctx, tensor->name) != -1) { + GGML_ABORT("duplicate tensor name: %s", tensor->name); + } + + struct gguf_tensor_info ti; + ti.t = *tensor; + ti.offset = ctx->info.empty() ? 0 : + ctx->info.back().offset + GGML_PAD(ggml_nbytes(&ctx->info.back().t), ctx->alignment); + ctx->info.push_back(ti); +} + +void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type) { + const int64_t tensor_id = gguf_find_tensor(ctx, name); + if (tensor_id < 0) { + GGML_ABORT("tensor not found: %s", name); + } + struct ggml_tensor * tensor = &ctx->info[tensor_id].t; + const size_t type_size = ggml_type_size(type); + const int64_t blck_size = ggml_blck_size(type); + + tensor->type = type; + GGML_ASSERT(tensor->ne[0] % blck_size == 0 && "tensor row size not divisible by block size of new type"); + + tensor->nb[0] = type_size; + tensor->nb[1] = tensor->nb[0]*(tensor->ne[0]/blck_size); + for (int i = 2; i < GGML_MAX_DIMS; i++) { + tensor->nb[i] = tensor->nb[i - 1]*tensor->ne[i - 1]; + } + + // update offsets + const int64_t n_tensors = gguf_get_n_tensors(ctx); + for (int64_t i = tensor_id + 1; i < n_tensors; ++i) { + ctx->info[i].offset = ctx->info[i - 1].offset + GGML_PAD(ggml_nbytes(&ctx->info[i - 1].t), ctx->alignment); + } +} + +void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data) { + const int64_t tensor_id = gguf_find_tensor(ctx, name); + if (tensor_id < 0) { + GGML_ABORT("tensor not found: %s", name); + } + + ctx->info[tensor_id].t.data = (void *)(uintptr_t)data; // double cast suppresses warning about casting away const +} + +struct gguf_writer { + std::vector & buf; + + gguf_writer(std::vector & buf) : buf(buf) {} + + template + void write(const T & val) const { + for (size_t i = 0; i < sizeof(val); ++i) { + buf.push_back(reinterpret_cast(&val)[i]); + } + } + + void write(const std::vector & val) const { + buf.insert(buf.end(), val.begin(), val.end()); + } + + void write(const bool & val) const { + const int8_t val8 = val ? 1 : 0; + write(val8); + } + + void write(const std::string & val) const { + { + const uint64_t n = val.length(); + write(n); + } + for (size_t i = 0; i < val.length(); ++i) { + buf.push_back(reinterpret_cast(val.data())[i]); + } + } + + void write(const char * val) const { + write(std::string(val)); + } + + void write(const enum ggml_type & val) const { + write(int32_t(val)); + } + + void write(const enum gguf_type & val) const { + write(int32_t(val)); + } + + void write(const struct gguf_kv & kv) const { + const uint64_t ne = kv.get_ne(); + + write(kv.get_key()); + + if (kv.is_array) { + write(GGUF_TYPE_ARRAY); + write(kv.get_type()); + write(ne); + } else { + write(kv.get_type()); + } + + switch (kv.get_type()) { + case GGUF_TYPE_UINT8: + case GGUF_TYPE_INT8: + case GGUF_TYPE_UINT16: + case GGUF_TYPE_INT16: + case GGUF_TYPE_UINT32: + case GGUF_TYPE_INT32: + case GGUF_TYPE_FLOAT32: + case GGUF_TYPE_UINT64: + case GGUF_TYPE_INT64: + case GGUF_TYPE_FLOAT64: { + write(kv.data); + } break; + case GGUF_TYPE_BOOL: { + for (size_t i = 0; i < ne; ++i) { + write(kv.get_val(i)); + } + } break; + case GGUF_TYPE_STRING: { + for (size_t i = 0; i < ne; ++i) { + write(kv.get_val(i)); + } + } break; + case GGUF_TYPE_ARRAY: + default: GGML_ABORT("invalid type"); + } + } + + void write_tensor_meta(const struct gguf_tensor_info & info) const { + write(info.t.name); + + const uint32_t n_dims = ggml_n_dims(&info.t); + write(n_dims); + + for (uint32_t j = 0; j < n_dims; ++j) { + write(info.t.ne[j]); + } + write(info.t.type); + write(info.offset); + } + + void pad(const size_t alignment) const { + while (buf.size() % alignment != 0) { + const int8_t zero = 0; + write(zero); + } + } + + void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) const { + GGML_ASSERT(buf.size() - offset_data == info.offset); + + GGML_ASSERT(ggml_is_contiguous(&info.t)); + const size_t offset = buf.size(); + const size_t nbytes = ggml_nbytes(&info.t); + + buf.resize(offset + nbytes); + if (info.t.buffer) { + ggml_backend_tensor_get(&info.t, buf.data() + offset, 0, nbytes); + } else { + GGML_ASSERT(info.t.data); + memcpy(buf.data() + offset, info.t.data, nbytes); + } + + pad(alignment); + } +}; + +void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & buf, bool only_meta) { + const struct gguf_writer gw(buf); + + const int64_t n_kv = gguf_get_n_kv(ctx); + const int64_t n_tensors = gguf_get_n_tensors(ctx); + + // write header + gw.write(GGUF_MAGIC[0]); + gw.write(GGUF_MAGIC[1]); + gw.write(GGUF_MAGIC[2]); + gw.write(GGUF_MAGIC[3]); + gw.write(ctx->version); + gw.write(n_tensors); + gw.write(n_kv); + + // write key-value pairs + for (int64_t i = 0; i < n_kv; ++i) { + gw.write(ctx->kv[i]); + } + + // write tensor info + for (int64_t i = 0; i < n_tensors; ++i) { + gw.write_tensor_meta(ctx->info[i]); + } + + // we require the data section to be aligned + gw.pad(ctx->alignment); + + if (only_meta) { + return; + } + + const size_t offset_data = gw.buf.size(); + + // write tensor data + for (int64_t i = 0; i < n_tensors; ++i) { + gw.write_tensor_data(ctx->info[i], offset_data, ctx->alignment); + } +} + +bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) { + FILE * file = ggml_fopen(fname, "wb"); + + if (!file) { + GGML_LOG_ERROR("%s: failed to open file '%s' for writing GGUF data\n", __func__, fname); + return false; + } + + std::vector buf; + gguf_write_to_buf(ctx, buf, only_meta); + const bool ok = fwrite(buf.data(), 1, buf.size(), file) == buf.size(); + fclose(file); + return ok; +} + +size_t gguf_get_meta_size(const struct gguf_context * ctx) { + // only return size + std::vector buf; + gguf_write_to_buf(ctx, buf, /*only_meta =*/ true); + return buf.size(); +} + +void gguf_get_meta_data(const struct gguf_context * ctx, void * data) { + std::vector buf; + gguf_write_to_buf(ctx, buf, /*only_meta =*/ true); + memcpy(data, buf.data(), buf.size()); +} diff --git a/gguf-py/LICENSE b/gguf-py/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..76f67efdc6470081b512a8db5bf2b1d4962d9c3c --- /dev/null +++ b/gguf-py/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Georgi Gerganov + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/gguf-py/README.md b/gguf-py/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ca7e09c68184f28b23a35ec7d611ff6ca117433c --- /dev/null +++ b/gguf-py/README.md @@ -0,0 +1,99 @@ +## gguf + +This is a Python package for writing binary files in the [GGUF](https://github.com/ggml-org/ggml/pull/302) +(GGML Universal File) format. + +See [convert_hf_to_gguf.py](https://github.com/ggml-org/llama.cpp/blob/master/convert_hf_to_gguf.py) +as an example for its usage. + +## Installation +```sh +pip install gguf +``` + +Optionally, you can install gguf with the extra 'gui' to enable the visual GGUF editor. +```sh +pip install gguf[gui] +``` + +## API Examples/Simple Tools + +[examples/writer.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/examples/writer.py) — Generates `example.gguf` in the current directory to demonstrate generating a GGUF file. Note that this file cannot be used as a model. + +[examples/reader.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/examples/reader.py) — Extracts and displays key-value pairs and tensor details from a GGUF file in a readable format. + +[gguf/scripts/gguf_dump.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_dump.py) — Dumps a GGUF file's metadata to the console. + +[gguf/scripts/gguf_set_metadata.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_set_metadata.py) — Allows changing simple metadata values in a GGUF file by key. + +[gguf/scripts/gguf_convert_endian.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_convert_endian.py) — Allows converting the endianness of GGUF files. + +[gguf/scripts/gguf_new_metadata.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_new_metadata.py) — Copies a GGUF file with added/modified/removed metadata values. + +[gguf/scripts/gguf_editor_gui.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_editor_gui.py) — Allows for viewing, editing, adding, or removing metadata values within a GGUF file as well as viewing its tensors with a Qt interface. + +## Development +Maintainers who participate in development of this package are advised to install it in editable mode: + +```sh +cd /path/to/llama.cpp/gguf-py + +pip install --editable . +``` + +**Note**: This may require to upgrade your Pip installation, with a message saying that editable installation currently requires `setup.py`. +In this case, upgrade Pip to the latest: + +```sh +pip install --upgrade pip +``` + +## Automatic publishing with CI + +There's a GitHub workflow to make a release automatically upon creation of tags in a specified format. + +1. Bump the version in `pyproject.toml`. +2. Create a tag named `gguf-vx.x.x` where `x.x.x` is the semantic version number. + +```sh +git tag -a gguf-v1.0.0 -m "Version 1.0 release" +``` + +3. Push the tags. + +```sh +git push origin --tags +``` + +## Manual publishing +If you want to publish the package manually for any reason, you need to have `twine` and `build` installed: + +```sh +pip install build twine +``` + +Then, follow these steps to release a new version: + +1. Bump the version in `pyproject.toml`. +2. Build the package: + +```sh +python -m build +``` + +3. Upload the generated distribution archives: + +```sh +python -m twine upload dist/* +``` + +## Run Unit Tests + +From root of this repository you can run this command to run all the unit tests + +```bash +python -m unittest discover ./gguf-py -v +``` + +## TODO +- [ ] Include conversion scripts as command line entry points in this package. diff --git a/gguf-py/examples/reader.py b/gguf-py/examples/reader.py new file mode 100644 index 0000000000000000000000000000000000000000..703b782b5fa6672020492918fa498a77f0e96dfa --- /dev/null +++ b/gguf-py/examples/reader.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +import logging +import sys +from pathlib import Path + +logger = logging.getLogger("reader") + +# Necessary to load the local gguf package +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from gguf.gguf_reader import GGUFReader + + +def read_gguf_file(gguf_file_path): + """ + Reads and prints key-value pairs and tensor information from a GGUF file in an improved format. + + Parameters: + - gguf_file_path: Path to the GGUF file. + """ + + reader = GGUFReader(gguf_file_path) + + # List all key-value pairs in a columnized format + print("Key-Value Pairs:") # noqa: NP100 + max_key_length = max(len(key) for key in reader.fields.keys()) + for key, field in reader.fields.items(): + value = field.parts[field.data[0]] + print(f"{key:{max_key_length}} : {value}") # noqa: NP100 + print("----") # noqa: NP100 + + # List all tensors + print("Tensors:") # noqa: NP100 + tensor_info_format = "{:<30} | Shape: {:<15} | Size: {:<12} | Quantization: {}" + print(tensor_info_format.format("Tensor Name", "Shape", "Size", "Quantization")) # noqa: NP100 + print("-" * 80) # noqa: NP100 + for tensor in reader.tensors: + shape_str = "x".join(map(str, tensor.shape)) + size_str = str(tensor.n_elements) + quantization_str = tensor.tensor_type.name + print(tensor_info_format.format(tensor.name, shape_str, size_str, quantization_str)) # noqa: NP100 + + +if __name__ == '__main__': + if len(sys.argv) < 2: + logger.info("Usage: reader.py ") + sys.exit(1) + gguf_file_path = sys.argv[1] + read_gguf_file(gguf_file_path) diff --git a/gguf-py/examples/writer.py b/gguf-py/examples/writer.py new file mode 100644 index 0000000000000000000000000000000000000000..731873a7d666ca5a58dfb95a42bb809fe92fe747 --- /dev/null +++ b/gguf-py/examples/writer.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +import sys +from pathlib import Path + +import numpy as np + +# Necessary to load the local gguf package +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from gguf import GGUFWriter # noqa: E402 + + +# Example usage: +def writer_example() -> None: + # Example usage with a file + gguf_writer = GGUFWriter("example.gguf", "llama") + + gguf_writer.add_block_count(12) + gguf_writer.add_uint32("answer", 42) # Write a 32-bit integer + gguf_writer.add_float32("answer_in_float", 42.0) # Write a 32-bit float + gguf_writer.add_custom_alignment(64) + + tensor1 = np.ones((32,), dtype=np.float32) * 100.0 + tensor2 = np.ones((64,), dtype=np.float32) * 101.0 + tensor3 = np.ones((96,), dtype=np.float32) * 102.0 + + gguf_writer.add_tensor("tensor1", tensor1) + gguf_writer.add_tensor("tensor2", tensor2) + gguf_writer.add_tensor("tensor3", tensor3) + + gguf_writer.write_header_to_file() + gguf_writer.write_kv_data_to_file() + gguf_writer.write_tensors_to_file() + + gguf_writer.close() + + +if __name__ == '__main__': + writer_example() diff --git a/gguf-py/gguf/__init__.py b/gguf-py/gguf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..243defc4c1ca42d3713017d8902592f54ac849cd --- /dev/null +++ b/gguf-py/gguf/__init__.py @@ -0,0 +1,9 @@ +from .constants import * +from .lazy import * +from .gguf_reader import * +from .gguf_writer import * +from .quants import * +from .tensor_mapping import * +from .vocab import * +from .utility import * +from .metadata import * diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..834a1d5e1a97ed98f01196e0a0542afa95739b8c --- /dev/null +++ b/gguf-py/gguf/constants.py @@ -0,0 +1,2438 @@ +from __future__ import annotations + +from enum import Enum, IntEnum, auto +from typing import Any + +# +# constants +# + +GGUF_MAGIC = 0x46554747 # "GGUF" +GGUF_VERSION = 3 +GGUF_DEFAULT_ALIGNMENT = 32 +GGML_QUANT_VERSION = 2 # GGML_QNT_VERSION from ggml.h + +# +# metadata keys +# + + +class Keys: + class General: + TYPE = "general.type" + ARCHITECTURE = "general.architecture" + QUANTIZATION_VERSION = "general.quantization_version" + ALIGNMENT = "general.alignment" + FILE_TYPE = "general.file_type" + + # Authorship Metadata + NAME = "general.name" + AUTHOR = "general.author" + VERSION = "general.version" + ORGANIZATION = "general.organization" + + FINETUNE = "general.finetune" + BASENAME = "general.basename" + + DESCRIPTION = "general.description" + QUANTIZED_BY = "general.quantized_by" + + SIZE_LABEL = "general.size_label" + + # Licensing details + LICENSE = "general.license" + LICENSE_NAME = "general.license.name" + LICENSE_LINK = "general.license.link" + + # Typically represents the converted GGUF repo (Unless native) + URL = "general.url" # Model Website/Paper + DOI = "general.doi" + UUID = "general.uuid" + REPO_URL = "general.repo_url" # Model Source Repository (git/svn/etc...) + + # Model Source during conversion + SOURCE_URL = "general.source.url" # Model Website/Paper + SOURCE_DOI = "general.source.doi" + SOURCE_UUID = "general.source.uuid" + SOURCE_REPO_URL = "general.source.repo_url" # Model Source Repository (git/svn/etc...) + + # Base Model Source. There can be more than one source if it's a merged + # model like with 'Mistral-7B-Merge-14-v0.1'. This will assist in + # tracing linage of models as it is finetuned or merged over time. + BASE_MODEL_COUNT = "general.base_model.count" + BASE_MODEL_NAME = "general.base_model.{id}.name" + BASE_MODEL_AUTHOR = "general.base_model.{id}.author" + BASE_MODEL_VERSION = "general.base_model.{id}.version" + BASE_MODEL_ORGANIZATION = "general.base_model.{id}.organization" + BASE_MODEL_DESCRIPTION = "general.base_model.{id}.description" + BASE_MODEL_URL = "general.base_model.{id}.url" # Model Website/Paper + BASE_MODEL_DOI = "general.base_model.{id}.doi" + BASE_MODEL_UUID = "general.base_model.{id}.uuid" + BASE_MODEL_REPO_URL = "general.base_model.{id}.repo_url" # Model Source Repository (git/svn/etc...) + + # Dataset Source + DATASET_COUNT = "general.dataset.count" + DATASET_NAME = "general.dataset.{id}.name" + DATASET_AUTHOR = "general.dataset.{id}.author" + DATASET_VERSION = "general.dataset.{id}.version" + DATASET_ORGANIZATION = "general.dataset.{id}.organization" + DATASET_DESCRIPTION = "general.dataset.{id}.description" + DATASET_URL = "general.dataset.{id}.url" # Model Website/Paper + DATASET_DOI = "general.dataset.{id}.doi" + DATASET_UUID = "general.dataset.{id}.uuid" + DATASET_REPO_URL = "general.dataset.{id}.repo_url" # Model Source Repository (git/svn/etc...) + + # Array based KV stores + TAGS = "general.tags" + LANGUAGES = "general.languages" + + class LLM: + VOCAB_SIZE = "{arch}.vocab_size" + CONTEXT_LENGTH = "{arch}.context_length" + EMBEDDING_LENGTH = "{arch}.embedding_length" + FEATURES_LENGTH = "{arch}.features_length" + BLOCK_COUNT = "{arch}.block_count" + LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count" + FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" + EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length" + EXPERT_SHARED_FEED_FORWARD_LENGTH = "{arch}.expert_shared_feed_forward_length" + USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" + TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" + EXPERT_COUNT = "{arch}.expert_count" + EXPERT_USED_COUNT = "{arch}.expert_used_count" + EXPERT_SHARED_COUNT = "{arch}.expert_shared_count" + EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale" + EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm" + EXPERT_GATING_FUNC = "{arch}.expert_gating_func" + MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers" + POOLING_TYPE = "{arch}.pooling_type" + LOGIT_SCALE = "{arch}.logit_scale" + DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" + ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping" + FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping" + SWIN_NORM = "{arch}.swin_norm" + RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers" + TIME_MIX_EXTRA_DIM = "{arch}.time_mix_extra_dim" + TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim" + RESIDUAL_SCALE = "{arch}.residual_scale" + EMBEDDING_SCALE = "{arch}.embedding_scale" + TOKEN_SHIFT_COUNT = "{arch}.token_shift_count" + INTERLEAVE_MOE_LAYER_STEP = "{arch}.interleave_moe_layer_step" + + class Attention: + HEAD_COUNT = "{arch}.attention.head_count" + HEAD_COUNT_KV = "{arch}.attention.head_count_kv" + MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias" + CLAMP_KQV = "{arch}.attention.clamp_kqv" + KEY_LENGTH = "{arch}.attention.key_length" + VALUE_LENGTH = "{arch}.attention.value_length" + LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon" + LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon" + GROUPNORM_EPS = "{arch}.attention.group_norm_epsilon" + GROUPNORM_GROUPS = "{arch}.attention.group_norm_groups" + CAUSAL = "{arch}.attention.causal" + Q_LORA_RANK = "{arch}.attention.q_lora_rank" + KV_LORA_RANK = "{arch}.attention.kv_lora_rank" + DECAY_LORA_RANK = "{arch}.attention.decay_lora_rank" + ICLR_LORA_RANK = "{arch}.attention.iclr_lora_rank" + VALUE_RESIDUAL_MIX_LORA_RANK = "{arch}.attention.value_residual_mix_lora_rank" + GATE_LORA_RANK = "{arch}.attention.gate_lora_rank" + REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count" + SLIDING_WINDOW = "{arch}.attention.sliding_window" + SCALE = "{arch}.attention.scale" + KEY_LENGTH_MLA = "{arch}.attention.key_length_mla" + VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla" + + class Rope: + DIMENSION_COUNT = "{arch}.rope.dimension_count" + DIMENSION_SECTIONS = "{arch}.rope.dimension_sections" + FREQ_BASE = "{arch}.rope.freq_base" + SCALING_TYPE = "{arch}.rope.scaling.type" + SCALING_FACTOR = "{arch}.rope.scaling.factor" + SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor" + SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" + SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" + SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier" + + class Split: + LLM_KV_SPLIT_NO = "split.no" + LLM_KV_SPLIT_COUNT = "split.count" + LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count" + + class SSM: + CONV_KERNEL = "{arch}.ssm.conv_kernel" + INNER_SIZE = "{arch}.ssm.inner_size" + STATE_SIZE = "{arch}.ssm.state_size" + TIME_STEP_RANK = "{arch}.ssm.time_step_rank" + DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" + + class WKV: + HEAD_SIZE = "{arch}.wkv.head_size" + + class PosNet: + EMBEDDING_LENGTH = "{arch}.posnet.embedding_length" + BLOCK_COUNT = "{arch}.posnet.block_count" + + class ConvNext: + EMBEDDING_LENGTH = "{arch}.convnext.embedding_length" + BLOCK_COUNT = "{arch}.convnext.block_count" + + class Classifier: + OUTPUT_LABELS = "{arch}.classifier.output_labels" + + class Tokenizer: + MODEL = "tokenizer.ggml.model" + PRE = "tokenizer.ggml.pre" + LIST = "tokenizer.ggml.tokens" + TOKEN_TYPE = "tokenizer.ggml.token_type" + TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types + SCORES = "tokenizer.ggml.scores" + MERGES = "tokenizer.ggml.merges" + BOS_ID = "tokenizer.ggml.bos_token_id" + EOS_ID = "tokenizer.ggml.eos_token_id" + EOT_ID = "tokenizer.ggml.eot_token_id" + EOM_ID = "tokenizer.ggml.eom_token_id" + UNK_ID = "tokenizer.ggml.unknown_token_id" + SEP_ID = "tokenizer.ggml.seperator_token_id" + PAD_ID = "tokenizer.ggml.padding_token_id" + MASK_ID = "tokenizer.ggml.mask_token_id" + ADD_BOS = "tokenizer.ggml.add_bos_token" + ADD_EOS = "tokenizer.ggml.add_eos_token" + ADD_PREFIX = "tokenizer.ggml.add_space_prefix" + REMOVE_EXTRA_WS = "tokenizer.ggml.remove_extra_whitespaces" + PRECOMPILED_CHARSMAP = "tokenizer.ggml.precompiled_charsmap" + HF_JSON = "tokenizer.huggingface.json" + RWKV = "tokenizer.rwkv.world" + CHAT_TEMPLATE = "tokenizer.chat_template" + CHAT_TEMPLATE_N = "tokenizer.chat_template.{name}" + CHAT_TEMPLATES = "tokenizer.chat_templates" + # FIM/Infill special tokens constants + FIM_PRE_ID = "tokenizer.ggml.fim_pre_token_id" + FIM_SUF_ID = "tokenizer.ggml.fim_suf_token_id" + FIM_MID_ID = "tokenizer.ggml.fim_mid_token_id" + FIM_PAD_ID = "tokenizer.ggml.fim_pad_token_id" + FIM_REP_ID = "tokenizer.ggml.fim_rep_token_id" + FIM_SEP_ID = "tokenizer.ggml.fim_sep_token_id" + # deprecated: + PREFIX_ID = "tokenizer.ggml.prefix_token_id" + SUFFIX_ID = "tokenizer.ggml.suffix_token_id" + MIDDLE_ID = "tokenizer.ggml.middle_token_id" + + class Adapter: + TYPE = "adapter.type" + LORA_ALPHA = "adapter.lora.alpha" + + class Clip: + PROJECTOR_TYPE = "clip.projector_type" + HAS_VISION_ENCODER = "clip.has_vision_encoder" + HAS_AUDIO_ENCODER = "clip.has_audio_encoder" + HAS_LLAVA_PROJECTOR = "clip.has_llava_projector" + + class ClipVision: + IMAGE_SIZE = "clip.vision.image_size" + PATCH_SIZE = "clip.vision.patch_size" + EMBEDDING_LENGTH = "clip.vision.embedding_length" + FEED_FORWARD_LENGTH = "clip.vision.feed_forward_length" + PROJECTION_DIM = "clip.vision.projection_dim" + BLOCK_COUNT = "clip.vision.block_count" + IMAGE_MEAN = "clip.vision.image_mean" + IMAGE_STD = "clip.vision.image_std" + SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size" + USE_GELU = "clip.use_gelu" + USE_SILU = "clip.use_silu" + N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl + + class Attention: + HEAD_COUNT = "clip.vision.attention.head_count" + LAYERNORM_EPS = "clip.vision.attention.layer_norm_epsilon" + + class Projector: + SCALE_FACTOR = "clip.vision.projector.scale_factor" + + class ClipAudio: + NUM_MEL_BINS = "clip.audio.num_mel_bins" + EMBEDDING_LENGTH = "clip.audio.embedding_length" + FEED_FORWARD_LENGTH = "clip.audio.feed_forward_length" + PROJECTION_DIM = "clip.audio.projection_dim" + BLOCK_COUNT = "clip.audio.block_count" + + class Attention: + HEAD_COUNT = "clip.audio.attention.head_count" + LAYERNORM_EPS = "clip.audio.attention.layer_norm_epsilon" + + class Projector: + STACK_FACTOR = "clip.audio.projector.stack_factor" + +# +# recommended mapping of model tensor names for storage in gguf +# + + +class GGUFType: + MODEL = "model" + ADAPTER = "adapter" + MMPROJ = "mmproj" # dummy, unused for now + + +class MODEL_ARCH(IntEnum): + MMPROJ = auto() # dummy arch for clip.cpp + LLAMA = auto() + LLAMA4 = auto() + DECI = auto() + FALCON = auto() + BAICHUAN = auto() + GROK = auto() + GPT2 = auto() + GPTJ = auto() + GPTNEOX = auto() + MPT = auto() + STARCODER = auto() + REFACT = auto() + BERT = auto() + NOMIC_BERT = auto() + NOMIC_BERT_MOE = auto() + NEO_BERT = auto() + JINA_BERT_V2 = auto() + BLOOM = auto() + STABLELM = auto() + QWEN = auto() + QWEN2 = auto() + QWEN2MOE = auto() + QWEN2VL = auto() + QWEN3 = auto() + QWEN3MOE = auto() + PHI2 = auto() + PHI3 = auto() + PHIMOE = auto() + PLAMO = auto() + CODESHELL = auto() + ORION = auto() + INTERNLM2 = auto() + MINICPM = auto() + MINICPM3 = auto() + GEMMA = auto() + GEMMA2 = auto() + GEMMA3 = auto() + STARCODER2 = auto() + RWKV6 = auto() + RWKV6QWEN2 = auto() + RWKV7 = auto() + ARWKV7 = auto() + MAMBA = auto() + XVERSE = auto() + COMMAND_R = auto() + COHERE2 = auto() + DBRX = auto() + OLMO = auto() + OLMO2 = auto() + OLMOE = auto() + OPENELM = auto() + ARCTIC = auto() + DEEPSEEK = auto() + DEEPSEEK2 = auto() + CHATGLM = auto() + GLM4 = auto() + BITNET = auto() + T5 = auto() + T5ENCODER = auto() + JAIS = auto() + NEMOTRON = auto() + EXAONE = auto() + GRANITE = auto() + GRANITE_MOE = auto() + CHAMELEON = auto() + WAVTOKENIZER_DEC = auto() + PLM = auto() + BAILINGMOE = auto() + DOTS1 = auto() + ARCEE = auto() + + +class VISION_PROJECTOR_TYPE(IntEnum): + MLP = auto() + LDP = auto() + LDPV2 = auto() + RESAMPLER = auto() + GLM_EDGE = auto() + MERGER = auto() + GEMMA3 = auto() + + +class MODEL_TENSOR(IntEnum): + TOKEN_EMBD = auto() + TOKEN_EMBD_NORM = auto() + TOKEN_TYPES = auto() + POS_EMBD = auto() + OUTPUT = auto() + OUTPUT_NORM = auto() + ROPE_FREQS = auto() + ROPE_FACTORS_LONG = auto() + ROPE_FACTORS_SHORT = auto() + ATTN_Q = auto() + ATTN_K = auto() + ATTN_V = auto() + ATTN_QKV = auto() + ATTN_OUT = auto() + ATTN_NORM = auto() + ATTN_NORM_2 = auto() + ATTN_OUT_NORM = auto() + ATTN_POST_NORM = auto() + ATTN_ROT_EMBD = auto() + FFN_GATE_INP = auto() + FFN_GATE_INP_SHEXP = auto() + FFN_NORM = auto() + FFN_PRE_NORM = auto() + FFN_POST_NORM = auto() + FFN_GATE = auto() + FFN_DOWN = auto() + FFN_UP = auto() + FFN_ACT = auto() + FFN_NORM_EXP = auto() + FFN_GATE_EXP = auto() + FFN_DOWN_EXP = auto() + FFN_UP_EXP = auto() + FFN_GATE_SHEXP = auto() + FFN_DOWN_SHEXP = auto() + FFN_UP_SHEXP = auto() + FFN_EXP_PROBS_B = auto() + ATTN_Q_NORM = auto() + ATTN_K_NORM = auto() + LAYER_OUT_NORM = auto() + SSM_IN = auto() + SSM_CONV1D = auto() + SSM_X = auto() + SSM_DT = auto() + SSM_A = auto() + SSM_D = auto() + SSM_OUT = auto() + TIME_MIX_W0 = auto() + TIME_MIX_W1 = auto() + TIME_MIX_W2 = auto() + TIME_MIX_A0 = auto() + TIME_MIX_A1 = auto() + TIME_MIX_A2 = auto() + TIME_MIX_V0 = auto() + TIME_MIX_V1 = auto() + TIME_MIX_V2 = auto() + TIME_MIX_G1 = auto() + TIME_MIX_G2 = auto() + TIME_MIX_K_K = auto() + TIME_MIX_K_A = auto() + TIME_MIX_R_K = auto() + TIME_MIX_LERP_X = auto() + TIME_MIX_LERP_K = auto() + TIME_MIX_LERP_V = auto() + TIME_MIX_LERP_R = auto() + TIME_MIX_LERP_G = auto() + TIME_MIX_LERP_FUSED = auto() + TIME_MIX_LERP_W = auto() + TIME_MIX_FIRST = auto() + TIME_MIX_DECAY = auto() + TIME_MIX_DECAY_W1 = auto() + TIME_MIX_DECAY_W2 = auto() + TIME_MIX_KEY = auto() + TIME_MIX_VALUE = auto() + TIME_MIX_RECEPTANCE = auto() + TIME_MIX_GATE = auto() + TIME_MIX_LN = auto() + TIME_MIX_OUTPUT = auto() + CHANNEL_MIX_LERP_K = auto() + CHANNEL_MIX_LERP_R = auto() + CHANNEL_MIX_KEY = auto() + CHANNEL_MIX_RECEPTANCE = auto() + CHANNEL_MIX_VALUE = auto() + ATTN_Q_A = auto() + ATTN_Q_B = auto() + ATTN_KV_A_MQA = auto() + ATTN_KV_B = auto() + ATTN_K_B = auto() + ATTN_V_B = auto() + ATTN_Q_A_NORM = auto() + ATTN_KV_A_NORM = auto() + FFN_SUB_NORM = auto() + ATTN_SUB_NORM = auto() + DEC_ATTN_NORM = auto() + DEC_ATTN_Q = auto() + DEC_ATTN_K = auto() + DEC_ATTN_V = auto() + DEC_ATTN_OUT = auto() + DEC_ATTN_REL_B = auto() + DEC_CROSS_ATTN_NORM = auto() + DEC_CROSS_ATTN_Q = auto() + DEC_CROSS_ATTN_K = auto() + DEC_CROSS_ATTN_V = auto() + DEC_CROSS_ATTN_OUT = auto() + DEC_CROSS_ATTN_REL_B = auto() + DEC_FFN_NORM = auto() + DEC_FFN_GATE = auto() + DEC_FFN_DOWN = auto() + DEC_FFN_UP = auto() + DEC_OUTPUT_NORM = auto() + ENC_ATTN_NORM = auto() + ENC_ATTN_Q = auto() + ENC_ATTN_K = auto() + ENC_ATTN_V = auto() + ENC_ATTN_OUT = auto() + ENC_ATTN_REL_B = auto() + ENC_FFN_NORM = auto() + ENC_FFN_GATE = auto() + ENC_FFN_DOWN = auto() + ENC_FFN_UP = auto() + ENC_OUTPUT_NORM = auto() + CLS = auto() # classifier + CLS_OUT = auto() # classifier output projection + CONV1D = auto() + CONVNEXT_DW = auto() + CONVNEXT_NORM = auto() + CONVNEXT_PW1 = auto() + CONVNEXT_PW2 = auto() + CONVNEXT_GAMMA = auto() + POSNET_CONV1 = auto() + POSNET_CONV2 = auto() + POSNET_NORM = auto() + POSNET_NORM1 = auto() + POSNET_NORM2 = auto() + POSNET_ATTN_NORM = auto() + POSNET_ATTN_Q = auto() + POSNET_ATTN_K = auto() + POSNET_ATTN_V = auto() + POSNET_ATTN_OUT = auto() + # vision + V_MMPROJ = auto() + V_MMPROJ_FC = auto() + V_MMPROJ_MLP = auto() + V_MMPROJ_PEG = auto() + V_ENC_EMBD_CLS = auto() + V_ENC_EMBD_PATCH = auto() + V_ENC_EMBD_POS = auto() + V_ENC_INPUT_NORM = auto() + V_ENC_ATTN_Q = auto() + V_ENC_ATTN_Q_NORM = auto() + V_ENC_ATTN_K = auto() + V_ENC_ATTN_K_NORM = auto() + V_ENC_ATTN_V = auto() + V_ENC_ATTN_O = auto() + V_ENC_ATTN_O_NORM = auto() + V_ENC_POST_ATTN_NORM = auto() + V_ENC_FFN_UP = auto() + V_ENC_FFN_GATE = auto() + V_ENC_FFN_DOWN = auto() + V_LAYER_SCALE_1 = auto() + V_LAYER_SCALE_2 = auto() + V_PRE_NORM = auto() + V_POST_NORM = auto() + V_MM_INP_NORM = auto() + V_MM_INP_PROJ = auto() # gemma3 + V_MM_SOFT_EMB_NORM = auto() # gemma3 + V_RESMPL_POS_EMBD_K = auto() # minicpmv + V_RESMPL_ATTN_Q = auto() # minicpmv + V_RESMPL_ATTN_K = auto() # minicpmv + V_RESMPL_ATTN_V = auto() # minicpmv + V_RESMPL_ATTN_OUT = auto() # minicpmv + V_RESMPL_KV = auto() # minicpmv + V_RESMPL_KV_NORM = auto() # minicpmv + V_RESMPL_POST_NORM = auto() # minicpmv + V_RESMPL_Q_NORM = auto() # minicpmv + V_RESMPL_PROJ = auto() # minicpmv + V_RESMPL_QUERY = auto() # minicpmv + V_TOK_EMBD_IMG_BREAK = auto() # pixtral + V_MM_PATCH_MERGER = auto() # mistral small 3.1 + # audio (mtmd) + A_ENC_EMBD_POS = auto() + A_ENC_CONV1D = auto() + A_PRE_NORM = auto() + A_POST_NORM = auto() + A_ENC_ATTN_Q = auto() + A_ENC_ATTN_K = auto() + A_ENC_ATTN_V = auto() + A_ENC_INPUT_NORM = auto() + A_ENC_OUTPUT = auto() + A_ENC_OUTPUT_NORM = auto() + A_ENC_FFN_UP = auto() + A_ENC_FFN_GATE = auto() + A_ENC_FFN_DOWN = auto() + A_MMPROJ = auto() + A_MMPROJ_FC = auto() + A_MM_NORM_PRE = auto() + A_MM_NORM_MID = auto() + + +MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { + MODEL_ARCH.MMPROJ: "clip", # dummy arch for clip.cpp + MODEL_ARCH.LLAMA: "llama", + MODEL_ARCH.LLAMA4: "llama4", + MODEL_ARCH.DECI: "deci", + MODEL_ARCH.FALCON: "falcon", + MODEL_ARCH.BAICHUAN: "baichuan", + MODEL_ARCH.GROK: "grok", + MODEL_ARCH.GPT2: "gpt2", + MODEL_ARCH.GPTJ: "gptj", + MODEL_ARCH.GPTNEOX: "gptneox", + MODEL_ARCH.MPT: "mpt", + MODEL_ARCH.STARCODER: "starcoder", + MODEL_ARCH.REFACT: "refact", + MODEL_ARCH.BERT: "bert", + MODEL_ARCH.NOMIC_BERT: "nomic-bert", + MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe", + MODEL_ARCH.NEO_BERT: "neo-bert", + MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2", + MODEL_ARCH.BLOOM: "bloom", + MODEL_ARCH.STABLELM: "stablelm", + MODEL_ARCH.QWEN: "qwen", + MODEL_ARCH.QWEN2: "qwen2", + MODEL_ARCH.QWEN2MOE: "qwen2moe", + MODEL_ARCH.QWEN2VL: "qwen2vl", + MODEL_ARCH.QWEN3: "qwen3", + MODEL_ARCH.QWEN3MOE: "qwen3moe", + MODEL_ARCH.PHI2: "phi2", + MODEL_ARCH.PHI3: "phi3", + MODEL_ARCH.PHIMOE: "phimoe", + MODEL_ARCH.PLAMO: "plamo", + MODEL_ARCH.CODESHELL: "codeshell", + MODEL_ARCH.ORION: "orion", + MODEL_ARCH.INTERNLM2: "internlm2", + MODEL_ARCH.MINICPM: "minicpm", + MODEL_ARCH.MINICPM3: "minicpm3", + MODEL_ARCH.GEMMA: "gemma", + MODEL_ARCH.GEMMA2: "gemma2", + MODEL_ARCH.GEMMA3: "gemma3", + MODEL_ARCH.STARCODER2: "starcoder2", + MODEL_ARCH.RWKV6: "rwkv6", + MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2", + MODEL_ARCH.RWKV7: "rwkv7", + MODEL_ARCH.ARWKV7: "arwkv7", + MODEL_ARCH.MAMBA: "mamba", + MODEL_ARCH.XVERSE: "xverse", + MODEL_ARCH.COMMAND_R: "command-r", + MODEL_ARCH.COHERE2: "cohere2", + MODEL_ARCH.DBRX: "dbrx", + MODEL_ARCH.OLMO: "olmo", + MODEL_ARCH.OLMO2: "olmo2", + MODEL_ARCH.OLMOE: "olmoe", + MODEL_ARCH.OPENELM: "openelm", + MODEL_ARCH.ARCTIC: "arctic", + MODEL_ARCH.DEEPSEEK: "deepseek", + MODEL_ARCH.DEEPSEEK2: "deepseek2", + MODEL_ARCH.CHATGLM: "chatglm", + MODEL_ARCH.GLM4: "glm4", + MODEL_ARCH.BITNET: "bitnet", + MODEL_ARCH.T5: "t5", + MODEL_ARCH.T5ENCODER: "t5encoder", + MODEL_ARCH.JAIS: "jais", + MODEL_ARCH.NEMOTRON: "nemotron", + MODEL_ARCH.EXAONE: "exaone", + MODEL_ARCH.GRANITE: "granite", + MODEL_ARCH.GRANITE_MOE: "granitemoe", + MODEL_ARCH.CHAMELEON: "chameleon", + MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", + MODEL_ARCH.PLM: "plm", + MODEL_ARCH.BAILINGMOE: "bailingmoe", + MODEL_ARCH.DOTS1: "dots1", + MODEL_ARCH.ARCEE: "arcee", +} + +VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { + VISION_PROJECTOR_TYPE.MLP: "mlp", + VISION_PROJECTOR_TYPE.LDP: "ldp", + VISION_PROJECTOR_TYPE.LDPV2: "ldpv2", + VISION_PROJECTOR_TYPE.RESAMPLER: "resampler", + VISION_PROJECTOR_TYPE.GLM_EDGE: "adapter", + VISION_PROJECTOR_TYPE.MERGER: "qwen2vl_merger", + VISION_PROJECTOR_TYPE.GEMMA3: "gemma3", +} + +TENSOR_NAMES: dict[MODEL_TENSOR, str] = { + MODEL_TENSOR.TOKEN_EMBD: "token_embd", + MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm", + MODEL_TENSOR.TOKEN_TYPES: "token_types", + MODEL_TENSOR.POS_EMBD: "position_embd", + MODEL_TENSOR.OUTPUT_NORM: "output_norm", + MODEL_TENSOR.OUTPUT: "output", + MODEL_TENSOR.ROPE_FREQS: "rope_freqs", + MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long", + MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short", + MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", + MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2", + MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", + MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q", + MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k", + MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v", + MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", + MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", + MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", + MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", + MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm", + MODEL_TENSOR.ATTN_POST_NORM: "blk.{bid}.post_attention_norm", + MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp", + MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp", + MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", + MODEL_TENSOR.FFN_PRE_NORM: "blk.{bid}.ffn_norm", + MODEL_TENSOR.FFN_POST_NORM: "blk.{bid}.post_ffw_norm", + MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", + MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", + MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", + MODEL_TENSOR.FFN_GATE_SHEXP: "blk.{bid}.ffn_gate_shexp", + MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp", + MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp", + MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn", + MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps", + MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", + MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", + MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", + MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", + MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", + MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", + MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", + MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x", + MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", + MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", + MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", + MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", + MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0", + MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1", + MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2", + MODEL_TENSOR.TIME_MIX_A0: "blk.{bid}.time_mix_a0", + MODEL_TENSOR.TIME_MIX_A1: "blk.{bid}.time_mix_a1", + MODEL_TENSOR.TIME_MIX_A2: "blk.{bid}.time_mix_a2", + MODEL_TENSOR.TIME_MIX_V0: "blk.{bid}.time_mix_v0", + MODEL_TENSOR.TIME_MIX_V1: "blk.{bid}.time_mix_v1", + MODEL_TENSOR.TIME_MIX_V2: "blk.{bid}.time_mix_v2", + MODEL_TENSOR.TIME_MIX_G1: "blk.{bid}.time_mix_g1", + MODEL_TENSOR.TIME_MIX_G2: "blk.{bid}.time_mix_g2", + MODEL_TENSOR.TIME_MIX_K_K: "blk.{bid}.time_mix_k_k", + MODEL_TENSOR.TIME_MIX_K_A: "blk.{bid}.time_mix_k_a", + MODEL_TENSOR.TIME_MIX_R_K: "blk.{bid}.time_mix_r_k", + MODEL_TENSOR.TIME_MIX_LERP_X: "blk.{bid}.time_mix_lerp_x", + MODEL_TENSOR.TIME_MIX_LERP_K: "blk.{bid}.time_mix_lerp_k", + MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v", + MODEL_TENSOR.TIME_MIX_LERP_R: "blk.{bid}.time_mix_lerp_r", + MODEL_TENSOR.TIME_MIX_LERP_G: "blk.{bid}.time_mix_lerp_g", + MODEL_TENSOR.TIME_MIX_LERP_FUSED: "blk.{bid}.time_mix_lerp_fused", + MODEL_TENSOR.TIME_MIX_LERP_W: "blk.{bid}.time_mix_lerp_w", + MODEL_TENSOR.TIME_MIX_FIRST: "blk.{bid}.time_mix_first", + MODEL_TENSOR.TIME_MIX_DECAY: "blk.{bid}.time_mix_decay", + MODEL_TENSOR.TIME_MIX_DECAY_W1: "blk.{bid}.time_mix_decay_w1", + MODEL_TENSOR.TIME_MIX_DECAY_W2: "blk.{bid}.time_mix_decay_w2", + MODEL_TENSOR.TIME_MIX_KEY: "blk.{bid}.time_mix_key", + MODEL_TENSOR.TIME_MIX_VALUE: "blk.{bid}.time_mix_value", + MODEL_TENSOR.TIME_MIX_RECEPTANCE: "blk.{bid}.time_mix_receptance", + MODEL_TENSOR.TIME_MIX_GATE: "blk.{bid}.time_mix_gate", + MODEL_TENSOR.TIME_MIX_LN: "blk.{bid}.time_mix_ln", + MODEL_TENSOR.TIME_MIX_OUTPUT: "blk.{bid}.time_mix_output", + MODEL_TENSOR.CHANNEL_MIX_LERP_K: "blk.{bid}.channel_mix_lerp_k", + MODEL_TENSOR.CHANNEL_MIX_LERP_R: "blk.{bid}.channel_mix_lerp_r", + MODEL_TENSOR.CHANNEL_MIX_KEY: "blk.{bid}.channel_mix_key", + MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: "blk.{bid}.channel_mix_receptance", + MODEL_TENSOR.CHANNEL_MIX_VALUE: "blk.{bid}.channel_mix_value", + MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a", + MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", + MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa", + MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", + MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b", + MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b", + MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", + MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", + MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", + MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm", + MODEL_TENSOR.DEC_ATTN_NORM: "dec.blk.{bid}.attn_norm", + MODEL_TENSOR.DEC_ATTN_Q: "dec.blk.{bid}.attn_q", + MODEL_TENSOR.DEC_ATTN_K: "dec.blk.{bid}.attn_k", + MODEL_TENSOR.DEC_ATTN_V: "dec.blk.{bid}.attn_v", + MODEL_TENSOR.DEC_ATTN_OUT: "dec.blk.{bid}.attn_o", + MODEL_TENSOR.DEC_ATTN_REL_B: "dec.blk.{bid}.attn_rel_b", + MODEL_TENSOR.DEC_CROSS_ATTN_NORM: "dec.blk.{bid}.cross_attn_norm", + MODEL_TENSOR.DEC_CROSS_ATTN_Q: "dec.blk.{bid}.cross_attn_q", + MODEL_TENSOR.DEC_CROSS_ATTN_K: "dec.blk.{bid}.cross_attn_k", + MODEL_TENSOR.DEC_CROSS_ATTN_V: "dec.blk.{bid}.cross_attn_v", + MODEL_TENSOR.DEC_CROSS_ATTN_OUT: "dec.blk.{bid}.cross_attn_o", + MODEL_TENSOR.DEC_CROSS_ATTN_REL_B: "dec.blk.{bid}.cross_attn_rel_b", + MODEL_TENSOR.DEC_FFN_NORM: "dec.blk.{bid}.ffn_norm", + MODEL_TENSOR.DEC_FFN_GATE: "dec.blk.{bid}.ffn_gate", + MODEL_TENSOR.DEC_FFN_DOWN: "dec.blk.{bid}.ffn_down", + MODEL_TENSOR.DEC_FFN_UP: "dec.blk.{bid}.ffn_up", + MODEL_TENSOR.DEC_OUTPUT_NORM: "dec.output_norm", + MODEL_TENSOR.ENC_ATTN_NORM: "enc.blk.{bid}.attn_norm", + MODEL_TENSOR.ENC_ATTN_Q: "enc.blk.{bid}.attn_q", + MODEL_TENSOR.ENC_ATTN_K: "enc.blk.{bid}.attn_k", + MODEL_TENSOR.ENC_ATTN_V: "enc.blk.{bid}.attn_v", + MODEL_TENSOR.ENC_ATTN_OUT: "enc.blk.{bid}.attn_o", + MODEL_TENSOR.ENC_ATTN_REL_B: "enc.blk.{bid}.attn_rel_b", + MODEL_TENSOR.ENC_FFN_NORM: "enc.blk.{bid}.ffn_norm", + MODEL_TENSOR.ENC_FFN_GATE: "enc.blk.{bid}.ffn_gate", + MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down", + MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up", + MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm", + MODEL_TENSOR.CLS: "cls", + MODEL_TENSOR.CLS_OUT: "cls.output", + MODEL_TENSOR.CONV1D: "conv1d", + MODEL_TENSOR.CONVNEXT_DW: "convnext.{bid}.dw", + MODEL_TENSOR.CONVNEXT_NORM: "convnext.{bid}.norm", + MODEL_TENSOR.CONVNEXT_PW1: "convnext.{bid}.pw1", + MODEL_TENSOR.CONVNEXT_PW2: "convnext.{bid}.pw2", + MODEL_TENSOR.CONVNEXT_GAMMA: "convnext.{bid}.gamma", + MODEL_TENSOR.POSNET_CONV1: "posnet.{bid}.conv1", + MODEL_TENSOR.POSNET_CONV2: "posnet.{bid}.conv2", + MODEL_TENSOR.POSNET_NORM: "posnet.{bid}.norm", + MODEL_TENSOR.POSNET_NORM1: "posnet.{bid}.norm1", + MODEL_TENSOR.POSNET_NORM2: "posnet.{bid}.norm2", + MODEL_TENSOR.POSNET_ATTN_NORM: "posnet.{bid}.attn_norm", + MODEL_TENSOR.POSNET_ATTN_Q: "posnet.{bid}.attn_q", + MODEL_TENSOR.POSNET_ATTN_K: "posnet.{bid}.attn_k", + MODEL_TENSOR.POSNET_ATTN_V: "posnet.{bid}.attn_v", + MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output", + # vision + MODEL_TENSOR.V_MMPROJ: "mm.{bid}", + MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc", + MODEL_TENSOR.V_MMPROJ_MLP: "mm.model.mlp.{bid}", + MODEL_TENSOR.V_MMPROJ_PEG: "mm.model.peg.{bid}", + MODEL_TENSOR.V_ENC_EMBD_CLS: "v.class_embd", + MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.patch_embd", + MODEL_TENSOR.V_ENC_EMBD_POS: "v.position_embd", + MODEL_TENSOR.V_ENC_ATTN_Q: "v.blk.{bid}.attn_q", + MODEL_TENSOR.V_ENC_ATTN_Q_NORM: "v.blk.{bid}.attn_q_norm", + MODEL_TENSOR.V_ENC_ATTN_K: "v.blk.{bid}.attn_k", + MODEL_TENSOR.V_ENC_ATTN_K_NORM: "v.blk.{bid}.attn_k_norm", + MODEL_TENSOR.V_ENC_ATTN_V: "v.blk.{bid}.attn_v", + MODEL_TENSOR.V_ENC_INPUT_NORM: "v.blk.{bid}.ln1", + MODEL_TENSOR.V_ENC_ATTN_O: "v.blk.{bid}.attn_out", + MODEL_TENSOR.V_ENC_ATTN_O_NORM: "v.blk.{bid}.attn_out_norm", + MODEL_TENSOR.V_ENC_POST_ATTN_NORM: "v.blk.{bid}.ln2", + MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up", + MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate", + MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down", + MODEL_TENSOR.V_LAYER_SCALE_1: "v.blk.{bid}.ls1", + MODEL_TENSOR.V_LAYER_SCALE_2: "v.blk.{bid}.ls2", + MODEL_TENSOR.V_PRE_NORM: "v.pre_ln", + MODEL_TENSOR.V_POST_NORM: "v.post_ln", + MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection", + MODEL_TENSOR.V_MM_INP_NORM: "mm.input_norm", + MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm", + MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "resampler.pos_embd_k", + MODEL_TENSOR.V_RESMPL_ATTN_Q: "resampler.attn.q", + MODEL_TENSOR.V_RESMPL_ATTN_K: "resampler.attn.k", + MODEL_TENSOR.V_RESMPL_ATTN_V: "resampler.attn.v", + MODEL_TENSOR.V_RESMPL_ATTN_OUT: "resampler.attn.out", + MODEL_TENSOR.V_RESMPL_KV: "resampler.kv", + MODEL_TENSOR.V_RESMPL_KV_NORM: "resampler.ln_kv", + MODEL_TENSOR.V_RESMPL_POST_NORM: "resampler.ln_post", + MODEL_TENSOR.V_RESMPL_Q_NORM: "resampler.ln_q", + MODEL_TENSOR.V_RESMPL_PROJ: "resampler.proj", + MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query", + MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral + MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1 + # audio (mtmd) + MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd", + MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}", + MODEL_TENSOR.A_PRE_NORM: "a.pre_ln", + MODEL_TENSOR.A_POST_NORM: "a.post_ln", + MODEL_TENSOR.A_ENC_ATTN_Q: "a.blk.{bid}.attn_q", + MODEL_TENSOR.A_ENC_ATTN_K: "a.blk.{bid}.attn_k", + MODEL_TENSOR.A_ENC_ATTN_V: "a.blk.{bid}.attn_v", + MODEL_TENSOR.A_ENC_INPUT_NORM: "a.blk.{bid}.ln1", + MODEL_TENSOR.A_ENC_OUTPUT: "a.blk.{bid}.attn_out", + MODEL_TENSOR.A_ENC_OUTPUT_NORM: "a.blk.{bid}.ln2", + MODEL_TENSOR.A_ENC_FFN_UP: "a.blk.{bid}.ffn_up", + MODEL_TENSOR.A_ENC_FFN_GATE: "a.blk.{bid}.ffn_gate", + MODEL_TENSOR.A_ENC_FFN_DOWN: "a.blk.{bid}.ffn_down", + MODEL_TENSOR.A_MMPROJ: "mm.a.mlp.{bid}", + MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc", + MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre", + MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid", +} + +MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { + MODEL_ARCH.MMPROJ: [ + MODEL_TENSOR.V_MMPROJ, + MODEL_TENSOR.V_MMPROJ_FC, + MODEL_TENSOR.V_MMPROJ_MLP, + MODEL_TENSOR.V_MMPROJ_PEG, + MODEL_TENSOR.V_ENC_EMBD_CLS, + MODEL_TENSOR.V_ENC_EMBD_PATCH, + MODEL_TENSOR.V_ENC_EMBD_POS, + MODEL_TENSOR.V_ENC_INPUT_NORM, + MODEL_TENSOR.V_ENC_ATTN_Q, + MODEL_TENSOR.V_ENC_ATTN_Q_NORM, + MODEL_TENSOR.V_ENC_ATTN_K, + MODEL_TENSOR.V_ENC_ATTN_K_NORM, + MODEL_TENSOR.V_ENC_ATTN_V, + MODEL_TENSOR.V_ENC_ATTN_O, + MODEL_TENSOR.V_ENC_ATTN_O_NORM, + MODEL_TENSOR.V_ENC_POST_ATTN_NORM, + MODEL_TENSOR.V_ENC_FFN_UP, + MODEL_TENSOR.V_ENC_FFN_GATE, + MODEL_TENSOR.V_ENC_FFN_DOWN, + MODEL_TENSOR.V_LAYER_SCALE_1, + MODEL_TENSOR.V_LAYER_SCALE_2, + MODEL_TENSOR.V_PRE_NORM, + MODEL_TENSOR.V_POST_NORM, + MODEL_TENSOR.V_MM_INP_PROJ, + MODEL_TENSOR.V_MM_INP_NORM, + MODEL_TENSOR.V_MM_SOFT_EMB_NORM, + MODEL_TENSOR.V_RESMPL_POS_EMBD_K, + MODEL_TENSOR.V_RESMPL_ATTN_Q, + MODEL_TENSOR.V_RESMPL_ATTN_K, + MODEL_TENSOR.V_RESMPL_ATTN_V, + MODEL_TENSOR.V_RESMPL_ATTN_OUT, + MODEL_TENSOR.V_RESMPL_KV, + MODEL_TENSOR.V_RESMPL_KV_NORM, + MODEL_TENSOR.V_RESMPL_POST_NORM, + MODEL_TENSOR.V_RESMPL_Q_NORM, + MODEL_TENSOR.V_RESMPL_PROJ, + MODEL_TENSOR.V_RESMPL_QUERY, + MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK, + MODEL_TENSOR.V_MM_PATCH_MERGER, + # audio + MODEL_TENSOR.A_ENC_EMBD_POS, + MODEL_TENSOR.A_ENC_CONV1D, + MODEL_TENSOR.A_PRE_NORM, + MODEL_TENSOR.A_POST_NORM, + MODEL_TENSOR.A_ENC_ATTN_Q, + MODEL_TENSOR.A_ENC_ATTN_K, + MODEL_TENSOR.A_ENC_ATTN_V, + MODEL_TENSOR.A_ENC_INPUT_NORM, + MODEL_TENSOR.A_ENC_OUTPUT, + MODEL_TENSOR.A_ENC_OUTPUT_NORM, + MODEL_TENSOR.A_ENC_FFN_UP, + MODEL_TENSOR.A_ENC_FFN_GATE, + MODEL_TENSOR.A_ENC_FFN_DOWN, + MODEL_TENSOR.A_MMPROJ, + MODEL_TENSOR.A_MMPROJ_FC, + MODEL_TENSOR.A_MM_NORM_PRE, + MODEL_TENSOR.A_MM_NORM_MID, + ], + MODEL_ARCH.LLAMA: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], + MODEL_ARCH.LLAMA4: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], + MODEL_ARCH.DECI: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], + MODEL_ARCH.GROK: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.ATTN_OUT_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.LAYER_OUT_NORM, + ], + MODEL_ARCH.GPTNEOX: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.FALCON: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_NORM_2, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.BAICHUAN: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.STARCODER: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.POS_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.BERT: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.TOKEN_TYPES, + MODEL_TENSOR.POS_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_OUT_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.LAYER_OUT_NORM, + MODEL_TENSOR.CLS, + MODEL_TENSOR.CLS_OUT, + ], + MODEL_ARCH.NOMIC_BERT: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.TOKEN_TYPES, + MODEL_TENSOR.POS_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_OUT_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.LAYER_OUT_NORM, + ], + MODEL_ARCH.NOMIC_BERT_MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.TOKEN_TYPES, + MODEL_TENSOR.POS_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_OUT_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.LAYER_OUT_NORM, + ], + MODEL_ARCH.NEO_BERT: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ENC_OUTPUT_NORM, + MODEL_TENSOR.CLS, + MODEL_TENSOR.CLS_OUT, + ], + MODEL_ARCH.JINA_BERT_V2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.TOKEN_TYPES, + MODEL_TENSOR.ATTN_NORM_2, + MODEL_TENSOR.ATTN_OUT_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.LAYER_OUT_NORM, + MODEL_TENSOR.CLS, + ], + MODEL_ARCH.MPT: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_ACT, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.POS_EMBD, + ], + MODEL_ARCH.GPTJ: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.REFACT: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.BLOOM: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.STABLELM: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + ], + MODEL_ARCH.QWEN: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.QWEN2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.QWEN2VL: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.QWEN2MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_INP_SHEXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], + MODEL_ARCH.QWEN3: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.QWEN3MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], + MODEL_ARCH.PLAMO: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.GPT2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.POS_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.PHI2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.PHI3: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FACTORS_LONG, + MODEL_TENSOR.ROPE_FACTORS_SHORT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.PHIMOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FACTORS_LONG, + MODEL_TENSOR.ROPE_FACTORS_SHORT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], + MODEL_ARCH.CODESHELL: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.POS_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.ORION: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.INTERNLM2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.MINICPM: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ROPE_FACTORS_LONG, + MODEL_TENSOR.ROPE_FACTORS_SHORT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], + MODEL_ARCH.MINICPM3: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FACTORS_LONG, + MODEL_TENSOR.ROPE_FACTORS_SHORT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q_A, + MODEL_TENSOR.ATTN_Q_B, + MODEL_TENSOR.ATTN_KV_A_MQA, + MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_Q_A_NORM, + MODEL_TENSOR.ATTN_KV_A_NORM, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.GEMMA: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_NORM, + ], + MODEL_ARCH.GEMMA2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_PRE_NORM, + MODEL_TENSOR.FFN_POST_NORM, + ], + MODEL_ARCH.GEMMA3: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_PRE_NORM, + MODEL_TENSOR.FFN_POST_NORM, + ], + MODEL_ARCH.STARCODER2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.RWKV6: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_NORM_2, + MODEL_TENSOR.TIME_MIX_W1, + MODEL_TENSOR.TIME_MIX_W2, + MODEL_TENSOR.TIME_MIX_LERP_X, + MODEL_TENSOR.TIME_MIX_LERP_K, + MODEL_TENSOR.TIME_MIX_LERP_V, + MODEL_TENSOR.TIME_MIX_LERP_R, + MODEL_TENSOR.TIME_MIX_LERP_G, + MODEL_TENSOR.TIME_MIX_LERP_W, + MODEL_TENSOR.TIME_MIX_LERP_FUSED, + MODEL_TENSOR.TIME_MIX_FIRST, + MODEL_TENSOR.TIME_MIX_DECAY, + MODEL_TENSOR.TIME_MIX_DECAY_W1, + MODEL_TENSOR.TIME_MIX_DECAY_W2, + MODEL_TENSOR.TIME_MIX_KEY, + MODEL_TENSOR.TIME_MIX_VALUE, + MODEL_TENSOR.TIME_MIX_RECEPTANCE, + MODEL_TENSOR.TIME_MIX_GATE, + MODEL_TENSOR.TIME_MIX_LN, + MODEL_TENSOR.TIME_MIX_OUTPUT, + MODEL_TENSOR.CHANNEL_MIX_LERP_K, + MODEL_TENSOR.CHANNEL_MIX_LERP_R, + MODEL_TENSOR.CHANNEL_MIX_KEY, + MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE, + MODEL_TENSOR.CHANNEL_MIX_VALUE, + ], + MODEL_ARCH.RWKV6QWEN2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.TIME_MIX_W1, + MODEL_TENSOR.TIME_MIX_W2, + MODEL_TENSOR.TIME_MIX_LERP_X, + MODEL_TENSOR.TIME_MIX_LERP_K, + MODEL_TENSOR.TIME_MIX_LERP_V, + MODEL_TENSOR.TIME_MIX_LERP_R, + MODEL_TENSOR.TIME_MIX_LERP_G, + MODEL_TENSOR.TIME_MIX_LERP_W, + MODEL_TENSOR.TIME_MIX_LERP_FUSED, + MODEL_TENSOR.TIME_MIX_FIRST, + MODEL_TENSOR.TIME_MIX_DECAY, + MODEL_TENSOR.TIME_MIX_DECAY_W1, + MODEL_TENSOR.TIME_MIX_DECAY_W2, + MODEL_TENSOR.TIME_MIX_KEY, + MODEL_TENSOR.TIME_MIX_VALUE, + MODEL_TENSOR.TIME_MIX_RECEPTANCE, + MODEL_TENSOR.TIME_MIX_GATE, + MODEL_TENSOR.TIME_MIX_LN, + MODEL_TENSOR.TIME_MIX_OUTPUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.RWKV7: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_NORM_2, + MODEL_TENSOR.TIME_MIX_LERP_FUSED, + MODEL_TENSOR.TIME_MIX_W0, + MODEL_TENSOR.TIME_MIX_W1, + MODEL_TENSOR.TIME_MIX_W2, + MODEL_TENSOR.TIME_MIX_A0, + MODEL_TENSOR.TIME_MIX_A1, + MODEL_TENSOR.TIME_MIX_A2, + MODEL_TENSOR.TIME_MIX_V0, + MODEL_TENSOR.TIME_MIX_V1, + MODEL_TENSOR.TIME_MIX_V2, + MODEL_TENSOR.TIME_MIX_G1, + MODEL_TENSOR.TIME_MIX_G2, + MODEL_TENSOR.TIME_MIX_K_K, + MODEL_TENSOR.TIME_MIX_K_A, + MODEL_TENSOR.TIME_MIX_R_K, + MODEL_TENSOR.TIME_MIX_KEY, + MODEL_TENSOR.TIME_MIX_VALUE, + MODEL_TENSOR.TIME_MIX_RECEPTANCE, + MODEL_TENSOR.TIME_MIX_LN, + MODEL_TENSOR.TIME_MIX_OUTPUT, + MODEL_TENSOR.CHANNEL_MIX_LERP_K, + MODEL_TENSOR.CHANNEL_MIX_KEY, + MODEL_TENSOR.CHANNEL_MIX_VALUE, + ], + MODEL_ARCH.ARWKV7: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.TIME_MIX_LERP_FUSED, + MODEL_TENSOR.TIME_MIX_W0, + MODEL_TENSOR.TIME_MIX_W1, + MODEL_TENSOR.TIME_MIX_W2, + MODEL_TENSOR.TIME_MIX_A0, + MODEL_TENSOR.TIME_MIX_A1, + MODEL_TENSOR.TIME_MIX_A2, + MODEL_TENSOR.TIME_MIX_V0, + MODEL_TENSOR.TIME_MIX_V1, + MODEL_TENSOR.TIME_MIX_V2, + MODEL_TENSOR.TIME_MIX_G1, + MODEL_TENSOR.TIME_MIX_G2, + MODEL_TENSOR.TIME_MIX_K_K, + MODEL_TENSOR.TIME_MIX_K_A, + MODEL_TENSOR.TIME_MIX_R_K, + MODEL_TENSOR.TIME_MIX_KEY, + MODEL_TENSOR.TIME_MIX_VALUE, + MODEL_TENSOR.TIME_MIX_RECEPTANCE, + MODEL_TENSOR.TIME_MIX_LN, + MODEL_TENSOR.TIME_MIX_OUTPUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.MAMBA: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_X, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_OUT, + ], + MODEL_ARCH.XVERSE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.COMMAND_R: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_Q_NORM, + ], + MODEL_ARCH.COHERE2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.DBRX: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_OUT_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], + MODEL_ARCH.OLMO: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.OLMO2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.FFN_POST_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.OLMOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + ], + MODEL_ARCH.OPENELM: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.ARCTIC: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_NORM_EXP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], + MODEL_ARCH.DEEPSEEK: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], + MODEL_ARCH.DEEPSEEK2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_A, + MODEL_TENSOR.ATTN_Q_B, + MODEL_TENSOR.ATTN_KV_A_MQA, + MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_K_B, + MODEL_TENSOR.ATTN_V_B, + MODEL_TENSOR.ATTN_Q_A_NORM, + MODEL_TENSOR.ATTN_KV_A_NORM, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + ], + MODEL_ARCH.PLM: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_KV_A_MQA, + MODEL_TENSOR.ATTN_KV_A_NORM, + MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_DOWN, + ], + MODEL_ARCH.CHATGLM : [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.GLM4 : [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_POST_NORM, + ], + MODEL_ARCH.BITNET: [ + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ATTN_SUB_NORM, + MODEL_TENSOR.FFN_SUB_NORM, + ], + MODEL_ARCH.T5: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.DEC_ATTN_NORM, + MODEL_TENSOR.DEC_ATTN_Q, + MODEL_TENSOR.DEC_ATTN_K, + MODEL_TENSOR.DEC_ATTN_V, + MODEL_TENSOR.DEC_ATTN_OUT, + MODEL_TENSOR.DEC_ATTN_REL_B, + MODEL_TENSOR.DEC_CROSS_ATTN_NORM, + MODEL_TENSOR.DEC_CROSS_ATTN_Q, + MODEL_TENSOR.DEC_CROSS_ATTN_K, + MODEL_TENSOR.DEC_CROSS_ATTN_V, + MODEL_TENSOR.DEC_CROSS_ATTN_OUT, + MODEL_TENSOR.DEC_CROSS_ATTN_REL_B, + MODEL_TENSOR.DEC_FFN_NORM, + MODEL_TENSOR.DEC_FFN_GATE, + MODEL_TENSOR.DEC_FFN_DOWN, + MODEL_TENSOR.DEC_FFN_UP, + MODEL_TENSOR.DEC_OUTPUT_NORM, + MODEL_TENSOR.ENC_ATTN_NORM, + MODEL_TENSOR.ENC_ATTN_Q, + MODEL_TENSOR.ENC_ATTN_K, + MODEL_TENSOR.ENC_ATTN_V, + MODEL_TENSOR.ENC_ATTN_OUT, + MODEL_TENSOR.ENC_ATTN_REL_B, + MODEL_TENSOR.ENC_FFN_NORM, + MODEL_TENSOR.ENC_FFN_GATE, + MODEL_TENSOR.ENC_FFN_DOWN, + MODEL_TENSOR.ENC_FFN_UP, + MODEL_TENSOR.ENC_OUTPUT_NORM, + ], + MODEL_ARCH.T5ENCODER: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ENC_ATTN_NORM, + MODEL_TENSOR.ENC_ATTN_Q, + MODEL_TENSOR.ENC_ATTN_K, + MODEL_TENSOR.ENC_ATTN_V, + MODEL_TENSOR.ENC_ATTN_OUT, + MODEL_TENSOR.ENC_ATTN_REL_B, + MODEL_TENSOR.ENC_FFN_NORM, + MODEL_TENSOR.ENC_FFN_GATE, + MODEL_TENSOR.ENC_FFN_DOWN, + MODEL_TENSOR.ENC_FFN_UP, + MODEL_TENSOR.ENC_OUTPUT_NORM, + ], + MODEL_ARCH.JAIS: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.NEMOTRON: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.EXAONE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.GRANITE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.GRANITE_MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + ], + MODEL_ARCH.CHAMELEON: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.WAVTOKENIZER_DEC: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.CONV1D, + MODEL_TENSOR.CONVNEXT_DW, + MODEL_TENSOR.CONVNEXT_NORM, + MODEL_TENSOR.CONVNEXT_PW1, + MODEL_TENSOR.CONVNEXT_PW2, + MODEL_TENSOR.CONVNEXT_GAMMA, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.POSNET_CONV1, + MODEL_TENSOR.POSNET_CONV2, + MODEL_TENSOR.POSNET_NORM, + MODEL_TENSOR.POSNET_NORM1, + MODEL_TENSOR.POSNET_NORM2, + MODEL_TENSOR.POSNET_ATTN_NORM, + MODEL_TENSOR.POSNET_ATTN_Q, + MODEL_TENSOR.POSNET_ATTN_K, + MODEL_TENSOR.POSNET_ATTN_V, + MODEL_TENSOR.POSNET_ATTN_OUT, + ], + MODEL_ARCH.BAILINGMOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], + MODEL_ARCH.DOTS1: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_EXP_PROBS_B, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], + MODEL_ARCH.ARCEE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + # TODO +} + +# tensors that will not be serialized +MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { + MODEL_ARCH.LLAMA: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], + MODEL_ARCH.DECI: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], + MODEL_ARCH.BAICHUAN: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], + MODEL_ARCH.QWEN: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], + MODEL_ARCH.CODESHELL: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], + MODEL_ARCH.ORION: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], + MODEL_ARCH.STARCODER2: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], + MODEL_ARCH.XVERSE: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], + MODEL_ARCH.DEEPSEEK: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], + MODEL_ARCH.DEEPSEEK2: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], + MODEL_ARCH.CHATGLM: [ + MODEL_TENSOR.ROPE_FREQS, + ], + MODEL_ARCH.NEMOTRON: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], + MODEL_ARCH.BAILINGMOE: [ + MODEL_TENSOR.ROPE_FREQS, + ], +} + +# +# types +# + + +class TokenType(IntEnum): + NORMAL = 1 + UNKNOWN = 2 + CONTROL = 3 + USER_DEFINED = 4 + UNUSED = 5 + BYTE = 6 + + +class RopeScalingType(Enum): + NONE = 'none' + LINEAR = 'linear' + YARN = 'yarn' + LONGROPE = 'longrope' + + +class PoolingType(IntEnum): + NONE = 0 + MEAN = 1 + CLS = 2 + LAST = 3 + RANK = 4 + + +class GGMLQuantizationType(IntEnum): + F32 = 0 + F16 = 1 + Q4_0 = 2 + Q4_1 = 3 + Q5_0 = 6 + Q5_1 = 7 + Q8_0 = 8 + Q8_1 = 9 + Q2_K = 10 + Q3_K = 11 + Q4_K = 12 + Q5_K = 13 + Q6_K = 14 + Q8_K = 15 + IQ2_XXS = 16 + IQ2_XS = 17 + IQ3_XXS = 18 + IQ1_S = 19 + IQ4_NL = 20 + IQ3_S = 21 + IQ2_S = 22 + IQ4_XS = 23 + I8 = 24 + I16 = 25 + I32 = 26 + I64 = 27 + F64 = 28 + IQ1_M = 29 + BF16 = 30 + TQ1_0 = 34 + TQ2_0 = 35 + + +class ExpertGatingFuncType(IntEnum): + SOFTMAX = 1 + SIGMOID = 2 + + +# TODO: add GGMLFileType from ggml_ftype in ggml.h + + +# from llama_ftype in llama.h +# ALL VALUES SHOULD BE THE SAME HERE AS THEY ARE OVER THERE. +class LlamaFileType(IntEnum): + ALL_F32 = 0 + MOSTLY_F16 = 1 # except 1d tensors + MOSTLY_Q4_0 = 2 # except 1d tensors + MOSTLY_Q4_1 = 3 # except 1d tensors + # MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16 + # MOSTLY_Q4_2 = 5 # support has been removed + # MOSTLY_Q4_3 = 6 # support has been removed + MOSTLY_Q8_0 = 7 # except 1d tensors + MOSTLY_Q5_0 = 8 # except 1d tensors + MOSTLY_Q5_1 = 9 # except 1d tensors + MOSTLY_Q2_K = 10 # except 1d tensors + MOSTLY_Q3_K_S = 11 # except 1d tensors + MOSTLY_Q3_K_M = 12 # except 1d tensors + MOSTLY_Q3_K_L = 13 # except 1d tensors + MOSTLY_Q4_K_S = 14 # except 1d tensors + MOSTLY_Q4_K_M = 15 # except 1d tensors + MOSTLY_Q5_K_S = 16 # except 1d tensors + MOSTLY_Q5_K_M = 17 # except 1d tensors + MOSTLY_Q6_K = 18 # except 1d tensors + MOSTLY_IQ2_XXS = 19 # except 1d tensors + MOSTLY_IQ2_XS = 20 # except 1d tensors + MOSTLY_Q2_K_S = 21 # except 1d tensors + MOSTLY_IQ3_XS = 22 # except 1d tensors + MOSTLY_IQ3_XXS = 23 # except 1d tensors + MOSTLY_IQ1_S = 24 # except 1d tensors + MOSTLY_IQ4_NL = 25 # except 1d tensors + MOSTLY_IQ3_S = 26 # except 1d tensors + MOSTLY_IQ3_M = 27 # except 1d tensors + MOSTLY_IQ2_S = 28 # except 1d tensors + MOSTLY_IQ2_M = 29 # except 1d tensors + MOSTLY_IQ4_XS = 30 # except 1d tensors + MOSTLY_IQ1_M = 31 # except 1d tensors + MOSTLY_BF16 = 32 # except 1d tensors + # MOSTLY_Q4_0_4_4 = 33 # removed from gguf files, use Q4_0 and runtime repack + # MOSTLY_Q4_0_4_8 = 34 # removed from gguf files, use Q4_0 and runtime repack + # MOSTLY_Q4_0_8_8 = 35 # removed from gguf files, use Q4_0 and runtime repack + MOSTLY_TQ1_0 = 36 # except 1d tensors + MOSTLY_TQ2_0 = 37 # except 1d tensors + + GUESSED = 1024 # not specified in the model file + + +class GGUFEndian(IntEnum): + LITTLE = 0 + BIG = 1 + + +class GGUFValueType(IntEnum): + UINT8 = 0 + INT8 = 1 + UINT16 = 2 + INT16 = 3 + UINT32 = 4 + INT32 = 5 + FLOAT32 = 6 + BOOL = 7 + STRING = 8 + ARRAY = 9 + UINT64 = 10 + INT64 = 11 + FLOAT64 = 12 + + @staticmethod + def get_type(val: Any) -> GGUFValueType: + if isinstance(val, (str, bytes, bytearray)): + return GGUFValueType.STRING + elif isinstance(val, list): + return GGUFValueType.ARRAY + elif isinstance(val, float): + return GGUFValueType.FLOAT32 + elif isinstance(val, bool): + return GGUFValueType.BOOL + elif isinstance(val, int): + return GGUFValueType.INT32 + # TODO: need help with 64-bit types in Python + else: + raise ValueError(f"Unknown type: {type(val)}") + + +class VisionProjectorType: + GEMMA3 = "gemma3" + IDEFICS3 = "idefics3" + PIXTRAL = "pixtral" + LLAMA4 = "llama4" + QWEN2VL = "qwen2vl_merger" + QWEN25VL = "qwen2.5vl_merger" + ULTRAVOX = "ultravox" + INTERNVL = "internvl" + QWEN2A = "qwen2a" # audio + QWEN25O = "qwen2.5o" # omni + + +# Items here are (block size, type size) +QK_K = 256 +GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = { + GGMLQuantizationType.F32: (1, 4), + GGMLQuantizationType.F16: (1, 2), + GGMLQuantizationType.Q4_0: (32, 2 + 16), + GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16), + GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16), + GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16), + GGMLQuantizationType.Q8_0: (32, 2 + 32), + GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32), + GGMLQuantizationType.Q2_K: (256, 2 + 2 + QK_K // 16 + QK_K // 4), + GGMLQuantizationType.Q3_K: (256, 2 + QK_K // 4 + QK_K // 8 + 12), + GGMLQuantizationType.Q4_K: (256, 2 + 2 + QK_K // 2 + 12), + GGMLQuantizationType.Q5_K: (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12), + GGMLQuantizationType.Q6_K: (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16), + GGMLQuantizationType.Q8_K: (256, 4 + QK_K + QK_K // 8), + GGMLQuantizationType.IQ2_XXS: (256, 2 + QK_K // 4), + GGMLQuantizationType.IQ2_XS: (256, 2 + QK_K // 4 + QK_K // 32), + GGMLQuantizationType.IQ3_XXS: (256, 2 + QK_K // 4 + QK_K // 8), + GGMLQuantizationType.IQ1_S: (256, 2 + QK_K // 8 + QK_K // 16), + GGMLQuantizationType.IQ4_NL: (32, 2 + 16), + GGMLQuantizationType.IQ3_S: (256, 2 + QK_K // 4 + QK_K // 8 + QK_K // 32 + 4), + GGMLQuantizationType.IQ2_S: (256, 2 + QK_K // 4 + QK_K // 16), + GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + QK_K // 2 + QK_K // 64), + GGMLQuantizationType.I8: (1, 1), + GGMLQuantizationType.I16: (1, 2), + GGMLQuantizationType.I32: (1, 4), + GGMLQuantizationType.I64: (1, 8), + GGMLQuantizationType.F64: (1, 8), + GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32), + GGMLQuantizationType.BF16: (1, 2), + GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13), + GGMLQuantizationType.TQ2_0: (256, 2 + 64), +} + + +# Aliases for backward compatibility. + +# general +KEY_GENERAL_ARCHITECTURE = Keys.General.ARCHITECTURE +KEY_GENERAL_QUANTIZATION_VERSION = Keys.General.QUANTIZATION_VERSION +KEY_GENERAL_ALIGNMENT = Keys.General.ALIGNMENT +KEY_GENERAL_NAME = Keys.General.NAME +KEY_GENERAL_AUTHOR = Keys.General.AUTHOR +KEY_GENERAL_URL = Keys.General.URL +KEY_GENERAL_DESCRIPTION = Keys.General.DESCRIPTION +KEY_GENERAL_LICENSE = Keys.General.LICENSE +KEY_GENERAL_SOURCE_URL = Keys.General.SOURCE_URL +KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE + +# LLM +KEY_VOCAB_SIZE = Keys.LLM.VOCAB_SIZE +KEY_CONTEXT_LENGTH = Keys.LLM.CONTEXT_LENGTH +KEY_EMBEDDING_LENGTH = Keys.LLM.EMBEDDING_LENGTH +KEY_BLOCK_COUNT = Keys.LLM.BLOCK_COUNT +KEY_FEED_FORWARD_LENGTH = Keys.LLM.FEED_FORWARD_LENGTH +KEY_USE_PARALLEL_RESIDUAL = Keys.LLM.USE_PARALLEL_RESIDUAL +KEY_TENSOR_DATA_LAYOUT = Keys.LLM.TENSOR_DATA_LAYOUT + +# attention +KEY_ATTENTION_HEAD_COUNT = Keys.Attention.HEAD_COUNT +KEY_ATTENTION_HEAD_COUNT_KV = Keys.Attention.HEAD_COUNT_KV +KEY_ATTENTION_MAX_ALIBI_BIAS = Keys.Attention.MAX_ALIBI_BIAS +KEY_ATTENTION_CLAMP_KQV = Keys.Attention.CLAMP_KQV +KEY_ATTENTION_LAYERNORM_EPS = Keys.Attention.LAYERNORM_EPS +KEY_ATTENTION_LAYERNORM_RMS_EPS = Keys.Attention.LAYERNORM_RMS_EPS + +# RoPE +KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT +KEY_ROPE_FREQ_BASE = Keys.Rope.FREQ_BASE +KEY_ROPE_SCALING_TYPE = Keys.Rope.SCALING_TYPE +KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR +KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN +KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED + +# SSM +KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL +KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE +KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE +KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK +KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS + +# tokenization +KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL +KEY_TOKENIZER_PRE = Keys.Tokenizer.PRE +KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST +KEY_TOKENIZER_TOKEN_TYPE = Keys.Tokenizer.TOKEN_TYPE +KEY_TOKENIZER_SCORES = Keys.Tokenizer.SCORES +KEY_TOKENIZER_MERGES = Keys.Tokenizer.MERGES +KEY_TOKENIZER_BOS_ID = Keys.Tokenizer.BOS_ID +KEY_TOKENIZER_EOS_ID = Keys.Tokenizer.EOS_ID +KEY_TOKENIZER_EOT_ID = Keys.Tokenizer.EOT_ID +KEY_TOKENIZER_EOM_ID = Keys.Tokenizer.EOM_ID +KEY_TOKENIZER_UNK_ID = Keys.Tokenizer.UNK_ID +KEY_TOKENIZER_SEP_ID = Keys.Tokenizer.SEP_ID +KEY_TOKENIZER_PAD_ID = Keys.Tokenizer.PAD_ID +KEY_TOKENIZER_MASK_ID = Keys.Tokenizer.MASK_ID +KEY_TOKENIZER_HF_JSON = Keys.Tokenizer.HF_JSON +KEY_TOKENIZER_RWKV = Keys.Tokenizer.RWKV + +KEY_TOKENIZER_FIM_PRE_ID = Keys.Tokenizer.FIM_PRE_ID +KEY_TOKENIZER_FIM_SUF_ID = Keys.Tokenizer.FIM_SUF_ID +KEY_TOKENIZER_FIM_MID_ID = Keys.Tokenizer.FIM_MID_ID +KEY_TOKENIZER_FIM_PAD_ID = Keys.Tokenizer.FIM_PAD_ID +KEY_TOKENIZER_FIM_REP_ID = Keys.Tokenizer.FIM_REP_ID +KEY_TOKENIZER_FIM_SEP_ID = Keys.Tokenizer.FIM_SEP_ID + +# deprecated +KEY_TOKENIZER_PREFIX_ID = Keys.Tokenizer.PREFIX_ID +KEY_TOKENIZER_SUFFIX_ID = Keys.Tokenizer.SUFFIX_ID +KEY_TOKENIZER_MIDDLE_ID = Keys.Tokenizer.MIDDLE_ID diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py new file mode 100644 index 0000000000000000000000000000000000000000..651a81eb828248728f854c85c1a437b52892f275 --- /dev/null +++ b/gguf-py/gguf/gguf.py @@ -0,0 +1,15 @@ +# This file left for compatibility. If you want to use the GGUF API from Python +# then don't import gguf/gguf.py directly. If you're looking for examples, see the +# examples/ directory for gguf-py + +import importlib +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +# Compatibility for people trying to import gguf/gguf.py directly instead of as a package. +importlib.invalidate_caches() +import gguf # noqa: E402 + +importlib.reload(gguf) diff --git a/gguf-py/gguf/gguf_reader.py b/gguf-py/gguf/gguf_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..d87e8f72321b3f776cf214e117e21cdde9e1e18b --- /dev/null +++ b/gguf-py/gguf/gguf_reader.py @@ -0,0 +1,367 @@ +# +# GGUF file reading/modification support. For API usage information, +# please see the files scripts/ for some fairly simple examples. +# +from __future__ import annotations + +import logging +import os +import sys +from collections import OrderedDict +from typing import Any, Literal, NamedTuple, TypeVar, Union + +import numpy as np +import numpy.typing as npt + +from .quants import quant_shape_to_byte_shape + +if __name__ == "__main__": + from pathlib import Path + + # Allow running file in package as a script. + sys.path.insert(0, str(Path(__file__).parent.parent)) + +from gguf.constants import ( + GGML_QUANT_SIZES, + GGUF_DEFAULT_ALIGNMENT, + GGUF_MAGIC, + GGUF_VERSION, + GGMLQuantizationType, + GGUFValueType, + GGUFEndian, +) + +logger = logging.getLogger(__name__) + +READER_SUPPORTED_VERSIONS = [2, GGUF_VERSION] + + +class ReaderField(NamedTuple): + # Offset to start of this field. + offset: int + + # Name of the field (not necessarily from file data). + name: str + + # Data parts. Some types have multiple components, such as strings + # that consist of a length followed by the string data. + parts: list[npt.NDArray[Any]] = [] + + # Indexes into parts that we can call the actual data. For example + # an array of strings will be populated with indexes to the actual + # string data. + data: list[int] = [-1] + + types: list[GGUFValueType] = [] + + def contents(self, index_or_slice: int | slice = slice(None)) -> Any: + if self.types: + to_string = lambda x: str(x.tobytes(), encoding='utf-8') # noqa: E731 + main_type = self.types[0] + + if main_type == GGUFValueType.ARRAY: + sub_type = self.types[-1] + + if sub_type == GGUFValueType.STRING: + indices = self.data[index_or_slice] + + if isinstance(index_or_slice, int): + return to_string(self.parts[indices]) # type: ignore + else: + return [to_string(self.parts[idx]) for idx in indices] # type: ignore + else: + # FIXME: When/if _get_field_parts() support multi-dimensional arrays, this must do so too + + # Check if it's unsafe to perform slice optimization on data + # if any(True for idx in self.data if len(self.parts[idx]) != 1): + # optim_slice = slice(None) + # else: + # optim_slice = index_or_slice + # index_or_slice = slice(None) + + # if isinstance(optim_slice, int): + # return self.parts[self.data[optim_slice]].tolist()[0] + # else: + # return [pv for idx in self.data[optim_slice] for pv in self.parts[idx].tolist()][index_or_slice] + + if isinstance(index_or_slice, int): + return self.parts[self.data[index_or_slice]].tolist()[0] + else: + return [pv for idx in self.data[index_or_slice] for pv in self.parts[idx].tolist()] + + if main_type == GGUFValueType.STRING: + return to_string(self.parts[-1]) + else: + return self.parts[-1].tolist()[0] + + return None + + +class ReaderTensor(NamedTuple): + name: str + tensor_type: GGMLQuantizationType + shape: npt.NDArray[np.uint32] + n_elements: int + n_bytes: int + data_offset: int + data: npt.NDArray[Any] + field: ReaderField + + +class GGUFReader: + # I - same as host, S - swapped + byte_order: Literal['I', 'S'] = 'I' + alignment: int = GGUF_DEFAULT_ALIGNMENT + data_offset: int + + # Note: Internal helper, API may change. + gguf_scalar_to_np: dict[GGUFValueType, type[np.generic]] = { + GGUFValueType.UINT8: np.uint8, + GGUFValueType.INT8: np.int8, + GGUFValueType.UINT16: np.uint16, + GGUFValueType.INT16: np.int16, + GGUFValueType.UINT32: np.uint32, + GGUFValueType.INT32: np.int32, + GGUFValueType.FLOAT32: np.float32, + GGUFValueType.UINT64: np.uint64, + GGUFValueType.INT64: np.int64, + GGUFValueType.FLOAT64: np.float64, + GGUFValueType.BOOL: np.bool_, + } + + def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'): + self.data = np.memmap(path, mode = mode) + offs = 0 + + # Check for GGUF magic + if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC: + raise ValueError('GGUF magic invalid') + offs += 4 + + # Check GGUF version + temp_version = self._get(offs, np.uint32) + if temp_version[0] & 65535 == 0: + # If we get 0 here that means it's (probably) a GGUF file created for + # the opposite byte order of the machine this script is running on. + self.byte_order = 'S' + temp_version = temp_version.view(temp_version.dtype.newbyteorder(self.byte_order)) + version = temp_version[0] + if version not in READER_SUPPORTED_VERSIONS: + raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle') + if sys.byteorder == "little": + # Host is little endian + host_endian = GGUFEndian.LITTLE + swapped_endian = GGUFEndian.BIG + else: + # Sorry PDP or other weird systems that don't use BE or LE. + host_endian = GGUFEndian.BIG + swapped_endian = GGUFEndian.LITTLE + self.endianess = swapped_endian if self.byte_order == "S" else host_endian + self.fields: OrderedDict[str, ReaderField] = OrderedDict() + self.tensors: list[ReaderTensor] = [] + offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32])) + + # Check tensor count and kv count + temp_counts = self._get(offs, np.uint64, 2) + offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64])) + offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64])) + tensor_count, kv_count = temp_counts + offs = self._build_fields(offs, kv_count) + + # Build Tensor Info Fields + offs, tensors_fields = self._build_tensor_info(offs, tensor_count) + new_align = self.fields.get('general.alignment') + if new_align is not None: + if new_align.types != [GGUFValueType.UINT32]: + raise ValueError('Bad type for general.alignment field') + self.alignment = new_align.parts[-1][0] + padding = offs % self.alignment + if padding != 0: + offs += self.alignment - padding + self.data_offset = offs + self._build_tensors(offs, tensors_fields) + + _DT = TypeVar('_DT', bound = npt.DTypeLike) + + # Fetch a key/value metadata field by key. + def get_field(self, key: str) -> Union[ReaderField, None]: + return self.fields.get(key, None) + + # Fetch a tensor from the list by index. + def get_tensor(self, idx: int) -> ReaderTensor: + return self.tensors[idx] + + def _get( + self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None, + ) -> npt.NDArray[Any]: + count = int(count) + itemsize = int(np.empty([], dtype = dtype).itemsize) + end_offs = offset + itemsize * count + arr = self.data[offset:end_offs].view(dtype=dtype)[:count] + return arr.view(arr.dtype.newbyteorder(self.byte_order if override_order is None else override_order)) + + def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int: + if field.name in self.fields: + # TODO: add option to generate error on duplicate keys + # raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}') + + logger.warning(f'Duplicate key {field.name} at offset {field.offset}') + self.fields[field.name + '_{}'.format(field.offset)] = field + else: + self.fields[field.name] = field + return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts) + + def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]: + slen = self._get(offset, np.uint64) + return slen, self._get(offset + 8, np.uint8, slen[0]) + + def _get_field_parts( + self, orig_offs: int, raw_type: int, + ) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]: + offs = orig_offs + types: list[GGUFValueType] = [] + gtype = GGUFValueType(raw_type) + types.append(gtype) + # Handle strings. + if gtype == GGUFValueType.STRING: + sparts: list[npt.NDArray[Any]] = list(self._get_str(offs)) + size = sum(int(part.nbytes) for part in sparts) + return size, sparts, [1], types + # Check if it's a simple scalar type. + nptype = self.gguf_scalar_to_np.get(gtype) + if nptype is not None: + val = self._get(offs, nptype) + return int(val.nbytes), [val], [0], types + # Handle arrays. + if gtype == GGUFValueType.ARRAY: + raw_itype = self._get(offs, np.uint32) + offs += int(raw_itype.nbytes) + alen = self._get(offs, np.uint64) + offs += int(alen.nbytes) + aparts: list[npt.NDArray[Any]] = [raw_itype, alen] + data_idxs: list[int] = [] + # FIXME: Handle multi-dimensional arrays properly instead of flattening + for idx in range(alen[0]): + curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0]) + if idx == 0: + types += curr_types + idxs_offs = len(aparts) + aparts += curr_parts + data_idxs += (idx + idxs_offs for idx in curr_idxs) + offs += curr_size + return offs - orig_offs, aparts, data_idxs, types + # We can't deal with this one. + raise ValueError(f'Unknown/unhandled field type {gtype}') + + def _get_tensor_info_field(self, orig_offs: int) -> ReaderField: + offs = orig_offs + + # Get Tensor Name + name_len, name_data = self._get_str(offs) + offs += int(name_len.nbytes + name_data.nbytes) + + # Get Tensor Dimensions Count + n_dims = self._get(offs, np.uint32) + offs += int(n_dims.nbytes) + + # Get Tensor Dimension Array + dims = self._get(offs, np.uint64, n_dims[0]) + offs += int(dims.nbytes) + + # Get Tensor Encoding Scheme Type + raw_dtype = self._get(offs, np.uint32) + offs += int(raw_dtype.nbytes) + + # Get Tensor Offset + offset_tensor = self._get(offs, np.uint64) + offs += int(offset_tensor.nbytes) + + return ReaderField( + orig_offs, + str(bytes(name_data), encoding = 'utf-8'), + [name_len, name_data, n_dims, dims, raw_dtype, offset_tensor], + [1, 3, 4, 5], + ) + + def _build_fields(self, offs: int, count: int) -> int: + for _ in range(count): + orig_offs = offs + kv_klen, kv_kdata = self._get_str(offs) + offs += int(kv_klen.nbytes + kv_kdata.nbytes) + raw_kv_type = self._get(offs, np.uint32) + offs += int(raw_kv_type.nbytes) + parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type] + idxs_offs = len(parts) + field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0]) + parts += field_parts + self._push_field(ReaderField( + orig_offs, + str(bytes(kv_kdata), encoding = 'utf-8'), + parts, + [idx + idxs_offs for idx in field_idxs], + field_types, + ), skip_sum = True) + offs += field_size + return offs + + def _build_tensor_info(self, offs: int, count: int) -> tuple[int, list[ReaderField]]: + tensor_fields = [] + for _ in range(count): + field = self._get_tensor_info_field(offs) + offs += sum(int(part.nbytes) for part in field.parts) + tensor_fields.append(field) + return offs, tensor_fields + + def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None: + tensors = [] + tensor_names = set() # keep track of name to prevent duplicated tensors + for field in fields: + _name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts + # check if there's any tensor having same name already in the list + tensor_name = str(bytes(name_data), encoding = 'utf-8') + if tensor_name in tensor_names: + raise ValueError(f'Found duplicated tensor with name {tensor_name}') + tensor_names.add(tensor_name) + ggml_type = GGMLQuantizationType(raw_dtype[0]) + n_elems = int(np.prod(dims)) + np_dims = tuple(reversed(dims.tolist())) + block_size, type_size = GGML_QUANT_SIZES[ggml_type] + n_bytes = n_elems * type_size // block_size + data_offs = int(start_offs + offset_tensor[0]) + item_type: npt.DTypeLike + if ggml_type == GGMLQuantizationType.F16: + item_count = n_elems + item_type = np.float16 + elif ggml_type == GGMLQuantizationType.F32: + item_count = n_elems + item_type = np.float32 + elif ggml_type == GGMLQuantizationType.F64: + item_count = n_elems + item_type = np.float64 + elif ggml_type == GGMLQuantizationType.I8: + item_count = n_elems + item_type = np.int8 + elif ggml_type == GGMLQuantizationType.I16: + item_count = n_elems + item_type = np.int16 + elif ggml_type == GGMLQuantizationType.I32: + item_count = n_elems + item_type = np.int32 + elif ggml_type == GGMLQuantizationType.I64: + item_count = n_elems + item_type = np.int64 + else: + item_count = n_bytes + item_type = np.uint8 + np_dims = quant_shape_to_byte_shape(np_dims, ggml_type) + tensors.append(ReaderTensor( + name = tensor_name, + tensor_type = ggml_type, + shape = dims, + n_elements = n_elems, + n_bytes = n_bytes, + data_offset = data_offs, + data = self._get(data_offs, item_type, item_count).reshape(np_dims), + field = field, + )) + self.tensors = tensors diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..54ca0c33fd3368daa6a19a15023b792da88951d2 --- /dev/null +++ b/gguf-py/gguf/gguf_writer.py @@ -0,0 +1,1076 @@ +from __future__ import annotations + +import logging +import os +import shutil +import struct +import tempfile +from dataclasses import dataclass +from enum import Enum, auto +from math import prod +from pathlib import Path +from io import BufferedWriter +from typing import IO, Any, Sequence, Mapping +from string import ascii_letters, digits + +import numpy as np + +from .constants import ( + GGUF_DEFAULT_ALIGNMENT, + GGUF_MAGIC, + GGUF_VERSION, + GGMLQuantizationType, + GGUFEndian, + GGUFValueType, + Keys, + RopeScalingType, + PoolingType, + TokenType, + ExpertGatingFuncType, +) + +from .quants import quant_shape_from_byte_shape + +logger = logging.getLogger(__name__) + + +SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf" + + +@dataclass +class TensorInfo: + shape: Sequence[int] + dtype: GGMLQuantizationType + nbytes: int + tensor: np.ndarray[Any, Any] | None = None + + +@dataclass +class GGUFValue: + value: Any + type: GGUFValueType + sub_type: GGUFValueType | None = None + + +class WriterState(Enum): + NO_FILE = auto() + EMPTY = auto() + HEADER = auto() + KV_DATA = auto() + TI_DATA = auto() + WEIGHTS = auto() + + +class GGUFWriter: + fout: list[BufferedWriter] | None + path: Path | None + temp_file: tempfile.SpooledTemporaryFile[bytes] | None + tensors: list[dict[str, TensorInfo]] + kv_data: list[dict[str, GGUFValue]] + state: WriterState + _simple_value_packing = { + GGUFValueType.UINT8: "B", + GGUFValueType.INT8: "b", + GGUFValueType.UINT16: "H", + GGUFValueType.INT16: "h", + GGUFValueType.UINT32: "I", + GGUFValueType.INT32: "i", + GGUFValueType.FLOAT32: "f", + GGUFValueType.UINT64: "Q", + GGUFValueType.INT64: "q", + GGUFValueType.FLOAT64: "d", + GGUFValueType.BOOL: "?", + } + + def __init__( + self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE, + split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False + ): + self.fout = None + self.path = Path(path) if path else None + self.arch = arch + self.endianess = endianess + self.data_alignment = GGUF_DEFAULT_ALIGNMENT + self.use_temp_file = use_temp_file + self.temp_file = None + self.tensors = [{}] + self.kv_data = [{}] + self.split_max_tensors = split_max_tensors + self.split_max_size = split_max_size + self.dry_run = dry_run + self.small_first_shard = small_first_shard + logger.info("gguf: This GGUF file is for {0} Endian only".format( + "Big" if self.endianess == GGUFEndian.BIG else "Little", + )) + self.state = WriterState.NO_FILE + + if self.small_first_shard: + self.tensors.append({}) + + self.add_architecture() + + def get_total_parameter_count(self) -> tuple[int, int, int, int]: + total_params = 0 + shared_params = 0 + expert_params = 0 + + expert_sum = 0 + n_expert_tensors = 0 + + last_lora_a: tuple[str, TensorInfo] | None = None + + for tensors in self.tensors: + for name, info in tensors.items(): + + shape = info.shape + + if name.endswith(".lora_a"): + last_lora_a = (name, info) + continue + elif name.endswith(".lora_b"): + if last_lora_a is None or last_lora_a[0] != name[:-1] + "a": + # Bail when the LoRA pair can't be found trivially + logger.warning("can't measure LoRA size correctly, tensor order is unusual") + return 0, 0, 0, 0 + else: + shape = (*shape[:-1], last_lora_a[1].shape[-1]) + + size = prod(shape) + + if "_exps." in name: + expert_params += (size // shape[-3]) + expert_sum += shape[-3] + n_expert_tensors += 1 + else: + shared_params += size + + total_params += size + + # Hopefully this should work even for variable-expert-count models + expert_count = (expert_sum // n_expert_tensors) if n_expert_tensors > 0 else 0 + + # Negate the total to signal it's likely not exact + if last_lora_a is not None: + total_params = -total_params + + # NOTE: keep the output in the same order as accepted by 'size_label' in gguf-py/gguf/utility.py + return total_params, shared_params, expert_params, expert_count + + def format_shard_names(self, path: Path) -> list[Path]: + if len(self.tensors) == 1: + return [path] + return [path.with_name(SHARD_NAME_FORMAT.format(path.stem, i + 1, len(self.tensors))) for i in range(len(self.tensors))] + + def open_output_file(self, path: Path | None = None) -> None: + if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path): + # allow calling this multiple times as long as the path is the same + return + + if self.state is not WriterState.NO_FILE: + raise ValueError(f'Expected output file to be not yet opened, got {self.state}') + + if path is not None: + self.path = path + + if self.path is not None: + filenames = self.print_plan() + self.fout = [open(filename, "wb") for filename in filenames] + self.state = WriterState.EMPTY + + def print_plan(self) -> list[Path]: + logger.info("Writing the following files:") + assert self.path is not None + filenames = self.format_shard_names(self.path) + assert len(filenames) == len(self.tensors) + for name, tensors in zip(filenames, self.tensors): + logger.info(f"{name}: n_tensors = {len(tensors)}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in tensors.values()))}") + + if self.dry_run: + logger.info("Dry run, not writing files") + for name in filenames: + print(name) # noqa: NP100 + exit() + + return filenames + + def add_shard_kv_data(self) -> None: + if len(self.tensors) == 1: + return + + total_tensors = sum(len(t) for t in self.tensors) + assert self.fout is not None + total_splits = len(self.fout) + self.kv_data.extend({} for _ in range(len(self.kv_data), total_splits)) + for i, kv_data in enumerate(self.kv_data): + kv_data[Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16) + kv_data[Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(total_splits, GGUFValueType.UINT16) + kv_data[Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32) + + def write_header_to_file(self, path: Path | None = None) -> None: + if len(self.tensors) == 1 and (self.split_max_tensors != 0 or self.split_max_size != 0): + logger.warning("Model fails split requirements, not splitting") + + self.open_output_file(path) + + if self.state is not WriterState.EMPTY: + raise ValueError(f'Expected output file to be empty, got {self.state}') + + assert self.fout is not None + assert len(self.fout) == len(self.tensors) + assert len(self.kv_data) == 1 + + self.add_shard_kv_data() + + for fout, tensors, kv_data in zip(self.fout, self.tensors, self.kv_data): + fout.write(self._pack(" None: + if self.state is not WriterState.HEADER: + raise ValueError(f'Expected output file to contain the header, got {self.state}') + assert self.fout is not None + + for fout, kv_data in zip(self.fout, self.kv_data): + kv_bytes = bytearray() + + for key, val in kv_data.items(): + kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False) + kv_bytes += self._pack_val(val.value, val.type, add_vtype=True, sub_type=val.sub_type) + + fout.write(kv_bytes) + + self.flush() + self.state = WriterState.KV_DATA + + def write_ti_data_to_file(self) -> None: + if self.state is not WriterState.KV_DATA: + raise ValueError(f'Expected output file to contain KV data, got {self.state}') + assert self.fout is not None + + for fout, tensors in zip(self.fout, self.tensors): + ti_data = bytearray() + offset_tensor = 0 + + for name, ti in tensors.items(): + ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False) + n_dims = len(ti.shape) + ti_data += self._pack("I", n_dims) + for j in range(n_dims): + ti_data += self._pack("Q", ti.shape[n_dims - 1 - j]) + ti_data += self._pack("I", ti.dtype) + ti_data += self._pack("Q", offset_tensor) + offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment) + + fout.write(ti_data) + fout.flush() + self.state = WriterState.TI_DATA + + def add_key_value(self, key: str, val: Any, vtype: GGUFValueType, sub_type: GGUFValueType | None = None) -> None: + if any(key in kv_data for kv_data in self.kv_data): + logger.warning(f'Duplicated key name {key!r}, overwriting it with new value {val!r} of type {vtype.name}') + + self.kv_data[0][key] = GGUFValue(value=val, type=vtype, sub_type=sub_type) + + def add_uint8(self, key: str, val: int) -> None: + self.add_key_value(key,val, GGUFValueType.UINT8) + + def add_int8(self, key: str, val: int) -> None: + self.add_key_value(key, val, GGUFValueType.INT8) + + def add_uint16(self, key: str, val: int) -> None: + self.add_key_value(key, val, GGUFValueType.UINT16) + + def add_int16(self, key: str, val: int) -> None: + self.add_key_value(key, val, GGUFValueType.INT16) + + def add_uint32(self, key: str, val: int) -> None: + self.add_key_value(key, val, GGUFValueType.UINT32) + + def add_int32(self, key: str, val: int) -> None: + self.add_key_value(key, val, GGUFValueType.INT32) + + def add_float32(self, key: str, val: float) -> None: + self.add_key_value(key, val, GGUFValueType.FLOAT32) + + def add_uint64(self, key: str, val: int) -> None: + self.add_key_value(key, val, GGUFValueType.UINT64) + + def add_int64(self, key: str, val: int) -> None: + self.add_key_value(key, val, GGUFValueType.INT64) + + def add_float64(self, key: str, val: float) -> None: + self.add_key_value(key, val, GGUFValueType.FLOAT64) + + def add_bool(self, key: str, val: bool) -> None: + self.add_key_value(key, val, GGUFValueType.BOOL) + + def add_string(self, key: str, val: str) -> None: + if not val: + return + self.add_key_value(key, val, GGUFValueType.STRING) + + def add_array(self, key: str, val: Sequence[Any]) -> None: + if len(val) == 0: + return + self.add_key_value(key, val, GGUFValueType.ARRAY) + + @staticmethod + def ggml_pad(x: int, n: int) -> int: + return ((x + n - 1) // n) * n + + def add_tensor_info( + self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype, + tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None, + ) -> None: + if self.state is not WriterState.NO_FILE: + raise ValueError(f'Expected output file to be not yet opened, got {self.state}') + + if any(name in tensors for tensors in self.tensors): + raise ValueError(f'Duplicated tensor name {name!r}') + + if raw_dtype is None: + if tensor_dtype == np.float16: + dtype = GGMLQuantizationType.F16 + elif tensor_dtype == np.float32: + dtype = GGMLQuantizationType.F32 + elif tensor_dtype == np.float64: + dtype = GGMLQuantizationType.F64 + elif tensor_dtype == np.int8: + dtype = GGMLQuantizationType.I8 + elif tensor_dtype == np.int16: + dtype = GGMLQuantizationType.I16 + elif tensor_dtype == np.int32: + dtype = GGMLQuantizationType.I32 + elif tensor_dtype == np.int64: + dtype = GGMLQuantizationType.I64 + else: + raise ValueError("Only F16, F32, F64, I8, I16, I32, I64 tensors are supported for now") + else: + dtype = raw_dtype + if tensor_dtype == np.uint8: + tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype) + + # make sure there is at least one tensor before splitting + if len(self.tensors[-1]) > 0: + if ( # split when over tensor limit + self.split_max_tensors != 0 + and len(self.tensors[-1]) >= self.split_max_tensors + ) or ( # split when over size limit + self.split_max_size != 0 + and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size + ): + self.tensors.append({}) + + self.tensors[-1][name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes) + + def add_tensor( + self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, + raw_dtype: GGMLQuantizationType | None = None, + ) -> None: + if self.endianess == GGUFEndian.BIG: + tensor.byteswap(inplace=True) + if self.use_temp_file and self.temp_file is None: + fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024) + fp.seek(0) + self.temp_file = fp + + shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape + self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype) + + if self.temp_file is None: + self.tensors[-1][name].tensor = tensor + return + + tensor.tofile(self.temp_file) + self.write_padding(self.temp_file, tensor.nbytes) + + def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None: + pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n + if pad != 0: + fp.write(bytes([0] * pad)) + + def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: + if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS: + raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}') + assert self.fout is not None + + if self.endianess == GGUFEndian.BIG: + tensor.byteswap(inplace=True) + + file_id = -1 + for i, tensors in enumerate(self.tensors): + if len(tensors) > 0: + file_id = i + break + + fout = self.fout[file_id] + + # pop the first tensor info + # TODO: cleaner way to get the first key + first_tensor_name = [name for name, _ in zip(self.tensors[file_id].keys(), range(1))][0] + ti = self.tensors[file_id].pop(first_tensor_name) + assert ti.nbytes == tensor.nbytes + + self.write_padding(fout, fout.tell()) + tensor.tofile(fout) + self.write_padding(fout, tensor.nbytes) + + self.state = WriterState.WEIGHTS + + def write_tensors_to_file(self, *, progress: bool = False) -> None: + self.write_ti_data_to_file() + + assert self.fout is not None + + for fout in self.fout: + self.write_padding(fout, fout.tell()) + + if self.temp_file is None: + shard_bar = None + bar = None + + if progress: + from tqdm import tqdm + + total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values()) + + if len(self.fout) > 1: + shard_bar = tqdm(desc=f"Shard (0/{len(self.fout)})", total=None, unit="byte", unit_scale=True) + bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) + + for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)): + if shard_bar is not None: + shard_bar.set_description(f"Shard ({i + 1}/{len(self.fout)})") + total = sum(ti.nbytes for ti in tensors.values()) + shard_bar.reset(total=(total if total > 0 else None)) + + # relying on the fact that Python dicts preserve insertion order (since 3.7) + for ti in tensors.values(): + assert ti.tensor is not None # can only iterate once over the tensors + assert ti.tensor.nbytes == ti.nbytes + ti.tensor.tofile(fout) + if shard_bar is not None: + shard_bar.update(ti.nbytes) + if bar is not None: + bar.update(ti.nbytes) + self.write_padding(fout, ti.nbytes) + ti.tensor = None + else: + self.temp_file.seek(0) + + shutil.copyfileobj(self.temp_file, self.fout[0 if not self.small_first_shard else 1]) + self.flush() + self.temp_file.close() + + self.state = WriterState.WEIGHTS + + def flush(self) -> None: + assert self.fout is not None + for fout in self.fout: + fout.flush() + + def close(self) -> None: + if self.fout is not None: + for fout in self.fout: + fout.close() + self.fout = None + + def add_type(self, type_name: str) -> None: + self.add_string(Keys.General.TYPE, type_name) + + def add_architecture(self) -> None: + self.add_string(Keys.General.ARCHITECTURE, self.arch) + + def add_quantization_version(self, quantization_version: int) -> None: + self.add_uint32(Keys.General.QUANTIZATION_VERSION, quantization_version) + + def add_custom_alignment(self, alignment: int) -> None: + self.data_alignment = alignment + self.add_uint32(Keys.General.ALIGNMENT, alignment) + + def add_file_type(self, ftype: int) -> None: + self.add_uint32(Keys.General.FILE_TYPE, ftype) + + def add_name(self, name: str) -> None: + self.add_string(Keys.General.NAME, name) + + def add_author(self, author: str) -> None: + self.add_string(Keys.General.AUTHOR, author) + + def add_version(self, version: str) -> None: + self.add_string(Keys.General.VERSION, version) + + def add_organization(self, organization: str) -> None: + self.add_string(Keys.General.ORGANIZATION, organization) + + def add_finetune(self, finetune: str) -> None: + self.add_string(Keys.General.FINETUNE, finetune) + + def add_basename(self, basename: str) -> None: + self.add_string(Keys.General.BASENAME, basename) + + def add_description(self, description: str) -> None: + self.add_string(Keys.General.DESCRIPTION, description) + + def add_quantized_by(self, quantized: str) -> None: + self.add_string(Keys.General.QUANTIZED_BY, quantized) + + def add_size_label(self, size_label: str) -> None: + self.add_string(Keys.General.SIZE_LABEL, size_label) + + def add_license(self, license: str) -> None: + self.add_string(Keys.General.LICENSE, license) + + def add_license_name(self, license: str) -> None: + self.add_string(Keys.General.LICENSE_NAME, license) + + def add_license_link(self, license: str) -> None: + self.add_string(Keys.General.LICENSE_LINK, license) + + def add_url(self, url: str) -> None: + self.add_string(Keys.General.URL, url) + + def add_doi(self, doi: str) -> None: + self.add_string(Keys.General.DOI, doi) + + def add_uuid(self, uuid: str) -> None: + self.add_string(Keys.General.UUID, uuid) + + def add_repo_url(self, repo_url: str) -> None: + self.add_string(Keys.General.REPO_URL, repo_url) + + def add_source_url(self, url: str) -> None: + self.add_string(Keys.General.SOURCE_URL, url) + + def add_source_doi(self, doi: str) -> None: + self.add_string(Keys.General.SOURCE_DOI, doi) + + def add_source_uuid(self, uuid: str) -> None: + self.add_string(Keys.General.SOURCE_UUID, uuid) + + def add_source_repo_url(self, repo_url: str) -> None: + self.add_string(Keys.General.SOURCE_REPO_URL, repo_url) + + def add_base_model_count(self, source_count: int) -> None: + self.add_uint32(Keys.General.BASE_MODEL_COUNT, source_count) + + def add_base_model_name(self, source_id: int, name: str) -> None: + self.add_string(Keys.General.BASE_MODEL_NAME.format(id=source_id), name) + + def add_base_model_author(self, source_id: int, author: str) -> None: + self.add_string(Keys.General.BASE_MODEL_AUTHOR.format(id=source_id), author) + + def add_base_model_version(self, source_id: int, version: str) -> None: + self.add_string(Keys.General.BASE_MODEL_VERSION.format(id=source_id), version) + + def add_base_model_organization(self, source_id: int, organization: str) -> None: + self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=source_id), organization) + + def add_base_model_description(self, source_id: int, description: str) -> None: + self.add_string(Keys.General.BASE_MODEL_DESCRIPTION.format(id=source_id), description) + + def add_base_model_url(self, source_id: int, url: str) -> None: + self.add_string(Keys.General.BASE_MODEL_URL.format(id=source_id), url) + + def add_base_model_doi(self, source_id: int, doi: str) -> None: + self.add_string(Keys.General.BASE_MODEL_DOI.format(id=source_id), doi) + + def add_base_model_uuid(self, source_id: int, uuid: str) -> None: + self.add_string(Keys.General.BASE_MODEL_UUID.format(id=source_id), uuid) + + def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None: + self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=source_id), repo_url) + + def add_dataset_count(self, source_count: int) -> None: + self.add_uint32(Keys.General.DATASET_COUNT, source_count) + + def add_dataset_name(self, source_id: int, name: str) -> None: + self.add_string(Keys.General.DATASET_NAME.format(id=source_id), name) + + def add_dataset_author(self, source_id: int, author: str) -> None: + self.add_string(Keys.General.DATASET_AUTHOR.format(id=source_id), author) + + def add_dataset_version(self, source_id: int, version: str) -> None: + self.add_string(Keys.General.DATASET_VERSION.format(id=source_id), version) + + def add_dataset_organization(self, source_id: int, organization: str) -> None: + self.add_string(Keys.General.DATASET_ORGANIZATION.format(id=source_id), organization) + + def add_dataset_description(self, source_id: int, description: str) -> None: + self.add_string(Keys.General.DATASET_DESCRIPTION.format(id=source_id), description) + + def add_dataset_url(self, source_id: int, url: str) -> None: + self.add_string(Keys.General.DATASET_URL.format(id=source_id), url) + + def add_dataset_doi(self, source_id: int, doi: str) -> None: + self.add_string(Keys.General.DATASET_DOI.format(id=source_id), doi) + + def add_dataset_uuid(self, source_id: int, uuid: str) -> None: + self.add_string(Keys.General.DATASET_UUID.format(id=source_id), uuid) + + def add_dataset_repo_url(self, source_id: int, repo_url: str) -> None: + self.add_string(Keys.General.DATASET_REPO_URL.format(id=source_id), repo_url) + + def add_tags(self, tags: Sequence[str]) -> None: + self.add_array(Keys.General.TAGS, tags) + + def add_languages(self, languages: Sequence[str]) -> None: + self.add_array(Keys.General.LANGUAGES, languages) + + def add_tensor_data_layout(self, layout: str) -> None: + self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout) + + def add_vocab_size(self, size: int) -> None: + self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size) + + def add_context_length(self, length: int) -> None: + self.add_uint32(Keys.LLM.CONTEXT_LENGTH.format(arch=self.arch), length) + + def add_embedding_length(self, length: int) -> None: + self.add_uint32(Keys.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length) + + def add_features_length(self, length: int) -> None: + self.add_uint32(Keys.LLM.FEATURES_LENGTH.format(arch=self.arch), length) + + def add_posnet_embedding_length(self, length: int) -> None: + self.add_uint32(Keys.PosNet.EMBEDDING_LENGTH.format(arch=self.arch), length) + + def add_posnet_block_count(self, length: int) -> None: + self.add_uint32(Keys.PosNet.BLOCK_COUNT.format(arch=self.arch), length) + + def add_convnext_embedding_length(self, length: int) -> None: + self.add_uint32(Keys.ConvNext.EMBEDDING_LENGTH.format(arch=self.arch), length) + + def add_convnext_block_count(self, length: int) -> None: + self.add_uint32(Keys.ConvNext.BLOCK_COUNT.format(arch=self.arch), length) + + def add_block_count(self, length: int) -> None: + self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length) + + def add_leading_dense_block_count(self, length: int) -> None: + self.add_uint32(Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length) + + def add_feed_forward_length(self, length: int | Sequence[int]) -> None: + if isinstance(length, int): + self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length) + else: + self.add_array(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length) + + def add_expert_feed_forward_length(self, length: int) -> None: + self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length) + + def add_expert_shared_feed_forward_length(self, length: int) -> None: + self.add_uint32(Keys.LLM.EXPERT_SHARED_FEED_FORWARD_LENGTH.format(arch=self.arch), length) + + def add_parallel_residual(self, use: bool) -> None: + self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use) + + def add_decoder_start_token_id(self, id: int) -> None: + self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id) + + def add_head_count(self, count: int | Sequence[int]) -> None: + if isinstance(count, int): + self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count) + else: + self.add_array(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count) + + def add_head_count_kv(self, count: int | Sequence[int]) -> None: + if isinstance(count, int): + self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count) + else: + self.add_array(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count) + + def add_key_length(self, length: int) -> None: + self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length) + + def add_value_length(self, length: int) -> None: + self.add_uint32(Keys.Attention.VALUE_LENGTH.format(arch=self.arch), length) + + def add_key_length_mla(self, length: int) -> None: + self.add_uint32(Keys.Attention.KEY_LENGTH_MLA.format(arch=self.arch), length) + + def add_value_length_mla(self, length: int) -> None: + self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length) + + def add_max_alibi_bias(self, bias: float) -> None: + self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias) + + def add_clamp_kqv(self, value: float) -> None: + self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value) + + def add_logit_scale(self, value: float) -> None: + self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value) + + def add_attn_logit_softcapping(self, value: float) -> None: + self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value) + + def add_final_logit_softcapping(self, value: float) -> None: + self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value) + + def add_expert_count(self, count: int) -> None: + self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count) + + def add_expert_used_count(self, count: int) -> None: + self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count) + + def add_expert_shared_count(self, count: int) -> None: + self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count) + + def add_expert_weights_scale(self, value: float) -> None: + self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value) + + def add_expert_weights_norm(self, value: bool) -> None: + self.add_bool(Keys.LLM.EXPERT_WEIGHTS_NORM.format(arch=self.arch), value) + + def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None: + self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value) + + def add_moe_every_n_layers(self, value: int) -> None: + self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value) + + def add_swin_norm(self, value: bool) -> None: + self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value) + + def add_rescale_every_n_layers(self, count: int) -> None: + self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count) + + def add_time_mix_extra_dim(self, dim: int) -> None: + self.add_uint32(Keys.LLM.TIME_MIX_EXTRA_DIM.format(arch=self.arch), dim) + + def add_time_decay_extra_dim(self, dim: int) -> None: + self.add_uint32(Keys.LLM.TIME_DECAY_EXTRA_DIM.format(arch=self.arch), dim) + + def add_residual_scale(self, value: float) -> None: + self.add_float32(Keys.LLM.RESIDUAL_SCALE.format(arch=self.arch), value) + + def add_embedding_scale(self, value: float) -> None: + self.add_float32(Keys.LLM.EMBEDDING_SCALE.format(arch=self.arch), value) + + def add_wkv_head_size(self, size: int) -> None: + self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size) + + def add_token_shift_count(self, count: int) -> None: + self.add_uint32(Keys.LLM.TOKEN_SHIFT_COUNT.format(arch=self.arch), count) + + def add_interleave_moe_layer_step(self, value: int) -> None: + self.add_uint32(Keys.LLM.INTERLEAVE_MOE_LAYER_STEP.format(arch=self.arch), value) + + def add_layer_norm_eps(self, value: float) -> None: + self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) + + def add_layer_norm_rms_eps(self, value: float) -> None: + self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value) + + def add_group_norm_eps(self, value: float) -> None: + self.add_float32(Keys.Attention.GROUPNORM_EPS.format(arch=self.arch), value) + + def add_group_norm_groups(self, value: int) -> None: + self.add_uint32(Keys.Attention.GROUPNORM_GROUPS.format(arch=self.arch), value) + + def add_causal_attention(self, value: bool) -> None: + self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value) + + def add_q_lora_rank(self, length: int) -> None: + self.add_uint32(Keys.Attention.Q_LORA_RANK.format(arch=self.arch), length) + + def add_kv_lora_rank(self, length: int) -> None: + self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length) + + def add_decay_lora_rank(self, length: int) -> None: + self.add_uint32(Keys.Attention.DECAY_LORA_RANK.format(arch=self.arch), length) + + def add_iclr_lora_rank(self, length: int) -> None: + self.add_uint32(Keys.Attention.ICLR_LORA_RANK.format(arch=self.arch), length) + + def add_value_residual_mix_lora_rank(self, length: int) -> None: + self.add_uint32(Keys.Attention.VALUE_RESIDUAL_MIX_LORA_RANK.format(arch=self.arch), length) + + def add_gate_lora_rank(self, length: int) -> None: + self.add_uint32(Keys.Attention.GATE_LORA_RANK.format(arch=self.arch), length) + + def add_relative_attn_buckets_count(self, value: int) -> None: + self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value) + + def add_sliding_window(self, value: int) -> None: + self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value) + + def add_attention_scale(self, value: float) -> None: + self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value) + + def add_pooling_type(self, value: PoolingType) -> None: + self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value) + + def add_rope_dimension_count(self, count: int) -> None: + self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count) + + def add_rope_dimension_sections(self, dims: Sequence[int]) -> None: + self.add_array(Keys.Rope.DIMENSION_SECTIONS.format(arch=self.arch), dims) + + def add_rope_freq_base(self, value: float) -> None: + self.add_float32(Keys.Rope.FREQ_BASE.format(arch=self.arch), value) + + def add_rope_scaling_type(self, value: RopeScalingType) -> None: + self.add_string(Keys.Rope.SCALING_TYPE.format(arch=self.arch), value.value) + + def add_rope_scaling_factor(self, value: float) -> None: + self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value) + + def add_rope_scaling_attn_factors(self, value: float) -> None: + self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value) + + def add_rope_scaling_orig_ctx_len(self, value: int) -> None: + self.add_uint32(Keys.Rope.SCALING_ORIG_CTX_LEN.format(arch=self.arch), value) + + def add_rope_scaling_finetuned(self, value: bool) -> None: + self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value) + + def add_rope_scaling_yarn_log_mul(self, value: float) -> None: + self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value) + + def add_ssm_conv_kernel(self, value: int) -> None: + self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value) + + def add_ssm_inner_size(self, value: int) -> None: + self.add_uint32(Keys.SSM.INNER_SIZE.format(arch=self.arch), value) + + def add_ssm_state_size(self, value: int) -> None: + self.add_uint32(Keys.SSM.STATE_SIZE.format(arch=self.arch), value) + + def add_ssm_time_step_rank(self, value: int) -> None: + self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value) + + def add_ssm_dt_b_c_rms(self, value: bool) -> None: + self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value) + + def add_tokenizer_model(self, model: str) -> None: + self.add_string(Keys.Tokenizer.MODEL, model) + + def add_tokenizer_pre(self, pre: str) -> None: + self.add_string(Keys.Tokenizer.PRE, pre) + + def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None: + self.add_array(Keys.Tokenizer.LIST, tokens) + + def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None: + self.add_array(Keys.Tokenizer.MERGES, merges) + + def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None: + self.add_array(Keys.Tokenizer.TOKEN_TYPE, types) + + def add_token_type_count(self, value: int) -> None: + self.add_uint32(Keys.Tokenizer.TOKEN_TYPE_COUNT, value) + + def add_token_scores(self, scores: Sequence[float]) -> None: + self.add_array(Keys.Tokenizer.SCORES, scores) + + def add_bos_token_id(self, id: int) -> None: + self.add_uint32(Keys.Tokenizer.BOS_ID, id) + + def add_eos_token_id(self, id: int) -> None: + self.add_uint32(Keys.Tokenizer.EOS_ID, id) + + def add_unk_token_id(self, id: int) -> None: + self.add_uint32(Keys.Tokenizer.UNK_ID, id) + + def add_sep_token_id(self, id: int) -> None: + self.add_uint32(Keys.Tokenizer.SEP_ID, id) + + def add_pad_token_id(self, id: int) -> None: + self.add_uint32(Keys.Tokenizer.PAD_ID, id) + + def add_mask_token_id(self, id: int) -> None: + self.add_uint32(Keys.Tokenizer.MASK_ID, id) + + def add_add_bos_token(self, value: bool) -> None: + self.add_bool(Keys.Tokenizer.ADD_BOS, value) + + def add_add_eos_token(self, value: bool) -> None: + self.add_bool(Keys.Tokenizer.ADD_EOS, value) + + def add_add_space_prefix(self, value: bool) -> None: + self.add_bool(Keys.Tokenizer.ADD_PREFIX, value) + + def add_remove_extra_whitespaces(self, value: bool) -> None: + self.add_bool(Keys.Tokenizer.REMOVE_EXTRA_WS, value) + + def add_precompiled_charsmap(self, charsmap: bytes) -> None: + self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap) + + def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None: + if not isinstance(value, str): + template_default = None + template_names = set() + + for choice in value: + name = choice.get('name', '') + template = choice.get('template') + + # Allowing non-alphanumerical characters in template name is probably not a good idea, so filter it + name = ''.join((c if c in ascii_letters + digits else '_' for c in name)) + + if name and template is not None: + if name == 'default': + template_default = template + else: + template_names.add(name) + self.add_string(Keys.Tokenizer.CHAT_TEMPLATE_N.format(name=name), template) + + if template_names: + self.add_array(Keys.Tokenizer.CHAT_TEMPLATES, list(template_names)) + + if template_default is None: + return + + value = template_default + + self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value) + + def add_eot_token_id(self, id: int) -> None: + self.add_uint32(Keys.Tokenizer.EOT_ID, id) + + def add_eom_token_id(self, id: int) -> None: + self.add_uint32(Keys.Tokenizer.EOM_ID, id) + + def add_classifier_output_labels(self, labels: Sequence[str]) -> None: + self.add_array(Keys.Classifier.OUTPUT_LABELS.format(arch=self.arch), labels) + + # for vision models + + def add_clip_has_vision_encoder(self, value: bool) -> None: + self.add_bool(Keys.Clip.HAS_VISION_ENCODER, value) + + def add_clip_has_audio_encoder(self, value: bool) -> None: + self.add_bool(Keys.Clip.HAS_AUDIO_ENCODER, value) + + def add_clip_projector_type(self, value: str) -> None: + self.add_string(Keys.Clip.PROJECTOR_TYPE, value) + + def add_vision_projection_dim(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.PROJECTION_DIM, value) + + def add_vision_patch_size(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.PATCH_SIZE, value) + + def add_vision_embedding_length(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.EMBEDDING_LENGTH, value) + + def add_vision_feed_forward_length(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.FEED_FORWARD_LENGTH, value) + + def add_vision_block_count(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.BLOCK_COUNT, value) + + def add_vision_head_count(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT, value) + + def add_vision_attention_layernorm_eps(self, value: float) -> None: + self.add_float32(Keys.ClipVision.Attention.LAYERNORM_EPS, value) + + def add_vision_image_size(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.IMAGE_SIZE, value) + + def add_vision_image_mean(self, values: Sequence[float]) -> None: + self.add_array(Keys.ClipVision.IMAGE_MEAN, values) + + def add_vision_image_std(self, values: Sequence[float]) -> None: + self.add_array(Keys.ClipVision.IMAGE_STD, values) + + def add_vision_spatial_merge_size(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.SPATIAL_MERGE_SIZE, value) + + def add_vision_use_gelu(self, value: bool) -> None: + self.add_bool(Keys.ClipVision.USE_GELU, value) + + def add_vision_use_silu(self, value: bool) -> None: + self.add_bool(Keys.ClipVision.USE_SILU, value) + + def add_vision_projector_scale_factor(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value) + + def add_vision_n_wa_pattern(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value) + + # audio models + + def add_audio_projection_dim(self, value: int) -> None: + self.add_uint32(Keys.ClipAudio.PROJECTION_DIM, value) + + def add_audio_embedding_length(self, value: int) -> None: + self.add_uint32(Keys.ClipAudio.EMBEDDING_LENGTH, value) + + def add_audio_feed_forward_length(self, value: int) -> None: + self.add_uint32(Keys.ClipAudio.FEED_FORWARD_LENGTH, value) + + def add_audio_block_count(self, value: int) -> None: + self.add_uint32(Keys.ClipAudio.BLOCK_COUNT, value) + + def add_audio_head_count(self, value: int) -> None: + self.add_uint32(Keys.ClipAudio.Attention.HEAD_COUNT, value) + + def add_audio_attention_layernorm_eps(self, value: float) -> None: + self.add_float32(Keys.ClipAudio.Attention.LAYERNORM_EPS, value) + + def add_audio_num_mel_bins(self, value: int) -> None: + self.add_uint32(Keys.ClipAudio.NUM_MEL_BINS, value) + + def add_audio_stack_factor(self, value: int) -> None: + self.add_uint32(Keys.ClipAudio.Projector.STACK_FACTOR, value) + + def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes: + pack_prefix = '' + if not skip_pack_prefix: + pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>' + return struct.pack(f'{pack_prefix}{fmt}', value) + + def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool, sub_type: GGUFValueType | None = None) -> bytes: + kv_data = bytearray() + + if add_vtype: + kv_data += self._pack("I", vtype) + + pack_fmt = self._simple_value_packing.get(vtype) + if pack_fmt is not None: + kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL) + elif vtype == GGUFValueType.STRING: + encoded_val = val.encode("utf-8") if isinstance(val, str) else val + kv_data += self._pack("Q", len(encoded_val)) + kv_data += encoded_val + elif vtype == GGUFValueType.ARRAY: + + if not isinstance(val, Sequence): + raise ValueError("Invalid GGUF metadata array, expecting sequence") + + if len(val) == 0: + raise ValueError("Invalid GGUF metadata array. Empty array") + + if sub_type is not None: + ltype = sub_type + elif isinstance(val, bytes): + ltype = GGUFValueType.UINT8 + else: + ltype = GGUFValueType.get_type(val[0]) + if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]): + raise ValueError("All items in a GGUF array should be of the same type") + kv_data += self._pack("I", ltype) + kv_data += self._pack("Q", len(val)) + for item in val: + kv_data += self._pack_val(item, ltype, add_vtype=False) + else: + raise ValueError("Invalid GGUF metadata value type or value") + + return kv_data + + @staticmethod + def format_n_bytes_to_str(num: int) -> str: + if num == 0: + return "negligible - metadata only" + fnum = float(num) + for unit in ("", "K", "M", "G"): + if abs(fnum) < 1000.0: + return f"{fnum:3.1f}{unit}" + fnum /= 1000.0 + return f"{fnum:.1f}T - over 1TB, split recommended" diff --git a/gguf-py/gguf/lazy.py b/gguf-py/gguf/lazy.py new file mode 100644 index 0000000000000000000000000000000000000000..f9bcadae0224bac5f922ce6565dd1e3883a923e1 --- /dev/null +++ b/gguf-py/gguf/lazy.py @@ -0,0 +1,223 @@ +from __future__ import annotations +from abc import ABC, ABCMeta, abstractmethod + +import logging +from typing import Any, Callable + +import numpy as np +from numpy.typing import DTypeLike + + +logger = logging.getLogger(__name__) + + +class LazyMeta(ABCMeta): + + def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs): + def __getattr__(self, name: str) -> Any: + meta_attr = getattr(self._meta, name) + if callable(meta_attr): + return type(self)._wrap_fn( + (lambda s, *args, **kwargs: getattr(s, name)(*args, **kwargs)), + use_self=self, + ) + elif isinstance(meta_attr, self._tensor_type): + # e.g. self.T with torch.Tensor should still be wrapped + return type(self)._wrap_fn(lambda s: getattr(s, name))(self) + else: + # no need to wrap non-tensor properties, + # and they likely don't depend on the actual contents of the tensor + return meta_attr + + namespace["__getattr__"] = __getattr__ + + # need to make a builder for the wrapped wrapper to copy the name, + # or else it fails with very cryptic error messages, + # because somehow the same string would end up in every closures + def mk_wrap(op_name: str, *, meta_noop: bool = False): + # need to wrap the wrapper to get self + def wrapped_special_op(self, *args, **kwargs): + return type(self)._wrap_fn( + getattr(type(self)._tensor_type, op_name), + meta_noop=meta_noop, + )(self, *args, **kwargs) + return wrapped_special_op + + # special methods bypass __getattr__, so they need to be added manually + # ref: https://docs.python.org/3/reference/datamodel.html#special-lookup + # NOTE: doing this from a metaclass is very convenient + # TODO: make this even more comprehensive + for binary_op in ( + "lt", "le", "eq", "ne", "ge", "gt", "not" + "abs", "add", "and", "floordiv", "invert", "lshift", "mod", "mul", "matmul", + "neg", "or", "pos", "pow", "rshift", "sub", "truediv", "xor", + "iadd", "iand", "ifloordiv", "ilshift", "imod", "imul", "ior", "irshift", "isub", "ixor", + "radd", "rand", "rfloordiv", "rmul", "ror", "rpow", "rsub", "rtruediv", "rxor", + ): + attr_name = f"__{binary_op}__" + # the result of these operators usually has the same shape and dtype as the input, + # so evaluation on the meta tensor can be skipped. + namespace[attr_name] = mk_wrap(attr_name, meta_noop=True) + + for special_op in ( + "getitem", "setitem", "len", + ): + attr_name = f"__{special_op}__" + namespace[attr_name] = mk_wrap(attr_name, meta_noop=False) + + return super().__new__(cls, name, bases, namespace, **kwargs) + + +# Tree of lazy tensors +class LazyBase(ABC, metaclass=LazyMeta): + _tensor_type: type + _meta: Any + _data: Any | None + _args: tuple + _kwargs: dict[str, Any] + _func: Callable[[Any], Any] | None + + def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None): + super().__init__() + self._meta = meta + self._data = data + self._args = args + self._kwargs = kwargs if kwargs is not None else {} + self._func = func + assert self._func is not None or self._data is not None + + def __init_subclass__(cls) -> None: + if "_tensor_type" not in cls.__dict__: + raise TypeError(f"property '_tensor_type' must be defined for {cls!r}") + return super().__init_subclass__() + + @staticmethod + def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any: + # TODO: dict and set + if isinstance(o, (list, tuple)): + L = [] + for item in o: + L.append(LazyBase._recurse_apply(item, fn)) + if isinstance(o, tuple): + L = tuple(L) + return L + elif isinstance(o, LazyBase): + return fn(o) + else: + return o + + @classmethod + def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False) -> Callable[[Any], Any]: + def wrapped_fn(*args, **kwargs): + if kwargs is None: + kwargs = {} + args = ((use_self,) if use_self is not None else ()) + args + + meta_args = LazyBase._recurse_apply(args, lambda t: t._meta) + # TODO: maybe handle tensors in kwargs too + + if isinstance(meta_noop, bool) and not meta_noop: + try: + res = fn(*meta_args, **kwargs) + except NotImplementedError: + # running some operations on PyTorch's Meta tensors can cause this exception + res = None + else: + # some operators don't need to actually run on the meta tensors + assert len(args) > 0 + res = args[0] + assert isinstance(res, cls) + res = res._meta + # allow operations to override the dtype and shape + if meta_noop is not True: + if isinstance(meta_noop, tuple): + dtype, shape = meta_noop + assert callable(shape) + res = cls.meta_with_dtype_and_shape(dtype, shape(res.shape)) + else: + res = cls.meta_with_dtype_and_shape(meta_noop, res.shape) + + if isinstance(res, cls._tensor_type): + return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn) + elif isinstance(res, tuple) and all(isinstance(t, cls._tensor_type) for t in res): + # share the evaluation between lazy tuple elements + shared_args: list = [args, None] + + def eager_tuple_element(a: list[Any], i: int = 0, /, **kw) -> LazyBase: + assert len(a) == 2 + if a[1] is None: + a[1] = fn(*a[0], **kw) + return a[1][i] + return tuple(cls(meta=cls.eager_to_meta(res[i]), args=(shared_args, i), kwargs=kwargs, func=eager_tuple_element) for i in range(len(res))) + else: + del res # not needed + # non-tensor return likely relies on the contents of the args + # (e.g. the result of torch.equal) + eager_args = cls.to_eager(args) + return fn(*eager_args, **kwargs) + return wrapped_fn + + @classmethod + def to_eager(cls, t: Any) -> Any: + def simple_to_eager(_t: LazyBase) -> Any: + if _t._data is not None: + return _t._data + + # NOTE: there's a recursion limit in Python (usually 1000) + + assert _t._func is not None + _t._args = cls._recurse_apply(_t._args, simple_to_eager) + _t._data = _t._func(*_t._args, **_t._kwargs) + # sanity check + assert _t._data is not None + assert _t._data.dtype == _t._meta.dtype + assert _t._data.shape == _t._meta.shape + + return _t._data + + # recurse into lists and/or tuples, keeping their structure + return cls._recurse_apply(t, simple_to_eager) + + @classmethod + def eager_to_meta(cls, t: Any) -> Any: + return cls.meta_with_dtype_and_shape(t.dtype, t.shape) + + # must be overridden, meta tensor init is backend-specific + @classmethod + @abstractmethod + def meta_with_dtype_and_shape(cls, dtype: Any, shape: Any) -> Any: pass + + @classmethod + def from_eager(cls, t: Any) -> Any: + if type(t) is cls: + # already lazy + return t + elif isinstance(t, cls._tensor_type): + return cls(meta=cls.eager_to_meta(t), data=t) + else: + return TypeError(f"{type(t)!r} is not compatible with {cls._tensor_type!r}") + + +class LazyNumpyTensor(LazyBase): + _tensor_type = np.ndarray + + shape: tuple[int, ...] # Makes the type checker happy in quants.py + + @classmethod + def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) -> np.ndarray[Any, Any]: + # The initial idea was to use np.nan as the fill value, + # but non-float types like np.int16 can't use that. + # So zero it is. + cheat = np.zeros(1, dtype) + return np.lib.stride_tricks.as_strided(cheat, shape, (0 for _ in shape)) + + def astype(self, dtype, *args, **kwargs): + meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape) + full_args = (self, dtype,) + args + return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs))) + + def tofile(self, *args, **kwargs): + eager = LazyNumpyTensor.to_eager(self) + return eager.tofile(*args, **kwargs) + + # TODO: __array_function__ diff --git a/gguf-py/gguf/metadata.py b/gguf-py/gguf/metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..e807f434689de669816504921f54809cac15fc71 --- /dev/null +++ b/gguf-py/gguf/metadata.py @@ -0,0 +1,642 @@ +from __future__ import annotations + +import re +import json +import yaml +import logging +from pathlib import Path +from typing import Any, Literal, Optional +from dataclasses import dataclass + +from .constants import Keys + +import gguf + +logger = logging.getLogger("metadata") + + +@dataclass +class Metadata: + # Authorship Metadata to be written to GGUF KV Store + name: Optional[str] = None + author: Optional[str] = None + version: Optional[str] = None + organization: Optional[str] = None + finetune: Optional[str] = None + basename: Optional[str] = None + description: Optional[str] = None + quantized_by: Optional[str] = None + size_label: Optional[str] = None + url: Optional[str] = None + doi: Optional[str] = None + uuid: Optional[str] = None + repo_url: Optional[str] = None + source_url: Optional[str] = None + source_doi: Optional[str] = None + source_uuid: Optional[str] = None + source_repo_url: Optional[str] = None + license: Optional[str] = None + license_name: Optional[str] = None + license_link: Optional[str] = None + base_models: Optional[list[dict]] = None + tags: Optional[list[str]] = None + languages: Optional[list[str]] = None + datasets: Optional[list[dict]] = None + + @staticmethod + def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata: + # This grabs as many contextual authorship metadata as possible from the model repository + # making any conversion as required to match the gguf kv store metadata format + # as well as giving users the ability to override any authorship metadata that may be incorrect + + # Create a new Metadata instance + metadata = Metadata() + + model_card = Metadata.load_model_card(model_path) + hf_params = Metadata.load_hf_parameters(model_path) + # TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter + + # heuristics + metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params) + + # Metadata Override File Provided + # This is based on LLM_KV_NAMES mapping in llama.cpp + metadata_override = Metadata.load_metadata_override(metadata_override_path) + + metadata.name = metadata_override.get(Keys.General.NAME, metadata.name) + metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author) + metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version) + metadata.organization = metadata_override.get(Keys.General.ORGANIZATION, metadata.organization) + + metadata.finetune = metadata_override.get(Keys.General.FINETUNE, metadata.finetune) + metadata.basename = metadata_override.get(Keys.General.BASENAME, metadata.basename) + + metadata.description = metadata_override.get(Keys.General.DESCRIPTION, metadata.description) + metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY, metadata.quantized_by) + + metadata.size_label = metadata_override.get(Keys.General.SIZE_LABEL, metadata.size_label) + metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME, metadata.license_name) + metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK, metadata.license_link) + + metadata.url = metadata_override.get(Keys.General.URL, metadata.url) + metadata.doi = metadata_override.get(Keys.General.DOI, metadata.doi) + metadata.uuid = metadata_override.get(Keys.General.UUID, metadata.uuid) + metadata.repo_url = metadata_override.get(Keys.General.REPO_URL, metadata.repo_url) + + metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL, metadata.source_url) + metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI, metadata.source_doi) + metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID, metadata.source_uuid) + metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url) + + # Base Models is received here as an array of models + metadata.base_models = metadata_override.get("general.base_models", metadata.base_models) + + # Datasets is received here as an array of datasets + metadata.datasets = metadata_override.get("general.datasets", metadata.datasets) + + metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags) + metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages) + + # Direct Metadata Override (via direct cli argument) + if model_name is not None: + metadata.name = model_name + + return metadata + + @staticmethod + def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, Any]: + if metadata_override_path is None or not metadata_override_path.is_file(): + return {} + + with open(metadata_override_path, "r", encoding="utf-8") as f: + return json.load(f) + + @staticmethod + def load_model_card(model_path: Optional[Path] = None) -> dict[str, Any]: + if model_path is None or not model_path.is_dir(): + return {} + + model_card_path = model_path / "README.md" + + if not model_card_path.is_file(): + return {} + + # The model card metadata is assumed to always be in YAML (frontmatter) + # ref: https://github.com/huggingface/transformers/blob/a5c642fe7a1f25d3bdcd76991443ba6ff7ee34b2/src/transformers/modelcard.py#L468-L473 + yaml_content: str = "" + with open(model_card_path, "r", encoding="utf-8") as f: + content = f.read() + lines = content.splitlines() + lines_yaml = [] + if len(lines) == 0: + # Empty file + return {} + if len(lines) > 0 and lines[0] != "---": + # No frontmatter + return {} + for line in lines[1:]: + if line == "---": + break # End of frontmatter + else: + lines_yaml.append(line) + yaml_content = "\n".join(lines_yaml) + "\n" + + # Quick hack to fix the Norway problem + # https://hitchdev.com/strictyaml/why/implicit-typing-removed/ + yaml_content = yaml_content.replace("- no\n", "- \"no\"\n") + + if yaml_content: + data = yaml.safe_load(yaml_content) + if isinstance(data, dict): + return data + else: + logger.error(f"while reading YAML model card frontmatter, data is {type(data)} instead of dict") + return {} + else: + return {} + + @staticmethod + def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]: + if model_path is None or not model_path.is_dir(): + return {} + + config_path = model_path / "config.json" + + if not config_path.is_file(): + return {} + + with open(config_path, "r", encoding="utf-8") as f: + return json.load(f) + + @staticmethod + def id_to_title(string): + # Convert capitalization into title form unless acronym or version number + return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()]) + + @staticmethod + def get_model_id_components(model_id: Optional[str] = None, total_params: int = 0) -> tuple[str | None, str | None, str | None, str | None, str | None, str | None]: + # Huggingface often store model id as '/' + # so let's parse it and apply some heuristics if possible for model name components + + if model_id is None: + # model ID missing + return None, None, None, None, None, None + + if ' ' in model_id: + # model ID is actually a normal human sentence + # which means its most likely a normal model name only + # not part of the hugging face naming standard, but whatever + return model_id, None, None, None, None, None + + if '/' in model_id: + # model ID (huggingface style) + org_component, model_full_name_component = model_id.split('/', 1) + else: + # model ID but missing org components + org_component, model_full_name_component = None, model_id + + # Check if we erroneously matched against './' or '../' etc... + if org_component is not None and len(org_component) > 0 and org_component[0] == '.': + org_component = None + + name_parts: list[str] = model_full_name_component.split('-') + + # Remove empty parts + for i in reversed(range(len(name_parts))): + if len(name_parts[i]) == 0: + del name_parts[i] + + name_types: list[ + set[Literal["basename", "size_label", "finetune", "version", "type"]] + ] = [set() for _ in name_parts] + + # Annotate the name + for i, part in enumerate(name_parts): + # Version + if re.fullmatch(r'(v|iter)?\d+([.]\d+)*', part, re.IGNORECASE): + name_types[i].add("version") + # Quant type (should not be there for base models, but still annotated) + elif re.fullmatch(r'i?q\d(_\w)*|b?fp?(16|32)', part, re.IGNORECASE): + name_types[i].add("type") + name_parts[i] = part.upper() + # Model size + elif i > 0 and re.fullmatch(r'(([A]|\d+[x])?\d+([._]\d+)?[KMBT][\d]?|small|mini|medium|large|x?xl)', part, re.IGNORECASE): + part = part.replace("_", ".") + # Handle weird bloom-7b1 notation + if part[-1].isdecimal(): + part = part[:-2] + "." + part[-1] + part[-2] + # Normalize the size suffixes + if len(part) > 1 and part[-2].isdecimal(): + if part[-1] in "kmbt": + part = part[:-1] + part[-1].upper() + if total_params != 0: + try: + label_params = float(part[:-1]) * pow(1000, " KMBT".find(part[-1])) + # Only use it as a size label if it's close or bigger than the model size + # Note that LoRA adapters don't necessarily include all layers, + # so this is why bigger label sizes are accepted. + # Do not use the size label when it's smaller than 1/8 of the model size + if (total_params < 0 and label_params < abs(total_params) // 8) or ( + # Check both directions when the current model isn't a LoRA adapter + total_params > 0 and abs(label_params - total_params) > 7 * total_params // 8 + ): + # Likely a context length + name_types[i].add("finetune") + # Lowercase the size when it's a context length + part = part[:-1] + part[-1].lower() + except ValueError: + # Failed to convert the size label to float, use it anyway + pass + if len(name_types[i]) == 0: + name_types[i].add("size_label") + name_parts[i] = part + # Some easy to recognize finetune names + elif i > 0 and re.fullmatch(r'chat|instruct|vision|lora', part, re.IGNORECASE): + if total_params < 0 and part.lower() == "lora": + # ignore redundant "lora" in the finetune part when the output is a lora adapter + name_types[i].add("type") + else: + name_types[i].add("finetune") + + # Ignore word-based size labels when there is at least a number-based one present + # TODO: should word-based size labels always be removed instead? + if any(c.isdecimal() for n, t in zip(name_parts, name_types) if "size_label" in t for c in n): + for n, t in zip(name_parts, name_types): + if "size_label" in t: + if all(c.isalpha() for c in n): + t.remove("size_label") + + at_start = True + # Find the basename through the annotated name + for part, t in zip(name_parts, name_types): + if at_start and ((len(t) == 0 and part[0].isalpha()) or "version" in t): + t.add("basename") + else: + if at_start: + at_start = False + if len(t) == 0: + t.add("finetune") + + # Remove the basename annotation from trailing version + for part, t in zip(reversed(name_parts), reversed(name_types)): + if "basename" in t and len(t) > 1: + t.remove("basename") + else: + break + + basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None + # Deduplicate size labels using order-preserving 'dict' ('set' seems to sort the keys) + size_label = "-".join(dict.fromkeys(s for s, t in zip(name_parts, name_types) if "size_label" in t).keys()) or None + finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None + # TODO: should the basename version always be excluded? + # NOTE: multiple finetune versions are joined together + version = "-".join(v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t) or None + + if size_label is None and finetune is None and version is None: + # Too ambiguous, output nothing + basename = None + + return model_full_name_component, org_component, basename, finetune, version, size_label + + @staticmethod + def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None, total_params: int = 0) -> Metadata: + # Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 + + # Model Card Heuristics + ######################## + if model_card is not None: + + def use_model_card_metadata(metadata_key: str, model_card_key: str): + if model_card_key in model_card and getattr(metadata, metadata_key, None) is None: + setattr(metadata, metadata_key, model_card.get(model_card_key)) + + def use_array_model_card_metadata(metadata_key: str, model_card_key: str): + # Note: Will append rather than replace if already exist + tags_value = model_card.get(model_card_key, None) + if tags_value is None: + return + + current_value = getattr(metadata, metadata_key, None) + if current_value is None: + current_value = [] + + if isinstance(tags_value, str): + current_value.append(tags_value) + elif isinstance(tags_value, list): + current_value.extend(tags_value) + + setattr(metadata, metadata_key, current_value) + + # LLAMA.cpp's direct internal convention + # (Definitely not part of hugging face formal/informal standard) + ######################################### + use_model_card_metadata("name", "name") + use_model_card_metadata("author", "author") + use_model_card_metadata("version", "version") + use_model_card_metadata("organization", "organization") + use_model_card_metadata("description", "description") + use_model_card_metadata("finetune", "finetune") + use_model_card_metadata("basename", "basename") + use_model_card_metadata("size_label", "size_label") + use_model_card_metadata("source_url", "url") + use_model_card_metadata("source_doi", "doi") + use_model_card_metadata("source_uuid", "uuid") + use_model_card_metadata("source_repo_url", "repo_url") + + # LLAMA.cpp's huggingface style convention + # (Definitely not part of hugging face formal/informal standard... but with model_ appended to match their style) + ########################################### + use_model_card_metadata("name", "model_name") + use_model_card_metadata("author", "model_author") + use_model_card_metadata("version", "model_version") + use_model_card_metadata("organization", "model_organization") + use_model_card_metadata("description", "model_description") + use_model_card_metadata("finetune", "model_finetune") + use_model_card_metadata("basename", "model_basename") + use_model_card_metadata("size_label", "model_size_label") + use_model_card_metadata("source_url", "model_url") + use_model_card_metadata("source_doi", "model_doi") + use_model_card_metadata("source_uuid", "model_uuid") + use_model_card_metadata("source_repo_url", "model_repo_url") + + # Hugging Face Direct Convention + ################################# + + # Not part of huggingface model card standard but notice some model creator using it + # such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF' + use_model_card_metadata("name", "model_name") + use_model_card_metadata("author", "model_creator") + use_model_card_metadata("basename", "model_type") + + if "base_model" in model_card or "base_models" in model_card or "base_model_sources" in model_card: + # This represents the parent models that this is based on + # Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges) + # Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md + metadata_base_models = [] + base_model_value = model_card.get("base_model", model_card.get("base_models", model_card.get("base_model_sources", None))) + + if base_model_value is not None: + if isinstance(base_model_value, str): + metadata_base_models.append(base_model_value) + elif isinstance(base_model_value, list): + metadata_base_models.extend(base_model_value) + + if metadata.base_models is None: + metadata.base_models = [] + + for model_id in metadata_base_models: + # NOTE: model size of base model is assumed to be similar to the size of the current model + base_model = {} + if isinstance(model_id, str): + if model_id.startswith("http://") or model_id.startswith("https://") or model_id.startswith("ssh://"): + base_model["repo_url"] = model_id + + # Check if Hugging Face ID is present in URL + if "huggingface.co" in model_id: + match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", model_id) + if match: + model_id_component = match.group(1) + model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id_component, total_params) + + # Populate model dictionary with extracted components + if model_full_name_component is not None: + base_model["name"] = Metadata.id_to_title(model_full_name_component) + if org_component is not None: + base_model["organization"] = Metadata.id_to_title(org_component) + if version is not None: + base_model["version"] = version + + else: + # Likely a Hugging Face ID + model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params) + + # Populate model dictionary with extracted components + if model_full_name_component is not None: + base_model["name"] = Metadata.id_to_title(model_full_name_component) + if org_component is not None: + base_model["organization"] = Metadata.id_to_title(org_component) + if version is not None: + base_model["version"] = version + if org_component is not None and model_full_name_component is not None: + base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}" + + elif isinstance(model_id, dict): + base_model = model_id + + else: + logger.error(f"base model entry '{str(model_id)}' not in a known format") + + metadata.base_models.append(base_model) + + if "datasets" in model_card or "dataset" in model_card or "dataset_sources" in model_card: + # This represents the datasets that this was trained from + metadata_datasets = [] + dataset_value = model_card.get("datasets", model_card.get("dataset", model_card.get("dataset_sources", None))) + + if dataset_value is not None: + if isinstance(dataset_value, str): + metadata_datasets.append(dataset_value) + elif isinstance(dataset_value, list): + metadata_datasets.extend(dataset_value) + + if metadata.datasets is None: + metadata.datasets = [] + + for dataset_id in metadata_datasets: + # NOTE: model size of base model is assumed to be similar to the size of the current model + dataset = {} + if isinstance(dataset_id, str): + if dataset_id.startswith(("http://", "https://", "ssh://")): + dataset["repo_url"] = dataset_id + + # Check if Hugging Face ID is present in URL + if "huggingface.co" in dataset_id: + match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", dataset_id) + if match: + dataset_id_component = match.group(1) + dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id_component, total_params) + + # Populate dataset dictionary with extracted components + if dataset_name_component is not None: + dataset["name"] = Metadata.id_to_title(dataset_name_component) + if org_component is not None: + dataset["organization"] = Metadata.id_to_title(org_component) + if version is not None: + dataset["version"] = version + + else: + # Likely a Hugging Face ID + dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id, total_params) + + # Populate dataset dictionary with extracted components + if dataset_name_component is not None: + dataset["name"] = Metadata.id_to_title(dataset_name_component) + if org_component is not None: + dataset["organization"] = Metadata.id_to_title(org_component) + if version is not None: + dataset["version"] = version + if org_component is not None and dataset_name_component is not None: + dataset["repo_url"] = f"https://huggingface.co/{org_component}/{dataset_name_component}" + + elif isinstance(dataset_id, dict): + dataset = dataset_id + + else: + logger.error(f"dataset entry '{str(dataset_id)}' not in a known format") + + metadata.datasets.append(dataset) + + use_model_card_metadata("license", "license") + use_model_card_metadata("license_name", "license_name") + use_model_card_metadata("license_link", "license_link") + + use_array_model_card_metadata("tags", "tags") + use_array_model_card_metadata("tags", "pipeline_tag") + + use_array_model_card_metadata("languages", "languages") + use_array_model_card_metadata("languages", "language") + + # Hugging Face Parameter Heuristics + #################################### + + if hf_params is not None: + + hf_name_or_path = hf_params.get("_name_or_path") + if hf_name_or_path is not None and hf_name_or_path.count('/') <= 1: + # Use _name_or_path only if its actually a model name and not some computer path + # e.g. 'meta-llama/Llama-2-7b-hf' + model_id = hf_name_or_path + model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params) + if metadata.name is None and model_full_name_component is not None: + metadata.name = Metadata.id_to_title(model_full_name_component) + if metadata.organization is None and org_component is not None: + metadata.organization = Metadata.id_to_title(org_component) + if metadata.basename is None and basename is not None: + metadata.basename = basename + if metadata.finetune is None and finetune is not None: + metadata.finetune = finetune + if metadata.version is None and version is not None: + metadata.version = version + if metadata.size_label is None and size_label is not None: + metadata.size_label = size_label + + # Directory Folder Name Fallback Heuristics + ############################################ + if model_path is not None: + model_id = model_path.name + model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params) + if metadata.name is None and model_full_name_component is not None: + metadata.name = Metadata.id_to_title(model_full_name_component) + if metadata.organization is None and org_component is not None: + metadata.organization = Metadata.id_to_title(org_component) + if metadata.basename is None and basename is not None: + metadata.basename = basename + if metadata.finetune is None and finetune is not None: + metadata.finetune = finetune + if metadata.version is None and version is not None: + metadata.version = version + if metadata.size_label is None and size_label is not None: + metadata.size_label = size_label + + return metadata + + def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter): + assert self.name is not None + gguf_writer.add_name(self.name) + + if self.author is not None: + gguf_writer.add_author(self.author) + if self.version is not None: + gguf_writer.add_version(self.version) + if self.organization is not None: + gguf_writer.add_organization(self.organization) + + if self.finetune is not None: + gguf_writer.add_finetune(self.finetune) + if self.basename is not None: + gguf_writer.add_basename(self.basename) + + if self.description is not None: + gguf_writer.add_description(self.description) + if self.quantized_by is not None: + gguf_writer.add_quantized_by(self.quantized_by) + + if self.size_label is not None: + gguf_writer.add_size_label(self.size_label) + + if self.license is not None: + if isinstance(self.license, list): + gguf_writer.add_license(",".join(self.license)) + else: + gguf_writer.add_license(self.license) + if self.license_name is not None: + gguf_writer.add_license_name(self.license_name) + if self.license_link is not None: + gguf_writer.add_license_link(self.license_link) + + if self.url is not None: + gguf_writer.add_url(self.url) + if self.doi is not None: + gguf_writer.add_doi(self.doi) + if self.uuid is not None: + gguf_writer.add_uuid(self.uuid) + if self.repo_url is not None: + gguf_writer.add_repo_url(self.repo_url) + + if self.source_url is not None: + gguf_writer.add_source_url(self.source_url) + if self.source_doi is not None: + gguf_writer.add_source_doi(self.source_doi) + if self.source_uuid is not None: + gguf_writer.add_source_uuid(self.source_uuid) + if self.source_repo_url is not None: + gguf_writer.add_source_repo_url(self.source_repo_url) + + if self.base_models is not None: + gguf_writer.add_base_model_count(len(self.base_models)) + for key, base_model_entry in enumerate(self.base_models): + if "name" in base_model_entry: + gguf_writer.add_base_model_name(key, base_model_entry["name"]) + if "author" in base_model_entry: + gguf_writer.add_base_model_author(key, base_model_entry["author"]) + if "version" in base_model_entry: + gguf_writer.add_base_model_version(key, base_model_entry["version"]) + if "organization" in base_model_entry: + gguf_writer.add_base_model_organization(key, base_model_entry["organization"]) + if "description" in base_model_entry: + gguf_writer.add_base_model_description(key, base_model_entry["description"]) + if "url" in base_model_entry: + gguf_writer.add_base_model_url(key, base_model_entry["url"]) + if "doi" in base_model_entry: + gguf_writer.add_base_model_doi(key, base_model_entry["doi"]) + if "uuid" in base_model_entry: + gguf_writer.add_base_model_uuid(key, base_model_entry["uuid"]) + if "repo_url" in base_model_entry: + gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"]) + + if self.datasets is not None: + gguf_writer.add_dataset_count(len(self.datasets)) + for key, dataset_entry in enumerate(self.datasets): + if "name" in dataset_entry: + gguf_writer.add_dataset_name(key, dataset_entry["name"]) + if "author" in dataset_entry: + gguf_writer.add_dataset_author(key, dataset_entry["author"]) + if "version" in dataset_entry: + gguf_writer.add_dataset_version(key, dataset_entry["version"]) + if "organization" in dataset_entry: + gguf_writer.add_dataset_organization(key, dataset_entry["organization"]) + if "description" in dataset_entry: + gguf_writer.add_dataset_description(key, dataset_entry["description"]) + if "url" in dataset_entry: + gguf_writer.add_dataset_url(key, dataset_entry["url"]) + if "doi" in dataset_entry: + gguf_writer.add_dataset_doi(key, dataset_entry["doi"]) + if "uuid" in dataset_entry: + gguf_writer.add_dataset_uuid(key, dataset_entry["uuid"]) + if "repo_url" in dataset_entry: + gguf_writer.add_dataset_repo_url(key, dataset_entry["repo_url"]) + + if self.tags is not None: + gguf_writer.add_tags(self.tags) + if self.languages is not None: + gguf_writer.add_languages(self.languages) diff --git a/gguf-py/gguf/py.typed b/gguf-py/gguf/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/gguf-py/gguf/quants.py b/gguf-py/gguf/quants.py new file mode 100644 index 0000000000000000000000000000000000000000..3c8ba82e19d3d9e984ba39caf5cf865b0ee8e72a --- /dev/null +++ b/gguf-py/gguf/quants.py @@ -0,0 +1,1269 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from typing import Any, Callable, Sequence +from math import log2, ceil + +from numpy.typing import DTypeLike + +from .constants import GGML_QUANT_SIZES, GGMLQuantizationType, QK_K +from .lazy import LazyNumpyTensor + +import numpy as np + + +def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]: + block_size, type_size = GGML_QUANT_SIZES[quant_type] + if shape[-1] % block_size != 0: + raise ValueError(f"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})") + return (*shape[:-1], shape[-1] // block_size * type_size) + + +def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]: + block_size, type_size = GGML_QUANT_SIZES[quant_type] + if shape[-1] % type_size != 0: + raise ValueError(f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {quant_type.name} type size ({type_size})") + return (*shape[:-1], shape[-1] // type_size * block_size) + + +# This is faster than np.vectorize and np.apply_along_axis because it works on more than one row at a time +def _apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: DTypeLike, oshape: tuple[int, ...]) -> np.ndarray: + rows = arr.reshape((-1, arr.shape[-1])) + osize = 1 + for dim in oshape: + osize *= dim + out = np.empty(shape=osize, dtype=otype) + # compute over groups of 16 rows (arbitrary, but seems good for performance) + n_groups = (rows.shape[0] // 16) or 1 + np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out) + return out.reshape(oshape) + + +# round away from zero +# ref: https://stackoverflow.com/a/59143326/22827863 +def np_roundf(n: np.ndarray) -> np.ndarray: + a = abs(n) + floored = np.floor(a) + b = floored + np.floor(2 * (a - floored)) + return np.sign(n) * b + + +class QuantError(Exception): ... + + +_type_traits: dict[GGMLQuantizationType, type[__Quant]] = {} + + +def quantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray: + if qtype == GGMLQuantizationType.F32: + return data.astype(np.float32, copy=False) + elif qtype == GGMLQuantizationType.F16: + return data.astype(np.float16, copy=False) + elif (q := _type_traits.get(qtype)) is not None: + return q.quantize(data) + else: + raise NotImplementedError(f"Quantization for {qtype.name} is not yet implemented") + + +def dequantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray: + if qtype == GGMLQuantizationType.F32: + return data.view(np.float32) + elif qtype == GGMLQuantizationType.F16: + return data.view(np.float16).astype(np.float32) + elif (q := _type_traits.get(qtype)) is not None: + return q.dequantize(data) + else: + raise NotImplementedError(f"Dequantization for {qtype.name} is not yet implemented") + + +class __Quant(ABC): + qtype: GGMLQuantizationType + block_size: int + type_size: int + + grid: np.ndarray[Any, np.dtype[np.float32]] | None = None + grid_shape: tuple[int, int] = (0, 0) + grid_map: tuple[int | float, ...] = () + grid_hex: bytes | None = None + + def __init__(self): + return TypeError("Quant conversion classes can't have instances") + + def __init_subclass__(cls, qtype: GGMLQuantizationType) -> None: + cls.qtype = qtype + cls.block_size, cls.type_size = GGML_QUANT_SIZES[qtype] + cls.__quantize_lazy = LazyNumpyTensor._wrap_fn( + cls.__quantize_array, + meta_noop=(np.uint8, cls.__shape_to_bytes) + ) + cls.__dequantize_lazy = LazyNumpyTensor._wrap_fn( + cls.__dequantize_array, + meta_noop=(np.float32, cls.__shape_from_bytes) + ) + assert qtype not in _type_traits + _type_traits[qtype] = cls + + @classmethod + def init_grid(cls): + if cls.grid is not None or cls.grid_hex is None: + return + + bits_per_elem = ceil(log2(len(cls.grid_map))) + assert bits_per_elem != 0, cls.qtype.name + elems_per_byte = 8 // bits_per_elem + + grid = np.frombuffer(cls.grid_hex, dtype=np.uint8) + # decode hexadecimal chars from grid + grid = grid.reshape((-1, 2)) + grid = (np.where(grid > 0x40, grid + 9, grid) & 0x0F) << np.array([4, 0], dtype=np.uint8).reshape((1, 2)) + grid = grid[..., 0] | grid[..., 1] + # unpack the grid values + grid = grid.reshape((-1, 1)) >> np.array([i for i in range(0, 8, 8 // elems_per_byte)], dtype=np.uint8).reshape((1, elems_per_byte)) + grid = (grid & ((1 << bits_per_elem) - 1)).reshape((-1, 1)) + grid_map = np.array(cls.grid_map, dtype=np.float32).reshape((1, -1)) + grid = np.take_along_axis(grid_map, grid, axis=-1) + cls.grid = grid.reshape((1, 1, *cls.grid_shape)) + + @classmethod + @abstractmethod + def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + raise NotImplementedError + + @classmethod + @abstractmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + raise NotImplementedError + + @classmethod + def quantize_rows(cls, rows: np.ndarray) -> np.ndarray: + rows = rows.astype(np.float32, copy=False) + shape = rows.shape + n_blocks = rows.size // cls.block_size + blocks = rows.reshape((n_blocks, cls.block_size)) + blocks = cls.quantize_blocks(blocks) + assert blocks.dtype == np.uint8 + assert blocks.shape[-1] == cls.type_size + return blocks.reshape(cls.__shape_to_bytes(shape)) + + @classmethod + def dequantize_rows(cls, rows: np.ndarray) -> np.ndarray: + rows = rows.view(np.uint8) + shape = rows.shape + n_blocks = rows.size // cls.type_size + blocks = rows.reshape((n_blocks, cls.type_size)) + blocks = cls.dequantize_blocks(blocks) + assert blocks.dtype == np.float32 + assert blocks.shape[-1] == cls.block_size + return blocks.reshape(cls.__shape_from_bytes(shape)) + + @classmethod + def __shape_to_bytes(cls, shape: Sequence[int]): + return quant_shape_to_byte_shape(shape, cls.qtype) + + @classmethod + def __shape_from_bytes(cls, shape: Sequence[int]): + return quant_shape_from_byte_shape(shape, cls.qtype) + + @classmethod + def __quantize_array(cls, array: np.ndarray) -> np.ndarray: + return _apply_over_grouped_rows(cls.quantize_rows, arr=array, otype=np.uint8, oshape=cls.__shape_to_bytes(array.shape)) + + @classmethod + def __dequantize_array(cls, array: np.ndarray) -> np.ndarray: + cls.init_grid() + return _apply_over_grouped_rows(cls.dequantize_rows, arr=array, otype=np.float32, oshape=cls.__shape_from_bytes(array.shape)) + + @classmethod + def __quantize_lazy(cls, lazy_tensor: LazyNumpyTensor, /) -> Any: + pass + + @classmethod + def __dequantize_lazy(cls, lazy_tensor: LazyNumpyTensor, /) -> Any: + pass + + @classmethod + def can_quantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> bool: + return tensor.shape[-1] % cls.block_size == 0 + + @classmethod + def quantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> np.ndarray: + if not cls.can_quantize(tensor): + raise QuantError(f"Can't quantize tensor with shape {tensor.shape} to {cls.qtype.name}") + if isinstance(tensor, LazyNumpyTensor): + return cls.__quantize_lazy(tensor) + else: + return cls.__quantize_array(tensor) + + @classmethod + def dequantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> np.ndarray: + if isinstance(tensor, LazyNumpyTensor): + return cls.__dequantize_lazy(tensor) + else: + return cls.__dequantize_array(tensor) + + +class BF16(__Quant, qtype=GGMLQuantizationType.BF16): + @classmethod + # same as ggml_compute_fp32_to_bf16 in ggml-impl.h + def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n = blocks.view(np.uint32) + # force nan to quiet + n = np.where((n & 0x7fffffff) > 0x7f800000, (n & np.uint32(0xffff0000)) | np.uint32(64 << 16), n) + # round to nearest even + n = (np.uint64(n) + (0x7fff + ((n >> 16) & 1))) >> 16 + return n.astype(np.uint16).view(np.uint8) + + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + return (blocks.view(np.int16).astype(np.int32) << 16).view(np.float32) + + +class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0): + @classmethod + def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + imax = abs(blocks).argmax(axis=-1, keepdims=True) + max = np.take_along_axis(blocks, imax, axis=-1) + + d = max / -8 + with np.errstate(divide="ignore"): + id = np.where(d == 0, 0, 1 / d) + # FIXME: Q4_0's reference rounding is cursed and depends on FMA + qs = np.trunc((np.float64(blocks) * np.float64(id)) + np.float64(8.5), dtype=np.float32).astype(np.uint8).clip(0, 15) + + qs = qs.reshape((n_blocks, 2, cls.block_size // 2)) + qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4)) + + d = d.astype(np.float16).view(np.uint8) + + return np.concatenate([d, qs], axis=-1) + + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + d, qs = np.hsplit(blocks, [2]) + + d = d.view(np.float16).astype(np.float32) + + qs = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1)) + qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1)).astype(np.int8) - np.int8(8) + + return (d * qs.astype(np.float32)) + + +class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1): + @classmethod + def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + max = blocks.max(axis=-1, keepdims=True) + min = blocks.min(axis=-1, keepdims=True) + + d = (max - min) / 15 + with np.errstate(divide="ignore"): + id = np.where(d == 0, 0, 1 / d) + qs = np.trunc((blocks - min) * id + np.float32(0.5), dtype=np.float32).astype(np.uint8).clip(0, 15) + + qs = qs.reshape((n_blocks, 2, cls.block_size // 2)) + qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4)) + + d = d.astype(np.float16).view(np.uint8) + m = min.astype(np.float16).view(np.uint8) + + return np.concatenate([d, m, qs], axis=-1) + + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + d, rest = np.hsplit(blocks, [2]) + m, qs = np.hsplit(rest, [2]) + + d = d.view(np.float16).astype(np.float32) + m = m.view(np.float16).astype(np.float32) + + qs = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1)) + qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1)).astype(np.float32) + + return (d * qs) + m + + +class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0): + @classmethod + def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + imax = abs(blocks).argmax(axis=-1, keepdims=True) + max = np.take_along_axis(blocks, imax, axis=-1) + + d = max / -16 + with np.errstate(divide="ignore"): + id = np.where(d == 0, 0, 1 / d) + # FIXME: Q5_0's reference rounding is cursed and depends on FMA + q = np.trunc((np.float64(blocks) * np.float64(id)) + np.float64(16.5), dtype=np.float32).astype(np.uint8).clip(0, 31) + + qs = q.reshape((n_blocks, 2, cls.block_size // 2)) + qs = (qs[..., 0, :] & np.uint8(0x0F)) | (qs[..., 1, :] << np.uint8(4)) + + qh = np.packbits(q.reshape((n_blocks, 1, 32)) >> np.uint8(4), axis=-1, bitorder="little").reshape(n_blocks, 4) + + d = d.astype(np.float16).view(np.uint8) + + return np.concatenate([d, qh, qs], axis=-1) + + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + d, rest = np.hsplit(blocks, [2]) + qh, qs = np.hsplit(rest, [4]) + + d = d.view(np.float16).astype(np.float32) + qh = qh.view(np.uint32) + + qh = qh.reshape((n_blocks, 1)) >> np.array([i for i in range(32)], dtype=np.uint32).reshape((1, 32)) + ql = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1)) + qh = (qh & np.uint32(0x01)).astype(np.uint8) + ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1)) + + qs = (ql | (qh << np.uint8(4))).astype(np.int8) - np.int8(16) + + return (d * qs.astype(np.float32)) + + +class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1): + @classmethod + def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + max = blocks.max(axis=-1, keepdims=True) + min = blocks.min(axis=-1, keepdims=True) + + d = (max - min) / 31 + with np.errstate(divide="ignore"): + id = np.where(d == 0, 0, 1 / d) + q = np.trunc((blocks - min) * id + np.float32(0.5), dtype=np.float32).astype(np.uint8).clip(0, 31) + + qs = q.reshape((n_blocks, 2, cls.block_size // 2)) + qs = (qs[..., 0, :] & np.uint8(0x0F)) | (qs[..., 1, :] << np.uint8(4)) + + qh = np.packbits(q.reshape((n_blocks, 1, 32)) >> np.uint8(4), axis=-1, bitorder="little").reshape(n_blocks, 4) + + d = d.astype(np.float16).view(np.uint8) + m = min.astype(np.float16).view(np.uint8) + + return np.concatenate([d, m, qh, qs], axis=-1) + + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + d, rest = np.hsplit(blocks, [2]) + m, rest = np.hsplit(rest, [2]) + qh, qs = np.hsplit(rest, [4]) + + d = d.view(np.float16).astype(np.float32) + m = m.view(np.float16).astype(np.float32) + qh = qh.view(np.uint32) + + qh = qh.reshape((n_blocks, 1)) >> np.array([i for i in range(32)], dtype=np.uint32).reshape((1, 32)) + ql = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1)) + qh = (qh & np.uint32(0x01)).astype(np.uint8) + ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1)) + + qs = (ql | (qh << np.uint8(4))).astype(np.float32) + + return (d * qs) + m + + +class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0): + @classmethod + # Implementation of Q8_0 with bit-exact same results as reference implementation in ggml-quants.c + def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + + d = abs(blocks).max(axis=1, keepdims=True) / 127 + with np.errstate(divide="ignore"): + id = np.where(d == 0, 0, 1 / d) + qs = np_roundf(blocks * id) + + # (n_blocks, 2) + d = d.astype(np.float16).view(np.uint8) + # (n_blocks, block_size) + qs = qs.astype(np.int8).view(np.uint8) + + return np.concatenate([d, qs], axis=1) + + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + d, x = np.split(blocks, [2], axis=1) + d = d.view(np.float16).astype(np.float32) + x = x.view(np.int8).astype(np.float32) + + return (x * d) + + +class Q2_K(__Quant, qtype=GGMLQuantizationType.Q2_K): + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + scales, rest = np.hsplit(blocks, [QK_K // 16]) + qs, rest = np.hsplit(rest, [QK_K // 4]) + d, dmin = np.hsplit(rest, [2]) + + d = d.view(np.float16).astype(np.float32) + dmin = dmin.view(np.float16).astype(np.float32) + + # (n_blocks, 16, 1) + dl = (d * (scales & 0xF).astype(np.float32)).reshape((n_blocks, QK_K // 16, 1)) + ml = (dmin * (scales >> 4).astype(np.float32)).reshape((n_blocks, QK_K // 16, 1)) + + shift = np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1)) + + qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & np.uint8(3) + + qs = qs.reshape((n_blocks, QK_K // 16, 16)).astype(np.float32) + + qs = dl * qs - ml + + return qs.reshape((n_blocks, -1)) + + +class Q3_K(__Quant, qtype=GGMLQuantizationType.Q3_K): + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + hmask, rest = np.hsplit(blocks, [QK_K // 8]) + qs, rest = np.hsplit(rest, [QK_K // 4]) + scales, d = np.hsplit(rest, [12]) + + d = d.view(np.float16).astype(np.float32) + + # The scales are packed at 6-bit each in this pattern: + # 0: IIIIAAAA + # 1: JJJJBBBB + # 2: KKKKCCCC + # 3: LLLLDDDD + # 4: MMMMEEEE + # 5: NNNNFFFF + # 6: OOOOGGGG + # 7: PPPPHHHH + # 8: MMIIEEAA + # 9: NNJJFFBB + # 10: OOKKGGCC + # 11: PPLLHHDD + lscales, hscales = np.hsplit(scales, [8]) + lscales = lscales.reshape((n_blocks, 1, 8)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 2, 1)) + lscales = lscales.reshape((n_blocks, 16)) + hscales = hscales.reshape((n_blocks, 1, 4)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 4, 1)) + hscales = hscales.reshape((n_blocks, 16)) + scales = (lscales & np.uint8(0x0F)) | ((hscales & np.uint8(0x03)) << np.uint8(4)) + scales = (scales.astype(np.int8) - np.int8(32)).astype(np.float32) + + dl = (d * scales).reshape((n_blocks, 16, 1)) + + ql = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1)) + qh = hmask.reshape(n_blocks, -1, 1, 32) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8, 1)) + ql = ql.reshape((n_blocks, 16, QK_K // 16)) & np.uint8(3) + qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & np.uint8(1)) + qh = qh ^ np.uint8(1) # strangely, the offset is zero when the bitmask is 1 + q = (ql.astype(np.int8) - (qh << np.uint8(2)).astype(np.int8)).astype(np.float32) + + return (dl * q).reshape((n_blocks, QK_K)) + + +class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K): + K_SCALE_SIZE = 12 + + @staticmethod + def get_scale_min(scales: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + n_blocks = scales.shape[0] + scales = scales.view(np.uint8) + ### Unpacking the following: ### + # 0 EEAAAAAA + # 1 FFBBBBBB + # 2 GGCCCCCC + # 3 HHDDDDDD + # 4 eeaaaaaa + # 5 ffbbbbbb + # 6 ggcccccc + # 7 hhdddddd + # 8 eeeeEEEE + # 9 ffffFFFF + # 10 ggggGGGG + # 11 hhhhHHHH + scales = scales.reshape((n_blocks, 3, 4)) + d, m, m_d = np.split(scales, 3, axis=-2) + + sc = np.concatenate([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], axis=-1) + min = np.concatenate([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], axis=-1) + + return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8))) + + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + d, rest = np.hsplit(blocks, [2]) + dmin, rest = np.hsplit(rest, [2]) + scales, qs = np.hsplit(rest, [cls.K_SCALE_SIZE]) + + d = d.view(np.float16).astype(np.float32) + dmin = dmin.view(np.float16).astype(np.float32) + + sc, m = Q4_K.get_scale_min(scales) + + d = (d * sc.astype(np.float32)).reshape((n_blocks, -1, 1)) + dm = (dmin * m.astype(np.float32)).reshape((n_blocks, -1, 1)) + + qs = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1)) + qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1, 32)).astype(np.float32) + + return (d * qs - dm).reshape((n_blocks, QK_K)) + + +class Q5_K(__Quant, qtype=GGMLQuantizationType.Q5_K): + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + d, rest = np.hsplit(blocks, [2]) + dmin, rest = np.hsplit(rest, [2]) + scales, rest = np.hsplit(rest, [Q4_K.K_SCALE_SIZE]) + qh, qs = np.hsplit(rest, [QK_K // 8]) + + d = d.view(np.float16).astype(np.float32) + dmin = dmin.view(np.float16).astype(np.float32) + + sc, m = Q4_K.get_scale_min(scales) + + d = (d * sc.astype(np.float32)).reshape((n_blocks, -1, 1)) + dm = (dmin * m.astype(np.float32)).reshape((n_blocks, -1, 1)) + + ql = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1)) + qh = qh.reshape((n_blocks, -1, 1, 32)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8, 1)) + ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1, 32)) + qh = (qh & np.uint8(0x01)).reshape((n_blocks, -1, 32)) + q = (ql | (qh << np.uint8(4))).astype(np.float32) + + return (d * q - dm).reshape((n_blocks, QK_K)) + + +class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K): + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + ql, rest = np.hsplit(blocks, [QK_K // 2]) + qh, rest = np.hsplit(rest, [QK_K // 4]) + scales, d = np.hsplit(rest, [QK_K // 16]) + + scales = scales.view(np.int8).astype(np.float32) + d = d.view(np.float16).astype(np.float32) + d = (d * scales).reshape((n_blocks, QK_K // 16, 1)) + + ql = ql.reshape((n_blocks, -1, 1, 64)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1)) + ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1, 32)) + qh = qh.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1)) + qh = (qh & np.uint8(0x03)).reshape((n_blocks, -1, 32)) + q = (ql | (qh << np.uint8(4))).astype(np.int8) - np.int8(32) + q = q.reshape((n_blocks, QK_K // 16, -1)).astype(np.float32) + + return (d * q).reshape((n_blocks, QK_K)) + + +class TQ1_0(__Quant, qtype=GGMLQuantizationType.TQ1_0): + @classmethod + def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + d = abs(blocks).max(axis=-1, keepdims=True) + with np.errstate(divide="ignore"): + id = np.where(d == 0, 0, 1 / d) + qs = np_roundf(blocks * id) + qs = (qs.astype(np.int8) + np.int8(1)).astype(np.uint8) + + qs0, qs1, qh = qs[..., :(32 * 5)], qs[..., (32 * 5):(48 * 5)], qs[..., (48 * 5):] + qs0 = qs0.reshape((n_blocks, -1, 5, 32)) * np.array([81, 27, 9, 3, 1], dtype=np.uint8).reshape((1, 1, 5, 1)) + qs0 = np.sum(qs0, axis=-2).reshape((n_blocks, -1)) + qs1 = qs1.reshape((n_blocks, -1, 5, 16)) * np.array([81, 27, 9, 3, 1], dtype=np.uint8).reshape((1, 1, 5, 1)) + qs1 = np.sum(qs1, axis=-2).reshape((n_blocks, -1)) + qh = qh.reshape((n_blocks, -1, 4, 4)) * np.array([81, 27, 9, 3], dtype=np.uint8).reshape((1, 1, 4, 1)) + qh = np.sum(qh, axis=-2).reshape((n_blocks, -1)) + qs = np.concatenate([qs0, qs1, qh], axis=-1) + qs = (qs.astype(np.uint16) * 256 + (243 - 1)) // 243 + + qs = qs.astype(np.uint8) + d = d.astype(np.float16).view(np.uint8) + + return np.concatenate([qs, d], axis=-1) + + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + qs, rest = np.hsplit(blocks, [(QK_K - 4 * QK_K // 64) // 5]) + qh, d = np.hsplit(rest, [QK_K // 64]) + + d = d.view(np.float16).astype(np.float32) + + qs0, qs1 = qs[..., :32], qs[..., 32:] + qs0 = qs0.reshape((n_blocks, -1, 1, 32)) * np.array([1, 3, 9, 27, 81], dtype=np.uint8).reshape((1, 1, 5, 1)) + qs0 = qs0.reshape((n_blocks, -1)) + qs1 = qs1.reshape((n_blocks, -1, 1, 16)) * np.array([1, 3, 9, 27, 81], dtype=np.uint8).reshape((1, 1, 5, 1)) + qs1 = qs1.reshape((n_blocks, -1)) + qh = qh.reshape((n_blocks, -1, 1, 4)) * np.array([1, 3, 9, 27], dtype=np.uint8).reshape((1, 1, 4, 1)) + qh = qh.reshape((n_blocks, -1)) + qs = np.concatenate([qs0, qs1, qh], axis=-1) + qs = ((qs.astype(np.uint16) * 3) >> 8).astype(np.int8) - np.int8(1) + + return (d * qs.astype(np.float32)) + + +class TQ2_0(__Quant, qtype=GGMLQuantizationType.TQ2_0): + @classmethod + def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + d = abs(blocks).max(axis=-1, keepdims=True) + with np.errstate(divide="ignore"): + id = np.where(d == 0, 0, 1 / d) + qs = np_roundf(blocks * id) + qs = (qs.astype(np.int8) + np.int8(1)).astype(np.uint8) + + qs = qs.reshape((n_blocks, -1, 4, 32)) << np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1)) + qs = qs[..., 0, :] | qs[..., 1, :] | qs[..., 2, :] | qs[..., 3, :] + qs = qs.reshape((n_blocks, -1)) + + d = d.astype(np.float16).view(np.uint8) + + return np.concatenate([qs, d], axis=-1) + + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + qs, d = np.hsplit(blocks, [QK_K // 4]) + + d = d.view(np.float16).astype(np.float32) + + qs = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1)) + qs = (qs & 0x03).reshape((n_blocks, -1)).astype(np.int8) - np.int8(1) + + return (d * qs.astype(np.float32)) + + +class IQ2_XXS(__Quant, qtype=GGMLQuantizationType.IQ2_XXS): + ksigns: bytes = ( + b"\x00\x81\x82\x03\x84\x05\x06\x87\x88\x09\x0a\x8b\x0c\x8d\x8e\x0f" + b"\x90\x11\x12\x93\x14\x95\x96\x17\x18\x99\x9a\x1b\x9c\x1d\x1e\x9f" + b"\xa0\x21\x22\xa3\x24\xa5\xa6\x27\x28\xa9\xaa\x2b\xac\x2d\x2e\xaf" + b"\x30\xb1\xb2\x33\xb4\x35\x36\xb7\xb8\x39\x3a\xbb\x3c\xbd\xbe\x3f" + b"\xc0\x41\x42\xc3\x44\xc5\xc6\x47\x48\xc9\xca\x4b\xcc\x4d\x4e\xcf" + b"\x50\xd1\xd2\x53\xd4\x55\x56\xd7\xd8\x59\x5a\xdb\x5c\xdd\xde\x5f" + b"\x60\xe1\xe2\x63\xe4\x65\x66\xe7\xe8\x69\x6a\xeb\x6c\xed\xee\x6f" + b"\xf0\x71\x72\xf3\x74\xf5\xf6\x77\x78\xf9\xfa\x7b\xfc\x7d\x7e\xff" + ) + + # iq2xxs_grid, but with each byte of the original packed in 2 bits, + # by mapping 0x08 to 0, 0x19 to 1, and 0x2b to 2. + grid_shape = (256, 8) + grid_map = (0x08, 0x19, 0x2b) + grid_hex = ( + b"00000200050008000a00110014002000220028002a0041004400500058006100" + b"6400800082008a00a20001010401100115014001840198010002020222028202" + b"010404041004210424044004420448046004810484049004a404000502050805" + b"200546056905800591050906100640068406a406000805080808140828084108" + b"440850085208880804094009020a140a01100410101021104010601084109010" + b"951000110811201150115a118011241245120014081420142514491480141815" + b"6215001616160118041810184018811800190519a019511a002002200a204420" + b"6120802082202921482100220222012404241024402456240025412564259026" + b"082820289428442a014004401040184021402440404048405640604081408440" + b"9040004120416141804185410142104248425642684200440844204480449944" + b"124524450046014804481048404845480049584961498249454a904a00500850" + b"1150195020508050885004514251a4519152905492540a550156545600581158" + b"195864584059085a046010604060686000615561186260620064056410651265" + b"84654268008002800a8041808280048118814081118201840484108415844084" + b"608400854685948509864086608602880489118a0490109024904090a1901691" + b"8091459200942294449451958198209902a050a085a009a100a218a450a804a9" + ) + + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + d, qs = np.hsplit(blocks, [2]) + + d = d.view(np.float16).astype(np.float32) + + qs = qs.view(np.uint32).reshape(n_blocks, -1, 2) + + db = d * (np.float32(0.5) + (qs[..., 1] >> 28).astype(np.float32)) * np.float32(0.25) + db = db.reshape((n_blocks, -1, 1, 1)) + + # get the sign indices and unpack the bits + signs = qs[..., 1].reshape((n_blocks, -1, 1)) >> np.array([0, 7, 14, 21], dtype=np.uint32).reshape((1, 1, 4)) + ksigns = np.frombuffer(cls.ksigns, dtype=np.uint8).reshape((1, 1, 1, 128)) + signs = (signs & np.uint32(0x7F)).reshape((n_blocks, -1, 4, 1)) + signs = np.take_along_axis(ksigns, signs, axis=-1) + signs = signs.reshape((n_blocks, -1, 4, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 1, 8)) + signs = signs & np.uint8(0x01) + signs = np.where(signs == 0, np.float32(1), np.float32(-1)) + signs = signs.reshape((n_blocks, -1, 4, 8)) + + assert cls.grid is not None + grid = np.take_along_axis(cls.grid, qs[..., 0].copy().view(np.uint8).reshape((n_blocks, -1, 1, 1)), axis=-2) + grid = grid.reshape((n_blocks, -1, 4, 8)) + + return (db * grid * signs).reshape((n_blocks, -1)) + + +class IQ2_XS(__Quant, qtype=GGMLQuantizationType.IQ2_XS): + # iq2xs_grid, but with each byte of the original packed in 2 bits, + # by mapping 0x08 to 0, 0x19 to 1, and 0x2b to 2. + grid_shape = (512, 8) + grid_map = (0x08, 0x19, 0x2b) + grid_hex = ( + b"00000200050008000a0011001400160019002000220025002800410044004600" + b"49005000520055005800610064008000820085008800910094009900a0000101" + b"04010601090110011201150118011a0121012401400142014501480151015401" + b"6001680181018401900100020202050208021102140220024102440250025502" + b"80028a0201040404060409041004120415041804210424044004420445044804" + b"5104540456046004810484049004000502050505080511051405200541054405" + b"500561058005010604061006260640064206840600080208050808080a081108" + b"14082008250841084408500858088008a008aa08010904091009400981098909" + b"000a200a280a960aa00a01100410061009101010121015101810211024104010" + b"4210451048105110541060106a10811084109010001102110511081111111411" + b"2011411144115011801194119611011204120612101240126012001402140514" + b"0814111414142014411444144914501464148014011504151015401500161416" + b"49160118041810181218401854188618001905196619511aa91a002002200520" + b"08200a201120142020204120442050208020a020012104211021402148216521" + b"002222228022a82201240424102429244024002541255225992501261a26a626" + b"002808280a28202855288828a22868299029082a202a822a882a8a2a01400440" + b"0640094010401240154018402140244040404240454048404a40514054406040" + b"6540814084409040004102410541084111411441204141414441504180418541" + b"a241014204421042124229424042004402440544084411441444194420444144" + b"4444504480449444014504451045244540459a4500460a464446504601480448" + b"1048404845485448624800491149444950496949044a00500250055008501150" + b"145020502850415044505050805001510451105115514051425100524452aa52" + b"0154045410542154405460548154a154005508558055885521566856a1560058" + b"14584158505899581a5940594259855a0160046010604060546062608660a960" + b"006124624a62926200641664106540654565a46501686a682569066a546a626a" + b"00800280058008801180148020802a8041804480508080808280a880aa800181" + b"0481068110814081518159810082208280828282a082a8820184048410841284" + b"158440846084898400854485a58518866a860088088825885a8880888288a888" + b"0689228a808a888a968aa88a0190049010904090569084900091229164915692" + b"89920094059444945094589429959095929541965198a6984999159a609a00a0" + b"02a008a00aa020a02aa0a0a051a159a1a6a100a202a208a22aa280a2a0a240a4" + b"95a465a698a60aa820a822a828a8a0a8a8a804a984a986a928aa2aaa91aaaaaa" + ) + + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + d, rest = np.hsplit(blocks, [2]) + qs, scales = np.hsplit(rest, [2 * QK_K // 8]) + + d = d.view(np.float16).astype(np.float32) + qs = qs.view(np.uint16) + + scales = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2)) + scales = (scales & 0x0F).reshape((n_blocks, -1)) + db = d * (np.float32(0.5) + scales) * np.float32(0.25) + db = db.reshape((n_blocks, -1, 1, 1)) + + # get the sign indices and unpack the bits + signs = np.frombuffer(IQ2_XXS.ksigns, dtype=np.uint8).reshape(1, 1, 128) + signs = np.take_along_axis(signs, (qs >> 9).reshape((n_blocks, -1, 1)), axis=-1) + signs = signs.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8)) + signs = signs & np.uint8(0x01) + signs = np.where(signs == 0, np.float32(1), np.float32(-1)) + signs = signs.reshape((n_blocks, -1, 2, 8)) + + assert cls.grid is not None + grid = np.take_along_axis(cls.grid, (qs & np.uint16(511)).reshape((n_blocks, -1, 1, 1)), axis=-2) + grid = grid.reshape((n_blocks, -1, 2, 8)) + + return (db * grid * signs).reshape((n_blocks, -1)) + + +class IQ2_S(__Quant, qtype=GGMLQuantizationType.IQ2_S): + # iq2s_grid, but with each byte of the original packed in 2 bits, + # by mapping 0x08 to 0, 0x19 to 1, and 0x2b to 2. + grid_shape = (1024, 8) + grid_map = (0x08, 0x19, 0x2b) + grid_hex = ( + b"00000200050008000a0011001400160019002000220025002800410044004600" + b"490050005200550058006100640066006900800082008500880091009400a000" + b"a500aa0001010401060109011001120115011801210124014001420145014801" + b"510154015601590160016501680181018401900192019501a101a40100020202" + b"050208021102140220022a02410244024602490250025502800285028a029402" + b"a202010404040604090410041204150418042104240426042904400442044504" + b"48044a0451045404560459046004620465048104840486048904900495049804" + b"a104a40400050205050508050a05110514051605190520052505280541054405" + b"46054905500552055505580561056405800582058505880591059405a0050106" + b"0406060609061006150640064506480651065406600681068406900600080208" + b"050808081108140816081908200825082a084108440846084908500852085508" + b"580861086408800885089408aa08010904091009120915091809210940094509" + b"480951095409600981099009000a110a140a220a280a2a0a500a990a01100410" + b"0610091010101210151018102110241026104010421045104810511054105610" + b"59106010621065106810811084108610901095109810a110a410001102110511" + b"08110a1111111411161119112011221125112811411144114611491150115211" + b"5511581161116411801182118511881191119411011204120912101215122112" + b"2412401245125112541281128412901200140214051408141114141416141914" + b"2014251428144114441446144914501452145514581461146414801482148514" + b"881491149414a014011504150615091510151215151518152115241540154215" + b"4515481551155415601581158415901500160516081611161416201641164416" + b"50168016aa160118041806180918101815181818211840184218451848185118" + b"541860188118841800190219051908191119141920194119441950196919a219" + b"041a101a401a561a00200220052008201120142016201920202025202a204120" + b"4420502052205520642080208a209420aa200121042110211221152121214021" + b"4221452151215421602181218421902100220a22222228222a22442250228822" + b"8a22a82201240424062409241024152418242124242440244224452448245124" + b"5424602481248424902400250525082511251425202541254425502566258025" + b"0126042610264026592600280528112814284128442850288a28aa2801290429" + b"102995290a2a222a642a882a8a2a014004400640094010401240154018401a40" + b"21402440264040404240454048404a4051405440564059406040624065408140" + b"8440904095409840a140a4400041024105410841114114411641194120412241" + b"2541414144414641494150415241554158416141644180418241854188419141" + b"9441a04101420442104212421542184224424042454248425142544260428142" + b"844200440244054408440a441144144416441944204422442544284441444444" + b"46444944504452445544584461446444804482448544884491449444a0440145" + b"0445064509451045124515451845214524454045424545454845514554456045" + b"6a4581458445904500460246054608461146144620464146444650468046a546" + b"0148044809481048124815481848214824484048424845484848514854486048" + b"84489048004902490549084911491449204941494449504980499649014a044a" + b"104a404a00500250055008501150145016501950205022502550285041504450" + b"4650495050505250555058506150645080508250855088509150945001510451" + b"0651095110511251155118512151245140514251455148515151545160518151" + b"8451905100520552085211521452205241524452505269528052015404540654" + b"0954105412541554185421542454405442544554485451545454605481548454" + b"9054005502550555085511551455205541554455505580550156045610562656" + b"405600580258055808581158145820584158445850585a588058015904591059" + b"4059005a195a855aa85a01600460066010601260156018602160246040604560" + b"4860516054606060846090600061026105610861116114612061416144615061" + b"806199610462106240625662a162006405640864116414642064416444645064" + b"806401650465106540654a656865926500669466016804681068656898680069" + b"2a69426aa16a0080028005800880118014801980208025804180448050805280" + b"5580588061808080858091809480018104810981108112811581188121812481" + b"408142814581488151815481818184819081a981008205820a82118214824182" + b"4482508201840484068409841084128415841884218440844284458448845184" + b"5484608481848484908400850285058508851185148520854185448550858085" + b"8a85018604861086298640860088058811881488418844885088a28801890489" + b"40896589228a588a5a8a828aa28a019004900990109012901590189024904090" + b"4290459048905190549060908190849090900091059111911491419144915091" + b"5a910192049210924092a6920094029405940894119414942094419444945094" + b"8094969401950495109540959895a19500964696649601980498109826984098" + b"a998009949995299909a00a005a00aa014a022a02aa041a044a050a0a2a0aaa0" + b"40a165a102a20aa222a228a22aa282a288a28aa2a8a201a404a410a440a489a4" + b"a4a400a519a551a60aa828a8a2a854a986a908aa0aaa20aa22aa28aa88aaaaaa" + ) + + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + d, rest = np.hsplit(blocks, [2]) + qs, rest = np.hsplit(rest, [QK_K // 8]) + signs, rest = np.hsplit(rest, [QK_K // 8]) + qh, scales = np.hsplit(rest, [QK_K // 32]) + + d = d.view(np.float16).astype(np.float32) + + scales = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2)) + scales = (scales & 0x0F).reshape((n_blocks, -1)) + db = d * (np.float32(0.5) + scales) * np.float32(0.25) + db = db.reshape((n_blocks, -1, 1, 1)) + + # unpack the sign bits + signs = signs.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8)) + signs = signs & np.uint8(0x01) + signs = np.where(signs == 0, np.float32(1), np.float32(-1)) + signs = signs.reshape((n_blocks, -1, 2, 8)) + + qh = qh.reshape((n_blocks, -1, 1)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4)) + qs = qs.astype(np.uint16) | ((qh & 0x03).astype(np.uint16) << 8).reshape((n_blocks, -1)) + + assert cls.grid is not None + grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2) + grid = grid.reshape((n_blocks, -1, 2, 8)) + + return (db * grid * signs).reshape((n_blocks, -1)) + + +class IQ3_XXS(__Quant, qtype=GGMLQuantizationType.IQ3_XXS): + grid_shape = (256, 4) + grid_map = (0x04, 0x0c, 0x14, 0x1c, 0x24, 0x2c, 0x34, 0x3e) + grid_hex = ( + b"0000020004001100130017002000220031004200730075000101030110011201" + b"2101250130013201410154017001000202020402110220022202310233023702" + b"5102570275020103070310031203250370031304370444045704730475040105" + b"0705320552053506640610071407160743076107011003101010121021102310" + b"3010321034104710501000110211111120112211011203121012121221123012" + b"7212001302132013311346136613011405145014201524154615711505162217" + b"4017002002201120132020202220262031204220012103210521102112212121" + b"3021632167217021002202221122172220222222372240225522012310231423" + b"7023742335245324032527254125742501270327162745270130103012302130" + b"2330503065307230003102312031313144314631013203321032253252327232" + b"1133333330344734723400350635223555351436363663363337603704401740" + b"3540374053405740744120423742404260426642074345430444514464442545" + b"4345704505471047124730471250415070500051065126515551145232527252" + b"0253535310542354275472540255315550562457425724604460466064602161" + b"6161176264623063366344640565526533660367216703700570077010703270" + b"5270267140711272457252720073157333736073217441740075027524753076" + ) + + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + d, rest = np.hsplit(blocks, [2]) + qs, scales = np.hsplit(rest, [QK_K // 4]) + + d = d.view(np.float16).astype(np.float32) + scales = scales.view(np.uint32) + + db = d * (np.float32(0.5) + (scales >> 28).astype(np.float32)) * np.float32(0.5) + db = db.reshape((n_blocks, -1, 1, 1)) + + # get the sign indices and unpack the bits + signs = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 7, 14, 21], dtype=np.uint32).reshape((1, 1, 4)) + ksigns = np.frombuffer(IQ2_XXS.ksigns, dtype=np.uint8).reshape((1, 1, 1, 128)) + signs = (signs & np.uint32(0x7F)).reshape((n_blocks, -1, 4, 1)) + signs = np.take_along_axis(ksigns, signs, axis=-1) + signs = signs.reshape((n_blocks, -1, 4, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 1, 8)) + signs = signs & np.uint8(0x01) + signs = np.where(signs == 0, np.float32(1), np.float32(-1)) + signs = signs.reshape((n_blocks, -1, 4, 8)) + + assert cls.grid is not None + grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2) + grid = grid.reshape((n_blocks, -1, 4, 8)) + + return (db * grid * signs).reshape((n_blocks, -1)) + + +class IQ3_S(__Quant, qtype=GGMLQuantizationType.IQ3_S): + grid_shape = (512, 4) + grid_map = (0x01, 0x03, 0x05, 0x07, 0x09, 0x0b, 0x0d, 0x0f) + grid_hex = ( + b"0000010002000500070010001100120014001600200021002500330040004200" + b"4500470051005300600062007100740077000001010102010401100111011501" + b"2001230127013101350144016101650172010002010205020702100213021602" + b"2102250230023402420245024702510253027002730203031103150320032203" + b"3103330336034403500352036703710375030004130417042104240432044004" + b"4304510470040205040520052205260533054105450547056605730506061106" + b"1306310652067106000702070407200722072607330750075407001001100210" + b"0410101011101310151017102010221031103410361054105610611072100011" + b"0111031106111011141121113011331141115011521170117611001212121512" + b"1712201224123212401243125512601272120113041307131013131321132713" + b"3013341341136213701303140514121414143114331442144614501454140115" + b"1015131521153015321551152016241627164416461601170317101712172117" + b"3517411762177017002001200320052007201020122014201620212023202720" + b"3020322041204320452050205220672070207320752000210221102113211721" + b"2221252131213421422151210122042207222122232230223722412253225722" + b"7122742200230223052311232223242331233323422350236623012407242024" + b"2324322435244124722475240425112522253725402553257025002602260726" + b"2126552661260527112726273027432750270230113013301530173022303130" + b"3330353042304430473051306330713001310331053114312131233140316031" + b"7231763100321232203232323432503201331033143321332333273330334133" + b"4333473355337333033411341634223431345234603464340135103512352535" + b"3235443556357335163641360137033720372237353700400440124020402440" + b"2740324041405040704002410741114113412241304135414341514155410142" + b"0342104215422142334240425742624270420443114313432043224331433543" + b"0044024424443744404471440545074521456245134634466046104715473047" + b"4347514702501050145022504050445047505250665074500151035105511251" + b"2151325172510052115223523052365253520253075310532753445351536553" + b"7353015404542054325446541255265551555355425602570457225711601360" + b"1560316033606060006120612761646112623462426255626262706200631463" + b"2163406325644364626400650365346560650566406611671367007004700770" + b"2070227036704070547062700271117124714371457101720472107216722172" + b"3072517202733273357353730174057413742074507422754275027631760077" + ) + + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + d, rest = np.hsplit(blocks, [2]) + qs, rest = np.hsplit(rest, [QK_K // 4]) + qh, rest = np.hsplit(rest, [QK_K // 32]) + signs, scales = np.hsplit(rest, [QK_K // 8]) + + d = d.view(np.float16).astype(np.float32) + + scales = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2)) + scales = (scales & 0x0F).reshape((n_blocks, -1)) + db = d * (1 + 2 * scales) + db = db.reshape((n_blocks, -1, 1, 1)) + + # unpack the sign bits + signs = signs.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8)) + signs = signs & np.uint8(0x01) + signs = np.where(signs == 0, np.float32(1), np.float32(-1)) + signs = signs.reshape((n_blocks, -1, 4, 8)) + + qh = qh.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8) + qh = (qh & 0x01).astype(np.uint16).reshape((n_blocks, -1)) + qs = qs.astype(np.uint16) | (qh << 8) + + assert cls.grid is not None + grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2) + grid = grid.reshape((n_blocks, -1, 4, 8)) + + return (db * grid * signs).reshape((n_blocks, -1)) + + +class IQ1_S(__Quant, qtype=GGMLQuantizationType.IQ1_S): + # iq1s_grid, with each byte packed into 2 bits + # -1, 0, 1 <=> 0, 1, 2 + grid_shape = (2048, 8) + grid_map = (-1, 0, 1) + grid_hex = ( + b"00000200050008000a00110015002000220028002a0045005100540056006500" + b"8000820088008a009500a000a200a800aa000401050111011401160119011a01" + b"2501410146014901520155015a0161016401660168018501910194019601a501" + b"0002020208020a0215022002220228022a024502510259026402690280028202" + b"88028a02910295029902a002a202a802aa021104140416042504410449045504" + b"5a046404650491049904a5040105040505050605150518051a05290540054505" + b"4a0550055105540555055605590560056205650568056a058105910595059805" + b"9a05a105a405a505a605a9051406190641064406500652065506580660066106" + b"6606690685069106940699060008020808080a0815082008220828082a084508" + b"5108560865088008820888088a089508a008a208a808aa080509110914091909" + b"2409250941095009510955096109640969099109940996099909a509000a020a" + b"080a0a0a150a200a220a280a2a0a450a510a590a610a650a800a820a850a880a" + b"8a0a950aa00aa20aa80aaa0a1010111014101910241025104110441050105510" + b"58106110641065106910911094109610a110a510011104110611091110111211" + b"1511181121112411291145114a11501151115211541155115611591160116511" + b"841192119511a111a41111121412161225124012461249125212551258125a12" + b"641266128512911294129612a512011406140914141415141814191421142614" + b"41144514461448144a1451145414551456145914621465146814841489149014" + b"94149514981499149a14a114a414a514a914021505150a151115141515151615" + b"191520152215251528152a154115441545154615511552155415551556155915" + b"5a1561156415651566156915801582158415851588158a159015911594159515" + b"961599159a15a015a215a51501160416051606161516161618161a1621162616" + b"401642164416451648164a165116551656165816591661166416651668166916" + b"6a1686168a1692169516a416a916111816182518411844184618491850185518" + b"58185a1860186118641866186918851891189418a5181019121915191a192119" + b"25194219441945194819511954195519561959195a19601965196a1989199119" + b"921995199819a119a619a919091a161a241a261a441a461a491a501a521a551a" + b"581a611a661a691a851a911a961a9a1a0020022008200a201520202022202520" + b"28202a20452051205920612065208020822088208a209520a020a220a520a820" + b"aa2005211121142119212521422144214921552158215a216121642165216621" + b"8521902196219921a521012208220a22112215222022222228222a2245225122" + b"562259226522812288228a2291229522a022a222a822aa220524142416241924" + b"252444244524462449245224552458245a2466248524912494249924a124a524" + b"0925152521252925402545254825512554255525592562256525682589259025" + b"9425952598259a25a125a425a625a92505261026122619262526412649265526" + b"6026612669268426862690269a260028022808280a2815282028222828282a28" + b"45285128542865288028822888288a28a028a228a828aa280929112914291929" + b"2529462949295229552961296429662969298529902996299929a429a529002a" + b"022a082a0a2a202a222a282a2a2a452a512a562a592a652a802a822a882a8a2a" + b"952aa02aa22aa82aaa2a054011401640254049405240554058405a4061406440" + b"664094409940a140a6400041014104410641094112411541164118411a412141" + b"26412941454148414a41514154415541564159415a41654168416a4181418441" + b"8641904192419541a041a141a241054211421442164225424142524255425a42" + b"6442694289429442a5420144154419442944454448444a445144544455445644" + b"61446244654468446a44814486448944904492449544a044a144a94401450245" + b"05450a4511451445154516451945204525452a45414544454545464549455045" + b"5145544555455645584559456145644565456645694582458445854588459145" + b"94459545964599459a45a545a845aa450146054609461446154618461a462146" + b"2446294640464246454648465046514652465546564659466246654668468146" + b"85468a4694469546a146a446a6460548114815481a4825484248494850485548" + b"5848614864486648694885489148944896489948a5480149054906490a491049" + b"144915491849214924492649404945494a495149524954495549564959496049" + b"6249654966496a49864989499249954996499849a149a449a649a949164a444a" + b"464a494a554a584a5a4a644a694a944aa54a0150045005500650095012501550" + b"1a50215024502950405045504850515054505550565059506550685086508950" + b"95509850a050a150a650a9500551085109510a51115114511551165118511951" + b"20512551265128512a5141514451455146514951505151515251545155515651" + b"585159515a51615164516551665169518251855191519451955196519951a051" + b"a551aa5101520652125215521a5221522452425245524a525152545255525652" + b"595262526552855290529252955299529a52a452045405541154145415541654" + b"185419542154255428542a54415444544554465449544a545054515454545554" + b"5654585459545a54615462546454655466546954805488548a54915494549554" + b"96549954a154a454a554aa540155025504550555065509551055115512551455" + b"1555165519551a55215524552555265529554055415542554455455546554855" + b"4955505551555255545555555655585559555a55605561556455655566556855" + b"69556a5581558455855589558a559055915594559555965598559955a155a455" + b"a555a655a9550056015602560456065608560956115614561556185619562056" + b"2156225624562556265628562956415645564656485649564a56505651565256" + b"545655565656585659565a566156645665566956825685568656885689568a56" + b"915695569a56a256a556a656a856a95604580558065809581058155818582158" + b"2a58455848584a58515854585558565858585958605862586458655882588958" + b"9058925895589858a158a9580159025905590a59115914591559165919592559" + b"41594459455946594959505951595259545955595659585959595a5961596459" + b"655966596959815985598959915994599559965998599959a559045a085a155a" + b"1a5a205a255a265a295a455a485a495a515a555a565a585a595a625a655a685a" + b"6a5a815a8a5a925a955a965a985a9a5aa15a0560146016601960256044605060" + b"5560566058605a60616064606660696081609660a56001610461066109611261" + b"15612161226126612961456149615161556156615961656166616a6184618a61" + b"92619561a161a661a96111621662196240624162466255625662586260628562" + b"91629662a56211641264156416641a6421642664296440644264456448644a64" + b"516454645564566459645a646064626465648464856489649064926494649564" + b"966498649a64a164a464a964056508650a651165156516651965446545654665" + b"496550655165546555655665596561656465656566656965866589658a659165" + b"9565966599659a65a265a565a665a86502660966156620662666286629664066" + b"456648664a66516654665566566658665a666066656668668066826685668a66" + b"9466966698669966a066a466a666aa661668196825684168526855685a686168" + b"6968856891689868a66801690469106915692169246926692969406941694569" + b"4669486951695469556956695969606965696a69826984698a699569a169a469" + b"a569a969116a166a186a416a446a496a506a556a586a5a6a646a656a696a866a" + b"946a986a9a6aa66a0080028008800a802080228028802a804580508051805480" + b"5680598065808080828088808a809580a080a280a880aa800581118114811681" + b"1981258141814481498150815281558156815881598164816681698185818981" + b"948196819981a5810082028208820a8215822082228228822a82518254825982" + b"65828082828288828a829582a082a282a882aa82148419844184448451845584" + b"5a846184648469849484998401850985128515851a8526852985408541854585" + b"4885518554855585568559855a856585668568856a8581858485868589859085" + b"928595859885a68511861686198625864186448649864a865086558659865a86" + b"618666866a86858691869a86a4860088028808880a8815882088228828882a88" + b"41884588518854885988658869888088828888888a889588a088a288a888aa88" + b"05890689118914891689258941894489468949895089528955895a8961896489" + b"858996899989a589008a028a088a0a8a158a208a228a288a2a8a458a518a548a" + b"568a808a828a888a8a8a958aa08aa28aa88aaa8a059011901690189019902590" + b"419046904990559058905a9069906a9085909190949096909990a59001910491" + b"069109911091159118911a912191249126912991409145915091519154915591" + b"569159916291659184918691929195919891a191a491a691a991059211921492" + b"19922592449246924992509252925592589266926992859294929692a9920194" + b"04940694109415941894269440944a9451945494559456945894599460946194" + b"62946594849486949294949495949894a194a9940095059508950a9510951195" + b"14951595169519952195259529952a9541954495459546954995509551955295" + b"549555955695589559955a956195649565956695699581958595889591959295" + b"94959595969599959a95a095a295a595a895aa95019604961096159619962096" + b"2696299645964896499651965296559656965996659668968296849689968a96" + b"929694969596a496a696a9960598169819982598419846985098529855985698" + b"5a98649865988598919896989998a59804990699099910991299159918991a99" + b"209921992499269940994299459948994a995199549955995699599962996599" + b"66996a99819984999099929995999a99a199a699059a159a259a449a469a499a" + b"509a559a589a619a859a919a949a959a969a00a002a008a00aa015a020a022a0" + b"28a02aa045a051a054a056a059a080a082a088a08aa095a0a0a0a2a0a8a0aaa0" + b"05a109a111a114a116a119a11aa146a149a151a155a158a15aa161a164a185a1" + b"90a192a196a199a102a208a20aa210a219a222a228a22aa245a251a256a259a2" + b"65a280a282a288a28aa295a2a0a2a2a2a8a2aaa219a425a441a444a450a454a4" + b"55a458a45aa461a465a466a468a469a485a406a509a510a512a515a518a526a5" + b"29a542a545a551a554a555a556a559a565a56aa581a584a585a586a589a592a5" + b"95a598a505a611a616a61aa621a625a644a646a64aa652a655a656a658a660a6" + b"62a686a690a695a696a699a6a1a6a4a6a6a600a802a808a80aa820a822a828a8" + b"2aa851a854a856a859a880a882a888a88aa895a8a0a8a2a8a8a8aaa805a914a9" + b"19a921a925a941a950a955a95aa961a966a969a990a996a900aa02aa08aa0aaa" + b"20aa22aa28aa2aaa51aa54aa56aa80aa82aa88aa8aaa95aaa0aaa2aaa8aaaaaa" + ) + + delta = np.float32(0.125) + + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + d, rest = np.hsplit(blocks, [2]) + qs, qh = np.hsplit(rest, [QK_K // 8]) + + d = d.view(np.float16).astype(np.float32) + qh = qh.view(np.uint16) + + dl = d * (2 * ((qh >> 12) & 7) + 1) + dl = dl.reshape((n_blocks, -1, 1, 1)) + delta = np.where((qh & np.uint16(0x8000)) == 0, cls.delta, -cls.delta) + delta = delta.reshape((n_blocks, -1, 1, 1)) + + qh = qh.reshape((n_blocks, -1, 1)) >> np.array([0, 3, 6, 9], dtype=np.uint16).reshape((1, 1, 4)) + qs = qs.astype(np.uint16) | ((qh & 7) << 8).reshape((n_blocks, -1)) + + assert cls.grid is not None + grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2) + grid = grid.reshape((n_blocks, -1, 4, 8)) + + return (dl * (grid + delta)).reshape((n_blocks, -1)) + + +class IQ1_M(__Quant, qtype=GGMLQuantizationType.IQ1_M): + grid_shape = IQ1_S.grid_shape + grid_map = IQ1_S.grid_map + grid_hex = IQ1_S.grid_hex + + delta = IQ1_S.delta + + # Okay *this* type is weird. It's the only one which stores the f16 scales in multiple parts. + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + qs, rest = np.hsplit(blocks, [QK_K // 8]) + qh, scales = np.hsplit(rest, [QK_K // 16]) + + # The f16 scale is packed across multiple bytes + scales = scales.view(np.uint16) + d = (scales.reshape((n_blocks, 4)) & np.uint16(0xF000)) >> np.array([12, 8, 4, 0], dtype=np.uint16).reshape((1, 4)) + d = d[..., 0] | d[..., 1] | d[..., 2] | d[..., 3] + d = d.view(np.float16).astype(np.float32).reshape((n_blocks, 1)) + + scales = scales.reshape(n_blocks, -1, 1) >> np.array([0, 3, 6, 9], dtype=np.uint16).reshape((1, 1, 4)) + scales = (scales & 0x07).reshape((n_blocks, -1)) + dl = d * (2 * scales + 1) + dl = dl.reshape((n_blocks, -1, 2, 1, 1)) + + qh = qh.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2)) + qs = qs.astype(np.uint16) | ((qh & 0x07).astype(np.uint16) << 8).reshape((n_blocks, -1)) + + delta = np.where(qh & 0x08 == 0, cls.delta, -cls.delta) + delta = delta.reshape((n_blocks, -1, 2, 2, 1)) + + assert cls.grid is not None + grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2) + grid = grid.reshape((n_blocks, -1, 2, 2, 8)) + + return (dl * (grid + delta)).reshape((n_blocks, -1)) + + +class IQ4_NL(__Quant, qtype=GGMLQuantizationType.IQ4_NL): + kvalues = (-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113) + + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + d, qs = np.hsplit(blocks, [2]) + + d = d.view(np.float16).astype(np.float32) + + qs = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1)) + + qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1, 1)) + + kvalues = np.array(cls.kvalues, dtype=np.int8).reshape(1, 1, 16) + qs = np.take_along_axis(kvalues, qs, axis=-1).astype(np.float32).reshape((n_blocks, -1)) + + return (d * qs) + + +class IQ4_XS(__Quant, qtype=GGMLQuantizationType.IQ4_XS): + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + d, rest = np.hsplit(blocks, [2]) + scales_h, rest = np.hsplit(rest, [2]) + scales_l, qs = np.hsplit(rest, [QK_K // 64]) + + d = d.view(np.float16).astype(np.float32) + scales_h = scales_h.view(np.uint16) + + scales_l = scales_l.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2)) + scales_h = scales_h.reshape((n_blocks, 1, -1)) >> np.array([2 * i for i in range(QK_K // 32)], dtype=np.uint16).reshape((1, -1, 1)) + scales_l = scales_l.reshape((n_blocks, -1)) & np.uint8(0x0F) + scales_h = scales_h.reshape((n_blocks, -1)).astype(np.uint8) & np.uint8(0x03) + + scales = (scales_l | (scales_h << np.uint8(4))).astype(np.int8) - np.int8(32) + dl = (d * scales.astype(np.float32)).reshape((n_blocks, -1, 1)) + + qs = qs.reshape((n_blocks, -1, 1, 16)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1)) + qs = qs.reshape((n_blocks, -1, 32, 1)) & np.uint8(0x0F) + + kvalues = np.array(IQ4_NL.kvalues, dtype=np.int8).reshape((1, 1, 1, -1)) + qs = np.take_along_axis(kvalues, qs, axis=-1).astype(np.float32).reshape((n_blocks, -1, 32)) + + return (dl * qs).reshape((n_blocks, -1)) diff --git a/gguf-py/gguf/scripts/gguf_convert_endian.py b/gguf-py/gguf/scripts/gguf_convert_endian.py new file mode 100644 index 0000000000000000000000000000000000000000..0e0febaa791786fb9d5c713478377e5d54642df8 --- /dev/null +++ b/gguf-py/gguf/scripts/gguf_convert_endian.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import logging +import argparse +import os +import sys +from tqdm import tqdm +from pathlib import Path + +import numpy as np + +# Necessary to load the local gguf package +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists(): + sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +import gguf + +logger = logging.getLogger("gguf-convert-endian") + + +def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None: + file_endian = reader.endianess.name + if reader.byte_order == 'S': + host_endian = 'BIG' if file_endian == 'LITTLE' else 'LITTLE' + else: + host_endian = file_endian + order = host_endian if args.order == "native" else args.order.upper() + logger.info(f"* Host is {host_endian} endian, GGUF file seems to be {file_endian} endian") + if file_endian == order: + logger.info(f"* File is already {order} endian. Nothing to do.") + sys.exit(0) + logger.info("* Checking tensors for conversion compatibility") + for tensor in reader.tensors: + if tensor.tensor_type not in ( + gguf.GGMLQuantizationType.F32, + gguf.GGMLQuantizationType.F16, + gguf.GGMLQuantizationType.Q8_0, + gguf.GGMLQuantizationType.Q4_K, + gguf.GGMLQuantizationType.Q6_K, + ): + raise ValueError(f"Cannot handle type {tensor.tensor_type.name} for tensor {repr(tensor.name)}") + logger.info(f"* Preparing to convert from {file_endian} to {order}") + if args.dry_run: + return + logger.warning("*** Warning *** Warning *** Warning **") + logger.warning("* This conversion process may damage the file. Ensure you have a backup.") + if order != host_endian: + logger.warning("* Requested endian differs from host, you will not be able to load the model on this machine.") + logger.warning("* The file will be modified immediately, so if conversion fails or is interrupted") + logger.warning("* the file will be corrupted. Enter exactly YES if you are positive you want to proceed:") + response = input("YES, I am sure> ") + if response != "YES": + logger.warning("You didn't enter YES. Okay then, see ya!") + sys.exit(0) + logger.info(f"* Converting fields ({len(reader.fields)})") + for idx, field in enumerate(reader.fields.values()): + logger.info(f"- {idx:4}: Converting field {repr(field.name)}, part count: {len(field.parts)}") + for part in field.parts: + part.byteswap(inplace=True) + logger.info(f"* Converting tensors ({len(reader.tensors)})") + + for idx, tensor in enumerate(pbar := tqdm(reader.tensors, desc="Converting tensor")): + log_message = ( + f"Converting tensor {repr(tensor.name)}, " + f"type={tensor.tensor_type.name}, " + f"elements={tensor.n_elements} " + ) + + # Byte-swap each part of the tensor's field + for part in tensor.field.parts: + part.byteswap(inplace=True) + + # Byte-swap tensor data if necessary + if tensor.tensor_type == gguf.GGMLQuantizationType.Q8_0: + # Handle Q8_0 tensor blocks (block_q8_0) + # Specific handling of block_q8_0 is required. + # Each block_q8_0 consists of an f16 delta (scaling factor) followed by 32 int8 quantizations. + + block_size = 34 # 34 bytes = + 32 * + + n_blocks = len(tensor.data) // block_size + for block_num in (inner_pbar := tqdm(range(n_blocks), desc="Byte-swapping Blocks", leave=False)): + block_offs = block_num * block_size + + # Byte-Swap f16 sized delta field + delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16) + delta.byteswap(inplace=True) + + # Byte-Swap Q8 weights + if block_num % 100000 == 0: + inner_pbar.set_description(f"Byte-swapping Blocks [{(n_blocks - block_num) // n_blocks}]") + + elif tensor.tensor_type == gguf.GGMLQuantizationType.Q4_K: + # Handle Q4_K tensor blocks (block_q4_k) + # Specific handling of block_q4_k is required. + # Each block_q4_k consists of 2 f16 values followed by 140 int8 values. + + # first flatten structure + newshape = 1 + for i in tensor.data.shape: + newshape *= i + + tensor.data.resize(newshape) + + block_size = 144 + n_blocks = len(tensor.data) // block_size + for block_num in (inner_pbar := tqdm(range(n_blocks), desc="Byte-swapping Blocks", leave=False)): + block_offs = block_num * block_size + + # Byte-Swap f16 sized fields + delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16) + delta.byteswap(inplace=True) + + delta = tensor.data[block_offs + 2:block_offs + 4].view(dtype=np.uint16) + delta.byteswap(inplace=True) + + # Byte-Swap + if block_num % 100000 == 0: + inner_pbar.set_description(f"Byte-swapping Blocks [{(n_blocks - block_num) // n_blocks}]") + + elif tensor.tensor_type == gguf.GGMLQuantizationType.Q6_K: + # Handle Q6_K tensor blocks (block_q6_k) + # Specific handling of block_q6_k is required. + # Each block_q6_k consists of 208 int8 values followed by 1 f16 value. + + # first flatten structure + newshape = 1 + for i in tensor.data.shape: + newshape *= i + + tensor.data.resize(newshape) + + block_size = 210 + n_blocks = len(tensor.data) // block_size + for block_num in (inner_pbar := tqdm(range(n_blocks), desc="Byte-swapping Blocks", leave=False)): + block_offs = block_num * block_size + + # Byte-Swap f16 sized field + delta = tensor.data[block_offs + 208:block_offs + 210].view(dtype=np.uint16) + delta.byteswap(inplace=True) + + # Byte-Swap + if block_num % 100000 == 0: + inner_pbar.set_description(f"Byte-swapping Blocks [{(n_blocks - block_num) // n_blocks}]") + + else: + # Handle other tensor types + tensor.data.byteswap(inplace=True) + + pbar.set_description(log_message) + + logger.info("* Completion") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Convert GGUF file byte order") + parser.add_argument( + "model", type=str, + help="GGUF format model filename", + ) + parser.add_argument( + "order", type=str, choices=['big', 'little', 'native'], + help="Requested byte order", + ) + parser.add_argument( + "--dry-run", action="store_true", + help="Don't actually change anything", + ) + parser.add_argument("--verbose", action="store_true", help="increase output verbosity") + + args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"]) + + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) + + logger.info(f'* Loading: {args.model}') + reader = gguf.GGUFReader(args.model, 'r' if args.dry_run else 'r+') + convert_byteorder(reader, args) + + +if __name__ == "__main__": + main() diff --git a/gguf-py/gguf/scripts/gguf_dump.py b/gguf-py/gguf/scripts/gguf_dump.py new file mode 100644 index 0000000000000000000000000000000000000000..e282892d645c7c31a1d5e00f383dcaaa5890a127 --- /dev/null +++ b/gguf-py/gguf/scripts/gguf_dump.py @@ -0,0 +1,454 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import logging +import argparse +import os +import re +import sys +from pathlib import Path +from typing import Any + +# Necessary to load the local gguf package +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists(): + sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from gguf import GGUFReader, GGUFValueType, ReaderTensor # noqa: E402 + +logger = logging.getLogger("gguf-dump") + + +def get_file_host_endian(reader: GGUFReader) -> tuple[str, str]: + file_endian = reader.endianess.name + if reader.byte_order == 'S': + host_endian = 'BIG' if file_endian == 'LITTLE' else 'LITTLE' + else: + host_endian = file_endian + return (host_endian, file_endian) + + +# For more information about what field.parts and field.data represent, +# please see the comments in the modify_gguf.py example. +def dump_metadata(reader: GGUFReader, args: argparse.Namespace) -> None: + host_endian, file_endian = get_file_host_endian(reader) + print(f'* File is {file_endian} endian, script is running on a {host_endian} endian host.') # noqa: NP100 + print(f'* Dumping {len(reader.fields)} key/value pair(s)') # noqa: NP100 + for n, field in enumerate(reader.fields.values(), 1): + if not field.types: + pretty_type = 'N/A' + elif field.types[0] == GGUFValueType.ARRAY: + nest_count = len(field.types) - 1 + pretty_type = '[' * nest_count + str(field.types[-1].name) + ']' * nest_count + else: + pretty_type = str(field.types[-1].name) + + log_message = f' {n:5}: {pretty_type:10} | {len(field.data):8} | {field.name}' + if field.types: + curr_type = field.types[0] + if curr_type == GGUFValueType.STRING: + content = field.contents() + if len(content) > 60: + content = content[:57] + '...' + log_message += ' = {0}'.format(repr(content)) + elif curr_type in reader.gguf_scalar_to_np: + log_message += ' = {0}'.format(field.contents()) + else: + content = repr(field.contents(slice(6))) + if len(field.data) > 6: + content = content[:-1] + ', ...]' + log_message += ' = {0}'.format(content) + print(log_message) # noqa: NP100 + if args.no_tensors: + return + print(f'* Dumping {len(reader.tensors)} tensor(s)') # noqa: NP100 + for n, tensor in enumerate(reader.tensors, 1): + prettydims = ', '.join('{0:5}'.format(d) for d in list(tensor.shape) + [1] * (4 - len(tensor.shape))) + print(f' {n:5}: {tensor.n_elements:10} | {prettydims} | {tensor.tensor_type.name:7} | {tensor.name}') # noqa: NP100 + + +def dump_metadata_json(reader: GGUFReader, args: argparse.Namespace) -> None: + import json + host_endian, file_endian = get_file_host_endian(reader) + metadata: dict[str, Any] = {} + tensors: dict[str, Any] = {} + result = { + "filename": args.model, + "endian": file_endian, + "metadata": metadata, + "tensors": tensors, + } + for idx, field in enumerate(reader.fields.values()): + curr: dict[str, Any] = { + "index": idx, + "type": field.types[0].name if field.types else 'UNKNOWN', + "offset": field.offset, + } + metadata[field.name] = curr + if field.types[:1] == [GGUFValueType.ARRAY]: + curr["array_types"] = [t.name for t in field.types][1:] + if not args.json_array: + continue + curr["value"] = field.contents() + else: + curr["value"] = field.contents() + if not args.no_tensors: + for idx, tensor in enumerate(reader.tensors): + tensors[tensor.name] = { + "index": idx, + "shape": tensor.shape.tolist(), + "type": tensor.tensor_type.name, + "offset": tensor.field.offset, + } + json.dump(result, sys.stdout) + + +def markdown_table_with_alignment_support(header_map: list[dict[str, str]], data: list[dict[str, Any]]): + # JSON to Markdown table formatting: https://stackoverflow.com/a/72983854/2850957 + + # Alignment Utility Function + def strAlign(padding: int, alignMode: str | None, strVal: str): + if alignMode == 'center': + return strVal.center(padding) + elif alignMode == 'right': + return strVal.rjust(padding - 1) + ' ' + elif alignMode == 'left': + return ' ' + strVal.ljust(padding - 1) + else: # default left + return ' ' + strVal.ljust(padding - 1) + + def dashAlign(padding: int, alignMode: str | None): + if alignMode == 'center': + return ':' + '-' * (padding - 2) + ':' + elif alignMode == 'right': + return '-' * (padding - 1) + ':' + elif alignMode == 'left': + return ':' + '-' * (padding - 1) + else: # default left + return '-' * (padding) + + # Calculate Padding For Each Column Based On Header and Data Length + rowsPadding = {} + for index, columnEntry in enumerate(header_map): + padCount = max([len(str(v)) for d in data for k, v in d.items() if k == columnEntry['key_name']], default=0) + 2 + headerPadCount = len(columnEntry['header_name']) + 2 + rowsPadding[index] = headerPadCount if padCount <= headerPadCount else padCount + + # Render Markdown Header + rows = [] + rows.append('|'.join(strAlign(rowsPadding[index], columnEntry.get('align'), str(columnEntry['header_name'])) for index, columnEntry in enumerate(header_map))) + rows.append('|'.join(dashAlign(rowsPadding[index], columnEntry.get('align')) for index, columnEntry in enumerate(header_map))) + + # Render Tabular Data + for item in data: + rows.append('|'.join(strAlign(rowsPadding[index], columnEntry.get('align'), str(item[columnEntry['key_name']])) for index, columnEntry in enumerate(header_map))) + + # Convert Tabular String Rows Into String + tableString = "" + for row in rows: + tableString += f'|{row}|\n' + + return tableString + + +def element_count_rounded_notation(count: int) -> str: + if count > 1e15 : + # Quadrillion + scaled_amount = count * 1e-15 + scale_suffix = "Q" + elif count > 1e12 : + # Trillions + scaled_amount = count * 1e-12 + scale_suffix = "T" + elif count > 1e9 : + # Billions + scaled_amount = count * 1e-9 + scale_suffix = "B" + elif count > 1e6 : + # Millions + scaled_amount = count * 1e-6 + scale_suffix = "M" + elif count > 1e3 : + # Thousands + scaled_amount = count * 1e-3 + scale_suffix = "K" + else: + # Under Thousands + scaled_amount = count + scale_suffix = "" + return f"{'~' if count > 1e3 else ''}{round(scaled_amount)}{scale_suffix}" + + +def translate_tensor_name(name): + words = name.split(".") + + # Source: https://github.com/ggml-org/ggml/blob/master/docs/gguf.md#standardized-tensor-names + abbreviation_dictionary = { + 'token_embd': 'Token embedding', + 'pos_embd': 'Position embedding', + 'output_norm': 'Output normalization', + 'output': 'Output', + 'attn_norm': 'Attention normalization', + 'attn_norm_2': 'Attention normalization', + 'attn_qkv': 'Attention query-key-value', + 'attn_q': 'Attention query', + 'attn_k': 'Attention key', + 'attn_v': 'Attention value', + 'attn_output': 'Attention output', + 'ffn_norm': 'Feed-forward network normalization', + 'ffn_up': 'Feed-forward network "up"', + 'ffn_gate': 'Feed-forward network "gate"', + 'ffn_down': 'Feed-forward network "down"', + 'ffn_gate_inp': 'Expert-routing layer for the Feed-forward network in Mixture of Expert models', + 'ffn_gate_exp': 'Feed-forward network "gate" layer per expert in Mixture of Expert models', + 'ffn_down_exp': 'Feed-forward network "down" layer per expert in Mixture of Expert models', + 'ffn_up_exp': 'Feed-forward network "up" layer per expert in Mixture of Expert models', + 'ssm_in': 'State space model input projections', + 'ssm_conv1d': 'State space model rolling/shift', + 'ssm_x': 'State space model selective parametrization', + 'ssm_a': 'State space model state compression', + 'ssm_d': 'State space model skip connection', + 'ssm_dt': 'State space model time step', + 'ssm_out': 'State space model output projection', + 'blk': 'Block', + 'enc': 'Encoder', + 'dec': 'Decoder', + } + + expanded_words = [] + for word in words: + word_norm = word.strip().lower() + if word_norm in abbreviation_dictionary: + expanded_words.append(abbreviation_dictionary[word_norm].title()) + else: + expanded_words.append(word.title()) + + return ' '.join(expanded_words) + + +def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None: + host_endian, file_endian = get_file_host_endian(reader) + markdown_content = "" + markdown_content += f'# {args.model} - GGUF Internal File Dump\n\n' + markdown_content += f'- Endian: {file_endian} endian\n' + markdown_content += '\n' + markdown_content += '## Key Value Metadata Store\n\n' + markdown_content += f'There are {len(reader.fields)} key-value pairs in this file\n' + markdown_content += '\n' + + kv_dump_table: list[dict[str, str | int]] = [] + for n, field in enumerate(reader.fields.values(), 1): + if not field.types: + pretty_type = 'N/A' + elif field.types[0] == GGUFValueType.ARRAY: + nest_count = len(field.types) - 1 + pretty_type = '[' * nest_count + str(field.types[-1].name) + ']' * nest_count + else: + pretty_type = str(field.types[-1].name) + + def escape_markdown_inline_code(value_string): + # Find the longest contiguous sequence of backticks in the string then + # wrap string with appropriate number of backticks required to escape it + max_backticks = max((len(match.group(0)) for match in re.finditer(r'`+', value_string)), default=0) + inline_code_marker = '`' * (max_backticks + 1) + + # If the string starts or ends with a backtick, add a space at the beginning and end + if value_string.startswith('`') or value_string.endswith('`'): + value_string = f" {value_string} " + + return f"{inline_code_marker}{value_string}{inline_code_marker}" + + total_elements = len(field.data) + value = "" + if len(field.types) == 1: + curr_type = field.types[0] + if curr_type == GGUFValueType.STRING: + truncate_length = 60 + value_string = str(bytes(field.parts[-1]), encoding='utf-8') + if len(value_string) > truncate_length: + head = escape_markdown_inline_code(value_string[:truncate_length // 2]) + tail = escape_markdown_inline_code(value_string[-truncate_length // 2:]) + value = "{head}...{tail}".format(head=head, tail=tail) + else: + value = escape_markdown_inline_code(value_string) + elif curr_type in reader.gguf_scalar_to_np: + value = str(field.parts[-1][0]) + else: + if field.types[0] == GGUFValueType.ARRAY: + curr_type = field.types[1] + array_elements = [] + + if curr_type == GGUFValueType.STRING: + render_element = min(5, total_elements) + for element_pos in range(render_element): + truncate_length = 30 + value_string = str(bytes(field.parts[-1 - (total_elements - element_pos - 1) * 2]), encoding='utf-8') + if len(value_string) > truncate_length: + head = escape_markdown_inline_code(value_string[:truncate_length // 2]) + tail = escape_markdown_inline_code(value_string[-truncate_length // 2:]) + value = "{head}...{tail}".format(head=head, tail=tail) + else: + value = escape_markdown_inline_code(value_string) + array_elements.append(value) + + elif curr_type in reader.gguf_scalar_to_np: + render_element = min(7, total_elements) + for element_pos in range(render_element): + array_elements.append(str(field.parts[-1 - (total_elements - element_pos - 1)][0])) + + value = f'[ {", ".join(array_elements).strip()}{", ..." if total_elements > len(array_elements) else ""} ]' + + kv_dump_table.append({"n":n, "pretty_type":pretty_type, "total_elements":total_elements, "field_name":field.name, "value":value}) + + kv_dump_table_header_map = [ + {'key_name':'n', 'header_name':'POS', 'align':'right'}, + {'key_name':'pretty_type', 'header_name':'TYPE', 'align':'left'}, + {'key_name':'total_elements', 'header_name':'Count', 'align':'right'}, + {'key_name':'field_name', 'header_name':'Key', 'align':'left'}, + {'key_name':'value', 'header_name':'Value', 'align':'left'}, + ] + + markdown_content += markdown_table_with_alignment_support(kv_dump_table_header_map, kv_dump_table) + + markdown_content += "\n" + + if not args.no_tensors: + # Group tensors by their prefix and maintain order + tensor_prefix_order: list[str] = [] + tensor_name_to_key: dict[str, int] = {} + tensor_groups: dict[str, list[ReaderTensor]] = {} + total_elements = sum(tensor.n_elements for tensor in reader.tensors) + + # Parsing Tensors Record + for key, tensor in enumerate(reader.tensors): + tensor_components = tensor.name.split('.') + + # Classify Tensor Group + tensor_group_name = "base" + if tensor_components[0] == 'blk': + tensor_group_name = f"{tensor_components[0]}.{tensor_components[1]}" + elif tensor_components[0] in ['enc', 'dec'] and tensor_components[1] == 'blk': + tensor_group_name = f"{tensor_components[0]}.{tensor_components[1]}.{tensor_components[2]}" + elif tensor_components[0] in ['enc', 'dec']: + tensor_group_name = f"{tensor_components[0]}" + + # Check if new Tensor Group + if tensor_group_name not in tensor_groups: + tensor_groups[tensor_group_name] = [] + tensor_prefix_order.append(tensor_group_name) + + # Record Tensor and Tensor Position + tensor_groups[tensor_group_name].append(tensor) + tensor_name_to_key[tensor.name] = key + + # Tensors Mapping Dump + markdown_content += f'## Tensors Overview {element_count_rounded_notation(total_elements)} Elements\n\n' + markdown_content += f'Total number of elements in all tensors: {total_elements} Elements\n' + markdown_content += '\n' + + for group in tensor_prefix_order: + tensors = tensor_groups[group] + group_elements = sum(tensor.n_elements for tensor in tensors) + markdown_content += f"- [{translate_tensor_name(group)} Tensor Group - {element_count_rounded_notation(group_elements)} Elements](#{group.replace('.', '_')})\n" + + markdown_content += "\n" + + markdown_content += "### Tensor Data Offset\n" + markdown_content += '\n' + markdown_content += 'This table contains the offset and data segment relative to start of file\n' + markdown_content += '\n' + + tensor_mapping_table: list[dict[str, str | int]] = [] + for key, tensor in enumerate(reader.tensors): + data_offset_pretty = '{0:#16x}'.format(tensor.data_offset) + data_size_pretty = '{0:#16x}'.format(tensor.n_bytes) + tensor_mapping_table.append({"t_id":key, "layer_name":tensor.name, "data_offset":data_offset_pretty, "data_size":data_size_pretty}) + + tensors_mapping_table_header_map = [ + {'key_name':'t_id', 'header_name':'T_ID', 'align':'right'}, + {'key_name':'layer_name', 'header_name':'Tensor Layer Name', 'align':'left'}, + {'key_name':'data_offset', 'header_name':'Data Offset (B)', 'align':'right'}, + {'key_name':'data_size', 'header_name':'Data Size (B)', 'align':'right'}, + ] + + markdown_content += markdown_table_with_alignment_support(tensors_mapping_table_header_map, tensor_mapping_table) + markdown_content += "\n" + + for group in tensor_prefix_order: + tensors = tensor_groups[group] + group_elements = sum(tensor.n_elements for tensor in tensors) + group_percentage = group_elements / total_elements * 100 + markdown_content += f"### {translate_tensor_name(group)} Tensor Group : {element_count_rounded_notation(group_elements)} Elements\n\n" + + # Precalculate column sizing for visual consistency + prettify_element_est_count_size: int = 1 + prettify_element_count_size: int = 1 + prettify_dimension_max_widths: dict[int, int] = {} + for tensor in tensors: + prettify_element_est_count_size = max(prettify_element_est_count_size, len(str(element_count_rounded_notation(tensor.n_elements)))) + prettify_element_count_size = max(prettify_element_count_size, len(str(tensor.n_elements))) + for i, dimension_size in enumerate(list(tensor.shape) + [1] * (4 - len(tensor.shape))): + prettify_dimension_max_widths[i] = max(prettify_dimension_max_widths.get(i,1), len(str(dimension_size))) + + # Generate Tensor Layer Table Content + tensor_dump_table: list[dict[str, str | int]] = [] + for tensor in tensors: + human_friendly_name = translate_tensor_name(tensor.name.replace(".weight", ".(W)").replace(".bias", ".(B)")) + pretty_dimension = ' x '.join(f'{str(d):>{prettify_dimension_max_widths[i]}}' for i, d in enumerate(list(tensor.shape) + [1] * (4 - len(tensor.shape)))) + element_count_est = f"({element_count_rounded_notation(tensor.n_elements):>{prettify_element_est_count_size}})" + element_count_string = f"{element_count_est} {tensor.n_elements:>{prettify_element_count_size}}" + type_name_string = f"{tensor.tensor_type.name}" + tensor_dump_table.append({"t_id":tensor_name_to_key[tensor.name], "layer_name":tensor.name, "human_layer_name":human_friendly_name, "element_count":element_count_string, "pretty_dimension":pretty_dimension, "tensor_type":type_name_string}) + + tensor_dump_table_header_map = [ + {'key_name':'t_id', 'header_name':'T_ID', 'align':'right'}, + {'key_name':'layer_name', 'header_name':'Tensor Layer Name', 'align':'left'}, + {'key_name':'human_layer_name', 'header_name':'Human Friendly Tensor Layer Name', 'align':'left'}, + {'key_name':'element_count', 'header_name':'Elements', 'align':'left'}, + {'key_name':'pretty_dimension', 'header_name':'Shape', 'align':'left'}, + {'key_name':'tensor_type', 'header_name':'Type', 'align':'left'}, + ] + + markdown_content += markdown_table_with_alignment_support(tensor_dump_table_header_map, tensor_dump_table) + + markdown_content += "\n" + markdown_content += f"- Total elements in {group}: ({element_count_rounded_notation(group_elements):>4}) {group_elements}\n" + markdown_content += f"- Percentage of total elements: {group_percentage:.2f}%\n" + markdown_content += "\n\n" + + print(markdown_content) # noqa: NP100 + + +def main() -> None: + parser = argparse.ArgumentParser(description="Dump GGUF file metadata") + parser.add_argument("model", type=str, help="GGUF format model filename") + parser.add_argument("--no-tensors", action="store_true", help="Don't dump tensor metadata") + parser.add_argument("--json", action="store_true", help="Produce JSON output") + parser.add_argument("--json-array", action="store_true", help="Include full array values in JSON output (long)") + parser.add_argument("--data-offset", action="store_true", help="Start of data offset") + parser.add_argument("--data-alignment", action="store_true", help="Data alignment applied globally to data field") + parser.add_argument("--markdown", action="store_true", help="Produce markdown output") + parser.add_argument("--verbose", action="store_true", help="increase output verbosity") + + args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"]) + + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) + + if not args.json and not args.markdown and not args.data_offset and not args.data_alignment: + logger.info(f'* Loading: {args.model}') + + reader = GGUFReader(args.model, 'r') + + if args.json: + dump_metadata_json(reader, args) + elif args.markdown: + dump_markdown_metadata(reader, args) + elif args.data_offset: + print(reader.data_offset) # noqa: NP100 + elif args.data_alignment: + print(reader.alignment) # noqa: NP100 + else: + dump_metadata(reader, args) + + +if __name__ == '__main__': + main() diff --git a/gguf-py/gguf/scripts/gguf_editor_gui.py b/gguf-py/gguf/scripts/gguf_editor_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..05f4db0f8cdc89049a7e8719fcaa1fe741391413 --- /dev/null +++ b/gguf-py/gguf/scripts/gguf_editor_gui.py @@ -0,0 +1,1621 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import logging +import argparse +import os +import sys +import numpy +import enum +from pathlib import Path +from typing import Any, Optional, Tuple, Type +import warnings + +import numpy as np +from PySide6.QtWidgets import ( + QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, + QPushButton, QLabel, QLineEdit, QFileDialog, QTableWidget, + QTableWidgetItem, QComboBox, QMessageBox, QTabWidget, + QTextEdit, QFormLayout, + QHeaderView, QDialog, QDialogButtonBox +) +from PySide6.QtCore import Qt + +# Necessary to load the local gguf package +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists(): + sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +import gguf +from gguf import GGUFReader, GGUFWriter, GGUFValueType, ReaderField +from gguf.constants import TokenType, RopeScalingType, PoolingType, GGMLQuantizationType + +logger = logging.getLogger("gguf-editor-gui") + +# Map of key names to enum types for automatic enum interpretation +KEY_TO_ENUM_TYPE = { + gguf.Keys.Tokenizer.TOKEN_TYPE: TokenType, + gguf.Keys.Rope.SCALING_TYPE: RopeScalingType, + gguf.Keys.LLM.POOLING_TYPE: PoolingType, + gguf.Keys.General.FILE_TYPE: GGMLQuantizationType, +} + +# Define the tokenizer keys that should be edited together +TOKENIZER_LINKED_KEYS = [ + gguf.Keys.Tokenizer.LIST, + gguf.Keys.Tokenizer.TOKEN_TYPE, + gguf.Keys.Tokenizer.SCORES +] + + +class TokenizerEditorDialog(QDialog): + def __init__(self, tokens, token_types, scores, parent=None): + super().__init__(parent) + self.setWindowTitle("Edit Tokenizer Data") + self.resize(900, 600) + + self.tokens = tokens.copy() if tokens else [] + self.token_types = token_types.copy() if token_types else [] + self.scores = scores.copy() if scores else [] + + # Ensure all arrays have the same length + max_len = max(len(self.tokens), len(self.token_types), len(self.scores)) + if len(self.tokens) < max_len: + self.tokens.extend([""] * (max_len - len(self.tokens))) + if len(self.token_types) < max_len: + self.token_types.extend([0] * (max_len - len(self.token_types))) + if len(self.scores) < max_len: + self.scores.extend([0.0] * (max_len - len(self.scores))) + + layout = QVBoxLayout(self) + + # Add filter controls + filter_layout = QHBoxLayout() + filter_layout.addWidget(QLabel("Filter:")) + self.filter_edit = QLineEdit() + self.filter_edit.setPlaceholderText("Type to filter tokens...") + self.filter_edit.textChanged.connect(self.apply_filter) + filter_layout.addWidget(self.filter_edit) + + # Add page controls + self.page_size = 100 # Show 100 items per page + self.current_page = 0 + self.total_pages = max(1, (len(self.tokens) + self.page_size - 1) // self.page_size) + + self.page_label = QLabel(f"Page 1 of {self.total_pages}") + filter_layout.addWidget(self.page_label) + + prev_page = QPushButton("Previous") + prev_page.clicked.connect(self.previous_page) + filter_layout.addWidget(prev_page) + + next_page = QPushButton("Next") + next_page.clicked.connect(self.next_page) + filter_layout.addWidget(next_page) + + layout.addLayout(filter_layout) + + # Tokenizer data table + self.tokens_table = QTableWidget() + self.tokens_table.setColumnCount(4) + self.tokens_table.setHorizontalHeaderLabels(["Index", "Token", "Type", "Score"]) + self.tokens_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents) + self.tokens_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch) + self.tokens_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents) + self.tokens_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents) + + layout.addWidget(self.tokens_table) + + # Controls + controls_layout = QHBoxLayout() + + add_button = QPushButton("Add Token") + add_button.clicked.connect(self.add_token) + controls_layout.addWidget(add_button) + + remove_button = QPushButton("Remove Selected") + remove_button.clicked.connect(self.remove_selected) + controls_layout.addWidget(remove_button) + + controls_layout.addStretch() + + layout.addLayout(controls_layout) + + # Buttons + buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + buttons.accepted.connect(self.accept) + buttons.rejected.connect(self.reject) + layout.addWidget(buttons) + + # Initialize the filtered values + self.filtered_indices = list(range(len(self.tokens))) + + # Load data for the first page + self.load_page() + + def apply_filter(self): + """Filter the tokens based on the search text.""" + filter_text = self.filter_edit.text().lower() + + if not filter_text: + # No filter, show all values + self.filtered_indices = list(range(len(self.tokens))) + else: + # Apply filter + self.filtered_indices = [] + for i, token in enumerate(self.tokens): + if filter_text in str(token).lower(): + self.filtered_indices.append(i) + + # Reset to first page and reload + self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size) + self.current_page = 0 + self.page_label.setText(f"Page 1 of {self.total_pages}") + self.load_page() + + def previous_page(self): + """Go to the previous page of results.""" + if self.current_page > 0: + self.current_page -= 1 + self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}") + self.load_page() + + def next_page(self): + """Go to the next page of results.""" + if self.current_page < self.total_pages - 1: + self.current_page += 1 + self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}") + self.load_page() + + def load_page(self): + """Load the current page of tokenizer data.""" + self.tokens_table.setRowCount(0) # Clear the table + + # Calculate start and end indices for the current page + start_idx = self.current_page * self.page_size + end_idx = min(start_idx + self.page_size, len(self.filtered_indices)) + + # Pre-allocate rows for better performance + self.tokens_table.setRowCount(end_idx - start_idx) + + for row, i in enumerate(range(start_idx, end_idx)): + orig_idx = self.filtered_indices[i] + + # Index + index_item = QTableWidgetItem(str(orig_idx)) + index_item.setData(Qt.ItemDataRole.UserRole, orig_idx) # Store original index + index_item.setFlags(index_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.tokens_table.setItem(row, 0, index_item) + + # Token + token_item = QTableWidgetItem(str(self.tokens[orig_idx])) + self.tokens_table.setItem(row, 1, token_item) + + # Token Type + token_type = self.token_types[orig_idx] if orig_idx < len(self.token_types) else 0 + try: + enum_val = TokenType(token_type) + display_text = f"{enum_val.name} ({token_type})" + except (ValueError, KeyError): + display_text = f"Unknown ({token_type})" + + type_item = QTableWidgetItem(display_text) + type_item.setData(Qt.ItemDataRole.UserRole, token_type) + + # Make type cell editable with a double-click handler + type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.tokens_table.setItem(row, 2, type_item) + + # Score + score = self.scores[orig_idx] if orig_idx < len(self.scores) else 0.0 + score_item = QTableWidgetItem(str(score)) + self.tokens_table.setItem(row, 3, score_item) + + # Connect double-click handler for token type cells + self.tokens_table.cellDoubleClicked.connect(self.handle_cell_double_click) + + def handle_cell_double_click(self, row, column): + """Handle double-click on a cell, specifically for token type editing.""" + if column == 2: # Token Type column + orig_item = self.tokens_table.item(row, 0) + if orig_item: + orig_idx = orig_item.data(Qt.ItemDataRole.UserRole) + self.edit_token_type(row, orig_idx) + + def edit_token_type(self, row, orig_idx): + """Edit a token type using a dialog with a dropdown of all enum options.""" + current_value = self.token_types[orig_idx] if orig_idx < len(self.token_types) else 0 + + # Create a dialog with enum options + dialog = QDialog(self) + dialog.setWindowTitle("Select Token Type") + layout = QVBoxLayout(dialog) + + combo = QComboBox() + for enum_val in TokenType: + combo.addItem(f"{enum_val.name} ({enum_val.value})", enum_val.value) + + # Set current value + try: + if isinstance(current_value, int): + enum_val = TokenType(current_value) + combo.setCurrentText(f"{enum_val.name} ({current_value})") + except (ValueError, KeyError): + pass + + layout.addWidget(combo) + + buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + buttons.accepted.connect(dialog.accept) + buttons.rejected.connect(dialog.reject) + layout.addWidget(buttons) + + if dialog.exec() == QDialog.DialogCode.Accepted: + # Get the selected value + new_value = combo.currentData() + enum_val = TokenType(new_value) + display_text = f"{enum_val.name} ({new_value})" + + # Update the display + type_item = self.tokens_table.item(row, 2) + if type_item: + type_item.setText(display_text) + type_item.setData(Qt.ItemDataRole.UserRole, new_value) + + # Update the actual value + self.token_types[orig_idx] = new_value + + def add_token(self): + """Add a new token to the end of the list.""" + # Add to the end of the arrays + self.tokens.append("") + self.token_types.append(0) # Default to normal token + self.scores.append(0.0) + + orig_idx = len(self.tokens) - 1 + + # Add to filtered indices if it matches the current filter + filter_text = self.filter_edit.text().lower() + if not filter_text or filter_text in "": + self.filtered_indices.append(orig_idx) + + # Update pagination + self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size) + + # Go to the last page to show the new item + self.current_page = self.total_pages - 1 + self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}") + + # Reload the page + self.load_page() + + def remove_selected(self): + """Remove selected tokens from all arrays.""" + selected_rows = [] + for item in self.tokens_table.selectedItems(): + row = item.row() + if row not in selected_rows: + selected_rows.append(row) + + if not selected_rows: + return + + # Get original indices in descending order to avoid index shifting + orig_indices = [] + for row in selected_rows: + orig_item = self.tokens_table.item(row, 0) + if orig_item: + orig_indices.append(orig_item.data(Qt.ItemDataRole.UserRole)) + orig_indices.sort(reverse=True) + + # Remove from all arrays + for idx in orig_indices: + if idx < len(self.tokens): + del self.tokens[idx] + if idx < len(self.token_types): + del self.token_types[idx] + if idx < len(self.scores): + del self.scores[idx] + + # Rebuild filtered_indices + self.filtered_indices = [] + filter_text = self.filter_edit.text().lower() + + for i, token in enumerate(self.tokens): + if not filter_text or filter_text in str(token).lower(): + self.filtered_indices.append(i) + + # Update pagination + self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size) + self.current_page = min(self.current_page, self.total_pages - 1) + self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}") + + # Reload the page + self.load_page() + + def get_data(self): + """Return the edited tokenizer data.""" + return self.tokens, self.token_types, self.scores + + +class ArrayEditorDialog(QDialog): + def __init__(self, array_values, element_type, key=None, parent=None): + super().__init__(parent) + self.setWindowTitle("Edit Array Values") + self.resize(700, 500) + + self.array_values = array_values + self.element_type = element_type + self.key = key + + # Get enum type for this array if applicable + self.enum_type = None + if key in KEY_TO_ENUM_TYPE and element_type == GGUFValueType.INT32: + self.enum_type = KEY_TO_ENUM_TYPE[key] + + layout = QVBoxLayout(self) + + # Add enum type information if applicable + if self.enum_type is not None: + enum_info_layout = QHBoxLayout() + enum_label = QLabel(f"Editing {self.enum_type.__name__} values:") + enum_info_layout.addWidget(enum_label) + + # Add a legend for the enum values + enum_values = ", ".join([f"{e.name}={e.value}" for e in self.enum_type]) + enum_values_label = QLabel(f"Available values: {enum_values}") + enum_values_label.setWordWrap(True) + enum_info_layout.addWidget(enum_values_label, 1) + + layout.addLayout(enum_info_layout) + + # Add search/filter controls + filter_layout = QHBoxLayout() + filter_layout.addWidget(QLabel("Filter:")) + self.filter_edit = QLineEdit() + self.filter_edit.setPlaceholderText("Type to filter values...") + self.filter_edit.textChanged.connect(self.apply_filter) + filter_layout.addWidget(self.filter_edit) + + # Add page controls for large arrays + self.page_size = 100 # Show 100 items per page + self.current_page = 0 + self.total_pages = max(1, (len(array_values) + self.page_size - 1) // self.page_size) + + self.page_label = QLabel(f"Page 1 of {self.total_pages}") + filter_layout.addWidget(self.page_label) + + prev_page = QPushButton("Previous") + prev_page.clicked.connect(self.previous_page) + filter_layout.addWidget(prev_page) + + next_page = QPushButton("Next") + next_page.clicked.connect(self.next_page) + filter_layout.addWidget(next_page) + + layout.addLayout(filter_layout) + + # Array items table + self.items_table = QTableWidget() + + # Set up columns based on whether we have an enum type + if self.enum_type is not None: + self.items_table.setColumnCount(3) + self.items_table.setHorizontalHeaderLabels(["Index", "Value", "Actions"]) + self.items_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents) + self.items_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch) + self.items_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents) + else: + self.items_table.setColumnCount(2) + self.items_table.setHorizontalHeaderLabels(["Index", "Value"]) + self.items_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents) + self.items_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch) + + layout.addWidget(self.items_table) + + # Controls + controls_layout = QHBoxLayout() + + add_button = QPushButton("Add Item") + add_button.clicked.connect(self.add_item) + controls_layout.addWidget(add_button) + + remove_button = QPushButton("Remove Selected") + remove_button.clicked.connect(self.remove_selected) + controls_layout.addWidget(remove_button) + + # Add bulk edit button for enum arrays + if self.enum_type is not None: + bulk_edit_button = QPushButton("Bulk Edit Selected") + bulk_edit_button.clicked.connect(self.bulk_edit_selected) + controls_layout.addWidget(bulk_edit_button) + + controls_layout.addStretch() + + layout.addLayout(controls_layout) + + # Buttons + buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + buttons.accepted.connect(self.accept) + buttons.rejected.connect(self.reject) + layout.addWidget(buttons) + + # Initialize the filtered values + self.filtered_indices = list(range(len(self.array_values))) + + # Load array values for the first page + self.load_page() + + def apply_filter(self): + """Filter the array values based on the search text.""" + filter_text = self.filter_edit.text().lower() + + if not filter_text: + # No filter, show all values + self.filtered_indices = list(range(len(self.array_values))) + else: + # Apply filter + self.filtered_indices = [] + for i, value in enumerate(self.array_values): + # For enum values, search in both name and value + if self.enum_type is not None and isinstance(value, int): + try: + enum_val = self.enum_type(value) + display_text = f"{enum_val.name} ({value})".lower() + if filter_text in display_text: + self.filtered_indices.append(i) + except (ValueError, KeyError): + # If not a valid enum value, just check the raw value + if filter_text in str(value).lower(): + self.filtered_indices.append(i) + else: + # For non-enum values, just check the string representation + if filter_text in str(value).lower(): + self.filtered_indices.append(i) + + # Reset to first page and reload + self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size) + self.current_page = 0 + self.page_label.setText(f"Page 1 of {self.total_pages}") + self.load_page() + + def previous_page(self): + """Go to the previous page of results.""" + if self.current_page > 0: + self.current_page -= 1 + self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}") + self.load_page() + + def next_page(self): + """Go to the next page of results.""" + if self.current_page < self.total_pages - 1: + self.current_page += 1 + self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}") + self.load_page() + + def load_page(self): + """Load the current page of array values.""" + self.items_table.setRowCount(0) # Clear the table + + # Calculate start and end indices for the current page + start_idx = self.current_page * self.page_size + end_idx = min(start_idx + self.page_size, len(self.filtered_indices)) + + # Pre-allocate rows for better performance + self.items_table.setRowCount(end_idx - start_idx) + + for row, i in enumerate(range(start_idx, end_idx)): + orig_idx = self.filtered_indices[i] + value = self.array_values[orig_idx] + + # Index + index_item = QTableWidgetItem(str(orig_idx)) + index_item.setData(Qt.ItemDataRole.UserRole, orig_idx) # Store original index + index_item.setFlags(index_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.items_table.setItem(row, 0, index_item) + + # Value + if self.enum_type is not None: + # Display enum value and name + try: + if isinstance(value, (int, numpy.signedinteger)): + enum_val = self.enum_type(value) + display_text = f"{enum_val.name} ({value})" + else: + display_text = str(value) + except (ValueError, KeyError): + display_text = f"Unknown ({value})" + + # Store the enum value in the item + value_item = QTableWidgetItem(display_text) + value_item.setData(Qt.ItemDataRole.UserRole, value) + value_item.setFlags(value_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.items_table.setItem(row, 1, value_item) + + # Add an edit button in a separate column + edit_button = QPushButton("Edit") + edit_button.setProperty("row", row) + edit_button.clicked.connect(self.edit_array_enum_value) + + # Create a widget to hold the button + button_widget = QWidget() + button_layout = QHBoxLayout(button_widget) + button_layout.setContentsMargins(2, 2, 2, 2) + button_layout.addWidget(edit_button) + button_layout.addStretch() + + self.items_table.setCellWidget(row, 2, button_widget) + else: + value_item = QTableWidgetItem(str(value)) + self.items_table.setItem(row, 1, value_item) + + def edit_array_enum_value(self): + """Handle editing an enum value in the array editor.""" + button = self.sender() + row = button.property("row") + + # Get the original index from the table item + orig_item = self.items_table.item(row, 0) + new_item = self.items_table.item(row, 1) + if orig_item and new_item and self.enum_type and self.edit_enum_value(row, self.enum_type): + orig_idx = orig_item.data(Qt.ItemDataRole.UserRole) + new_value = new_item.data(Qt.ItemDataRole.UserRole) + # Update the stored value in the array + if isinstance(new_value, (int, float, str, bool)): + self.array_values[orig_idx] = new_value + + def bulk_edit_selected(self): + """Edit multiple enum values at once.""" + if not self.enum_type: + return + + selected_rows = set() + for item in self.items_table.selectedItems(): + selected_rows.add(item.row()) + + if not selected_rows: + QMessageBox.information(self, "No Selection", "Please select at least one row to edit.") + return + + # Create a dialog with enum options + dialog = QDialog(self) + dialog.setWindowTitle(f"Bulk Edit {self.enum_type.__name__} Values") + layout = QVBoxLayout(dialog) + + layout.addWidget(QLabel(f"Set {len(selected_rows)} selected items to:")) + + combo = QComboBox() + for enum_val in self.enum_type: + combo.addItem(f"{enum_val.name} ({enum_val.value})", enum_val.value) + + layout.addWidget(combo) + + buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + buttons.accepted.connect(dialog.accept) + buttons.rejected.connect(dialog.reject) + layout.addWidget(buttons) + + if dialog.exec() == QDialog.DialogCode.Accepted: + # Get the selected value + new_value = combo.currentData() + enum_val = self.enum_type(new_value) + display_text = f"{enum_val.name} ({new_value})" + + # Update all selected rows + for row in selected_rows: + orig_item = self.items_table.item(row, 0) + new_item = self.items_table.item(row, 1) + if orig_item and new_item: + orig_idx = orig_item.data(Qt.ItemDataRole.UserRole) + self.array_values[orig_idx] = new_value + + # Update the display + new_item.setText(display_text) + new_item.setData(Qt.ItemDataRole.UserRole, new_value) + + def add_item(self): + # Add to the end of the array + orig_idx = len(self.array_values) + + # Add default value based on type + if self.enum_type is not None: + # Default to first enum value + default_value = list(self.enum_type)[0].value + self.array_values.append(default_value) + else: + if self.element_type == GGUFValueType.STRING: + self.array_values.append("") + else: + self.array_values.append(0) + + # Add to filtered indices if it matches the current filter + self.filtered_indices.append(orig_idx) + + # Update pagination + self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size) + + # Go to the last page to show the new item + self.current_page = self.total_pages - 1 + self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}") + + # Reload the page + self.load_page() + + def remove_selected(self): + selected_rows = [] + for item in self.items_table.selectedItems(): + row = item.row() + if row not in selected_rows: + selected_rows.append(row) + + if not selected_rows: + return + + # Get original indices in descending order to avoid index shifting + orig_indices = list() + for row in selected_rows: + orig_item = self.items_table.item(row, 0) + if orig_item: + orig_indices.append(orig_item.data(Qt.ItemDataRole.UserRole)) + orig_indices.sort(reverse=True) + + # Remove from array_values + for idx in orig_indices: + del self.array_values[idx] + + # Rebuild filtered_indices + self.filtered_indices = [] + filter_text = self.filter_edit.text().lower() + + for i, value in enumerate(self.array_values): + if not filter_text: + self.filtered_indices.append(i) + else: + # Apply filter + if self.enum_type is not None and isinstance(value, int): + try: + enum_val = self.enum_type(value) + display_text = f"{enum_val.name} ({value})".lower() + if filter_text in display_text: + self.filtered_indices.append(i) + except (ValueError, KeyError): + if filter_text in str(value).lower(): + self.filtered_indices.append(i) + else: + if filter_text in str(value).lower(): + self.filtered_indices.append(i) + + # Update pagination + self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size) + self.current_page = min(self.current_page, self.total_pages - 1) + self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}") + + # Reload the page + self.load_page() + + def edit_enum_value(self, row: int, enum_type: Type[enum.Enum]): + """Edit an enum value using a dialog with a dropdown of all enum options.""" + # Get the original index from the table item + orig_item = self.items_table.item(row, 0) + if orig_item: + orig_idx = orig_item.data(Qt.ItemDataRole.UserRole) + else: + return + current_value = self.array_values[orig_idx] + + # Create a dialog with enum options + dialog = QDialog(self) + dialog.setWindowTitle(f"Select {enum_type.__name__} Value") + layout = QVBoxLayout(dialog) + + # Add description + description = QLabel(f"Select a {enum_type.__name__} value:") + layout.addWidget(description) + + # Use a combo box for quick selection + combo = QComboBox() + for enum_val in enum_type: + combo.addItem(f"{enum_val.name} ({enum_val.value})", enum_val.value) + + # Set current value + try: + if isinstance(current_value, int): + enum_val = enum_type(current_value) + combo.setCurrentText(f"{enum_val.name} ({current_value})") + except (ValueError, KeyError): + pass + + layout.addWidget(combo) + + buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + buttons.accepted.connect(dialog.accept) + buttons.rejected.connect(dialog.reject) + layout.addWidget(buttons) + + if dialog.exec() == QDialog.DialogCode.Accepted: + # Update the value display and stored data + new_value = combo.currentData() + enum_val = enum_type(new_value) + display_text = f"{enum_val.name} ({new_value})" + + new_item = self.items_table.item(row, 1) + if new_item: + new_item.setText(display_text) + new_item.setData(Qt.ItemDataRole.UserRole, new_value) + + # Update the actual array value + self.array_values[orig_idx] = new_value + return True + return False + + def get_array_values(self): + # The array_values list is kept up-to-date as edits are made + return self.array_values + + +class AddMetadataDialog(QDialog): + def __init__(self, parent=None): + super().__init__(parent) + self.setWindowTitle("Add Metadata") + self.resize(400, 200) + + layout = QVBoxLayout(self) + + form_layout = QFormLayout() + + self.key_edit = QLineEdit() + form_layout.addRow("Key:", self.key_edit) + + self.type_combo = QComboBox() + for value_type in GGUFValueType: + if value_type != GGUFValueType.ARRAY: # Skip array type for simplicity + self.type_combo.addItem(value_type.name, value_type) + form_layout.addRow("Type:", self.type_combo) + + self.value_edit = QTextEdit() + form_layout.addRow("Value:", self.value_edit) + + layout.addLayout(form_layout) + + buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + buttons.accepted.connect(self.accept) + buttons.rejected.connect(self.reject) + layout.addWidget(buttons) + + def get_data(self) -> Tuple[str, GGUFValueType, Any]: + key = self.key_edit.text() + value_type = self.type_combo.currentData() + value_text = self.value_edit.toPlainText() + + # Convert value based on type + if value_type == GGUFValueType.UINT8: + value = np.uint8(int(value_text)) + elif value_type == GGUFValueType.INT8: + value = np.int8(int(value_text)) + elif value_type == GGUFValueType.UINT16: + value = np.uint16(int(value_text)) + elif value_type == GGUFValueType.INT16: + value = np.int16(int(value_text)) + elif value_type == GGUFValueType.UINT32: + value = np.uint32(int(value_text)) + elif value_type == GGUFValueType.INT32: + value = np.int32(int(value_text)) + elif value_type == GGUFValueType.FLOAT32: + value = np.float32(float(value_text)) + elif value_type == GGUFValueType.BOOL: + value = value_text.lower() in ('true', 'yes', '1') + elif value_type == GGUFValueType.STRING: + value = value_text + else: + value = value_text + + return key, value_type, value + + +class GGUFEditorWindow(QMainWindow): + def __init__(self): + super().__init__() + + self.setWindowTitle("GGUF Editor") + self.resize(1000, 800) + + self.current_file = None + self.reader = None + self.modified = False + self.metadata_changes = {} # Store changes to apply when saving + self.metadata_to_remove = set() # Store keys to remove when saving + self.on_metadata_changed_is_connected = False + + self.setup_ui() + + def setup_ui(self): + central_widget = QWidget() + self.setCentralWidget(central_widget) + + main_layout = QVBoxLayout(central_widget) + + # File controls + file_layout = QHBoxLayout() + + self.file_path_edit = QLineEdit() + self.file_path_edit.setReadOnly(True) + file_layout.addWidget(self.file_path_edit) + + open_button = QPushButton("Open GGUF") + open_button.clicked.connect(self.open_file) + file_layout.addWidget(open_button) + + save_button = QPushButton("Save As...") + save_button.clicked.connect(self.save_file) + file_layout.addWidget(save_button) + + main_layout.addLayout(file_layout) + + # Tabs for different views + self.tabs = QTabWidget() + + # Metadata tab + self.metadata_tab = QWidget() + metadata_layout = QVBoxLayout(self.metadata_tab) + + # Metadata table + self.metadata_table = QTableWidget() + self.metadata_table.setColumnCount(4) + self.metadata_table.setHorizontalHeaderLabels(["Key", "Type", "Value", "Actions"]) + self.metadata_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch) + self.metadata_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents) + self.metadata_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.Stretch) + self.metadata_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents) + metadata_layout.addWidget(self.metadata_table) + + # Metadata controls + metadata_controls = QHBoxLayout() + + add_metadata_button = QPushButton("Add Metadata") + add_metadata_button.clicked.connect(self.add_metadata) + metadata_controls.addWidget(add_metadata_button) + + metadata_controls.addStretch() + + metadata_layout.addLayout(metadata_controls) + + # Tensors tab + self.tensors_tab = QWidget() + tensors_layout = QVBoxLayout(self.tensors_tab) + + self.tensors_table = QTableWidget() + self.tensors_table.setColumnCount(5) + self.tensors_table.setHorizontalHeaderLabels(["Name", "Type", "Shape", "Elements", "Size (bytes)"]) + self.tensors_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch) + self.tensors_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents) + self.tensors_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents) + self.tensors_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents) + self.tensors_table.horizontalHeader().setSectionResizeMode(4, QHeaderView.ResizeMode.ResizeToContents) + tensors_layout.addWidget(self.tensors_table) + + # Add tabs to tab widget + self.tabs.addTab(self.metadata_tab, "Metadata") + self.tabs.addTab(self.tensors_tab, "Tensors") + + main_layout.addWidget(self.tabs) + + # Status bar + self.statusBar().showMessage("Ready") + + def load_file(self, file_path): + """Load a GGUF file by path""" + try: + self.statusBar().showMessage(f"Loading {file_path}...") + QApplication.processEvents() + + self.reader = GGUFReader(file_path, 'r') + self.current_file = file_path + self.file_path_edit.setText(file_path) + + self.load_metadata() + self.load_tensors() + + self.metadata_changes = {} + self.metadata_to_remove = set() + self.modified = False + + self.statusBar().showMessage(f"Loaded {file_path}") + return True + except Exception as e: + QMessageBox.critical(self, "Error", f"Failed to open file: {str(e)}") + self.statusBar().showMessage("Error loading file") + return False + + def open_file(self): + file_path, _ = QFileDialog.getOpenFileName( + self, "Open GGUF File", "", "GGUF Files (*.gguf);;All Files (*)" + ) + + if not file_path: + return + + self.load_file(file_path) + + def load_metadata(self): + self.metadata_table.setRowCount(0) + + if not self.reader: + return + + # Disconnect to prevent triggering during loading + if self.on_metadata_changed_is_connected: + with warnings.catch_warnings(): + warnings.filterwarnings('ignore') + self.metadata_table.itemChanged.disconnect(self.on_metadata_changed) + self.on_metadata_changed_is_connected = False + + for i, (key, field) in enumerate(self.reader.fields.items()): + self.metadata_table.insertRow(i) + + # Key + key_item = QTableWidgetItem(key) + key_item.setFlags(key_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.metadata_table.setItem(i, 0, key_item) + + # Type + if not field.types: + type_str = "N/A" + elif field.types[0] == GGUFValueType.ARRAY: + nest_count = len(field.types) - 1 + element_type = field.types[-1].name + # Check if this is an enum array + enum_type = self.get_enum_for_key(key) + if enum_type is not None and field.types[-1] == GGUFValueType.INT32: + element_type = enum_type.__name__ + type_str = '[' * nest_count + element_type + ']' * nest_count + else: + type_str = str(field.types[0].name) + # Check if this is an enum field + enum_type = self.get_enum_for_key(key) + if enum_type is not None and field.types[0] == GGUFValueType.INT32: + type_str = enum_type.__name__ + + type_item = QTableWidgetItem(type_str) + type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.metadata_table.setItem(i, 1, type_item) + + # Value + value_str = self.format_field_value(field) + value_item = QTableWidgetItem(value_str) + + # Make only simple values editable + if len(field.types) == 1 and field.types[0] != GGUFValueType.ARRAY: + value_item.setFlags(value_item.flags() | Qt.ItemFlag.ItemIsEditable) + else: + value_item.setFlags(value_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + + self.metadata_table.setItem(i, 2, value_item) + + # Actions + actions_widget = QWidget() + actions_layout = QHBoxLayout(actions_widget) + actions_layout.setContentsMargins(2, 2, 2, 2) + + # Add Edit button for arrays and enum fields + if field.types and field.types[0] == GGUFValueType.ARRAY: + edit_button = QPushButton("Edit") + edit_button.setProperty("row", i) + edit_button.setProperty("key", key) + edit_button.clicked.connect(self.edit_array_metadata) + actions_layout.addWidget(edit_button) + + # Add special label for tokenizer linked fields + if key in TOKENIZER_LINKED_KEYS: + edit_button.setText("Edit Tokenizer") + edit_button.setToolTip("Edit all tokenizer data together") + elif len(field.types) == 1 and self.get_enum_for_key(key) is not None: + edit_button = QPushButton("Edit") + edit_button.setProperty("row", i) + edit_button.setProperty("key", key) + edit_button.clicked.connect(self.edit_metadata_enum) + actions_layout.addWidget(edit_button) + + remove_button = QPushButton("Remove") + remove_button.setProperty("row", i) + remove_button.setProperty("key", key) + remove_button.clicked.connect(self.remove_metadata) + actions_layout.addWidget(remove_button) + + self.metadata_table.setCellWidget(i, 3, actions_widget) + + # Reconnect after loading + self.metadata_table.itemChanged.connect(self.on_metadata_changed) + self.on_metadata_changed_is_connected = True + + def extract_array_values(self, field: ReaderField) -> list: + """Extract all values from an array field.""" + if not field.types or field.types[0] != GGUFValueType.ARRAY: + return [] + + curr_type = field.types[1] + array_values = [] + total_elements = len(field.data) + + if curr_type == GGUFValueType.STRING: + for element_pos in range(total_elements): + value_string = str(bytes(field.parts[-1 - (total_elements - element_pos - 1) * 2]), encoding='utf-8') + array_values.append(value_string) + elif self.reader and curr_type in self.reader.gguf_scalar_to_np: + for element_pos in range(total_elements): + array_values.append(field.parts[-1 - (total_elements - element_pos - 1)][0]) + + return array_values + + def get_enum_for_key(self, key: str) -> Optional[Type[enum.Enum]]: + """Get the enum type for a given key if it exists.""" + return KEY_TO_ENUM_TYPE.get(key) + + def format_enum_value(self, value: Any, enum_type: Type[enum.Enum]) -> str: + """Format a value as an enum if possible.""" + try: + if isinstance(value, (int, str)): + enum_value = enum_type(value) + return f"{enum_value.name} ({value})" + except (ValueError, KeyError): + pass + return str(value) + + def format_field_value(self, field: ReaderField) -> str: + if not field.types: + return "N/A" + + if len(field.types) == 1: + curr_type = field.types[0] + if curr_type == GGUFValueType.STRING: + return str(bytes(field.parts[-1]), encoding='utf-8') + elif self.reader and curr_type in self.reader.gguf_scalar_to_np: + value = field.parts[-1][0] + # Check if this field has an enum type + enum_type = self.get_enum_for_key(field.name) + if enum_type is not None: + return self.format_enum_value(value, enum_type) + return str(value) + + if field.types[0] == GGUFValueType.ARRAY: + array_values = self.extract_array_values(field) + render_element = min(5, len(array_values)) + + # Get enum type for this array if applicable + enum_type = self.get_enum_for_key(field.name) + + if enum_type is not None: + array_elements = [] + for i in range(render_element): + array_elements.append(self.format_enum_value(array_values[i], enum_type)) + else: + array_elements = [str(array_values[i]) for i in range(render_element)] + + return f"[ {', '.join(array_elements).strip()}{', ...' if len(array_values) > len(array_elements) else ''} ]" + + return "Complex value" + + def load_tensors(self): + self.tensors_table.setRowCount(0) + + if not self.reader: + return + + for i, tensor in enumerate(self.reader.tensors): + self.tensors_table.insertRow(i) + + # Name + name_item = QTableWidgetItem(tensor.name) + name_item.setFlags(name_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.tensors_table.setItem(i, 0, name_item) + + # Type + type_item = QTableWidgetItem(tensor.tensor_type.name) + type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.tensors_table.setItem(i, 1, type_item) + + # Shape + shape_str = " × ".join(str(d) for d in tensor.shape) + shape_item = QTableWidgetItem(shape_str) + shape_item.setFlags(shape_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.tensors_table.setItem(i, 2, shape_item) + + # Elements + elements_item = QTableWidgetItem(str(tensor.n_elements)) + elements_item.setFlags(elements_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.tensors_table.setItem(i, 3, elements_item) + + # Size + size_item = QTableWidgetItem(f"{tensor.n_bytes:,}") + size_item.setFlags(size_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.tensors_table.setItem(i, 4, size_item) + + def on_metadata_changed(self, item): + if item.column() != 2: # Only handle value column changes + return + + row = item.row() + orig_item = self.metadata_table.item(row, 0) + key = None + if orig_item: + key = orig_item.text() + new_value = item.text() + + field = None + if self.reader and key: + field = self.reader.get_field(key) + if not field or not field.types or not key: + return + + value_type = field.types[0] + + # Check if this is an enum field + enum_type = self.get_enum_for_key(key) + if enum_type is not None and value_type == GGUFValueType.INT32: + # Try to parse the enum value from the text + try: + # Check if it's a name + try: + enum_val = enum_type[new_value] + converted_value = enum_val.value + except (KeyError, AttributeError): + # Check if it's a number or "NAME (value)" format + if '(' in new_value and ')' in new_value: + # Extract the value from "NAME (value)" format + value_part = new_value.split('(')[1].split(')')[0].strip() + converted_value = int(value_part) + else: + # Try to convert directly to int + converted_value = int(new_value) + + # Validate that it's a valid enum value + enum_type(converted_value) + + # Store the change + self.metadata_changes[key] = (value_type, converted_value) + self.modified = True + + # Update display with formatted enum value + formatted_value = self.format_enum_value(converted_value, enum_type) + item.setText(formatted_value) + + self.statusBar().showMessage(f"Changed {key} to {formatted_value}") + return + except (ValueError, KeyError) as e: + QMessageBox.warning( + self, + f"Invalid Enum Value ({e})", + f"'{new_value}' is not a valid {enum_type.__name__} value.\n" + f"Valid values are: {', '.join(v.name for v in enum_type)}") + + # Revert to original value + original_value = self.format_field_value(field) + item.setText(original_value) + return + + try: + # Convert the string value to the appropriate type + if value_type == GGUFValueType.UINT8: + converted_value = np.uint8(int(new_value)) + elif value_type == GGUFValueType.INT8: + converted_value = np.int8(int(new_value)) + elif value_type == GGUFValueType.UINT16: + converted_value = np.uint16(int(new_value)) + elif value_type == GGUFValueType.INT16: + converted_value = np.int16(int(new_value)) + elif value_type == GGUFValueType.UINT32: + converted_value = np.uint32(int(new_value)) + elif value_type == GGUFValueType.INT32: + converted_value = np.int32(int(new_value)) + elif value_type == GGUFValueType.FLOAT32: + converted_value = np.float32(float(new_value)) + elif value_type == GGUFValueType.BOOL: + converted_value = new_value.lower() in ('true', 'yes', '1') + elif value_type == GGUFValueType.STRING: + converted_value = new_value + else: + # Unsupported type for editing + return + + # Store the change + self.metadata_changes[key] = (value_type, converted_value) + self.modified = True + + self.statusBar().showMessage(f"Changed {key} to {new_value}") + except ValueError: + QMessageBox.warning(self, "Invalid Value", f"The value '{new_value}' is not valid for type {value_type.name}") + + # Revert to original value + original_value = self.format_field_value(field) + item.setText(original_value) + + def remove_metadata(self): + button = self.sender() + key = button.property("key") + row = button.property("row") + + reply = QMessageBox.question( + self, "Confirm Removal", + f"Are you sure you want to remove the metadata key '{key}'?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, QMessageBox.StandardButton.No + ) + + if reply == QMessageBox.StandardButton.Yes: + self.metadata_table.removeRow(row) + self.metadata_to_remove.add(key) + + # If we previously had changes for this key, remove them + if key in self.metadata_changes: + del self.metadata_changes[key] + + self.modified = True + self.statusBar().showMessage(f"Marked {key} for removal") + + def edit_metadata_enum(self): + """Edit an enum metadata field.""" + button = self.sender() + key = button.property("key") + row = button.property("row") + + field = None + if self.reader: + field = self.reader.get_field(key) + if not field or not field.types: + return + + enum_type = self.get_enum_for_key(key) + if enum_type is None: + return + + # Get current value + current_value = field.contents() + + # Create a dialog with enum options + dialog = QDialog(self) + dialog.setWindowTitle(f"Select {enum_type.__name__} Value") + layout = QVBoxLayout(dialog) + + combo = QComboBox() + for enum_val in enum_type: + combo.addItem(f"{enum_val.name} ({enum_val.value})", enum_val.value) + + # Set current value + try: + if isinstance(current_value, (int, str)): + enum_val = enum_type(current_value) + combo.setCurrentText(f"{enum_val.name} ({current_value})") + except (ValueError, KeyError): + pass + + layout.addWidget(combo) + + buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + buttons.accepted.connect(dialog.accept) + buttons.rejected.connect(dialog.reject) + layout.addWidget(buttons) + + if dialog.exec() == QDialog.DialogCode.Accepted: + # Get the selected value + new_value = combo.currentData() + enum_val = enum_type(new_value) + + # Store the change + self.metadata_changes[key] = (field.types[0], new_value) + self.modified = True + + # Update display + display_text = f"{enum_val.name} ({new_value})" + target_item = self.metadata_table.item(row, 2) + if target_item: + target_item.setText(display_text) + + self.statusBar().showMessage(f"Changed {key} to {display_text}") + + def edit_array_metadata(self): + button = self.sender() + key = button.property("key") + row = button.property("row") + + # Check if this is one of the linked tokenizer keys + if key in TOKENIZER_LINKED_KEYS: + self.edit_tokenizer_metadata(key) + return + + field = None + if self.reader: + field = self.reader.get_field(key) + if not field or not field.types or field.types[0] != GGUFValueType.ARRAY: + return + + # Get array element type + element_type = field.types[1] + + # Extract array values + array_values = self.extract_array_values(field) + + # Open array editor dialog + dialog = ArrayEditorDialog(array_values, element_type, key, self) + if dialog.exec() == QDialog.DialogCode.Accepted: + new_values = dialog.get_array_values() + + # Store the change + self.metadata_changes[key] = (GGUFValueType.ARRAY, (element_type, new_values)) + self.modified = True + + # Update display + enum_type = self.get_enum_for_key(key) + if enum_type is not None and element_type == GGUFValueType.INT32: + value_str = f"[ {', '.join(self.format_enum_value(v, enum_type) for v in new_values[:5])}{', ...' if len(new_values) > 5 else ''} ]" + else: + value_str = f"[ {', '.join(str(v) for v in new_values[:5])}{', ...' if len(new_values) > 5 else ''} ]" + target_item = self.metadata_table.item(row, 2) + if target_item: + target_item.setText(value_str) + + self.statusBar().showMessage(f"Updated array values for {key}") + + def edit_tokenizer_metadata(self, trigger_key): + """Edit the linked tokenizer metadata arrays together.""" + if not self.reader: + return + + # Get all three fields + tokens_field = self.reader.get_field(gguf.Keys.Tokenizer.LIST) + token_types_field = self.reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE) + scores_field = self.reader.get_field(gguf.Keys.Tokenizer.SCORES) + + # Extract values from each field + tokens = self.extract_array_values(tokens_field) if tokens_field else [] + token_types = self.extract_array_values(token_types_field) if token_types_field else [] + scores = self.extract_array_values(scores_field) if scores_field else [] + + # Apply any pending changes + if gguf.Keys.Tokenizer.LIST in self.metadata_changes: + _, (_, tokens) = self.metadata_changes[gguf.Keys.Tokenizer.LIST] + if gguf.Keys.Tokenizer.TOKEN_TYPE in self.metadata_changes: + _, (_, token_types) = self.metadata_changes[gguf.Keys.Tokenizer.TOKEN_TYPE] + if gguf.Keys.Tokenizer.SCORES in self.metadata_changes: + _, (_, scores) = self.metadata_changes[gguf.Keys.Tokenizer.SCORES] + + # Open the tokenizer editor dialog + dialog = TokenizerEditorDialog(tokens, token_types, scores, self) + if dialog.exec() == QDialog.DialogCode.Accepted: + new_tokens, new_token_types, new_scores = dialog.get_data() + + # Store changes for all three arrays + if tokens_field: + self.metadata_changes[gguf.Keys.Tokenizer.LIST] = ( + GGUFValueType.ARRAY, + (tokens_field.types[1], new_tokens) + ) + + if token_types_field: + self.metadata_changes[gguf.Keys.Tokenizer.TOKEN_TYPE] = ( + GGUFValueType.ARRAY, + (token_types_field.types[1], new_token_types) + ) + + if scores_field: + self.metadata_changes[gguf.Keys.Tokenizer.SCORES] = ( + GGUFValueType.ARRAY, + (scores_field.types[1], new_scores) + ) + + self.modified = True + + # Update display for all three fields + self.update_tokenizer_display(gguf.Keys.Tokenizer.LIST, new_tokens) + self.update_tokenizer_display(gguf.Keys.Tokenizer.TOKEN_TYPE, new_token_types) + self.update_tokenizer_display(gguf.Keys.Tokenizer.SCORES, new_scores) + + self.statusBar().showMessage("Updated tokenizer data") + + def update_tokenizer_display(self, key, values): + """Update the display of a tokenizer field in the metadata table.""" + for row in range(self.metadata_table.rowCount()): + key_item = self.metadata_table.item(row, 0) + if key_item and key_item.text() == key: + value_str = f"[ {', '.join(str(v) for v in values[:5])}{', ...' if len(values) > 5 else ''} ]" + value_item = self.metadata_table.item(row, 2) + if value_item: + value_item.setText(value_str) + break + + def add_metadata(self): + dialog = AddMetadataDialog(self) + if dialog.exec() == QDialog.DialogCode.Accepted: + key, value_type, value = dialog.get_data() + + if not key: + QMessageBox.warning(self, "Invalid Key", "Key cannot be empty") + return + + # Check if key already exists + for row in range(self.metadata_table.rowCount()): + orig_item = self.metadata_table.item(row, 0) + if orig_item and orig_item.text() == key: + QMessageBox.warning(self, "Duplicate Key", f"Key '{key}' already exists") + return + + # Add to table + row = self.metadata_table.rowCount() + self.metadata_table.insertRow(row) + + # Key + key_item = QTableWidgetItem(key) + key_item.setFlags(key_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.metadata_table.setItem(row, 0, key_item) + + # Type + type_item = QTableWidgetItem(value_type.name) + type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.metadata_table.setItem(row, 1, type_item) + + # Value + value_item = QTableWidgetItem(str(value)) + value_item.setFlags(value_item.flags() | Qt.ItemFlag.ItemIsEditable) + self.metadata_table.setItem(row, 2, value_item) + + # Actions + actions_widget = QWidget() + actions_layout = QHBoxLayout(actions_widget) + actions_layout.setContentsMargins(2, 2, 2, 2) + + remove_button = QPushButton("Remove") + remove_button.setProperty("row", row) + remove_button.setProperty("key", key) + remove_button.clicked.connect(self.remove_metadata) + actions_layout.addWidget(remove_button) + + self.metadata_table.setCellWidget(row, 3, actions_widget) + + # Store the change + self.metadata_changes[key] = (value_type, value) + self.modified = True + + self.statusBar().showMessage(f"Added new metadata key {key}") + + def save_file(self): + if not self.reader: + QMessageBox.warning(self, "No File Open", "Please open a GGUF file first") + return + + if not self.modified and not self.metadata_changes and not self.metadata_to_remove: + QMessageBox.information(self, "No Changes", "No changes to save") + return + + file_path, _ = QFileDialog.getSaveFileName( + self, "Save GGUF File As", "", "GGUF Files (*.gguf);;All Files (*)" + ) + + if not file_path: + return + + try: + self.statusBar().showMessage(f"Saving to {file_path}...") + QApplication.processEvents() + + # Get architecture and endianness from the original file + arch = 'unknown' + field = self.reader.get_field(gguf.Keys.General.ARCHITECTURE) + if field: + arch = field.contents() + + # Create writer + writer = GGUFWriter(file_path, arch=arch, endianess=self.reader.endianess) + + # Get alignment if present + alignment = None + field = self.reader.get_field(gguf.Keys.General.ALIGNMENT) + if field: + alignment = field.contents() + if alignment is not None: + writer.data_alignment = alignment + + # Copy metadata with changes + for field in self.reader.fields.values(): + # Skip virtual fields and fields written by GGUFWriter + if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'): + continue + + # Skip fields marked for removal + if field.name in self.metadata_to_remove: + continue + + # Apply changes if any + sub_type = None + if field.name in self.metadata_changes: + value_type, value = self.metadata_changes[field.name] + if value_type == GGUFValueType.ARRAY: + # Handle array values + sub_type, value = value + else: + # Copy original value + value = field.contents() + value_type = field.types[0] + if value_type == GGUFValueType.ARRAY: + sub_type = field.types[-1] + + if value is not None: + writer.add_key_value(field.name, value, value_type, sub_type=sub_type) + + # Add new metadata + for key, (value_type, value) in self.metadata_changes.items(): + # Skip if the key already existed (we handled it above) + if self.reader.get_field(key) is not None: + continue + + sub_type = None + if value_type == GGUFValueType.ARRAY: + # Handle array values + sub_type, value = value + + writer.add_key_value(key, value, value_type, sub_type=sub_type) + + # Add tensors (including data) + for tensor in self.reader.tensors: + writer.add_tensor(tensor.name, tensor.data, raw_shape=tensor.data.shape, raw_dtype=tensor.tensor_type) + + # Write header and metadata + writer.open_output_file(Path(file_path)) + writer.write_header_to_file() + writer.write_kv_data_to_file() + + # Write tensor data using the optimized method + writer.write_tensors_to_file(progress=False) + + writer.close() + + self.statusBar().showMessage(f"Saved to {file_path}") + + # Ask if user wants to open the new file + reply = QMessageBox.question( + self, "Open Saved File", + "Would you like to open the newly saved file?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, QMessageBox.StandardButton.Yes + ) + + if reply == QMessageBox.StandardButton.Yes: + self.reader = GGUFReader(file_path, 'r') + self.current_file = file_path + self.file_path_edit.setText(file_path) + + self.load_metadata() + self.load_tensors() + + self.metadata_changes = {} + self.metadata_to_remove = set() + self.modified = False + + except Exception as e: + QMessageBox.critical(self, "Error", f"Failed to save file: {str(e)}") + self.statusBar().showMessage("Error saving file") + + +def main() -> None: + parser = argparse.ArgumentParser(description="GUI GGUF Editor") + parser.add_argument("model_path", nargs="?", help="path to GGUF model file to load at startup") + parser.add_argument("--verbose", action="store_true", help="increase output verbosity") + + args = parser.parse_args() + + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) + + app = QApplication(sys.argv) + window = GGUFEditorWindow() + window.show() + + # Load model if specified + if args.model_path: + if os.path.isfile(args.model_path) and args.model_path.endswith('.gguf'): + window.load_file(args.model_path) + else: + logger.error(f"Invalid model path: {args.model_path}") + QMessageBox.warning( + window, + "Invalid Model Path", + f"The specified file does not exist or is not a GGUF file: {args.model_path}") + + sys.exit(app.exec()) + + +if __name__ == '__main__': + main() diff --git a/gguf-py/gguf/scripts/gguf_hash.py b/gguf-py/gguf/scripts/gguf_hash.py new file mode 100644 index 0000000000000000000000000000000000000000..3ef98992197e987d8aefc8fb5f5bf9c787e9d06a --- /dev/null +++ b/gguf-py/gguf/scripts/gguf_hash.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import uuid +import hashlib + +import logging +import argparse +import os +import sys +from pathlib import Path + +from tqdm import tqdm + +# Necessary to load the local gguf package +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists(): + sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from gguf import GGUFReader # noqa: E402 + + +logger = logging.getLogger("gguf-hash") + +# UUID_NAMESPACE_LLAMA_CPP = uuid.uuid5(uuid.NAMESPACE_URL, 'en.wikipedia.org/wiki/Llama.cpp') +UUID_NAMESPACE_LLAMA_CPP = uuid.UUID('ef001206-dadc-5f6d-a15f-3359e577d4e5') + + +# For more information about what field.parts and field.data represent, +# please see the comments in the modify_gguf.py example. +def gguf_hash(reader: GGUFReader, filename: str, disable_progress_bar: bool, no_layer: bool) -> None: + sha1 = hashlib.sha1() + sha256 = hashlib.sha256() + uuidv5_sha1 = hashlib.sha1() + uuidv5_sha1.update(UUID_NAMESPACE_LLAMA_CPP.bytes) + + # Total Weight Calculation For Progress Bar + total_weights = 0 + for n, tensor in enumerate(reader.tensors, 1): + + # We don't need these + if tensor.name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): + continue + + # Calculate Tensor Volume + sum_weights_in_tensor = 1 + for dim in tensor.shape: + sum_weights_in_tensor *= dim + total_weights += sum_weights_in_tensor + + # Hash Progress Bar + bar = tqdm(desc="Hashing", total=total_weights, unit="weights", unit_scale=True, disable=disable_progress_bar) + + # Hashing Process + for tensor in reader.tensors: + + # We don't need these + if tensor.name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): + continue + + # Progressbar + sum_weights_in_tensor = 1 + for dim in tensor.shape: + sum_weights_in_tensor *= dim + bar.update(sum_weights_in_tensor) + + if not no_layer: + + sha1_layer = hashlib.sha1() + sha1_layer.update(tensor.data.data) + print("sha1 {0} {1}:{2}".format(sha1_layer.hexdigest(), filename, tensor.name)) # noqa: NP100 + + sha256_layer = hashlib.sha256() + sha256_layer.update(tensor.data.data) + print("sha256 {0} {1}:{2}".format(sha256_layer.hexdigest(), filename, tensor.name)) # noqa: NP100 + + sha1.update(tensor.data.data) + sha256.update(tensor.data.data) + uuidv5_sha1.update(tensor.data.data) + + # Flush Hash Progress Bar + bar.close() + + # Display Hash Output + print("sha1 {0} {1}".format(sha1.hexdigest(), filename)) # noqa: NP100 + print("sha256 {0} {1}".format(sha256.hexdigest(), filename)) # noqa: NP100 + print("uuid {0} {1}".format(uuid.UUID(bytes=uuidv5_sha1.digest()[:16], version=5), filename)) # noqa: NP100 + + +def main() -> None: + parser = argparse.ArgumentParser(description="Dump GGUF file metadata") + parser.add_argument("model", type=str, help="GGUF format model filename") + parser.add_argument("--no-layer", action="store_true", help="exclude per layer hash") + parser.add_argument("--verbose", action="store_true", help="increase output verbosity") + parser.add_argument("--progressbar", action="store_true", help="enable progressbar") + args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"]) + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) + reader = GGUFReader(args.model, 'r') + gguf_hash(reader, args.model, not args.progressbar, args.no_layer) + + +if __name__ == '__main__': + main() diff --git a/gguf-py/gguf/scripts/gguf_new_metadata.py b/gguf-py/gguf/scripts/gguf_new_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..63f2300348ed0ffad3bb9a4be6aa7aae6b827350 --- /dev/null +++ b/gguf-py/gguf/scripts/gguf_new_metadata.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import logging +import argparse +import os +import sys +import json +from pathlib import Path + +from tqdm import tqdm +from typing import Any, Sequence, NamedTuple + +# Necessary to load the local gguf package +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists(): + sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +import gguf + +logger = logging.getLogger("gguf-new-metadata") + + +class MetadataDetails(NamedTuple): + type: gguf.GGUFValueType + value: Any + description: str = '' + sub_type: gguf.GGUFValueType | None = None + + +def get_field_data(reader: gguf.GGUFReader, key: str) -> Any: + field = reader.get_field(key) + + return field.contents() if field else None + + +def find_token(token_list: Sequence[int], token: str) -> Sequence[int]: + token_ids = [index for index, value in enumerate(token_list) if value == token] + + if len(token_ids) == 0: + raise LookupError(f'Unable to find "{token}" in token list!') + + return token_ids + + +def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: dict[str, MetadataDetails], remove_metadata: Sequence[str]) -> None: + for field in reader.fields.values(): + # Suppress virtual fields and fields written by GGUFWriter + if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'): + logger.debug(f'Suppressing {field.name}') + continue + + # Skip old chat templates if we have new ones + if field.name.startswith(gguf.Keys.Tokenizer.CHAT_TEMPLATE) and gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata: + logger.debug(f'Skipping {field.name}') + continue + + if field.name in remove_metadata: + logger.debug(f'Removing {field.name}') + continue + + val_type = field.types[0] + sub_type = field.types[-1] if val_type == gguf.GGUFValueType.ARRAY else None + old_val = MetadataDetails(val_type, field.contents(), sub_type=sub_type) + val = new_metadata.get(field.name, old_val) + + if field.name in new_metadata: + logger.debug(f'Modifying {field.name}: "{old_val.value}" -> "{val.value}" {val.description}') + del new_metadata[field.name] + elif val.value is not None: + logger.debug(f'Copying {field.name}') + + if val.value is not None: + writer.add_key_value(field.name, val.value, val.type, sub_type=sub_type if val.sub_type is None else val.sub_type) + + if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata: + logger.debug('Adding chat template(s)') + writer.add_chat_template(new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE].value) + del new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] + + for key, val in new_metadata.items(): + logger.debug(f'Adding {key}: "{val.value}" {val.description}') + writer.add_key_value(key, val.value, val.type) + + total_bytes = 0 + + for tensor in reader.tensors: + total_bytes += tensor.n_bytes + writer.add_tensor_info(tensor.name, tensor.data.shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type) + + bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) + + writer.write_header_to_file() + writer.write_kv_data_to_file() + writer.write_ti_data_to_file() + + for tensor in reader.tensors: + writer.write_tensor_data(tensor.data) + bar.update(tensor.n_bytes) + + writer.close() + + +def main() -> None: + tokenizer_metadata = (getattr(gguf.Keys.Tokenizer, n) for n in gguf.Keys.Tokenizer.__dict__.keys() if not n.startswith('_')) + token_names = dict((n.split('.')[-1][:-len('_token_id')], n) for n in tokenizer_metadata if n.endswith('_token_id')) + + parser = argparse.ArgumentParser(description="Make a copy of a GGUF file with new metadata") + parser.add_argument("input", type=Path, help="GGUF format model input filename") + parser.add_argument("output", type=Path, help="GGUF format model output filename") + parser.add_argument("--general-name", type=str, help="The models general.name", metavar='"name"') + parser.add_argument("--general-description", type=str, help="The models general.description", metavar='"Description ..."') + parser.add_argument("--chat-template", type=str, help="Chat template string (or JSON string containing templates)", metavar='"{% ... %} ..."') + parser.add_argument("--chat-template-config", type=Path, help="Config file containing chat template(s)", metavar='tokenizer_config.json') + parser.add_argument("--pre-tokenizer", type=str, help="The models tokenizer.ggml.pre", metavar='"pre tokenizer"') + parser.add_argument("--remove-metadata", action="append", type=str, help="Remove metadata (by key name) from output model", metavar='general.url') + parser.add_argument("--special-token", action="append", type=str, help="Special token by value", nargs=2, metavar=(' | '.join(token_names.keys()), '""')) + parser.add_argument("--special-token-by-id", action="append", type=str, help="Special token by id", nargs=2, metavar=(' | '.join(token_names.keys()), '0')) + parser.add_argument("--force", action="store_true", help="Bypass warnings without confirmation") + parser.add_argument("--verbose", action="store_true", help="Increase output verbosity") + args = parser.parse_args(None if len(sys.argv) > 2 else ["--help"]) + + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) + + new_metadata = {} + remove_metadata = args.remove_metadata or [] + + if args.general_name: + new_metadata[gguf.Keys.General.NAME] = MetadataDetails(gguf.GGUFValueType.STRING, args.general_name) + + if args.general_description: + new_metadata[gguf.Keys.General.DESCRIPTION] = MetadataDetails(gguf.GGUFValueType.STRING, args.general_description) + + if args.chat_template: + new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, json.loads(args.chat_template) if args.chat_template.startswith('[') else args.chat_template) + + if args.chat_template_config: + with open(args.chat_template_config, 'r') as fp: + config = json.load(fp) + template = config.get('chat_template') + if template: + new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, template) + + if args.pre_tokenizer: + new_metadata[gguf.Keys.Tokenizer.PRE] = MetadataDetails(gguf.GGUFValueType.STRING, args.pre_tokenizer) + + if remove_metadata: + logger.warning('*** Warning *** Warning *** Warning **') + logger.warning('* Most metadata is required for a fully functional GGUF file,') + logger.warning('* removing crucial metadata may result in a corrupt output file!') + + if not args.force: + logger.warning('* Enter exactly YES if you are positive you want to proceed:') + response = input('YES, I am sure> ') + if response != 'YES': + logger.info("You didn't enter YES. Okay then, see ya!") + sys.exit(0) + + logger.info(f'* Loading: {args.input}') + reader = gguf.GGUFReader(args.input, 'r') + + arch = get_field_data(reader, gguf.Keys.General.ARCHITECTURE) + + token_list = get_field_data(reader, gguf.Keys.Tokenizer.LIST) or [] + + for name, token in args.special_token or []: + if name not in token_names: + logger.warning(f'Unknown special token "{name}", ignoring...') + else: + ids = find_token(token_list, token) + new_metadata[token_names[name]] = MetadataDetails(gguf.GGUFValueType.UINT32, ids[0], f'= {token}') + + if len(ids) > 1: + logger.warning(f'Multiple "{token}" tokens found, choosing ID {ids[0]}, use --special-token-by-id if you want another:') + logger.warning(', '.join(str(i) for i in ids)) + + for name, id_string in args.special_token_by_id or []: + if name not in token_names: + logger.warning(f'Unknown special token "{name}", ignoring...') + elif not id_string.isdecimal(): + raise LookupError(f'Token ID "{id_string}" is not a valid ID!') + else: + id_int = int(id_string) + + if id_int >= 0 and id_int < len(token_list): + new_metadata[token_names[name]] = MetadataDetails(gguf.GGUFValueType.UINT32, id_int, f'= {token_list[id_int]}') + else: + raise LookupError(f'Token ID {id_int} is not within token list!') + + if os.path.isfile(args.output) and not args.force: + logger.warning('*** Warning *** Warning *** Warning **') + logger.warning(f'* The "{args.output}" GGUF file already exists, it will be overwritten!') + logger.warning('* Enter exactly YES if you are positive you want to proceed:') + response = input('YES, I am sure> ') + if response != 'YES': + logger.info("You didn't enter YES. Okay then, see ya!") + sys.exit(0) + + logger.info(f'* Writing: {args.output}') + writer = gguf.GGUFWriter(args.output, arch=arch, endianess=reader.endianess) + + alignment = get_field_data(reader, gguf.Keys.General.ALIGNMENT) + if alignment is not None: + logger.debug(f'Setting custom alignment: {alignment}') + writer.data_alignment = alignment + + copy_with_new_metadata(reader, writer, new_metadata, remove_metadata) + + +if __name__ == '__main__': + main() diff --git a/gguf-py/gguf/scripts/gguf_set_metadata.py b/gguf-py/gguf/scripts/gguf_set_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..f5809c35c887086f8055e828719041f8ae3ea749 --- /dev/null +++ b/gguf-py/gguf/scripts/gguf_set_metadata.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +import logging +import argparse +import os +import sys +from pathlib import Path + +# Necessary to load the local gguf package +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists(): + sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from gguf import GGUFReader # noqa: E402 + +logger = logging.getLogger("gguf-set-metadata") + + +def minimal_example(filename: str) -> None: + reader = GGUFReader(filename, 'r+') + field = reader.fields['tokenizer.ggml.bos_token_id'] + if field is None: + return + part_index = field.data[0] + field.parts[part_index][0] = 2 # Set tokenizer.ggml.bos_token_id to 2 + # + # So what's this field.data thing? It's helpful because field.parts contains + # _every_ part of the GGUF field. For example, tokenizer.ggml.bos_token_id consists + # of: + # + # Part index 0: Key length (27) + # Part index 1: Key data ("tokenizer.ggml.bos_token_id") + # Part index 2: Field type (4, the id for GGUFValueType.UINT32) + # Part index 3: Field value + # + # Note also that each part is an NDArray slice, so even a part that + # is only a single value like the key length will be a NDArray of + # the key length type (numpy.uint32). + # + # The .data attribute in the Field is a list of relevant part indexes + # and doesn't contain internal GGUF details like the key length part. + # In this case, .data will be [3] - just the part index of the + # field value itself. + + +def set_metadata(reader: GGUFReader, args: argparse.Namespace) -> None: + field = reader.get_field(args.key) + if field is None: + logger.error(f'! Field {repr(args.key)} not found') + sys.exit(1) + # Note that field.types is a list of types. This is because the GGUF + # format supports arrays. For example, an array of UINT32 would + # look like [GGUFValueType.ARRAY, GGUFValueType.UINT32] + handler = reader.gguf_scalar_to_np.get(field.types[0]) if field.types else None + if handler is None: + logger.error(f'! This tool only supports changing simple values, {repr(args.key)} has unsupported type {field.types}') + sys.exit(1) + current_value = field.parts[field.data[0]][0] + new_value = handler(args.value) + logger.info(f'* Preparing to change field {repr(args.key)} from {current_value} to {new_value}') + if current_value == new_value: + logger.info(f'- Key {repr(args.key)} already set to requested value {current_value}') + sys.exit(0) + if args.dry_run: + sys.exit(0) + if not args.force: + logger.warning('*** Warning *** Warning *** Warning **') + logger.warning('* Changing fields in a GGUF file can make it unusable. Proceed at your own risk.') + logger.warning('* Enter exactly YES if you are positive you want to proceed:') + response = input('YES, I am sure> ') + if response != 'YES': + logger.info("You didn't enter YES. Okay then, see ya!") + sys.exit(0) + field.parts[field.data[0]][0] = new_value + logger.info('* Field changed. Successful completion.') + + +def main() -> None: + parser = argparse.ArgumentParser(description="Set a simple value in GGUF file metadata") + parser.add_argument("model", type=str, help="GGUF format model filename") + parser.add_argument("key", type=str, help="Metadata key to set") + parser.add_argument("value", type=str, help="Metadata value to set") + parser.add_argument("--dry-run", action="store_true", help="Don't actually change anything") + parser.add_argument("--force", action="store_true", help="Change the field without confirmation") + parser.add_argument("--verbose", action="store_true", help="increase output verbosity") + + args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"]) + + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) + + logger.info(f'* Loading: {args.model}') + reader = GGUFReader(args.model, 'r' if args.dry_run else 'r+') + set_metadata(reader, args) + + +if __name__ == '__main__': + main() diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..79f044d2a5945236b613ac3733ccc2b60ebd9ba4 --- /dev/null +++ b/gguf-py/gguf/tensor_mapping.py @@ -0,0 +1,1280 @@ +from __future__ import annotations + +from typing import Sequence + +from .constants import MODEL_ARCH, MODEL_TENSOR, MODEL_TENSORS, TENSOR_NAMES + + +class TensorNameMap: + mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = { + # Token embeddings + MODEL_TENSOR.TOKEN_EMBD: ( + "gpt_neox.embed_in", # gptneox + "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone + "transformer.word_embeddings", # falcon + "word_embeddings", # bloom + "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 + "tok_embeddings", # llama-pth + "embeddings.word_embeddings", # bert nomic-bert + "language_model.embedding.word_embeddings", # persimmon + "wte", # gpt2 + "transformer.embd.wte", # phi2 + "model.tok_embeddings", # internlm2 + "model.embedding", # mamba-qbert + "backbone.embedding", # mamba + "backbone.embeddings", # mamba-hf + "transformer.in_out_embed", # Grok + "embedding.word_embeddings", # chatglm + "transformer.token_embeddings", # openelm + "shared", # t5 + "rwkv.embeddings", # rwkv6 + "model.embeddings", # rwkv7 + "model.word_embeddings", # bailingmoe + "language_model.model.embed_tokens", # llama4 + "encoder", # neobert + ), + + # Token type embeddings + MODEL_TENSOR.TOKEN_TYPES: ( + "embeddings.token_type_embeddings", # bert nomic-bert + ), + + # Normalization of token embeddings + MODEL_TENSOR.TOKEN_EMBD_NORM: ( + "word_embeddings_layernorm", # bloom + "embeddings.LayerNorm", # bert + "emb_ln", # nomic-bert + "transformer.norm", # openelm + "rwkv.blocks.0.pre_ln", # rwkv + "rwkv.blocks.0.pre_ln", # rwkv6 + "model.pre_ln", # rwkv7 + "model.layers.0.pre_norm", # rwkv7 + "backbone.norm", # wavtokenizer + ), + + # Position embeddings + MODEL_TENSOR.POS_EMBD: ( + "transformer.wpe", # gpt2 + "embeddings.position_embeddings", # bert + "wpe", # gpt2 + ), + + # Output + MODEL_TENSOR.OUTPUT: ( + "embed_out", # gptneox + "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe + "output", # llama-pth bloom internlm2 + "word_embeddings_for_head", # persimmon + "lm_head.linear", # phi2 + "output_layer", # chatglm + "head", # rwkv + "head.out", # wavtokenizer + "lm_head", # llama4 + ), + + # Output norm + MODEL_TENSOR.OUTPUT_NORM: ( + "gpt_neox.final_layer_norm", # gptneox + "transformer.ln_f", # gpt2 gpt-j falcon jais exaone + "model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe + "norm", # llama-pth + "transformer.norm_f", # mpt dbrx + "ln_f", # refact bloom qwen gpt2 + "language_model.encoder.final_layernorm", # persimmon + "model.final_layernorm", # persimmon + "lm_head.ln", # phi2 + "model.norm_f", # mamba-qbert + "backbone.norm_f", # mamba + "transformer.rms_norm", # Grok + "encoder.final_layernorm", # chatglm + "transformer.norm", # openelm + "model.norm", # nemotron + "rwkv.ln_out", # rwkv6 + "model.ln_out", # rwkv7 + "backbone.final_layer_norm", # wavtokenizer + "model.norm", # llama4 + ), + + # Rope frequencies + MODEL_TENSOR.ROPE_FREQS: ( + "rope.freqs", # llama-pth + "rotary_pos_emb.inv_freq", # chatglm + ), + + MODEL_TENSOR.ROPE_FACTORS_LONG: (), + MODEL_TENSOR.ROPE_FACTORS_SHORT: (), + + MODEL_TENSOR.CONV1D: ( + "backbone.embed", # roberta + ), + } + + block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = { + # Attention norm + MODEL_TENSOR.ATTN_NORM: ( + "gpt_neox.layers.{bid}.input_layernorm", # gptneox + "transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais exaone + "transformer.blocks.{bid}.norm_1", # mpt + "transformer.h.{bid}.input_layernorm", # falcon7b + "h.{bid}.input_layernorm", # bloom + "transformer.h.{bid}.ln_mlp", # falcon40b + "model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe + "layers.{bid}.attention_norm", # llama-pth + "language_model.encoder.layers.{bid}.input_layernorm", # persimmon + "model.layers.{bid}.ln1", # yi + "h.{bid}.ln_1", # gpt2 + "transformer.h.{bid}.ln", # phi2 + "model.layers.layers.{bid}.norm", # plamo + "model.layers.{bid}.attention_norm", # internlm2 + "model.layers.{bid}.norm", # mamba-qbert + "backbone.layers.{bid}.norm", # mamba + "transformer.decoder_layer.{bid}.rms_norm", # Grok + "transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx + "encoder.layers.{bid}.input_layernorm", # chatglm + "transformer.layers.{bid}.attn_norm", # openelm + "rwkv.blocks.{bid}.ln1", # rwkv6 + "model.layers.{bid}.ln1", # rwkv7 + "model.layers.{bid}.input_layernorm", # llama4 + "transformer_encoder.{bid}.attention_norm", # neobert + ), + + # Attention norm 2 + MODEL_TENSOR.ATTN_NORM_2: ( + "transformer.h.{bid}.ln_attn", # falcon40b + "encoder.layer.{bid}.layer_norm_1", # jina-v2-code + "rwkv.blocks.{bid}.ln2", # rwkv6 + "model.layers.{bid}.ln2", # rwkv7 + ), + + # Attention query-key-value + MODEL_TENSOR.ATTN_QKV: ( + "gpt_neox.layers.{bid}.attention.query_key_value", # gptneox + "transformer.h.{bid}.attn.c_attn", # gpt2 qwen jais + "transformer.blocks.{bid}.attn.Wqkv", # mpt + "transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv", # dbrx + "transformer.h.{bid}.self_attention.query_key_value", # falcon + "h.{bid}.self_attention.query_key_value", # bloom + "language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon + "model.layers.{bid}.self_attn.query_key_value", # persimmon + "h.{bid}.attn.c_attn", # gpt2 + "transformer.h.{bid}.mixer.Wqkv", # phi2 + "encoder.layers.{bid}.attn.Wqkv", # nomic-bert + "encoder.layers.{bid}.mixer.Wqkv", # jina + "model.layers.{bid}.self_attn.qkv_proj", # phi3 + "encoder.layers.{bid}.self_attention.query_key_value", # chatglm + "transformer.layers.{bid}.attn.qkv_proj", # openelm + "transformer_encoder.{bid}.qkv", # neobert + ), + + # Attention query + MODEL_TENSOR.ATTN_Q: ( + "model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 phimoe + "model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom + "layers.{bid}.attention.wq", # llama-pth + "encoder.layer.{bid}.attention.self.query", # bert + "transformer.layer.{bid}.attention.q_lin", # distillbert + "transformer.h.{bid}.attn.q_proj", # gpt-j + "model.layers.layers.{bid}.self_attn.q_proj", # plamo + "model.layers.{bid}.attention.wq", # internlm2 + "transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok + "transformer.h.{bid}.attn.attention.q_proj", # exaone + "model.layers.{bid}.self_attn.q_proj", # llama4 + ), + + # Attention key + MODEL_TENSOR.ATTN_K: ( + "model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 phimoe + "model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom + "layers.{bid}.attention.wk", # llama-pth + "encoder.layer.{bid}.attention.self.key", # bert + "transformer.layer.{bid}.attention.k_lin", # distillbert + "transformer.h.{bid}.attn.k_proj", # gpt-j + "transformer.h.{bid}.attn.k", # refact + "model.layers.layers.{bid}.self_attn.k_proj", # plamo + "model.layers.{bid}.attention.wk", # internlm2 + "transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok + "transformer.h.{bid}.attn.attention.k_proj", # exaone + "model.layers.{bid}.self_attn.k_proj", # llama4 + ), + + # Attention value + MODEL_TENSOR.ATTN_V: ( + "model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe + "layers.{bid}.attention.wv", # llama-pth + "encoder.layer.{bid}.attention.self.value", # bert + "transformer.layer.{bid}.attention.v_lin", # distillbert + "transformer.h.{bid}.attn.v_proj", # gpt-j + "transformer.h.{bid}.attn.v", # refact + "model.layers.layers.{bid}.self_attn.v_proj", # plamo + "model.layers.{bid}.attention.wv", # internlm2 + "transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok + "transformer.h.{bid}.attn.attention.v_proj", # exaone + "model.layers.{bid}.self_attn.v_proj", # llama4 + ), + + # Attention output + MODEL_TENSOR.ATTN_OUT: ( + "gpt_neox.layers.{bid}.attention.dense", # gptneox + "transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen jais + "transformer.blocks.{bid}.attn.out_proj", # mpt + "transformer.h.{bid}.self_attention.dense", # falcon + "h.{bid}.self_attention.dense", # bloom + "model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe + "model.layers.{bid}.self_attn.linear_attn", # deci + "layers.{bid}.attention.wo", # llama-pth + "encoder.layer.{bid}.attention.output.dense", # bert + "transformer.layer.{bid}.attention.out_lin", # distillbert + "transformer.h.{bid}.attn.out_proj", # gpt-j + "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon + "model.layers.{bid}.self_attn.dense", # persimmon + "h.{bid}.attn.c_proj", # gpt2 + "transformer.h.{bid}.mixer.out_proj", # phi2 + "model.layers.layers.{bid}.self_attn.o_proj", # plamo + "model.layers.{bid}.attention.wo", # internlm2 + "encoder.layers.{bid}.attn.out_proj", # nomic-bert + "encoder.layers.{bid}.mixer.out_proj", # jina + "transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok + "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx + "encoder.layers.{bid}.self_attention.dense", # chatglm + "transformer.layers.{bid}.attn.out_proj", # openelm + "transformer.h.{bid}.attn.attention.out_proj", # exaone + "model.layers.{bid}.self_attn.o_proj", # llama4 + "transformer_encoder.{bid}.wo", # neobert + ), + + # Attention output norm + MODEL_TENSOR.ATTN_OUT_NORM: ( + "encoder.layer.{bid}.attention.output.LayerNorm", # bert + "transformer.layer.{bid}.sa_layer_norm", # distillbert + "encoder.layers.{bid}.norm1", # nomic-bert + "transformer.decoder_layer.{bid}.rms_norm_1", # Grok + "transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx + ), + + MODEL_TENSOR.ATTN_POST_NORM: ( + "model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge + "model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414 + ), + + # Rotary embeddings + MODEL_TENSOR.ATTN_ROT_EMBD: ( + "model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf + "layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth + "model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq", # plamo + "transformer.h.{bid}.attn.rotary_emb.inv_freq", # codeshell + ), + + # Feed-forward norm + MODEL_TENSOR.FFN_NORM: ( + "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox + "transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone + "h.{bid}.post_attention_layernorm", # bloom + "transformer.blocks.{bid}.norm_2", # mpt + "model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe phimoe + "layers.{bid}.ffn_norm", # llama-pth + "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon + "model.layers.{bid}.ln2", # yi + "h.{bid}.ln_2", # gpt2 + "model.layers.{bid}.ffn_norm", # internlm2 + "transformer.decoder_layer.{bid}.rms_norm_2", # Grok + "encoder.layers.{bid}.post_attention_layernorm", # chatglm + "transformer.layers.{bid}.ffn_norm", # openelm + "model.layers.{bid}.post_attention_layernorm", # llama4 + "transformer_encoder.{bid}.ffn_norm", # neobert + ), + + # Post feed-forward norm + MODEL_TENSOR.FFN_PRE_NORM: ( + "model.layers.{bid}.pre_feedforward_layernorm", # gemma2 + ), + + # Post feed-forward norm + MODEL_TENSOR.FFN_POST_NORM: ( + "model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2 + "model.layers.{bid}.post_mlp_layernorm", # glm-4-0414 + ), + + MODEL_TENSOR.FFN_GATE_INP: ( + "layers.{bid}.feed_forward.gate", # mixtral + "model.layers.{bid}.block_sparse_moe.gate", # mixtral phimoe + "model.layers.{bid}.mlp.gate", # qwen2moe olmoe + "transformer.decoder_layer.{bid}.router", # Grok + "transformer.blocks.{bid}.ffn.router.layer", # dbrx + "model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe + "model.layers.{bid}.feed_forward.router", # llama4 + "encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe + ), + + MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( + "model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe + ), + + MODEL_TENSOR.FFN_EXP_PROBS_B: ( + "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1 + ), + + # Feed-forward up + MODEL_TENSOR.FFN_UP: ( + "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox + "transformer.h.{bid}.mlp.c_fc", # gpt2 jais + "transformer.blocks.{bid}.ffn.up_proj", # mpt + "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon + "h.{bid}.mlp.dense_h_to_4h", # bloom + "model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo2 + "layers.{bid}.feed_forward.w3", # llama-pth + "encoder.layer.{bid}.intermediate.dense", # bert + "transformer.layer.{bid}.ffn.lin1", # distillbert + "transformer.h.{bid}.mlp.fc_in", # gpt-j + "transformer.h.{bid}.mlp.linear_3", # refact + "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon + "model.layers.{bid}.mlp.dense_h_to_4h", # persimmon + "transformer.h.{bid}.mlp.w1", # qwen + "h.{bid}.mlp.c_fc", # gpt2 + "transformer.h.{bid}.mlp.fc1", # phi2 + "model.layers.{bid}.mlp.fc1", # phi2 + "model.layers.{bid}.mlp.gate_up_proj", # phi3 glm-4-0414 + "model.layers.layers.{bid}.mlp.up_proj", # plamo + "model.layers.{bid}.feed_forward.w3", # internlm2 + "encoder.layers.{bid}.mlp.fc11", # nomic-bert + "encoder.layers.{bid}.mlp.fc1", # nomic-bert-moe + "model.layers.{bid}.mlp.c_fc", # starcoder2 + "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 (split up/gate, no longer used) + "encoder.layer.{bid}.mlp.gated_layers", # jina-bert-v2 (GEGLU) + "encoder.layer.{bid}.mlp.up_gated_layer", # jina-v2-code (GEGLU) + "model.layers.{bid}.residual_mlp.w3", # arctic + "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm + "transformer.h.{bid}.mlp.c_fc_1", # exaone + "model.layers.{bid}.feed_forward.up_proj", # llama4 + "transformer_encoder.{bid}.ffn.w12", # neobert + ), + + MODEL_TENSOR.FFN_UP_EXP: ( + "layers.{bid}.feed_forward.experts.w3", # mixtral (merged) + "transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged) + "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx + "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) + "model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged) + "model.layers.{bid}.feed_forward.experts.up_proj", # llama4 + "encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe + ), + + MODEL_TENSOR.FFN_UP_SHEXP: ( + "model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe + "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2 + "model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4 + ), + + # AWQ-activation gate + MODEL_TENSOR.FFN_ACT: ( + "transformer.blocks.{bid}.ffn.act", # mpt + ), + + # Feed-forward gate + MODEL_TENSOR.FFN_GATE: ( + "model.layers.{bid}.mlp.gate_proj", # llama-hf refact olmo2 + "layers.{bid}.feed_forward.w1", # llama-pth + "transformer.h.{bid}.mlp.w2", # qwen + "transformer.h.{bid}.mlp.c_fc2", # jais + "model.layers.layers.{bid}.mlp.gate_proj", # plamo + "model.layers.{bid}.feed_forward.w1", # internlm2 + "encoder.layers.{bid}.mlp.fc12", # nomic-bert + "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 (split up/gate, no longer used) + "transformer.h.{bid}.mlp.linear_1", # refact + "model.layers.{bid}.residual_mlp.w1", # arctic + "transformer.h.{bid}.mlp.c_fc_0", # exaone + "model.layers.{bid}.feed_forward.gate_proj", # llama4 + ), + + MODEL_TENSOR.FFN_GATE_EXP: ( + "layers.{bid}.feed_forward.experts.w1", # mixtral (merged) + "transformer.decoder_layer.{bid}.moe.linear", # Grok (merged) + "transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx + "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) + "model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged) + "model.layers.{bid}.feed_forward.experts.gate_proj", # llama4 + ), + + MODEL_TENSOR.FFN_GATE_SHEXP: ( + "model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe + "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2 + "model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4 + ), + + # Feed-forward down + MODEL_TENSOR.FFN_DOWN: ( + "gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox + "transformer.h.{bid}.mlp.c_proj", # gpt2 refact qwen jais + "transformer.blocks.{bid}.ffn.down_proj", # mpt + "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon + "h.{bid}.mlp.dense_4h_to_h", # bloom + "model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo2 + "layers.{bid}.feed_forward.w2", # llama-pth + "encoder.layer.{bid}.output.dense", # bert + "transformer.layer.{bid}.ffn.lin2", # distillbert + "transformer.h.{bid}.mlp.fc_out", # gpt-j + "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon + "model.layers.{bid}.mlp.dense_4h_to_h", # persimmon + "h.{bid}.mlp.c_proj", # gpt2 + "transformer.h.{bid}.mlp.fc2", # phi2 + "model.layers.{bid}.mlp.fc2", # phi2 + "model.layers.layers.{bid}.mlp.down_proj", # plamo + "model.layers.{bid}.feed_forward.w2", # internlm2 + "encoder.layers.{bid}.mlp.fc2", # nomic-bert + "model.layers.{bid}.mlp.c_proj", # starcoder2 + "encoder.layer.{bid}.mlp.wo", # jina-bert-v2 + "transformer.layers.{bid}.ffn.proj_2", # openelm + "model.layers.{bid}.residual_mlp.w2", # arctic + "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 + "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm + "model.layers.h.{bid}.mlp.c_proj", # exaone + "model.layers.{bid}.feed_forward.down_proj", # llama4 + "transformer_encoder.{bid}.ffn.w3", # neobert + ), + + MODEL_TENSOR.FFN_DOWN_EXP: ( + "layers.{bid}.feed_forward.experts.w2", # mixtral (merged) + "transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged) + "transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx + "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) + "model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe + "model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged) + "model.layers.{bid}.feed_forward.experts.down_proj", # llama4 + "encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe + ), + + MODEL_TENSOR.FFN_DOWN_SHEXP: ( + "model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe + "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2 + "model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4 + "model.layers.{bid}.shared_mlp.output_linear", # granitemoe + ), + + MODEL_TENSOR.ATTN_Q_NORM: ( + "language_model.encoder.layers.{bid}.self_attention.q_layernorm", + "model.layers.{bid}.self_attn.q_layernorm", # persimmon + "model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2 + "transformer.blocks.{bid}.attn.q_ln", # sea-lion + "encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2 + "transformer.layers.{bid}.attn.q_norm", # openelm + ), + + MODEL_TENSOR.ATTN_K_NORM: ( + "language_model.encoder.layers.{bid}.self_attention.k_layernorm", + "model.layers.{bid}.self_attn.k_layernorm", # persimmon + "model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2 + "transformer.blocks.{bid}.attn.k_ln", # sea-lion + "encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2 + "transformer.layers.{bid}.attn.k_norm", # openelm + ), + + MODEL_TENSOR.ROPE_FREQS: ( + "language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon + ), + + MODEL_TENSOR.LAYER_OUT_NORM: ( + "encoder.layer.{bid}.output.LayerNorm", # bert + "transformer.layer.{bid}.output_layer_norm", # distillbert + "encoder.layers.{bid}.norm2", # nomic-bert + "transformer.decoder_layer.{bid}.rms_norm_3", # Grok + "encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2 + "encoder.layer.{bid}.layer_norm_2" # jina-v2-code + ), + + MODEL_TENSOR.SSM_IN: ( + "model.layers.{bid}.in_proj", + "backbone.layers.{bid}.mixer.in_proj", + ), + + MODEL_TENSOR.SSM_CONV1D: ( + "model.layers.{bid}.conv1d", + "backbone.layers.{bid}.mixer.conv1d", + ), + + MODEL_TENSOR.SSM_X: ( + "model.layers.{bid}.x_proj", + "backbone.layers.{bid}.mixer.x_proj", + ), + + MODEL_TENSOR.SSM_DT: ( + "model.layers.{bid}.dt_proj", + "backbone.layers.{bid}.mixer.dt_proj", + ), + + MODEL_TENSOR.SSM_A: ( + "model.layers.{bid}.A_log", + "backbone.layers.{bid}.mixer.A_log", + ), + + MODEL_TENSOR.SSM_D: ( + "model.layers.{bid}.D", + "backbone.layers.{bid}.mixer.D", + ), + + MODEL_TENSOR.SSM_OUT: ( + "model.layers.{bid}.out_proj", + "backbone.layers.{bid}.mixer.out_proj", + ), + + MODEL_TENSOR.TIME_MIX_W0: ( + "model.layers.{bid}.attention.w0", # rwkv7 + ), + + MODEL_TENSOR.TIME_MIX_W1: ( + "rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv6 + "model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2 + "model.layers.{bid}.attention.w1", # rwkv7 + ), + + MODEL_TENSOR.TIME_MIX_W2: ( + "rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv6 + "model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2 + "model.layers.{bid}.attention.w2", # rwkv7 + ), + + MODEL_TENSOR.TIME_MIX_A0: ( + "model.layers.{bid}.attention.a0", # rwkv7 + ), + + MODEL_TENSOR.TIME_MIX_A1: ( + "model.layers.{bid}.attention.a1", # rwkv7 + ), + + MODEL_TENSOR.TIME_MIX_A2: ( + "model.layers.{bid}.attention.a2", # rwkv7 + ), + + MODEL_TENSOR.TIME_MIX_V0: ( + "model.layers.{bid}.attention.v0", # rwkv7 + ), + + MODEL_TENSOR.TIME_MIX_V1: ( + "model.layers.{bid}.attention.v1", # rwkv7 + ), + + MODEL_TENSOR.TIME_MIX_V2: ( + "model.layers.{bid}.attention.v2", # rwkv7 + ), + + MODEL_TENSOR.TIME_MIX_G1: ( + "model.layers.{bid}.attention.g1", # rwkv7 + ), + + MODEL_TENSOR.TIME_MIX_G2: ( + "model.layers.{bid}.attention.g2", # rwkv7 + ), + + MODEL_TENSOR.TIME_MIX_K_K: ( + "model.layers.{bid}.attention.k_k", # rwkv7 + ), + + MODEL_TENSOR.TIME_MIX_K_A: ( + "model.layers.{bid}.attention.k_a", # rwkv7 + ), + + MODEL_TENSOR.TIME_MIX_R_K: ( + "model.layers.{bid}.attention.r_k", # rwkv7 + ), + + MODEL_TENSOR.TIME_MIX_LERP_X: ( + "rwkv.blocks.{bid}.attention.time_maa_x", # rwkv6 + "model.layers.{bid}.self_attn.time_maa_x", # rwkv6qwen2 + ), + + MODEL_TENSOR.TIME_MIX_LERP_K: ( + "rwkv.blocks.{bid}.attention.time_maa_k", # rwkv6 + "model.layers.{bid}.self_attn.time_maa_k", # rwkv6qwen2 + ), + + MODEL_TENSOR.TIME_MIX_LERP_V: ( + "rwkv.blocks.{bid}.attention.time_maa_v", # rwkv6 + "model.layers.{bid}.self_attn.time_maa_v", # rwkv6qwen2 + ), + + MODEL_TENSOR.TIME_MIX_LERP_R: ( + "rwkv.blocks.{bid}.attention.time_maa_r", # rwkv6 + "model.layers.{bid}.self_attn.time_maa_r", # rwkv6qwen2 + ), + + MODEL_TENSOR.TIME_MIX_LERP_G: ( + "rwkv.blocks.{bid}.attention.time_maa_g", # rwkv6 + "model.layers.{bid}.self_attn.time_maa_g", # rwkv6qwen2 + ), + + MODEL_TENSOR.TIME_MIX_LERP_W: ( + "rwkv.blocks.{bid}.attention.time_maa_w", # rwkv6 + "model.layers.{bid}.self_attn.time_maa_w", # rwkv6qwen2 + ), + + MODEL_TENSOR.TIME_MIX_FIRST: ( + "rwkv.blocks.{bid}.attention.time_faaaa", # rwkv6 + ), + + MODEL_TENSOR.TIME_MIX_DECAY: ( + "rwkv.blocks.{bid}.attention.time_decay", # rwkv6 + "model.layers.{bid}.self_attn.time_decay", # rwkv6qwen2 + ), + + MODEL_TENSOR.TIME_MIX_DECAY_W1: ( + "rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv6 + "model.layers.{bid}.self_attn.time_decay_w1", # rwkv6qwen2 + ), + + MODEL_TENSOR.TIME_MIX_DECAY_W2: ( + "rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv6 + "model.layers.{bid}.self_attn.time_decay_w2", # rwkv6qwen2 + ), + + MODEL_TENSOR.TIME_MIX_KEY: ( + "rwkv.blocks.{bid}.attention.key", # rwkv6 + "model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2 + "model.layers.{bid}.attention.key", # rwkv7 + "model.layers.{bid}.attention.k_proj", # rwkv7 + ), + + MODEL_TENSOR.TIME_MIX_VALUE: ( + "rwkv.blocks.{bid}.attention.value", # rwkv6 + "model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2 + "model.layers.{bid}.attention.value", # rwkv7 + "model.layers.{bid}.attention.v_proj", # rwkv7 + ), + + MODEL_TENSOR.TIME_MIX_RECEPTANCE: ( + "rwkv.blocks.{bid}.attention.receptance", # rwkv6 + "model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2 + "model.layers.{bid}.attention.receptance", # rwkv7 + "model.layers.{bid}.attention.r_proj", # rwkv7 + ), + + MODEL_TENSOR.TIME_MIX_GATE: ( + "rwkv.blocks.{bid}.attention.gate", # rwkv6 + "model.layers.{bid}.self_attn.gate", # rwkv6qwen2 + ), + + MODEL_TENSOR.TIME_MIX_LN: ( + "rwkv.blocks.{bid}.attention.ln_x", # rwkv6 + "model.layers.{bid}.attention.ln_x" # rwkv7 + ), + + MODEL_TENSOR.TIME_MIX_OUTPUT: ( + "rwkv.blocks.{bid}.attention.output", # rwkv6 + "model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2 + "model.layers.{bid}.attention.output", # rwkv7 + "model.layers.{bid}.attention.o_proj", # rwkv7 + ), + + MODEL_TENSOR.CHANNEL_MIX_LERP_K: ( + "rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv6 + "model.layers.{bid}.feed_forward.x_k", # rwkv7 + ), + + MODEL_TENSOR.CHANNEL_MIX_LERP_R: ( + "rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv6 + ), + + MODEL_TENSOR.CHANNEL_MIX_KEY: ( + "rwkv.blocks.{bid}.feed_forward.key", # rwkv6 + "model.layers.{bid}.feed_forward.key", # rwkv7 + ), + + MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: ( + "rwkv.blocks.{bid}.feed_forward.receptance", # rwkv6 + ), + + MODEL_TENSOR.CHANNEL_MIX_VALUE: ( + "rwkv.blocks.{bid}.feed_forward.value", # rwkv6 + "model.layers.{bid}.feed_forward.value", # rwkv7 + ), + + MODEL_TENSOR.ATTN_Q_A: ( + "model.layers.{bid}.self_attn.q_a_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_Q_B: ( + "model.layers.{bid}.self_attn.q_b_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_KV_A_MQA: ( + "model.layers.{bid}.self_attn.kv_a_proj_with_mqa", # deepseek2 + ), + + MODEL_TENSOR.ATTN_KV_B: ( + "model.layers.{bid}.self_attn.kv_b_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_K_B: ( + "model.layers.{bid}.self_attn.k_b_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_V_B: ( + "model.layers.{bid}.self_attn.v_b_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_Q_A_NORM: ( + "model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2 + ), + + MODEL_TENSOR.ATTN_KV_A_NORM: ( + "model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2 + ), + + MODEL_TENSOR.ATTN_SUB_NORM: ( + "model.layers.{bid}.self_attn.inner_attn_ln", # bitnet + ), + + MODEL_TENSOR.FFN_SUB_NORM: ( + "model.layers.{bid}.mlp.ffn_layernorm", # bitnet + ), + + MODEL_TENSOR.DEC_ATTN_NORM: ( + "decoder.block.{bid}.layer.0.layer_norm", # t5 + ), + + MODEL_TENSOR.DEC_ATTN_Q: ( + "decoder.block.{bid}.layer.0.SelfAttention.q", # t5 + ), + + MODEL_TENSOR.DEC_ATTN_K: ( + "decoder.block.{bid}.layer.0.SelfAttention.k", # t5 + ), + + MODEL_TENSOR.DEC_ATTN_V: ( + "decoder.block.{bid}.layer.0.SelfAttention.v", # t5 + ), + + MODEL_TENSOR.DEC_ATTN_OUT: ( + "decoder.block.{bid}.layer.0.SelfAttention.o", # t5 + ), + + MODEL_TENSOR.DEC_ATTN_REL_B: ( + "decoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", # t5 + ), + + MODEL_TENSOR.DEC_CROSS_ATTN_NORM: ( + "decoder.block.{bid}.layer.1.layer_norm", # t5 + ), + + MODEL_TENSOR.DEC_CROSS_ATTN_Q: ( + "decoder.block.{bid}.layer.1.EncDecAttention.q", # t5 + ), + + MODEL_TENSOR.DEC_CROSS_ATTN_K: ( + "decoder.block.{bid}.layer.1.EncDecAttention.k", # t5 + ), + + MODEL_TENSOR.DEC_CROSS_ATTN_V: ( + "decoder.block.{bid}.layer.1.EncDecAttention.v", # t5 + ), + + MODEL_TENSOR.DEC_CROSS_ATTN_OUT: ( + "decoder.block.{bid}.layer.1.EncDecAttention.o", # t5 + ), + + MODEL_TENSOR.DEC_CROSS_ATTN_REL_B: ( + "decoder.block.{bid}.layer.1.EncDecAttention.relative_attention_bias", # t5 + ), + + MODEL_TENSOR.DEC_FFN_NORM: ( + "decoder.block.{bid}.layer.2.layer_norm", # t5 + ), + + MODEL_TENSOR.DEC_FFN_GATE: ( + "decoder.block.{bid}.layer.2.DenseReluDense.wi_0", # flan-t5 + ), + + MODEL_TENSOR.DEC_FFN_UP: ( + "decoder.block.{bid}.layer.2.DenseReluDense.wi", # t5 + "decoder.block.{bid}.layer.2.DenseReluDense.wi_1", # flan-t5 + ), + + MODEL_TENSOR.DEC_FFN_DOWN: ( + "decoder.block.{bid}.layer.2.DenseReluDense.wo", # t5 + ), + + MODEL_TENSOR.DEC_OUTPUT_NORM: ( + "decoder.final_layer_norm", # t5 + ), + + MODEL_TENSOR.ENC_ATTN_NORM: ( + "encoder.block.{bid}.layer.0.layer_norm", # t5 + ), + + MODEL_TENSOR.ENC_ATTN_Q: ( + "encoder.block.{bid}.layer.0.SelfAttention.q", # t5 + ), + + MODEL_TENSOR.ENC_ATTN_K: ( + "encoder.block.{bid}.layer.0.SelfAttention.k", # t5 + ), + + MODEL_TENSOR.ENC_ATTN_V: ( + "encoder.block.{bid}.layer.0.SelfAttention.v", # t5 + ), + + MODEL_TENSOR.ENC_ATTN_OUT: ( + "encoder.block.{bid}.layer.0.SelfAttention.o", # t5 + ), + + MODEL_TENSOR.ENC_ATTN_REL_B: ( + "encoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", # t5 + ), + + MODEL_TENSOR.ENC_FFN_NORM: ( + "encoder.block.{bid}.layer.1.layer_norm", # t5 + ), + + MODEL_TENSOR.ENC_FFN_GATE: ( + "encoder.block.{bid}.layer.1.DenseReluDense.wi_0", # flan-t5 + ), + + MODEL_TENSOR.ENC_FFN_UP: ( + "encoder.block.{bid}.layer.1.DenseReluDense.wi", # t5 + "encoder.block.{bid}.layer.1.DenseReluDense.wi_1", # flan-t5 + ), + + MODEL_TENSOR.ENC_FFN_DOWN: ( + "encoder.block.{bid}.layer.1.DenseReluDense.wo", # t5 + ), + + ############################################################################ + # TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg + MODEL_TENSOR.ENC_OUTPUT_NORM: ( + "encoder.final_layer_norm", # t5 + "layer_norm", # neobert + ), + + MODEL_TENSOR.CLS: ( + "classifier", # jina + "classifier.dense", # roberta + "pre_classifier", # distillbert + "dense", # neobert + ), + + MODEL_TENSOR.CLS_OUT: ( + "classifier.out_proj", # roberta + ), + ############################################################################# + + MODEL_TENSOR.CONVNEXT_DW: ( + "backbone.convnext.{bid}.dwconv", # wavtokenizer + ), + + MODEL_TENSOR.CONVNEXT_NORM: ( + "backbone.convnext.{bid}.norm", # wavtokenizer + ), + + MODEL_TENSOR.CONVNEXT_PW1: ( + "backbone.convnext.{bid}.pwconv1", # wavtokenizer + ), + + MODEL_TENSOR.CONVNEXT_PW2: ( + "backbone.convnext.{bid}.pwconv2", # wavtokenizer + ), + + MODEL_TENSOR.CONVNEXT_GAMMA: ( + "backbone.convnext.{bid}.gamma", # wavtokenizer + ), + + MODEL_TENSOR.POSNET_CONV1: ( + "backbone.posnet.{bid}.conv1", # wavtokenizer + ), + + MODEL_TENSOR.POSNET_CONV2: ( + "backbone.posnet.{bid}.conv2", # wavtokenizer + ), + + MODEL_TENSOR.POSNET_NORM: ( + "backbone.posnet.{bid}.norm", # wavtokenizer + ), + + MODEL_TENSOR.POSNET_NORM1: ( + "backbone.posnet.{bid}.norm1", # wavtokenizer + ), + + MODEL_TENSOR.POSNET_NORM2: ( + "backbone.posnet.{bid}.norm2", # wavtokenizer + ), + + MODEL_TENSOR.POSNET_ATTN_NORM: ( + "backbone.posnet.{bid}.norm", # wavtokenizer + ), + + MODEL_TENSOR.POSNET_ATTN_Q: ( + "backbone.posnet.{bid}.q", # wavtokenizer + ), + + MODEL_TENSOR.POSNET_ATTN_K: ( + "backbone.posnet.{bid}.k", # wavtokenizer + ), + + MODEL_TENSOR.POSNET_ATTN_V: ( + "backbone.posnet.{bid}.v", # wavtokenizer + ), + + MODEL_TENSOR.POSNET_ATTN_OUT: ( + "backbone.posnet.{bid}.proj_out", # wavtokenizer + ), + + ############################################################################# + ## Vision encoder + + MODEL_TENSOR.V_MMPROJ: ( + "multi_modal_projector.linear_{bid}", + "visual.merger.mlp.{bid}", # qwen2vl + ), + + MODEL_TENSOR.V_MMPROJ_FC: ( + "model.connector.modality_projection.proj", # SmolVLM + ), + + MODEL_TENSOR.V_MMPROJ_MLP: ( + "model.mm_projector.mlp.mlp.{bid}", + "vision_model.vision_adapter.mlp.fc{bid}", # llama 4 + "mlp1.{bid}", # InternVL + ), + + MODEL_TENSOR.V_MMPROJ_PEG: ( + "model.mm_projector.peg.peg.{bid}", + ), + + MODEL_TENSOR.V_ENC_EMBD_CLS: ( + "vision_tower.vision_model.embeddings.class_embedding", + "vision_model.class_embedding", # llama 4 + ), + + MODEL_TENSOR.V_ENC_EMBD_PATCH: ( + "vision_tower.vision_model.embeddings.patch_embedding", + "vpm.embeddings.patch_embedding", + "model.vision_model.embeddings.patch_embedding", # SmolVLM + "vision_tower.patch_conv", # pixtral + "vision_model.patch_embedding.linear", # llama 4 + "visual.patch_embed.proj", # qwen2vl + ), + + MODEL_TENSOR.V_ENC_EMBD_POS: ( + "vision_tower.vision_model.embeddings.position_embedding", + "vpm.embeddings.position_embedding", + "model.vision_model.embeddings.position_embedding", # SmolVLM + "vision_model.positional_embedding_vlm", # llama 4 + ), + + MODEL_TENSOR.V_ENC_ATTN_Q: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj", + "vpm.encoder.layers.{bid}.self_attn.q_proj", + "model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM + "vision_model.model.layers.{bid}.self_attn.q_proj", # llama4 + "vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral + "visual.blocks.{bid}.attn.q", # qwen2vl, generated + ), + + MODEL_TENSOR.V_ENC_ATTN_Q_NORM: ( + "vision_tower.vision_model.encoder.layers.{bid}.attn.q_norm", # InternVL + ), + + MODEL_TENSOR.V_ENC_ATTN_K: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj", + "vpm.encoder.layers.{bid}.self_attn.k_proj", + "model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM + "vision_model.model.layers.{bid}.self_attn.k_proj", # llama4 + "vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral + "visual.blocks.{bid}.attn.k", # qwen2vl, generated + ), + + MODEL_TENSOR.V_ENC_ATTN_K_NORM: ( + "vision_tower.vision_model.encoder.layers.{bid}.attn.k_norm", # InternVL + ), + + MODEL_TENSOR.V_ENC_ATTN_V: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj", + "vpm.encoder.layers.{bid}.self_attn.v_proj", + "model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM + "vision_model.model.layers.{bid}.self_attn.v_proj", # llama4 + "vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral + "visual.blocks.{bid}.attn.v", # qwen2vl, generated + ), + + MODEL_TENSOR.V_ENC_INPUT_NORM: ( + "vision_tower.vision_model.encoder.layers.{bid}.layer_norm1", + "vision_tower.vision_model.encoder.layers.{bid}.norm1", # InternVL + "vpm.encoder.layers.{bid}.layer_norm1", + "model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM + "vision_tower.transformer.layers.{bid}.attention_norm", # pixtral + "vision_model.model.layers.{bid}.input_layernorm", # llama4 + "visual.blocks.{bid}.norm1", # qwen2vl + ), + + MODEL_TENSOR.V_ENC_ATTN_O: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj", + "vision_tower.vision_model.encoder.layers.{bid}.attn.proj", # InternVL + "vpm.encoder.layers.{bid}.self_attn.out_proj", + "model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM + "vision_model.model.layers.{bid}.self_attn.o_proj", # llama4 + "vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral + "visual.blocks.{bid}.attn.proj", # qwen2vl + ), + + MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( + "vision_tower.vision_model.encoder.layers.{bid}.layer_norm2", + "vision_tower.vision_model.encoder.layers.{bid}.norm2", # InternVL + "vpm.encoder.layers.{bid}.layer_norm2", + "model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM + "vision_model.model.layers.{bid}.post_attention_layernorm", # llama4 + "vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral + "visual.blocks.{bid}.norm2", # qwen2vl + ), + + MODEL_TENSOR.V_ENC_FFN_UP: ( + "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1", + "vpm.encoder.layers.{bid}.mlp.fc1", + "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 + "vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral + "vision_model.model.layers.{bid}.mlp.fc1", # llama4 + "visual.blocks.{bid}.mlp.fc1", # qwen2vl + "visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl + ), + + MODEL_TENSOR.V_ENC_FFN_GATE: ( + "vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral + "visual.blocks.{bid}.mlp.gate_proj", # qwen2.5vl + ), + + MODEL_TENSOR.V_ENC_FFN_DOWN: ( + "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2", + "vpm.encoder.layers.{bid}.mlp.fc2", + "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 + "vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral + "vision_model.model.layers.{bid}.mlp.fc2", # llama4 + "visual.blocks.{bid}.mlp.fc2", # qwen2vl + "visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl + ), + + MODEL_TENSOR.V_LAYER_SCALE_1: ( + "vision_tower.vision_model.encoder.layers.{bid}.ls1", # InternVL + ), + + MODEL_TENSOR.V_LAYER_SCALE_2: ( + "vision_tower.vision_model.encoder.layers.{bid}.ls2", # InternVL + ), + + MODEL_TENSOR.V_PRE_NORM: ( + "vision_tower.vision_model.pre_layrnorm", + "vision_tower.ln_pre", # pixtral + "vision_model.layernorm_pre", # llama4 + ), + + MODEL_TENSOR.V_POST_NORM: ( + "vision_tower.vision_model.post_layernorm", + "model.vision_model.post_layernorm", # SmolVLM + "vision_model.layernorm_post", # llama4 + "visual.merger.ln_q", # qwen2vl + ), + + MODEL_TENSOR.V_MM_INP_PROJ: ( + "multi_modal_projector.mm_input_projection", + ), + + MODEL_TENSOR.V_MM_INP_NORM: ( + "multi_modal_projector.norm", + ), + + MODEL_TENSOR.V_MM_SOFT_EMB_NORM: ( + "multi_modal_projector.mm_soft_emb_norm", + ), + + MODEL_TENSOR.V_RESMPL_POS_EMBD_K: ( + "resampler.pos_embed_k", + ), + + MODEL_TENSOR.V_RESMPL_ATTN_Q: ( + "resampler.attn.in_proj_q", # tensor generated from resampler.attn.in_proj + ), + + MODEL_TENSOR.V_RESMPL_ATTN_K: ( + "resampler.attn.in_proj_k", # tensor generated from resampler.attn.in_proj + ), + + MODEL_TENSOR.V_RESMPL_ATTN_V: ( + "resampler.attn.in_proj_v", # tensor generated from resampler.attn.in_proj + ), + + MODEL_TENSOR.V_RESMPL_ATTN_OUT: ( + "resampler.attn.out_proj", + ), + + MODEL_TENSOR.V_RESMPL_KV: ( + "resampler.kv_proj", + ), + + MODEL_TENSOR.V_RESMPL_POST_NORM: ( + "resampler.ln_post", + ), + + MODEL_TENSOR.V_RESMPL_KV_NORM: ( + "resampler.ln_kv", + ), + + MODEL_TENSOR.V_RESMPL_Q_NORM: ( + "resampler.ln_q", + ), + + MODEL_TENSOR.V_RESMPL_PROJ: ( + "resampler.proj", + ), + + MODEL_TENSOR.V_RESMPL_QUERY: ( + "resampler.query", + ), + + MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: ( + "v.token_embd.img_break", # for pixtral, this is a generated vector + ), + + MODEL_TENSOR.V_MM_PATCH_MERGER: ( + "multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1 + ), + + # audio (mtmd) + + MODEL_TENSOR.A_ENC_EMBD_POS: ( + "audio_tower.embed_positions", # ultravox + ), + + MODEL_TENSOR.A_ENC_CONV1D: ( + "audio_tower.conv{bid}", # ultravox + ), + + MODEL_TENSOR.A_PRE_NORM: (), + + MODEL_TENSOR.A_POST_NORM: ( + "audio_tower.layer_norm", # ultravox + "audio_tower.ln_post", # qwen2omni + ), + + MODEL_TENSOR.A_ENC_ATTN_Q: ( + "audio_tower.layers.{bid}.self_attn.q_proj", # ultravox + ), + + MODEL_TENSOR.A_ENC_ATTN_K: ( + "audio_tower.layers.{bid}.self_attn.k_proj", # ultravox + ), + + MODEL_TENSOR.A_ENC_ATTN_V: ( + "audio_tower.layers.{bid}.self_attn.v_proj", # ultravox + ), + + MODEL_TENSOR.A_ENC_INPUT_NORM: ( + "audio_tower.layers.{bid}.self_attn_layer_norm", # ultravox + ), + + MODEL_TENSOR.A_ENC_OUTPUT: ( + "audio_tower.layers.{bid}.self_attn.out_proj", # ultravox + ), + + MODEL_TENSOR.A_ENC_OUTPUT_NORM: ( + "audio_tower.layers.{bid}.final_layer_norm", # ultravox + ), + + MODEL_TENSOR.A_ENC_FFN_UP: ( + "audio_tower.layers.{bid}.fc1", # ultravox + ), + + MODEL_TENSOR.A_ENC_FFN_GATE: (), + + MODEL_TENSOR.A_ENC_FFN_DOWN: ( + "audio_tower.layers.{bid}.fc2", # ultravox + ), + + # note: some tensors below has "audio." pseudo-prefix, to prevent conflicts with vision tensors + # this prefix is added in the conversion code in modify_tensors() + + MODEL_TENSOR.A_MMPROJ: ( + "audio.multi_modal_projector.linear_{bid}", # ultravox + ), + + MODEL_TENSOR.A_MMPROJ_FC: ( + "audio.multi_modal_projector.linear", # qwen2audio + "audio_tower.proj", # qwen2omni + ), + + MODEL_TENSOR.A_MM_NORM_PRE: ( + "audio.multi_modal_projector.ln_pre", # ultravox + ), + + MODEL_TENSOR.A_MM_NORM_MID: ( + "audio.multi_modal_projector.ln_mid", # ultravox + ), + } + + # architecture-specific block mappings + arch_block_mappings_cfg: dict[MODEL_ARCH, dict[MODEL_TENSOR, tuple[str, ...]]] = { + MODEL_ARCH.ARCTIC: { + MODEL_TENSOR.FFN_NORM: ( + "model.layers.{bid}.residual_layernorm", + ), + MODEL_TENSOR.FFN_NORM_EXP: ( + "model.layers.{bid}.post_attention_layernorm", + ), + }, + } + + mapping: dict[str, tuple[MODEL_TENSOR, str]] + + def __init__(self, arch: MODEL_ARCH, n_blocks: int): + self.mapping = {} + for tensor, keys in self.mappings_cfg.items(): + if tensor not in MODEL_TENSORS[arch]: + continue + tensor_name = TENSOR_NAMES[tensor] + self.mapping[tensor_name] = (tensor, tensor_name) + for key in keys: + self.mapping[key] = (tensor, tensor_name) + if arch in self.arch_block_mappings_cfg: + self.block_mappings_cfg.update(self.arch_block_mappings_cfg[arch]) + for bid in range(n_blocks): + for tensor, keys in self.block_mappings_cfg.items(): + if tensor not in MODEL_TENSORS[arch]: + continue + + tensor_name = TENSOR_NAMES[tensor].format(bid = bid) + self.mapping[tensor_name] = (tensor, tensor_name) + for key in keys: + key = key.format(bid = bid) + self.mapping[key] = (tensor, tensor_name) + + def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None: + result = self.mapping.get(key) + if result is not None: + return result + for suffix in try_suffixes: + if key.endswith(suffix): + result = self.mapping.get(key[:-len(suffix)]) + if result is not None: + return result[0], result[1] + suffix + return None + + def get_name(self, key: str, try_suffixes: Sequence[str] = ()) -> str | None: + result = self.get_type_and_name(key, try_suffixes = try_suffixes) + if result is None: + return None + return result[1] + + def get_type(self, key: str, try_suffixes: Sequence[str] = ()) -> MODEL_TENSOR | None: + result = self.get_type_and_name(key, try_suffixes = try_suffixes) + if result is None: + return None + return result[0] + + def __getitem__(self, key: str) -> str: + try: + return self.mapping[key][1] + except KeyError: + raise KeyError(key) + + def __contains__(self, key: str) -> bool: + return key in self.mapping + + def __repr__(self) -> str: + return repr(self.mapping) + + +def get_tensor_name_map(arch: MODEL_ARCH, n_blocks: int) -> TensorNameMap: + return TensorNameMap(arch, n_blocks) diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py new file mode 100644 index 0000000000000000000000000000000000000000..00adcbc937398d1e7a9d4b7159a54052bb79ae64 --- /dev/null +++ b/gguf-py/gguf/utility.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +import os +import json + + +def fill_templated_filename(filename: str, output_type: str | None) -> str: + # Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf' + ftype_lowercase: str = output_type.lower() if output_type is not None else "" + ftype_uppercase: str = output_type.upper() if output_type is not None else "" + return filename.format(ftype_lowercase, + outtype=ftype_lowercase, ftype=ftype_lowercase, + OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase) + + +def model_weight_count_rounded_notation(model_params_count: int, min_digits: int = 2) -> str: + if model_params_count > 1e12 : + # Trillions Of Parameters + scaled_model_params = model_params_count * 1e-12 + scale_suffix = "T" + elif model_params_count > 1e9 : + # Billions Of Parameters + scaled_model_params = model_params_count * 1e-9 + scale_suffix = "B" + elif model_params_count > 1e6 : + # Millions Of Parameters + scaled_model_params = model_params_count * 1e-6 + scale_suffix = "M" + else: + # Thousands Of Parameters + scaled_model_params = model_params_count * 1e-3 + scale_suffix = "K" + + fix = max(min_digits - len(str(round(scaled_model_params)).lstrip('0')), 0) + + return f"{scaled_model_params:.{fix}f}{scale_suffix}" + + +def size_label(total_params: int, shared_params: int, expert_params: int, expert_count: int) -> str: + + if expert_count > 0: + pretty_size = model_weight_count_rounded_notation(abs(shared_params) + abs(expert_params), min_digits=2) + size_class = f"{expert_count}x{pretty_size}" + else: + size_class = model_weight_count_rounded_notation(abs(total_params), min_digits=2) + + return size_class + + +def naming_convention(model_name: str | None, base_name: str | None, finetune_string: str | None, version_string: str | None, size_label: str | None, output_type: str | None, model_type: Literal['vocab', 'LoRA'] | None = None) -> str: + # Reference: https://github.com/ggml-org/ggml/blob/master/docs/gguf.md#gguf-naming-convention + + if base_name is not None: + name = base_name.strip().replace(' ', '-').replace('/', '-') + elif model_name is not None: + name = model_name.strip().replace(' ', '-').replace('/', '-') + else: + name = "ggml-model" + + parameters = f"-{size_label}" if size_label is not None else "" + + finetune = f"-{finetune_string.strip().replace(' ', '-')}" if finetune_string is not None else "" + + version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else "" + + encoding = f"-{output_type.strip().replace(' ', '-').upper()}" if output_type is not None else "" + + kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else "" + + return f"{name}{parameters}{finetune}{version}{encoding}{kind}" + + +@dataclass +class RemoteTensor: + dtype: str + shape: tuple[int, ...] + offset_start: int + size: int + url: str + + def data(self) -> bytearray: + # TODO: handle request errors (maybe with limited retries?) + # NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable + data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size)) + return data + + +class SafetensorRemote: + """ + Uility class to handle remote safetensor files. + This class is designed to work with Hugging Face model repositories. + + Example (one model has single safetensor file, the other has multiple): + for model_id in ["ngxson/TEST-Tiny-Llama4", "Qwen/Qwen2.5-7B-Instruct"]: + tensors = SafetensorRemote.get_list_tensors_hf_model(model_id) + print(tensors) + + Example reading tensor data: + tensors = SafetensorRemote.get_list_tensors_hf_model(model_id) + for name, meta in tensors.items(): + dtype, shape, offset_start, size, remote_safetensor_url = meta + # read the tensor data + data = SafetensorRemote.get_data_by_range(remote_safetensor_url, offset_start, size) + print(data) + """ + + BASE_DOMAIN = "https://huggingface.co" + ALIGNMENT = 8 # bytes + + @classmethod + def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]: + """ + Get list of tensors from a Hugging Face model repository. + + Returns a dictionary of tensor names and their metadata. + Each tensor is represented as a tuple of (dtype, shape, offset_start, size, remote_safetensor_url) + """ + # case 1: model has only one single model.safetensor file + is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors") + if is_single_file: + url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors" + return cls.get_list_tensors(url) + + # case 2: model has multiple files + index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json" + is_multiple_files = cls.check_file_exist(index_url) + if is_multiple_files: + # read the index file + index_data = cls.get_data_by_range(index_url, 0) + index_str = index_data.decode('utf-8') + index_json = json.loads(index_str) + assert index_json.get("weight_map") is not None, "weight_map not found in index file" + weight_map = index_json["weight_map"] + # get the list of files + all_files = list(set(weight_map.values())) + all_files.sort() # make sure we load shard files in order + # get the list of tensors + tensors: dict[str, RemoteTensor] = {} + for file in all_files: + url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}" + for key, val in cls.get_list_tensors(url).items(): + tensors[key] = val + return tensors + + raise ValueError(f"Model {model_id} does not have any safetensor files") + + @classmethod + def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]: + """ + Get list of tensors from a remote safetensor file. + + Returns a dictionary of tensor names and their metadata. + Each tensor is represented as a tuple of (dtype, shape, offset_start, size) + """ + metadata, data_start_offset = cls.get_metadata(url) + res: dict[str, RemoteTensor] = {} + + for name, meta in metadata.items(): + if name == "__metadata__": + continue + if not isinstance(meta, dict): + raise ValueError(f"Invalid metadata for tensor '{name}': {meta}") + try: + dtype = meta["dtype"] + shape = meta["shape"] + offset_start_relative, offset_end_relative = meta["data_offsets"] + size = offset_end_relative - offset_start_relative + offset_start = data_start_offset + offset_start_relative + res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url) + except KeyError as e: + raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}") + + return res + + @classmethod + def get_metadata(cls, url: str) -> tuple[dict, int]: + """ + Get JSON metadata from a remote safetensor file. + + Returns tuple of (metadata, data_start_offset) + """ + # Request first 5MB of the file (hopefully enough for metadata) + read_size = 5 * 1024 * 1024 + raw_data = cls.get_data_by_range(url, 0, read_size) + + # Parse header + # First 8 bytes contain the metadata length as u64 little-endian + if len(raw_data) < 8: + raise ValueError("Not enough data to read metadata size") + metadata_length = int.from_bytes(raw_data[:8], byteorder='little') + + # Calculate the data start offset + data_start_offset = 8 + metadata_length + alignment = SafetensorRemote.ALIGNMENT + if data_start_offset % alignment != 0: + data_start_offset += alignment - (data_start_offset % alignment) + + # Check if we have enough data to read the metadata + if len(raw_data) < 8 + metadata_length: + raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {len(raw_data)}") + + # Extract metadata bytes and parse as JSON + metadata_bytes = raw_data[8:8 + metadata_length] + metadata_str = metadata_bytes.decode('utf-8') + try: + metadata = json.loads(metadata_str) + return metadata, data_start_offset + except json.JSONDecodeError as e: + raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}") + + @classmethod + def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes: + """ + Get raw byte data from a remote file by range. + If size is not specified, it will read the entire file. + """ + import requests + from urllib.parse import urlparse + + parsed_url = urlparse(url) + if not parsed_url.scheme or not parsed_url.netloc: + raise ValueError(f"Invalid URL: {url}") + + headers = cls._get_request_headers() + if size > -1: + headers["Range"] = f"bytes={start}-{start + size}" + response = requests.get(url, allow_redirects=True, headers=headers) + response.raise_for_status() + + # Get raw byte data + return response.content[slice(size if size > -1 else None)] + + @classmethod + def check_file_exist(cls, url: str) -> bool: + """ + Check if a file exists at the given URL. + Returns True if the file exists, False otherwise. + """ + import requests + from urllib.parse import urlparse + + parsed_url = urlparse(url) + if not parsed_url.scheme or not parsed_url.netloc: + raise ValueError(f"Invalid URL: {url}") + + try: + headers = cls._get_request_headers() + headers["Range"] = "bytes=0-0" + response = requests.head(url, allow_redirects=True, headers=headers) + # Success (2xx) or redirect (3xx) + return 200 <= response.status_code < 400 + except requests.RequestException: + return False + + @classmethod + def _get_request_headers(cls) -> dict[str, str]: + """Prepare common headers for requests.""" + headers = {"User-Agent": "convert_hf_to_gguf"} + if os.environ.get("HF_TOKEN"): + headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" + return headers diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py new file mode 100644 index 0000000000000000000000000000000000000000..cca0979862a71495bf9ff41c1a48016dc7185865 --- /dev/null +++ b/gguf-py/gguf/vocab.py @@ -0,0 +1,492 @@ +from __future__ import annotations + +import re +import logging +import json +import os +from pathlib import Path +from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVar, runtime_checkable + +from sentencepiece import SentencePieceProcessor + +import gguf + +from .gguf_writer import GGUFWriter + +logger = logging.getLogger(__name__) + + +class SpecialVocab: + merges: list[str] + add_special_token: dict[str, bool] + special_token_ids: dict[str, int] + chat_template: str | Sequence[Mapping[str, str]] | None + + def __init__( + self, path: str | os.PathLike[str], load_merges: bool = False, + special_token_types: Iterable[str] | None = None, + n_vocab: int | None = None, + ): + self.special_token_ids = {} + self.add_special_token = {} + self.n_vocab = n_vocab + self.load_merges = load_merges + self.merges = [] + self.chat_template = None + if special_token_types is not None: + self.special_token_types = special_token_types + else: + self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad', 'cls', 'mask') + self._load(Path(path)) + + def __repr__(self) -> str: + return ''.format( + len(self.merges), self.special_token_ids or "unset", self.add_special_token or "unset", + ) + + def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None: + if self.merges: + if not quiet: + logger.info(f'Adding {len(self.merges)} merge(s).') + gw.add_token_merges(self.merges) + elif self.load_merges: + logger.warning('Adding merges requested but no merges found, output may be non-functional.') + for typ, tokid in self.special_token_ids.items(): + id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None) + if id_handler is None: + logger.warning(f'No handler for special token type {typ} with id {tokid} - skipping') + continue + if not quiet: + logger.info(f'Setting special token type {typ} to {tokid}') + id_handler(tokid) + for typ, value in self.add_special_token.items(): + add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None) + if add_handler is None: + logger.warning(f'No handler for add_{typ}_token with value {value} - skipping') + continue + if not quiet: + logger.info(f'Setting add_{typ}_token to {value}') + add_handler(value) + if self.chat_template is not None: + if not quiet: + logger.info(f'Setting chat_template to {self.chat_template}') + gw.add_chat_template(self.chat_template) + + def _load(self, path: Path) -> None: + self._try_load_from_tokenizer_json(path) + self._try_load_from_config_json(path) + if self.load_merges and not self.merges: + self._try_load_merges_txt(path) + + def _try_load_merges_txt(self, path: Path) -> bool: + merges_file = path / 'merges.txt' + if not merges_file.is_file(): + return False + with open(merges_file, 'r', encoding = 'utf-8') as fp: + first_line = next(fp, '').strip() + if not first_line.startswith('#'): + fp.seek(0) + line_num = 0 + else: + line_num = 1 + merges = [] + for line in fp: + line_num += 1 + line = line.strip() + if not line: + continue + parts = line.split(None, 3) + if len(parts) != 2: + logger.warning(f'{merges_file.name}: Line {line_num}: Entry malformed, ignoring') + continue + merges.append(f'{parts[0]} {parts[1]}') + self.merges = merges + return True + + def _set_special_token(self, typ: str, tid: Any) -> None: + if not isinstance(tid, int): + return + if tid < 0: + raise ValueError(f'invalid value for special token type {typ}: {tid}') + if self.n_vocab is None or tid < self.n_vocab: + if typ in self.special_token_ids: + return + self.special_token_ids[typ] = tid + return + logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping') + + def _try_load_from_tokenizer_json(self, path: Path) -> bool: + tokenizer_file = path / 'tokenizer.json' + if tokenizer_file.is_file(): + with open(tokenizer_file, encoding = 'utf-8') as f: + tokenizer = json.load(f) + if self.load_merges: + merges = tokenizer.get('model', {}).get('merges') + if isinstance(merges, list) and merges: + if isinstance(merges[0], str): + self.merges = merges + elif isinstance(merges[0], list) and len(merges[0]) == 2 and isinstance(merges[0][0], str): + # New format since transformers 4.45 to support spaces in merges + # ref: https://github.com/ggml-org/llama.cpp/issues/9692 + # TODO: internally store as the new format instead of converting to old + if any(' ' in s for pair in merges for s in pair): + logger.warning(f'Spaces in merges detected, encoding as {chr(ord(" ") + 256)!r}') + self.merges = [ + ' '.join( + [ + # ensure the spaces are properly encoded + ''.join( + chr(ord(c) + 256) if c == ' ' else c + for c in part + ) + for part in pair + ] + ) + for pair in merges + ] + else: + raise ValueError("Unknown tokenizer merges format") + added_tokens = tokenizer.get('added_tokens', {}) + else: + added_tokens = {} + tokenizer_config_file = path / 'tokenizer_config.json' + if not tokenizer_config_file.is_file(): + return True + with open(tokenizer_config_file, encoding = 'utf-8') as f: + tokenizer_config = json.load(f) + chat_template_alt = None + chat_template_file = path / 'chat_template.json' + if chat_template_file.is_file(): + with open(chat_template_file, encoding = 'utf-8') as f: + chat_template_alt = json.load(f).get('chat_template') + chat_template = tokenizer_config.get('chat_template', chat_template_alt) + if chat_template is None or isinstance(chat_template, (str, list)): + self.chat_template = chat_template + else: + logger.warning(f'Bad type for chat_template field in {tokenizer_config_file!r} - ignoring') + for typ in self.special_token_types: + add_entry = tokenizer_config.get(f'add_{typ}_token') + if isinstance(add_entry, bool): + self.add_special_token[typ] = add_entry + entry = tokenizer_config.get(f'{typ}_token') + if isinstance(entry, str): + tc_content = entry + elif isinstance(entry, dict): + entry_content = entry.get('content') + if not isinstance(entry_content, str): + continue + tc_content = entry_content + else: + continue + # We only need the first match here. + maybe_token_id = next( + (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content), + None, + ) + self._set_special_token(typ, maybe_token_id) + return True + + def _try_load_from_config_json(self, path: Path) -> bool: + config_file = path / 'config.json' + if not config_file.is_file(): + return False + with open(config_file, encoding = 'utf-8') as f: + config = json.load(f) + for typ in self.special_token_types: + self._set_special_token(typ, config.get(f'{typ}_token_id')) + return True + + +@runtime_checkable +class BaseVocab(Protocol): + tokenizer_model: ClassVar[str] + name: ClassVar[str] + + +@runtime_checkable +class Vocab(BaseVocab, Protocol): + vocab_size: int + added_tokens_dict: dict[str, int] + added_tokens_list: list[str] + fname_tokenizer: Path + + def __init__(self, base_path: Path): ... + def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: ... + + +class NoVocab(BaseVocab): + tokenizer_model = "no_vocab" + name = "no_vocab" + + def __repr__(self) -> str: + return "" + + +class BpeVocab(Vocab): + tokenizer_model = "gpt2" + name = "bpe" + + def __init__(self, base_path: Path): + added_tokens: dict[str, int] = {} + + if (fname_tokenizer := base_path / 'vocab.json').exists(): + # "slow" tokenizer + with open(fname_tokenizer, encoding="utf-8") as f: + self.vocab = json.load(f) + + try: + # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab. + with open(base_path / 'added_tokens.json', encoding="utf-8") as f: + added_tokens = json.load(f) + except FileNotFoundError: + pass + else: + # "fast" tokenizer + fname_tokenizer = base_path / 'tokenizer.json' + + # if this fails, FileNotFoundError propagates to caller + with open(fname_tokenizer, encoding="utf-8") as f: + tokenizer_json = json.load(f) + + tokenizer_model: dict[str, Any] = tokenizer_json['model'] + if ( + tokenizer_model['type'] != 'BPE' or tokenizer_model.get('byte_fallback', False) + or tokenizer_json['decoder']['type'] != 'ByteLevel' + ): + raise FileNotFoundError('Cannot find GPT-2 BPE tokenizer') + + self.vocab = tokenizer_model["vocab"] + + if (added := tokenizer_json.get('added_tokens')) is not None: + # Added tokens here can be duplicates of the main vocabulary. + added_tokens = {item['content']: item['id'] + for item in added + if item['content'] not in self.vocab} + + vocab_size = len(self.vocab) + expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) + actual_ids = sorted(added_tokens.values()) + if expected_ids != actual_ids: + expected_end_id = vocab_size + len(actual_ids) - 1 + raise ValueError(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range " + f"{vocab_size} - {expected_end_id}; got {actual_ids}") + + items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1]) + self.added_tokens_dict = added_tokens + self.added_tokens_list = [text for (text, idx) in items] + self.vocab_size_base = vocab_size + self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) + self.fname_tokenizer = fname_tokenizer + + def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + reverse_vocab = {id: encoded_tok for encoded_tok, id in self.vocab.items()} + + for i, _ in enumerate(self.vocab): + yield reverse_vocab[i], 0.0, gguf.TokenType.NORMAL + + def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + for text in self.added_tokens_list: + score = -1000.0 + yield text.encode("utf-8"), score, gguf.TokenType.CONTROL + + def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + yield from self.bpe_tokens() + yield from self.added_tokens() + + def __repr__(self) -> str: + return f"" + + +class SentencePieceVocab(Vocab): + tokenizer_model = "llama" + name = "spm" + + def __init__(self, base_path: Path): + added_tokens: dict[str, int] = {} + if (fname_tokenizer := base_path / 'tokenizer.model').exists(): + # normal location + try: + with open(base_path / 'added_tokens.json', encoding="utf-8") as f: + added_tokens = json.load(f) + except FileNotFoundError: + pass + elif not (fname_tokenizer := base_path.parent / 'tokenizer.model').exists(): + # not found in alternate location either + raise FileNotFoundError('Cannot find tokenizer.model') + + self.sentencepiece_tokenizer = SentencePieceProcessor() + self.sentencepiece_tokenizer.LoadFromFile(str(fname_tokenizer)) + vocab_size = self.sentencepiece_tokenizer.vocab_size() + + new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size} + expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens))) + actual_new_ids = sorted(new_tokens.keys()) + + if expected_new_ids != actual_new_ids: + raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}") + + # Token pieces that were added to the base vocabulary. + self.added_tokens_dict = added_tokens + self.added_tokens_list = [new_tokens[id] for id in actual_new_ids] + self.vocab_size_base = vocab_size + self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) + self.fname_tokenizer = fname_tokenizer + + def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + tokenizer = self.sentencepiece_tokenizer + for i in range(tokenizer.vocab_size()): + piece = tokenizer.IdToPiece(i) + text = piece.encode("utf-8") + score: float = tokenizer.GetScore(i) + + toktype = gguf.TokenType.NORMAL + if tokenizer.IsUnknown(i): + toktype = gguf.TokenType.UNKNOWN + if tokenizer.IsControl(i): + toktype = gguf.TokenType.CONTROL + + # NOTE: I think added_tokens are user defined. + # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto + # if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED + + if tokenizer.IsUnused(i): + toktype = gguf.TokenType.UNUSED + if tokenizer.IsByte(i): + toktype = gguf.TokenType.BYTE + + yield text, score, toktype + + def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + for text in self.added_tokens_list: + score = -1000.0 + yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED + + def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + yield from self.sentencepiece_tokens() + yield from self.added_tokens() + + def __repr__(self) -> str: + return f"" + + +class LlamaHfVocab(Vocab): + tokenizer_model = "llama" + name = "hfft" + + def __init__(self, base_path: Path): + fname_tokenizer = base_path / 'tokenizer.json' + # if this fails, FileNotFoundError propagates to caller + with open(fname_tokenizer, encoding='utf-8') as f: + tokenizer_json = json.load(f) + + # pre-check so we know if we need transformers + tokenizer_model: dict[str, Any] = tokenizer_json['model'] + is_llama3 = ( + tokenizer_model['type'] == 'BPE' and tokenizer_model.get('ignore_merges', False) + and not tokenizer_model.get('byte_fallback', True) + ) + if is_llama3: + raise TypeError('Llama 3 must be converted with BpeVocab') + + if not is_llama3 and ( + tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False) + or tokenizer_json['decoder']['type'] != 'Sequence' + ): + raise FileNotFoundError('Cannot find Llama BPE tokenizer') + + try: + from transformers import AutoTokenizer + except ImportError as e: + raise ImportError( + "To use LlamaHfVocab, please install the `transformers` package. " + "You can install it with `pip install transformers`." + ) from e + + # Allow the tokenizer to default to slow or fast versions. + # Explicitly set tokenizer to use local paths. + self.tokenizer = AutoTokenizer.from_pretrained( + base_path, + cache_dir=base_path, + local_files_only=True, + ) + assert self.tokenizer.is_fast # assume tokenizer.json is used + + # Initialize lists and dictionaries for added tokens + self.added_tokens_list = [] + self.added_tokens_dict = dict() + self.added_tokens_ids = set() + + # Process added tokens + for tok, tokidx in sorted( + self.tokenizer.get_added_vocab().items(), key=lambda x: x[1] + ): + # Only consider added tokens that are not in the base vocabulary + if tokidx >= self.tokenizer.vocab_size: + self.added_tokens_list.append(tok) + self.added_tokens_dict[tok] = tokidx + self.added_tokens_ids.add(tokidx) + + # Store special tokens and their IDs + self.specials = { + tok: self.tokenizer.get_vocab()[tok] + for tok in self.tokenizer.all_special_tokens + } + self.special_ids = set(self.tokenizer.all_special_ids) + + # Set vocabulary sizes + self.vocab_size_base = self.tokenizer.vocab_size + self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) + + self.fname_tokenizer = fname_tokenizer + + def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + reverse_vocab = { + id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items() + } + + for token_id in range(self.vocab_size_base): + # Skip processing added tokens here + if token_id in self.added_tokens_ids: + continue + + # Convert token text to bytes + token_text = reverse_vocab[token_id].encode("utf-8") + + # Yield token text, score, and type + yield token_text, self.get_token_score(token_id), self.get_token_type( + token_id, token_text, self.special_ids # Reuse already stored special IDs + ) + + def get_token_type(self, token_id: int, token_text: bytes, special_ids: set[int]) -> gguf.TokenType: + # Special case for byte tokens + if re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text): + return gguf.TokenType.BYTE + + # Determine token type based on whether it's a special token + return gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL + + def get_token_score(self, token_id: int) -> float: + # Placeholder for actual logic to determine the token's score + # This needs to be implemented based on specific requirements + return -1000.0 # Default score + + def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + for text in self.added_tokens_list: + if text in self.specials: + toktype = self.get_token_type(self.specials[text], b'', self.special_ids) + score = self.get_token_score(self.specials[text]) + else: + toktype = gguf.TokenType.USER_DEFINED + score = -1000.0 + + yield text.encode("utf-8"), score, toktype + + def has_newline_token(self): + return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab + + def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + yield from self.hf_tokens() + yield from self.added_tokens() + + def __repr__(self) -> str: + return f"" diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..f11351cba17677e262cf600504ebd19eda0a5279 --- /dev/null +++ b/gguf-py/pyproject.toml @@ -0,0 +1,43 @@ +[tool.poetry] +name = "gguf" +version = "0.17.0" +description = "Read and write ML models in GGUF for GGML" +authors = ["GGML "] +packages = [ + {include = "gguf"}, + {include = "gguf/py.typed"}, +] +readme = "README.md" +homepage = "https://ggml.ai" +repository = "https://github.com/ggml-org/llama.cpp" +keywords = ["ggml", "gguf", "llama.cpp"] +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] + +[tool.poetry.dependencies] +python = ">=3.8" +numpy = ">=1.17" +tqdm = ">=4.27" +pyyaml = ">=5.1" +sentencepiece = ">=0.1.98,<=0.2.0" +PySide6 = { version = "^6.9", python = ">=3.9,<3.14", optional = true } + +[tool.poetry.dev-dependencies] +pytest = "^5.2" + +[tool.poetry.extras] +gui = ["PySide6"] + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry.scripts] +gguf-convert-endian = "gguf.scripts.gguf_convert_endian:main" +gguf-dump = "gguf.scripts.gguf_dump:main" +gguf-set-metadata = "gguf.scripts.gguf_set_metadata:main" +gguf-new-metadata = "gguf.scripts.gguf_new_metadata:main" +gguf-editor-gui = "gguf.scripts.gguf_editor_gui:main" diff --git a/gguf-py/tests/__init__.py b/gguf-py/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d23ff9cb7380c5e0f9e246398f2ad0bee49e6981 --- /dev/null +++ b/gguf-py/tests/__init__.py @@ -0,0 +1 @@ +from .test_metadata import * diff --git a/gguf-py/tests/test_metadata.py b/gguf-py/tests/test_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..40d484f4eaa9d0c1270a5b70dc2e5153c6ed8d9f --- /dev/null +++ b/gguf-py/tests/test_metadata.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python3 + +import unittest +from pathlib import Path +import os +import sys + +# Necessary to load the local gguf package +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): + sys.path.insert(0, str(Path(__file__).parent.parent)) + +import gguf + + +class TestMetadataMethod(unittest.TestCase): + + def test_id_to_title(self): + self.assertEqual(gguf.Metadata.id_to_title("Mixtral-8x7B-Instruct-v0.1"), "Mixtral 8x7B Instruct v0.1") + self.assertEqual(gguf.Metadata.id_to_title("Meta-Llama-3-8B"), "Meta Llama 3 8B") + self.assertEqual(gguf.Metadata.id_to_title("hermes-2-pro-llama-3-8b-DPO"), "Hermes 2 Pro Llama 3 8b DPO") + + def test_get_model_id_components(self): + # This is the basic standard form with organization marker + self.assertEqual(gguf.Metadata.get_model_id_components("Mistral/Mixtral-8x7B-Instruct-v0.1"), + ('Mixtral-8x7B-Instruct-v0.1', "Mistral", 'Mixtral', 'Instruct', 'v0.1', '8x7B')) + + # Similar to basic standard form but without organization marker + self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct-v0.1"), + ('Mixtral-8x7B-Instruct-v0.1', None, 'Mixtral', 'Instruct', 'v0.1', '8x7B')) + + # Missing version + self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct"), + ('Mixtral-8x7B-Instruct', None, 'Mixtral', 'Instruct', None, '8x7B')) + + # Missing finetune + self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-v0.1"), + ('Mixtral-8x7B-v0.1', None, 'Mixtral', None, 'v0.1', '8x7B')) + + # Base name and size label only + self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B"), + ('Mixtral-8x7B', None, 'Mixtral', None, None, '8x7B')) + + # Base name and version only + self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-v0.1"), + ('Mixtral-v0.1', None, 'Mixtral', None, 'v0.1', None)) + + ## Edge Cases ## + + # This is too ambiguous... best to err on caution and output nothing + self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral"), + ('Mixtral', None, None, None, None, None)) + + # Basename has numbers mixed in and also size label provided. Must avoid capturing number in basename + self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"), + ('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, '8B')) + + # Non standard naming + self.assertEqual(gguf.Metadata.get_model_id_components("Qwen1.5-MoE-A2.7B-Chat"), + ('Qwen1.5-MoE-A2.7B-Chat', None, 'Qwen1.5-MoE', 'Chat', None, 'A2.7B')) + + # Capture 'sub size labels' e.g. A14B in '57B-A14B' usually refers to activated params/weight count + self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-57B-A14B-Instruct"), + ('Qwen2-57B-A14B-Instruct', None, 'Qwen2', 'Instruct', None, '57B-A14B')) + + # Check that it can handle a real model id with no version code + # Note that 4k in this string is non standard and microsoft were referring to context length rather than weight count + self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Phi-3-mini-4k-instruct", 4 * 10**9), + ('Phi-3-mini-4k-instruct', 'microsoft', 'Phi-3', '4k-instruct', None, 'mini')) + + # There is some legitimate models with only thousands of parameters + self.assertEqual(gguf.Metadata.get_model_id_components("delphi-suite/stories-llama2-50k", 50 * 10**3), + ('stories-llama2-50k', 'delphi-suite', 'stories-llama2', None, None, '50K')) + + # Non standard and not easy to disambiguate + self.assertEqual(gguf.Metadata.get_model_id_components("DeepSeek-Coder-V2-Lite-Instruct"), + ('DeepSeek-Coder-V2-Lite-Instruct', None, 'DeepSeek-Coder-V2-Lite', 'Instruct', None, None)) + + # This is a real model_id where they append 2DPO to refer to Direct Preference Optimization + self.assertEqual(gguf.Metadata.get_model_id_components("crestf411/daybreak-kunoichi-2dpo-7b"), + ('daybreak-kunoichi-2dpo-7b', 'crestf411', 'daybreak-kunoichi', '2dpo', None, '7B')) + + # This is a real model id where the weight size has a decimal point + self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-0.5B-Instruct"), + ('Qwen2-0.5B-Instruct', None, 'Qwen2', 'Instruct', None, '0.5B')) + + # Uses an underscore in the size label + self.assertEqual(gguf.Metadata.get_model_id_components("smallcloudai/Refact-1_6B-fim"), + ('Refact-1_6B-fim', 'smallcloudai', 'Refact', 'fim', None, '1.6B')) + + # Uses Iter3 for the version + self.assertEqual(gguf.Metadata.get_model_id_components("UCLA-AGI/Gemma-2-9B-It-SPPO-Iter3"), + ('Gemma-2-9B-It-SPPO-Iter3', 'UCLA-AGI', 'Gemma-2', 'It-SPPO', 'Iter3', '9B')) + + # Has two potential versions in the basename + self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Hermes-2-Theta-Llama-3-8B"), + ('Hermes-2-Theta-Llama-3-8B', 'NousResearch', 'Hermes-2-Theta-Llama-3', None, None, '8B')) + + # Potential version in the basename + self.assertEqual(gguf.Metadata.get_model_id_components("SeaLLMs/SeaLLMs-v3-7B-Chat"), + ('SeaLLMs-v3-7B-Chat', 'SeaLLMs', 'SeaLLMs-v3', 'Chat', None, '7B')) + + # Underscore in the basename, and 1m for the context size + self.assertEqual(gguf.Metadata.get_model_id_components("internlm/internlm2_5-7b-chat-1m", 7 * 10**9), + ('internlm2_5-7b-chat-1m', 'internlm', 'internlm2_5', 'chat-1m', None, '7B')) + + # Version before the finetune name + self.assertEqual(gguf.Metadata.get_model_id_components("pszemraj/jamba-900M-v0.13-KIx2"), + ('jamba-900M-v0.13-KIx2', 'pszemraj', 'jamba', 'KIx2', 'v0.13', '900M')) + + # TODO: hf suffix which could be ignored but isn't + self.assertEqual(gguf.Metadata.get_model_id_components("state-spaces/mamba-2.8b-hf"), + ('mamba-2.8b-hf', 'state-spaces', 'mamba', 'hf', None, '2.8B')) + + # Two sizes, don't merge them, the other is the number of tokens on which it was trained + self.assertEqual(gguf.Metadata.get_model_id_components("abacaj/llama-161M-100B", 161 * 10**6), + ('llama-161M-100B', 'abacaj', 'llama', '100b', None, '161M')) + + # It's a trap, there is no size label + self.assertEqual(gguf.Metadata.get_model_id_components("SparseLLM/relu-100B", 1340 * 10**6), + ('relu-100B', 'SparseLLM', 'relu', '100b', None, None)) + + # Weird size notation + self.assertEqual(gguf.Metadata.get_model_id_components("bigscience/bloom-7b1-petals"), + ('bloom-7b1-petals', 'bigscience', 'bloom', 'petals', None, '7.1B')) + + # Ignore full-text size labels when there are number-based ones, and deduplicate size labels + self.assertEqual(gguf.Metadata.get_model_id_components("MaziyarPanahi/GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1"), + ('GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1', 'MaziyarPanahi', 'GreenNode-mini', 'multilingual-v1olet-Mistral-Instruct', 'v0.1', '7B')) + + # Instruct in a name without a size label + self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/Mistral-Nemo-Instruct-2407"), + ('Mistral-Nemo-Instruct-2407', 'mistralai', 'Mistral-Nemo', 'Instruct', '2407', None)) + + # Non-obvious splitting relying on 'chat' keyword + self.assertEqual(gguf.Metadata.get_model_id_components("deepseek-ai/DeepSeek-V2-Chat-0628"), + ('DeepSeek-V2-Chat-0628', 'deepseek-ai', 'DeepSeek-V2', 'Chat', '0628', None)) + + # Multiple versions + self.assertEqual(gguf.Metadata.get_model_id_components("OpenGVLab/Mini-InternVL-Chat-2B-V1-5"), + ('Mini-InternVL-Chat-2B-V1-5', 'OpenGVLab', 'Mini-InternVL', 'Chat', 'V1-5', '2B')) + + # TODO: DPO in the name + self.assertEqual(gguf.Metadata.get_model_id_components("jondurbin/bagel-dpo-2.8b-v0.2"), + ('bagel-dpo-2.8b-v0.2', 'jondurbin', 'bagel-dpo', None, 'v0.2', '2.8B')) + + # DPO in name, but can't be used for the finetune to keep 'LLaMA-3' in the basename + self.assertEqual(gguf.Metadata.get_model_id_components("voxmenthe/SFR-Iterative-DPO-LLaMA-3-8B-R-unquantized"), + ('SFR-Iterative-DPO-LLaMA-3-8B-R-unquantized', 'voxmenthe', 'SFR-Iterative-DPO-LLaMA-3', 'R-unquantized', None, '8B')) + + # Too ambiguous + # TODO: should "base" be a 'finetune' or 'size_label'? + # (in this case it should be a size label, but other models use it to signal that they are not finetuned) + self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Florence-2-base"), + ('Florence-2-base', 'microsoft', None, None, None, None)) + + ## Invalid cases ## + + # Start with a dash and has dashes in rows + self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/-Mistral--Nemo-Base-2407-"), + ('-Mistral--Nemo-Base-2407-', 'mistralai', 'Mistral-Nemo-Base', None, '2407', None)) + + ## LoRA ## + + self.assertEqual(gguf.Metadata.get_model_id_components("Llama-3-Instruct-abliteration-LoRA-8B"), + ('Llama-3-Instruct-abliteration-LoRA-8B', None, 'Llama-3', 'Instruct-abliteration-LoRA', None, '8B')) + + # Negative size --> output is a LoRA adaper --> prune "LoRA" out of the name to avoid redundancy with the suffix + self.assertEqual(gguf.Metadata.get_model_id_components("Llama-3-Instruct-abliteration-LoRA-8B", -1234), + ('Llama-3-Instruct-abliteration-LoRA-8B', None, 'Llama-3', 'Instruct-abliteration', None, '8B')) + + def test_apply_metadata_heuristic_from_model_card(self): + model_card = { + 'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'], + 'model-index': [{'name': 'Mixtral-8x7B-Instruct-v0.1', 'results': []}], + 'language': ['en'], + 'datasets': ['teknium/OpenHermes-2.5'], + 'widget': [{'example_title': 'Hermes 2 Pro', 'messages': [{'role': 'system', 'content': 'You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.'}, {'role': 'user', 'content': 'Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.'}]}], + 'base_model': ["EmbeddedLLM/Mistral-7B-Merge-14-v0", "janai-hq/trinity-v1"] + } + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + expect = gguf.Metadata() + expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': '14-v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}] + expect.tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'] + expect.languages=['en'] + expect.datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}] + self.assertEqual(got, expect) + + # Base Model spec is inferred from model id + model_card = {'base_models': 'teknium/OpenHermes-2.5'} + expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + self.assertEqual(got, expect) + + # Base Model spec is only url + model_card = {'base_models': ['https://huggingface.co/teknium/OpenHermes-2.5']} + expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + self.assertEqual(got, expect) + + # Base Model spec is given directly + model_card = {'base_models': [{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]} + expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + self.assertEqual(got, expect) + + # Dataset spec is inferred from model id + model_card = {'datasets': 'teknium/OpenHermes-2.5'} + expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + self.assertEqual(got, expect) + + # Dataset spec is only url + model_card = {'datasets': ['https://huggingface.co/teknium/OpenHermes-2.5']} + expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + self.assertEqual(got, expect) + + # Dataset spec is given directly + model_card = {'datasets': [{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]} + expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + self.assertEqual(got, expect) + + def test_apply_metadata_heuristic_from_hf_parameters(self): + hf_params = {"_name_or_path": "./hermes-2-pro-llama-3-8b-DPO"} + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=hf_params, model_path=None) + expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', finetune='DPO', basename='hermes-2-pro-llama-3', size_label='8B') + self.assertEqual(got, expect) + + def test_apply_metadata_heuristic_from_model_dir(self): + model_dir_path = Path("./hermes-2-pro-llama-3-8b-DPO") + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=None, model_path=model_dir_path) + expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', finetune='DPO', basename='hermes-2-pro-llama-3', size_label='8B') + self.assertEqual(got, expect) + + +if __name__ == "__main__": + unittest.main() diff --git a/gguf-py/tests/test_quants.py b/gguf-py/tests/test_quants.py new file mode 100644 index 0000000000000000000000000000000000000000..f04d5acce279325e5f45b27339c3166803d94e00 --- /dev/null +++ b/gguf-py/tests/test_quants.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python3 + +# Test gguf.quants so that it exactly matches the C implementation of the (de)quantization + +# NOTE: this is kind of a mess, but at least it worked for initially testing the Python implementations. + +from __future__ import annotations + +import argparse +from math import prod +import os +import sys +from pathlib import Path +import ctypes +import logging +import numpy as np + +# Necessary to load the local gguf package +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): + sys.path.insert(0, str(Path(__file__).parent.parent)) + +import gguf +from gguf.constants import GGMLQuantizationType + + +logger = logging.getLogger("test-quants") + + +c_float_p = ctypes.POINTER(ctypes.c_float) + + +class ggml_init_params(ctypes.Structure): + _fields_ = [ + ("mem_size", ctypes.c_size_t), + ("mem_buffer", ctypes.c_void_p), + ("no_alloc", ctypes.c_bool), + ] + + +class GGMLQuants: + libggml: ctypes.CDLL + + def __init__(self, libggml: Path): + self.libggml = ctypes.CDLL(str(libggml)) + self.libggml.ggml_quantize_chunk.restype = ctypes.c_size_t + # enum ggml_type type, + # const float * src, + # void * dst, + # int64_t start, + # int64_t nrows, + # int64_t n_per_row, + # const float * imatrix) { + self.libggml.ggml_quantize_chunk.argtypes = ( + ctypes.c_int, + ctypes.POINTER(ctypes.c_float), + ctypes.c_void_p, + ctypes.c_int64, + ctypes.c_int64, + ctypes.c_int64, + ctypes.POINTER(ctypes.c_float), + ) + + self.libggml.ggml_quantize_requires_imatrix.restype = ctypes.c_bool + self.libggml.ggml_quantize_requires_imatrix.argtypes = (ctypes.c_int,) + + for t in ( + "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", + "q2_K", "q3_K", "q4_K", "q5_K", "q6_K", + "tq1_0", "tq2_0", + "iq2_xxs", "iq2_xs", "iq2_s", "iq3_xxs", "iq3_s", "iq1_s", "iq1_m", + "iq4_nl", "iq4_xs", + ): + dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + t) + dequant_func.restype = None + dequant_func.argtypes = (ctypes.c_void_p, ctypes.POINTER(ctypes.c_float), ctypes.c_int64) + + self.libggml.ggml_fp16_to_fp32_row.restype = None + self.libggml.ggml_fp16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64) + self.libggml.ggml_bf16_to_fp32_row.restype = None + self.libggml.ggml_bf16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64) + + self.libggml.ggml_init.argtypes = (ggml_init_params,) + + self.libggml.ggml_init(ggml_init_params(1 * 1024 * 1024, 0, False)) + + def dequantize(self, tensor: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray: + result = np.zeros(gguf.quant_shape_from_byte_shape(tensor.shape, qtype), dtype=np.float32, order="C") + if qtype == GGMLQuantizationType.F32: + # no-op + result = tensor.view(np.float32) + elif qtype == GGMLQuantizationType.F16: + self.libggml.ggml_fp16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size) + elif qtype == GGMLQuantizationType.BF16: + self.libggml.ggml_bf16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size) + else: + lw_qname = qtype.name.lower() + if lw_qname[-1] == "k": + lw_qname = lw_qname[:-1] + "K" + dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + lw_qname) + dequant_func(tensor.ctypes.data_as(ctypes.c_void_p), result.ctypes.data_as(c_float_p), result.size) + return result + + def quantize(self, data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray: + result = np.zeros(gguf.quant_shape_to_byte_shape(data.shape, qtype), dtype=np.uint8, order="C") + if self.libggml.ggml_quantize_requires_imatrix(qtype.value): + # TODO: is a column-wise sum of squares appropriate? + qw = np.sum((data * data).reshape((-1, data.shape[-1])), axis=0).ctypes.data_as(c_float_p) + else: + qw = ctypes.cast(0, c_float_p) + result_size = self.libggml.ggml_quantize_chunk(qtype.value, data.ctypes.data_as(c_float_p), result.ctypes.data_as(ctypes.c_void_p), 0, prod(data.shape[:-1]), data.shape[-1], qw) + assert result.size == result_size + return result + + +def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType) -> bool: + same = np.array_equal(t1, t2) + if same: + return True + else: + block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] + if t1.dtype == np.float32: + t1 = t1.reshape((-1, block_size)) + t2 = t2.reshape((-1, block_size)) + else: + t1 = t1.reshape((-1, type_size)) + t2 = t2.reshape((-1, type_size)) + x = t1.view(np.uint8) ^ t2.view(np.uint8) + diff_bits = np.count_nonzero(np.unpackbits(x, axis=-1), axis=-1) + num_bad_blocks = np.count_nonzero(diff_bits, axis=0) + if num_bad_blocks == 0 and t1.shape == t2.shape: + logger.debug("Bits are equal, but arrays don't match, likely contains NANs") + return True + logger.debug(f"{num_bad_blocks} bad blocks ({100 * num_bad_blocks / x.shape[0]:.6f}%)") + bad_block_id = np.argmax(diff_bits, axis=0) + logger.debug(f"Worst block id: {bad_block_id}") + logger.debug(f"Sample bad block ({diff_bits[bad_block_id]} differing bits):\n{t1[bad_block_id]}\nReference:\n{t2[bad_block_id]}") + + sum_diff_bits = np.sum(diff_bits) + logger.debug(f"{sum_diff_bits} bits differ ({100 * sum_diff_bits / (x.size * 8):.6f}%)") + return False + + +def do_test(libggml_path: Path, quick: bool = False): + ggml_quants = GGMLQuants(libggml_path) + + np.set_printoptions(precision=None, threshold=(4 * 256) + 1, formatter={"int": lambda n: "0x%02X" % n}) + + r = np.random.randn(8, 1024, 1024).astype(np.float32, copy=False) + + for qtype in (GGMLQuantizationType.F16, *gguf.quants._type_traits.keys()): + has_dequantize = False + has_quantize = False + + try: + gguf.dequantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][1]), dtype=np.uint8), qtype) + has_dequantize = True + except (NotImplementedError, AssertionError) as e: + if isinstance(e, AssertionError): + logger.error(f"Error with {qtype.name}: {e}") + raise e + try: + gguf.quantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][0]), dtype=np.float32), qtype) + has_quantize = True + except (NotImplementedError, AssertionError) as e: + if isinstance(e, AssertionError): + logger.error(f"Error with {qtype.name}: {e}") + raise e + + if not has_dequantize and not has_quantize: + continue + + logger.info(f"Testing {qtype.name}") + + rc = r.copy(order="C") + + pyq = None + ggq = None + + if has_quantize: + logger.debug(f"Quantizing to {qtype.name} with Python") + pyq = gguf.quants.quantize(rc, qtype) + + logger.debug(f"Quantizing to {qtype.name} with C") + ggq = ggml_quants.quantize(rc, qtype) + + if qtype == GGMLQuantizationType.F16: + pyq = pyq.view(np.uint8) + quant_equal = compare_tensors(pyq, ggq, qtype) + + if not quant_equal: + logger.error(f"Quantization to {qtype.name} does not match ❌") + else: + logger.info(f"Quantization to {qtype.name} matches exactly ✅") + + if has_dequantize: + if ggq is None and not quick: + logger.debug(f"Quantizing to {qtype.name} with C") + ggq = ggml_quants.quantize(rc, qtype) + + if ggq is not None: + logger.debug(f"Dequantizing from {qtype.name} with Python") + pydq = gguf.quants.dequantize(ggq, qtype) + logger.debug(f"Dequantizing from {qtype.name} with C") + ggdq = ggml_quants.dequantize(ggq, qtype) + + dequant_equal = compare_tensors(pydq, ggdq, qtype) + + if not dequant_equal: + logger.error(f"Dequantization from {qtype.name} does not match ❌") + else: + logger.info(f"Dequantization from {qtype.name} matches exactly ✅") + + rq_shape = gguf.quants.quant_shape_to_byte_shape((8, 1024, 1024 // 2), qtype) + rq = np.random.random(rq_shape).astype(np.float16).view(np.uint8) + + logger.debug(f"Dequantizing random f16 data as {qtype.name} with Python") + pydq = gguf.quants.dequantize(rq, qtype) + logger.debug(f"Dequantizing random f16 data as {qtype.name} with C") + ggdq = ggml_quants.dequantize(rq, qtype) + + dequant_equal = compare_tensors(pydq, ggdq, qtype) + + if not dequant_equal: + logger.error(f"Dequantization from random f16 data as {qtype.name} does not match ❌") + else: + logger.info(f"Dequantization from random f16 data as {qtype.name} matches exactly ✅") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test Python (de)quantization against the reference C implementation") + parser.add_argument("--libggml", type=Path, default=Path(__file__).parent.parent.parent / "build" / "ggml" / "src" / "libggml.so", help="The path to libggml.so") + parser.add_argument("--quick", action="store_true", help="Don't quantize with C when it's not strictly necessary") + + args = parser.parse_args() + + logging.basicConfig(level=logging.DEBUG) + + do_test(args.libggml, args.quick) diff --git a/grammars/README.md b/grammars/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a63198b5aeb8e65c4bc6b89384792dbd3267b19d --- /dev/null +++ b/grammars/README.md @@ -0,0 +1,382 @@ +# GBNF Guide + +GBNF (GGML BNF) is a format for defining [formal grammars](https://en.wikipedia.org/wiki/Formal_grammar) to constrain model outputs in `llama.cpp`. For example, you can use it to force the model to generate valid JSON, or speak only in emojis. GBNF grammars are supported in various ways in `tools/main` and `tools/server`. + +## Background + +[Backus-Naur Form (BNF)](https://en.wikipedia.org/wiki/Backus%E2%80%93Naur_form) is a notation for describing the syntax of formal languages like programming languages, file formats, and protocols. GBNF is an extension of BNF that primarily adds a few modern regex-like features. + +## Basics + +In GBNF, we define *production rules* that specify how a *non-terminal* (rule name) can be replaced with sequences of *terminals* (characters, specifically Unicode [code points](https://en.wikipedia.org/wiki/Code_point)) and other non-terminals. The basic format of a production rule is `nonterminal ::= sequence...`. + +## Example + +Before going deeper, let's look at some of the features demonstrated in `grammars/chess.gbnf`, a small chess notation grammar: +``` +# `root` specifies the pattern for the overall output +root ::= ( + # it must start with the characters "1. " followed by a sequence + # of characters that match the `move` rule, followed by a space, followed + # by another move, and then a newline + "1. " move " " move "\n" + + # it's followed by one or more subsequent moves, numbered with one or two digits + ([1-9] [0-9]? ". " move " " move "\n")+ +) + +# `move` is an abstract representation, which can be a pawn, nonpawn, or castle. +# The `[+#]?` denotes the possibility of checking or mate signs after moves +move ::= (pawn | nonpawn | castle) [+#]? + +pawn ::= ... +nonpawn ::= ... +castle ::= ... +``` + +## Non-Terminals and Terminals + +Non-terminal symbols (rule names) stand for a pattern of terminals and other non-terminals. They are required to be a dashed lowercase word, like `move`, `castle`, or `check-mate`. + +Terminals are actual characters ([code points](https://en.wikipedia.org/wiki/Code_point)). They can be specified as a sequence like `"1"` or `"O-O"` or as ranges like `[1-9]` or `[NBKQR]`. + +## Characters and character ranges + +Terminals support the full range of Unicode. Unicode characters can be specified directly in the grammar, for example `hiragana ::= [ぁ-ゟ]`, or with escapes: 8-bit (`\xXX`), 16-bit (`\uXXXX`) or 32-bit (`\UXXXXXXXX`). + +Character ranges can be negated with `^`: +``` +single-line ::= [^\n]+ "\n" +``` + +## Sequences and Alternatives + +The order of symbols in a sequence matters. For example, in `"1. " move " " move "\n"`, the `"1. "` must come before the first `move`, etc. + +Alternatives, denoted by `|`, give different sequences that are acceptable. For example, in `move ::= pawn | nonpawn | castle`, `move` can be a `pawn` move, a `nonpawn` move, or a `castle`. + +Parentheses `()` can be used to group sequences, which allows for embedding alternatives in a larger rule or applying repetition and optional symbols (below) to a sequence. + +## Repetition and Optional Symbols + +- `*` after a symbol or sequence means that it can be repeated zero or more times (equivalent to `{0,}`). +- `+` denotes that the symbol or sequence should appear one or more times (equivalent to `{1,}`). +- `?` makes the preceding symbol or sequence optional (equivalent to `{0,1}`). +- `{m}` repeats the precedent symbol or sequence exactly `m` times +- `{m,}` repeats the precedent symbol or sequence at least `m` times +- `{m,n}` repeats the precedent symbol or sequence at between `m` and `n` times (included) +- `{0,n}` repeats the precedent symbol or sequence at most `n` times (included) + +## Comments and newlines + +Comments can be specified with `#`: +``` +# defines optional whitespace +ws ::= [ \t\n]+ +``` + +Newlines are allowed between rules and between symbols or sequences nested inside parentheses. Additionally, a newline after an alternate marker `|` will continue the current rule, even outside of parentheses. + +## The root rule + +In a full grammar, the `root` rule always defines the starting point of the grammar. In other words, it specifies what the entire output must match. + +``` +# a grammar for lists +root ::= ("- " item)+ +item ::= [^\n]+ "\n" +``` + +## Next steps + +This guide provides a brief overview. Check out the GBNF files in this directory (`grammars/`) for examples of full grammars. You can try them out with: +``` +./llama-cli -m --grammar-file grammars/some-grammar.gbnf -p 'Some prompt' +``` + +`llama.cpp` can also convert JSON schemas to grammars either ahead of time or at each request, see below. + +## Troubleshooting + +Grammars currently have performance gotchas (see https://github.com/ggml-org/llama.cpp/issues/4218). + +### Efficient optional repetitions + +A common pattern is to allow repetitions of a pattern `x` up to N times. + +While semantically correct, the syntax `x? x? x?.... x?` (with N repetitions) may result in extremely slow sampling. Instead, you can write `x{0,N}` (or `(x (x (x ... (x)?...)?)?)?` w/ N-deep nesting in earlier llama.cpp versions). + +## Using GBNF grammars + +You can use GBNF grammars: + +- In [llama-server](../tools/server)'s completion endpoints, passed as the `grammar` body field +- In [llama-cli](../tools/main), passed as the `--grammar` & `--grammar-file` flags +- With [test-gbnf-validator](../tests/test-gbnf-validator.cpp), to test them against strings. + +## JSON Schemas → GBNF + +`llama.cpp` supports converting a subset of https://json-schema.org/ to GBNF grammars: + +- In [llama-server](../tools/server): + - For any completion endpoints, passed as the `json_schema` body field + - For the `/chat/completions` endpoint, passed inside the `response_format` body field (e.g. `{"type", "json_object", "schema": {"items": {}}}` or `{ type: "json_schema", json_schema: {"schema": ...} }`) +- In [llama-cli](../tools/main), passed as the `--json` / `-j` flag +- To convert to a grammar ahead of time: + - in CLI, with [examples/json_schema_to_grammar.py](../examples/json_schema_to_grammar.py) + - in JavaScript with [json-schema-to-grammar.mjs](../tools/server/public_legacy/json-schema-to-grammar.mjs) (this is used by the [server](../tools/server)'s Web UI) + +Take a look at [tests](../tests/test-json-schema-to-grammar.cpp) to see which features are likely supported (you'll also find usage examples in https://github.com/ggml-org/llama.cpp/pull/5978, https://github.com/ggml-org/llama.cpp/pull/6659 & https://github.com/ggml-org/llama.cpp/pull/6555). + +```bash +llama-cli \ + -hfr bartowski/Phi-3-medium-128k-instruct-GGUF \ + -hff Phi-3-medium-128k-instruct-Q8_0.gguf \ + -j '{ + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "minLength": 1, + "maxLength": 100 + }, + "age": { + "type": "integer", + "minimum": 0, + "maximum": 150 + } + }, + "required": ["name", "age"], + "additionalProperties": false + }, + "minItems": 10, + "maxItems": 100 + }' \ + -p 'Generate a {name, age}[] JSON array with famous actors of all ages.' +``` + +

+ +Show grammar + +You can convert any schema in command-line with: + +```bash +examples/json_schema_to_grammar.py name-age-schema.json +``` + +``` +char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4}) +item ::= "{" space item-name-kv "," space item-age-kv "}" space +item-age ::= ([0-9] | ([1-8] [0-9] | [9] [0-9]) | "1" ([0-4] [0-9] | [5] "0")) space +item-age-kv ::= "\"age\"" space ":" space item-age +item-name ::= "\"" char{1,100} "\"" space +item-name-kv ::= "\"name\"" space ":" space item-name +root ::= "[" space item ("," space item){9,99} "]" space +space ::= | " " | "\n" [ \t]{0,20} +``` + +
+ +Here is also a list of known limitations (contributions welcome): + +- `additionalProperties` defaults to `false` (produces faster grammars + reduces hallucinations). +- `"additionalProperties": true` may produce keys that contain unescaped newlines. +- Unsupported features are skipped silently. It is currently advised to use the command-line Python converter (see above) to see any warnings, and to inspect the resulting grammar / test it w/ [llama-gbnf-validator](../examples/gbnf-validator/gbnf-validator.cpp). +- Can't mix `properties` w/ `anyOf` / `oneOf` in the same type (https://github.com/ggml-org/llama.cpp/issues/7703) +- [prefixItems](https://json-schema.org/draft/2020-12/json-schema-core#name-prefixitems) is broken (but [items](https://json-schema.org/draft/2020-12/json-schema-core#name-items) works) +- `minimum`, `exclusiveMinimum`, `maximum`, `exclusiveMaximum`: only supported for `"type": "integer"` for now, not `number` +- Nested `$ref`s are broken (https://github.com/ggml-org/llama.cpp/issues/8073) +- [pattern](https://json-schema.org/draft/2020-12/json-schema-validation#name-pattern)s must start with `^` and end with `$` +- Remote `$ref`s not supported in the C++ version (Python & JavaScript versions fetch https refs) +- `string` [formats](https://json-schema.org/draft/2020-12/json-schema-validation#name-defined-formats) lack `uri`, `email` +- No [`patternProperties`](https://json-schema.org/draft/2020-12/json-schema-core#name-patternproperties) + +And a non-exhaustive list of other unsupported features that are unlikely to be implemented (hard and/or too slow to support w/ stateless grammars): + +- [`uniqueItems`](https://json-schema.org/draft/2020-12/json-schema-validation#name-uniqueitems) +- [`contains`](https://json-schema.org/draft/2020-12/json-schema-core#name-contains) / `minContains` +- `$anchor` (cf. [dereferencing](https://json-schema.org/draft/2020-12/json-schema-core#name-dereferencing)) +- [`not`](https://json-schema.org/draft/2020-12/json-schema-core#name-not) +- [Conditionals](https://json-schema.org/draft/2020-12/json-schema-core#name-keywords-for-applying-subsche) `if` / `then` / `else` / `dependentSchemas` + +### A word about additionalProperties + +> [!WARNING] +> The JSON schemas spec states `object`s accept [additional properties](https://json-schema.org/understanding-json-schema/reference/object#additionalproperties) by default. +> Since this is slow and seems prone to hallucinations, we default to no additional properties. +> You can set `"additionalProperties": true` in the the schema of any object to explicitly allow additional properties. + +If you're using [Pydantic](https://pydantic.dev/) to generate schemas, you can enable additional properties with the `extra` config on each model class: + +```python +# pip install pydantic +import json +from typing import Annotated, List +from pydantic import BaseModel, Extra, Field +class QAPair(BaseModel): + class Config: + extra = 'allow' # triggers additionalProperties: true in the JSON schema + question: str + concise_answer: str + justification: str + +class Summary(BaseModel): + class Config: + extra = 'allow' + key_facts: List[Annotated[str, Field(pattern='- .{5,}')]] + question_answers: List[Annotated[List[QAPair], Field(min_items=5)]] + +print(json.dumps(Summary.model_json_schema(), indent=2)) +``` + +
+Show JSON schema & grammar + +```json +{ + "$defs": { + "QAPair": { + "additionalProperties": true, + "properties": { + "question": { + "title": "Question", + "type": "string" + }, + "concise_answer": { + "title": "Concise Answer", + "type": "string" + }, + "justification": { + "title": "Justification", + "type": "string" + } + }, + "required": [ + "question", + "concise_answer", + "justification" + ], + "title": "QAPair", + "type": "object" + } + }, + "additionalProperties": true, + "properties": { + "key_facts": { + "items": { + "pattern": "^- .{5,}$", + "type": "string" + }, + "title": "Key Facts", + "type": "array" + }, + "question_answers": { + "items": { + "items": { + "$ref": "#/$defs/QAPair" + }, + "minItems": 5, + "type": "array" + }, + "title": "Question Answers", + "type": "array" + } + }, + "required": [ + "key_facts", + "question_answers" + ], + "title": "Summary", + "type": "object" +} +``` + +``` +QAPair ::= "{" space QAPair-question-kv "," space QAPair-concise-answer-kv "," space QAPair-justification-kv ( "," space ( QAPair-additional-kv ( "," space QAPair-additional-kv )* ) )? "}" space +QAPair-additional-k ::= ["] ( [c] ([o] ([n] ([c] ([i] ([s] ([e] ([_] ([a] ([n] ([s] ([w] ([e] ([r] char+ | [^"r] char*) | [^"e] char*) | [^"w] char*) | [^"s] char*) | [^"n] char*) | [^"a] char*) | [^"_] char*) | [^"e] char*) | [^"s] char*) | [^"i] char*) | [^"c] char*) | [^"n] char*) | [^"o] char*) | [j] ([u] ([s] ([t] ([i] ([f] ([i] ([c] ([a] ([t] ([i] ([o] ([n] char+ | [^"n] char*) | [^"o] char*) | [^"i] char*) | [^"t] char*) | [^"a] char*) | [^"c] char*) | [^"i] char*) | [^"f] char*) | [^"i] char*) | [^"t] char*) | [^"s] char*) | [^"u] char*) | [q] ([u] ([e] ([s] ([t] ([i] ([o] ([n] char+ | [^"n] char*) | [^"o] char*) | [^"i] char*) | [^"t] char*) | [^"s] char*) | [^"e] char*) | [^"u] char*) | [^"cjq] char* )? ["] space +QAPair-additional-kv ::= QAPair-additional-k ":" space value +QAPair-concise-answer-kv ::= "\"concise_answer\"" space ":" space string +QAPair-justification-kv ::= "\"justification\"" space ":" space string +QAPair-question-kv ::= "\"question\"" space ":" space string +additional-k ::= ["] ( [k] ([e] ([y] ([_] ([f] ([a] ([c] ([t] ([s] char+ | [^"s] char*) | [^"t] char*) | [^"c] char*) | [^"a] char*) | [^"f] char*) | [^"_] char*) | [^"y] char*) | [^"e] char*) | [q] ([u] ([e] ([s] ([t] ([i] ([o] ([n] ([_] ([a] ([n] ([s] ([w] ([e] ([r] ([s] char+ | [^"s] char*) | [^"r] char*) | [^"e] char*) | [^"w] char*) | [^"s] char*) | [^"n] char*) | [^"a] char*) | [^"_] char*) | [^"n] char*) | [^"o] char*) | [^"i] char*) | [^"t] char*) | [^"s] char*) | [^"e] char*) | [^"u] char*) | [^"kq] char* )? ["] space +additional-kv ::= additional-k ":" space value +array ::= "[" space ( value ("," space value)* )? "]" space +boolean ::= ("true" | "false") space +char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4}) +decimal-part ::= [0-9]{1,16} +dot ::= [^\x0A\x0D] +integral-part ::= [0] | [1-9] [0-9]{0,15} +key-facts ::= "[" space (key-facts-item ("," space key-facts-item)*)? "]" space +key-facts-item ::= "\"" "- " key-facts-item-1{5,} "\"" space +key-facts-item-1 ::= dot +key-facts-kv ::= "\"key_facts\"" space ":" space key-facts +null ::= "null" space +number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space +object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space +question-answers ::= "[" space (question-answers-item ("," space question-answers-item)*)? "]" space +question-answers-item ::= "[" space question-answers-item-item ("," space question-answers-item-item){4,} "]" space +question-answers-item-item ::= QAPair +question-answers-kv ::= "\"question_answers\"" space ":" space question-answers +root ::= "{" space key-facts-kv "," space question-answers-kv ( "," space ( additional-kv ( "," space additional-kv )* ) )? "}" space +space ::= | " " | "\n" [ \t]{0,20} +string ::= "\"" char* "\"" space +value ::= object | array | string | number | boolean | null +``` + +
+ +If you're using [Zod](https://zod.dev/), you can make your objects to explicitly allow extra properties w/ `nonstrict()` / `passthrough()` (or explicitly no extra props w/ `z.object(...).strict()` or `z.strictObject(...)`) but note that [zod-to-json-schema](https://github.com/StefanTerdell/zod-to-json-schema) currently always sets `"additionalProperties": false` anyway. + +```js +import { z } from 'zod'; +import { zodToJsonSchema } from 'zod-to-json-schema'; + +const Foo = z.object({ + age: z.number().positive(), + email: z.string().email(), +}).strict(); + +console.log(zodToJsonSchema(Foo)); +``` + +
+Show JSON schema & grammar + +```json +{ + "type": "object", + "properties": { + "age": { + "type": "number", + "exclusiveMinimum": 0 + }, + "email": { + "type": "string", + "format": "email" + } + }, + "required": [ + "age", + "email" + ], + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#" +} +``` + +``` +age-kv ::= "\"age\"" space ":" space number +char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4}) +decimal-part ::= [0-9]{1,16} +email-kv ::= "\"email\"" space ":" space string +integral-part ::= [0] | [1-9] [0-9]{0,15} +number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space +root ::= "{" space age-kv "," space email-kv "}" space +space ::= | " " | "\n" [ \t]{0,20} +string ::= "\"" char* "\"" space +``` + +
diff --git a/grammars/arithmetic.gbnf b/grammars/arithmetic.gbnf new file mode 100644 index 0000000000000000000000000000000000000000..3aa95a9dda7e80dc6ffc5fa07773e334aa6fb9f0 --- /dev/null +++ b/grammars/arithmetic.gbnf @@ -0,0 +1,6 @@ +root ::= (expr "=" ws term "\n")+ +expr ::= term ([-+*/] term)* +term ::= ident | num | "(" ws expr ")" ws +ident ::= [a-z] [a-z0-9_]* ws +num ::= [0-9]+ ws +ws ::= [ \t\n]* diff --git a/grammars/c.gbnf b/grammars/c.gbnf new file mode 100644 index 0000000000000000000000000000000000000000..4a0331dd2d6df91b6e2154b78dc31df57a4e715e --- /dev/null +++ b/grammars/c.gbnf @@ -0,0 +1,42 @@ +root ::= (declaration)* + +declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}" + +dataType ::= "int" ws | "float" ws | "char" ws +identifier ::= [a-zA-Z_] [a-zA-Z_0-9]* + +parameter ::= dataType identifier + +statement ::= + ( dataType identifier ws "=" ws expression ";" ) | + ( identifier ws "=" ws expression ";" ) | + ( identifier ws "(" argList? ")" ";" ) | + ( "return" ws expression ";" ) | + ( "while" "(" condition ")" "{" statement* "}" ) | + ( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) | + ( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) | + ( singleLineComment ) | + ( multiLineComment ) + +forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression +forUpdate ::= identifier ws "=" ws expression + +condition ::= expression relationOperator expression +relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">") + +expression ::= term (("+" | "-") term)* +term ::= factor(("*" | "/") factor)* + +factor ::= identifier | number | unaryTerm | funcCall | parenExpression +unaryTerm ::= "-" factor +funcCall ::= identifier "(" argList? ")" +parenExpression ::= "(" ws expression ws ")" + +argList ::= expression ("," ws expression)* + +number ::= [0-9]+ + +singleLineComment ::= "//" [^\n]* "\n" +multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/" + +ws ::= ([ \t\n]+) diff --git a/grammars/chess.gbnf b/grammars/chess.gbnf new file mode 100644 index 0000000000000000000000000000000000000000..ef0fc1b07f01c0ff32465597ced7dc36780f6171 --- /dev/null +++ b/grammars/chess.gbnf @@ -0,0 +1,13 @@ +# Specifies chess moves as a list in algebraic notation, using PGN conventions + +# Force first move to "1. ", then any 1-2 digit number after, relying on model to follow the pattern +root ::= "1. " move " " move "\n" ([1-9] [0-9]? ". " move " " move "\n")+ +move ::= (pawn | nonpawn | castle) [+#]? + +# piece type, optional file/rank, optional capture, dest file & rank +nonpawn ::= [NBKQR] [a-h]? [1-8]? "x"? [a-h] [1-8] + +# optional file & capture, dest file & rank, optional promotion +pawn ::= ([a-h] "x")? [a-h] [1-8] ("=" [NBKQR])? + +castle ::= "O-O" "-O"? diff --git a/grammars/english.gbnf b/grammars/english.gbnf new file mode 100644 index 0000000000000000000000000000000000000000..2e53686c8215143477deb5800d4694bf9b07280c --- /dev/null +++ b/grammars/english.gbnf @@ -0,0 +1,6 @@ +# note: this might be incomplete, mostly an example +root ::= en-char+ ([ \t\n] en-char+)* +en-char ::= letter | digit | punctuation +letter ::= [a-zA-Z] +digit ::= [0-9] +punctuation ::= [!"#$%&'()*+,-./:;<=>?@[\\\]^_`{|}~] diff --git a/grammars/japanese.gbnf b/grammars/japanese.gbnf new file mode 100644 index 0000000000000000000000000000000000000000..43f25ab598586d740c3e92066bfad3a2e9e3c562 --- /dev/null +++ b/grammars/japanese.gbnf @@ -0,0 +1,7 @@ +# A probably incorrect grammar for Japanese +root ::= jp-char+ ([ \t\n] jp-char+)* +jp-char ::= hiragana | katakana | punctuation | cjk +hiragana ::= [ぁ-ゟ] +katakana ::= [ァ-ヿ] +punctuation ::= [、-〾] +cjk ::= [一-鿿] diff --git a/grammars/json.gbnf b/grammars/json.gbnf new file mode 100644 index 0000000000000000000000000000000000000000..b6448c87b72ec7526404d563464df4e622a32f63 --- /dev/null +++ b/grammars/json.gbnf @@ -0,0 +1,25 @@ +root ::= object +value ::= object | array | string | number | ("true" | "false" | "null") ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\\x7F\x00-\x1F] | + "\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [0-9] [1-9]{0,15})? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= | " " | "\n" [ \t]{0,20} diff --git a/grammars/json_arr.gbnf b/grammars/json_arr.gbnf new file mode 100644 index 0000000000000000000000000000000000000000..b3dc6f9b16277e8e8e1f7dd1f33522695794e0f8 --- /dev/null +++ b/grammars/json_arr.gbnf @@ -0,0 +1,34 @@ +# This is the same as json.gbnf but we restrict whitespaces at the end of the root array +# Useful for generating JSON arrays + +root ::= arr +value ::= object | array | string | number | ("true" | "false" | "null") ws + +arr ::= + "[\n" ws ( + value + (",\n" ws value)* + )? "]" + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\\x7F\x00-\x1F] | + "\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [1-9] [0-9]{0,15})? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= | " " | "\n" [ \t]{0,20} diff --git a/grammars/list.gbnf b/grammars/list.gbnf new file mode 100644 index 0000000000000000000000000000000000000000..51e6c9c4b03292a52dfd2b58acc2312fc8aa9338 --- /dev/null +++ b/grammars/list.gbnf @@ -0,0 +1,4 @@ +root ::= item+ + +# Excludes various line break characters +item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n" diff --git a/include/llama-cpp.h b/include/llama-cpp.h new file mode 100644 index 0000000000000000000000000000000000000000..8f6368177de0918e67daa4059e34519eecb42a78 --- /dev/null +++ b/include/llama-cpp.h @@ -0,0 +1,30 @@ +#pragma once + +#ifndef __cplusplus +#error "This header is for C++ only" +#endif + +#include + +#include "llama.h" + +struct llama_model_deleter { + void operator()(llama_model * model) { llama_model_free(model); } +}; + +struct llama_context_deleter { + void operator()(llama_context * context) { llama_free(context); } +}; + +struct llama_sampler_deleter { + void operator()(llama_sampler * sampler) { llama_sampler_free(sampler); } +}; + +struct llama_adapter_lora_deleter { + void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); } +}; + +typedef std::unique_ptr llama_model_ptr; +typedef std::unique_ptr llama_context_ptr; +typedef std::unique_ptr llama_sampler_ptr; +typedef std::unique_ptr llama_adapter_lora_ptr; diff --git a/include/llama.h b/include/llama.h new file mode 100644 index 0000000000000000000000000000000000000000..635508b10f2ff1a2820ca98b15d26992b539f495 --- /dev/null +++ b/include/llama.h @@ -0,0 +1,1479 @@ +#ifndef LLAMA_H +#define LLAMA_H + +#include "ggml.h" +#include "ggml-cpu.h" +#include "ggml-backend.h" +#include "ggml-opt.h" + +#include +#include +#include +#include + +#ifdef LLAMA_SHARED +# if defined(_WIN32) && !defined(__MINGW32__) +# ifdef LLAMA_BUILD +# define LLAMA_API __declspec(dllexport) +# else +# define LLAMA_API __declspec(dllimport) +# endif +# else +# define LLAMA_API __attribute__ ((visibility ("default"))) +# endif +#else +# define LLAMA_API +#endif + +#ifdef __GNUC__ +# define DEPRECATED(func, hint) func __attribute__((deprecated(hint))) +#elif defined(_MSC_VER) +# define DEPRECATED(func, hint) __declspec(deprecated(hint)) func +#else +# define DEPRECATED(func, hint) func +#endif + +#define LLAMA_DEFAULT_SEED 0xFFFFFFFF + +#define LLAMA_TOKEN_NULL -1 + +#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' +#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' +#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' + +#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN +#define LLAMA_SESSION_VERSION 9 + +#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ +#define LLAMA_STATE_SEQ_VERSION 2 + +#ifdef __cplusplus +extern "C" { +#endif + + // + // C interface + // + // TODO: show sample usage + // + + struct llama_vocab; + struct llama_model; + struct llama_context; + struct llama_sampler; + + typedef struct llama_memory_i * llama_memory_t; + + struct llama_kv_cache; // DEPRECATED (use llama_memory instead) + + typedef int32_t llama_pos; + typedef int32_t llama_token; + typedef int32_t llama_seq_id; + + enum llama_vocab_type { + LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab + LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback + LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE + LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece + LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram + LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization + }; + + // pre-tokenization types + enum llama_vocab_pre_type { + LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, + LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, + LLAMA_VOCAB_PRE_TYPE_FALCON = 4, + LLAMA_VOCAB_PRE_TYPE_MPT = 5, + LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, + LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, + LLAMA_VOCAB_PRE_TYPE_REFACT = 8, + LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, + LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, + LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, + LLAMA_VOCAB_PRE_TYPE_OLMO = 12, + LLAMA_VOCAB_PRE_TYPE_DBRX = 13, + LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, + LLAMA_VOCAB_PRE_TYPE_PORO = 15, + LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, + LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, + LLAMA_VOCAB_PRE_TYPE_VIKING = 18, + LLAMA_VOCAB_PRE_TYPE_JAIS = 19, + LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, + LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, + LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, + LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, + LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, + LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, + LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, + LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, + LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, + LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, + LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, + LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, + LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, + LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, + LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, + }; + + enum llama_rope_type { + LLAMA_ROPE_TYPE_NONE = -1, + LLAMA_ROPE_TYPE_NORM = 0, + LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX, + LLAMA_ROPE_TYPE_MROPE = GGML_ROPE_TYPE_MROPE, + LLAMA_ROPE_TYPE_VISION = GGML_ROPE_TYPE_VISION, + }; + + enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file + LLAMA_TOKEN_TYPE_UNDEFINED = 0, + LLAMA_TOKEN_TYPE_NORMAL = 1, + LLAMA_TOKEN_TYPE_UNKNOWN = 2, + LLAMA_TOKEN_TYPE_CONTROL = 3, + LLAMA_TOKEN_TYPE_USER_DEFINED = 4, + LLAMA_TOKEN_TYPE_UNUSED = 5, + LLAMA_TOKEN_TYPE_BYTE = 6, + }; + + enum llama_token_attr { + LLAMA_TOKEN_ATTR_UNDEFINED = 0, + LLAMA_TOKEN_ATTR_UNKNOWN = 1 << 0, + LLAMA_TOKEN_ATTR_UNUSED = 1 << 1, + LLAMA_TOKEN_ATTR_NORMAL = 1 << 2, + LLAMA_TOKEN_ATTR_CONTROL = 1 << 3, // SPECIAL? + LLAMA_TOKEN_ATTR_USER_DEFINED = 1 << 4, + LLAMA_TOKEN_ATTR_BYTE = 1 << 5, + LLAMA_TOKEN_ATTR_NORMALIZED = 1 << 6, + LLAMA_TOKEN_ATTR_LSTRIP = 1 << 7, + LLAMA_TOKEN_ATTR_RSTRIP = 1 << 8, + LLAMA_TOKEN_ATTR_SINGLE_WORD = 1 << 9, + }; + + // model file types + enum llama_ftype { + LLAMA_FTYPE_ALL_F32 = 0, + LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors + // LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 + // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed + // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed + LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_S = 11, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_M = 12, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_L = 13, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_K_S = 14, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_K_M = 15, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_K_S = 16, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_K_M = 17, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q6_K = 18, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_XXS = 19, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_XS = 20, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q2_K_S = 21, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ3_XS = 22, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ1_S = 24, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ4_NL = 25, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ3_S = 26, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ3_M = 27, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_S = 28, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors + LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors + //LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // removed from gguf files, use Q4_0 and runtime repack + //LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // removed from gguf files, use Q4_0 and runtime repack + //LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // removed from gguf files, use Q4_0 and runtime repack + LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors + LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors + + LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file + }; + + enum llama_rope_scaling_type { + LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED = -1, + LLAMA_ROPE_SCALING_TYPE_NONE = 0, + LLAMA_ROPE_SCALING_TYPE_LINEAR = 1, + LLAMA_ROPE_SCALING_TYPE_YARN = 2, + LLAMA_ROPE_SCALING_TYPE_LONGROPE = 3, + LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_LONGROPE, + }; + + enum llama_pooling_type { + LLAMA_POOLING_TYPE_UNSPECIFIED = -1, + LLAMA_POOLING_TYPE_NONE = 0, + LLAMA_POOLING_TYPE_MEAN = 1, + LLAMA_POOLING_TYPE_CLS = 2, + LLAMA_POOLING_TYPE_LAST = 3, + LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models to attach the classification head to the graph + }; + + enum llama_attention_type { + LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1, + LLAMA_ATTENTION_TYPE_CAUSAL = 0, + LLAMA_ATTENTION_TYPE_NON_CAUSAL = 1, + }; + + enum llama_split_mode { + LLAMA_SPLIT_MODE_NONE = 0, // single GPU + LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs + LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported + }; + + // TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979) + typedef struct llama_token_data { + llama_token id; // token id + float logit; // log-odds of the token + float p; // probability of the token + } llama_token_data; + + typedef struct llama_token_data_array { + // TODO: consider SoA + // NOTE: this pointer can be modified by the samplers + llama_token_data * data; + size_t size; + int64_t selected; // this is the index in the data array (i.e. not the token id) + bool sorted; + } llama_token_data_array; + + typedef bool (*llama_progress_callback)(float progress, void * user_data); + + // Input data for llama_encode/llama_decode + // A llama_batch object can contain input about one or many sequences + // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens + // + // - token : the token ids of the input (used when embd is NULL) + // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) + // - pos : the positions of the respective token in the sequence + // (if set to NULL, the token position will be tracked automatically by llama_encode/llama_decode) + // - seq_id : the sequence to which the respective token belongs + // (if set to NULL, the sequence ID will be assumed to be 0) + // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output + // (if set to NULL: + // - if embeddings: all tokens are output + // - if not: only the last token is output + // ) + // + typedef struct llama_batch { + int32_t n_tokens; + + llama_token * token; + float * embd; + llama_pos * pos; + int32_t * n_seq_id; + llama_seq_id ** seq_id; + int8_t * logits; // TODO: rename this to "output" + } llama_batch; + + enum llama_model_kv_override_type { + LLAMA_KV_OVERRIDE_TYPE_INT, + LLAMA_KV_OVERRIDE_TYPE_FLOAT, + LLAMA_KV_OVERRIDE_TYPE_BOOL, + LLAMA_KV_OVERRIDE_TYPE_STR, + }; + + struct llama_model_kv_override { + enum llama_model_kv_override_type tag; + + char key[128]; + + union { + int64_t val_i64; + double val_f64; + bool val_bool; + char val_str[128]; + }; + }; + + struct llama_model_tensor_buft_override { + const char * pattern; + ggml_backend_buffer_type_t buft; + }; + + struct llama_model_params { + // NULL-terminated list of devices to use for offloading (if NULL, all available devices are used) + ggml_backend_dev_t * devices; + + // NULL-terminated list of buffer types to use for tensors that match a pattern + const struct llama_model_tensor_buft_override * tensor_buft_overrides; + + int32_t n_gpu_layers; // number of layers to store in VRAM + enum llama_split_mode split_mode; // how to split the model across multiple GPUs + + // the GPU that is used for the entire model when split_mode is LLAMA_SPLIT_MODE_NONE + int32_t main_gpu; + + // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices() + const float * tensor_split; + + // Called with a progress value between 0.0 and 1.0. Pass NULL to disable. + // If the provided progress_callback returns true, model loading continues. + // If it returns false, model loading is immediately aborted. + llama_progress_callback progress_callback; + + // context pointer passed to the progress callback + void * progress_callback_user_data; + + // override key-value pairs of the model meta data + const struct llama_model_kv_override * kv_overrides; + + // Keep the booleans together to avoid misalignment during copy-by-value. + bool vocab_only; // only load the vocabulary, no weights + bool use_mmap; // use mmap if possible + bool use_mlock; // force system to keep model in RAM + bool check_tensors; // validate model tensor data + }; + + // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations + // https://github.com/ggml-org/llama.cpp/pull/7544 + struct llama_context_params { + uint32_t n_ctx; // text context, 0 = from model + uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode + uint32_t n_ubatch; // physical maximum batch size + uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models) + int32_t n_threads; // number of threads to use for generation + int32_t n_threads_batch; // number of threads to use for batch processing + + enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` + enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id + enum llama_attention_type attention_type; // attention type to use for embeddings + + // ref: https://github.com/ggml-org/llama.cpp/pull/2054 + float rope_freq_base; // RoPE base frequency, 0 = from model + float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model + float yarn_ext_factor; // YaRN extrapolation mix factor, negative = from model + float yarn_attn_factor; // YaRN magnitude scaling factor + float yarn_beta_fast; // YaRN low correction dim + float yarn_beta_slow; // YaRN high correction dim + uint32_t yarn_orig_ctx; // YaRN original context size + float defrag_thold; // defragment the KV cache if holes/size > thold, <= 0 disabled (default) + + ggml_backend_sched_eval_callback cb_eval; + void * cb_eval_user_data; + + enum ggml_type type_k; // data type for K cache [EXPERIMENTAL] + enum ggml_type type_v; // data type for V cache [EXPERIMENTAL] + + // Abort callback + // if it returns true, execution of llama_decode() will be aborted + // currently works only with CPU execution + ggml_abort_callback abort_callback; + void * abort_callback_data; + + // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. + bool embeddings; // if true, extract embeddings (together with logits) + bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU + bool flash_attn; // use flash attention [EXPERIMENTAL] + bool no_perf; // measure performance timings + bool op_offload; // offload host tensor operations to device + bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) + // NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases + // ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573 + }; + + // model quantization parameters + typedef struct llama_model_quantize_params { + int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() + enum llama_ftype ftype; // quantize to this llama_ftype + enum ggml_type output_tensor_type; // output tensor type + enum ggml_type token_embedding_type; // token embeddings tensor type + bool allow_requantize; // allow quantizing non-f32/f16 tensors + bool quantize_output_tensor; // quantize output.weight + bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored + bool pure; // quantize all tensors to the default type + bool keep_split; // quantize to the same number of shards + void * imatrix; // pointer to importance matrix data + void * kv_overrides; // pointer to vector containing overrides + void * tensor_types; // pointer to vector containing tensor types + } llama_model_quantize_params; + + typedef struct llama_logit_bias { + llama_token token; + float bias; + } llama_logit_bias; + + typedef struct llama_sampler_chain_params { + bool no_perf; // whether to measure performance timings + } llama_sampler_chain_params; + + // used in chat template + typedef struct llama_chat_message { + const char * role; + const char * content; + } llama_chat_message; + + // lora adapter + struct llama_adapter_lora; + + // Helpers for getting default parameters + // TODO: update API to start accepting pointers to params structs (https://github.com/ggml-org/llama.cpp/discussions/9172) + LLAMA_API struct llama_model_params llama_model_default_params(void); + LLAMA_API struct llama_context_params llama_context_default_params(void); + LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void); + LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); + + // Initialize the llama + ggml backend + // If numa is true, use NUMA optimizations + // Call once at the start of the program + LLAMA_API void llama_backend_init(void); + + // Call once at the end of the program - currently only used for MPI + LLAMA_API void llama_backend_free(void); + + //optional: + LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa); + + // Optional: an auto threadpool gets created in ggml if not passed explicitly + LLAMA_API void llama_attach_threadpool( + struct llama_context * ctx, + ggml_threadpool_t threadpool, + ggml_threadpool_t threadpool_batch); + + LLAMA_API void llama_detach_threadpool(struct llama_context * ctx); + + DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file( + const char * path_model, + struct llama_model_params params), + "use llama_model_load_from_file instead"); + + // Load the model from a file + // If the file is split into multiple parts, the file name must follow this pattern: -%05d-of-%05d.gguf + // If the split file name does not follow this pattern, use llama_model_load_from_splits + LLAMA_API struct llama_model * llama_model_load_from_file( + const char * path_model, + struct llama_model_params params); + + // Load the model from multiple splits (support custom naming scheme) + // The paths must be in the correct order + LLAMA_API struct llama_model * llama_model_load_from_splits( + const char ** paths, + size_t n_paths, + struct llama_model_params params); + + LLAMA_API void llama_model_save_to_file( + const struct llama_model * model, + const char * path_model); + + DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model), + "use llama_model_free instead"); + + LLAMA_API void llama_model_free(struct llama_model * model); + + LLAMA_API struct llama_context * llama_init_from_model( + struct llama_model * model, + struct llama_context_params params); + + DEPRECATED(LLAMA_API struct llama_context * llama_new_context_with_model( + struct llama_model * model, + struct llama_context_params params), + "use llama_init_from_model instead"); + + // Frees all allocated memory + LLAMA_API void llama_free(struct llama_context * ctx); + + LLAMA_API int64_t llama_time_us(void); + + LLAMA_API size_t llama_max_devices(void); + LLAMA_API size_t llama_max_parallel_sequences(void); + + LLAMA_API bool llama_supports_mmap (void); + LLAMA_API bool llama_supports_mlock (void); + LLAMA_API bool llama_supports_gpu_offload(void); + LLAMA_API bool llama_supports_rpc (void); + + LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); + LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); + LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); + LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); + + DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead"); + DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead"); + DEPRECATED(LLAMA_API int32_t llama_n_layer (const struct llama_model * model), "use llama_model_n_layer instead"); + DEPRECATED(LLAMA_API int32_t llama_n_head (const struct llama_model * model), "use llama_model_n_head instead"); + + DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead"); + + LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); + LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx); + LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type + + DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead"); + + LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); + LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); + + LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model); + LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model); + + // Get the model's RoPE frequency scaling factor + LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); + + // Returns the number of classifier outputs (only valid for classifier models) + // Undefined behavior for non-classifier models + LLAMA_API uint32_t llama_model_n_cls_out(const struct llama_model * model); + + // Returns label of classifier output by index ( 1` + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) + LLAMA_API void llama_memory_seq_div( + llama_memory_t mem, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d); + + // Returns the smallest position present in the memory for the specified sequence + // This is typically non-zero only for SWA caches + // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory + // Return -1 if the sequence is empty + LLAMA_API llama_pos llama_memory_seq_pos_min( + llama_memory_t mem, + llama_seq_id seq_id); + + // Returns the largest position present in the memory for the specified sequence + // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory + // Return -1 if the sequence is empty + LLAMA_API llama_pos llama_memory_seq_pos_max( + llama_memory_t mem, + llama_seq_id seq_id); + + // Check if the memory supports shifting + LLAMA_API bool llama_memory_can_shift(llama_memory_t mem); + + // + // KV cache for self-attention (TODO: deprecate in favor of llama_memory) + // + + // Returns the number of tokens in the KV cache (slow, use only for debug) + // If a KV cell has multiple sequences assigned to it, it will be counted multiple times + DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx), + "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); + + // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) + DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx), + "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); + + // Clear the KV cache - both cell info is erased and KV data is zeroed + DEPRECATED(LLAMA_API void llama_kv_self_clear( + struct llama_context * ctx), + "Use llama_memory_clear() instead"); + + // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) + // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails + // seq_id < 0 : match any sequence + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) + DEPRECATED(LLAMA_API bool llama_kv_self_seq_rm( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1), + "Use llama_memory_seq_rm() instead"); + + // Copy all tokens that belong to the specified sequence to another sequence + // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) + DEPRECATED(LLAMA_API void llama_kv_self_seq_cp( + struct llama_context * ctx, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1), + "Use llama_memory_seq_cp() instead"); + + // Removes all tokens that do not belong to the specified sequence + DEPRECATED(LLAMA_API void llama_kv_self_seq_keep( + struct llama_context * ctx, + llama_seq_id seq_id), + "Use llama_memory_seq_keep() instead"); + + // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) + // If the KV cache is RoPEd, the KV data is updated accordingly: + // - lazily on next llama_decode() + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) + DEPRECATED(LLAMA_API void llama_kv_self_seq_add( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta), + "Use llama_memory_seq_add() instead"); + + // Integer division of the positions by factor of `d > 1` + // If the KV cache is RoPEd, the KV data is updated accordingly: + // - lazily on next llama_decode() + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) + DEPRECATED(void llama_kv_self_seq_div( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d), + "Use llama_memory_seq_div() instead"); + + // Returns the smallest position present in the KV cache for the specified sequence + // This is typically non-zero only for SWA caches + // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache + // Return -1 if the sequence is empty + DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_min( + struct llama_context * ctx, + llama_seq_id seq_id), + "Use llama_memory_seq_pos_min() instead"); + + // Returns the largest position present in the KV cache for the specified sequence + // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache + // Return -1 if the sequence is empty + DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_max( + struct llama_context * ctx, + llama_seq_id seq_id), + "Use llama_memory_seq_pos_max() instead"); + + // Defragment the KV cache + // This will be applied: + // - lazily on next llama_decode() + DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx), + "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'"); + + // Check if the context supports KV cache shifting + DEPRECATED(LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx), + "use llama_memory_can_shift() instead"); + + // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) + DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx), + "simply remove this call, updates are applied lazily on the next llama_decode()"); + + // + // State / sessions + // + + // Returns the *actual* size in bytes of the state + // (logits, embedding and memory) + // Only use when saving the state, not when restoring it, otherwise the size may be too small. + LLAMA_API size_t llama_state_get_size(struct llama_context * ctx); + LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx), + "use llama_state_get_size instead"); + + // Copies the state to the specified destination address. + // Destination needs to have allocated enough memory. + // Returns the number of bytes copied + LLAMA_API size_t llama_state_get_data( + struct llama_context * ctx, + uint8_t * dst, + size_t size); + LLAMA_API DEPRECATED(size_t llama_copy_state_data( + struct llama_context * ctx, + uint8_t * dst), + "use llama_state_get_data instead"); + + // Set the state reading from the specified address + // Returns the number of bytes read + LLAMA_API size_t llama_state_set_data( + struct llama_context * ctx, + const uint8_t * src, + size_t size); + LLAMA_API DEPRECATED(size_t llama_set_state_data( + struct llama_context * ctx, + const uint8_t * src), + "use llama_state_set_data instead"); + + // Save/load session file + LLAMA_API bool llama_state_load_file( + struct llama_context * ctx, + const char * path_session, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out); + LLAMA_API DEPRECATED(bool llama_load_session_file( + struct llama_context * ctx, + const char * path_session, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out), + "use llama_state_load_file instead"); + + LLAMA_API bool llama_state_save_file( + struct llama_context * ctx, + const char * path_session, + const llama_token * tokens, + size_t n_token_count); + LLAMA_API DEPRECATED(bool llama_save_session_file( + struct llama_context * ctx, + const char * path_session, + const llama_token * tokens, + size_t n_token_count), + "use llama_state_save_file instead"); + + // Get the exact size needed to copy the state of a single sequence + LLAMA_API size_t llama_state_seq_get_size( + struct llama_context * ctx, + llama_seq_id seq_id); + + // Copy the state of a single sequence into the specified buffer + LLAMA_API size_t llama_state_seq_get_data( + struct llama_context * ctx, + uint8_t * dst, + size_t size, + llama_seq_id seq_id); + + // Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence + // Returns: + // - Positive: Ok + // - Zero: Failed to load + LLAMA_API size_t llama_state_seq_set_data( + struct llama_context * ctx, + const uint8_t * src, + size_t size, + llama_seq_id dest_seq_id); + + LLAMA_API size_t llama_state_seq_save_file( + struct llama_context * ctx, + const char * filepath, + llama_seq_id seq_id, + const llama_token * tokens, + size_t n_token_count); + + LLAMA_API size_t llama_state_seq_load_file( + struct llama_context * ctx, + const char * filepath, + llama_seq_id dest_seq_id, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out); + + // + // Decoding + // + + // Return batch for single sequence of tokens + // The sequence ID will be fixed to 0 + // The position of the tokens will be tracked automatically by llama_decode + // + // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it + // + LLAMA_API struct llama_batch llama_batch_get_one( + llama_token * tokens, + int32_t n_tokens); + + // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens + // Each token can be assigned up to n_seq_max sequence ids + // The batch has to be freed with llama_batch_free() + // If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float) + // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token + // The rest of the llama_batch members are allocated with size n_tokens + // All members are left uninitialized + LLAMA_API struct llama_batch llama_batch_init( + int32_t n_tokens, + int32_t embd, + int32_t n_seq_max); + + // Frees a batch of tokens allocated with llama_batch_init() + LLAMA_API void llama_batch_free(struct llama_batch batch); + + // Process a batch of tokens. + // In contrast to llama_decode() - this call does not use KV cache. + // For encode-decoder contexts, processes the batch using the encoder. + // Can store the encoder output internally for later use by the decoder's cross-attention layers. + // 0 - success + // < 0 - error. the memory state is restored to the state before this call + LLAMA_API int32_t llama_encode( + struct llama_context * ctx, + struct llama_batch batch); + + // Process a batch of tokens. + // Requires the context to have a memory. + // For encode-decoder contexts, processes the batch using the decoder. + // Positive return values does not mean a fatal error, but rather a warning. + // Upon non-zero return values, the memory state is restored to the state before this call + // 0 - success + // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) + // 2 - aborted + // -1 - invalid input batch + // < -1 - error + LLAMA_API int32_t llama_decode( + struct llama_context * ctx, + struct llama_batch batch); + + // Set the number of threads used for decoding + // n_threads is the number of threads used for generation (single token) + // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) + LLAMA_API void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch); + + // Get the number of threads used for generation of a single token. + LLAMA_API int32_t llama_n_threads(struct llama_context * ctx); + + // Get the number of threads used for prompt and batch processing (multiple token). + LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx); + + // Set whether the context outputs embeddings or not + // TODO: rename to avoid confusion with llama_get_embeddings() + LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings); + + // Set whether to use causal attention or not + // If set to true, the model will only attend to the past tokens + LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn); + + // Set whether the model is in warmup mode or not + // If true, all model tensors are activated during llama_decode() to load and cache their weights. + LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup); + + // Set abort callback + LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data); + + // Wait until all computations are finished + // This is automatically done when using one of the functions below to obtain the computation results + // and is not necessary to call it explicitly in most cases + LLAMA_API void llama_synchronize(struct llama_context * ctx); + + // Token logits obtained from the last call to llama_decode() + // The logits for which llama_batch.logits[i] != 0 are stored contiguously + // in the order they have appeared in the batch. + // Rows: number of tokens for which llama_batch.logits[i] != 0 + // Cols: n_vocab + LLAMA_API float * llama_get_logits(struct llama_context * ctx); + + // Logits for the ith token. For positive indices, Equivalent to: + // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab + // Negative indicies can be used to access logits in reverse order, -1 is the last logit. + // returns NULL for invalid ids. + LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); + + // Get all output token embeddings. + // when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model, + // the embeddings for which llama_batch.logits[i] != 0 are stored contiguously + // in the order they have appeared in the batch. + // shape: [n_outputs*n_embd] + // Otherwise, returns NULL. + LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); + + // Get the embeddings for the ith token. For positive indices, Equivalent to: + // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd + // Negative indicies can be used to access embeddings in reverse order, -1 is the last embedding. + // shape: [n_embd] (1-dimensional) + // returns NULL for invalid ids. + LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); + + // Get the embeddings for a sequence id + // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE + // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[n_cls_out] with the rank(s) of the sequence + // otherwise: float[n_embd] (1-dimensional) + LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); + + // + // Vocab + // + + LLAMA_API const char * llama_vocab_get_text(const struct llama_vocab * vocab, llama_token token); + + LLAMA_API float llama_vocab_get_score(const struct llama_vocab * vocab, llama_token token); + + LLAMA_API enum llama_token_attr llama_vocab_get_attr(const struct llama_vocab * vocab, llama_token token); + + // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.) + LLAMA_API bool llama_vocab_is_eog(const struct llama_vocab * vocab, llama_token token); + + // Identify if Token Id is a control token or a render-able token + LLAMA_API bool llama_vocab_is_control(const struct llama_vocab * vocab, llama_token token); + + // Special tokens + LLAMA_API llama_token llama_vocab_bos(const struct llama_vocab * vocab); // beginning-of-sentence + LLAMA_API llama_token llama_vocab_eos(const struct llama_vocab * vocab); // end-of-sentence + LLAMA_API llama_token llama_vocab_eot(const struct llama_vocab * vocab); // end-of-turn + LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator + LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line + LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding + + LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab); + LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab); + + LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_mid(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_pad(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab); + + DEPRECATED(LLAMA_API const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_text instead"); + DEPRECATED(LLAMA_API float llama_token_get_score(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_score instead"); + DEPRECATED(LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_attr instead"); + DEPRECATED(LLAMA_API bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_eog instead"); + DEPRECATED(LLAMA_API bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_control instead"); + DEPRECATED(LLAMA_API llama_token llama_token_bos(const struct llama_vocab * vocab), "use llama_vocab_bos instead"); + DEPRECATED(LLAMA_API llama_token llama_token_eos(const struct llama_vocab * vocab), "use llama_vocab_eos instead"); + DEPRECATED(LLAMA_API llama_token llama_token_eot(const struct llama_vocab * vocab), "use llama_vocab_eot instead"); + DEPRECATED(LLAMA_API llama_token llama_token_cls(const struct llama_vocab * vocab), "use llama_vocab_cls instead"); + DEPRECATED(LLAMA_API llama_token llama_token_sep(const struct llama_vocab * vocab), "use llama_vocab_sep instead"); + DEPRECATED(LLAMA_API llama_token llama_token_nl (const struct llama_vocab * vocab), "use llama_vocab_nl instead"); + DEPRECATED(LLAMA_API llama_token llama_token_pad(const struct llama_vocab * vocab), "use llama_vocab_pad instead"); + DEPRECATED(LLAMA_API bool llama_add_bos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_bos instead"); + DEPRECATED(LLAMA_API bool llama_add_eos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_eos instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_pre(const struct llama_vocab * vocab), "use llama_vocab_fim_pre instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_suf(const struct llama_vocab * vocab), "use llama_vocab_fim_suf instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_mid(const struct llama_vocab * vocab), "use llama_vocab_fim_mid instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_pad(const struct llama_vocab * vocab), "use llama_vocab_fim_pad instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_rep(const struct llama_vocab * vocab), "use llama_vocab_fim_rep instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_sep(const struct llama_vocab * vocab), "use llama_vocab_fim_sep instead"); + + // CLS is equivalent to BOS + DEPRECATED(LLAMA_API llama_token llama_vocab_cls(const struct llama_vocab * vocab), // classification + "use llama_vocab_bos instead"); + + // + // Tokenization + // + // The API is thread-safe. + // + + /// @details Convert the provided text into tokens. + /// @param tokens The tokens pointer must be large enough to hold the resulting tokens. + /// @return Returns the number of tokens on success, no more than n_tokens_max + /// @return Returns a negative number on failure - the number of tokens that would have been returned + /// @param add_special Allow to add BOS and EOS tokens if model is configured to do so. + /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated + /// as plaintext. Does not insert a leading space. + LLAMA_API int32_t llama_tokenize( + const struct llama_vocab * vocab, + const char * text, + int32_t text_len, + llama_token * tokens, + int32_t n_tokens_max, + bool add_special, + bool parse_special); + + // Token Id -> Piece. + // Uses the vocabulary in the provided context. + // Does not write null terminator to the buffer. + // User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix') + // @param special If true, special tokens are rendered in the output. + LLAMA_API int32_t llama_token_to_piece( + const struct llama_vocab * vocab, + llama_token token, + char * buf, + int32_t length, + int32_t lstrip, + bool special); + + /// @details Convert the provided tokens into text (inverse of llama_tokenize()). + /// @param text The char pointer must be large enough to hold the resulting text. + /// @return Returns the number of chars/bytes on success, no more than text_len_max. + /// @return Returns a negative number on failure - the number of chars/bytes that would have been returned. + /// @param remove_special Allow to remove BOS and EOS tokens if model is configured to do so. + /// @param unparse_special If true, special tokens are rendered in the output. + LLAMA_API int32_t llama_detokenize( + const struct llama_vocab * vocab, + const llama_token * tokens, + int32_t n_tokens, + char * text, + int32_t text_len_max, + bool remove_special, + bool unparse_special); + + // + // Chat templates + // + + /// Apply chat template. Inspired by hf apply_chat_template() on python. + /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" + /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template + /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead. + /// @param chat Pointer to a list of multiple llama_chat_message + /// @param n_msg Number of llama_chat_message in this chat + /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message. + /// @param buf A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages) + /// @param length The size of the allocated buffer + /// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. + LLAMA_API int32_t llama_chat_apply_template( + const char * tmpl, + const struct llama_chat_message * chat, + size_t n_msg, + bool add_ass, + char * buf, + int32_t length); + + // Get list of built-in chat templates + LLAMA_API int32_t llama_chat_builtin_templates(const char ** output, size_t len); + + // + // Sampling API + // + // Sample usage: + // + // // prepare the sampling chain at the start + // auto sparams = llama_sampler_chain_default_params(); + // + // llama_sampler * smpl = llama_sampler_chain_init(sparams); + // + // llama_sampler_chain_add(smpl, llama_sampler_init_top_k(50)); + // llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1)); + // llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.8)); + // + // // typically, the chain should end with a sampler such as "greedy", "dist" or "mirostat" + // // this sampler will be responsible to select the actual token + // llama_sampler_chain_add(smpl, llama_sampler_init_dist(seed)); + // + // ... + // + // // decoding loop: + // while (...) { + // ... + // + // llama_decode(ctx, batch); + // + // // sample from the logits of the last token in the batch + // const llama_token id = llama_sampler_sample(smpl, ctx, -1); + // + // // accepting the token updates the internal state of certain samplers (e.g. grammar, repetition, etc.) + // llama_sampler_accept(smpl, id); + // ... + // } + // + // llama_sampler_free(smpl); + // + // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). + // + + typedef void * llama_sampler_context_t; + + // user code can implement the interface below in order to create custom llama_sampler + struct llama_sampler_i { + const char * (*name) (const struct llama_sampler * smpl); // can be NULL + void (*accept)( struct llama_sampler * smpl, llama_token token); // can be NULL + void (*apply) ( struct llama_sampler * smpl, llama_token_data_array * cur_p); // required + void (*reset) ( struct llama_sampler * smpl); // can be NULL + struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL + void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL + + // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph + //void (*apply_ggml) (struct llama_sampler * smpl, ...); + }; + + struct llama_sampler { + const struct llama_sampler_i * iface; + llama_sampler_context_t ctx; + }; + + // mirror of llama_sampler_i: + LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx); + LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); + LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); + LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); + LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl); + LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl); + // important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add) + LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); + + // llama_sampler_chain + // a type of llama_sampler that can chain multiple samplers one after another + + LLAMA_API struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params); + + // important: takes ownership of the sampler object and will free it when llama_sampler_free is called + LLAMA_API void llama_sampler_chain_add( struct llama_sampler * chain, struct llama_sampler * smpl); + LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i); + LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain); + + // after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed + LLAMA_API struct llama_sampler * llama_sampler_chain_remove( struct llama_sampler * chain, int32_t i); + + // available samplers: + + LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void); + LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed); + + /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. + /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first. + DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void), + "will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)"); + + /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + /// Setting k <= 0 makes this a noop + LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k); + + /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, size_t min_keep); + + /// @details Minimum P sampling as described in https://github.com/ggml-org/llama.cpp/pull/3841 + LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep); + + /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. + LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep); + + /// #details Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf + LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t); + + /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772. + LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent); + + /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335 + LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed); + + /// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641 + LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(float n); + + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. + /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. + /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. + /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + LLAMA_API struct llama_sampler * llama_sampler_init_mirostat( + int32_t n_vocab, + uint32_t seed, + float tau, + float eta, + int32_t m); + + /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. + /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. + /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + LLAMA_API struct llama_sampler * llama_sampler_init_mirostat_v2( + uint32_t seed, + float tau, + float eta); + + /// @details Intializes a GBNF grammar, see grammars/README.md for details. + /// @param vocab The vocabulary that this grammar will be used with. + /// @param grammar_str The production rules for the grammar, encoded as a string. Returns an empty grammar if empty. Returns NULL if parsing of grammar_str fails. + /// @param grammar_root The name of the start symbol for the grammar. + LLAMA_API struct llama_sampler * llama_sampler_init_grammar( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root); + + DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens), + "use llama_sampler_init_grammar_lazy_patterns instead"); + + + /// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639 + /// @param trigger_patterns A list of patterns that will trigger the grammar sampler. Pattern will be matched from the start of the generation output, and grammar sampler will be fed content starting from its first match group. + /// @param trigger_tokens A list of tokens that will trigger the grammar sampler. Grammar sampler will be fed content starting from the trigger token included. + LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy_patterns( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_patterns, + size_t num_trigger_patterns, + const llama_token * trigger_tokens, + size_t num_trigger_tokens); + + + /// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first. + LLAMA_API struct llama_sampler * llama_sampler_init_penalties( + int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat, // 1.0 = disabled + float penalty_freq, // 0.0 = disabled + float penalty_present); // 0.0 = disabled + + /// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982 + LLAMA_API struct llama_sampler * llama_sampler_init_dry( + const struct llama_vocab * vocab, + int32_t n_ctx_train, + float dry_multiplier, + float dry_base, + int32_t dry_allowed_length, + int32_t dry_penalty_last_n, + const char ** seq_breakers, + size_t num_breakers); + + LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( + int32_t n_vocab, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias); + + // this sampler is meant to be used for fill-in-the-middle infilling + // it's supposed to be used after top_k + top_p sampling + // + // 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG + // 2. combine probs of tokens that have the same prefix + // + // example: + // + // - before: + // "hel": 0.5 + // "hell": 0.2 + // "hello": 0.1 + // "dummy": 0.1 + // + // - after: + // "hel": 0.8 + // "dummy": 0.1 + // + // 3. discard non-EOG tokens with low prob + // 4. if no tokens are left -> pick EOT + // + LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab); + + // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise + LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl); + + /// @details Sample and accept a token from the idx-th output of the last evaluation + // + // Shorthand for: + // const auto * logits = llama_get_logits_ith(ctx, idx); + // llama_token_data_array cur_p = { ... init from logits ... }; + // llama_sampler_apply(smpl, &cur_p); + // auto token = cur_p.data[cur_p.selected].id; + // llama_sampler_accept(smpl, token); + // return token; + // Returns the sampled token + LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx); + + // TODO: extend in the future + //LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...); + + // + // Model split + // + + /// @details Build a split GGUF final path for this chunk. + /// llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf" + // Returns the split_path length. + LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count); + + /// @details Extract the path prefix from the split_path if and only if the split_no and split_count match. + /// llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0" + // Returns the split_prefix length. + LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count); + + // Print system information + LLAMA_API const char * llama_print_system_info(void); + + // Set callback for all future logging events. + // If this is not called, or NULL is supplied, everything is output on stderr. + LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data); + + // + // Performance utils + // + // NOTE: Used by llama.cpp examples, avoid using in third-party apps. Instead, do your own performance measurements. + // + + struct llama_perf_context_data { + double t_start_ms; + double t_load_ms; + double t_p_eval_ms; + double t_eval_ms; + + int32_t n_p_eval; + int32_t n_eval; + }; + + struct llama_perf_sampler_data { + double t_sample_ms; + + int32_t n_sample; + }; + + LLAMA_API struct llama_perf_context_data llama_perf_context (const struct llama_context * ctx); + LLAMA_API void llama_perf_context_print(const struct llama_context * ctx); + LLAMA_API void llama_perf_context_reset( struct llama_context * ctx); + + // NOTE: the following work only with samplers constructed via llama_sampler_chain_init + LLAMA_API struct llama_perf_sampler_data llama_perf_sampler (const struct llama_sampler * chain); + LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain); + LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain); + + // + // training + // + + // function that returns whether or not a given tensor contains trainable parameters + typedef bool (*llama_opt_param_filter)(const struct ggml_tensor * tensor, void * userdata); + + // always returns true + LLAMA_API bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata); + + struct llama_opt_params { + uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0 + + llama_opt_param_filter param_filter; // callback for determining which tensors contain trainable parameters + void * param_filter_ud; // userdata for determining which tensors contain trainable parameters + + ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters + void * get_opt_pars_ud; // userdata for calculating optimizer parameters + }; + + LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params); + + LLAMA_API void llama_opt_epoch( + struct llama_context * lctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, + ggml_opt_result_t result_eval, + int64_t idata_split, + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval); + +#ifdef __cplusplus +} +#endif + +#endif // LLAMA_H diff --git a/licenses/LICENSE-curl b/licenses/LICENSE-curl new file mode 100644 index 0000000000000000000000000000000000000000..da9c038253092b4992247bae77088eddfc1eca53 --- /dev/null +++ b/licenses/LICENSE-curl @@ -0,0 +1,9 @@ +Copyright (c) 1996 - 2025, Daniel Stenberg, daniel@haxx.se, and many contributors, see the THANKS file. + +All rights reserved. + +Permission to use, copy, modify, and distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +Except as contained in this notice, the name of a copyright holder shall not be used in advertising or otherwise to promote the sale, use or other dealings in this Software without prior written authorization of the copyright holder. diff --git a/licenses/LICENSE-httplib b/licenses/LICENSE-httplib new file mode 100644 index 0000000000000000000000000000000000000000..47c418e0726767c755c592cab89164219b45c43a --- /dev/null +++ b/licenses/LICENSE-httplib @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2017 yhirose + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/licenses/LICENSE-jsonhpp b/licenses/LICENSE-jsonhpp new file mode 100644 index 0000000000000000000000000000000000000000..b5a10275c1cdff6d2d1b4d84f51922d42499dec5 --- /dev/null +++ b/licenses/LICENSE-jsonhpp @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2013-2025 Niels Lohmann + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/licenses/LICENSE-linenoise b/licenses/LICENSE-linenoise new file mode 100644 index 0000000000000000000000000000000000000000..b006b3b24dcf79007f257178d29a52e5c7fe67fd --- /dev/null +++ b/licenses/LICENSE-linenoise @@ -0,0 +1,26 @@ +Copyright (c) 2010-2014, Salvatore Sanfilippo +Copyright (c) 2010-2013, Pieter Noordhuis +Copyright (c) 2025, Eric Curtin + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/media/llama0-banner.png b/media/llama0-banner.png new file mode 100644 index 0000000000000000000000000000000000000000..d9d38839076050c68372eebac609000b939a946e --- /dev/null +++ b/media/llama0-banner.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f15963f0d41f76ff4fb341d919c5dd0e01c22d8c839935d0ef390ea968ab7a1c +size 144615 diff --git a/media/llama0-logo.png b/media/llama0-logo.png new file mode 100644 index 0000000000000000000000000000000000000000..7cd462b42b1babf59663670f891f5422cb0a6e0f --- /dev/null +++ b/media/llama0-logo.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4aa1256ae1b68367e8ad32d4f47d8b69c29be35a4f2c01c7fa3f45f9fdae16f8 +size 179940 diff --git a/media/llama1-banner.png b/media/llama1-banner.png new file mode 100644 index 0000000000000000000000000000000000000000..1e469584e0cea32f7949fd061d2dd64e2753026a Binary files /dev/null and b/media/llama1-banner.png differ diff --git a/media/llama1-logo.png b/media/llama1-logo.png new file mode 100644 index 0000000000000000000000000000000000000000..365c5b865f3f4518bcc080bf685b7a55f414938a Binary files /dev/null and b/media/llama1-logo.png differ diff --git a/media/llama1-logo.svg b/media/llama1-logo.svg new file mode 100644 index 0000000000000000000000000000000000000000..e080481fa67c3981fa7392435febf013a35728e6 --- /dev/null +++ b/media/llama1-logo.svg @@ -0,0 +1,34 @@ + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/media/matmul.png b/media/matmul.png new file mode 100644 index 0000000000000000000000000000000000000000..c28ba0d1b1fb309e31e4335dc4ca26c0889020f6 --- /dev/null +++ b/media/matmul.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0d6484cc146bbfcc1e8eae12ed0290057bfabeb7944d24d15d5d78995364bde4 +size 265705 diff --git a/media/matmul.svg b/media/matmul.svg new file mode 100644 index 0000000000000000000000000000000000000000..1d6cb4bb78a2206b06bd613723d8b221b2f01243 --- /dev/null +++ b/media/matmul.svg @@ -0,0 +1,1238 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ARow-major + BTColumn-major + CT=ABTColumn-major + + ne00 + + ne01 + + ne1 + + ne0 + + ne10 + + ne11 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + BRow-major + ATColumn-major + C=BATRow-major + + ne10 + + ne11 + + ne0 + + ne1 + + ne00 + + ne01 + + + diff --git a/models/.editorconfig b/models/.editorconfig new file mode 100644 index 0000000000000000000000000000000000000000..78b36ca0838fc0e4747c102a9b16154d7d481c2b --- /dev/null +++ b/models/.editorconfig @@ -0,0 +1 @@ +root = true diff --git a/models/ggml-vocab-aquila.gguf b/models/ggml-vocab-aquila.gguf new file mode 100644 index 0000000000000000000000000000000000000000..1d28649d92732d9cea3df1ae3c88542febd8f99a --- /dev/null +++ b/models/ggml-vocab-aquila.gguf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7c53c3c516ac67c7ca12977b9690fdea3d2ef13bbaed6378f98191a13ef5ca00 +size 4825676 diff --git a/models/ggml-vocab-baichuan.gguf b/models/ggml-vocab-baichuan.gguf new file mode 100644 index 0000000000000000000000000000000000000000..024492cfd361c3dcc79bd9a72f04beb38eb72439 --- /dev/null +++ b/models/ggml-vocab-baichuan.gguf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f5b955697f3bd3108070b1d5936c7eb9fc542b81c6932e59abddec75bca1963 +size 1340998 diff --git a/models/ggml-vocab-bert-bge.gguf b/models/ggml-vocab-bert-bge.gguf new file mode 100644 index 0000000000000000000000000000000000000000..0c12abe39ef7abb07aa0b5dcd9457cf4aa3135d0 --- /dev/null +++ b/models/ggml-vocab-bert-bge.gguf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fbcbe22278fb302694d5f4a41bfe48c5f90e8e3554eab1c0435387dff654a854 +size 627549 diff --git a/models/ggml-vocab-bert-bge.gguf.inp b/models/ggml-vocab-bert-bge.gguf.inp new file mode 100644 index 0000000000000000000000000000000000000000..86b934e4020fba7061c15b6349d5e4163bb2383e --- /dev/null +++ b/models/ggml-vocab-bert-bge.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Äpfel +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-bert-bge.gguf.out b/models/ggml-vocab-bert-bge.gguf.out new file mode 100644 index 0000000000000000000000000000000000000000..b1c49672fadad845bad979532cf6491189d83c08 --- /dev/null +++ b/models/ggml-vocab-bert-bge.gguf.out @@ -0,0 +1,46 @@ + 29464 2094 1018 1092 2706 + 9706 7959 2140 + + + + + + + + + + 7592 2088 + 7592 2088 + 7592 2088 + 7592 2088 + 7592 2088 999 + 7592 1010 2088 999 + 7592 1010 2088 999 + 2023 2003 100 1012 18133 2361 + 1059 2692 18139 1021 8525 28418 2243 16233 20952 6979 + 1192 15290 29754 14150 1192 10260 1181 29755 29436 29741 10260 16856 29747 23925 10325 + 100 + 100 1006 3671 1007 100 1006 3674 7861 29147 2483 9530 16280 23854 1007 100 1006 2069 7861 29147 2072 2008 2038 2049 2219 19204 1007 + 7592 + 7592 + 7592 + 7592 + 7592 + 7592 7592 + 1006 + 1027 + 1005 3690 + 7592 1010 1061 1005 2035 999 2129 2024 2017 100 1029 1855 100 100 6207 100 100 14677 23632 22203 1811 1995 + 999 999 999 999 999 999 + 1017 + 3943 + 21211 + 21211 2509 + 21211 22394 + 21211 22394 2509 + 21211 22394 22394 + 21211 22394 22394 2509 + 21211 22394 22394 22394 + 12731 2050 19710 + 5860 18117 + 100 1006 3671 1007 100 1006 3674 7861 29147 2483 9530 16280 23854 1007 100 100 1017 3943 21211 21211 2509 21211 22394 21211 22394 2509 21211 22394 22394 21211 22394 22394 2509 1017 1012 1017 1017 1012 1012 1017 1017 1012 1012 1012 1017 100 1029 1855 100 100 6207 100 100 14677 23632 22203 1811 1995 1011 1011 1011 1011 1011 1011 1027 1027 1027 1027 1027 1027 1027 1192 15290 29754 14150 1192 10260 1181 29755 29436 29741 10260 16856 29747 23925 10325 1005 1005 1005 1005 1005 1005 1036 1036 1036 1036 1036 1036 1036 1000 1000 1000 1000 1012 1012 1012 1012 1012 1012 999 999 999 999 999 999 1029 1029 1029 1029 1029 1029 1045 1005 2310 2042 1005 2409 2002 1005 1055 2045 1010 1005 2128 2017 2469 1029 1005 1049 2025 2469 1045 1005 2222 2191 2009 1010 1005 1040 2017 2066 2070 5572 1029 2057 1005 2310 1037 1005 2222 diff --git a/models/ggml-vocab-command-r.gguf b/models/ggml-vocab-command-r.gguf new file mode 100644 index 0000000000000000000000000000000000000000..91656898fc12b3038b6ecf1c716fc61b6b88ab09 --- /dev/null +++ b/models/ggml-vocab-command-r.gguf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a2f8cfea952ef7c391a6d92a1c309d0bd32e36384d9b9230569a7425732f27d9 +size 10874545 diff --git a/models/ggml-vocab-command-r.gguf.inp b/models/ggml-vocab-command-r.gguf.inp new file mode 100644 index 0000000000000000000000000000000000000000..86b934e4020fba7061c15b6349d5e4163bb2383e --- /dev/null +++ b/models/ggml-vocab-command-r.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Äpfel +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-command-r.gguf.out b/models/ggml-vocab-command-r.gguf.out new file mode 100644 index 0000000000000000000000000000000000000000..0e3af72eb1c23471ec7ed292fa7fbab89f8371b8 --- /dev/null +++ b/models/ggml-vocab-command-r.gguf.out @@ -0,0 +1,46 @@ + 2536 228 27 228 22957 6983 + 90711 87 20910 + + 228 + 1667 + 1742 + 205 + 206 + 2126 + 11516 + 34777 + 28339 3845 + 46609 3845 + 28339 3930 + 46609 3930 + 46609 3930 8 + 28339 19 3845 8 + 46609 19 3845 8 + 2075 1801 11254 107 255 21 19317 + 94 23 27 31 228 30 21213 20752 39267 6405 9980 + 4929 40071 2196 3236 8750 1764 37097 41168 + 38111 230 174833 38111 249 86325 241 38111 245 86325 232 38111 252 38111 123 38111 261 165 24629 38111 261 38111 103 174833 38111 235 38111 231 38111 257 38111 235 165 24629 38111 239 + 2226 256 230 1737 18258 16 80503 122 35927 2226 242 112 57462 1737 54457 223165 106230 2096 16 48389 1737 10203 109160 1875 2222 2517 3342 12523 16 + 28339 + 46609 + 228 46609 + 1667 46609 + 1742 46609 + 1742 46609 1856 46609 + 1737 + 206 1857 + 14 4515 + 28339 19 1770 14 1954 8 4070 1955 1933 80503 231 5691 12081 13336 2648 29325 14315 24 26 24 27 24 28 24 5123 18372 + 57178 10251 + 26 + 26 26 + 26 26 26 + 26 26 26 26 + 26 26 26 26 26 + 26 26 26 26 26 26 + 26 26 26 26 26 26 26 + 26 26 26 26 26 26 26 26 + 26 26 26 26 26 26 26 26 26 + 42 30719 12584 + 3642 4388 + 127731 51628 205 57788 18494 97469 126134 206 2226 256 230 1737 18258 16 80503 122 35927 2226 242 112 57462 1737 54457 223165 106230 2096 16 48389 11254 107 255 2226 107 255 228 26 228 26 26 228 26 26 26 228 26 26 26 26 228 26 26 26 26 26 228 26 26 26 26 26 26 228 26 26 26 26 26 26 26 228 26 26 26 26 26 26 26 26 228 26 21 26 228 26 2271 26 228 26 3834 26 182018 230 174833 38111 249 86325 241 38111 245 86325 232 38111 252 38111 123 38111 261 165 24629 38111 261 38111 103 174833 38111 235 188568 231 5691 12081 13336 2648 29325 14315 24 26 24 27 24 28 24 5123 18372 8391 158343 3512 40071 2196 3236 8750 1764 37097 41168 29721 32797 25646 3802 4975 4975 116167 57178 10251 154048 27292 1767 5125 2632 2155 91 2378 1919 1914 2782 19 2155 3354 1933 5470 38 2155 52 2068 5470 1767 4961 3059 1894 19 2155 43 1933 3026 2725 23186 38 2930 14 20676 1671 14 83 51 diff --git a/models/ggml-vocab-deepseek-coder.gguf b/models/ggml-vocab-deepseek-coder.gguf new file mode 100644 index 0000000000000000000000000000000000000000..b04617a314846e33cf088107f4854933105dc4fa --- /dev/null +++ b/models/ggml-vocab-deepseek-coder.gguf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91cb1379f2e33af1c4866b194622b7a0e12e8f0c9dba7ba2f10d55978730bec1 +size 1156067 diff --git a/models/ggml-vocab-deepseek-coder.gguf.inp b/models/ggml-vocab-deepseek-coder.gguf.inp new file mode 100644 index 0000000000000000000000000000000000000000..86b934e4020fba7061c15b6349d5e4163bb2383e --- /dev/null +++ b/models/ggml-vocab-deepseek-coder.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Äpfel +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-deepseek-coder.gguf.out b/models/ggml-vocab-deepseek-coder.gguf.out new file mode 100644 index 0000000000000000000000000000000000000000..ef6bc5b8a3776d4febec2b43f0f8d57b1370206c --- /dev/null +++ b/models/ggml-vocab-deepseek-coder.gguf.out @@ -0,0 +1,46 @@ + 1050 207 19 207 19192 4217 + 125 213 26862 282 + + 207 + 243 + 315 + 184 + 185 + 185 185 + 185 185 185 + 184 185 + 17535 1835 + 414 9489 1835 + 17535 5414 + 414 9489 5414 + 414 9489 5414 0 + 17535 11 1835 0 + 414 9489 11 1835 0 + 437 317 12394 99 234 13 14789 + 86 15 19 23 207 22 83 3963 27659 26078 3934 14072 + 1593 6478 616 2251 14994 + 155 239 209 155 239 114 155 239 228 155 240 220 155 239 224 155 240 211 155 239 231 155 239 115 155 239 240 155 240 210 155 239 240 155 239 95 155 239 114 155 239 214 155 239 210 155 239 236 155 239 214 155 240 210 155 239 218 + 10047 235 209 334 8760 8 12394 233 114 350 222 10047 221 104 169 116 224 334 4684 3909 992 24330 262 29651 612 8 207 156 237 214 334 5950 992 78 12896 344 638 891 1372 10736 8 + 17535 + 414 9489 + 207 414 9489 + 243 414 9489 + 315 414 9489 + 315 414 9489 185 315 414 9489 + 334 + 185 405 + 6 2895 + 17535 11 320 6 435 0 1717 417 340 12394 233 210 3015 19100 608 9413 2668 16 18 16 19 16 20 16 1393 169 121 239 + 15330 3023 + 18 + 18 18 + 18 18 18 + 18 18 18 18 + 18 18 18 18 18 + 18 18 18 18 18 18 + 18 18 18 18 18 18 18 + 18 18 18 18 18 18 18 18 + 18 18 18 18 18 18 18 18 18 + 34 155 119 242 64 24297 155 119 216 83 + 1607 2539 + 185 207 185 185 207 185 185 185 207 12405 459 22758 185 243 185 315 185 251 185 730 185 10047 235 209 334 8760 8 12394 233 114 350 222 10047 221 104 169 116 224 334 4684 3909 992 24330 262 29651 612 8 207 156 237 214 12394 99 234 10047 99 234 207 18 207 18 18 207 18 18 18 207 18 18 18 18 207 18 18 18 18 18 207 18 18 18 18 18 18 207 18 18 18 18 18 18 18 207 18 18 18 18 18 18 18 18 207 18 13 18 207 18 524 18 207 18 1202 18 207 155 239 209 155 239 114 155 239 228 155 240 220 155 239 224 155 240 211 155 239 231 155 239 115 155 239 240 155 240 210 155 239 240 155 239 95 155 239 114 155 239 214 10047 233 210 3015 19100 608 9413 2668 16 18 16 19 16 20 16 1393 169 121 239 18155 374 17194 28 2861 6478 616 2251 14994 31269 4191 6 4686 4686 10252 3358 3358 3409 524 15330 3023 15031 5668 303 6 312 798 651 83 839 362 6 82 741 11 651 1369 340 2037 30 651 44 441 2037 303 6 642 1098 359 11 651 35 340 833 738 10860 30 998 6 10709 245 6 75 43 diff --git a/models/ggml-vocab-deepseek-llm.gguf b/models/ggml-vocab-deepseek-llm.gguf new file mode 100644 index 0000000000000000000000000000000000000000..8b821b95edba0591c11d05e308835ff84a9fa8c6 --- /dev/null +++ b/models/ggml-vocab-deepseek-llm.gguf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:867f77537b54565f0d81d508c04edc41aa1d4ffc1a92745f225b4c1b02755f76 +size 3970167 diff --git a/models/ggml-vocab-deepseek-llm.gguf.inp b/models/ggml-vocab-deepseek-llm.gguf.inp new file mode 100644 index 0000000000000000000000000000000000000000..86b934e4020fba7061c15b6349d5e4163bb2383e --- /dev/null +++ b/models/ggml-vocab-deepseek-llm.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Äpfel +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-deepseek-llm.gguf.out b/models/ggml-vocab-deepseek-llm.gguf.out new file mode 100644 index 0000000000000000000000000000000000000000..f9d49c9afe703703ce26bf30bc4a0dd52c4ffbbe --- /dev/null +++ b/models/ggml-vocab-deepseek-llm.gguf.out @@ -0,0 +1,46 @@ + 1052 207 19 207 19109 4223 + 82077 26723 282 + + 207 + 243 + 300 + 184 + 185 + 185 185 + 185 185 185 + 184 185 + 17464 1843 + 37727 1843 + 17464 5427 + 37727 5427 + 37727 5427 0 + 17464 11 1843 0 + 37727 11 1843 0 + 437 317 12356 99 234 13 14743 + 86 15 19 23 207 22 83 3970 27519 26016 3944 14025 + 1603 6476 620 91754 + 71374 209 71374 114 71374 228 155 240 220 71374 224 155 240 211 71374 231 71374 115 71374 240 155 240 210 71374 240 71374 95 71374 114 71374 214 71374 210 71374 236 71374 214 155 240 210 71374 218 + 10044 95300 334 8754 8 33701 114 350 222 10044 221 104 46713 334 34732 996 24250 262 80923 8 207 37103 214 334 5956 89213 344 643 895 1377 10728 8 + 17464 + 37727 + 207 37727 + 243 37727 + 300 37727 + 300 37727 185 300 37727 + 334 + 185 403 + 6 2906 + 17464 11 320 6 436 0 1724 418 340 33701 210 3025 19017 612 9407 2681 16 18 16 19 16 20 16 1398 68940 239 + 15278 3033 + 18 + 18 18 + 18 18 18 + 18 18 18 18 + 18 18 18 18 18 + 18 18 18 18 18 18 + 18 18 18 18 18 18 18 + 18 18 18 18 18 18 18 18 + 18 18 18 18 18 18 18 18 18 + 34 32555 242 64 23708 32555 216 83 + 1763 2550 + 185 207 185 185 207 185 185 185 207 11969 486 22504 185 243 185 300 185 251 185 663 185 10044 95300 334 8754 8 33701 114 350 222 10044 221 104 46713 334 34732 996 24250 262 80923 8 207 37103 214 12356 99 234 10044 99 234 207 18 207 18 18 207 18 18 18 207 18 18 18 18 207 18 18 18 18 18 207 18 18 18 18 18 18 207 18 18 18 18 18 18 18 207 18 18 18 18 18 18 18 18 207 18 13 18 207 18 526 18 207 18 1204 18 207 71374 209 71374 114 71374 228 155 240 220 71374 224 155 240 211 71374 231 71374 115 71374 240 155 240 210 71374 240 71374 95 71374 114 71374 214 71899 210 3025 19017 612 9407 2681 16 18 16 19 16 20 16 1398 68940 239 78827 55170 76659 620 91754 31116 36804 4885 4885 10897 4390 4390 41047 15278 3033 14986 5675 304 6 313 803 655 33326 362 6 82 745 11 655 1374 340 2049 30 655 44 441 2049 304 6 647 1099 359 11 655 35 340 837 742 10842 30 1003 6 10699 245 6 75 43 diff --git a/models/ggml-vocab-falcon.gguf b/models/ggml-vocab-falcon.gguf new file mode 100644 index 0000000000000000000000000000000000000000..1916c29e2e7602678c84dfa555e5153a68f7e27b --- /dev/null +++ b/models/ggml-vocab-falcon.gguf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9f0bf8b0733680398b72e652e90f260f43782f326e75545fc0e49611a5ba35ad +size 2287728 diff --git a/models/ggml-vocab-falcon.gguf.inp b/models/ggml-vocab-falcon.gguf.inp new file mode 100644 index 0000000000000000000000000000000000000000..86b934e4020fba7061c15b6349d5e4163bb2383e --- /dev/null +++ b/models/ggml-vocab-falcon.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Äpfel +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-falcon.gguf.out b/models/ggml-vocab-falcon.gguf.out new file mode 100644 index 0000000000000000000000000000000000000000..6319de60e241269e91f79f6d0ce4bcc63f99a1b4 --- /dev/null +++ b/models/ggml-vocab-falcon.gguf.out @@ -0,0 +1,46 @@ + 878 204 31 3068 133 2137 + 34502 18614 286 + + 204 + 258 + 466 + 192 + 193 + 1001 + 11331 + 19125 + 9856 1079 + 23090 1079 + 9856 2889 + 23090 2889 + 23090 2889 12 + 9856 23 1079 12 + 23090 23 1079 12 + 414 304 3346 111 231 25 29247 + 98 55866 204 34 16682 7149 36190 6869 11481 + 150 133 6207 151 215 150 134 5052 133 6279 5052 223 151 216 49679 123 53110 47043 7795 + 38154 206 38154 126 38154 225 167 237 217 38154 221 167 237 208 38154 228 38154 127 38154 237 167 237 207 38154 237 38154 107 38154 126 38154 211 38154 207 38154 233 38154 211 167 237 207 38154 215 + 2571 232 206 204 19 11003 20 8196 126 283 219 48778 116 13392 204 19 51831 732 63209 1741 7955 522 20 22438 211 204 19 7927 53360 325 504 701 946 10930 20 + 9856 + 23090 + 204 23090 + 258 23090 + 466 23090 + 466 23090 742 23090 + 204 19 + 1212 40 + 18 4932 + 9856 23 291 18 436 12 1265 362 299 8196 207 204 42 50087 123 2727 20300 32022 133 234 17419 30137 28 7858 181 133 236 + 51520 + 30 + 3138 + 22287 + 22287 30 + 22287 3138 + 22287 22287 + 22287 22287 30 + 22287 22287 3138 + 22287 22287 22287 + 46 19768 239 76 9634 19768 213 95 + 1080 1502 + 1212 4824 1001 1212 192 204 663 49453 2069 742 561 1501 193 2571 232 206 204 19 11003 20 8196 126 283 219 48778 116 13392 204 19 51831 732 63209 1741 7955 522 20 22438 211 3346 111 231 2571 111 231 204 30 204 3138 204 22287 204 22287 30 204 22287 3138 204 22287 22287 204 22287 22287 30 204 22287 22287 3138 204 30 25 30 204 30 513 30 204 30 951 30 27171 236 206 38154 126 38154 225 167 237 217 38154 221 167 237 208 38154 228 38154 127 38154 237 167 237 207 38154 237 38154 107 38154 126 38154 211 20589 207 204 42 50087 123 2727 20300 32022 133 234 17419 30137 28 7858 181 133 236 204 37057 2228 10666 5052 133 6207 151 215 150 134 5052 133 6279 5052 223 151 216 49679 123 53110 47043 7795 204 7544 7544 7544 8543 8543 17593 3513 3513 12844 51520 17664 4247 295 18 298 650 204 18 95 693 332 18 94 629 23 204 18 1553 299 1310 42 204 18 56 416 1310 295 18 567 717 334 23 204 18 47 299 606 596 6696 42 703 18 16139 241 18 87 55 diff --git a/models/ggml-vocab-gpt-2.gguf b/models/ggml-vocab-gpt-2.gguf new file mode 100644 index 0000000000000000000000000000000000000000..2745d9f7aa50a3b4e152c9a98dffbc544bdfe005 --- /dev/null +++ b/models/ggml-vocab-gpt-2.gguf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cedc56ca6e2e89f63e781696d1fd76b4b1d49e6720dee86463e915f6e90016ac +size 1766807 diff --git a/models/ggml-vocab-gpt-2.gguf.inp b/models/ggml-vocab-gpt-2.gguf.inp new file mode 100644 index 0000000000000000000000000000000000000000..86b934e4020fba7061c15b6349d5e4163bb2383e --- /dev/null +++ b/models/ggml-vocab-gpt-2.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Äpfel +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-gpt-2.gguf.out b/models/ggml-vocab-gpt-2.gguf.out new file mode 100644 index 0000000000000000000000000000000000000000..6464ded3d276709bc22d49f38f2d799e5828a59f --- /dev/null +++ b/models/ggml-vocab-gpt-2.gguf.out @@ -0,0 +1,46 @@ + 798 604 25208 1933 + 127 226 79 69 417 + + 220 + 220 220 + 220 220 220 + 197 + 198 + 628 + 628 198 + 197 198 + 15496 995 + 18435 995 + 15496 2159 + 18435 2159 + 18435 2159 0 + 15496 11 995 0 + 18435 11 995 0 + 428 318 12520 99 247 13 20322 + 86 47202 767 28047 45961 288 82 7568 13415 + 22177 16843 141 231 15166 12466 121 16142 12466 239 141 232 30143 140 111 16142 21169 21727 31583 18849 + 157 252 222 157 252 114 157 252 241 157 253 233 157 252 237 157 253 224 157 252 244 157 252 115 157 252 253 157 253 223 157 252 253 157 252 95 157 252 114 157 252 227 157 252 223 157 252 249 157 252 227 157 253 223 157 252 231 + 8582 248 222 357 11265 8 30325 114 447 235 8582 234 104 37929 357 48101 795 13210 271 1673 36686 515 8 14519 227 357 8807 44805 326 468 663 898 11241 8 + 15496 + 18435 + 220 18435 + 220 220 18435 + 220 220 220 18435 + 220 220 220 18435 198 220 220 220 18435 + 357 + 198 796 + 6 6980 + 15496 11 331 6 439 0 1374 389 345 30325 223 5633 22755 239 46349 111 28839 101 18040 32432 98 43291 1485 1415 24309 25465 171 121 252 + 13896 3228 + 18 + 2091 + 20370 + 24840 + 2091 20370 + 24840 2091 + 24840 20370 + 24840 24840 + 24840 2091 20370 + 34 157 119 255 64 16049 157 119 229 83 + 1221 1371 + 198 220 628 220 628 198 220 197 220 197 197 220 197 198 220 220 198 220 220 220 198 220 220 220 220 198 220 220 220 220 220 198 8582 248 222 357 11265 8 30325 114 447 235 8582 234 104 37929 357 48101 795 13210 271 1673 36686 515 8 14519 227 12520 99 247 8582 99 247 513 4747 23460 513 20370 23460 2091 23460 20370 23460 24840 23460 2091 20370 513 13 18 513 492 18 513 986 18 28053 252 222 157 252 114 157 252 241 157 253 233 157 252 237 157 253 224 157 252 244 157 252 115 157 252 253 157 253 223 157 252 253 157 252 95 157 252 114 157 252 227 47249 223 5633 22755 239 46349 111 28839 101 18040 32432 98 43291 1485 1415 24309 25465 171 121 252 40103 1421 18604 12466 121 16843 141 231 15166 12466 121 16142 12466 239 141 232 30143 140 111 16142 21169 21727 31583 18849 705 39115 6 33153 15506 63 15931 15931 16317 13896 3228 9805 3548 314 1053 587 705 44040 339 338 612 11 705 2200 345 1654 30 705 44 407 1654 314 1183 787 340 11 705 35 345 588 617 8887 30 775 6 26979 257 6 75 43 diff --git a/models/ggml-vocab-gpt-neox.gguf b/models/ggml-vocab-gpt-neox.gguf new file mode 100644 index 0000000000000000000000000000000000000000..62b552628e043adb7c58163d4375ccc4f0d8653f --- /dev/null +++ b/models/ggml-vocab-gpt-neox.gguf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae593a7f9b8bb174ed4f5019e41530463e4dac7aa06e42dee8aa650d2bdac53d +size 1771431 diff --git a/models/ggml-vocab-llama-bpe.gguf b/models/ggml-vocab-llama-bpe.gguf new file mode 100644 index 0000000000000000000000000000000000000000..0c6826d8a9c8abd7623cafcad23bfab019bce8cc --- /dev/null +++ b/models/ggml-vocab-llama-bpe.gguf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:97272e430d53bc7688f52d5e0ad8ea8f163ede9f1bbd1694feaa504797d5d96e +size 7818140 diff --git a/models/ggml-vocab-llama-bpe.gguf.inp b/models/ggml-vocab-llama-bpe.gguf.inp new file mode 100644 index 0000000000000000000000000000000000000000..86b934e4020fba7061c15b6349d5e4163bb2383e --- /dev/null +++ b/models/ggml-vocab-llama-bpe.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Äpfel +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-llama-bpe.gguf.out b/models/ggml-vocab-llama-bpe.gguf.out new file mode 100644 index 0000000000000000000000000000000000000000..a77376625a2a0f25c2cf158353b5dabf63756c2a --- /dev/null +++ b/models/ggml-vocab-llama-bpe.gguf.out @@ -0,0 +1,46 @@ + 1142 220 19 220 27154 4038 + 88075 16276 301 + + 220 + 256 + 262 + 197 + 198 + 271 + 1432 + 1602 + 9906 1917 + 22691 1917 + 9906 4435 + 22691 4435 + 22691 4435 0 + 9906 11 1917 0 + 22691 11 1917 0 + 420 374 11410 99 247 13 11055 + 86 23904 220 22 83 2005 42908 11729 3013 17156 + 79862 102118 13373 64571 34694 3114 112203 80112 + 21549 222 98629 241 45358 233 21549 237 45358 224 21549 244 21549 115 21549 253 45358 223 21549 253 21549 95 98629 227 21549 223 21549 249 21549 227 45358 223 21549 231 + 9468 248 222 320 8416 8 27623 114 102470 9468 234 104 31643 320 36773 100166 98634 8 26602 227 320 3323 43465 430 706 1202 1866 4037 8 + 9906 + 22691 + 220 22691 + 256 22691 + 262 22691 + 262 22691 198 262 22691 + 320 + 198 284 + 6 11639 + 9906 11 379 65948 0 2650 527 499 27623 223 949 37046 101067 19000 23182 102301 9263 18136 16 36827 21909 + 17523 3001 + 18 + 1644 + 8765 + 8765 18 + 8765 1644 + 8765 8765 + 8765 8765 18 + 8765 8765 1644 + 8765 8765 8765 + 34 91163 101798 + 2624 2402 + 198 4815 15073 66597 8004 1602 2355 79772 11187 9468 248 222 320 8416 8 27623 114 102470 9468 234 104 31643 320 36773 100166 98634 8 26602 227 11410 99 247 9468 99 247 220 18 220 1644 220 8765 220 8765 18 220 8765 1644 220 8765 8765 220 8765 8765 18 220 8765 8765 1644 220 18 13 18 220 18 497 18 220 18 1131 18 220 21549 222 98629 241 45358 233 21549 237 45358 224 21549 244 21549 115 21549 253 45358 223 21549 253 21549 95 98629 227 76460 223 949 37046 101067 19000 23182 102301 9263 18136 16 36827 21909 56560 54337 19175 102118 13373 64571 34694 3114 112203 80112 3436 106451 14196 14196 74694 3089 3089 29249 17523 3001 27708 7801 358 3077 1027 364 83 820 568 596 1070 11 364 793 499 2771 30 364 44 539 2771 358 3358 1304 433 11 364 35 499 1093 1063 15600 30 1226 6 43712 264 64966 43 diff --git a/models/ggml-vocab-llama-spm.gguf b/models/ggml-vocab-llama-spm.gguf new file mode 100644 index 0000000000000000000000000000000000000000..2093285959bf2c0a9f46febf445019645ca222b7 --- /dev/null +++ b/models/ggml-vocab-llama-spm.gguf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:16c3724582d59aa8bf84711894e833f916ee46a31d80e21312759c48bf8d0e69 +size 723869 diff --git a/models/ggml-vocab-llama-spm.gguf.inp b/models/ggml-vocab-llama-spm.gguf.inp new file mode 100644 index 0000000000000000000000000000000000000000..86b934e4020fba7061c15b6349d5e4163bb2383e --- /dev/null +++ b/models/ggml-vocab-llama-spm.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Äpfel +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-llama-spm.gguf.out b/models/ggml-vocab-llama-spm.gguf.out new file mode 100644 index 0000000000000000000000000000000000000000..2a71a6ef86efa9ee78509747de5cd1ceb59c8a5a --- /dev/null +++ b/models/ggml-vocab-llama-spm.gguf.out @@ -0,0 +1,46 @@ + 474 287 29871 29946 29871 30226 7378 + 11585 7810 295 + + 259 + 1678 + 268 + 29871 12 + 29871 13 + 29871 13 13 + 29871 13 13 13 + 29871 12 13 + 15043 3186 + 29871 15043 3186 + 15043 2787 + 29871 15043 2787 + 29871 15043 2787 29991 + 15043 29892 3186 29991 + 29871 15043 29892 3186 29991 + 29871 445 338 29871 243 162 169 156 29889 8223 + 281 29900 29946 29947 29871 29955 9161 13535 18031 2176 6905 + 1538 4851 665 1386 29713 1305 + 29871 31849 31324 31934 228 162 142 228 161 146 228 162 133 228 161 153 228 161 186 31708 228 162 132 31708 228 161 165 31324 228 161 136 228 161 132 228 161 158 228 161 136 228 162 132 228 161 140 + 29871 243 162 157 131 313 8945 29897 29871 243 162 155 185 30722 243 162 143 174 30598 313 20787 953 3848 275 16125 630 29897 29871 31681 313 6194 953 29877 2397 393 756 967 1914 5993 29897 + 15043 + 29871 15043 + 259 15043 + 1678 15043 + 268 15043 + 268 15043 13 1678 15043 + 29871 313 + 29871 13 353 + 525 3152 + 15043 29892 343 29915 497 29991 1128 526 366 29871 243 162 155 132 1577 30672 31522 30505 11548 31041 30732 29896 29941 29896 29946 29896 29945 29896 30408 30739 + 1738 6824 21004 + 29871 29941 + 29871 29941 29941 + 29871 29941 29941 29941 + 29871 29941 29941 29941 29941 + 29871 29941 29941 29941 29941 29941 + 29871 29941 29941 29941 29941 29941 29941 + 29871 29941 29941 29941 29941 29941 29941 29941 + 29871 29941 29941 29941 29941 29941 29941 29941 29941 + 29871 29941 29941 29941 29941 29941 29941 29941 29941 29941 + 315 228 190 176 29874 10630 30529 29873 + 29871 2313 3163 + 29871 13 29871 13 13 29871 13 13 13 29871 12 29871 12 12 29871 12 13 259 13 1678 13 268 13 418 13 243 162 157 131 313 8945 29897 29871 243 162 155 185 30722 243 162 143 174 30598 313 20787 953 3848 275 16125 630 29897 29871 31681 29871 243 162 169 156 243 162 169 156 29871 29941 29871 29941 29941 29871 29941 29941 29941 29871 29941 29941 29941 29941 29871 29941 29941 29941 29941 29941 29871 29941 29941 29941 29941 29941 29941 29871 29941 29941 29941 29941 29941 29941 29941 29871 29941 29941 29941 29941 29941 29941 29941 29941 29871 29941 29889 29941 29871 29941 636 29941 29871 29941 856 29941 29871 31849 31324 31934 228 162 142 228 161 146 228 162 133 228 161 153 228 161 186 31708 228 162 132 31708 228 161 165 31324 228 161 136 243 162 155 132 1577 30672 31522 30505 11548 31041 30732 29896 29941 29896 29946 29896 29945 29896 30408 30739 448 23648 2751 25512 1538 4851 665 1386 29713 1305 14550 4907 11120 16159 16159 16159 15945 15945 3045 636 6824 6824 6824 8773 8773 8773 306 29915 345 1063 525 29873 1025 540 29915 29879 727 29892 525 1525 366 1854 29973 525 29924 451 1854 306 29915 645 1207 372 29892 525 29928 366 763 777 23429 29973 1334 29915 29963 29872 263 29915 29880 29931 diff --git a/models/ggml-vocab-mpt.gguf b/models/ggml-vocab-mpt.gguf new file mode 100644 index 0000000000000000000000000000000000000000..188bfa128057c36f035c44eedc0bcb4a77bdf630 --- /dev/null +++ b/models/ggml-vocab-mpt.gguf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:59dc382612866d1fc6c11ea531318d327598f3412d9c8f8600607cdf3030898f +size 1771393 diff --git a/models/ggml-vocab-mpt.gguf.inp b/models/ggml-vocab-mpt.gguf.inp new file mode 100644 index 0000000000000000000000000000000000000000..86b934e4020fba7061c15b6349d5e4163bb2383e --- /dev/null +++ b/models/ggml-vocab-mpt.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Äpfel +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-mpt.gguf.out b/models/ggml-vocab-mpt.gguf.out new file mode 100644 index 0000000000000000000000000000000000000000..ca62669ad0945dcf646759d38b0192e33086b85e --- /dev/null +++ b/models/ggml-vocab-mpt.gguf.out @@ -0,0 +1,46 @@ + 728 577 24142 2607 + 37515 18569 293 + + 209 + 50276 + 50275 + 186 + 187 + 535 + 2756 + 186 187 + 12092 1533 + 24387 1533 + 12092 3645 + 24387 3645 + 24387 3645 2 + 12092 13 1533 2 + 24387 13 1533 2 + 436 310 22692 101 236 15 14161 + 88 27244 818 16853 16392 20505 4989 11917 + 32520 11514 1068 8713 38177 13396 3415 9925 12559 10453 1389 + 18081 211 18081 116 18081 230 39936 222 18081 226 39936 213 18081 233 18081 117 18081 242 39936 212 18081 242 18081 97 18081 116 18081 216 18081 212 18081 238 18081 216 39936 212 18081 220 + 14931 237 211 313 6320 10 49042 116 325 224 14931 223 106 171 118 226 313 34263 802 13511 261 32147 456 10 3384 239 216 313 7483 802 80 8020 326 556 697 1211 10669 10 + 12092 + 24387 + 50276 12092 + 50275 12092 + 50274 12092 + 50274 12092 187 50274 12092 + 313 + 187 426 + 8 8685 + 12092 13 340 8 455 2 1359 403 368 49042 212 3736 15367 41197 13610 19934 41869 21275 1012 1047 18795 40120 20422 241 + 18963 4672 + 20 + 1610 + 20084 + 26409 + 1610 20084 + 26409 1610 + 26409 20084 + 26409 26409 + 26409 1610 20084 + 36 6829 244 66 17721 35177 85 + 1262 2196 + 586 1744 33525 186 209 623 28910 187 50276 187 50275 187 50274 187 50273 187 14931 237 211 313 6320 10 49042 116 325 224 14931 223 106 171 118 226 313 34263 802 13511 261 32147 456 10 3384 239 216 22692 101 236 14931 101 236 495 5922 30057 495 20084 495 26409 30057 20084 495 26409 1610 495 26409 20084 495 15 20 495 537 20 495 1051 20 209 18081 211 18081 116 18081 230 39936 222 18081 226 39936 213 18081 233 18081 117 18081 242 39936 212 18081 242 18081 97 18081 116 18081 216 14931 235 212 3736 15367 41197 13610 19934 41869 21275 1012 1047 18795 40120 20422 241 16081 6877 12880 11514 1068 8713 38177 13396 3415 9925 12559 10453 1389 42011 35033 34842 11202 9739 9739 33021 18963 4672 25561 8220 309 1849 644 686 42618 344 434 627 13 686 1848 368 2119 32 686 46 417 2119 309 1833 1056 352 13 686 37 368 751 690 10331 32 844 8 31516 247 8 77 45 diff --git a/models/ggml-vocab-nomic-bert-moe.gguf b/models/ggml-vocab-nomic-bert-moe.gguf new file mode 100644 index 0000000000000000000000000000000000000000..fee7ec54610be32a526dd4f0b76625c4a1e0ae0e --- /dev/null +++ b/models/ggml-vocab-nomic-bert-moe.gguf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:90a6746926454784a98389ad36a36d89bc9cfc81db9cb0f33c941bcc959fe5f9 +size 6821877 diff --git a/models/ggml-vocab-phi-3.gguf b/models/ggml-vocab-phi-3.gguf new file mode 100644 index 0000000000000000000000000000000000000000..c1db01d8d6c3f59d275809509b130fe95203ba19 --- /dev/null +++ b/models/ggml-vocab-phi-3.gguf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:967d7190d11c4842eab697079d98d56c2116e10eb617be355a2733bfc132e326 +size 726019 diff --git a/models/ggml-vocab-phi-3.gguf.inp b/models/ggml-vocab-phi-3.gguf.inp new file mode 100644 index 0000000000000000000000000000000000000000..86b934e4020fba7061c15b6349d5e4163bb2383e --- /dev/null +++ b/models/ggml-vocab-phi-3.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Äpfel +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-phi-3.gguf.out b/models/ggml-vocab-phi-3.gguf.out new file mode 100644 index 0000000000000000000000000000000000000000..2a71a6ef86efa9ee78509747de5cd1ceb59c8a5a --- /dev/null +++ b/models/ggml-vocab-phi-3.gguf.out @@ -0,0 +1,46 @@ + 474 287 29871 29946 29871 30226 7378 + 11585 7810 295 + + 259 + 1678 + 268 + 29871 12 + 29871 13 + 29871 13 13 + 29871 13 13 13 + 29871 12 13 + 15043 3186 + 29871 15043 3186 + 15043 2787 + 29871 15043 2787 + 29871 15043 2787 29991 + 15043 29892 3186 29991 + 29871 15043 29892 3186 29991 + 29871 445 338 29871 243 162 169 156 29889 8223 + 281 29900 29946 29947 29871 29955 9161 13535 18031 2176 6905 + 1538 4851 665 1386 29713 1305 + 29871 31849 31324 31934 228 162 142 228 161 146 228 162 133 228 161 153 228 161 186 31708 228 162 132 31708 228 161 165 31324 228 161 136 228 161 132 228 161 158 228 161 136 228 162 132 228 161 140 + 29871 243 162 157 131 313 8945 29897 29871 243 162 155 185 30722 243 162 143 174 30598 313 20787 953 3848 275 16125 630 29897 29871 31681 313 6194 953 29877 2397 393 756 967 1914 5993 29897 + 15043 + 29871 15043 + 259 15043 + 1678 15043 + 268 15043 + 268 15043 13 1678 15043 + 29871 313 + 29871 13 353 + 525 3152 + 15043 29892 343 29915 497 29991 1128 526 366 29871 243 162 155 132 1577 30672 31522 30505 11548 31041 30732 29896 29941 29896 29946 29896 29945 29896 30408 30739 + 1738 6824 21004 + 29871 29941 + 29871 29941 29941 + 29871 29941 29941 29941 + 29871 29941 29941 29941 29941 + 29871 29941 29941 29941 29941 29941 + 29871 29941 29941 29941 29941 29941 29941 + 29871 29941 29941 29941 29941 29941 29941 29941 + 29871 29941 29941 29941 29941 29941 29941 29941 29941 + 29871 29941 29941 29941 29941 29941 29941 29941 29941 29941 + 315 228 190 176 29874 10630 30529 29873 + 29871 2313 3163 + 29871 13 29871 13 13 29871 13 13 13 29871 12 29871 12 12 29871 12 13 259 13 1678 13 268 13 418 13 243 162 157 131 313 8945 29897 29871 243 162 155 185 30722 243 162 143 174 30598 313 20787 953 3848 275 16125 630 29897 29871 31681 29871 243 162 169 156 243 162 169 156 29871 29941 29871 29941 29941 29871 29941 29941 29941 29871 29941 29941 29941 29941 29871 29941 29941 29941 29941 29941 29871 29941 29941 29941 29941 29941 29941 29871 29941 29941 29941 29941 29941 29941 29941 29871 29941 29941 29941 29941 29941 29941 29941 29941 29871 29941 29889 29941 29871 29941 636 29941 29871 29941 856 29941 29871 31849 31324 31934 228 162 142 228 161 146 228 162 133 228 161 153 228 161 186 31708 228 162 132 31708 228 161 165 31324 228 161 136 243 162 155 132 1577 30672 31522 30505 11548 31041 30732 29896 29941 29896 29946 29896 29945 29896 30408 30739 448 23648 2751 25512 1538 4851 665 1386 29713 1305 14550 4907 11120 16159 16159 16159 15945 15945 3045 636 6824 6824 6824 8773 8773 8773 306 29915 345 1063 525 29873 1025 540 29915 29879 727 29892 525 1525 366 1854 29973 525 29924 451 1854 306 29915 645 1207 372 29892 525 29928 366 763 777 23429 29973 1334 29915 29963 29872 263 29915 29880 29931 diff --git a/models/ggml-vocab-qwen2.gguf b/models/ggml-vocab-qwen2.gguf new file mode 100644 index 0000000000000000000000000000000000000000..d97fb6fc8f222f7a92f8c0a916c9549de38b81dd --- /dev/null +++ b/models/ggml-vocab-qwen2.gguf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:44c2f46b715f585c6ab513970e8a006bfa5badd6108560054921cf598d154d8c +size 5928681 diff --git a/models/ggml-vocab-qwen2.gguf.inp b/models/ggml-vocab-qwen2.gguf.inp new file mode 100644 index 0000000000000000000000000000000000000000..86b934e4020fba7061c15b6349d5e4163bb2383e --- /dev/null +++ b/models/ggml-vocab-qwen2.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Äpfel +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-qwen2.gguf.out b/models/ggml-vocab-qwen2.gguf.out new file mode 100644 index 0000000000000000000000000000000000000000..595d59a44963cddd3653b3659406c2d65d5640b2 --- /dev/null +++ b/models/ggml-vocab-qwen2.gguf.out @@ -0,0 +1,46 @@ + 1122 220 19 220 26062 3951 + 86975 15897 301 + + 220 + 256 + 262 + 197 + 198 + 271 + 1406 + 1572 + 9707 1879 + 21927 1879 + 9707 4337 + 21927 4337 + 21927 4337 0 + 9707 11 1879 0 + 21927 11 1879 0 + 419 374 11162 99 247 13 10821 + 86 15 19 23 220 22 83 1963 41808 11472 2940 16739 + 78762 14144 1456 13073 63471 33594 3038 133178 79012 + 146394 97529 241 44258 233 146568 44258 224 147603 20879 115 146280 44258 223 146280 147272 97529 227 147805 148301 147270 44258 223 146848 + 145836 320 8252 8 26525 114 378 235 149921 30543 320 35673 99066 97534 8 25521 227 320 3243 42365 429 702 1181 1828 3950 8 + 9707 + 21927 + 220 21927 + 256 21927 + 262 21927 + 262 21927 198 262 21927 + 320 + 198 284 + 6 11385 + 9707 11 379 64848 0 2585 525 498 26525 223 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216 + 17085 2928 + 18 + 18 18 + 18 18 18 + 18 18 18 18 + 18 18 18 18 18 + 18 18 18 18 18 18 + 18 18 18 18 18 18 18 + 18 18 18 18 18 18 18 18 + 18 18 18 18 18 18 18 18 18 + 34 90063 128324 + 2560 2347 + 198 4710 14731 65497 7847 1572 2303 78672 10947 145836 320 8252 8 26525 114 378 235 149921 30543 320 35673 99066 97534 8 25521 227 11162 99 247 149955 220 18 220 18 18 220 18 18 18 220 18 18 18 18 220 18 18 18 18 18 220 18 18 18 18 18 18 220 18 18 18 18 18 18 18 220 18 18 18 18 18 18 18 18 220 18 13 18 220 18 496 18 220 18 1112 18 220 146394 97529 241 44258 233 146568 44258 224 147603 20879 115 146280 44258 223 146280 147272 97529 227 144534 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216 55460 53237 18658 14144 1456 13073 63471 33594 3038 133178 79012 3355 4605 4605 13874 13874 73594 3014 3014 28149 17085 2928 26610 7646 358 3003 1012 364 83 813 566 594 1052 11 364 787 498 2704 30 364 44 537 2704 358 3278 1281 432 11 364 35 498 1075 1045 15243 30 1205 6 42612 264 63866 43 diff --git a/models/ggml-vocab-refact.gguf b/models/ggml-vocab-refact.gguf new file mode 100644 index 0000000000000000000000000000000000000000..ba90829117ad48af114796550f2681549dea09ad --- /dev/null +++ b/models/ggml-vocab-refact.gguf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ac3ceda902fed91ccf74312b305d9b86c37e4f8e35fa9cc6ef3ce34fca7d4678 +size 1720710 diff --git a/models/ggml-vocab-refact.gguf.inp b/models/ggml-vocab-refact.gguf.inp new file mode 100644 index 0000000000000000000000000000000000000000..86b934e4020fba7061c15b6349d5e4163bb2383e --- /dev/null +++ b/models/ggml-vocab-refact.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Äpfel +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-refact.gguf.out b/models/ggml-vocab-refact.gguf.out new file mode 100644 index 0000000000000000000000000000000000000000..f13dda52ce41e1d93309d9def498d3cf2c322d22 --- /dev/null +++ b/models/ggml-vocab-refact.gguf.out @@ -0,0 +1,46 @@ + 4833 225 38 225 143 140 17723 + 144 231 7132 342 + + 225 + 261 + 264 + 202 + 203 + 478 + 2831 + 15773 + 8279 5788 + 12000 5788 + 8279 10896 + 12000 10896 + 12000 10896 19 + 8279 30 5788 19 + 12000 30 5788 19 + 458 438 5945 118 252 32 3766 + 105 34 38 42 225 41 102 1707 12530 10180 1479 8278 + 39862 8372 1039 9446 40242 13852 2053 8949 12531 1520 10700 + 14574 227 14574 133 14574 246 30457 238 14574 242 30457 229 14574 249 14574 134 14574 258 30457 228 14574 258 14574 114 14574 133 14574 232 14574 228 14574 254 14574 232 30457 228 14574 236 + 3807 253 227 308 4382 27 18458 133 46113 44967 123 13868 308 12565 19775 33071 40824 733 27 41889 308 2585 22680 688 1401 2819 4369 2404 27 + 8279 + 12000 + 225 12000 + 261 12000 + 264 12000 + 264 12000 284 12000 + 308 + 203 280 + 25 34666 + 8279 30 533 25 464 19 4971 884 844 18458 228 1018 4982 13368 2909 9513 17827 35 37 35 38 35 39 35 11873 47838 + 9163 3202 + 37 + 37 37 + 37 37 37 + 37 37 37 37 + 37 37 37 37 37 + 37 37 37 37 37 37 + 37 37 37 37 37 37 37 + 37 37 37 37 37 37 37 37 + 37 37 37 37 37 37 37 37 37 + 53 33934 83 33217 17102 102 + 1214 12258 + 334 719 8878 202 10885 4222 16104 28570 203 3807 253 227 308 4382 27 18458 133 46113 44967 123 13868 308 12565 19775 33071 40824 733 27 41889 5945 118 252 3807 118 252 225 37 225 37 37 225 37 37 37 225 37 37 37 37 225 37 37 37 37 37 225 37 37 37 37 37 37 225 37 37 37 37 37 37 37 225 37 37 37 37 37 37 37 37 225 37 32 37 225 37 497 37 225 37 1179 37 225 14574 227 14574 133 14574 246 30457 238 14574 242 30457 229 14574 249 14574 134 14574 258 30457 228 14574 258 14574 114 14574 133 14574 232 36628 228 1018 4982 13368 2909 9513 17827 35 37 35 38 35 39 35 11873 47838 20921 16623 13028 8372 1039 9446 40242 13852 2053 8949 12531 1520 10700 5881 9592 13299 914 31753 31359 9163 3202 35472 10397 439 4763 2583 330 102 1455 938 1182 2017 30 330 613 844 3654 49 330 63 646 3654 439 4621 1930 561 30 330 54 844 2124 1629 35993 49 2688 25 7709 312 25 94 62 diff --git a/models/ggml-vocab-starcoder.gguf b/models/ggml-vocab-starcoder.gguf new file mode 100644 index 0000000000000000000000000000000000000000..66c113ad7e480a2421e7098595e0e2f28e427aa7 --- /dev/null +++ b/models/ggml-vocab-starcoder.gguf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fedb892b4e1bd3c1f2fcdae356440b14fb458f4264d586e5c987ed93df4e174d +size 1719346 diff --git a/models/ggml-vocab-starcoder.gguf.inp b/models/ggml-vocab-starcoder.gguf.inp new file mode 100644 index 0000000000000000000000000000000000000000..86b934e4020fba7061c15b6349d5e4163bb2383e --- /dev/null +++ b/models/ggml-vocab-starcoder.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Äpfel +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-starcoder.gguf.out b/models/ggml-vocab-starcoder.gguf.out new file mode 100644 index 0000000000000000000000000000000000000000..4698e2c3c81ad21e5046d5ba1a0bb8010f3c212b --- /dev/null +++ b/models/ggml-vocab-starcoder.gguf.out @@ -0,0 +1,46 @@ + 4850 244 57 244 162 159 17722 + 163 250 7146 361 + + 244 + 280 + 283 + 221 + 222 + 499 + 3067 + 15767 + 8302 5810 + 12009 5810 + 8302 10914 + 12009 10914 + 12009 10914 38 + 8302 49 5810 38 + 12009 49 5810 38 + 477 458 5954 137 271 51 3779 + 124 53 57 61 244 60 121 1726 12568 10240 1519 8290 + 39916 8389 1059 9504 40216 13858 2073 8983 12571 1539 10721 + 14566 246 14566 152 14566 265 30428 257 14566 261 30428 248 14566 268 14566 153 14566 277 30428 247 14566 277 14566 133 14566 152 14566 251 14566 247 14566 273 14566 251 30428 247 14566 255 + 3822 272 246 327 4434 46 18445 152 46030 45022 142 13878 327 12585 19884 33773 40920 751 46 41839 327 2605 22716 708 1421 2840 4387 2421 46 + 8302 + 12009 + 244 12009 + 280 12009 + 283 12009 + 283 12009 303 12009 + 327 + 222 299 + 44 34719 + 8302 49 553 44 483 38 4998 904 863 18445 247 1037 4995 13379 2924 9515 17823 54 56 54 57 54 58 54 11904 47892 + 9221 3226 + 56 + 56 56 + 56 56 56 + 56 56 56 56 + 56 56 56 56 56 + 56 56 56 56 56 56 + 56 56 56 56 56 56 56 + 56 56 56 56 56 56 56 56 + 56 56 56 56 56 56 56 56 56 + 72 34269 102 33245 17234 121 + 1236 12266 + 353 736 8886 221 10883 4238 16101 28540 222 3822 272 246 327 4434 46 18445 152 46030 45022 142 13878 327 12585 19884 33773 40920 751 46 41839 5954 137 271 3822 137 271 244 56 244 56 56 244 56 56 56 244 56 56 56 56 244 56 56 56 56 56 244 56 56 56 56 56 56 244 56 56 56 56 56 56 56 244 56 56 56 56 56 56 56 56 244 56 51 56 244 56 516 56 244 56 1198 56 244 14566 246 14566 152 14566 265 30428 257 14566 261 30428 248 14566 268 14566 153 14566 277 30428 247 14566 277 14566 133 14566 152 14566 251 36570 247 1037 4995 13379 2924 9515 17823 54 56 54 57 54 58 54 11904 47892 20895 16625 13047 8389 1059 9504 40216 13858 2073 8983 12571 1539 10721 5918 9643 13298 932 31723 31330 9221 3226 35426 10400 457 4783 2602 349 121 1477 957 1200 2038 49 349 632 863 3673 68 349 82 666 3673 457 4650 1949 580 49 349 73 863 2144 1649 35941 68 2726 44 7728 331 44 113 81 diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000000000000000000000000000000000000..e51910ca77e7fd4c8f45c1e9f1f4279028899f94 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,7 @@ +[mypy] +strict = true +allow_untyped_calls = true +allow_untyped_defs = true +allow_incomplete_defs = true +disable_error_code = import-untyped +warn_return_any = false diff --git a/pocs/CMakeLists.txt b/pocs/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d49d14dee4351395d04745660bd970085cfe6aad --- /dev/null +++ b/pocs/CMakeLists.txt @@ -0,0 +1,14 @@ +# dependencies + +find_package(Threads REQUIRED) + +# third-party + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) + +if (EMSCRIPTEN) +else() + if (NOT GGML_BACKEND_DL) + add_subdirectory(vdot) + endif() +endif() diff --git a/pocs/vdot/CMakeLists.txt b/pocs/vdot/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..6235aec1fdade03f7c0e4c8abdfdf1c94f71b8b0 --- /dev/null +++ b/pocs/vdot/CMakeLists.txt @@ -0,0 +1,9 @@ +set(TARGET llama-vdot) +add_executable(${TARGET} vdot.cpp) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) + +set(TARGET llama-q8dot) +add_executable(${TARGET} q8dot.cpp) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/pocs/vdot/q8dot.cpp b/pocs/vdot/q8dot.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3df6e1f42112f658f6b9337408a0272dfeb16179 --- /dev/null +++ b/pocs/vdot/q8dot.cpp @@ -0,0 +1,173 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +constexpr int kVecSize = 1 << 16; + +// Copy-pasted from ggml.c +#define QK4_0 32 +typedef struct { + float d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding"); + +#define QK4_1 32 +typedef struct { + float d; // delta + float m; // min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; +static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); + +// Copy-pasted from ggml.c +#define QK8_0 32 +typedef struct { + float d; // delta + float s; // d * sum(qs[i]) + int8_t qs[QK8_0]; // quants +} block_q8_0; +static_assert(sizeof(block_q8_0) == 2*sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); + +static_assert(QK4_1 == QK8_0, "QK4_1 and QK8_0 must be the same"); +static_assert(QK4_0 == QK8_0, "QK4_0 and QK8_0 must be the same"); + +template +static void fillQ4blocks(std::vector& blocks, std::mt19937& rndm) { + for (auto& b : blocks) { + b.d = 1; + for (int i=0; i> 28; + uint8_t v2 = rndm() >> 28; + b.qs[i] = v1 | (v2 << 4); + } + } +} + +static void fillQ80blocks(std::vector& blocks, std::mt19937& rndm) { + for (auto& b : blocks) { + b.d = 1; + int sum = 0; + for (int i=0; i> 24) - 128; + sum += b.qs[i]; + } + b.s = b.d * sum; + } +} + +static float simpleDot(const block_q4_0& x, const block_q8_0& y) { + int s1 = 0; //, s2 = 0; + for (int i=0; i> 4; + int v3 = x.qs[i+1] & 0xf; + int v4 = x.qs[i+1] >> 4; + int j = 2*i; + s1 += v1*y.qs[j] + v2*y.qs[j+1] + v3*y.qs[j+2] + v4*y.qs[j+3]; + //s2 += y.qs[j] + y.qs[j+1] + y.qs[j+2] + y.qs[j+3]; + } + return y.d * x.d * s1 - 8 * x.d * y.s; + //return y.d * x.d * (s1 - 8 * s2); +} + +static float simpleDot(const block_q4_1& x, const block_q8_0& y) { + int s1 = 0; //, s2 = 0; + for (int i=0; i> 4; + int v3 = x.qs[i+1] & 0xf; + int v4 = x.qs[i+1] >> 4; + int j = 2*i; + s1 += v1*y.qs[j] + v2*y.qs[j+1] + v3*y.qs[j+2] + v4*y.qs[j+3]; + //s2 += y.qs[j] + y.qs[j+1] + y.qs[j+2] + y.qs[j+3]; + } + return y.d * x.d * s1 + y.s * x.m; + //return y.d * (x.d * s1 + x.m * s2); +} + +struct Stat { + double sum = 0, sumt = 0, sumt2 = 0, maxt = 0; + int nloop = 0; + void addResult(double s, double t) { + sum += s; + sumt += t; sumt2 += t*t; maxt = std::max(maxt, t); + ++nloop; + } + void reportResult(const char* title) const { + if (nloop < 1) { + printf("%s(%s): no result\n",__func__,title); + return; + } + printf("============ %s\n",title); + printf(" = %g\n",sum/nloop); + auto t = sumt/nloop, dt = sumt2/nloop - t*t; + if (dt > 0) dt = sqrt(dt); + printf("
[-h|--help]" + exit 1 +} + +# check for curl or wget +function has_cmd { + if ! [ -x "$(command -v $1)" ]; then + return 1 + fi +} + +if has_cmd wget; then + cmd="wget -q -c -O %s/%s %s" +elif has_cmd curl; then + cmd="curl -C - -f --output-dir %s -o %s -L %s" +else + log "[E] curl or wget not found" + exit 1 +fi + +url="" +repo="" +file="" +outdir="." + +# parse args +while [[ $# -gt 0 ]]; do + case "$1" in + --url) + url="$2" + shift 2 + ;; + --repo) + repo="$2" + shift 2 + ;; + --file) + file="$2" + shift 2 + ;; + --outdir) + outdir="$2" + shift 2 + ;; + -h|--help) + usage + ;; + *) + url="$1" + shift + ;; + esac +done + +if [ -n "$repo" ] && [ -n "$file" ]; then + url="https://huggingface.co/$repo/resolve/main/$file" +fi + +if [ -z "$url" ]; then + log "[E] missing --url" + usage +fi + +# check if the URL is a HuggingFace model, and if so, try to download it +is_url=false + +if [[ ${#url} -gt 22 ]]; then + if [[ ${url:0:22} == "https://huggingface.co" ]]; then + is_url=true + fi +fi + +if [ "$is_url" = false ]; then + log "[E] invalid URL, must start with https://huggingface.co" + exit 0 +fi + +# replace "blob/main" with "resolve/main" +url=${url/blob\/main/resolve\/main} + +basename=$(basename $url) + +log "[+] attempting to download $basename" + +if [ -n "$cmd" ]; then + cmd=$(printf "$cmd" "$outdir" "$basename" "$url") + log "[+] $cmd" + if $cmd; then + echo $outdir/$basename + exit 0 + fi +fi + +log "[-] failed to download" + +exit 1 diff --git a/scripts/install-oneapi.bat b/scripts/install-oneapi.bat new file mode 100644 index 0000000000000000000000000000000000000000..e99bef14a45255c4711ab3baf1813afd548249a7 --- /dev/null +++ b/scripts/install-oneapi.bat @@ -0,0 +1,19 @@ +:: MIT license +:: Copyright (C) 2024 Intel Corporation +:: SPDX-License-Identifier: MIT + + +set URL=%1 +set COMPONENTS=%2 + +curl.exe --output %TEMP%\webimage.exe --url %URL% --retry 5 --retry-delay 5 +start /b /wait %TEMP%\webimage.exe -s -x -f webimage_extracted --log extract.log +del %TEMP%\webimage.exe +if "%COMPONENTS%"=="" ( + webimage_extracted\bootstrapper.exe -s --action install --eula=accept -p=NEED_VS2017_INTEGRATION=0 -p=NEED_VS2019_INTEGRATION=0 -p=NEED_VS2022_INTEGRATION=0 --log-dir=. +) else ( + webimage_extracted\bootstrapper.exe -s --action install --components=%COMPONENTS% --eula=accept -p=NEED_VS2017_INTEGRATION=0 -p=NEED_VS2019_INTEGRATION=0 -p=NEED_VS2022_INTEGRATION=0 --log-dir=. +) +set installer_exit_code=%ERRORLEVEL% +rd /s/q "webimage_extracted" +exit /b %installer_exit_code% diff --git a/scripts/qnt-all.sh b/scripts/qnt-all.sh new file mode 100644 index 0000000000000000000000000000000000000000..bc43738a2f49831c16290d986942f5ca46103f12 --- /dev/null +++ b/scripts/qnt-all.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +qnt=(q8_0 q6_k q5_k q5_1 q5_0 q4_k q4_1 q4_0 q3_k q2_k) +args="" + +if [ -z "$1" ]; then + echo "usage: $0 [qnt] [args]" + echo "default: $0 \"${qnt[@]}\" \"${args}\"" + exit 1 +fi + +if [ ! -z "$2" ]; then + qnt=($2) +fi + +if [ ! -z "$3" ]; then + args="$3" +fi + +model="$1" +out="../tmp/results-${model}" + +set -o pipefail +set -e + +mkdir -p ${out} + +for q in ${qnt[@]}; do + time ./bin/llama-quantize ../models/${model}/ggml-model-f16.gguf ../models/${model}/ggml-model-${q}.gguf ${q} 2>&1 ${args} | tee ${out}/qnt-${q}.txt +done diff --git a/scripts/run-all-perf.sh b/scripts/run-all-perf.sh new file mode 100644 index 0000000000000000000000000000000000000000..6384e364d558439ad5a7e46ba2f0d7188615c66a --- /dev/null +++ b/scripts/run-all-perf.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +qnt=(f16 q8_0 q6_k q5_k q5_1 q5_0 q4_k q4_1 q4_0 q3_k q2_k) +args="-ngl 999 -n 64 -p 512" + +if [ -z "$1" ]; then + echo "usage: $0 [qnt] [args]" + echo "default: $0 \"${qnt[@]}\" \"${args}\"" + exit 1 +fi + +if [ ! -z "$2" ]; then + qnt=($2) +fi + +if [ ! -z "$3" ]; then + args="$3" +fi + +model="$1" +out="../tmp/results-${model}" + +set -o pipefail +set -e + +mkdir -p ${out} + +mstr="" + +for q in ${qnt[@]}; do + mstr="${mstr} -m ../models/${model}/ggml-model-${q}.gguf" +done + +./bin/llama-bench ${mstr} ${args} 2> /dev/null diff --git a/scripts/run-all-ppl.sh b/scripts/run-all-ppl.sh new file mode 100644 index 0000000000000000000000000000000000000000..e15f74f1b666db23a245a0cf3850e8f6929ff0dd --- /dev/null +++ b/scripts/run-all-ppl.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +qnt=(f16 q8_0 q6_k q5_k q5_1 q5_0 q4_k q4_1 q4_0 q3_k q2_k) +args="-ngl 999 -t 8" + +if [ -z "$1" ]; then + echo "usage: $0 [qnt] [args]" + echo "default: $0 \"${qnt[@]}\" \"${args}\"" + exit 1 +fi + +if [ ! -z "$2" ]; then + qnt=($2) +fi + +if [ ! -z "$3" ]; then + args="$3" +fi + +set -o pipefail +set -e + +model="$1" +out="../tmp/results-${model}" + +mkdir -p ${out} + +for q in ${qnt[@]}; do + time ./bin/llama-perplexity -m ../models/${model}/ggml-model-f16.gguf -f ./wiki.test.raw ${args} 2>&1 | tee ${out}/ppl-${q}.txt +done diff --git a/scripts/sync-ggml-am.sh b/scripts/sync-ggml-am.sh new file mode 100644 index 0000000000000000000000000000000000000000..204354209f2d6dccc039199b00fea7099fc35d9c --- /dev/null +++ b/scripts/sync-ggml-am.sh @@ -0,0 +1,202 @@ +#!/bin/bash +# +# Synchronize ggml changes to llama.cpp +# +# Usage: +# +# $ cd /path/to/llama.cpp +# $ ./scripts/sync-ggml-am.sh -skip hash0,hash1,hash2... -C 3 +# + +set -e + +sd=$(dirname $0) +cd $sd/../ + +SRC_LLAMA=$(pwd) +SRC_GGML=$(cd ../ggml; pwd) + +if [ ! -d $SRC_GGML ]; then + echo "ggml not found at $SRC_GGML" + exit 1 +fi + +lc=$(cat $SRC_LLAMA/scripts/sync-ggml.last) +echo "Syncing ggml changes since commit $lc" + +to_skip="" + +# context for git patches in number of lines +ctx="8" + +while [ "$1" != "" ]; do + case $1 in + -skip ) + shift + to_skip=$1 + ;; + -C ) + shift + ctx=$1 + ;; + esac + shift +done + +cd $SRC_GGML + +git log --oneline $lc..HEAD +git log --oneline $lc..HEAD --reverse | grep -v "(llama/[0-9]*)" | cut -d' ' -f1 > $SRC_LLAMA/ggml-commits + +if [ ! -s $SRC_LLAMA/ggml-commits ]; then + rm -v $SRC_LLAMA/ggml-commits + echo "No new commits" + exit 0 +fi + +if [ -f $SRC_LLAMA/ggml-src.patch ]; then + rm -v $SRC_LLAMA/ggml-src.patch +fi + +while read c; do + if [ -n "$to_skip" ]; then + if [[ $to_skip == *"$c"* ]]; then + echo "Skipping $c" + continue + fi + fi + + git format-patch -U${ctx} -k $c~1..$c --stdout -- \ + CMakeLists.txt \ + src/CMakeLists.txt \ + cmake/BuildTypes.cmake \ + cmake/GitVars.cmake \ + cmake/common.cmake \ + cmake/ggml-config.cmake.in \ + src/ggml-cpu/cmake/FindSIMD.cmake \ + src/ggml*.h \ + src/ggml*.c \ + src/ggml*.cpp \ + src/gguf*.cpp \ + src/ggml-blas/* \ + src/ggml-cann/* \ + src/ggml-cpu/* \ + src/ggml-cuda/* \ + src/ggml-hip/* \ + src/ggml-kompute/* \ + src/ggml-metal/* \ + src/ggml-musa/* \ + src/ggml-opencl/* \ + src/ggml-rpc/* \ + src/ggml-sycl/* \ + src/ggml-vulkan/* \ + include/ggml*.h \ + include/gguf*.h \ + tests/test-opt.cpp \ + tests/test-quantize-fns.cpp \ + tests/test-quantize-perf.cpp \ + tests/test-backend-ops.cpp \ + LICENSE \ + scripts/gen-authors.sh \ + >> $SRC_LLAMA/ggml-src.patch +done < $SRC_LLAMA/ggml-commits + +rm -v $SRC_LLAMA/ggml-commits + +# delete files if empty +if [ ! -s $SRC_LLAMA/ggml-src.patch ]; then + rm -v $SRC_LLAMA/ggml-src.patch +fi + +cd $SRC_LLAMA + +if [ -f $SRC_LLAMA/ggml-src.patch ]; then + # replace PR numbers + # + # Subject: some text (#1234) + # Subject: some text (ggml/1234) + cat ggml-src.patch | sed -e 's/^Subject: \(.*\) (#\([0-9]*\))/Subject: \1 (ggml\/\2)/' > ggml-src.patch.tmp + mv ggml-src.patch.tmp ggml-src.patch + + cat ggml-src.patch | sed -e 's/^\(.*\) (#\([0-9]*\))$/\1 (ggml\/\2)/' > ggml-src.patch.tmp + mv ggml-src.patch.tmp ggml-src.patch + + # replace filenames: + # + # CMakelists.txt -> ggml/CMakeLists.txt + # src/CMakeLists.txt -> ggml/src/CMakeLists.txt + + # cmake/BuildTypes.cmake -> ggml/cmake/BuildTypes.cmake + # cmake/GitVars.cmake -> ggml/cmake/GitVars.cmake + # cmake/common.cmake -> ggml/cmake/common.cmake + # cmake/ggml-config.cmake.in -> ggml/cmake/ggml-config.cmake.in + # src/ggml-cpu/cmake/FindSIMD.cmake -> ggml/src/ggml-cpu/cmake/FindSIMD.cmake + # + # src/ggml*.c -> ggml/src/ggml*.c + # src/ggml*.cpp -> ggml/src/ggml*.cpp + # src/ggml*.h -> ggml/src/ggml*.h + # src/gguf*.cpp -> ggml/src/gguf*.cpp + # src/ggml-blas/* -> ggml/src/ggml-blas/* + # src/ggml-cann/* -> ggml/src/ggml-cann/* + # src/ggml-cpu/* -> ggml/src/ggml-cpu/* + # src/ggml-cuda/* -> ggml/src/ggml-cuda/* + # src/ggml-hip/* -> ggml/src/ggml-hip/* + # src/ggml-kompute/* -> ggml/src/ggml-kompute/* + # src/ggml-metal/* -> ggml/src/ggml-metal/* + # src/ggml-musa/* -> ggml/src/ggml-musa/* + # src/ggml-opencl/* -> ggml/src/ggml-opencl/* + # src/ggml-rpc/* -> ggml/src/ggml-rpc/* + # src/ggml-sycl/* -> ggml/src/ggml-sycl/* + # src/ggml-vulkan/* -> ggml/src/ggml-vulkan/* + # + # include/ggml*.h -> ggml/include/ggml*.h + # include/gguf*.h -> ggml/include/gguf*.h + # + # tests/test*.cpp -> tests/ + # + # LICENSE -> LICENSE + # scripts/gen-authors.sh -> scripts/gen-authors.sh + + cat ggml-src.patch | sed -E \ + -e 's/([[:space:]]| [ab]\/)CMakeLists.txt/\1ggml\/CMakeLists.txt/g' \ + -e 's/([[:space:]]| [ab]\/)src\/CMakeLists.txt/\1ggml\/src\/CMakeLists.txt/g' \ + -e 's/([[:space:]]| [ab]\/)cmake\/BuildTypes.cmake/\1ggml\/cmake\/BuildTypes.cmake/g' \ + -e 's/([[:space:]]| [ab]\/)cmake\/GitVars.cmake/\1ggml\/cmake\/GitVars.cmake/g' \ + -e 's/([[:space:]]| [ab]\/)cmake\/common.cmake/\1ggml\/cmake\/common.cmake/g' \ + -e 's/([[:space:]]| [ab]\/)cmake\/ggml-config.cmake.in/\1ggml\/cmake\/ggml-config.cmake.in/g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-cpu\/cmake\/FindSIMD.cmake/\1ggml\/src\/ggml-cpu\/cmake\/FindSIMD.cmake/g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.c/\1ggml\/src\/ggml\2.c/g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.cpp/\1ggml\/src\/ggml\2.cpp/g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.h/\1ggml\/src\/ggml\2.h/g' \ + -e 's/([[:space:]]| [ab]\/)src\/gguf(.*)\.cpp/\1ggml\/src\/gguf\2.cpp/g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-blas\//\1ggml\/src\/ggml-blas\//g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-cann\//\1ggml\/src\/ggml-cann\//g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-cpu\//\1ggml\/src\/ggml-cpu\//g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-cuda\//\1ggml\/src\/ggml-cuda\//g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-hip\//\1ggml\/src\/ggml-hip\//g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-kompute\//\1ggml\/src\/ggml-kompute\//g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-metal\//\1ggml\/src\/ggml-metal\//g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-opencl\//\1ggml\/src\/ggml-opencl\//g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-rpc\//\1ggml\/src\/ggml-rpc\//g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-sycl\//\1ggml\/src\/ggml-sycl\//g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-vulkan\//\1ggml\/src\/ggml-vulkan\//g' \ + -e 's/([[:space:]]| [ab]\/)include\/ggml(.*)\.h/\1ggml\/include\/ggml\2.h/g' \ + -e 's/([[:space:]]| [ab]\/)include\/gguf(.*)\.h/\1ggml\/include\/gguf\2.h/g' \ + -e 's/([[:space:]]| [ab]\/)tests\/(.*)\.cpp/\1tests\/\2.cpp/g' \ + -e 's/([[:space:]]| [ab]\/)LICENSE/\1LICENSE/g' \ + -e 's/([[:space:]]| [ab]\/)scripts\/gen-authors\.sh/\1scripts\/gen-authors.sh/g' \ + > ggml-src.patch.tmp + mv ggml-src.patch.tmp ggml-src.patch + + git am -C${ctx} ggml-src.patch + + rm -v $SRC_LLAMA/ggml-src.patch +fi + +# update last commit +cd $SRC_GGML +git log -1 --format=%H > $SRC_LLAMA/scripts/sync-ggml.last + +echo "Done" + +exit 0 diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last new file mode 100644 index 0000000000000000000000000000000000000000..914fe47ff6a3457d13fb29665bc8b21c0f425388 --- /dev/null +++ b/scripts/sync-ggml.last @@ -0,0 +1 @@ +6a7d170c04789f6ebcf320ed03c1b16973f93bd7 diff --git a/scripts/sync-ggml.sh b/scripts/sync-ggml.sh new file mode 100644 index 0000000000000000000000000000000000000000..aa1a46b4bfccd59e75b98931410594f1d1bfe87d --- /dev/null +++ b/scripts/sync-ggml.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +cp -rpv ../ggml/CMakeLists.txt ./ggml/CMakeLists.txt +cp -rpv ../ggml/src/CMakeLists.txt ./ggml/src/CMakeLists.txt + +cp -rpv ../ggml/cmake/* ./ggml/cmake/ +cp -rpv ../ggml/src/ggml-cpu/cmake/* ./ggml/src/ggml-cpu/cmake/ + +cp -rpv ../ggml/src/ggml*.c ./ggml/src/ +cp -rpv ../ggml/src/ggml*.cpp ./ggml/src/ +cp -rpv ../ggml/src/ggml*.h ./ggml/src/ +cp -rpv ../ggml/src/gguf*.cpp ./ggml/src/ +cp -rpv ../ggml/src/ggml-blas/* ./ggml/src/ggml-blas/ +cp -rpv ../ggml/src/ggml-cann/* ./ggml/src/ggml-cann/ +cp -rpv ../ggml/src/ggml-cpu/* ./ggml/src/ggml-cpu/ +cp -rpv ../ggml/src/ggml-cuda/* ./ggml/src/ggml-cuda/ +cp -rpv ../ggml/src/ggml-hip/* ./ggml/src/ggml-hip/ +cp -rpv ../ggml/src/ggml-kompute/* ./ggml/src/ggml-kompute/ +cp -rpv ../ggml/src/ggml-metal/* ./ggml/src/ggml-metal/ +cp -rpv ../ggml/src/ggml-musa/* ./ggml/src/ggml-musa/ +cp -rpv ../ggml/src/ggml-opencl/* ./ggml/src/ggml-opencl/ +cp -rpv ../ggml/src/ggml-rpc/* ./ggml/src/ggml-rpc/ +cp -rpv ../ggml/src/ggml-sycl/* ./ggml/src/ggml-sycl/ +cp -rpv ../ggml/src/ggml-vulkan/* ./ggml/src/ggml-vulkan/ + +cp -rpv ../ggml/include/ggml*.h ./ggml/include/ +cp -rpv ../ggml/include/gguf*.h ./ggml/include/ + +cp -rpv ../ggml/tests/test-opt.cpp ./tests/test-opt.cpp +cp -rpv ../ggml/tests/test-quantize-fns.cpp ./tests/test-quantize-fns.cpp +cp -rpv ../ggml/tests/test-quantize-perf.cpp ./tests/test-quantize-perf.cpp +cp -rpv ../ggml/tests/test-backend-ops.cpp ./tests/test-backend-ops.cpp + +cp -rpv ../LICENSE ./LICENSE +cp -rpv ../ggml/scripts/gen-authors.sh ./scripts/gen-authors.sh diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py new file mode 100644 index 0000000000000000000000000000000000000000..1151c9f01963b892531cde8246a951fb79f72e2c --- /dev/null +++ b/scripts/sync_vendor.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 + +import urllib.request + +vendor = { + "https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp", + "https://github.com/nlohmann/json/releases/latest/download/json_fwd.hpp": "vendor/nlohmann/json_fwd.hpp", + + # sync manually + # "https://raw.githubusercontent.com/ochafik/minja/refs/heads/main/include/minja/minja.hpp": "vendor/minja/minja.hpp", + # "https://raw.githubusercontent.com/ochafik/minja/refs/heads/main/include/minja/chat-template.hpp": "vendor/minja/chat-template.hpp", + + "https://raw.githubusercontent.com/nothings/stb/refs/heads/master/stb_image.h": "vendor/stb/stb_image.h", + + "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.22/miniaudio.h": "vendor/miniaudio/miniaudio.h", + + "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.20.1/httplib.h": "vendor/cpp-httplib/httplib.h", +} + +for url, filename in vendor.items(): + print(f"downloading {url} to {filename}") # noqa: NP100 + urllib.request.urlretrieve(url, filename) diff --git a/scripts/tool_bench.py b/scripts/tool_bench.py new file mode 100644 index 0000000000000000000000000000000000000000..d8018e2e23c0dfe6360a08c4b36c01cbd65a1790 --- /dev/null +++ b/scripts/tool_bench.py @@ -0,0 +1,379 @@ +#!/usr/bin/env uv run +''' + Simplistic tool call benchmarks for llama-server and ollama. + + Essentially runs the tests at server/tools/server/tests/unit/test_tool_call.py N times, at different temperatures and on different backends (current llama-server, baseline llama-server and ollama), + and plots the results of multiple runs (from same .jsonl file or multiple ones) as a success rate heatmap. + + Simple usage example: + + cmake -B build -DLLAMA_CURL=1 && cmake --build build --config Release -j -t llama-server + + export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server + export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp} + + ./scripts/tool_bench.py run --n 10 --temp -1 --temp 0 --temp 1 --temp 2 --temp 5 --llama-baseline $PWD/buildMaster/bin/llama-server --output qwen14b.jsonl --hf bartowski/Qwen2.5-14B-Instruct-GGUF:Q4_K_L + ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 1.5B Q4_K_M" --output qwen1.5b.jsonl --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M + ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 Coder 7B Q4_K_M" --output qwenc7b.jsonl --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b + + ./scripts/tool_bench.py plot *.jsonl # Opens window w/ heatmap + ./scripts/tool_bench.py plot qwen*.jsonl --output qwen.png # Saves heatmap to qwen.png + + (please see ./scripts/tool_bench.sh for a more complete example) +''' +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "pytest", +# "pandas", +# "matplotlib", +# "seaborn", +# "requests", +# "wget", +# "typer", +# ] +# /// +from contextlib import contextmanager +from pathlib import Path +import re +from statistics import mean, median +from typing import Annotated, Dict, List, Optional, Tuple +import atexit +import json +import logging +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +import subprocess +import sys +import time +import typer + +sys.path.insert(0, Path(__file__).parent.parent.as_posix()) +if True: + from tools.server.tests.utils import ServerProcess + from tools.server.tests.unit.test_tool_call import TIMEOUT_SERVER_START, do_test_calc_result, do_test_hello_world, do_test_weather + + +@contextmanager +def scoped_server(sp: ServerProcess): + def stop(): + nonlocal sp + if sp is not None: + sp.stop() + sp = None # type: ignore + atexit.register(stop) + yield sp + stop() + + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +app = typer.Typer() + + +@app.command() +def plot(files: List[Path], output: Optional[Path] = None, test_regex: Optional[str] = None, server_regex: Optional[str] = None): + + lines: List[Dict] = [] + for file in files: + if not file.exists(): + logger.error(f"File not found: {file}") + continue + + try: + with file.open() as f: + raw_data = f.read() + logger.info(f"Reading {file} ({len(raw_data)} bytes)") + + for line_num, line in enumerate(raw_data.split('\n'), 1): + line = line.strip() + if not line: + continue + try: + record = json.loads(line) + lines.append(record) + except json.JSONDecodeError as e: + logger.warning(f"Invalid JSON at {file}:{line_num} - {e}") + except Exception as e: + logger.error(f"Error processing {file}: {e}") + + if not lines: + raise Exception("No valid data was loaded") + + data_dict: Dict[Tuple, float] = {} + models: List[str] = [] + temps = set() + tests = set() + server_names = set() + total_counts = set() + for rec in lines: + try: + model = rec["model"] + temp = rec["temp"] + server_name = rec["server_name"] + test = rec["test"] + success = rec["success_ratio"] + success_count = rec["success_count"] + failure_count = rec["failure_count"] + total_count = success_count + failure_count + total_counts.add(total_count) + + if test_regex and not re.search(test_regex, test): + continue + + if server_regex and not re.search(server_regex, server_name): + continue + + data_dict[(model, temp, server_name, test)] = success + + if model not in models: + models.append(model) + temps.add(temp) + tests.add(test) + server_names.add(server_name) + + except KeyError as e: + logger.warning(f"Missing required field in record: {e}") + + if len(total_counts) > 1: + logger.warning(f"Total counts are not consistent: {total_counts}") + + # Sort the collected values + temps = list(sorted(temps, key=lambda x: x if x is not None else -1)) + tests = list(sorted(tests)) + server_names = list(sorted(server_names)) + + logger.info(f"Processed {len(lines)} lines") + logger.info(f"Found {len(data_dict)} valid data points") + logger.info(f"Models: {models}") + logger.info(f"Temperatures: {temps}") + logger.info(f"Tests: {tests}") + logger.info(f"Servers: {server_names}") + + matrix: list[list[float]] = [] + index: list[str] = [] + + all_cols = [ + (server_name, test) + for server_name in server_names + for test in tests + ] + for model in models: + for temp in temps: + index.append(f"{model} @ {temp}") + row_vals = [ + data_dict.get((model, temp, server_name, test), np.nan) + for server_name, test in all_cols + ] + matrix.append(row_vals) + + columns: list[str] = [f"{server_name}\n{test}" for server_name, test in all_cols] + + df = pd.DataFrame(matrix, index=np.array(index), columns=np.array(columns)) + + plt.figure(figsize=(12, 6)) + + sns.heatmap( + df, annot=True, cmap="RdYlGn", vmin=0.0, vmax=1.0, cbar=True, fmt=".2f", center=0.5, square=True, linewidths=0.5, + cbar_kws={"label": "Success Ratio"}, + ) + + plt.title(f"Tool Call Bench (n = {str(min(total_counts)) if len(total_counts) == 1 else f'{min(total_counts)}-{max(total_counts)}'})\nSuccess Ratios by Server & Test", pad=20) + plt.xlabel("Server & Test", labelpad=10) + plt.ylabel("Model @ Temperature", labelpad=10) + + plt.xticks(rotation=45, ha='right') + plt.yticks(rotation=0) + + plt.tight_layout() + + if output: + plt.savefig(output, dpi=300, bbox_inches='tight') + logger.info(f"Plot saved to {output}") + else: + plt.show() + + +@app.command() +def run( + output: Annotated[Path, typer.Option(help="Output JSON file")], + model: Annotated[Optional[str], typer.Option(help="Name of the model to test (server agnostic)")] = None, + hf: Annotated[Optional[str], typer.Option(help="GGUF huggingface model repo id (+ optional quant) to test w/ llama-server")] = None, + chat_template: Annotated[Optional[str], typer.Option(help="Chat template override for llama-server")] = None, + chat_template_file: Annotated[Optional[str], typer.Option(help="Chat template file override for llama-server")] = None, + ollama: Annotated[Optional[str], typer.Option(help="Ollama model tag to test")] = None, + llama_baseline: Annotated[Optional[str], typer.Option(help="llama-server baseline binary path to use as baseline")] = None, + n: Annotated[int, typer.Option(help="Number of times to run each test")] = 10, + temp: Annotated[Optional[List[float]], typer.Option(help="Set of temperatures to test")] = None, + top_p: Annotated[Optional[float], typer.Option(help="top_p")] = None, + top_k: Annotated[Optional[int], typer.Option(help="top_k")] = None, + ctk: Annotated[Optional[str], typer.Option(help="ctk")] = None, + ctv: Annotated[Optional[str], typer.Option(help="ctv")] = None, + fa: Annotated[Optional[bool], typer.Option(help="fa")] = None, + seed: Annotated[Optional[int], typer.Option(help="Random seed")] = None, + port: Annotated[int, typer.Option(help="llama-server port")] = 8084, + force: Annotated[bool, typer.Option(help="Force overwrite of output file")] = False, + append: Annotated[bool, typer.Option(help="Append to output file")] = False, + + test_hello_world: Annotated[bool, typer.Option(help="Whether to run the hello world test")] = True, + test_weather: Annotated[bool, typer.Option(help="Whether to run the weather test")] = True, + test_calc_result: Annotated[bool, typer.Option(help="Whether to run the calc result test")] = False, +): + # Check only one of output and append + + n_predict = 512 # High because of DeepSeek R1 + # n_ctx = 8192 + n_ctx = 2048 + + if model is None: + if hf is not None: + model = hf.split("/")[-1] + elif ollama is not None: + model = ollama + + assert force or append or not output.exists(), f"Output file already exists: {output}; use --force to overwrite" + + with output.open('a' if append else 'w') as output_file: + + def run(server: ServerProcess, *, server_name: str, model_id: str, temp: Optional[float] = None, output_kwargs={}, request_kwargs={}): + request_kwargs = {**request_kwargs} + if temp is not None: + request_kwargs['temperature'] = temp + if top_p is not None: + request_kwargs['top_p'] = top_p + if top_k is not None: + request_kwargs['top_k'] = top_k + if seed is not None: + request_kwargs['seed'] = seed + + request_kwargs['cache_prompt'] = False + + tests = {} + if test_hello_world: + tests["hello world"] = lambda server: do_test_hello_world(server, **request_kwargs) + if test_weather: + tests["weather"] = lambda server: do_test_weather(server, **request_kwargs) + if test_calc_result: + tests["calc result"] = lambda server: do_test_calc_result(server, None, 512, **request_kwargs) + + for test_name, test in tests.items(): + success_count = 0 + failure_count = 0 + failures = [] + success_times = [] + failure_times = [] + logger.info(f"Running {test_name} ({server_name}, {model}): ") + for i in range(n): + start_time = time.time() + + def elapsed(): + return time.time() - start_time + + try: + test(server) + success_times.append(elapsed()) + success_count += 1 + logger.info('success') + except Exception as e: + logger.error(f'failure: {e}') + failure_count += 1 + failure_times.append(elapsed()) + failures.append(str(e)) + # import traceback + # traceback.print_exc() + output_file.write(json.dumps({**output_kwargs, **dict( + model=model, + server_name=server_name, + model_id=model_id, + test=test_name, + temp=t, + top_p=top_p, + top_k=top_k, + ctk=ctk, + ctv=ctv, + seed=seed, + success_ratio=float(success_count) / n, + avg_time=mean(success_times + failure_times), + median_time=median(success_times + failure_times), + success_count=success_count, + success_times=success_times, + failure_count=failure_count, + failure_times=failure_times, + failures=list(set(failures)), + )}) + '\n') + output_file.flush() + + for t in [None] if temp is None else [t if t >= 0 else None for t in temp]: + if hf is not None: + + servers: list[Tuple[str, Optional[str]]] = [('llama-server', None)] + if llama_baseline is not None: + servers.append(('llama-server (baseline)', llama_baseline)) + + for server_name, server_path in servers: + server = ServerProcess() + server.n_ctx = n_ctx + server.n_slots = 1 + server.jinja = True + server.ctk = ctk + server.ctv = ctv + server.fa = fa + server.n_predict = n_predict + server.model_hf_repo = hf + server.model_hf_file = None + server.chat_template = chat_template + server.chat_template_file = chat_template_file + server.server_path = server_path + if port is not None: + server.server_port = port + # server.debug = True + + with scoped_server(server): + server.start(timeout_seconds=TIMEOUT_SERVER_START) + for ignore_chat_grammar in [False]: + run( + server, + server_name=server_name, + model_id=hf, + temp=t, + output_kwargs=dict( + chat_template=chat_template, + chat_template_file=chat_template_file, + ), + request_kwargs=dict( + ignore_chat_grammar=ignore_chat_grammar, + ), + ) + + if ollama is not None: + server = ServerProcess() + server.server_port = 11434 + server.server_host = "localhost" + subprocess.check_call(["ollama", "pull", ollama]) + + with scoped_server(server): + run( + server, + server_name="ollama", + model_id=ollama, + temp=t, + output_kwargs=dict( + chat_template=None, + chat_template_file=None, + ), + request_kwargs=dict( + model=ollama, + max_tokens=n_predict, + num_ctx = n_ctx, + ), + ) + + +if __name__ == "__main__": + app() diff --git a/scripts/tool_bench.sh b/scripts/tool_bench.sh new file mode 100644 index 0000000000000000000000000000000000000000..6c7616a88fe5bcd353c6636faa877fcad0744186 --- /dev/null +++ b/scripts/tool_bench.sh @@ -0,0 +1,66 @@ +#!/bin/bash +set -euo pipefail + +cmake --build build -j + +export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp} +export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server + +if [ ! -x "$LLAMA_SERVER_BIN_PATH" ]; then + echo "Could not find llama-server binary at $LLAMA_SERVER_BIN_PATH" + exit 1 +fi +if [ ! -d "$LLAMA_CACHE" ]; then + echo "Could not find llama cache at $LLAMA_CACHE, please set LLAMA_CACHE explicitly." + exit 1 +fi + +export ARGS=( + --llama-baseline="$(which llama-server)" + --n 30 + --temp -1 # Leaves temperature parameter unset (use the server's default, e.g. 0.6 for ollama) + --temp 0 + --temp 0.5 + --temp 0.75 + --temp 1 + --temp 1.5 + --temp 2 + --temp 5 + "$@" +) + +./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 Coder 0.5B Q4_K_M" --output ../qwenc0.5b.jsonl --hf bartowski/Qwen2.5-Coder-0.5B-Instruct-GGUF:Q4_K_M --ollama qwen2.5-coder:0.5b-instruct-q4_K_M +./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 Coder 1.5B Q4_K_M" --output ../qwenc1.5b.jsonl --hf bartowski/Qwen2.5-Coder-1.5B-Instruct-GGUF:Q4_K_M --ollama qwen2.5-coder:1.5b-instruct-q4_K_M +./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 Coder 3B Q4_K_M" --output ../qwenc3b.jsonl --hf bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M --ollama qwen2.5-coder:3b-instruct-q4_K_M +./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 Coder 7B Q4_K_M" --output ../qwenc7b.jsonl --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF:Q4_K_M --ollama qwen2.5-coder:7b-instruct-q4_K_M +./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 Coder 32B Q4_K_M" --output ../qwenc32b.jsonl --hf bartowski/Qwen2.5-Coder-32B-Instruct-GGUF:Q4_K_M --ollama qwen2.5-coder:32B-instruct-q4_K_M +./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 1.5B Q4_K_M" --output ../qwen1.5b.jsonl --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M --ollama qwen2.5:1.5b-instruct-q4_K_M +./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 3B Q4_K_M" --output ../qwen3b.jsonl --hf bartowski/Qwen2.5-3B-Instruct-GGUF:Q4_K_M --ollama qwen2.5:3b-instruct-q4_K_M +./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 7B Q4_K_M" --output ../qwen7b.jsonl --hf bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M --ollama qwen2.5:7b-instruct-q4_K_M + +./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.2 Instruct 1B Q4_K_M" --output ../llama1b.jsonl --hf bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M --ollama llama3.2:1b-instruct-q4_K_M +./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.2 Instruct 3B Q4_K_M" --output ../llama3b.jsonl --hf bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M --ollama llama3.2:3b-instruct-q4_K_M +./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.1 Instruct 8B Q4_K_M" --output ../llama8b.jsonl --hf bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M --ollama llama3.1:8b-instruct-q4_K_M +./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.3 70B Q4_K_M" --output ../llama70b.jsonl --hf bartowski/Llama-3.3-70B-Instruct-GGUF:Q4_K_M + +./scripts/tool_bench.py run ${ARGS[@]} --model "Mistral Nemo Q4_K_M" --output ../nemo.jsonl --hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M --ollama mistral-nemo:12b-instruct-2407-q4_K_M + +./scripts/tool_bench.py run ${ARGS[@]} --model "Hermes 3 Llama 3.1 8B Q4_K_M" --output ../hermes3.jsonl --hf bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M --ollama hermes3:8b-llama3.1-q4_K_M --chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use ) +./scripts/tool_bench.py run ${ARGS[@]} --model "Hermes 2 Pro Llama 3 8B Q4_K_M" --output ../hermes2.jsonl --hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M --ollama hermes2:8b-llama3-q4_K_M --chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use ) + +./scripts/tool_bench.py run ${ARGS[@]} --model "Functionary Small V3.2 Q4_K_M" --output ../funct3.2.jsonl --hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M +./scripts/tool_bench.py run ${ARGS[@]} --model "FireFunction V2 IQ1_M" --output ../firef2.jsonl --hf bartowski/firefunction-v2-GGUF:IQ1_M --chat-template-file <( python scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 tool_use ) + +./scripts/tool_bench.py run ${ARGS[@]} --model "Command R7B 12-2024 Q6_K_L" --output ../c4ai.jsonl --hf bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L --chat-template-file <( python scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use ) + +./scripts/tool_bench.py run ${ARGS[@]} --model "Gemma 2 2B Q8_0" --output ../gemma2.jsonl --hf bartowski/gemma-2-2b-it-GGUF:Q8_0 +./scripts/tool_bench.py run ${ARGS[@]} --model "Phi 4 Instruct Q4_K_M" --output ../phi4.jsonl --hf bartowski/phi-4-GGUF:Q4_K_M # --ollama phi4 +./scripts/tool_bench.py run ${ARGS[@]} --model "Phi 3.5 Mini Instruct Q4_K_M" --output ../phi3.5.jsonl --hf bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M # --ollama phi3.5:3.8b-mini-instruct-q4_K_M + +# ./scripts/tool_bench.py run ${ARGS[@]} --model "DeepSeek R1 Distill Qwen 7B Q6_K_L" --output ../dsqw7.jsonl --hf bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q6_K_L --chat-template-file <( python scripts/get_chat_template.py NousResearch/DeepSeek-R1-Distill-Qwen-7B tool_use ) +# ./scripts/tool_bench.py run ${ARGS[@]} --model "DeepSeek R1 Distill Qwen 32B Q4_K_M" --output ../dsqw32.jsonl --hf bartowski/DeepSeek-R1-Distill-Qwen-32B-GGUF:Q4_K_M --chat-template-file <( python scripts/get_chat_template.py NousResearch/DeepSeek-R1-Distill-Qwen-32B tool_use ) + + +for f in ../*.jsonl; do + ./scripts/tool_bench.py plot "$f" --output ${f%.jsonl}.png || true +done diff --git a/scripts/verify-checksum-models.py b/scripts/verify-checksum-models.py new file mode 100644 index 0000000000000000000000000000000000000000..0b5b9aafaade3d25547cb633bd39c2eaf29f5199 --- /dev/null +++ b/scripts/verify-checksum-models.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 + +import logging +import os +import hashlib + +logger = logging.getLogger("verify-checksum-models") + + +def sha256sum(file): + block_size = 16 * 1024 * 1024 # 16 MB block size + b = bytearray(block_size) + file_hash = hashlib.sha256() + mv = memoryview(b) + with open(file, 'rb', buffering=0) as f: + while True: + n = f.readinto(mv) + if not n: + break + file_hash.update(mv[:n]) + + return file_hash.hexdigest() + + +# Define the path to the llama directory (parent folder of script directory) +llama_path = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) + +# Define the file with the list of hashes and filenames +hash_list_file = os.path.join(llama_path, "SHA256SUMS") + +# Check if the hash list file exists +if not os.path.exists(hash_list_file): + logger.error(f"Hash list file not found: {hash_list_file}") + exit(1) + +# Read the hash file content and split it into an array of lines +with open(hash_list_file, "r") as f: + hash_list = f.read().splitlines() + +# Create an array to store the results +results = [] + +# Loop over each line in the hash list +for line in hash_list: + # Split the line into hash and filename + hash_value, filename = line.split(" ") + + # Get the full path of the file by joining the llama path and the filename + file_path = os.path.join(llama_path, filename) + + # Informing user of the progress of the integrity check + logger.info(f"Verifying the checksum of {file_path}") + + # Check if the file exists + if os.path.exists(file_path): + # Calculate the SHA256 checksum of the file using hashlib + file_hash = sha256sum(file_path) + + # Compare the file hash with the expected hash + if file_hash == hash_value: + valid_checksum = "V" + file_missing = "" + else: + valid_checksum = "" + file_missing = "" + else: + valid_checksum = "" + file_missing = "X" + + # Add the results to the array + results.append({ + "filename": filename, + "valid checksum": valid_checksum, + "file missing": file_missing + }) + + +# Print column headers for results table +print("filename".ljust(40) + "valid checksum".center(20) + "file missing".center(20)) # noqa: NP100 +print("-" * 80) # noqa: NP100 + +# Output the results as a table +for r in results: + print(f"{r['filename']:40} {r['valid checksum']:^20} {r['file missing']:^20}") # noqa: NP100 diff --git a/scripts/xxd.cmake b/scripts/xxd.cmake new file mode 100644 index 0000000000000000000000000000000000000000..14d2753808a8e2fd7709b87d9568e982e7192fb1 --- /dev/null +++ b/scripts/xxd.cmake @@ -0,0 +1,16 @@ +# CMake equivalent of `xxd -i ${INPUT} ${OUTPUT}` +# Usage: cmake -DINPUT=tools/server/public/index.html -DOUTPUT=tools/server/index.html.hpp -P scripts/xxd.cmake + +SET(INPUT "" CACHE STRING "Input File") +SET(OUTPUT "" CACHE STRING "Output File") + +get_filename_component(filename "${INPUT}" NAME) +string(REGEX REPLACE "\\.|-" "_" name "${filename}") + +file(READ "${INPUT}" hex_data HEX) +string(REGEX REPLACE "([0-9a-f][0-9a-f])" "0x\\1," hex_sequence "${hex_data}") + +string(LENGTH ${hex_data} hex_len) +math(EXPR len "${hex_len} / 2") + +file(WRITE "${OUTPUT}" "unsigned char ${name}[] = {${hex_sequence}};\nunsigned int ${name}_len = ${len};\n") diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..70be604e4b0d336a30b24d9cf0601cde62b91a91 --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,49 @@ +llama_add_compile_flags() + +# +# libraries +# + +# llama + +add_library(llama + ../include/llama.h + llama.cpp + llama-adapter.cpp + llama-arch.cpp + llama-batch.cpp + llama-chat.cpp + llama-context.cpp + llama-cparams.cpp + llama-grammar.cpp + llama-graph.cpp + llama-hparams.cpp + llama-impl.cpp + llama-io.cpp + llama-kv-cache-unified.cpp + llama-kv-cache-unified-iswa.cpp + llama-kv-cache-recurrent.cpp + llama-memory.cpp + llama-mmap.cpp + llama-model-loader.cpp + llama-model-saver.cpp + llama-model.cpp + llama-quant.cpp + llama-sampling.cpp + llama-vocab.cpp + unicode-data.cpp + unicode.cpp + unicode.h + ) + +target_include_directories(llama PRIVATE .) +target_include_directories(llama PUBLIC ../include) +target_compile_features (llama PRIVATE cxx_std_17) # don't bump + +target_link_libraries(llama PUBLIC ggml) + +if (BUILD_SHARED_LIBS) + set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_compile_definitions(llama PRIVATE LLAMA_BUILD) + target_compile_definitions(llama PUBLIC LLAMA_SHARED) +endif() diff --git a/src/llama-adapter.cpp b/src/llama-adapter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8d94034aed95debd4b1ea9269996578744ba809b --- /dev/null +++ b/src/llama-adapter.cpp @@ -0,0 +1,388 @@ +#include "llama-adapter.h" + +#include "llama-impl.h" +#include "llama-mmap.h" +#include "llama-model.h" + +#include +#include +#include + +// vec + +ggml_tensor * llama_adapter_cvec::tensor_for(int il) const { + if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) { + return nullptr; + } + + return tensors[il]; +} + +ggml_tensor * llama_adapter_cvec::apply_to(ggml_context * ctx, ggml_tensor * cur, int il) const { + ggml_tensor * layer_dir = tensor_for(il); + if (layer_dir != nullptr) { + cur = ggml_add(ctx, cur, layer_dir); + } + + return cur; +} + +bool llama_adapter_cvec::init(const llama_model & model) { + const auto & hparams = model.hparams; + + GGML_ASSERT(tensors.empty()); + GGML_ASSERT(ctxs.empty()); + GGML_ASSERT(bufs.empty()); + + // create a context for each buffer type + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return nullptr; + } + + ctx_map[buft] = ctx; + ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + // make tensors + tensors.reserve(hparams.n_layer); + tensors.push_back(nullptr); // there's never a tensor for layer 0 + for (size_t il = 1; il < hparams.n_layer; il++) { + ggml_backend_buffer_type_t buft = model.select_buft(il); + ggml_context * ctx = ctx_for_buft(buft); + if (!ctx) { + LLAMA_LOG_ERROR("%s: failed to allocate context for control vector\n", __func__); + return false; + } + ggml_tensor * tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); + tensors.push_back(tensor); + } + + // allocate tensors / buffers and zero + bufs.reserve(ctx_map.size()); + for (auto it : ctx_map) { + ggml_backend_buffer_type_t buft = it.first; + ggml_context * ctx = it.second; + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + LLAMA_LOG_ERROR("%s: failed to allocate buffer for control vector\n", __func__); + return false; + } + ggml_backend_buffer_clear(buf, 0); + bufs.emplace_back(buf); + } + + return true; +} + +bool llama_adapter_cvec::apply( + const llama_model & model, + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end) { + const auto & hparams = model.hparams; + + if (data == nullptr) { + // disable the current control vector (but leave allocated for later) + layer_start = -1; + layer_end = -1; + return true; + } + + if (n_embd != (int) hparams.n_embd) { + LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__); + return false; + } + + if (tensors.empty()) { + if (!init(model)) { + return false; + } + } + + layer_start = il_start; + layer_end = il_end; + + for (size_t il = 1; il < hparams.n_layer; il++) { + assert(tensors[il] != nullptr); + + const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present + if (off + n_embd <= len) { + ggml_backend_tensor_set(tensors[il], data + off, 0, n_embd * ggml_element_size(tensors[il])); + } + } + + return true; +} + +// lora + +llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) { + const std::string name(w->name); + + const auto pos = ab_map.find(name); + if (pos != ab_map.end()) { + return &pos->second; + } + + return nullptr; +} + +static void llama_adapter_lora_init_impl(llama_model & model, const char * path_lora, llama_adapter_lora & adapter) { + LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora); + + ggml_context * ctx_init; + gguf_init_params meta_gguf_params = { + /* .no_alloc = */ true, + /* .ctx = */ &ctx_init, + }; + + gguf_context_ptr ctx_gguf { gguf_init_from_file(path_lora, meta_gguf_params) }; + if (!ctx_gguf) { + throw std::runtime_error("failed to load lora adapter file from " + std::string(path_lora)); + } + + ggml_context_ptr ctx { ctx_init }; + + // check metadata + { + auto get_kv_str = [&](const std::string & key) -> std::string { + int id = gguf_find_key(ctx_gguf.get(), key.c_str()); + return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf.get(), id)); + }; + auto get_kv_f32 = [&](const std::string & key) -> float { + int id = gguf_find_key(ctx_gguf.get(), key.c_str()); + return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf.get(), id); + }; + LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); + + auto general_type = get_kv_str(llm_kv(LLM_KV_GENERAL_TYPE)); + if (general_type != "adapter") { + throw std::runtime_error("expect general.type to be 'adapter', but got: " + general_type); + } + + auto general_arch_str = get_kv_str(llm_kv(LLM_KV_GENERAL_ARCHITECTURE)); + auto general_arch = llm_arch_from_string(general_arch_str); + if (general_arch != model.arch) { + throw std::runtime_error("model arch and LoRA arch mismatch"); + } + + auto adapter_type = get_kv_str(llm_kv(LLM_KV_ADAPTER_TYPE)); + if (adapter_type != "lora") { + throw std::runtime_error("expect adapter.type to be 'lora', but got: " + adapter_type); + } + + adapter.alpha = get_kv_f32(llm_kv(LLM_KV_ADAPTER_LORA_ALPHA)); + } + + int n_tensors = gguf_get_n_tensors(ctx_gguf.get()); + + // contexts for each buffer type + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + // add a new context + ggml_init_params params = { + /*.mem_size =*/ n_tensors*ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context * buft_ctx = ggml_init(params); + if (!buft_ctx) { + return nullptr; + } + ctx_map[buft] = buft_ctx; + adapter.ctxs.emplace_back(buft_ctx); + return buft_ctx; + }; + return it->second; + }; + + // bundle lora_a and lora_b into pairs + std::map ab_map; + auto str_endswith = [](const std::string & str, const std::string & suffix) { + return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; + }; + + for (ggml_tensor * cur = ggml_get_first_tensor(ctx.get()); cur; cur = ggml_get_next_tensor(ctx.get(), cur)) { + std::string name(cur->name); + if (str_endswith(name, ".lora_a")) { + replace_all(name, ".lora_a", ""); + if (ab_map.find(name) == ab_map.end()) { + ab_map[name] = llama_adapter_lora_weight(cur, nullptr); + } else { + ab_map[name].a = cur; + } + } else if (str_endswith(name, ".lora_b")) { + replace_all(name, ".lora_b", ""); + if (ab_map.find(name) == ab_map.end()) { + ab_map[name] = llama_adapter_lora_weight(nullptr, cur); + } else { + ab_map[name].b = cur; + } + } else if (str_endswith(name, "_norm.weight")) { + // TODO: add support for norm vector + // for now, we don't really care because most adapters still work fine without it + continue; + } else { + throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix"); + } + } + + // get extra buffer types of the CPU + // TODO: a more general solution for non-CPU extra buft should be imlpemented in the future + // ref: https://github.com/ggml-org/llama.cpp/pull/12593#pullrequestreview-2718659948 + std::vector buft_extra; + { + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!cpu_dev) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + + auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + + if (ggml_backend_dev_get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_extra.emplace_back(*extra_bufts); + ++extra_bufts; + } + } + } + + // add tensors + for (auto & it : ab_map) { + const std::string & name = it.first; + llama_adapter_lora_weight & w = it.second; + bool is_token_embd = str_endswith(name, "token_embd.weight"); + + if (!w.a || !w.b) { + throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component"); + } + + // device buft and device ctx + const auto * model_tensor = model.get_tensor(name.c_str()); + if (!model_tensor) { + throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)"); + } + + auto * buft = ggml_backend_buffer_get_type(model_tensor->buffer); + + // do not load loras to extra buffer types (i.e. bufts for repacking) -> use the CPU in that case + for (auto & ex : buft_extra) { + if (ex == buft) { + LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft)); + + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!cpu_dev) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } + buft = ggml_backend_dev_buffer_type(cpu_dev); + + break; + } + } + + LLAMA_LOG_DEBUG("%s: lora for '%s' -> '%s'\n", __func__, model_tensor->name, ggml_backend_buft_name(buft)); + + ggml_context * dev_ctx = ctx_for_buft(buft); + // validate tensor shape + if (is_token_embd) { + // expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd() + if (model_tensor->ne[0] != w.b->ne[1] || model_tensor->ne[1] != w.a->ne[1]) { + throw std::runtime_error("tensor '" + name + "' has incorrect shape (hint: maybe wrong base model?)"); + } + } else { + if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) { + throw std::runtime_error("tensor '" + name + "' has incorrect shape (hint: maybe wrong base model?)"); + } + if (w.a->ne[1] != w.b->ne[0]) { + throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)"); + } + } + + // save tensor to adapter + ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a); + ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b); + ggml_set_name(tensor_a, w.a->name); + ggml_set_name(tensor_b, w.b->name); + adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b); + } + + // allocate tensors / buffers and zero + { + adapter.ctxs.reserve(ctx_map.size()); + adapter.bufs.reserve(ctx_map.size()); + for (auto & it : ctx_map) { + ggml_backend_buffer_type_t buft = it.first; + ggml_context * ctx_dev = it.second; + ggml_backend_buffer_ptr buf { ggml_backend_alloc_ctx_tensors_from_buft(ctx_dev, buft) }; + if (!buf) { + throw std::runtime_error("failed to allocate buffer for lora adapter\n"); + } + LLAMA_LOG_INFO("%s: %10s LoRA buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get())/1024.0/1024.0); + adapter.bufs.emplace_back(std::move(buf)); + } + } + + // set tensor data + { + llama_file gguf_file(path_lora, "rb"); + std::vector read_buf; + auto set_tensor = [&](ggml_tensor * orig, ggml_tensor * dev) { + size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name)); + size_t size = ggml_nbytes(orig); + read_buf.resize(size); + gguf_file.seek(offs, SEEK_SET); + gguf_file.read_raw(read_buf.data(), size); + ggml_backend_tensor_set(dev, read_buf.data(), 0, size); + }; + for (auto & it : adapter.ab_map) { + auto orig = ab_map[it.first]; + auto dev = it.second; + set_tensor(orig.a, dev.a); + set_tensor(orig.b, dev.b); + } + } + + LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2); +} + +llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) { + llama_adapter_lora * adapter = new llama_adapter_lora(); + + try { + llama_adapter_lora_init_impl(*model, path_lora, *adapter); + return adapter; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); + + delete adapter; + } + + return nullptr; +} + +void llama_adapter_lora_free(llama_adapter_lora * adapter) { + delete adapter; +} diff --git a/src/llama-adapter.h b/src/llama-adapter.h new file mode 100644 index 0000000000000000000000000000000000000000..65824e972765bd5b8b845e0c41ec46dc81657a4f --- /dev/null +++ b/src/llama-adapter.h @@ -0,0 +1,76 @@ +#pragma once + +#include "llama.h" + +#include "ggml-cpp.h" + +#include +#include +#include + +// TODO: pimpl + +// +// llama_adapter_cvec +// + +struct llama_adapter_cvec { + ggml_tensor * tensor_for(int il) const; + + ggml_tensor * apply_to(ggml_context * ctx, ggml_tensor * cur, int il) const; + + bool apply( + const llama_model & model, + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end); + +private: + bool init(const llama_model & model); + + int32_t layer_start = -1; + int32_t layer_end = -1; + + std::vector ctxs; + std::vector bufs; + + std::vector tensors; // per layer +}; + +// +// llama_adapter_lora +// + +struct llama_adapter_lora_weight { + ggml_tensor * a = nullptr; + ggml_tensor * b = nullptr; + + // get actual scale based on rank and alpha + float get_scale(float alpha, float adapter_scale) const { + const float rank = (float) b->ne[0]; + const float scale = alpha ? adapter_scale * alpha / rank : adapter_scale; + return scale; + } + + llama_adapter_lora_weight() = default; + llama_adapter_lora_weight(ggml_tensor * a, ggml_tensor * b) : a(a), b(b) {} +}; + +struct llama_adapter_lora { + // map tensor name to lora_a_b + std::unordered_map ab_map; + + std::vector ctxs; + std::vector bufs; + + float alpha; + + llama_adapter_lora() = default; + ~llama_adapter_lora() = default; + + llama_adapter_lora_weight * get_weight(ggml_tensor * w); +}; + +using llama_adapter_loras = std::unordered_map; diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp new file mode 100644 index 0000000000000000000000000000000000000000..de8d289cf967e989cbd3fcca85639b1b11c4ba32 --- /dev/null +++ b/src/llama-arch.cpp @@ -0,0 +1,1818 @@ +#include "llama-arch.h" + +#include "llama-impl.h" + +#include + +static const std::map LLM_ARCH_NAMES = { + { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_LLAMA4, "llama4" }, + { LLM_ARCH_DECI, "deci" }, + { LLM_ARCH_FALCON, "falcon" }, + { LLM_ARCH_GROK, "grok" }, + { LLM_ARCH_GPT2, "gpt2" }, + { LLM_ARCH_GPTJ, "gptj" }, + { LLM_ARCH_GPTNEOX, "gptneox" }, + { LLM_ARCH_MPT, "mpt" }, + { LLM_ARCH_BAICHUAN, "baichuan" }, + { LLM_ARCH_STARCODER, "starcoder" }, + { LLM_ARCH_REFACT, "refact" }, + { LLM_ARCH_BERT, "bert" }, + { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, + { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" }, + { LLM_ARCH_NEO_BERT, "neo-bert" }, + { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, + { LLM_ARCH_BLOOM, "bloom" }, + { LLM_ARCH_STABLELM, "stablelm" }, + { LLM_ARCH_QWEN, "qwen" }, + { LLM_ARCH_QWEN2, "qwen2" }, + { LLM_ARCH_QWEN2MOE, "qwen2moe" }, + { LLM_ARCH_QWEN2VL, "qwen2vl" }, + { LLM_ARCH_QWEN3, "qwen3" }, + { LLM_ARCH_QWEN3MOE, "qwen3moe" }, + { LLM_ARCH_PHI2, "phi2" }, + { LLM_ARCH_PHI3, "phi3" }, + { LLM_ARCH_PHIMOE, "phimoe" }, + { LLM_ARCH_PLAMO, "plamo" }, + { LLM_ARCH_CODESHELL, "codeshell" }, + { LLM_ARCH_ORION, "orion" }, + { LLM_ARCH_INTERNLM2, "internlm2" }, + { LLM_ARCH_MINICPM, "minicpm" }, + { LLM_ARCH_MINICPM3, "minicpm3" }, + { LLM_ARCH_GEMMA, "gemma" }, + { LLM_ARCH_GEMMA2, "gemma2" }, + { LLM_ARCH_GEMMA3, "gemma3" }, + { LLM_ARCH_STARCODER2, "starcoder2" }, + { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_XVERSE, "xverse" }, + { LLM_ARCH_COMMAND_R, "command-r" }, + { LLM_ARCH_COHERE2, "cohere2" }, + { LLM_ARCH_DBRX, "dbrx" }, + { LLM_ARCH_OLMO, "olmo" }, + { LLM_ARCH_OLMO2, "olmo2" }, + { LLM_ARCH_OLMOE, "olmoe" }, + { LLM_ARCH_OPENELM, "openelm" }, + { LLM_ARCH_ARCTIC, "arctic" }, + { LLM_ARCH_DEEPSEEK, "deepseek" }, + { LLM_ARCH_DEEPSEEK2, "deepseek2" }, + { LLM_ARCH_CHATGLM, "chatglm" }, + { LLM_ARCH_GLM4, "glm4" }, + { LLM_ARCH_BITNET, "bitnet" }, + { LLM_ARCH_T5, "t5" }, + { LLM_ARCH_T5ENCODER, "t5encoder" }, + { LLM_ARCH_JAIS, "jais" }, + { LLM_ARCH_NEMOTRON, "nemotron" }, + { LLM_ARCH_EXAONE, "exaone" }, + { LLM_ARCH_RWKV6, "rwkv6" }, + { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, + { LLM_ARCH_RWKV7, "rwkv7" }, + { LLM_ARCH_ARWKV7, "arwkv7" }, + { LLM_ARCH_GRANITE, "granite" }, + { LLM_ARCH_GRANITE_MOE, "granitemoe" }, + { LLM_ARCH_CHAMELEON, "chameleon" }, + { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, + { LLM_ARCH_PLM, "plm" }, + { LLM_ARCH_BAILINGMOE, "bailingmoe" }, + { LLM_ARCH_DOTS1, "dots1" }, + { LLM_ARCH_ARCEE, "arcee" }, + { LLM_ARCH_UNKNOWN, "(unknown)" }, +}; + +static const std::map LLM_KV_NAMES = { + { LLM_KV_GENERAL_TYPE, "general.type" }, + { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, + { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, + { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, + { LLM_KV_GENERAL_FILE_TYPE, "general.file_type" }, + { LLM_KV_GENERAL_NAME, "general.name" }, + { LLM_KV_GENERAL_AUTHOR, "general.author" }, + { LLM_KV_GENERAL_VERSION, "general.version" }, + { LLM_KV_GENERAL_URL, "general.url" }, + { LLM_KV_GENERAL_DESCRIPTION, "general.description" }, + { LLM_KV_GENERAL_LICENSE, "general.license" }, + { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, + { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, + + { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, + { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, + { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, + { LLM_KV_FEATURES_LENGTH, "%s.features_length" }, + { LLM_KV_BLOCK_COUNT, "%s.block_count" }, + { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, + { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, + { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, + { LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" }, + { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, + { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, + { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, + { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, + { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, + { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, + { LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" }, + { LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" }, + { LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" }, + { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, + { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, + { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, + { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, + { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, + { LLM_KV_SWIN_NORM, "%s.swin_norm" }, + { LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" }, + { LLM_KV_TIME_MIX_EXTRA_DIM, "%s.time_mix_extra_dim" }, + { LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" }, + { LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" }, + { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, + { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" }, + { LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" }, + + { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, + { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, + { LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" }, + { LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" }, + { LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" }, + { LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" }, + { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, + { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, + { LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" }, + { LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" }, + { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, + { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, + { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, + { LLM_KV_ATTENTION_DECAY_LORA_RANK, "%s.attention.decay_lora_rank" }, + { LLM_KV_ATTENTION_ICLR_LORA_RANK, "%s.attention.iclr_lora_rank" }, + { LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, "%s.attention.value_residual_mix_lora_rank" }, + { LLM_KV_ATTENTION_GATE_LORA_RANK, "%s.attention.gate_lora_rank" }, + { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, + { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, + { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, + { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, + { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, + + { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, + { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, + { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, + { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, + { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, + { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, + { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, + { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, + + { LLM_KV_SPLIT_NO, "split.no" }, + { LLM_KV_SPLIT_COUNT, "split.count" }, + { LLM_KV_SPLIT_TENSORS_COUNT, "split.tensors.count" }, + + { LLM_KV_SSM_CONV_KERNEL, "%s.ssm.conv_kernel" }, + { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, + { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, + { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, + { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, + + { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, + + { LLM_KV_POSNET_EMBEDDING_LENGTH, "%s.posnet.embedding_length" }, + { LLM_KV_POSNET_BLOCK_COUNT, "%s.posnet.block_count" }, + + { LLM_KV_CONVNEXT_EMBEDDING_LENGTH, "%s.convnext.embedding_length" }, + { LLM_KV_CONVNEXT_BLOCK_COUNT, "%s.convnext.block_count" }, + + { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" }, + + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, + { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, + { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, + { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, + { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, + { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, + { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, + { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, + { LLM_KV_TOKENIZER_EOM_ID, "tokenizer.ggml.eom_token_id" }, + { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, + { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, + { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, + { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" }, + { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, + { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, + { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, + { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, + { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" }, + { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, + { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, + { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, + { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, + { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, + { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, + { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, + { LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" }, + { LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" }, + { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" }, + + { LLM_KV_ADAPTER_TYPE, "adapter.type" }, + { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, + + // deprecated + { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, + { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, + { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, +}; + +static const std::map> LLM_TENSOR_NAMES = { + { + LLM_ARCH_LLAMA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_ARCEE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_LLAMA4, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, + { + LLM_ARCH_DECI, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_BAICHUAN, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_FALCON, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_GROK, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + }, + }, + { + LLM_ARCH_GPT2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_POS_EMBD, "position_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, + { + LLM_ARCH_GPTJ, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + }, + }, + { + LLM_ARCH_GPTNEOX, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_MPT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output"}, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_ACT, "blk.%d.ffn.act" }, + { LLM_TENSOR_POS_EMBD, "position_embd" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm"}, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm"}, + }, + }, + { + LLM_ARCH_STARCODER, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_POS_EMBD, "position_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, + { + LLM_ARCH_REFACT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_BERT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_POS_EMBD, "position_embd" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_CLS, "cls" }, + { LLM_TENSOR_CLS_OUT, "cls.output" }, + }, + }, + { + LLM_ARCH_NOMIC_BERT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_NOMIC_BERT_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_NEO_BERT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" }, + { LLM_TENSOR_CLS, "cls" }, + { LLM_TENSOR_CLS_OUT, "cls.output" }, + }, + }, + { + LLM_ARCH_JINA_BERT_V2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_CLS, "cls" }, + }, + }, + { + LLM_ARCH_BLOOM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, + { + LLM_ARCH_STABLELM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + }, + }, + { + LLM_ARCH_QWEN, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_QWEN2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_QWEN2VL, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_QWEN2MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, + { + LLM_ARCH_QWEN3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_QWEN3MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_PHI2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_PHI3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, + { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_PHIMOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, + { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_PLAMO, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_CODESHELL, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_ORION, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_INTERNLM2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_MINICPM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, + { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + }, + }, + { + LLM_ARCH_MINICPM3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, + { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, + { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, + { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, + { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, + { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, + { + LLM_ARCH_GEMMA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_GEMMA2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, + { + LLM_ARCH_GEMMA3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, + { + LLM_ARCH_STARCODER2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_MAMBA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + }, + }, + { + LLM_ARCH_XVERSE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_COMMAND_R, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + }, + }, + { + LLM_ARCH_COHERE2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_DBRX, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_OLMO, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_OLMO2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_OLMOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_OPENELM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_ARCTIC, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_DEEPSEEK, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, + { + LLM_ARCH_DEEPSEEK2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, + { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, + { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, + { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, + { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, + { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + }, + }, + { + LLM_ARCH_PLM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, + { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, + { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_CHATGLM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, + { + LLM_ARCH_GLM4, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, + { + LLM_ARCH_BITNET, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, + }, + }, + { + LLM_ARCH_T5, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_DEC_OUTPUT_NORM, "dec.output_norm" }, + { LLM_TENSOR_DEC_ATTN_NORM, "dec.blk.%d.attn_norm" }, + { LLM_TENSOR_DEC_ATTN_Q, "dec.blk.%d.attn_q" }, + { LLM_TENSOR_DEC_ATTN_K, "dec.blk.%d.attn_k" }, + { LLM_TENSOR_DEC_ATTN_V, "dec.blk.%d.attn_v" }, + { LLM_TENSOR_DEC_ATTN_OUT, "dec.blk.%d.attn_o" }, + { LLM_TENSOR_DEC_ATTN_REL_B, "dec.blk.%d.attn_rel_b" }, + { LLM_TENSOR_DEC_CROSS_ATTN_NORM, "dec.blk.%d.cross_attn_norm" }, + { LLM_TENSOR_DEC_CROSS_ATTN_Q, "dec.blk.%d.cross_attn_q" }, + { LLM_TENSOR_DEC_CROSS_ATTN_K, "dec.blk.%d.cross_attn_k" }, + { LLM_TENSOR_DEC_CROSS_ATTN_V, "dec.blk.%d.cross_attn_v" }, + { LLM_TENSOR_DEC_CROSS_ATTN_OUT, "dec.blk.%d.cross_attn_o" }, + { LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "dec.blk.%d.cross_attn_rel_b" }, + { LLM_TENSOR_DEC_FFN_NORM, "dec.blk.%d.ffn_norm" }, + { LLM_TENSOR_DEC_FFN_GATE, "dec.blk.%d.ffn_gate" }, + { LLM_TENSOR_DEC_FFN_DOWN, "dec.blk.%d.ffn_down" }, + { LLM_TENSOR_DEC_FFN_UP, "dec.blk.%d.ffn_up" }, + { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" }, + { LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" }, + { LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" }, + { LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" }, + { LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" }, + { LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" }, + { LLM_TENSOR_ENC_ATTN_REL_B, "enc.blk.%d.attn_rel_b" }, + { LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" }, + { LLM_TENSOR_ENC_FFN_GATE, "enc.blk.%d.ffn_gate" }, + { LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" }, + { LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_T5ENCODER, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" }, + { LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" }, + { LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" }, + { LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" }, + { LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" }, + { LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" }, + { LLM_TENSOR_ENC_ATTN_REL_B, "enc.blk.%d.attn_rel_b" }, + { LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" }, + { LLM_TENSOR_ENC_FFN_GATE, "enc.blk.%d.ffn_gate" }, + { LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" }, + { LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_JAIS, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, + { + LLM_ARCH_NEMOTRON, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_EXAONE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_RWKV6, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, + { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" }, + { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" }, + { LLM_TENSOR_TIME_MIX_LERP_X, "blk.%d.time_mix_lerp_x" }, + { LLM_TENSOR_TIME_MIX_LERP_W, "blk.%d.time_mix_lerp_w" }, + { LLM_TENSOR_TIME_MIX_LERP_K, "blk.%d.time_mix_lerp_k" }, + { LLM_TENSOR_TIME_MIX_LERP_V, "blk.%d.time_mix_lerp_v" }, + { LLM_TENSOR_TIME_MIX_LERP_R, "blk.%d.time_mix_lerp_r" }, + { LLM_TENSOR_TIME_MIX_LERP_G, "blk.%d.time_mix_lerp_g" }, + { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" }, + { LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" }, + { LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" }, + { LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" }, + { LLM_TENSOR_TIME_MIX_DECAY_W2, "blk.%d.time_mix_decay_w2" }, + { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" }, + { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" }, + { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" }, + { LLM_TENSOR_TIME_MIX_GATE, "blk.%d.time_mix_gate" }, + { LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" }, + { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" }, + { LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" }, + { LLM_TENSOR_CHANNEL_MIX_LERP_R, "blk.%d.channel_mix_lerp_r" }, + { LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" }, + { LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" }, + { LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "blk.%d.channel_mix_receptance" }, + }, + }, + { + LLM_ARCH_RWKV6QWEN2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" }, + { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" }, + { LLM_TENSOR_TIME_MIX_LERP_X, "blk.%d.time_mix_lerp_x" }, + { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" }, + { LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" }, + { LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" }, + { LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" }, + { LLM_TENSOR_TIME_MIX_DECAY_W2, "blk.%d.time_mix_decay_w2" }, + { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" }, + { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" }, + { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" }, + { LLM_TENSOR_TIME_MIX_GATE, "blk.%d.time_mix_gate" }, + { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_RWKV7, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, + { LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" }, + { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" }, + { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" }, + { LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" }, + { LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" }, + { LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" }, + { LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" }, + { LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" }, + { LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" }, + { LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" }, + { LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" }, + { LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" }, + { LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" }, + { LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" }, + { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" }, + { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" }, + { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" }, + { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" }, + { LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" }, + { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" }, + { LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" }, + { LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" }, + { LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" }, + }, + }, + { + LLM_ARCH_ARWKV7, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" }, + { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" }, + { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" }, + { LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" }, + { LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" }, + { LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" }, + { LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" }, + { LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" }, + { LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" }, + { LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" }, + { LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" }, + { LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" }, + { LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" }, + { LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" }, + { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" }, + { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" }, + { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" }, + { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" }, + { LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" }, + { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_GRANITE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_GRANITE_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, + { + LLM_ARCH_CHAMELEON, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + }, + }, + { + LLM_ARCH_WAVTOKENIZER_DEC, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_CONV1D, "conv1d" }, + { LLM_TENSOR_CONVNEXT_DW, "convnext.%d.dw" }, + { LLM_TENSOR_CONVNEXT_NORM, "convnext.%d.norm" }, + { LLM_TENSOR_CONVNEXT_PW1, "convnext.%d.pw1" }, + { LLM_TENSOR_CONVNEXT_PW2, "convnext.%d.pw2" }, + { LLM_TENSOR_CONVNEXT_GAMMA, "convnext.%d.gamma" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_POS_NET_CONV1, "posnet.%d.conv1" }, + { LLM_TENSOR_POS_NET_CONV2, "posnet.%d.conv2" }, + { LLM_TENSOR_POS_NET_NORM, "posnet.%d.norm" }, + { LLM_TENSOR_POS_NET_NORM1, "posnet.%d.norm1" }, + { LLM_TENSOR_POS_NET_NORM2, "posnet.%d.norm2" }, + { LLM_TENSOR_POS_NET_ATTN_NORM, "posnet.%d.attn_norm" }, + { LLM_TENSOR_POS_NET_ATTN_Q, "posnet.%d.attn_q" }, + { LLM_TENSOR_POS_NET_ATTN_K, "posnet.%d.attn_k" }, + { LLM_TENSOR_POS_NET_ATTN_V, "posnet.%d.attn_v" }, + { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" }, + }, + }, + { + LLM_ARCH_BAILINGMOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, + { + LLM_ARCH_DOTS1, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + } + }, + { + LLM_ARCH_UNKNOWN, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + }, + }, +}; + +static const std::map LLM_TENSOR_INFOS = { + {LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_ROPE_FREQS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, + {LLM_TENSOR_ROPE_FACTORS_LONG, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, + {LLM_TENSOR_ROPE_FACTORS_SHORT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, + {LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_CROSS_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_CROSS_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_CROSS_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_CROSS_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ENC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ENC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ENC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ENC_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ENC_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ENC_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ENC_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_GATE_INP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_GATE_INP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_IN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_A2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_V1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_V2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_G1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_G2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_VALUE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_RECEPTANCE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_OUTPUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CHANNEL_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CHANNEL_MIX_VALUE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_ACT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}}, + {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, + {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}}, + {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_TIME_MIX_K_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_TIME_MIX_K_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_TIME_MIX_R_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_W0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_A0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_V0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}}, + {LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_NORM_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_Q_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_LAYER_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_Q_A_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_KV_A_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_SUB_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_SUB_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_DEC_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_DEC_CROSS_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_DEC_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ENC_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ENC_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_DEC_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_ENC_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, + {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, + {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, + {LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + // this tensor is loaded for T5, but never used + {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, + {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}}, + {LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_POS_NET_NORM2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_POS_NET_CONV1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}}, + {LLM_TENSOR_POS_NET_CONV2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}}, + {LLM_TENSOR_POS_NET_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_POS_NET_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_POS_NET_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_POS_NET_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_POS_NET_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CONVNEXT_DW, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}}, + {LLM_TENSOR_CONVNEXT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, +}; + +LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} + +std::string LLM_KV::operator()(llm_kv kv) const { + std::string name = ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); + + if (suffix != nullptr) { + name += "."; + name += suffix; + } + + return name; +} + +std::string LLM_TN_IMPL::str() const { + if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { + return "__missing__"; + } + + std::string name = ::format(LLM_TENSOR_NAMES.at(arch).at(tensor), bid, xid); + + if (suffix != nullptr) { + name += "."; + name += suffix; + } + + return name; +} + +const char * llm_arch_name(llm_arch arch) { + auto it = LLM_ARCH_NAMES.find(arch); + if (it == LLM_ARCH_NAMES.end()) { + return "unknown"; + } + return it->second; +} + +llm_arch llm_arch_from_string(const std::string & name) { + for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT + if (kv.second == name) { + return kv.first; + } + } + + return LLM_ARCH_UNKNOWN; +} + +const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) { + return LLM_TENSOR_INFOS.at(tensor); +} diff --git a/src/llama-arch.h b/src/llama-arch.h new file mode 100644 index 0000000000000000000000000000000000000000..3e8a61da3c13e38fc1711e003d2f2b6ba3f59393 --- /dev/null +++ b/src/llama-arch.h @@ -0,0 +1,441 @@ +#pragma once + +#include "ggml.h" // ggml_op + +#include + +// +// gguf constants (sync with gguf.py) +// + +enum llm_arch { + LLM_ARCH_LLAMA, + LLM_ARCH_LLAMA4, + LLM_ARCH_DECI, + LLM_ARCH_FALCON, + LLM_ARCH_BAICHUAN, + LLM_ARCH_GROK, + LLM_ARCH_GPT2, + LLM_ARCH_GPTJ, + LLM_ARCH_GPTNEOX, + LLM_ARCH_MPT, + LLM_ARCH_STARCODER, + LLM_ARCH_REFACT, + LLM_ARCH_BERT, + LLM_ARCH_NOMIC_BERT, + LLM_ARCH_NOMIC_BERT_MOE, + LLM_ARCH_NEO_BERT, + LLM_ARCH_JINA_BERT_V2, + LLM_ARCH_BLOOM, + LLM_ARCH_STABLELM, + LLM_ARCH_QWEN, + LLM_ARCH_QWEN2, + LLM_ARCH_QWEN2MOE, + LLM_ARCH_QWEN2VL, + LLM_ARCH_QWEN3, + LLM_ARCH_QWEN3MOE, + LLM_ARCH_PHI2, + LLM_ARCH_PHI3, + LLM_ARCH_PHIMOE, + LLM_ARCH_PLAMO, + LLM_ARCH_CODESHELL, + LLM_ARCH_ORION, + LLM_ARCH_INTERNLM2, + LLM_ARCH_MINICPM, + LLM_ARCH_MINICPM3, + LLM_ARCH_GEMMA, + LLM_ARCH_GEMMA2, + LLM_ARCH_GEMMA3, + LLM_ARCH_STARCODER2, + LLM_ARCH_MAMBA, + LLM_ARCH_XVERSE, + LLM_ARCH_COMMAND_R, + LLM_ARCH_COHERE2, + LLM_ARCH_DBRX, + LLM_ARCH_OLMO, + LLM_ARCH_OLMO2, + LLM_ARCH_OLMOE, + LLM_ARCH_OPENELM, + LLM_ARCH_ARCTIC, + LLM_ARCH_DEEPSEEK, + LLM_ARCH_DEEPSEEK2, + LLM_ARCH_CHATGLM, + LLM_ARCH_GLM4, + LLM_ARCH_BITNET, + LLM_ARCH_T5, + LLM_ARCH_T5ENCODER, + LLM_ARCH_JAIS, + LLM_ARCH_NEMOTRON, + LLM_ARCH_EXAONE, + LLM_ARCH_RWKV6, + LLM_ARCH_RWKV6QWEN2, + LLM_ARCH_RWKV7, + LLM_ARCH_ARWKV7, + LLM_ARCH_GRANITE, + LLM_ARCH_GRANITE_MOE, + LLM_ARCH_CHAMELEON, + LLM_ARCH_WAVTOKENIZER_DEC, + LLM_ARCH_PLM, + LLM_ARCH_BAILINGMOE, + LLM_ARCH_DOTS1, + LLM_ARCH_ARCEE, + LLM_ARCH_UNKNOWN, +}; + +enum llm_kv { + LLM_KV_GENERAL_TYPE, + LLM_KV_GENERAL_ARCHITECTURE, + LLM_KV_GENERAL_QUANTIZATION_VERSION, + LLM_KV_GENERAL_ALIGNMENT, + LLM_KV_GENERAL_FILE_TYPE, + LLM_KV_GENERAL_NAME, + LLM_KV_GENERAL_AUTHOR, + LLM_KV_GENERAL_VERSION, + LLM_KV_GENERAL_URL, + LLM_KV_GENERAL_DESCRIPTION, + LLM_KV_GENERAL_LICENSE, + LLM_KV_GENERAL_SOURCE_URL, + LLM_KV_GENERAL_SOURCE_HF_REPO, + + LLM_KV_VOCAB_SIZE, + LLM_KV_CONTEXT_LENGTH, + LLM_KV_EMBEDDING_LENGTH, + LLM_KV_FEATURES_LENGTH, + LLM_KV_BLOCK_COUNT, + LLM_KV_LEADING_DENSE_BLOCK_COUNT, + LLM_KV_FEED_FORWARD_LENGTH, + LLM_KV_EXPERT_FEED_FORWARD_LENGTH, + LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, + LLM_KV_USE_PARALLEL_RESIDUAL, + LLM_KV_TENSOR_DATA_LAYOUT, + LLM_KV_EXPERT_COUNT, + LLM_KV_EXPERT_USED_COUNT, + LLM_KV_EXPERT_SHARED_COUNT, + LLM_KV_EXPERT_WEIGHTS_SCALE, + LLM_KV_EXPERT_WEIGHTS_NORM, + LLM_KV_EXPERT_GATING_FUNC, + LLM_KV_MOE_EVERY_N_LAYERS, + LLM_KV_POOLING_TYPE, + LLM_KV_LOGIT_SCALE, + LLM_KV_DECODER_START_TOKEN_ID, + LLM_KV_ATTN_LOGIT_SOFTCAPPING, + LLM_KV_FINAL_LOGIT_SOFTCAPPING, + LLM_KV_SWIN_NORM, + LLM_KV_RESCALE_EVERY_N_LAYERS, + LLM_KV_TIME_MIX_EXTRA_DIM, + LLM_KV_TIME_DECAY_EXTRA_DIM, + LLM_KV_RESIDUAL_SCALE, + LLM_KV_EMBEDDING_SCALE, + LLM_KV_TOKEN_SHIFT_COUNT, + LLM_KV_INTERLEAVE_MOE_LAYER_STEP, + + LLM_KV_ATTENTION_HEAD_COUNT, + LLM_KV_ATTENTION_HEAD_COUNT_KV, + LLM_KV_ATTENTION_MAX_ALIBI_BIAS, + LLM_KV_ATTENTION_CLAMP_KQV, + LLM_KV_ATTENTION_KEY_LENGTH, + LLM_KV_ATTENTION_VALUE_LENGTH, + LLM_KV_ATTENTION_LAYERNORM_EPS, + LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, + LLM_KV_ATTENTION_GROUPNORM_EPS, + LLM_KV_ATTENTION_GROUPNORM_GROUPS, + LLM_KV_ATTENTION_CAUSAL, + LLM_KV_ATTENTION_Q_LORA_RANK, + LLM_KV_ATTENTION_KV_LORA_RANK, + LLM_KV_ATTENTION_DECAY_LORA_RANK, + LLM_KV_ATTENTION_ICLR_LORA_RANK, + LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, + LLM_KV_ATTENTION_GATE_LORA_RANK, + LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, + LLM_KV_ATTENTION_SLIDING_WINDOW, + LLM_KV_ATTENTION_SCALE, + LLM_KV_ATTENTION_KEY_LENGTH_MLA, + LLM_KV_ATTENTION_VALUE_LENGTH_MLA, + + LLM_KV_ROPE_DIMENSION_COUNT, + LLM_KV_ROPE_DIMENSION_SECTIONS, + LLM_KV_ROPE_FREQ_BASE, + LLM_KV_ROPE_SCALE_LINEAR, + LLM_KV_ROPE_SCALING_TYPE, + LLM_KV_ROPE_SCALING_FACTOR, + LLM_KV_ROPE_SCALING_ATTN_FACTOR, + LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, + LLM_KV_ROPE_SCALING_FINETUNED, + LLM_KV_ROPE_SCALING_YARN_LOG_MUL, + + LLM_KV_SPLIT_NO, + LLM_KV_SPLIT_COUNT, + LLM_KV_SPLIT_TENSORS_COUNT, + + LLM_KV_SSM_INNER_SIZE, + LLM_KV_SSM_CONV_KERNEL, + LLM_KV_SSM_STATE_SIZE, + LLM_KV_SSM_TIME_STEP_RANK, + LLM_KV_SSM_DT_B_C_RMS, + + LLM_KV_WKV_HEAD_SIZE, + + LLM_KV_TOKENIZER_MODEL, + LLM_KV_TOKENIZER_PRE, + LLM_KV_TOKENIZER_LIST, + LLM_KV_TOKENIZER_TOKEN_TYPE, + LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, + LLM_KV_TOKENIZER_SCORES, + LLM_KV_TOKENIZER_MERGES, + LLM_KV_TOKENIZER_BOS_ID, + LLM_KV_TOKENIZER_EOS_ID, + LLM_KV_TOKENIZER_EOT_ID, + LLM_KV_TOKENIZER_EOM_ID, + LLM_KV_TOKENIZER_UNK_ID, + LLM_KV_TOKENIZER_SEP_ID, + LLM_KV_TOKENIZER_PAD_ID, + LLM_KV_TOKENIZER_CLS_ID, + LLM_KV_TOKENIZER_MASK_ID, + LLM_KV_TOKENIZER_ADD_BOS, + LLM_KV_TOKENIZER_ADD_EOS, + LLM_KV_TOKENIZER_ADD_PREFIX, + LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, + LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, + LLM_KV_TOKENIZER_HF_JSON, + LLM_KV_TOKENIZER_RWKV, + LLM_KV_TOKENIZER_CHAT_TEMPLATE, + LLM_KV_TOKENIZER_FIM_PRE_ID, + LLM_KV_TOKENIZER_FIM_SUF_ID, + LLM_KV_TOKENIZER_FIM_MID_ID, + LLM_KV_TOKENIZER_FIM_PAD_ID, + LLM_KV_TOKENIZER_FIM_REP_ID, + LLM_KV_TOKENIZER_FIM_SEP_ID, + + LLM_KV_ADAPTER_TYPE, + LLM_KV_ADAPTER_LORA_ALPHA, + + LLM_KV_POSNET_EMBEDDING_LENGTH, + LLM_KV_POSNET_BLOCK_COUNT, + + LLM_KV_CONVNEXT_EMBEDDING_LENGTH, + LLM_KV_CONVNEXT_BLOCK_COUNT, + + LLM_KV_CLASSIFIER_OUTPUT_LABELS, + + // deprecated: + LLM_KV_TOKENIZER_PREFIX_ID, + LLM_KV_TOKENIZER_SUFFIX_ID, + LLM_KV_TOKENIZER_MIDDLE_ID, +}; + +enum llm_tensor { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_TOKEN_EMBD_NORM, + LLM_TENSOR_TOKEN_TYPES, + LLM_TENSOR_POS_EMBD, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ROPE_FACTORS_LONG, + LLM_TENSOR_ROPE_FACTORS_SHORT, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_NORM_2, + LLM_TENSOR_ATTN_OUT_NORM, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_ATTN_ROT_EMBD, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_POST_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_ACT, + LLM_TENSOR_FFN_DOWN_EXP, // split experts for backward compatibility + LLM_TENSOR_FFN_GATE_EXP, + LLM_TENSOR_FFN_UP_EXP, + LLM_TENSOR_FFN_NORM_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, // merged experts + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_LAYER_OUT_NORM, + LLM_TENSOR_POST_ATTN_NORM, + LLM_TENSOR_POST_MLP_NORM, + LLM_TENSOR_SSM_IN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_X, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_D, + LLM_TENSOR_SSM_OUT, + LLM_TENSOR_TIME_MIX_W0, + LLM_TENSOR_TIME_MIX_W1, + LLM_TENSOR_TIME_MIX_W2, + LLM_TENSOR_TIME_MIX_A0, + LLM_TENSOR_TIME_MIX_A1, + LLM_TENSOR_TIME_MIX_A2, + LLM_TENSOR_TIME_MIX_V0, + LLM_TENSOR_TIME_MIX_V1, + LLM_TENSOR_TIME_MIX_V2, + LLM_TENSOR_TIME_MIX_G1, + LLM_TENSOR_TIME_MIX_G2, + LLM_TENSOR_TIME_MIX_K_K, + LLM_TENSOR_TIME_MIX_K_A, + LLM_TENSOR_TIME_MIX_R_K, + LLM_TENSOR_TIME_MIX_LERP_X, + LLM_TENSOR_TIME_MIX_LERP_W, + LLM_TENSOR_TIME_MIX_LERP_K, + LLM_TENSOR_TIME_MIX_LERP_V, + LLM_TENSOR_TIME_MIX_LERP_R, + LLM_TENSOR_TIME_MIX_LERP_G, + LLM_TENSOR_TIME_MIX_LERP_FUSED, + LLM_TENSOR_TIME_MIX_FIRST, + LLM_TENSOR_TIME_MIX_DECAY, + LLM_TENSOR_TIME_MIX_DECAY_W1, + LLM_TENSOR_TIME_MIX_DECAY_W2, + LLM_TENSOR_TIME_MIX_KEY, + LLM_TENSOR_TIME_MIX_VALUE, + LLM_TENSOR_TIME_MIX_RECEPTANCE, + LLM_TENSOR_TIME_MIX_GATE, + LLM_TENSOR_TIME_MIX_LN, + LLM_TENSOR_TIME_MIX_OUTPUT, + LLM_TENSOR_CHANNEL_MIX_LERP_K, + LLM_TENSOR_CHANNEL_MIX_LERP_R, + LLM_TENSOR_CHANNEL_MIX_KEY, + LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, + LLM_TENSOR_CHANNEL_MIX_VALUE, + LLM_TENSOR_ATTN_Q_A, + LLM_TENSOR_ATTN_Q_B, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, + LLM_TENSOR_ATTN_Q_A_NORM, + LLM_TENSOR_ATTN_KV_A_NORM, + LLM_TENSOR_ATTN_SUB_NORM, + LLM_TENSOR_FFN_SUB_NORM, + LLM_TENSOR_DEC_ATTN_NORM, + LLM_TENSOR_DEC_ATTN_Q, + LLM_TENSOR_DEC_ATTN_K, + LLM_TENSOR_DEC_ATTN_V, + LLM_TENSOR_DEC_ATTN_OUT, + LLM_TENSOR_DEC_ATTN_REL_B, + LLM_TENSOR_DEC_CROSS_ATTN_NORM, + LLM_TENSOR_DEC_CROSS_ATTN_Q, + LLM_TENSOR_DEC_CROSS_ATTN_K, + LLM_TENSOR_DEC_CROSS_ATTN_V, + LLM_TENSOR_DEC_CROSS_ATTN_OUT, + LLM_TENSOR_DEC_CROSS_ATTN_REL_B, + LLM_TENSOR_DEC_FFN_NORM, + LLM_TENSOR_DEC_FFN_GATE, + LLM_TENSOR_DEC_FFN_DOWN, + LLM_TENSOR_DEC_FFN_UP, + LLM_TENSOR_DEC_OUTPUT_NORM, + LLM_TENSOR_ENC_ATTN_NORM, + LLM_TENSOR_ENC_ATTN_Q, + LLM_TENSOR_ENC_ATTN_K, + LLM_TENSOR_ENC_ATTN_V, + LLM_TENSOR_ENC_ATTN_OUT, + LLM_TENSOR_ENC_ATTN_REL_B, + LLM_TENSOR_ENC_FFN_NORM, + LLM_TENSOR_ENC_FFN_GATE, + LLM_TENSOR_ENC_FFN_DOWN, + LLM_TENSOR_ENC_FFN_UP, + LLM_TENSOR_ENC_OUTPUT_NORM, + LLM_TENSOR_CLS, + LLM_TENSOR_CLS_OUT, + LLM_TENSOR_CONV1D, + LLM_TENSOR_CONVNEXT_DW, + LLM_TENSOR_CONVNEXT_NORM, + LLM_TENSOR_CONVNEXT_PW1, + LLM_TENSOR_CONVNEXT_PW2, + LLM_TENSOR_CONVNEXT_GAMMA, + LLM_TENSOR_POS_NET_CONV1, + LLM_TENSOR_POS_NET_CONV2, + LLM_TENSOR_POS_NET_NORM, + LLM_TENSOR_POS_NET_NORM1, + LLM_TENSOR_POS_NET_NORM2, + LLM_TENSOR_POS_NET_ATTN_NORM, + LLM_TENSOR_POS_NET_ATTN_Q, + LLM_TENSOR_POS_NET_ATTN_K, + LLM_TENSOR_POS_NET_ATTN_V, + LLM_TENSOR_POS_NET_ATTN_OUT, +}; + +enum llm_tensor_layer { + LLM_TENSOR_LAYER_INPUT, + LLM_TENSOR_LAYER_REPEATING, + LLM_TENSOR_LAYER_OUTPUT, +}; + +struct LLM_KV { + LLM_KV(llm_arch arch, const char * suffix = nullptr); + + llm_arch arch; + const char * suffix; + + std::string operator()(llm_kv kv) const; +}; + +// helper to handle gguf constants +// usage: +// +// const auto tn = LLM_TN(LLM_ARCH_LLAMA); +// +// std::string name = tn(LLM_TENSOR_OUTPUT); -> "output" +// std::string name = tn(LLM_TENSOR_TOKEN_EMBD, "bias"); -> "token_embd.bias" +// std::string name = tn(LLM_TENSOR_ATTN_NORM, "weight", 3); -> "blk.3.attn_norm.weight" +// +struct LLM_TN_IMPL { + const llm_arch arch; + const llm_tensor tensor; + const char * const suffix; + const int bid; + const int xid; + + std::string str() const; + + operator std::string() const { + return str(); + } + + friend bool operator==(const std::string & str, const LLM_TN_IMPL & tn) { + return str == tn.str(); + } + + friend bool operator!=(const std::string & str, const LLM_TN_IMPL & tn) { + return str != tn.str(); + } +}; + +struct LLM_TN { + LLM_TN(llm_arch arch) : arch(arch) {} + + llm_arch arch; + + LLM_TN_IMPL operator()(llm_tensor tensor, const char * suffix, int bid = -1, int xid = -1) const { + return { arch, tensor, suffix, bid, xid }; + } + + LLM_TN_IMPL operator()(llm_tensor tensor, int bid = -1, int xid = -1) const { + return { arch, tensor, nullptr, bid, xid }; + } +}; + + +struct llm_tensor_info { + llm_tensor_layer layer; + ggml_op op; +}; + +const char * llm_arch_name(llm_arch arch); + +llm_arch llm_arch_from_string(const std::string & name); + +const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor); diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8b6d14fe8813c3d0874a4ad2aaca71c9e6f0aa4b --- /dev/null +++ b/src/llama-batch.cpp @@ -0,0 +1,632 @@ +#include "llama-batch.h" + +#include "llama-impl.h" +#include "llama-cparams.h" +#include "llama-vocab.h" +#include "llama-memory.h" + +#include +#include +#include +#include + +llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) { + // clear empty sequences + // the previous ubatch is assumed to be gone, + // so nothing should refer to values in these sequences anymore. + for (size_t i = seq.size(); i-- > 0;) { + if (seq[i].length == 0) { + seq.pop_back(); + } else { + break; + } + } + + udatas.push_back({}); + + auto & udata = udatas.back(); + + udata.token.resize(!has_embd ? n_ubatch : 0); + udata.embd.resize(has_embd ? n_embd * n_ubatch : 0); + udata.pos.resize(n_ubatch); + udata.n_seq_id.resize(n_ubatch); + udata.seq_id.resize(n_ubatch); + udata.output.resize(n_ubatch); + + llama_ubatch ubatch = { + /*equal_seqs =*/ true, + /*n_tokens =*/ 0, + /*n_seq_tokens =*/ 0, + /*n_seqs =*/ 0, + /*token =*/ !has_embd ? udata.token.data() : nullptr, + /*embd =*/ has_embd ? udata.embd.data() : nullptr, + /*pos =*/ udata.pos.data(), + /*n_seq_id =*/ udata.n_seq_id.data(), + /*seq_id =*/ udata.seq_id.data(), + /*output =*/ udata.output.data(), + }; + + return ubatch; +} + +void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) { + GGML_ASSERT(batch != nullptr); + GGML_ASSERT(length <= seq.length); + // Can only add sequences of equal lengths to a batch, + // otherwise it isn't clear to which sequence a token belongs + GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs); + GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs); + // NOTE: loops are separated for cache-friendliness + if (batch->token) { + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; + } + } else { + // simple split + ubatch.token = batch->token + seq.offset; + } + } else { + ubatch.token = nullptr; + } + if (batch->embd) { + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + memcpy( + ubatch.embd + (n_embd * (ubatch.n_tokens + i)), + batch->embd + (n_embd * ids[seq.offset + i]), + n_embd * sizeof(float) + ); + } + } else { + // simple split + ubatch.embd = batch->embd + (n_embd * seq.offset); + } + } else { + ubatch.embd = nullptr; + } + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; + } + } else { + // simple split + ubatch.pos = batch->pos + seq.offset; + } + if (ubatch.equal_seqs) { + ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id; + if (seq.seq_id) { + ubatch.seq_id[ubatch.n_seqs] = seq.seq_id; + } + } else { + // simple split + if (batch->n_seq_id) { + ubatch.n_seq_id = batch->n_seq_id + seq.offset; + } else { + for (size_t i = 0; i < length; ++i) { + ubatch.n_seq_id[ubatch.n_seqs + i] = 1; + } + } + if (batch->seq_id) { + ubatch.seq_id = batch->seq_id + seq.offset; + } + } + if (batch->logits) { + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + size_t id = ids[seq.offset + i]; + int8_t is_output = batch->logits[id]; + ubatch.output[ubatch.n_tokens + i] = is_output; + if (is_output) { out_ids.push_back(id); } + } + } else { + // simple split + ubatch.output = batch->logits + seq.offset; + for (size_t i = 0; i < length; ++i) { + if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); } + } + } + } else { + // only get last output + for (size_t i = 0; i < length; ++i) { + size_t id = ids[seq.offset + i]; + int8_t is_last = id == ids.size() - 1; + ubatch.output[ubatch.n_tokens + i] = is_last; + if (is_last) { out_ids.push_back(id); } + } + } + if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) { + ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1; + } + ubatch.n_tokens += length; + ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits + seq.offset += length; + seq.length -= length; + n_tokens -= length; + GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs); +} + +llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) { + n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + ubatch.equal_seqs = false; + if (!seq.empty()) { + llama_sbatch_seq & s = seq[0]; + size_t length = s.length < n_ubatch ? s.length : n_ubatch; + GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits + add_seq_to_ubatch(ubatch, s, length); + } + return ubatch; +} + +llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) { + n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + if (!seq.empty()) { + size_t length = 0; + size_t n_tokens_in_ubatch = 0; + GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits + // smallest first, because it's easier to split this way; + // starting from the end to pop in constant time. + for (size_t i = seq.size(); i-- > 0;) { + llama_sbatch_seq & s = seq[i]; + GGML_ASSERT(s.length > 0); + if (length == 0) { + length = s.length < n_ubatch ? s.length : n_ubatch; + } + add_seq_to_ubatch(ubatch, s, length); + n_tokens_in_ubatch += length; + // shared prompts can't be mixed with any of their sequences, + // so it's safer to compute them in their own ubatch + if (s.n_seq_id > 1) { break; } + // stop when there isn't enough space for another sequence + if (length + n_tokens_in_ubatch > n_ubatch) { break; } + } + } + return ubatch; +} + +llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { + n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + if (!seq.empty()) { + llama_sbatch_seq & s = seq[seq.size() - 1]; + size_t length = s.length < n_ubatch ? s.length : n_ubatch; + GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits + add_seq_to_ubatch(ubatch, s, length); + } + return ubatch; +} + +llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) { + GGML_ASSERT(batch.n_tokens >= 0); + this->batch = &batch; + this->n_embd = n_embd; + + n_tokens = batch.n_tokens; + ids.resize(n_tokens); + out_ids.clear(); + // TODO: reserve out_ids and seq + + for (size_t i = 0; i < n_tokens; ++i) { + ids[i] = i; + } + + if (simple_split) { + seq.resize(1); + llama_sbatch_seq & s = seq[0]; + s.n_seq_id = 0; + s.seq_id = nullptr; + s.offset = 0; + s.length = n_tokens; + return; + } + + std::sort(ids.begin(), ids.end(), + [&batch](size_t a, size_t b) { + int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1; + int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1; + // sort by seq_id, then by pos + if (n_seq_a == n_seq_b) { + if (batch.seq_id) { + for (int32_t i = 0; i < n_seq_a; ++i) { + llama_seq_id seq_id_a = batch.seq_id[a][i]; + llama_seq_id seq_id_b = batch.seq_id[b][i]; + // smaller seq_ids go first + if (seq_id_a != seq_id_b) { + return seq_id_a < seq_id_b; + } + } + } + // when all else is equal, sort by pos + if (batch.pos) { + return batch.pos[a] < batch.pos[b]; + } + // no pos, sort by id + return a < b; + } + // shared prompts go first + return n_seq_a > n_seq_b; + } + ); + + // init seq + llama_sbatch_seq * last_seq = nullptr; + + for (size_t i = 0; i < n_tokens; ++i) { + const size_t bi = ids[i]; + const int32_t n_seqs = batch.n_seq_id[bi]; + llama_seq_id * seq_ids = batch.seq_id[bi]; + if (last_seq != nullptr) { + bool same = n_seqs == last_seq->n_seq_id; + for (int32_t j = 0; same && j < n_seqs; ++j) { + if (seq_ids[j] != last_seq->seq_id[j]) { + same = false; + } + } + if (same) { + last_seq->length += 1; + continue; + } + } + llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1}; + seq.push_back(new_seq); + last_seq = &seq.back(); + } + + // keep shared prompts first at the end, then sort by length descending. + std::sort(seq.begin(), seq.end(), + [](llama_sbatch_seq & a, llama_sbatch_seq & b) { + if (a.n_seq_id == b.n_seq_id) { + return a.length > b.length; + } + return a.n_seq_id < b.n_seq_id; + } + ); +} + +llama_batch_allocr::llama_batch_allocr() { + const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG"); + debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0; + + seq_pos.resize(LLAMA_MAX_SEQ); + seq_cpl.resize(LLAMA_MAX_SEQ); + for (auto & cur : seq_cpl) { + cur.resize(LLAMA_MAX_SEQ); + } +} + +bool llama_batch_allocr::init( + const llama_batch & batch_inp, + const llama_vocab & vocab, + const llama_memory_i * memory, + bool embd_all) { + clear(); + + batch = batch_inp; + + GGML_ASSERT(batch.n_tokens > 0); + + // + // validate input batch + // + + if (batch.token) { + for (int32_t i = 0; i < batch.n_tokens; ++i) { + if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) { + LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); + return false; + } + } + } + + if (batch.seq_id) { + for (int32_t i = 0; i < batch.n_tokens; ++i) { + for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { + if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) { + LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ); + return false; + } + } + } + } + + // + // auto-generate missing fields + // + + if (!batch.n_seq_id) { + n_seq_id.resize(batch.n_tokens); + for (int32_t i = 0; i < batch.n_tokens; i++) { + n_seq_id[i] = seq_id_0.size(); + } + batch.n_seq_id = n_seq_id.data(); + } + + if (!batch.seq_id) { + seq_id.resize(batch.n_tokens + 1); + seq_id[batch.n_tokens] = NULL; + for (int32_t i = 0; i < batch.n_tokens; i++) { + seq_id[i] = seq_id_0.data(); + } + batch.seq_id = seq_id.data(); + } + + if (!batch.pos) { + pos.resize(batch.n_tokens); + + // initialize the starting position for each sequence based on the positions in the memory + llama_pos p0[LLAMA_MAX_SEQ]; + for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (!memory) { + p0[s] = 0; + } else { + p0[s] = memory->seq_pos_max(s) + 1; + } + } + + for (int32_t i = 0; i < batch.n_tokens; i++) { + const llama_seq_id seq_id = batch.seq_id[i][0]; + + pos[i] = p0[seq_id]; + + for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { + p0[batch.seq_id[i][s]] = pos[i] + 1; + } + } + + batch.pos = pos.data(); + } + + if (!batch.logits) { + if (embd_all) { + // return the output for all tokens + output.resize(batch.n_tokens, true); + } else { + // return the output only for the last token + output.resize(batch.n_tokens, false); + output[output.size() - 1] = true; + } + + batch.logits = output.data(); + } else if (embd_all) { + bool warn = false; + + for (int32_t i = 0; i < batch.n_tokens; ++i) { + if (batch.logits[i] == 0) { + warn = true; + } + } + + if (warn) { + LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__); + + output.resize(batch.n_tokens, true); + batch.logits = output.data(); + } + } + + // + // compute stats + // + + for (int32_t i = 0; i < batch.n_tokens; ++i) { + n_outputs += batch.logits[i] != 0; + } + + // determine coupled sequences + // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them + for (int32_t i = 0; i < batch.n_tokens; ++i) { + for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { + seq_pos[batch.seq_id[i][s]].insert(batch.pos[i]); + + if (s > 0) { + const llama_seq_id s0 = batch.seq_id[i][0]; + const llama_seq_id s1 = batch.seq_id[i][s]; + + // mark that sequence s1 is coupled to s0 + seq_cpl[s1][s0] = true; + + // note: the other way around is not necessary for now + //seq_cpl[s0][s1] = true; + } + } + } + + if (debug > 0) { + LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__); + LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens); + LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) batch.token); + LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) batch.embd); + LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) batch.pos); + LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) batch.n_seq_id); + LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) batch.seq_id); + LLAMA_LOG_DEBUG("%s: logits = %p\n", __func__, (void *) batch.logits); + LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs); + + if (debug > 1) { + int seq_id_max = 0; + for (int32_t i = 0; i < batch.n_tokens; ++i) { + for (int s = 0; s < batch.n_seq_id[i]; ++s) { + for (int s = 0; s < batch.n_seq_id[i]; ++s) { + seq_id_max = std::max(seq_id_max, batch.seq_id[i][s]); + } + } + } + ++seq_id_max; + + LLAMA_LOG_DEBUG("%s: token = [\n", __func__); + for (int32_t i = 0; i < batch.n_tokens; ++i) { + std::vector seq_id(seq_id_max); + + for (int s = 0; s < batch.n_seq_id[i]; ++s) { + seq_id[batch.seq_id[i][s]] = 1; + } + + std::stringstream ss; + for (int s = 0; s < seq_id_max; ++s) { + if (seq_id[s]) { + ss << s%10; + } else { + ss << "."; + } + } + + LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n", + __func__, i, batch.token[i], vocab.token_to_piece(batch.token[i]).c_str(), + batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]); + } + LLAMA_LOG_DEBUG("%s: ]\n", __func__); + + LLAMA_LOG_DEBUG("%s: seq = [\n", __func__); + for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) { + if (seq_pos[s0].empty()) { + continue; + } + + std::stringstream ss; + for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) { + if (seq_cpl[s0][s1]) { + ss << s1 << " "; + } + } + + LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n", + __func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str()); + } + LLAMA_LOG_DEBUG("%s: ]\n", __func__); + } + } + + // + // consistency checks + // + + for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (seq_pos[s].empty()) { + continue; + } + + if (memory && seq_pos_min(s) != memory->seq_pos_max(s) + 1) { + LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s); + return false; + } + + if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) { + LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s); + return false; + } + } + + if (memory) { + for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) { + for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) { + if (seq_cpl[s0][s1]) { + if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) || + memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) { + LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1); + return false; + } + } + } + } + } + + return true; +} + +const llama_batch & llama_batch_allocr::get_batch() const { + return batch; +} + +uint32_t llama_batch_allocr::get_n_outputs() const { + return n_outputs; +} + +llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const { + return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin(); +} + +llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const { + return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin(); +} + +void llama_batch_allocr::clear() { + n_outputs = 0; + + batch = {}; + pos.clear(); + n_seq_id.clear(); + seq_id.clear(); + output.clear(); + + for (auto & cur : seq_pos) { + cur.clear(); + } + + for (auto & cur : seq_cpl) { + std::fill(cur.begin(), cur.end(), false); + } +} + +// +// interface implementation +// + +struct llama_batch llama_batch_get_one( + llama_token * tokens, + int32_t n_tokens) { + return { + /*n_tokens =*/ n_tokens, + /*tokens =*/ tokens, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + }; +} + +struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { + llama_batch batch = { + /*n_tokens =*/ 0, + /*tokens =*/ nullptr, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + }; + + if (embd) { + batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); + } else { + batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); + } + + batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc); + batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1)); + for (int i = 0; i < n_tokens_alloc; ++i) { + batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); + } + batch.seq_id[n_tokens_alloc] = nullptr; + + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc); + + return batch; +} + +void llama_batch_free(struct llama_batch batch) { + if (batch.token) free(batch.token); + if (batch.embd) free(batch.embd); + if (batch.pos) free(batch.pos); + if (batch.n_seq_id) free(batch.n_seq_id); + if (batch.seq_id) { + for (int i = 0; batch.seq_id[i] != nullptr; ++i) { + free(batch.seq_id[i]); + } + free(batch.seq_id); + } + if (batch.logits) free(batch.logits); +} diff --git a/src/llama-batch.h b/src/llama-batch.h new file mode 100644 index 0000000000000000000000000000000000000000..a555c157234be82933b0bf7b35b43c4fed2593e4 --- /dev/null +++ b/src/llama-batch.h @@ -0,0 +1,119 @@ +#pragma once + +#include "llama.h" + +#include +#include +#include + +// very similar to llama_batch, +// but has more metadata about sequences +struct llama_ubatch { + bool equal_seqs; + // TODO: whole_seqs for embeddings? + + uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) + uint32_t n_seq_tokens; // tokens per sequence + uint32_t n_seqs; + + llama_token * token; // [n_tokens] + float * embd; // [n_embd, n_tokens] + llama_pos * pos; // [n_tokens] + int32_t * n_seq_id; // [n_seqs] + llama_seq_id ** seq_id; // [n_seqs] + int8_t * output; // [n_tokens] +}; + +struct llama_sbatch_seq { + int32_t n_seq_id; + + llama_seq_id * seq_id; + + size_t offset; + size_t length; +}; + +// sequence-length-aware batch splitting +struct llama_sbatch { + // tokens left in this batch + size_t n_tokens; + + size_t n_embd; + + // sorted indices into the batch + std::vector ids; + // batch indices of the output + std::vector out_ids; + std::vector seq; + + const llama_batch * batch = nullptr; + + // buffers for the ubatches + // TODO: very hacky, this needs a complete rework + struct ubatch_data { + std::vector token; + std::vector embd; + std::vector pos; + std::vector n_seq_id; + std::vector seq_id; + std::vector output; + }; + + std::vector udatas; + + llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false); + + void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length); + + // simple split, unknown number of sequences of unequal lengths + llama_ubatch split_simple(size_t n_ubatch); + + // make batches of equal-length sequences + llama_ubatch split_equal(size_t n_ubatch); + + // sequence-wise split + llama_ubatch split_seq(size_t n_ubatch); + + llama_sbatch() = default; + llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false); +}; + +// a helper for sanitizing and fulfilling a batch +class llama_batch_allocr { +public: + llama_batch_allocr(); + + // sanitize and auto-gen missing data in the input batch + // memory is optional. if provided will be used to check for sequence continuity and to determine the positions + bool init( + const llama_batch & batch_inp, + const llama_vocab & vocab, + const llama_memory_i * memory, + bool embd_all); + + const llama_batch & get_batch() const; + + uint32_t get_n_outputs() const; + + llama_pos seq_pos_min(llama_seq_id seq_id) const; + llama_pos seq_pos_max(llama_seq_id seq_id) const; + +private: + void clear(); + + llama_batch batch; + + uint32_t n_outputs; + + std::array seq_id_0 = { 0 }; // default sequence id + + std::vector pos; + std::vector n_seq_id; + std::vector seq_id; + std::vector output; + + std::vector> seq_pos; // seq_pos[s]: the set of positions in sequence s + std::vector> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1 + + int debug; +}; diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bc4fa05a74ef470796099125835dcc32b0ad52d9 --- /dev/null +++ b/src/llama-chat.cpp @@ -0,0 +1,680 @@ +#include "llama-chat.h" + +#include "llama.h" + +#include +#include +#include + +#if __cplusplus >= 202000L + #define LU8(x) (const char*)(u8##x) +#else + #define LU8(x) u8##x +#endif + +// trim whitespace from the beginning and end of a string +static std::string trim(const std::string & str) { + size_t start = 0; + size_t end = str.size(); + while (start < end && isspace(str[start])) { + start += 1; + } + while (end > start && isspace(str[end - 1])) { + end -= 1; + } + return str.substr(start, end - start); +} + +static const std::map LLM_CHAT_TEMPLATES = { + { "chatml", LLM_CHAT_TEMPLATE_CHATML }, + { "llama2", LLM_CHAT_TEMPLATE_LLAMA_2 }, + { "llama2-sys", LLM_CHAT_TEMPLATE_LLAMA_2_SYS }, + { "llama2-sys-bos", LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS }, + { "llama2-sys-strip", LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP }, + { "mistral-v1", LLM_CHAT_TEMPLATE_MISTRAL_V1 }, + { "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 }, + { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN }, + { "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 }, + { "mistral-v7-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN }, + { "phi3", LLM_CHAT_TEMPLATE_PHI_3 }, + { "phi4", LLM_CHAT_TEMPLATE_PHI_4 }, + { "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 }, + { "zephyr", LLM_CHAT_TEMPLATE_ZEPHYR }, + { "monarch", LLM_CHAT_TEMPLATE_MONARCH }, + { "gemma", LLM_CHAT_TEMPLATE_GEMMA }, + { "orion", LLM_CHAT_TEMPLATE_ORION }, + { "openchat", LLM_CHAT_TEMPLATE_OPENCHAT }, + { "vicuna", LLM_CHAT_TEMPLATE_VICUNA }, + { "vicuna-orca", LLM_CHAT_TEMPLATE_VICUNA_ORCA }, + { "deepseek", LLM_CHAT_TEMPLATE_DEEPSEEK }, + { "deepseek2", LLM_CHAT_TEMPLATE_DEEPSEEK_2 }, + { "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 }, + { "command-r", LLM_CHAT_TEMPLATE_COMMAND_R }, + { "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 }, + { "chatglm3", LLM_CHAT_TEMPLATE_CHATGLM_3 }, + { "chatglm4", LLM_CHAT_TEMPLATE_CHATGLM_4 }, + { "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE }, + { "minicpm", LLM_CHAT_TEMPLATE_MINICPM }, + { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 }, + { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, + { "granite", LLM_CHAT_TEMPLATE_GRANITE }, + { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, + { "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, + { "yandex", LLM_CHAT_TEMPLATE_YANDEX }, + { "bailing", LLM_CHAT_TEMPLATE_BAILING }, + { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, + { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM }, +}; + +llm_chat_template llm_chat_template_from_str(const std::string & name) { + return LLM_CHAT_TEMPLATES.at(name); +} + +llm_chat_template llm_chat_detect_template(const std::string & tmpl) { + try { + return llm_chat_template_from_str(tmpl); + } catch (const std::out_of_range &) { + // ignore + } + + auto tmpl_contains = [&tmpl](const char * haystack) -> bool { + return tmpl.find(haystack) != std::string::npos; + }; + if (tmpl_contains("<|im_start|>")) { + return tmpl_contains("<|im_sep|>") + ? LLM_CHAT_TEMPLATE_PHI_4 + : tmpl_contains("") + ? LLM_CHAT_TEMPLATE_SMOLVLM // SmolVLM uses <|im_start|> as BOS, but it is NOT chatml + : LLM_CHAT_TEMPLATE_CHATML; + } else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) { + if (tmpl_contains("[SYSTEM_PROMPT]")) { + return LLM_CHAT_TEMPLATE_MISTRAL_V7; + } else if ( + // catches official 'v1' template + tmpl_contains("' [INST] ' + system_message") + // catches official 'v3' and 'v3-tekken' templates + || tmpl_contains("[AVAILABLE_TOOLS]") + ) { + // Official mistral 'v1', 'v3' and 'v3-tekken' templates + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md + if (tmpl_contains(" [INST]")) { + return LLM_CHAT_TEMPLATE_MISTRAL_V1; + } else if (tmpl_contains("\"[INST]\"")) { + return LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN; + } + return LLM_CHAT_TEMPLATE_MISTRAL_V3; + } else { + // llama2 template and its variants + // [variant] support system message + // See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + bool support_system_message = tmpl_contains("<>"); + bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]"); + bool strip_message = tmpl_contains("content.strip()"); + if (strip_message) { + return LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP; + } else if (add_bos_inside_history) { + return LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS; + } else if (support_system_message) { + return LLM_CHAT_TEMPLATE_LLAMA_2_SYS; + } else { + return LLM_CHAT_TEMPLATE_LLAMA_2; + } + } + } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) { + return LLM_CHAT_TEMPLATE_PHI_3; + } else if (tmpl_contains("[gMASK]")) { + return LLM_CHAT_TEMPLATE_CHATGLM_4; + } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) { + return tmpl_contains("") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE; + } else if (tmpl_contains("<|{{ item['role'] }}|>") && tmpl_contains("<|begin_of_image|>")) { + return LLM_CHAT_TEMPLATE_GLMEDGE; + } else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) { + return LLM_CHAT_TEMPLATE_ZEPHYR; + } else if (tmpl_contains("bos_token + message['role']")) { + return LLM_CHAT_TEMPLATE_MONARCH; + } else if (tmpl_contains("")) { + return LLM_CHAT_TEMPLATE_GEMMA; + } else if (tmpl_contains("'\\n\\nAssistant: ' + eos_token")) { + // OrionStarAI/Orion-14B-Chat + return LLM_CHAT_TEMPLATE_ORION; + } else if (tmpl_contains("GPT4 Correct ")) { + // openchat/openchat-3.5-0106 + return LLM_CHAT_TEMPLATE_OPENCHAT; + } else if (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: ")) { + // eachadea/vicuna-13b-1.1 (and Orca variant) + if (tmpl_contains("SYSTEM: ")) { + return LLM_CHAT_TEMPLATE_VICUNA_ORCA; + } + return LLM_CHAT_TEMPLATE_VICUNA; + } else if (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>")) { + // deepseek-ai/deepseek-coder-33b-instruct + return LLM_CHAT_TEMPLATE_DEEPSEEK; + } else if (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>")) { + // CohereForAI/c4ai-command-r-plus + return LLM_CHAT_TEMPLATE_COMMAND_R; + } else if (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>")) { + return LLM_CHAT_TEMPLATE_LLAMA_3; + } else if (tmpl_contains("[gMASK]sop")) { + // chatglm3-6b + return LLM_CHAT_TEMPLATE_CHATGLM_3; + } else if (tmpl_contains(LU8("<用户>"))) { + // MiniCPM-3B-OpenHermes-2.5-v2-GGUF + return LLM_CHAT_TEMPLATE_MINICPM; + } else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) { + return LLM_CHAT_TEMPLATE_DEEPSEEK_2; + } else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) { + return LLM_CHAT_TEMPLATE_DEEPSEEK_3; + } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) { + // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb + // EXAONE-3.0-7.8B-Instruct + return LLM_CHAT_TEMPLATE_EXAONE_3; + } else if (tmpl_contains("rwkv-world")) { + return LLM_CHAT_TEMPLATE_RWKV_WORLD; + } else if (tmpl_contains("<|start_of_role|>")) { + return LLM_CHAT_TEMPLATE_GRANITE; + } else if (tmpl_contains("message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1]")) { + return LLM_CHAT_TEMPLATE_GIGACHAT; + } else if (tmpl_contains("<|role_start|>")) { + return LLM_CHAT_TEMPLATE_MEGREZ; + } else if (tmpl_contains(" Ассистент:")) { + return LLM_CHAT_TEMPLATE_YANDEX; + } else if (tmpl_contains("ASSISTANT") && tmpl_contains("'HUMAN'")) { + return LLM_CHAT_TEMPLATE_BAILING; + } else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) { + return LLM_CHAT_TEMPLATE_LLAMA4; + } else if (tmpl_contains("<|endofuserprompt|>")) { + return LLM_CHAT_TEMPLATE_DOTS1; + } + return LLM_CHAT_TEMPLATE_UNKNOWN; +} + +// Simple version of "llama_apply_chat_template" that only works with strings +// This function uses heuristic checks to determine commonly used template. It is not a jinja parser. +int32_t llm_chat_apply_template( + llm_chat_template tmpl, + const std::vector & chat, + std::string & dest, bool add_ass) { + // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 + std::stringstream ss; + if (tmpl == LLM_CHAT_TEMPLATE_CHATML) { + // chatml template + for (auto message : chat) { + ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n"; + } + if (add_ass) { + ss << "<|im_start|>assistant\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN) { + // Official mistral 'v7' template + // See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7 + // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503#basic-instruct-template-v7-tekken + const char * trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 ? " " : ""; + for (auto message : chat) { + std::string role(message->role); + std::string content(message->content); + if (role == "system") { + ss << "[SYSTEM_PROMPT]" << trailing_space << content << "[/SYSTEM_PROMPT]"; + } else if (role == "user") { + ss << "[INST]" << trailing_space << content << "[/INST]"; + } else { + ss << trailing_space << content << ""; + } + } + } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 + || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3 + || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN) { + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md + std::string leading_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 ? " " : ""; + std::string trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN ? "" : " "; + bool trim_assistant_message = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3; + bool is_inside_turn = false; + for (auto message : chat) { + if (!is_inside_turn) { + ss << leading_space << "[INST]" << trailing_space; + is_inside_turn = true; + } + std::string role(message->role); + std::string content(message->content); + if (role == "system") { + ss << content << "\n\n"; + } else if (role == "user") { + ss << content << leading_space << "[/INST]"; + } else { + ss << trailing_space << (trim_assistant_message ? trim(content) : content) << ""; + is_inside_turn = false; + } + } + } else if ( + tmpl == LLM_CHAT_TEMPLATE_LLAMA_2 + || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS + || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS + || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP) { + // llama2 template and its variants + // [variant] support system message + // See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + bool support_system_message = tmpl != LLM_CHAT_TEMPLATE_LLAMA_2; + // [variant] add BOS inside history + bool add_bos_inside_history = tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS; + // [variant] trim spaces from the input message + bool strip_message = tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP; + // construct the prompt + bool is_inside_turn = true; // skip BOS at the beginning + ss << "[INST] "; + for (auto message : chat) { + std::string content = strip_message ? trim(message->content) : message->content; + std::string role(message->role); + if (!is_inside_turn) { + is_inside_turn = true; + ss << (add_bos_inside_history ? "[INST] " : "[INST] "); + } + if (role == "system") { + if (support_system_message) { + ss << "<>\n" << content << "\n<>\n\n"; + } else { + // if the model does not support system message, we still include it in the first message, but without <> + ss << content << "\n"; + } + } else if (role == "user") { + ss << content << " [/INST]"; + } else { + ss << content << ""; + is_inside_turn = false; + } + } + } else if (tmpl == LLM_CHAT_TEMPLATE_PHI_3) { + // Phi 3 + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>\n" << message->content << "<|end|>\n"; + } + if (add_ass) { + ss << "<|assistant|>\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_PHI_4) { + // chatml template + for (auto message : chat) { + ss << "<|im_start|>" << message->role << "<|im_sep|>" << message->content << "<|im_end|>"; + } + if (add_ass) { + ss << "<|im_start|>assistant<|im_sep|>"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_FALCON_3) { + // Falcon 3 + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>\n" << message->content << "\n"; + } + if (add_ass) { + ss << "<|assistant|>\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_ZEPHYR) { + // zephyr template + for (auto message : chat) { + ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n"; + } + if (add_ass) { + ss << "<|assistant|>\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_MONARCH) { + // mlabonne/AlphaMonarch-7B template (the is included inside history) + for (auto message : chat) { + std::string bos = (message == chat.front()) ? "" : ""; // skip BOS for first message + ss << bos << message->role << "\n" << message->content << "\n"; + } + if (add_ass) { + ss << "assistant\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_GEMMA) { + // google/gemma-7b-it + std::string system_prompt = ""; + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + // there is no system message for gemma, but we will merge it with user prompt, so nothing is broken + system_prompt = trim(message->content); + continue; + } + // in gemma, "assistant" is "model" + role = role == "assistant" ? "model" : message->role; + ss << "" << role << "\n"; + if (!system_prompt.empty() && role != "model") { + ss << system_prompt << "\n\n"; + system_prompt = ""; + } + ss << trim(message->content) << "\n"; + } + if (add_ass) { + ss << "model\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_ORION) { + // OrionStarAI/Orion-14B-Chat + std::string system_prompt = ""; + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + // there is no system message support, we will merge it with user prompt + system_prompt = message->content; + continue; + } else if (role == "user") { + ss << "Human: "; + if (!system_prompt.empty()) { + ss << system_prompt << "\n\n"; + system_prompt = ""; + } + ss << message->content << "\n\nAssistant: "; + } else { + ss << message->content << ""; + } + } + } else if (tmpl == LLM_CHAT_TEMPLATE_OPENCHAT) { + // openchat/openchat-3.5-0106, + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content << "<|end_of_turn|>"; + } else { + role[0] = toupper(role[0]); + ss << "GPT4 Correct " << role << ": " << message->content << "<|end_of_turn|>"; + } + } + if (add_ass) { + ss << "GPT4 Correct Assistant:"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_VICUNA || tmpl == LLM_CHAT_TEMPLATE_VICUNA_ORCA) { + // eachadea/vicuna-13b-1.1 (and Orca variant) + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + // Orca-Vicuna variant uses a system prefix + if (tmpl == LLM_CHAT_TEMPLATE_VICUNA_ORCA) { + ss << "SYSTEM: " << message->content << "\n"; + } else { + ss << message->content << "\n\n"; + } + } else if (role == "user") { + ss << "USER: " << message->content << "\n"; + } else if (role == "assistant") { + ss << "ASSISTANT: " << message->content << "\n"; + } + } + if (add_ass) { + ss << "ASSISTANT:"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK) { + // deepseek-ai/deepseek-coder-33b-instruct + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content; + } else if (role == "user") { + ss << "### Instruction:\n" << message->content << "\n"; + } else if (role == "assistant") { + ss << "### Response:\n" << message->content << "\n<|EOT|>\n"; + } + } + if (add_ass) { + ss << "### Response:\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_COMMAND_R) { + // CohereForAI/c4ai-command-r-plus + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; + } else if (role == "user") { + ss << "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; + } else if (role == "assistant") { + ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; + } + } + if (add_ass) { + ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA_3) { + // Llama 3 + for (auto message : chat) { + std::string role(message->role); + ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n" << trim(message->content) << "<|eot_id|>"; + } + if (add_ass) { + ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_3) { + // chatglm3-6b + ss << "[gMASK]" << "sop"; + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>" << "\n " << message->content; + } + if (add_ass) { + ss << "<|assistant|>"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4) { + ss << "[gMASK]" << ""; + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>" << "\n" << message->content; + } + if (add_ass) { + ss << "<|assistant|>\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) { + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>" << "\n" << message->content; + } + if (add_ass) { + ss << "<|assistant|>"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) { + // MiniCPM-3B-OpenHermes-2.5-v2-GGUF + for (auto message : chat) { + std::string role(message->role); + if (role == "user") { + ss << LU8("<用户>"); + ss << trim(message->content); + ss << ""; + } else { + ss << trim(message->content); + } + } + } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_2) { + // DeepSeek-V2 + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content << "\n\n"; + } else if (role == "user") { + ss << "User: " << message->content << "\n\n"; + } else if (role == "assistant") { + ss << "Assistant: " << message->content << LU8("<|end▁of▁sentence|>"); + } + } + if (add_ass) { + ss << "Assistant:"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_3) { + // DeepSeek-V3 + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content << "\n\n"; + } else if (role == "user") { + ss << LU8("<|User|>") << message->content; + } else if (role == "assistant") { + ss << LU8("<|Assistant|>") << message->content << LU8("<|end▁of▁sentence|>"); + } + } + if (add_ass) { + ss << LU8("<|Assistant|>"); + } + } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) { + // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb + // EXAONE-3.0-7.8B-Instruct + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n"; + } else if (role == "user") { + ss << "[|user|]" << trim(message->content) << "\n"; + } else if (role == "assistant") { + ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n"; + } + } + if (add_ass) { + ss << "[|assistant|]"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) { + // this template requires the model to have "\n\n" as EOT token + for (auto message : chat) { + std::string role(message->role); + if (role == "user") { + ss << "User: " << message->content << "\n\nAssistant:"; + } else { + ss << message->content << "\n\n"; + } + } + } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) { + // IBM Granite template + for (const auto & message : chat) { + std::string role(message->role); + ss << "<|start_of_role|>" << role << "<|end_of_role|>"; + if (role == "assistant_tool_call") { + ss << "<|tool_call|>"; + } + ss << message->content << "<|end_of_text|>\n"; + } + if (add_ass) { + ss << "<|start_of_role|>assistant<|end_of_role|>\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) { + // GigaChat template + bool has_system = !chat.empty() && std::string(chat[0]->role) == "system"; + + // Handle system message if present + if (has_system) { + ss << "" << chat[0]->content << "<|message_sep|>"; + } else { + ss << ""; + } + + // Process remaining messages + for (size_t i = has_system ? 1 : 0; i < chat.size(); i++) { + std::string role(chat[i]->role); + if (role == "user") { + ss << "user<|role_sep|>" << chat[i]->content << "<|message_sep|>" + << "available functions<|role_sep|>[]<|message_sep|>"; + } else if (role == "assistant") { + ss << "assistant<|role_sep|>" << chat[i]->content << "<|message_sep|>"; + } + } + + // Add generation prompt if needed + if (add_ass) { + ss << "assistant<|role_sep|>"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_MEGREZ) { + // Megrez template + for (auto message : chat) { + std::string role(message->role); + ss << "<|role_start|>" << role << "<|role_end|>" << message->content << "<|turn_end|>"; + } + + if (add_ass) { + ss << "<|role_start|>assistant<|role_end|>"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_YANDEX) { + // Yandex template ("\n\n" is defined as EOT token) + + ss << ""; + + for (size_t i = 0; i < chat.size(); i++) { + std::string role(chat[i]->role); + if (role == "user") { + ss << " Пользователь: " << chat[i]->content << "\n\n"; + } else if (role == "assistant") { + ss << " Ассистент: " << chat[i]->content << "\n\n"; + } + } + + // Add generation prompt if needed + if (add_ass) { + ss << " Ассистент:[SEP]"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_BAILING) { + // Bailing (Ling) template + for (auto message : chat) { + std::string role(message->role); + + if (role == "user") { + role = "HUMAN"; + } else { + std::transform(role.begin(), role.end(), role.begin(), ::toupper); + } + + ss << "" << role << "" << message->content; + } + + if (add_ass) { + ss << "ASSISTANT"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA4) { + // Llama 4 + for (auto message : chat) { + std::string role(message->role); + ss << "<|header_start|>" << role << "<|header_end|>\n\n" << trim(message->content) << "<|eot|>"; + } + if (add_ass) { + ss << "<|header_start|>assistant<|header_end|>\n\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_SMOLVLM) { + // SmolVLM + ss << "<|im_start|>"; // uses <|im_start|> as BOS, but the actual content is NOT chatml + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content << "\n\n"; + } else if (role == "user") { + ss << "User: " << message->content << "\n"; + } else { + ss << "Assistant: " << message->content << "\n"; + } + } + if (add_ass) { + ss << "Assistant:"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_DOTS1) { + // dots.llm1.inst (DOTS1) + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "<|system|>" << message->content << "<|endofsystem|>"; + } else if (role == "user") { + ss << "<|userprompt|>" << message->content << "<|endofuserprompt|>"; + } else { + ss << "<|response|>" << message->content << "<|endofresponse|>"; + } + } + if (add_ass) { + ss << "<|response|>"; + } + } else { + // template not supported + return -1; + } + dest = ss.str(); + return dest.size(); +} + +// public interface + +int32_t llama_chat_builtin_templates(const char ** output, size_t len) { + auto it = LLM_CHAT_TEMPLATES.begin(); + for (size_t i = 0; i < std::min(len, LLM_CHAT_TEMPLATES.size()); i++) { + output[i] = it->first.c_str(); + std::advance(it, 1); + } + return (int32_t) LLM_CHAT_TEMPLATES.size(); +} diff --git a/src/llama-chat.h b/src/llama-chat.h new file mode 100644 index 0000000000000000000000000000000000000000..38800010ae48b5da3474fb47ee4c490d693db68f --- /dev/null +++ b/src/llama-chat.h @@ -0,0 +1,59 @@ +#pragma once + +#include +#include +#include + +enum llm_chat_template { + LLM_CHAT_TEMPLATE_CHATML, + LLM_CHAT_TEMPLATE_LLAMA_2, + LLM_CHAT_TEMPLATE_LLAMA_2_SYS, + LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS, + LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP, + LLM_CHAT_TEMPLATE_MISTRAL_V1, + LLM_CHAT_TEMPLATE_MISTRAL_V3, + LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN, + LLM_CHAT_TEMPLATE_MISTRAL_V7, + LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN, + LLM_CHAT_TEMPLATE_PHI_3, + LLM_CHAT_TEMPLATE_PHI_4, + LLM_CHAT_TEMPLATE_FALCON_3, + LLM_CHAT_TEMPLATE_ZEPHYR, + LLM_CHAT_TEMPLATE_MONARCH, + LLM_CHAT_TEMPLATE_GEMMA, + LLM_CHAT_TEMPLATE_ORION, + LLM_CHAT_TEMPLATE_OPENCHAT, + LLM_CHAT_TEMPLATE_VICUNA, + LLM_CHAT_TEMPLATE_VICUNA_ORCA, + LLM_CHAT_TEMPLATE_DEEPSEEK, + LLM_CHAT_TEMPLATE_DEEPSEEK_2, + LLM_CHAT_TEMPLATE_DEEPSEEK_3, + LLM_CHAT_TEMPLATE_COMMAND_R, + LLM_CHAT_TEMPLATE_LLAMA_3, + LLM_CHAT_TEMPLATE_CHATGLM_3, + LLM_CHAT_TEMPLATE_CHATGLM_4, + LLM_CHAT_TEMPLATE_GLMEDGE, + LLM_CHAT_TEMPLATE_MINICPM, + LLM_CHAT_TEMPLATE_EXAONE_3, + LLM_CHAT_TEMPLATE_RWKV_WORLD, + LLM_CHAT_TEMPLATE_GRANITE, + LLM_CHAT_TEMPLATE_GIGACHAT, + LLM_CHAT_TEMPLATE_MEGREZ, + LLM_CHAT_TEMPLATE_YANDEX, + LLM_CHAT_TEMPLATE_BAILING, + LLM_CHAT_TEMPLATE_LLAMA4, + LLM_CHAT_TEMPLATE_SMOLVLM, + LLM_CHAT_TEMPLATE_DOTS1, + LLM_CHAT_TEMPLATE_UNKNOWN, +}; + +struct llama_chat_message; + +llm_chat_template llm_chat_template_from_str(const std::string & name); + +llm_chat_template llm_chat_detect_template(const std::string & tmpl); + +int32_t llm_chat_apply_template( + llm_chat_template tmpl, + const std::vector & chat, + std::string & dest, bool add_ass); diff --git a/src/llama-context.cpp b/src/llama-context.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f56a58e9b6ec6208c92e0e38c4a879c5af44937c --- /dev/null +++ b/src/llama-context.cpp @@ -0,0 +1,2851 @@ +#include "llama-context.h" + +#include "llama-impl.h" +#include "llama-batch.h" +#include "llama-io.h" +#include "llama-memory.h" +#include "llama-mmap.h" +#include "llama-model.h" + +#include +#include +#include +#include + +// +// llama_context +// + +llama_context::llama_context( + const llama_model & model, + llama_context_params params) : + model(model), + batch_allocr(std::make_unique()) { + LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__); + + t_start_us = model.t_start_us; + t_load_us = model.t_load_us; + + const auto & hparams = model.hparams; + + cparams.n_seq_max = std::max(1u, params.n_seq_max); + if (cparams.n_seq_max > LLAMA_MAX_SEQ) { + throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ)); + } + + cparams.n_threads = params.n_threads; + cparams.n_threads_batch = params.n_threads_batch; + cparams.yarn_ext_factor = params.yarn_ext_factor; + cparams.yarn_attn_factor = params.yarn_attn_factor; + cparams.yarn_beta_fast = params.yarn_beta_fast; + cparams.yarn_beta_slow = params.yarn_beta_slow; + cparams.defrag_thold = params.defrag_thold; + cparams.embeddings = params.embeddings; + cparams.offload_kqv = params.offload_kqv; + cparams.flash_attn = params.flash_attn; + cparams.no_perf = params.no_perf; + cparams.pooling_type = params.pooling_type; + cparams.warmup = false; + + cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; + cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; + cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; + + cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx : + hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn : + hparams.n_ctx_train; + + cparams.cb_eval = params.cb_eval; + cparams.cb_eval_user_data = params.cb_eval_user_data; + + auto rope_scaling_type = params.rope_scaling_type; + if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) { + rope_scaling_type = hparams.rope_scaling_type_train; + } + + if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) { + cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none + } + + if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set' + cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; + } + + cparams.yarn_attn_factor *= hparams.rope_attn_factor; + + if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { + if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { + cparams.pooling_type = LLAMA_POOLING_TYPE_NONE; + } else { + cparams.pooling_type = hparams.pooling_type; + } + } + + if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) { + cparams.causal_attn = hparams.causal_attn; + } else { + cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL; + } + + // with causal attention, the batch size is limited by the context size + cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; + + // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask + // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) + // ref: https://github.com/ggerganov/llama.cpp/pull/5021 + // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self + if (cparams.n_batch < GGML_KQ_MASK_PAD) { + LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD); + cparams.n_batch = GGML_KQ_MASK_PAD; + } + + cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); + + cparams.op_offload = params.op_offload; + + const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; + + LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max); + LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); + LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq); + LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); + LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); + LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); + LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); + LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); + LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); + + if (n_ctx_per_seq < hparams.n_ctx_train) { + LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", + __func__, n_ctx_per_seq, hparams.n_ctx_train); + } + + if (n_ctx_per_seq > hparams.n_ctx_train) { + LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n", + __func__, n_ctx_per_seq, hparams.n_ctx_train); + } + + if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) { + LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n", + __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573"); + } + + if (!hparams.vocab_only) { + // GPU backends + for (auto * dev : model.devices) { + ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); + if (backend == nullptr) { + throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev))); + } + backends.emplace_back(backend); + } + + // add ACCEL backends (such as BLAS) + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { + ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); + if (backend == nullptr) { + throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev))); + } + backends.emplace_back(backend); + } + } + + // add CPU backend + backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); + if (backend_cpu == nullptr) { + throw std::runtime_error("failed to initialize CPU backend"); + } + backends.emplace_back(backend_cpu); + + // create a list of the set_n_threads functions in the backends + for (auto & backend : backends) { + ggml_backend_dev_t dev = ggml_backend_get_device(backend.get()); + ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; + if (reg) { + auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (ggml_backend_set_n_threads_fn) { + set_n_threads_fns.emplace_back(backend.get(), ggml_backend_set_n_threads_fn); + } + } + } + + llama_set_abort_callback(this, params.abort_callback, params.abort_callback_data); + + // graph outputs buffer + { + // resized during inference when a batch uses more outputs + if ((uint32_t) output_reserve(params.n_seq_max) < params.n_seq_max) { + throw std::runtime_error("failed to reserve initial output buffer"); + } + + LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__, + ggml_backend_buffer_name (buf_output.get()), + ggml_backend_buffer_get_size(buf_output.get()) / 1024.0 / 1024.0); + } + } + + // init the memory module + if (!hparams.vocab_only) { + llama_memory_params params_mem = { + /*.type_k =*/ params.type_k, + /*.type_v =*/ params.type_v, + /*.swa_full =*/ params.swa_full, + }; + + memory.reset(model.create_memory(params_mem, cparams)); + } + + // init backends + if (!hparams.vocab_only) { + LLAMA_LOG_DEBUG("%s: enumerating backends\n", __func__); + + backend_buft.clear(); + backend_ptrs.clear(); + + for (auto & backend : backends) { + auto * buft = ggml_backend_get_default_buffer_type(backend.get()); + auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); + + if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) { + // use the host buffer of the first device CPU for faster transfer of the intermediate state + auto * dev = model.devices[0]; + auto * host_buft = ggml_backend_dev_host_buffer_type(dev); + if (host_buft) { + buft = host_buft; + } + } + + backend_buft.push_back(buft); + backend_ptrs.push_back(backend.get()); + } + + LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size()); + + const size_t max_nodes = this->graph_max_nodes(); + + LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); + + // buffer used to store the computation graph and the tensor meta data + buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false)); + + // TODO: move these checks to ggml_backend_sched + // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary + bool pipeline_parallel = + model.n_devices() > 1 && + model.params.n_gpu_layers > (int) model.hparams.n_layer && + model.params.split_mode == LLAMA_SPLIT_MODE_LAYER && + cparams.offload_kqv && + !model.has_tensor_overrides(); + + // pipeline parallelism requires support for async compute and events in all devices + if (pipeline_parallel) { + for (auto & backend : backends) { + auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); + if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) { + // ignore CPU backend + continue; + } + auto * dev = ggml_backend_get_device(backend.get()); + ggml_backend_dev_props props; + ggml_backend_dev_get_props(dev, &props); + if (!props.caps.async || !props.caps.events) { + // device does not support async compute or events + pipeline_parallel = false; + break; + } + } + } + + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload)); + + if (pipeline_parallel) { + LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); + } + } + + // reserve worst-case graph + if (!hparams.vocab_only && memory) { + const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + + LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); + + int n_splits_pp = -1; + int n_nodes_pp = -1; + + int n_splits_tg = -1; + int n_nodes_tg = -1; + + // simulate full KV cache + + const auto mstate = memory->init_full(); + if (!mstate) { + throw std::runtime_error("failed to initialize KV cache"); + } + + cross.v_embd.clear(); + + // reserve pp graph first so that buffers are only allocated once + { + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); + if (!gf) { + throw std::runtime_error("failed to allocate compute pp buffers"); + } + + n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); + n_nodes_pp = ggml_graph_n_nodes(gf); + } + + // reserve with tg graph to get the number of splits and nodes + { + auto * gf = graph_reserve(1, 1, 1, mstate.get()); + if (!gf) { + throw std::runtime_error("failed to allocate compute tg buffers"); + } + + n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); + n_nodes_tg = ggml_graph_n_nodes(gf); + } + + // reserve again with pp graph to avoid ggml-alloc reallocations during inference + { + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); + if (!gf) { + throw std::runtime_error("failed to allocate compute pp buffers"); + } + } + + for (size_t i = 0; i < backend_ptrs.size(); ++i) { + ggml_backend_t backend = backend_ptrs[i]; + ggml_backend_buffer_type_t buft = backend_buft[i]; + size_t size = ggml_backend_sched_get_buffer_size(sched.get(), backend); + if (size > 1) { + LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, + ggml_backend_buft_name(buft), + size / 1024.0 / 1024.0); + } + } + + if (n_nodes_pp == n_nodes_tg) { + LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); + } else { + LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); + } + + if (n_splits_pp == n_splits_tg) { + LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); + } else { + LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); + } + } +} + +llama_context::~llama_context() { + ggml_opt_free(opt_ctx); +} + +void llama_context::synchronize() { + ggml_backend_sched_synchronize(sched.get()); + + // FIXME: if multiple single tokens are evaluated without a synchronization, + // the stats will be added to the prompt evaluation stats + // this should only happen when using batch size 1 to evaluate a batch + + // add the evaluation to the stats + if (n_queued_tokens == 1) { + if (!cparams.no_perf) { + t_eval_us += ggml_time_us() - t_compute_start_us; + } + n_eval++; + } else if (n_queued_tokens > 1) { + if (!cparams.no_perf) { + t_p_eval_us += ggml_time_us() - t_compute_start_us; + } + n_p_eval += n_queued_tokens; + } + + // get a more accurate load time, upon first eval + if (n_queued_tokens > 0 && !has_evaluated_once) { + t_load_us = ggml_time_us() - t_start_us; + has_evaluated_once = true; + } + + n_queued_tokens = 0; + t_compute_start_us = 0; +} + +const llama_model & llama_context::get_model() const { + return model; +} + +const llama_cparams & llama_context::get_cparams() const { + return cparams; +} + +ggml_backend_sched_t llama_context::get_sched() const { + return sched.get(); +} + +ggml_context * llama_context::get_ctx_compute() const { + return ctx_compute.get(); +} + +uint32_t llama_context::n_ctx() const { + return cparams.n_ctx; +} + +uint32_t llama_context::n_ctx_per_seq() const { + return cparams.n_ctx / cparams.n_seq_max; +} + +uint32_t llama_context::n_batch() const { + return cparams.n_batch; +} + +uint32_t llama_context::n_ubatch() const { + return cparams.n_ubatch; +} + +uint32_t llama_context::n_seq_max() const { + return cparams.n_seq_max; +} + +uint32_t llama_context::n_threads() const { + return cparams.n_threads; +} + +uint32_t llama_context::n_threads_batch() const { + return cparams.n_threads_batch; +} + +llama_memory_t llama_context::get_memory() const { + return memory.get(); +} + +// deprecated +void llama_context::kv_self_defrag_sched() { + if (!memory) { + return; + } + + memory_force_optimize = true; +} + +// deprecated +bool llama_context::kv_self_update(bool optimize) { + if (!memory) { + return false; + } + + { + // TODO: remove in the future + optimize |= memory_force_optimize; + memory_force_optimize = false; + + const auto mstate = memory->init_update(this, optimize); + switch (mstate->get_status()) { + case LLAMA_MEMORY_STATUS_SUCCESS: + { + // noop + } break; + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + // no updates need to be performed + return false; + } + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__); + return false; + } + } + + if (!mstate->apply()) { + LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__); + } + } + + // if the memory module did any computation, we have to reserve a new worst-case graph + { + const auto mstate = memory->init_full(); + if (!mstate) { + throw std::runtime_error("failed to initialize memory state"); + } + + const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); + if (!gf) { + LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); + } + } + + return true; +} + +enum llama_pooling_type llama_context::pooling_type() const { + return cparams.pooling_type; +} + +float * llama_context::get_logits() { + return logits; +} + +float * llama_context::get_logits_ith(int32_t i) { + int64_t j = -1; + + try { + if (logits == nullptr) { + throw std::runtime_error("no logits"); + } + + if (i < 0) { + j = n_outputs + i; + if (j < 0) { + throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); + } + } else if ((size_t) i >= output_ids.size()) { + throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); + } else { + j = output_ids[i]; + } + + if (j < 0) { + throw std::runtime_error(format("batch.logits[%d] != true", i)); + } + if (j >= n_outputs) { + // This should not happen + throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); + } + + return logits + j*model.vocab.n_tokens(); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); +#ifndef NDEBUG + GGML_ABORT("fatal error"); +#else + return nullptr; +#endif + } +} + +float * llama_context::get_embeddings() { + return embd; +} + +float * llama_context::get_embeddings_ith(int32_t i) { + int64_t j = -1; + + try { + if (embd == nullptr) { + throw std::runtime_error("no embeddings"); + } + + if (i < 0) { + j = n_outputs + i; + if (j < 0) { + throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); + } + } else if ((size_t) i >= output_ids.size()) { + throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); + } else { + j = output_ids[i]; + } + + if (j < 0) { + throw std::runtime_error(format("batch.logits[%d] != true", i)); + } + if (j >= n_outputs) { + // This should not happen + throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); + } + + return embd + j*model.hparams.n_embd; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); +#ifndef NDEBUG + GGML_ABORT("fatal error"); +#else + return nullptr; +#endif + } +} + +float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { + auto it = embd_seq.find(seq_id); + if (it == embd_seq.end()) { + return nullptr; + } + + return it->second.data(); +} + +void llama_context::attach_threadpool( + ggml_threadpool_t threadpool, + ggml_threadpool_t threadpool_batch) { + LLAMA_LOG_DEBUG("%s: call\n", __func__); + + this->threadpool = threadpool; + this->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool; +} + +void llama_context::detach_threadpool() { + LLAMA_LOG_DEBUG("%s: call\n", __func__); + + this->threadpool = nullptr; + this->threadpool_batch = nullptr; +} + +void llama_context::set_n_threads(int32_t n_threads, int32_t n_threads_batch) { + LLAMA_LOG_DEBUG("%s: n_threads = %d, n_threads_batch = %d\n", __func__, n_threads, n_threads_batch); + + cparams.n_threads = n_threads; + cparams.n_threads_batch = n_threads_batch; +} + +void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) { + LLAMA_LOG_DEBUG("%s: call\n", __func__); + + this->abort_callback = abort_callback; + this->abort_callback_data = abort_callback_data; + + for (auto & backend : backends) { + auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); + auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); + if (set_abort_callback_fn) { + set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data); + } + } +} + +void llama_context::set_embeddings(bool value) { + LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + + cparams.embeddings = value; +} + +void llama_context::set_causal_attn(bool value) { + LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + + cparams.causal_attn = value; +} + +void llama_context::set_warmup(bool value) { + LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + + cparams.warmup = value; +} + +void llama_context::set_adapter_lora( + llama_adapter_lora * adapter, + float scale) { + LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale); + + loras[adapter] = scale; +} + +bool llama_context::rm_adapter_lora( + llama_adapter_lora * adapter) { + LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter); + + auto pos = loras.find(adapter); + if (pos != loras.end()) { + loras.erase(pos); + return true; + } + + return false; +} + +void llama_context::clear_adapter_lora() { + LLAMA_LOG_DEBUG("%s: call\n", __func__); + + loras.clear(); +} + +bool llama_context::apply_adapter_cvec( + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end) { + LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end); + + return cvec.apply(model, data, len, n_embd, il_start, il_end); +} + +llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) { + if (mstate && !mstate->apply()) { + LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + + auto * gf = graph_init(); + if (!gf) { + LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + + auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + + // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); + + if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); + ret = GGML_STATUS_ALLOC_FAILED; + return nullptr; + } + + res->set_inputs(&ubatch); + + const auto status = graph_compute(gf, ubatch.n_tokens > 1); + if (status != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); + ret = status; + return nullptr; + } + + ret = GGML_STATUS_SUCCESS; + + return res; +} + +int llama_context::encode(const llama_batch & batch_inp) { + if (batch_inp.n_tokens == 0) { + LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); + return -1; + } + + // note: during encode, we always pass the full sequence starting from pos = 0 + if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) { + LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); + return -1; + } + + const llama_batch & batch = batch_allocr->get_batch(); + + const uint32_t n_tokens = batch.n_tokens; + + GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + + // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot + GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens"); + + if (t_compute_start_us == 0) { + t_compute_start_us = ggml_time_us(); + } + + // TODO: this clear of the buffer can easily be forgotten - need something better + embd_seq.clear(); + + n_queued_tokens += n_tokens; + + const auto & hparams = model.hparams; + + const int64_t n_embd = hparams.n_embd; + + llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true); + + const llama_ubatch ubatch = sbatch.split_simple(n_tokens); + + // reserve output buffer + if (output_reserve(n_tokens) < n_tokens) { + LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); + return -2; + }; + + for (uint32_t i = 0; i < n_tokens; ++i) { + output_ids[i] = i; + } + + n_outputs = n_tokens; + + ggml_backend_sched_reset(sched.get()); + ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + + const auto causal_attn_org = cparams.causal_attn; + + // always use non-causal attention for encoder graphs + // TODO: this is a tmp solution until we have a proper way to support enc-dec models + // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223 + cparams.causal_attn = false; + + ggml_status status; + const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); + + cparams.causal_attn = causal_attn_org; + + if (!res) { + switch (status) { + case GGML_STATUS_ABORTED: return 2; + case GGML_STATUS_ALLOC_FAILED: return -2; + case GGML_STATUS_FAILED: return -3; + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); + } + } + + auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + + // extract embeddings + if (t_embd) { + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); + GGML_ASSERT(backend_embd != nullptr); + + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // extract token embeddings + GGML_ASSERT(embd != nullptr); + + GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float)); + } break; + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + // extract sequence embeddings + auto & embd_seq_out = embd_seq; + embd_seq_out.clear(); + + GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits + + // TODO: fix indexing [UBATCH_IDX] + for (uint32_t i = 0; i < n_tokens; i++) { + const llama_seq_id seq_id = ubatch.seq_id[i][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(n_embd); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_RANK: + { + // extract the rerank score - n_cls_out floats per sequence + auto & embd_seq_out = embd_seq; + const uint32_t n_cls_out = hparams.n_cls_out; + + // TODO: fix indexing [UBATCH_IDX] + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(n_cls_out); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ABORT("unknown pooling type"); + } + } + } + + // Reset state for the next token before backend sync, to allow the CPU activities in the reset to + // overlap with device computation. + ggml_backend_sched_reset(sched.get()); + + // TODO: hacky solution + if (model.arch == LLM_ARCH_T5 && t_embd) { + //cross.t_embd = t_embd; + + synchronize(); + + cross.n_embd = t_embd->ne[0]; + cross.n_enc = t_embd->ne[1]; + cross.v_embd.resize(cross.n_embd*cross.n_enc); + memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd)); + + // remember the sequence ids used during the encoding - needed for cross attention later + cross.seq_ids_enc.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + cross.seq_ids_enc[i].clear(); + for (int s = 0; s < batch.n_seq_id[i]; s++) { + llama_seq_id seq_id = batch.seq_id[i][s]; + cross.seq_ids_enc[i].insert(seq_id); + } + } + } + + return 0; +} + +int llama_context::decode(const llama_batch & batch_inp) { + if (!memory) { + LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); + return encode(batch_inp); + } + + if (batch_inp.n_tokens == 0) { + LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); + return -1; + } + + // when computing embeddings, all tokens are output + const bool embd_all = cparams.embeddings; + + if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) { + LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); + return -1; + } + + const llama_batch & batch = batch_allocr->get_batch(); + + const auto & vocab = model.vocab; + const auto & hparams = model.hparams; + + const int32_t n_vocab = vocab.n_tokens(); + const int64_t n_embd = hparams.n_embd; + + const uint32_t n_tokens_all = batch.n_tokens; + + GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + + const uint32_t n_outputs_all = batch_allocr->get_n_outputs(); + + if (embd_all) { + // require that all tokens are output + if (n_outputs_all != n_tokens_all) { + LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n", + __func__, n_outputs_all, n_tokens_all); + return -1; + } + } + + GGML_ASSERT(n_tokens_all <= cparams.n_batch); + + GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens"); + + if (t_compute_start_us == 0) { + t_compute_start_us = ggml_time_us(); + } + n_queued_tokens += n_tokens_all; + + // TODO: this clear of the buffer can easily be forgotten - need something better + embd_seq.clear(); + + bool did_optimize = false; + + // handle any pending defrags/shifts + kv_self_update(false); + + llama_memory_state_ptr mstate; + + while (true) { + mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all); + if (!mstate) { + return -2; + } + + switch (mstate->get_status()) { + case LLAMA_MEMORY_STATUS_SUCCESS: + { + } break; + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status()); + + return -2; + } + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + { + if (!did_optimize) { + did_optimize = true; + + if (kv_self_update(true)) { + LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens); + + continue; + } + } + + LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens); + + return 1; + } + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens); + + return -2; + } + } + + break; + } + + // reserve output buffer + if (output_reserve(n_outputs_all) < n_outputs_all) { + LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); + return -2; + }; + + int64_t n_outputs_prev = 0; + + do { + const auto & ubatch = mstate->get_ubatch(); + + // count the outputs in this ubatch + { + int32_t n_outputs_new = 0; + + if (n_outputs_all == n_tokens_all) { + n_outputs_new = ubatch.n_tokens; + } else { + GGML_ASSERT(ubatch.output); + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + n_outputs_new += (int32_t) (ubatch.output[i] != 0); + } + } + + // needs to happen before the graph is built + n_outputs = n_outputs_new; + } + + ggml_backend_sched_reset(sched.get()); + ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + + ggml_status status; + const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status); + + if (!res) { + // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache + llama_pos pos_min[LLAMA_MAX_SEQ]; + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + pos_min[s] = std::numeric_limits::max(); + } + + // TODO: fix sequence indexing + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + const auto & seq_id = ubatch.seq_id[i][0]; + + pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]); + } + + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (pos_min[s] == std::numeric_limits::max()) { + continue; + } + + LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); + + memory->seq_rm(s, pos_min[s], -1); + } + + switch (status) { + case GGML_STATUS_ABORTED: return 2; + case GGML_STATUS_ALLOC_FAILED: return -2; + case GGML_STATUS_FAILED: return -3; + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); + } + } + + // plot the computation graph in dot format (for debugging purposes) + //if (n_past%100 == 0) { + // ggml_graph_dump_dot(gf, NULL, "llama.dot"); + //} + + auto * t_logits = res->get_logits(); + auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + + if (t_embd && res->get_embd_pooled()) { + t_embd = res->get_embd_pooled(); + } + + // extract logits + if (t_logits && n_outputs > 0) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); + + float * logits_out = logits + n_outputs_prev*n_vocab; + + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); + ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); + } + } + + // extract embeddings + if (t_embd && n_outputs > 0) { + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); + GGML_ASSERT(backend_embd != nullptr); + + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // extract token embeddings + GGML_ASSERT(embd != nullptr); + float * embd_out = embd + n_outputs_prev*n_embd; + + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + // extract sequence embeddings (cleared before processing each batch) + auto & embd_seq_out = embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(n_embd); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_RANK: + { + // extract the rerank score - a single float per sequence + auto & embd_seq_out = embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(1); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ABORT("unknown pooling type"); + } + } + } + + n_outputs_prev += n_outputs; + } while (mstate->next()); + + // set to total number of outputs in the batch, for use in llama_get_logits_ith + n_outputs = n_outputs_all; + + // set output mappings + if (n_outputs > 0) { + bool sorted_output = true; + + auto & out_ids = mstate->out_ids(); + + GGML_ASSERT(out_ids.size() == (size_t) n_outputs); + + for (int64_t i = 0; i < n_outputs; ++i) { + int64_t out_id = out_ids[i]; + output_ids[out_id] = i; + if (out_id != i) { + sorted_output = false; + } + } + + // make the outputs have the same order they had in the user-provided batch + // note: this is mostly relevant for recurrent models atm + if (!sorted_output) { + const uint32_t n_vocab = model.vocab.n_tokens(); + const uint64_t n_embd = model.hparams.n_embd; + + GGML_ASSERT((size_t) n_outputs == out_ids.size()); + + // TODO: is there something more efficient which also minimizes swaps? + // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) + for (uint32_t i = 0; i < n_outputs - 1; ++i) { + uint32_t j_min = i; + for (uint32_t j = i + 1; j < n_outputs; ++j) { + if (out_ids[j] < out_ids[j_min]) { + j_min = j; + } + } + if (j_min == i) { + continue; + } + std::swap(out_ids[i], out_ids[j_min]); + if (logits_size > 0) { + for (uint32_t k = 0; k < n_vocab; k++) { + std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]); + } + } + if (embd_size > 0) { + for (uint32_t k = 0; k < n_embd; k++) { + std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]); + } + } + } + + std::fill(output_ids.begin(), output_ids.end(), -1); + + for (uint32_t i = 0; i < n_outputs; ++i) { + output_ids[out_ids[i]] = i; + } + } + } + + // wait for the computation to finish (automatically done when obtaining the model output) + //synchronize(); + + // Reset state for the next token before backend sync, to allow the CPU activities in the reset to + // overlap with device computation. + ggml_backend_sched_reset(sched.get()); + + return 0; +} + +// +// output +// + +uint32_t llama_context::output_reserve(int32_t n_outputs) { + const auto & hparams = model.hparams; + const auto & vocab = model.vocab; + + const int64_t n_outputs_max = std::max(n_outputs, n_seq_max()); + + const auto n_batch = cparams.n_batch; + const auto n_vocab = vocab.n_tokens(); + const auto n_embd = hparams.n_embd; + + bool has_logits = true; + bool has_embd = cparams.embeddings; + + // TODO: hacky enc-dec support + if (model.arch == LLM_ARCH_T5) { + has_logits = true; + has_embd = true; + } + + logits_size = has_logits ? n_vocab*n_outputs_max : 0; + embd_size = has_embd ? n_embd*n_outputs_max : 0; + + if (output_ids.empty()) { + // init, never resized afterwards + output_ids.resize(n_batch); + } + + const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; + const size_t new_size = (logits_size + embd_size) * sizeof(float); + + // alloc only when more than the current capacity is required + // TODO: also consider shrinking the buffer + if (!buf_output || prev_size < new_size) { + if (buf_output) { +#ifndef NDEBUG + // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark) + LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); +#endif + buf_output = nullptr; + logits = nullptr; + embd = nullptr; + } + + auto * buft = ggml_backend_cpu_buffer_type(); + // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory + auto * output_dev = model.dev_output(); + auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr; + if (output_dev_host_buft) { + buft = output_dev_host_buft; + } + buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size)); + if (buf_output == nullptr) { + LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0)); + return 0; + } + } + + float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); + + logits = has_logits ? output_base : nullptr; + embd = has_embd ? output_base + logits_size : nullptr; + + // set all ids as invalid (negative) + std::fill(output_ids.begin(), output_ids.end(), -1); + + this->n_outputs = 0; + + return n_outputs_max; +} + +// +// graph +// + +int32_t llama_context::graph_max_nodes() const { + return std::max(65536, 5*model.n_tensors()); +} + +ggml_cgraph * llama_context::graph_init() { + ggml_init_params params = { + /*.mem_size =*/ buf_compute_meta.size(), + /*.mem_buffer =*/ buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + + ctx_compute.reset(ggml_init(params)); + + return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false); +} + +ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) { + LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); + + if (n_tokens % n_seqs != 0) { + n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs + n_outputs = std::min(n_outputs, n_tokens); + + LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs); + } + + // store the n_outputs as it is, and restore it afterwards + // TODO: not sure if needed, might simplify in the future by removing this + const auto save_n_outputs = this->n_outputs; + + this->n_outputs = n_outputs; + + llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + + auto * gf = graph_init(); + auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate); + + this->n_outputs = save_n_outputs; + + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__); + return nullptr; + } + + ggml_backend_sched_reset(sched.get()); + + // initialize scheduler with the specified graph + if (!ggml_backend_sched_reserve(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); + return nullptr; + } + + return gf; +} + +llm_graph_result_ptr llama_context::graph_build( + ggml_context * ctx, + ggml_cgraph * gf, + const llama_ubatch & ubatch, + llm_graph_type gtype, + const llama_memory_state_i * mstate) { + return model.build_graph( + { + /*.ctx =*/ ctx, + /*.arch =*/ model.arch, + /*.hparams =*/ model.hparams, + /*.cparams =*/ cparams, + /*.ubatch =*/ ubatch, + /*.sched =*/ sched.get(), + /*.backend_cpu =*/ backend_cpu, + /*.cvec =*/ &cvec, + /*.loras =*/ &loras, + /*.mstate =*/ mstate, + /*.cross =*/ &cross, + /*.n_outputs =*/ n_outputs, + /*.cb =*/ graph_get_cb(), + }, gf, gtype); +} + +ggml_status llama_context::graph_compute( + ggml_cgraph * gf, + bool batched) { + int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads; + ggml_threadpool_t tp = batched ? threadpool_batch : threadpool; + + if (backend_cpu != nullptr) { + auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu)); + auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool"); + set_threadpool_fn(backend_cpu, tp); + } + + // set the number of threads for all the backends + for (const auto & set_n_threads_fn : set_n_threads_fns) { + set_n_threads_fn.second(set_n_threads_fn.first, n_threads); + } + + auto status = ggml_backend_sched_graph_compute_async(sched.get(), gf); + if (status != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status); + } + + // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched)); + + return status; +} + +llm_graph_cb llama_context::graph_get_cb() const { + return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) { + if (il >= 0) { + ggml_format_name(cur, "%s-%d", name, il); + } else { + ggml_set_name(cur, name); + } + + if (!cparams.offload_kqv) { + if (strcmp(name, "kqv_merged_cont") == 0) { + // all nodes between the KV store and the attention output are run on the CPU + ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu); + } + } + + // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends + // FIXME: fix in ggml_backend_sched + const bool full_offload = model.params.n_gpu_layers > (int) model.hparams.n_layer; + if (ubatch.n_tokens < 32 || full_offload) { + if (il != -1 && strcmp(name, "norm") == 0) { + const auto & dev_layer = model.dev_layer(il); + for (const auto & backend : backends) { + if (ggml_backend_get_device(backend.get()) == dev_layer) { + if (ggml_backend_supports_op(backend.get(), cur)) { + ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get()); + } + } + } + } + } + }; +} + +// +// state save/load +// + +class llama_io_write_dummy : public llama_io_write_i { +public: + llama_io_write_dummy() = default; + + void write(const void * /* src */, size_t size) override { + size_written += size; + } + + void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override { + size_written += size; + } + + size_t n_bytes() override { + return size_written; + } + +private: + size_t size_written = 0; +}; + +class llama_io_write_buffer : public llama_io_write_i { +public: + llama_io_write_buffer( + uint8_t * p, size_t len) : ptr(p), buf_size(len) {} + + void write(const void * src, size_t size) override { + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + memcpy(ptr, src, size); + ptr += size; + size_written += size; + buf_size -= size; + } + + void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + ggml_backend_tensor_get(tensor, ptr, offset, size); + ptr += size; + size_written += size; + buf_size -= size; + } + + size_t n_bytes() override { + return size_written; + } + +private: + uint8_t * ptr; + size_t buf_size = 0; + size_t size_written = 0; +}; + +class llama_io_read_buffer : public llama_io_read_i { +public: + llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {} + + const uint8_t * read(size_t size) override { + const uint8_t * base_ptr = ptr; + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + ptr += size; + size_read += size; + buf_size -= size; + return base_ptr; + } + + void read_to(void * dst, size_t size) override { + memcpy(dst, read(size), size); + } + + size_t n_bytes() override { + return size_read; + } + +private: + const uint8_t * ptr; + size_t buf_size = 0; + size_t size_read = 0; +}; + +class llama_io_write_file : public llama_io_write_i { +public: + llama_io_write_file(llama_file * f) : file(f) {} + + void write(const void * src, size_t size) override { + file->write_raw(src, size); + size_written += size; + } + + void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { + temp_buffer.resize(size); + ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size); + write(temp_buffer.data(), temp_buffer.size()); + } + + size_t n_bytes() override { + return size_written; + } + +private: + llama_file * file; + size_t size_written = 0; + std::vector temp_buffer; +}; + +class llama_io_read_file : public llama_io_read_i { +public: + llama_io_read_file(llama_file * f) : file(f) {} + + void read_to(void * dst, size_t size) override { + file->read_raw(dst, size); + size_read += size; + } + + const uint8_t * read(size_t size) override { + temp_buffer.resize(size); + read_to(temp_buffer.data(), size); + return temp_buffer.data(); + } + + size_t n_bytes() override { + return size_read; + } + +private: + llama_file * file; + size_t size_read = 0; + std::vector temp_buffer; +}; + +size_t llama_context::state_get_size() { + llama_io_write_dummy io; + try { + return state_write_data(io); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); + return 0; + } +} + +size_t llama_context::state_get_data(uint8_t * dst, size_t size) { + llama_io_write_buffer io(dst, size); + try { + return state_write_data(io); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); + return 0; + } +} + +size_t llama_context::state_set_data(const uint8_t * src, size_t size) { + llama_io_read_buffer io(src, size); + try { + return state_read_data(io); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); + return 0; + } +} + +size_t llama_context::state_seq_get_size(llama_seq_id seq_id) { + llama_io_write_dummy io; + try { + return state_seq_write_data(io, seq_id); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); + return 0; + } +} + +size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) { + llama_io_write_buffer io(dst, size); + try { + return state_seq_write_data(io, seq_id); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); + return 0; + } +} + +size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) { + llama_io_read_buffer io(src, size); + try { + return state_seq_read_data(io, seq_id); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); + return 0; + } +} + +bool llama_context::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + llama_file file(filepath, "rb"); + + // sanity checks + { + const uint32_t magic = file.read_u32(); + const uint32_t version = file.read_u32(); + + if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) { + LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); + return false; + } + } + + // load the prompt + { + const uint32_t n_token_count = file.read_u32(); + + if (n_token_count > n_token_capacity) { + LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); + return false; + } + + file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); + *n_token_count_out = n_token_count; + } + + // restore the context state + { + const size_t n_state_size_cur = file.size() - file.tell(); + + llama_io_read_file io( &file); + const size_t n_read = state_read_data(io); + + if (n_read != n_state_size_cur) { + LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read); + return false; + } + } + + return true; +} + +bool llama_context::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) { + llama_file file(filepath, "wb"); + + file.write_u32(LLAMA_SESSION_MAGIC); + file.write_u32(LLAMA_SESSION_VERSION); + + // save the prompt + file.write_u32((uint32_t) n_token_count); + file.write_raw(tokens, sizeof(llama_token) * n_token_count); + + // save the context state using stream saving + llama_io_write_file io(&file); + state_write_data(io); + + return true; +} + +size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + llama_file file(filepath, "rb"); + + // version checks + { + const uint32_t magic = file.read_u32(); + const uint32_t version = file.read_u32(); + + if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) { + LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version); + return 0; + } + } + + // load the prompt + { + const uint32_t n_token_count = file.read_u32(); + + if (n_token_count > n_token_capacity) { + LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); + return 0; + } + + file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); + *n_token_count_out = n_token_count; + } + + // restore the context state + { + const size_t state_size = file.size() - file.tell(); + llama_io_read_file io(&file); + const size_t nread = state_seq_read_data(io, seq_id); + if (!nread) { + LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); + return 0; + } + GGML_ASSERT(nread <= state_size); + GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell()); + } + + return file.tell(); +} + +size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) { + llama_file file(filepath, "wb"); + + file.write_u32(LLAMA_STATE_SEQ_MAGIC); + file.write_u32(LLAMA_STATE_SEQ_VERSION); + + // save the prompt + file.write_u32((uint32_t) n_token_count); + file.write_raw(tokens, sizeof(llama_token) * n_token_count); + + // save the context state using stream saving + llama_io_write_file io(&file); + state_seq_write_data(io, seq_id); + + const size_t res = file.tell(); + GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes()); + + return res; +} + +size_t llama_context::state_write_data(llama_io_write_i & io) { + LLAMA_LOG_DEBUG("%s: writing state\n", __func__); + + // write model info + { + LLAMA_LOG_DEBUG("%s: - writing model info\n", __func__); + + const std::string arch_str = llm_arch_name(model.arch); + io.write_string(arch_str); + // TODO: add more model-specific info which should prevent loading the session file if not identical + } + + // write output ids + { + LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__); + + const auto n_outputs = this->n_outputs; + const auto & output_ids = this->output_ids; + + std::vector w_output_pos; + + w_output_pos.resize(n_outputs); + + // build a more compact representation of the output ids + for (size_t i = 0; i < n_batch(); ++i) { + // map an output id to a position in the batch + int64_t pos = output_ids[i]; + if (pos >= 0) { + GGML_ASSERT(pos < n_outputs); + w_output_pos[pos] = i; + } + } + + io.write(&n_outputs, sizeof(n_outputs)); + + if (n_outputs) { + io.write(w_output_pos.data(), n_outputs * sizeof(int32_t)); + } + } + + // write logits + { + LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__); + + const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens()); + + io.write(&logits_size, sizeof(logits_size)); + + if (logits_size) { + io.write(logits, logits_size * sizeof(float)); + } + } + + // write embeddings + { + LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__); + + const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd); + + io.write(&embd_size, sizeof(embd_size)); + + if (embd_size) { + io.write(embd, embd_size * sizeof(float)); + } + } + + if (memory != nullptr) { + LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); + memory->state_write(io); + } + + return io.n_bytes(); +} + +size_t llama_context::state_read_data(llama_io_read_i & io) { + LLAMA_LOG_DEBUG("%s: reading state\n", __func__); + + // read model info + { + LLAMA_LOG_DEBUG("%s: - reading model info\n", __func__); + + const std::string cur_arch_str = llm_arch_name(model.arch); + + std::string arch_str; + io.read_string(arch_str); + if (cur_arch_str != arch_str) { + throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str())); + } + // TODO: add more info which needs to be identical but which is not verified otherwise + } + + // read output ids + { + LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__); + + auto n_outputs = this->n_outputs; + io.read_to(&n_outputs, sizeof(n_outputs)); + + if (n_outputs > output_reserve(n_outputs)) { + throw std::runtime_error("could not reserve outputs"); + } + + std::vector output_pos; + + if (n_outputs) { + output_pos.resize(n_outputs); + io.read_to(output_pos.data(), n_outputs * sizeof(int32_t)); + + for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) { + int32_t id = output_pos[i]; + if ((uint32_t) id >= n_batch()) { + throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch())); + } + this->output_ids[id] = i; + } + + this->n_outputs = n_outputs; + } + } + + // read logits + { + LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__); + + uint64_t logits_size; + io.read_to(&logits_size, sizeof(logits_size)); + + if (this->logits_size < logits_size) { + throw std::runtime_error("logits buffer too small"); + } + + if (logits_size) { + io.read_to(this->logits, logits_size * sizeof(float)); + } + } + + // read embeddings + { + LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__); + + uint64_t embd_size; + io.read_to(&embd_size, sizeof(embd_size)); + + if (this->embd_size < embd_size) { + throw std::runtime_error("embeddings buffer too small"); + } + + if (embd_size) { + io.read_to(this->embd, embd_size * sizeof(float)); + } + } + + if (memory) { + LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__); + + memory->state_read(io); + } + + return io.n_bytes(); +} + +size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) { + GGML_UNUSED(seq_id); + + if (memory) { + memory->state_write(io, seq_id); + } + + return io.n_bytes(); +} + +size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) { + GGML_UNUSED(seq_id); + + if (memory) { + memory->state_read(io, seq_id); + } + + return io.n_bytes(); +} + +// +// perf +// + +llama_perf_context_data llama_context::perf_get_data() const { + llama_perf_context_data data = {}; + + data.t_start_ms = 1e-3 * t_start_us; + data.t_load_ms = 1e-3 * t_load_us; + data.t_p_eval_ms = 1e-3 * t_p_eval_us; + data.t_eval_ms = 1e-3 * t_eval_us; + data.n_p_eval = std::max(1, n_p_eval); + data.n_eval = std::max(1, n_eval); + + return data; +} + +void llama_context::perf_reset() { + t_start_us = ggml_time_us(); + t_eval_us = n_eval = 0; + t_p_eval_us = n_p_eval = 0; +} + +// +// training +// + +static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) { + if (!tensor || tensor->type != GGML_TYPE_F32) { + return; + } + if (!param_filter(tensor, userdata)) { + return; + } + if (strcmp(tensor->name, "token_embd.weight") == 0) { + return; // FIXME + } + if (strcmp(tensor->name, "rope_freqs.weight") == 0) { + return; // FIXME + } + ggml_set_param(tensor); +} + +void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) { + GGML_ASSERT(!opt_ctx); + model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx(); + const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train); + const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); + GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0); + GGML_ASSERT(n_batch % n_ubatch == 0); + + ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY); + opt_params.opt_period = n_batch / n_ubatch; + opt_params.get_opt_pars = lopt_params.get_opt_pars; + opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud; + + opt_ctx = ggml_opt_init(opt_params); + + llama_opt_param_filter param_filter = lopt_params.param_filter; + void * param_filter_ud = lopt_params.param_filter_ud; + + //llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME + llama_set_param(model->type_embd, param_filter, param_filter_ud); + llama_set_param(model->pos_embd, param_filter, param_filter_ud); + llama_set_param(model->tok_norm, param_filter, param_filter_ud); + llama_set_param(model->tok_norm_b, param_filter, param_filter_ud); + llama_set_param(model->output_norm, param_filter, param_filter_ud); + llama_set_param(model->output_norm_b, param_filter, param_filter_ud); + llama_set_param(model->output, param_filter, param_filter_ud); + llama_set_param(model->output_b, param_filter, param_filter_ud); + llama_set_param(model->output_norm_enc, param_filter, param_filter_ud); + llama_set_param(model->cls, param_filter, param_filter_ud); + llama_set_param(model->cls_b, param_filter, param_filter_ud); + llama_set_param(model->cls_out, param_filter, param_filter_ud); + llama_set_param(model->cls_out_b, param_filter, param_filter_ud); + + for (struct llama_layer & layer : model->layers) { + for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { + llama_set_param(reinterpret_cast(&layer)[i], param_filter, param_filter_ud); + } + } +} + +void llama_context::opt_epoch_iter( + ggml_opt_dataset_t dataset, + ggml_opt_result_t result, + const std::vector & tokens, + const std::vector & labels_sparse, + llama_batch & batch, + ggml_opt_epoch_callback callback, + bool train, + int64_t idata_in_loop, + int64_t ndata_in_loop, + int64_t t_loop_start) { + GGML_ASSERT(opt_ctx); + const uint32_t n_ctx = llama_model_n_ctx_train(&model); + const uint32_t n_batch = std::min(this->n_batch(), n_ctx); + const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); + + memory->clear(true); + + for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) { + batch.n_tokens = n_batch; + for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) { + batch.token [pos_batch] = tokens[pos_ctx + pos_batch]; + batch.pos [pos_batch] = pos_ctx + pos_batch; + batch.n_seq_id[pos_batch] = 1; + batch.seq_id [pos_batch][0] = 0; + batch.logits [pos_batch] = true; + } + + const auto n_tokens_all = batch.n_tokens; + + n_queued_tokens += n_tokens_all; + + embd_seq.clear(); + + uint32_t n_outputs_all = n_tokens_all; + + auto mstate = memory->init_batch(batch, cparams.n_ubatch, true); + if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); + break; + } + + // reserve output buffer + if (output_reserve(n_outputs_all) < n_outputs_all) { + LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); + GGML_ABORT("TODO: handle this error"); + }; + + uint32_t pos_batch = 0; + do { + const auto & ubatch = mstate->get_ubatch(); + + n_outputs = ubatch.n_tokens; + + if (!mstate->apply()) { + LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__); + break; + } + + auto * gf = graph_init(); + auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get()); + + struct ggml_context * ctx_compute_opt; + { + const size_t size_gf = ggml_graph_size(gf); + const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true); + struct ggml_init_params params = { + /*.mem_size =*/ size_meta, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ctx_compute_opt = ggml_init(params); + } + ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits()); + ggml_opt_alloc(opt_ctx, train); + + res->set_inputs(&ubatch); + { + struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); + GGML_ASSERT(labels->ne[1] == n_ubatch); + ggml_set_zero(labels); + const float onef = 1.0f; + for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) { + const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch; + GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]); + ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float)); + } + } + ggml_opt_eval(opt_ctx, result); + if (callback) { + callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start); + } + ggml_free(ctx_compute_opt); + + pos_batch += ubatch.n_tokens; + } while (mstate->next()); + } +} + +void llama_context::opt_epoch( + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, + ggml_opt_result_t result_eval, + int64_t idata_split, + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval) { + const uint32_t n_ctx = this->n_ctx(); + const uint32_t n_batch = std::min(cparams.n_batch, n_ctx); + const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch); + const int64_t ndata = ggml_opt_dataset_ndata(dataset); + + GGML_ASSERT(idata_split >= 0); + GGML_ASSERT(idata_split <= ndata); + + const uint32_t ubatch_per_ctx = n_ctx / n_ubatch; + + struct llama_batch batch = llama_batch_init(n_batch, 0, 1); + std::vector tokens(n_ctx); + std::vector labels_sparse(n_ctx); + + int64_t idata = 0; + + int64_t t_loop_start = ggml_time_us(); + int64_t ndata_in_loop = idata_split*ubatch_per_ctx; + for (; idata < idata_split; ++idata) { + constexpr bool train = true; + const int64_t idata_in_loop = idata*ubatch_per_ctx; + + ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata); + opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch, + callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start); + } + + t_loop_start = ggml_time_us(); + ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx; + for (; idata < ndata; ++idata) { + constexpr bool train = false; + const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx; + + ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata); + opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch, + callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start); + } + + llama_batch_free(batch); +} + +// +// interface implementation +// + +llama_context_params llama_context_default_params() { + llama_context_params result = { + /*.n_ctx =*/ 512, + /*.n_batch =*/ 2048, + /*.n_ubatch =*/ 512, + /*.n_seq_max =*/ 1, + /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default + /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, + /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, + /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, + /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, + /*.rope_freq_base =*/ 0.0f, + /*.rope_freq_scale =*/ 0.0f, + /*.yarn_ext_factor =*/ -1.0f, + /*.yarn_attn_factor =*/ 1.0f, + /*.yarn_beta_fast =*/ 32.0f, + /*.yarn_beta_slow =*/ 1.0f, + /*.yarn_orig_ctx =*/ 0, + /*.defrag_thold =*/ -1.0f, + /*.cb_eval =*/ nullptr, + /*.cb_eval_user_data =*/ nullptr, + /*.type_k =*/ GGML_TYPE_F16, + /*.type_v =*/ GGML_TYPE_F16, + /*.abort_callback =*/ nullptr, + /*.abort_callback_data =*/ nullptr, + /*.embeddings =*/ false, + /*.offload_kqv =*/ true, + /*.flash_attn =*/ false, + /*.no_perf =*/ true, + /*.op_offload =*/ true, + /*.swa_full =*/ true, + }; + + return result; +} + +llama_context * llama_init_from_model( + llama_model * model, + llama_context_params params) { + if (!model) { + LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__); + return nullptr; + } + + if (params.n_batch == 0 && params.n_ubatch == 0) { + LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__); + return nullptr; + } + + if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) { + LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__); + return nullptr; + } + + if (params.flash_attn && model->arch == LLM_ARCH_GROK) { + LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__); + params.flash_attn = false; + } + + if (ggml_is_quantized(params.type_v) && !params.flash_attn) { + LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); + return nullptr; + } + + try { + auto * ctx = new llama_context(*model, params); + return ctx; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: failed to initialize the context: %s\n", __func__, err.what()); + } + + return nullptr; +} + +// deprecated +llama_context * llama_new_context_with_model( + llama_model * model, + llama_context_params params) { + return llama_init_from_model(model, params); +} + +void llama_free(llama_context * ctx) { + delete ctx; +} + +uint32_t llama_n_ctx(const llama_context * ctx) { + return ctx->n_ctx(); +} + +uint32_t llama_n_batch(const llama_context * ctx) { + return ctx->n_batch(); +} + +uint32_t llama_n_ubatch(const llama_context * ctx) { + return ctx->n_ubatch(); +} + +uint32_t llama_n_seq_max(const llama_context * ctx) { + return ctx->n_seq_max(); +} + +const llama_model * llama_get_model(const llama_context * ctx) { + return &ctx->get_model(); +} + +// deprecated +llama_kv_cache * llama_get_kv_self(llama_context * ctx) { + return dynamic_cast(ctx->get_memory()); +} + +// deprecated +void llama_kv_self_update(llama_context * ctx) { + ctx->kv_self_update(false); +} + +enum llama_pooling_type llama_pooling_type(const llama_context * ctx) { + return ctx->pooling_type(); +} + +void llama_attach_threadpool( + llama_context * ctx, + ggml_threadpool_t threadpool, + ggml_threadpool_t threadpool_batch) { + ctx->attach_threadpool(threadpool, threadpool_batch); +} + +void llama_detach_threadpool(llama_context * ctx) { + ctx->detach_threadpool(); +} + +void llama_set_n_threads(llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) { + ctx->set_n_threads(n_threads, n_threads_batch); +} + +int32_t llama_n_threads(llama_context * ctx) { + return ctx->n_threads(); +} + +int32_t llama_n_threads_batch(llama_context * ctx) { + return ctx->n_threads_batch(); +} + +void llama_set_abort_callback(llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) { + ctx->set_abort_callback(abort_callback, abort_callback_data); +} + +void llama_set_embeddings(llama_context * ctx, bool embeddings) { + ctx->set_embeddings(embeddings); +} + +void llama_set_causal_attn(llama_context * ctx, bool causal_attn) { + ctx->set_causal_attn(causal_attn); +} + +void llama_set_warmup(llama_context * ctx, bool warmup) { + ctx->set_warmup(warmup); +} + +void llama_synchronize(llama_context * ctx) { + ctx->synchronize(); +} + +float * llama_get_logits(llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_logits(); +} + +float * llama_get_logits_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_logits_ith(i); +} + +float * llama_get_embeddings(llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_embeddings(); +} + +float * llama_get_embeddings_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_embeddings_ith(i); +} + +float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { + ctx->synchronize(); + + return ctx->get_embeddings_seq(seq_id); +} + +// llama adapter API + +int32_t llama_set_adapter_lora( + llama_context * ctx, + llama_adapter_lora * adapter, + float scale) { + ctx->set_adapter_lora(adapter, scale); + + return 0; +} + +int32_t llama_rm_adapter_lora( + llama_context * ctx, + llama_adapter_lora * adapter) { + bool res = ctx->rm_adapter_lora(adapter); + + return res ? 0 : -1; +} + +void llama_clear_adapter_lora(llama_context * ctx) { + ctx->clear_adapter_lora(); +} + +int32_t llama_apply_adapter_cvec( + llama_context * ctx, + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end) { + bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end); + + return res ? 0 : -1; +} + +// +// memory +// + +llama_memory_t llama_get_memory(const struct llama_context * ctx) { + return ctx->get_memory(); +} + +void llama_memory_clear(llama_memory_t mem, bool data) { + if (!mem) { + return; + } + + mem->clear(data); +} + +bool llama_memory_seq_rm( + llama_memory_t mem, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { + if (!mem) { + return true; + } + + return mem->seq_rm(seq_id, p0, p1); +} + +void llama_memory_seq_cp( + llama_memory_t mem, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { + if (!mem) { + return; + } + + mem->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_memory_seq_keep( + llama_memory_t mem, + llama_seq_id seq_id) { + if (!mem) { + return; + } + + mem->seq_keep(seq_id); +} + +void llama_memory_seq_add( + llama_memory_t mem, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { + if (!mem) { + return; + } + + mem->seq_add(seq_id, p0, p1, delta); +} + +void llama_memory_seq_div( + llama_memory_t mem, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d) { + if (!mem) { + return; + } + + mem->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_memory_seq_pos_min( + llama_memory_t mem, + llama_seq_id seq_id) { + if (!mem) { + return -1; + } + + return mem->seq_pos_min(seq_id); +} + +llama_pos llama_memory_seq_pos_max( + llama_memory_t mem, + llama_seq_id seq_id) { + if (!mem) { + return -1; + } + + return mem->seq_pos_max(seq_id); +} + +bool llama_memory_can_shift(llama_memory_t mem) { + if (!mem) { + return false; + } + + return mem->get_can_shift(); +} + +// +// kv cache +// + +// deprecated +int32_t llama_kv_self_n_tokens(const llama_context * ctx) { + const auto * kv = llama_get_memory(ctx); + if (!kv) { + return 0; + } + + int32_t res = 0; + + for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) { + const llama_pos p0 = kv->seq_pos_min(s); + const llama_pos p1 = kv->seq_pos_max(s); + + if (p0 >= 0) { + res += (p1 - p0) + 1; + } + } + + return res; +} + +// deprecated +// note: this is the same as above - will be removed anyway, so it's ok +int32_t llama_kv_self_used_cells(const llama_context * ctx) { + const auto * kv = llama_get_memory(ctx); + if (!kv) { + return 0; + } + + int32_t res = 0; + + for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) { + const llama_pos p0 = kv->seq_pos_min(s); + const llama_pos p1 = kv->seq_pos_max(s); + + if (p0 >= 0) { + res += (p1 - p0) + 1; + } + } + + return res; +} + +// deprecated +void llama_kv_self_clear(llama_context * ctx) { + auto * kv = llama_get_memory(ctx); + if (!kv) { + return; + } + + llama_memory_clear(kv, true); +} + +// deprecated +bool llama_kv_self_seq_rm( + llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { + auto * kv = llama_get_memory(ctx); + if (!kv) { + return true; + } + + return llama_memory_seq_rm(kv, seq_id, p0, p1); +} + +// deprecated +void llama_kv_self_seq_cp( + llama_context * ctx, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { + auto * kv = llama_get_memory(ctx); + if (!kv) { + return; + } + + llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1); +} + +// deprecated +void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { + auto * kv = llama_get_memory(ctx); + if (!kv) { + return; + } + + llama_memory_seq_keep(kv, seq_id); +} + +// deprecated +void llama_kv_self_seq_add( + llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { + auto * kv = llama_get_memory(ctx); + if (!kv) { + return; + } + + llama_memory_seq_add(kv, seq_id, p0, p1, delta); +} + +// deprecated +void llama_kv_self_seq_div( + llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d) { + auto * kv = llama_get_memory(ctx); + if (!kv) { + return; + } + + llama_memory_seq_div(kv, seq_id, p0, p1, d); +} + +// deprecated +llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) { + auto * kv = llama_get_memory(ctx); + if (!kv) { + return -1; + } + + return llama_memory_seq_pos_min(kv, seq_id); +} + +// deprecated +llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { + auto * kv = llama_get_memory(ctx); + if (!kv) { + return -1; + } + + return llama_memory_seq_pos_max(kv, seq_id); +} + +// deprecated +void llama_kv_self_defrag(llama_context * ctx) { + // force defrag + ctx->kv_self_defrag_sched(); +} + +// deprecated +bool llama_kv_self_can_shift(const llama_context * ctx) { + auto * kv = llama_get_memory(ctx); + if (!kv) { + return false; + } + + return llama_memory_can_shift(kv); +} + +// llama state API + +// deprecated +size_t llama_get_state_size(llama_context * ctx) { + return llama_state_get_size(ctx); +} + +// deprecated +size_t llama_copy_state_data(llama_context * ctx, uint8_t * dst) { + return llama_state_get_data(ctx, dst, -1); +} + +// deprecated +size_t llama_set_state_data(llama_context * ctx, const uint8_t * src) { + return llama_state_set_data(ctx, src, -1); +} + +// deprecated +bool llama_load_session_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); +} + +// deprecated +bool llama_save_session_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { + return llama_state_save_file(ctx, path_session, tokens, n_token_count); +} + +// Returns the *actual* size of the state. +// Intended to be used when saving to state to a buffer. +size_t llama_state_get_size(llama_context * ctx) { + return ctx->state_get_size(); +} + +size_t llama_state_get_data(llama_context * ctx, uint8_t * dst, size_t size) { + ctx->synchronize(); + + return ctx->state_get_data(dst, size); +} + +// Sets the state reading from the specified source address +size_t llama_state_set_data(llama_context * ctx, const uint8_t * src, size_t size) { + ctx->synchronize(); + + return ctx->state_set_data(src, size); +} + +bool llama_state_load_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + ctx->synchronize(); + + try { + return ctx->state_load_file(path_session, tokens_out, n_token_capacity, n_token_count_out); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what()); + return false; + } +} + +bool llama_state_save_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { + ctx->synchronize(); + + try { + return ctx->state_save_file(path_session, tokens, n_token_count); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what()); + return false; + } +} + +size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) { + return ctx->state_seq_get_size(seq_id); +} + +size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) { + ctx->synchronize(); + + return ctx->state_seq_get_data(seq_id, dst, size); +} + +size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) { + ctx->synchronize(); + + return ctx->state_seq_set_data(seq_id, src, size); +} + +size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { + ctx->synchronize(); + + try { + return ctx->state_seq_save_file(seq_id, filepath, tokens, n_token_count); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what()); + return 0; + } +} + +size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + ctx->synchronize(); + + try { + return ctx->state_seq_load_file(dest_seq_id, filepath, tokens_out, n_token_capacity, n_token_count_out); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what()); + return 0; + } +} + +/// + +int32_t llama_encode( + llama_context * ctx, + llama_batch batch) { + const int ret = ctx->encode(batch); + if (ret != 0) { + LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); + } + + return ret; +} + +int32_t llama_decode( + llama_context * ctx, + llama_batch batch) { + const int ret = ctx->decode(batch); + if (ret != 0 && ret != 1) { + LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); + } + + return ret; +} + +// +// perf +// + +llama_perf_context_data llama_perf_context(const llama_context * ctx) { + llama_perf_context_data data = {}; + + if (ctx == nullptr) { + return data; + } + + data = ctx->perf_get_data(); + + return data; +} + +void llama_perf_context_print(const llama_context * ctx) { + const auto data = llama_perf_context(ctx); + + const double t_end_ms = 1e-3 * ggml_time_us(); + + LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, data.t_load_ms); + LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval); + LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval); + LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval)); +} + +void llama_perf_context_reset(llama_context * ctx) { + ctx->perf_reset(); +} + +// +// training +// + +bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) { + GGML_UNUSED(tensor); + GGML_UNUSED(userdata); + return true; +} + +void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) { + ctx->opt_init(model, lopt_params); +} + +void llama_opt_epoch( + struct llama_context * ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, + ggml_opt_result_t result_eval, + int64_t idata_split, + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval) { + ctx->opt_epoch( + dataset, + result_train, + result_eval, + idata_split, + callback_train, + callback_eval); +} diff --git a/src/llama-context.h b/src/llama-context.h new file mode 100644 index 0000000000000000000000000000000000000000..040f03ae42e65b242f05f7e7f49a6cea577eb6ba --- /dev/null +++ b/src/llama-context.h @@ -0,0 +1,297 @@ +#pragma once + +#include "llama.h" +#include "llama-cparams.h" +#include "llama-graph.h" +#include "llama-adapter.h" + +#include "ggml-cpp.h" +#include "ggml-opt.h" + +#include +#include + +struct llama_model; +class llama_batch_allocr; + +class llama_io_read_i; +class llama_io_write_i; + +struct llama_memory_i; +struct llama_memory_state_i; + +struct llama_context { + // init scheduler and compute buffers, reserve worst-case graphs + llama_context( + const llama_model & model, + llama_context_params params); + + ~llama_context(); + + void synchronize(); + + const llama_model & get_model() const; + const llama_cparams & get_cparams() const; + + ggml_backend_sched_t get_sched() const; + + ggml_context * get_ctx_compute() const; + + uint32_t n_ctx() const; + uint32_t n_ctx_per_seq() const; + uint32_t n_batch() const; + uint32_t n_ubatch() const; + uint32_t n_seq_max() const; + + uint32_t n_threads() const; + uint32_t n_threads_batch() const; + + llama_memory_t get_memory() const; + + // return true of the KV cache was updated + // TODO: remove + bool kv_self_update(bool optimize); + void kv_self_defrag_sched(); + + enum llama_pooling_type pooling_type() const; + + float * get_logits(); + float * get_logits_ith(int32_t i); + + float * get_embeddings(); + float * get_embeddings_ith(int32_t i); + float * get_embeddings_seq(llama_seq_id seq_id); + + void attach_threadpool( + ggml_threadpool_t threadpool, + ggml_threadpool_t threadpool_batch); + + void detach_threadpool(); + + void set_n_threads(int32_t n_threads, int32_t n_threads_batch); + + void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); + + void set_embeddings (bool value); + void set_causal_attn(bool value); + void set_warmup(bool value); + + void set_adapter_lora( + llama_adapter_lora * adapter, + float scale); + + bool rm_adapter_lora( + llama_adapter_lora * adapter); + + void clear_adapter_lora(); + + bool apply_adapter_cvec( + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end); + + // process a single ubatch with a specific graph type + // if memory_state is provided, it will be applied first to the context's memory + // ret contains the status of the graph computation + // returns nullptr only if ret != GGML_STATUS_SUCCESS + llm_graph_result_ptr process_ubatch( + const llama_ubatch & ubatch, + llm_graph_type gtype, + llama_memory_state_i * mstate, + ggml_status & ret); + + int encode(const llama_batch & batch_inp); + int decode(const llama_batch & batch_inp); + + // + // state save/load + // + + size_t state_get_size(); + size_t state_get_data( uint8_t * dst, size_t size); + size_t state_set_data(const uint8_t * src, size_t size); + + size_t state_seq_get_size(llama_seq_id seq_id); + size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size); + size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size); + + bool state_load_file( + const char * filepath, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out); + + bool state_save_file( + const char * filepath, + const llama_token * tokens, + size_t n_token_count); + + size_t state_seq_load_file( + llama_seq_id seq_id, + const char * filepath, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out); + + size_t state_seq_save_file( + llama_seq_id seq_id, + const char * filepath, + const llama_token * tokens, + size_t n_token_count); + + // + // perf + // + + llama_perf_context_data perf_get_data() const; + void perf_reset(); + + // + // training + // + + void opt_init(struct llama_model * model, struct llama_opt_params lopt_params); + + void opt_epoch( + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, + ggml_opt_result_t result_eval, + int64_t idata_split, + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval); + + void opt_epoch_iter( + ggml_opt_dataset_t dataset, + ggml_opt_result_t result, + const std::vector & tokens, + const std::vector & labels_sparse, + llama_batch & batch, + ggml_opt_epoch_callback callback, + bool train, + int64_t idata_in_loop, + int64_t ndata_in_loop, + int64_t t_loop_start); + +private: + // + // output + // + + // Make sure enough space is available for outputs. + // Returns max number of outputs for which space was reserved. + uint32_t output_reserve(int32_t n_outputs); + + // + // graph + // + +public: + int32_t graph_max_nodes() const; + + // zero-out inputs and create the ctx_compute for the compute graph + ggml_cgraph * graph_init(); + + // returns the result of ggml_backend_sched_graph_compute_async execution + ggml_status graph_compute(ggml_cgraph * gf, bool batched); + + // reserve a graph with a dummy ubatch of the specified size + ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate); + +private: + llm_graph_result_ptr graph_build( + ggml_context * ctx, + ggml_cgraph * gf, + const llama_ubatch & ubatch, + llm_graph_type gtype, + const llama_memory_state_i * mstate); + + llm_graph_cb graph_get_cb() const; + + // TODO: read/write lora adapters and cvec + size_t state_write_data(llama_io_write_i & io); + size_t state_read_data (llama_io_read_i & io); + + size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id); + size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id); + + // + // members + // + + const llama_model & model; + + llama_cparams cparams; + llama_adapter_cvec cvec; + llama_adapter_loras loras; + + llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably + + std::unique_ptr memory; + + // TODO: temporary, until the llama_kv_self_defrag() API is removed + bool memory_force_optimize = false; + + // decode output (2-dimensional array: [n_outputs][n_vocab]) + size_t logits_size = 0; // capacity (of floats) for logits + float * logits = nullptr; + + // embeddings output (2-dimensional array: [n_outputs][n_embd]) + // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE + size_t embd_size = 0; // capacity (of floats) for embeddings + float * embd = nullptr; + + // sequence embeddings output (map of [n_embd] vectors) + // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE + std::map> embd_seq; + + // reuse the batch_allocr to avoid unnecessary memory allocations + std::unique_ptr batch_allocr; + + uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch + + std::vector output_ids; // map batch token positions to ids of the logits and embd buffers + + ggml_backend_sched_ptr sched; + + ggml_backend_t backend_cpu = nullptr; + std::vector backends; + + ggml_context_ptr ctx_compute; + + // training + ggml_opt_context_t opt_ctx = nullptr; + + ggml_threadpool_t threadpool = nullptr; + ggml_threadpool_t threadpool_batch = nullptr; + + ggml_abort_callback abort_callback = nullptr; + void * abort_callback_data = nullptr; + + std::vector> set_n_threads_fns; + + // buffer types used for the compute buffer of each backend + std::vector backend_ptrs; + std::vector backend_buft; + + // memory buffers used to evaluate the model + std::vector buf_compute_meta; + + // host buffer for the model output (logits and embeddings) + ggml_backend_buffer_ptr buf_output; + + bool has_evaluated_once = false; + + // perf + mutable int64_t t_start_us = 0; + mutable int64_t t_load_us = 0; + mutable int64_t t_p_eval_us = 0; + mutable int64_t t_eval_us = 0; + + mutable int64_t t_compute_start_us = 0; + mutable int64_t n_queued_tokens = 0; + + mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) + mutable int32_t n_eval = 0; // number of eval calls +}; diff --git a/src/llama-cparams.cpp b/src/llama-cparams.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a3e7a37ee36d78fd6af7d5e3aba8762ecbb636b6 --- /dev/null +++ b/src/llama-cparams.cpp @@ -0,0 +1,5 @@ +#include "llama-cparams.h" + +size_t llama_max_parallel_sequences(void) { + return LLAMA_MAX_SEQ; +} diff --git a/src/llama-cparams.h b/src/llama-cparams.h new file mode 100644 index 0000000000000000000000000000000000000000..118615d5bd2d59f4e640c661592cf91ee0d3fda3 --- /dev/null +++ b/src/llama-cparams.h @@ -0,0 +1,41 @@ +#pragma once + +#include "llama.h" + +#include + +#define LLAMA_MAX_SEQ 64 + +struct llama_cparams { + uint32_t n_ctx; // context size used during inference + uint32_t n_batch; + uint32_t n_ubatch; + uint32_t n_seq_max; + int n_threads; // number of threads to use for generation + int n_threads_batch; // number of threads to use for batch processing + + float rope_freq_base; + float rope_freq_scale; + + uint32_t n_ctx_orig_yarn; + // These hyperparameters are not exposed in GGUF, because all + // existing YaRN models use the same values for them. + float yarn_ext_factor; + float yarn_attn_factor; + float yarn_beta_fast; + float yarn_beta_slow; + float defrag_thold; + + bool embeddings; + bool causal_attn; + bool offload_kqv; + bool flash_attn; + bool no_perf; + bool warmup; + bool op_offload; + + enum llama_pooling_type pooling_type; + + ggml_backend_sched_eval_callback cb_eval; + void * cb_eval_user_data; +}; diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bed706bb248d139664d8024726948e9fb1ba4cb5 --- /dev/null +++ b/src/llama-grammar.cpp @@ -0,0 +1,1229 @@ +#include "llama-grammar.h" + +#include "llama-impl.h" +#include "llama-vocab.h" +#include "llama-sampling.h" + +#include +#include +#include + +// +// helpers +// + +// NOTE: assumes valid utf8 (but checks for overrun) +static std::pair decode_utf8(const char * src) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t first_byte = static_cast(*src); + uint8_t highbits = first_byte >> 4; + int len = lookup[highbits]; + uint8_t mask = (1 << (8 - len)) - 1; + uint32_t value = first_byte & mask; + const char * end = src + len; // may overrun! + const char * pos = src + 1; + for ( ; pos < end && *pos; pos++) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + } + return std::make_pair(value, pos); +} + +static std::pair, llama_partial_utf8> decode_utf8( + const std::string & src, + llama_partial_utf8 partial_start) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; + const char * pos = src.c_str(); + std::vector code_points; + + // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0. + code_points.reserve(src.size() + 1); + uint32_t value = partial_start.value; + int n_remain = partial_start.n_remain; + + // continue previous decode, if applicable + while (*pos != 0 && n_remain > 0) { + uint8_t next_byte = static_cast(*pos); + if ((next_byte >> 6) != 2) { + // invalid sequence, abort + code_points.push_back(0); + return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 }); + } + value = (value << 6) + (next_byte & 0x3F); + ++pos; + --n_remain; + } + + if (partial_start.n_remain > 0 && n_remain == 0) { + code_points.push_back(value); + } + + // decode any subsequent utf-8 sequences, which may end in an incomplete one + while (*pos != 0) { + uint8_t first_byte = static_cast(*pos); + uint8_t highbits = first_byte >> 4; + n_remain = lookup[highbits] - 1; + + if (n_remain < 0) { + // invalid sequence, abort + code_points.clear(); + code_points.push_back(0); + return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain }); + } + + uint8_t mask = (1 << (7 - n_remain)) - 1; + value = first_byte & mask; + + ++pos; + while (*pos != 0 && n_remain > 0) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + ++pos; + --n_remain; + } + if (n_remain == 0) { + code_points.push_back(value); + } + } + code_points.push_back(0); + + return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain }); +} + +static bool is_digit_char(char c) { + return '0' <= c && c <= '9'; +} + +static bool is_word_char(char c) { + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c); +} + +static std::pair parse_hex(const char * src, int size) { + const char * pos = src; + const char * end = src + size; + uint32_t value = 0; + for ( ; pos < end && *pos; pos++) { + value <<= 4; + char c = *pos; + if ('a' <= c && c <= 'f') { + value += c - 'a' + 10; + } else if ('A' <= c && c <= 'F') { + value += c - 'A' + 10; + } else if ('0' <= c && c <= '9') { + value += c - '0'; + } else { + break; + } + } + if (pos != end) { + throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); + } + return std::make_pair(value, pos); +} + +static const char * parse_space(const char * src, bool newline_ok) { + const char * pos = src; + while (*pos == ' ' || *pos == '\t' || *pos == '#' || + (newline_ok && (*pos == '\r' || *pos == '\n'))) { + if (*pos == '#') { + while (*pos && *pos != '\r' && *pos != '\n') { + pos++; + } + } else { + pos++; + } + } + return pos; +} + +static const char * parse_name(const char * src) { + const char * pos = src; + while (is_word_char(*pos)) { + pos++; + } + if (pos == src) { + throw std::runtime_error(std::string("expecting name at ") + src); + } + return pos; +} + +static const char * parse_int(const char * src) { + const char * pos = src; + while (is_digit_char(*pos)) { + pos++; + } + if (pos == src) { + throw std::runtime_error(std::string("expecting integer at ") + src); + } + return pos; +} + +static std::pair parse_char(const char * src) { + if (*src == '\\') { + switch (src[1]) { + case 'x': return parse_hex(src + 2, 2); + case 'u': return parse_hex(src + 2, 4); + case 'U': return parse_hex(src + 2, 8); + case 't': return std::make_pair('\t', src + 2); + case 'r': return std::make_pair('\r', src + 2); + case 'n': return std::make_pair('\n', src + 2); + case '\\': + case '"': + case '[': + case ']': + return std::make_pair(src[1], src + 2); + default: + throw std::runtime_error(std::string("unknown escape at ") + src); + } + } else if (*src) { + return decode_utf8(src); + } + throw std::runtime_error("unexpected end of input"); +} + +static void print_grammar_char(FILE * file, uint32_t c) { + if (0x20 <= c && c <= 0x7f) { + fprintf(file, "%c", static_cast(c)); + } else { + // cop out of encoding UTF-8 + fprintf(file, "", c); + } +} + +static bool is_char_element(llama_grammar_element elem) { + switch (elem.type) { + case LLAMA_GRETYPE_CHAR: return true; + case LLAMA_GRETYPE_CHAR_NOT: return true; + case LLAMA_GRETYPE_CHAR_ALT: return true; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; + case LLAMA_GRETYPE_CHAR_ANY: return true; + default: return false; + } +} + +static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) { + for (auto elem : rule) { + switch (elem.type) { + case LLAMA_GRETYPE_END: fprintf(file, "END"); break; + case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break; + case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; + case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break; + case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; + case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; + case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; + } + switch (elem.type) { + case LLAMA_GRETYPE_END: + case LLAMA_GRETYPE_ALT: + case LLAMA_GRETYPE_RULE_REF: + fprintf(file, "(%u) ", elem.value); + break; + case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_NOT: + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + case LLAMA_GRETYPE_CHAR_ALT: + case LLAMA_GRETYPE_CHAR_ANY: + fprintf(file, "(\""); + print_grammar_char(file, elem.value); + fprintf(file, "\") "); + break; + } + } + fprintf(file, "\n"); +} + +static void print_rule( + FILE * file, + uint32_t rule_id, + const llama_grammar_rule & rule, + const std::map & symbol_id_names) { + if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { + throw std::runtime_error( + "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); + } + fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); + for (size_t i = 0, end = rule.size() - 1; i < end; i++) { + llama_grammar_element elem = rule[i]; + switch (elem.type) { + case LLAMA_GRETYPE_END: + throw std::runtime_error( + "unexpected end of rule: " + std::to_string(rule_id) + "," + + std::to_string(i)); + case LLAMA_GRETYPE_ALT: + fprintf(file, "| "); + break; + case LLAMA_GRETYPE_RULE_REF: + fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); + break; + case LLAMA_GRETYPE_CHAR: + fprintf(file, "["); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_NOT: + fprintf(file, "[^"); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); + } + fprintf(file, "-"); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_ALT: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); + } + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_ANY: + fprintf(file, "."); + break; + } + if (is_char_element(elem)) { + switch (rule[i + 1].type) { + case LLAMA_GRETYPE_CHAR_ALT: + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + case LLAMA_GRETYPE_CHAR_ANY: + break; + default: + fprintf(file, "] "); + } + } + } + fprintf(file, "\n"); +} + +// +// implementation +// + +uint32_t llama_grammar_parser::get_symbol_id(const char * src, size_t len) { + uint32_t next_id = static_cast(symbol_ids.size()); + auto result = symbol_ids.emplace(std::string(src, len), next_id); + return result.first->second; +} + +uint32_t llama_grammar_parser::generate_symbol_id(const std::string & base_name) { + uint32_t next_id = static_cast(symbol_ids.size()); + symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; + return next_id; +} + +void llama_grammar_parser::add_rule(uint32_t rule_id, const llama_grammar_rule & rule) { + if (rules.size() <= rule_id) { + rules.resize(rule_id + 1); + } + rules[rule_id] = rule; +} + +const char * llama_grammar_parser::parse_alternates( + const char * src, + const std::string & rule_name, + uint32_t rule_id, + bool is_nested) { + llama_grammar_rule rule; + const char * pos = parse_sequence(src, rule_name, rule, is_nested); + while (*pos == '|') { + rule.push_back({LLAMA_GRETYPE_ALT, 0}); + pos = parse_space(pos + 1, true); + pos = parse_sequence(pos, rule_name, rule, is_nested); + } + rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule(rule_id, rule); + return pos; +} + +const char * llama_grammar_parser::parse_sequence( + const char * src, + const std::string & rule_name, + llama_grammar_rule & rule, + bool is_nested) { + size_t last_sym_start = rule.size(); + const char * pos = src; + + auto handle_repetitions = [&](int min_times, int max_times) { + + if (last_sym_start == rule.size()) { + throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); + } + + // apply transformation to previous symbol (last_sym_start to end) according to + // the following rewrite rules: + // S{m,n} --> S S S (m times) S'(n-m) + // S'(x) ::= S S'(x-1) | + // (... n-m definitions of these S' rules ...) + // S'(1) ::= S | + // S{m,} --> S S S (m times) S' + // S' ::= S S' | + // S* --> S{0,} + // --> S' ::= S S' | + // S+ --> S{1,} + // --> S S' + // S' ::= S S' | + // S? --> S{0,1} + // --> S' + // S' ::= S | + + llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end()); + if (min_times == 0) { + rule.resize(last_sym_start); + } else { + // Repeat the previous elements (min_times - 1) times + for (int i = 1; i < min_times; i++) { + rule.insert(rule.end(), prev_rule.begin(), prev_rule.end()); + } + } + + uint32_t last_rec_rule_id = 0; + auto n_opt = max_times < 0 ? 1 : max_times - min_times; + + llama_grammar_rule rec_rule(prev_rule); + for (int i = 0; i < n_opt; i++) { + rec_rule.resize(prev_rule.size()); + uint32_t rec_rule_id = generate_symbol_id( rule_name); + if (i > 0 || max_times < 0) { + rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); + } + rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); + rec_rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule( rec_rule_id, rec_rule); + last_rec_rule_id = rec_rule_id; + } + if (n_opt > 0) { + rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); + } + }; + + while (*pos) { + if (*pos == '"') { // literal string + pos++; + last_sym_start = rule.size(); + while (*pos != '"') { + if (!*pos) { + throw std::runtime_error("unexpected end of input"); + } + auto char_pair = parse_char(pos); + pos = char_pair.second; + rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); + } + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '[') { // char range(s) + pos++; + enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; + if (*pos == '^') { + pos++; + start_type = LLAMA_GRETYPE_CHAR_NOT; + } + last_sym_start = rule.size(); + while (*pos != ']') { + if (!*pos) { + throw std::runtime_error("unexpected end of input"); + } + auto char_pair = parse_char(pos); + pos = char_pair.second; + enum llama_gretype type = last_sym_start < rule.size() + ? LLAMA_GRETYPE_CHAR_ALT + : start_type; + + rule.push_back({type, char_pair.first}); + if (pos[0] == '-' && pos[1] != ']') { + if (!pos[1]) { + throw std::runtime_error("unexpected end of input"); + } + auto endchar_pair = parse_char(pos + 1); + pos = endchar_pair.second; + rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); + } + } + pos = parse_space(pos + 1, is_nested); + } else if (is_word_char(*pos)) { // rule reference + const char * name_end = parse_name(pos); + uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos); + pos = parse_space(name_end, is_nested); + last_sym_start = rule.size(); + rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); + } else if (*pos == '(') { // grouping + // parse nested alternates into synthesized rule + pos = parse_space(pos + 1, true); + uint32_t sub_rule_id = generate_symbol_id(rule_name); + pos = parse_alternates(pos, rule_name, sub_rule_id, true); + last_sym_start = rule.size(); + // output reference to synthesized rule + rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + if (*pos != ')') { + throw std::runtime_error(std::string("expecting ')' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '.') { // any char + last_sym_start = rule.size(); + rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '*') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, -1); + } else if (*pos == '+') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(1, -1); + } else if (*pos == '?') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, 1); + } else if (*pos == '{') { + pos = parse_space(pos + 1, is_nested); + + if (!is_digit_char(*pos)) { + throw std::runtime_error(std::string("expecting an int at ") + pos); + } + const char * int_end = parse_int(pos); + int min_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + + int max_times = -1; + + if (*pos == '}') { + max_times = min_times; + pos = parse_space(pos + 1, is_nested); + } else if (*pos == ',') { + pos = parse_space(pos + 1, is_nested); + + if (is_digit_char(*pos)) { + const char * int_end = parse_int(pos); + max_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + } + + if (*pos != '}') { + throw std::runtime_error(std::string("expecting '}' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else { + throw std::runtime_error(std::string("expecting ',' at ") + pos); + } + handle_repetitions(min_times, max_times); + } else { + break; + } + } + return pos; +} + +const char * llama_grammar_parser::parse_rule(const char * src) { + const char * name_end = parse_name(src); + const char * pos = parse_space(name_end, false); + size_t name_len = name_end - src; + uint32_t rule_id = get_symbol_id(src, name_len); + const std::string name(src, name_len); + + if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { + throw std::runtime_error(std::string("expecting ::= at ") + pos); + } + pos = parse_space(pos + 3, true); + + pos = parse_alternates(pos, name, rule_id, false); + + if (*pos == '\r') { + pos += pos[1] == '\n' ? 2 : 1; + } else if (*pos == '\n') { + pos++; + } else if (*pos) { + throw std::runtime_error(std::string("expecting newline or end at ") + pos); + } + return parse_space(pos, true); +} + +bool llama_grammar_parser::parse(const char * src) { + try { + const char * pos = parse_space(src, true); + while (*pos) { + pos = parse_rule(pos); + } + // Validate the state to ensure that all rules are defined + for (const auto & rule : rules) { + if (rule.empty()) { + throw std::runtime_error("Undefined rule"); + } + for (const auto & elem : rule) { + if (elem.type == LLAMA_GRETYPE_RULE_REF) { + // Ensure that the rule at that location exists + if (elem.value >= rules.size() || rules[elem.value].empty()) { + // Get the name of the rule that is missing + for (const auto & kv : symbol_ids) { + if (kv.second == elem.value) { + throw std::runtime_error("Undefined rule identifier '" + kv.first + "'"); + } + } + } + } + } + } + } catch (const std::exception & err) { + fprintf(stderr, "%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src); + rules.clear(); + return false; + } + + return true; +} + +void llama_grammar_parser::print(FILE * file) { + try { + std::map symbol_id_names; + for (const auto & kv : symbol_ids) { + symbol_id_names[kv.second] = kv.first; + } + for (size_t i = 0, end = rules.size(); i < end; i++) { + // fprintf(file, "%zu: ", i); + // print_rule_binary(file, rules[i]); + print_rule(file, uint32_t(i), rules[i], symbol_id_names); + // fprintf(file, "\n"); + } + } catch (const std::exception & err) { + fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); + } +} + +llama_grammar_stack llama_grammar_parser::c_rules() const { + llama_grammar_stack ret; + ret.reserve(rules.size()); + for (const auto & rule : rules) { + ret.push_back(rule.data()); + } + return ret; +} + +// returns true iff pos points to the end of one of the definitions of a rule +static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) { + switch (pos->type) { + case LLAMA_GRETYPE_END: return true; // NOLINT + case LLAMA_GRETYPE_ALT: return true; // NOLINT + default: return false; + } +} + +// returns true iff chr satisfies the char range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static std::pair llama_grammar_match_char( + const llama_grammar_element * pos, + const uint32_t chr) { + bool found = false; + bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY; + + GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT + + do { + if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + found = found || (pos->value <= chr && chr <= pos[1].value); + pos += 2; + } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) { + // Any character matches "." + found = true; + pos += 1; + } else { + // exact char match, e.g. [a] or "a" + found = found || pos->value == chr; + pos += 1; + } + } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); + + return std::make_pair(found == is_positive_char, pos); +} + +// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char +// range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static bool llama_grammar_match_partial_char( + const llama_grammar_element * pos, + const llama_partial_utf8 partial_utf8) { + bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY; + GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); + + uint32_t partial_value = partial_utf8.value; + int n_remain = partial_utf8.n_remain; + + // invalid sequence or 7-bit char split across 2 bytes (overlong) + if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) { + return false; + } + + // range of possible code points this partial UTF-8 sequence could complete to + uint32_t low = partial_value << (n_remain * 6); + uint32_t high = low | ((1 << (n_remain * 6)) - 1); + + if (low == 0) { + if (n_remain == 2) { + low = 1 << 11; + } else if (n_remain == 3) { + low = 1 << 16; + } + } + + do { + if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + if (pos->value <= high && low <= pos[1].value) { + return is_positive_char; + } + pos += 2; + } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) { + // Any character matches "." + return true; + } else { + // exact char match, e.g. [a] or "a" + if (low <= pos->value && pos->value <= high) { + return is_positive_char; + } + pos += 1; + } + } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); + + return !is_positive_char; +} + +// transforms a grammar pushdown stack into N possible stacks, all ending +// at a character range (terminal element) +static void llama_grammar_advance_stack( + const llama_grammar_rules & rules, + const llama_grammar_stack & stack, + llama_grammar_stacks & new_stacks) { + if (stack.empty()) { + if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { + new_stacks.emplace_back(stack); + } + return; + } + + const llama_grammar_element * pos = stack.back(); + + switch (pos->type) { + case LLAMA_GRETYPE_RULE_REF: { + const size_t rule_id = static_cast(pos->value); + const llama_grammar_element * subpos = rules[rule_id].data(); + do { + // init new stack without the top (pos) + llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos + 1)) { + // if this rule ref is followed by another element, add that to stack + new_stack.push_back(pos + 1); + } + if (!llama_grammar_is_end_of_sequence(subpos)) { + // if alternate is nonempty, add to stack + new_stack.push_back(subpos); + } + llama_grammar_advance_stack(rules, new_stack, new_stacks); + while (!llama_grammar_is_end_of_sequence(subpos)) { + // scan to end of alternate def + subpos++; + } + if (subpos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + subpos++; + } else { + break; + } + } while (true); + break; + } + case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_NOT: + case LLAMA_GRETYPE_CHAR_ANY: + if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { + // only add the stack if it's not a duplicate of one we already have + new_stacks.emplace_back(stack); + } + break; + default: + // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range + // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on + // those + GGML_ABORT("fatal error"); + } +} + +static llama_grammar_candidates llama_grammar_reject_candidates( + const llama_grammar_rules & rules, + const llama_grammar_stacks & stacks, + const llama_grammar_candidates & candidates) { + GGML_ASSERT(!stacks.empty()); // REVIEW + + if (candidates.empty()) { + return {}; + } + + auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates); + + for (size_t i = 1, size = stacks.size(); i < size; ++i) { + rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects); + } + + return rejects; +} + +static bool llama_grammar_detect_left_recursion( + const llama_grammar_rules & rules, + size_t rule_index, + std::vector * rules_visited, + std::vector * rules_in_progress, + std::vector * rules_may_be_empty) { + if ((*rules_in_progress)[rule_index]) { + return true; + } + + (*rules_in_progress)[rule_index] = true; + + const llama_grammar_rule & rule = rules[rule_index]; + + // First check if the rule might produce the empty string. This could be done combined with the second + // step but it's more readable as two steps. + bool at_rule_start = true; + for (size_t i = 0; i < rule.size(); i++) { + if (llama_grammar_is_end_of_sequence(&rule[i])) { + if (at_rule_start) { + (*rules_may_be_empty)[rule_index] = true; + break; + } + at_rule_start = true; + } else { + at_rule_start = false; + } + } + + // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may + // be empty) + bool recurse_into_nonterminal = true; + for (size_t i = 0; i < rule.size(); i++) { + if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) { + if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) { + return true; + } + if (!((*rules_may_be_empty)[(size_t)rule[i].value])) { + recurse_into_nonterminal = false; + } + } else if (llama_grammar_is_end_of_sequence(&rule[i])) { + recurse_into_nonterminal = true; + } else { + recurse_into_nonterminal = false; + } + } + + (*rules_in_progress)[rule_index] = false; + (*rules_visited)[rule_index] = true; + + return false; +} + +const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) { + return grammar->rules; +} + +llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) { + return grammar->stacks; +} + +void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) { + llama_grammar_stacks stacks_new; + stacks_new.reserve(grammar->stacks.size()); + + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + continue; + } + + auto match = llama_grammar_match_char(stack.back(), chr); + if (match.first) { + const llama_grammar_element * pos = match.second; + + // update top of stack to next element, if any + llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos)) { + new_stack.push_back(pos); + } + llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new); + } + } + + grammar->stacks = std::move(stacks_new); +} + +llama_grammar_candidates llama_grammar_reject_candidates_for_stack( + const llama_grammar_rules & rules, + const llama_grammar_stack & stack, + const llama_grammar_candidates & candidates) { + + llama_grammar_candidates rejects; + rejects.reserve(candidates.size()); + + if (stack.empty()) { + for (const auto & tok : candidates) { + if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) { + rejects.push_back(tok); + } + } + return rejects; + } + + const llama_grammar_element * stack_pos = stack.back(); + + llama_grammar_candidates next_candidates; + next_candidates.reserve(candidates.size()); + + for (const auto & tok : candidates) { + if (*tok.code_points == 0) { + // reached end of full codepoints in token, reject iff it ended in a partial sequence + // that cannot satisfy this position in grammar + if (tok.partial_utf8.n_remain != 0 && + !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) { + rejects.push_back(tok); + } + } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) { + next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 }); + } else { + rejects.push_back(tok); + } + } + + const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second; + + // update top of stack to next element, if any + llama_grammar_stack stack_after(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(stack_pos_after)) { + stack_after.push_back(stack_pos_after); + } + llama_grammar_stacks next_stacks; + llama_grammar_advance_stack(rules, stack_after, next_stacks); + + auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); + for (const auto & tok : next_rejects) { + rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 }); + } + + return rejects; +} + +//////////////////// + +struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index) { + const llama_grammar_element * pos; + + // copy rule definitions into vectors + llama_grammar_rules vec_rules(n_rules); + for (size_t i = 0; i < n_rules; i++) { + for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) { + vec_rules[i].push_back(*pos); + } + vec_rules[i].push_back({LLAMA_GRETYPE_END, 0}); + } + + // Check for left recursion + std::vector rules_visited(n_rules); + std::vector rules_in_progress(n_rules); + std::vector rules_may_be_empty(n_rules); + for (size_t i = 0; i < n_rules; i++) { + if (rules_visited[i]) { + continue; + } + if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) { + LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i); + return nullptr; + } + } + + // loop over alternates of start rule to build initial stacks + llama_grammar_stacks stacks; + pos = vec_rules[start_rule_index].data(); + do { + llama_grammar_stack stack; + if (!llama_grammar_is_end_of_sequence(pos)) { + // if alternate is nonempty, add to stack + stack.push_back(pos); + } + llama_grammar_advance_stack(vec_rules, stack, stacks); + while (!llama_grammar_is_end_of_sequence(pos)) { + // scan to end of alternate def + pos++; + } + if (pos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + pos++; + } else { + break; + } + } while (true); + + // Important: vec_rules has to be moved here, not copied, because stacks contains + // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar + // then the pointers would be invalidated when the local vec_rules goes out of scope. + return new llama_grammar { + vocab, + std::move(vec_rules), + std::move(stacks), + /* .partial_utf8 = */ {}, + /* .lazy =*/ false, + /* .awaiting_trigger = */ false, + /* .trigger_buffer = */ "", + /* .trigger_tokens = */ {}, + /* .trigger_patterns = */ {}, + }; +} + +struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_patterns, + size_t num_trigger_patterns, + const llama_token * trigger_tokens, + size_t num_trigger_tokens) { + llama_grammar_parser parser; + + // if there is a grammar, parse it + // rules will be empty (default) if there are parse errors + if (!parser.parse(grammar_str) || parser.rules.empty()) { + fprintf(stderr, "%s: failed to parse grammar\n", __func__); + return nullptr; + } + + // Ensure that there is a "root" node. + if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) { + fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__); + return nullptr; + } + + std::vector grammar_rules(parser.c_rules()); + + const size_t n_rules = grammar_rules.size(); + const size_t start_rule_index = parser.symbol_ids.at(grammar_root); + + const llama_grammar_element * pos; + + // copy rule definitions into vectors + llama_grammar_rules vec_rules(n_rules); + for (size_t i = 0; i < n_rules; i++) { + for (pos = grammar_rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) { + vec_rules[i].push_back(*pos); + } + vec_rules[i].push_back({LLAMA_GRETYPE_END, 0}); + } + + // Check for left recursion + std::vector rules_visited(n_rules); + std::vector rules_in_progress(n_rules); + std::vector rules_may_be_empty(n_rules); + for (size_t i = 0; i < n_rules; i++) { + if (rules_visited[i]) { + continue; + } + if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) { + LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i); + return nullptr; + } + } + + // loop over alternates of start rule to build initial stacks + llama_grammar_stacks stacks; + pos = vec_rules[start_rule_index].data(); + do { + llama_grammar_stack stack; + if (!llama_grammar_is_end_of_sequence(pos)) { + // if alternate is nonempty, add to stack + stack.push_back(pos); + } + llama_grammar_advance_stack(vec_rules, stack, stacks); + while (!llama_grammar_is_end_of_sequence(pos)) { + // scan to end of alternate def + pos++; + } + if (pos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + pos++; + } else { + break; + } + } while (true); + + std::vector vec_trigger_tokens; + std::vector vec_trigger_patterns; + for (size_t i = 0; i < num_trigger_tokens; i++) { + GGML_ASSERT(trigger_tokens != nullptr); + vec_trigger_tokens.push_back(trigger_tokens[i]); + } + for (size_t i = 0; i < num_trigger_patterns; i++) { + GGML_ASSERT(trigger_patterns != nullptr); + auto & trigger = vec_trigger_patterns.emplace_back(); + trigger.pattern = trigger_patterns[i]; + trigger.regex = std::regex(trigger.pattern); + } + + // Important: vec_rules has to be moved here, not copied, because stacks contains + // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar + // then the pointers would be invalidated when the local vec_rules goes out of scope. + return new llama_grammar { + vocab, + std::move(vec_rules), + std::move(stacks), + /* .partial_utf8 = */ {}, + /* .lazy = */ lazy, + /* .awaiting_trigger = */ lazy, + /* .trigger_buffer = */ "", + std::move(vec_trigger_tokens), + std::move(vec_trigger_patterns), + }; +} + +void llama_grammar_free_impl(struct llama_grammar * grammar) { + if (grammar == nullptr) { + return; + } + + delete grammar; +} + +struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) { + auto * result = new llama_grammar { + grammar.vocab, + grammar.rules, + grammar.stacks, + grammar.partial_utf8, + grammar.lazy, + grammar.awaiting_trigger, + grammar.trigger_buffer, + grammar.trigger_tokens, + grammar.trigger_patterns, + }; + + // redirect elements in stacks to point to new rules + for (size_t is = 0; is < result->stacks.size(); is++) { + for (size_t ie = 0; ie < result->stacks[is].size(); ie++) { + for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) { + for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) { + if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) { + result->stacks[is][ie] = &result->rules[ir0][ir1]; + } + } + } + } + } + + return result; +} + +void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) { + GGML_ASSERT(grammar.vocab != nullptr); + + if (grammar.awaiting_trigger) { + return; + } + + bool allow_eog = false; + for (const auto & stack : grammar.stacks) { + if (stack.empty()) { + allow_eog = true; + break; + } + } + + std::vector, llama_partial_utf8>> candidates_decoded; + candidates_decoded.reserve(cur_p->size); + + llama_grammar_candidates candidates_grammar; + candidates_grammar.reserve(cur_p->size); + + for (size_t i = 0; i < cur_p->size; ++i) { + const llama_token id = cur_p->data[i].id; + const std::string & piece = grammar.vocab->token_to_piece(id); + + if (grammar.vocab->is_eog(id)) { + if (!allow_eog) { + cur_p->data[i].logit = -INFINITY; + } + } else if (piece.empty() || piece[0] == 0) { + cur_p->data[i].logit = -INFINITY; + } else { + candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8)); + candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); + } + } + + const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); + for (const auto & reject : rejects) { + cur_p->data[reject.index].logit = -INFINITY; + } +} + +void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { + GGML_ASSERT(grammar.vocab != nullptr); + + const auto & piece = grammar.vocab->token_to_piece(token); + + if (grammar.awaiting_trigger) { + if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { + grammar.awaiting_trigger = false; + grammar.trigger_buffer.clear(); + llama_grammar_accept_str(grammar, piece); + LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str()); + return; + } else { + grammar.trigger_buffer += piece; + + std::smatch match; + for (const auto & trigger_pattern : grammar.trigger_patterns) { + if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) { + grammar.awaiting_trigger = false; + // get from the first matched capturing group to the end of the string + size_t start = std::string::npos; + for (auto i = 1u; i < match.size(); i++) { + if (match.length(i) > 0) { + start = match.position(i); + break; + } + } + if (start == std::string::npos) { + start = match.position(0); + } + auto constrained_str = grammar.trigger_buffer.substr(start); + // std::string constrained_str(match[1].first, grammar.trigger_buffer.end()); + grammar.trigger_buffer.clear(); + llama_grammar_accept_str(grammar, constrained_str); + LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str()); + return; + } + } + LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str()); + return; + } + } + + if (grammar.vocab->is_eog(token)) { + for (const auto & stack : grammar.stacks) { + if (stack.empty()) { + return; + } + } + GGML_ABORT("fatal error"); + } + + llama_grammar_accept_str(grammar, piece); +} + +void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) { + // Note terminating 0 in decoded string + const auto decoded = decode_utf8(piece, grammar.partial_utf8); + const auto & code_points = decoded.first; + + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + llama_grammar_accept(&grammar, *it); + } + + grammar.partial_utf8 = decoded.second; + if (grammar.stacks.empty()) { + throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece); + } +} diff --git a/src/llama-grammar.h b/src/llama-grammar.h new file mode 100644 index 0000000000000000000000000000000000000000..f8c291de999ac6f320936e7b1d6255a57ad65ac7 --- /dev/null +++ b/src/llama-grammar.h @@ -0,0 +1,173 @@ +#pragma once + +#include "llama.h" + +#include +#include +#include +#include + +struct llama_vocab; + +// grammar element type +enum llama_gretype { + // end of rule definition + LLAMA_GRETYPE_END = 0, + + // start of alternate definition for rule + LLAMA_GRETYPE_ALT = 1, + + // non-terminal element: reference to rule + LLAMA_GRETYPE_RULE_REF = 2, + + // terminal element: character (code point) + LLAMA_GRETYPE_CHAR = 3, + + // inverse char(s) ([^a], [^a-b] [^abc]) + LLAMA_GRETYPE_CHAR_NOT = 4, + + // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to + // be an inclusive range ([a-z]) + LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, + + // modifies a preceding LLAMA_GRETYPE_CHAR or + // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + LLAMA_GRETYPE_CHAR_ALT = 6, + + // any character (.) + LLAMA_GRETYPE_CHAR_ANY = 7, +}; + +typedef struct llama_grammar_element { + enum llama_gretype type; + uint32_t value; // Unicode code point or rule ID +} llama_grammar_element; + +struct llama_partial_utf8 { + uint32_t value; // bit value so far (unshifted) + int n_remain; // num bytes remaining; -1 indicates invalid sequence +}; + +struct llama_grammar_candidate { + size_t index; + const uint32_t * code_points; + llama_partial_utf8 partial_utf8; +}; + +using llama_grammar_rule = std::vector< llama_grammar_element>; +using llama_grammar_stack = std::vector; + +using llama_grammar_rules = std::vector; +using llama_grammar_stacks = std::vector; +using llama_grammar_candidates = std::vector; + +// TODO: remove, needed for tests atm +const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); + llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); + +// takes a set of possible pushdown stacks on a grammar, which are required to +// be positioned at a character range (see `llama_grammar_advance_stack`), and +// produces the N possible stacks if the given char is accepted at those +// positions +void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr); + +std::vector llama_grammar_reject_candidates_for_stack( + const llama_grammar_rules & rules, + const llama_grammar_stack & stack, + const llama_grammar_candidates & candidates); + +struct llama_grammar_parser { + std::map symbol_ids; + + llama_grammar_rules rules; + + llama_grammar_stack c_rules() const; + + uint32_t get_symbol_id(const char * src, size_t len); + uint32_t generate_symbol_id(const std::string & base_name); + + void add_rule(uint32_t rule_id, const llama_grammar_rule & rule); + + const char * parse_alternates( + const char * src, + const std::string & rule_name, + uint32_t rule_id, + bool is_nested); + + const char * parse_sequence( + const char * src, + const std::string & rule_name, + llama_grammar_rule & rule, + bool is_nested); + + const char * parse_rule(const char * src); + + bool parse(const char * src); + void print(FILE * file); +}; + +struct llama_grammar_trigger_pattern { + std::string pattern; + std::regex regex; +}; + +struct llama_grammar { + // note: allow null vocab for testing (not great) + const llama_vocab * vocab; + + const llama_grammar_rules rules; // TODO: shared ptr + llama_grammar_stacks stacks; + + // buffer for partially generated UTF-8 sequence from accepted tokens + llama_partial_utf8 partial_utf8; + + // lazy grammars wait for trigger words or tokens before constraining the sampling. + // we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens. + // (useful e.g. for tool_choice=required) + bool lazy = false; + bool awaiting_trigger = false; // Initialized to true for lazy grammars only + std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found. + std::vector trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special). + std::vector + trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated + // string, and the grammar will be given the string from the first match group onwards. + +}; + +// +// internal API +// + +// note: needed for tests (not great) +struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index); + +struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_patterns, + size_t num_trigger_patterns, + const llama_token * trigger_tokens, + size_t num_trigger_tokens); + +void llama_grammar_free_impl(struct llama_grammar * grammar); + +struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar); + +// TODO: move the API below as member functions of llama_grammar +void llama_grammar_apply_impl( + const struct llama_grammar & grammar, + llama_token_data_array * cur_p); + +void llama_grammar_accept_impl( + struct llama_grammar & grammar, + llama_token token); + +void llama_grammar_accept_str( + struct llama_grammar & grammar, + const std::string & piece); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp new file mode 100644 index 0000000000000000000000000000000000000000..337fb5cb0df3634d00284c64429dce9141c641a9 --- /dev/null +++ b/src/llama-graph.cpp @@ -0,0 +1,1631 @@ +#include "llama-graph.h" + +#include "llama-impl.h" +#include "llama-batch.h" +#include "llama-cparams.h" + +#include "llama-kv-cache-unified.h" +#include "llama-kv-cache-unified-iswa.h" +#include "llama-kv-cache-recurrent.h" + +#include +#include +#include + +void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { + if (ubatch->token) { + const int64_t n_tokens = ubatch->n_tokens; + + ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens)); + } + + if (ubatch->embd) { + const int64_t n_embd = embd->ne[0]; + const int64_t n_tokens = ubatch->n_tokens; + + ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd)); + } +} + +void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) { + if (ubatch->pos && pos) { + const int64_t n_tokens = ubatch->n_tokens; + + if (ubatch->token && n_pos_per_embd == 4) { + // in case we're using M-RoPE with text tokens, convert the 1D positions to 4D + // the 3 first dims are the same, and 4th dim is all 0 + std::vector pos_data(n_tokens*n_pos_per_embd); + // copy the first dimension + for (int i = 0; i < n_tokens; ++i) { + pos_data[ i] = ubatch->pos[i]; + pos_data[ n_tokens + i] = ubatch->pos[i]; + pos_data[2 * n_tokens + i] = ubatch->pos[i]; + pos_data[3 * n_tokens + i] = 0; // 4th dim is 0 + } + ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(pos)); + } else { + ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(pos)); + } + } +} + +void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) { + if (ubatch->pos && attn_scale) { + const int64_t n_tokens = ubatch->n_tokens; + + std::vector attn_scale_data(n_tokens, 0.0f); + for (int i = 0; i < n_tokens; ++i) { + const float pos = ubatch->pos[i]; + attn_scale_data[i] = std::log( + std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0 + ) * f_attn_temp_scale + 1.0; + } + + ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*ggml_element_size(attn_scale)); + } +} + +void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) { + if (pos_bucket) { + const int64_t n_tokens = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer)); + GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing + + int32_t * data = (int32_t *) pos_bucket->data; + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_tokens; ++i) { + data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true); + } + } + } + } +} + +void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) { + if (pos_bucket) { + kv_state->set_input_pos_bucket(pos_bucket, ubatch); + } +} + +void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) { + if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + //GGML_ASSERT(out_ids && "every model that can must skip unused outputs"); + + if (!out_ids) { + LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__); + } else { + const int64_t n_tokens = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer)); + int32_t * data = (int32_t *) out_ids->data; + + if (n_outputs == n_tokens) { + for (int i = 0; i < n_tokens; ++i) { + data[i] = i; + } + } else if (ubatch->output) { + int32_t n_outputs = 0; + for (int i = 0; i < n_tokens; ++i) { + if (ubatch->output[i]) { + data[n_outputs++] = i; + } + } + // the graph needs to have been passed the correct number of outputs + GGML_ASSERT(n_outputs == n_outputs); + } else if (n_outputs == 1) { + // only keep last output + data[0] = n_tokens - 1; + } else { + GGML_ASSERT(n_outputs == 0); + } + } + } +} + +void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) { + if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_seq_tokens = ubatch->n_seq_tokens; + const int64_t n_seqs = ubatch->n_seqs; + + GGML_ASSERT(mean); + GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer)); + + float * data = (float *) mean->data; + memset(mean->data, 0, n_tokens * n_tokens * ggml_element_size(mean)); + + std::vector sum(n_tokens, 0); + + // TODO: fix indexing [UBATCH_IDX] + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[s][0]; + + // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN"); + + sum[seq_id] += ubatch->n_seq_tokens; + } + + std::vector div(n_tokens, 0.0f); + for (int i = 0; i < n_tokens; ++i) { + const uint64_t s = sum[i]; + if (s > 0) { + div[i] = 1.0f/float(s); + } + } + + // TODO: fix indexing [UBATCH_IDX] + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[s][0]; + + for (int i = 0; i < n_seq_tokens; ++i) { + data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id]; + } + } + } +} + +void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { + if (cparams.embeddings && ( + cparams.pooling_type == LLAMA_POOLING_TYPE_CLS || + cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) { + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_seq_tokens = ubatch->n_seq_tokens; + const int64_t n_seqs = ubatch->n_seqs; + + GGML_ASSERT(cls); + GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer)); + + uint32_t * data = (uint32_t *) cls->data; + memset(cls->data, 0, n_tokens * ggml_element_size(cls)); + + // TODO: fix indexing [UBATCH_IDX] + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[s][0]; + + // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK"); + + for (int i = 0; i < n_seq_tokens; ++i) { + const llama_pos pos = ubatch->pos[s*n_seq_tokens + i]; + + if (pos == 0) { + data[seq_id] = s*n_seq_tokens + i; + } + } + } + } + + if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_seq_tokens = ubatch->n_seq_tokens; + const int64_t n_seqs = ubatch->n_seqs; + + GGML_ASSERT(cls); + GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer)); + + uint32_t * data = (uint32_t *) cls->data; + memset(cls->data, 0, n_tokens * ggml_element_size(cls)); + + std::vector last_pos(n_tokens, -1); + std::vector last_row(n_tokens, -1); + + // TODO: fix indexing [UBATCH_IDX] + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[s][0]; + + // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST"); + + for (int i = 0; i < n_seq_tokens; ++i) { + const llama_pos pos = ubatch->pos[s*n_seq_tokens + i]; + + if (pos >= last_pos[seq_id]) { + last_pos[seq_id] = pos; + last_row[seq_id] = s*n_seq_tokens + i; + } + } + } + + for (int i = 0; i < n_tokens; ++i) { + if (last_row[i] >= 0) { + data[i] = last_row[i]; + } + } + } +} + +void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); + + const int64_t n_kv = kv_state->get_n_kv(); + + if (s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer)); + int32_t * data = (int32_t *) s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_kv; ++i) { + data[i] = kv_state->s_copy(i); + } + } +} + +void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); + + if (cross_embd && !cross->v_embd.empty()) { + assert(cross_embd->type == GGML_TYPE_F32); + + ggml_backend_tensor_set(cross_embd, cross->v_embd.data(), 0, ggml_nbytes(cross_embd)); + } +} + +void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { + if (kq_mask) { + if (cparams.causal_attn) { + const int64_t n_kv = ubatch->n_tokens; + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_seq_tokens = ubatch->n_seq_tokens; + const int64_t n_seqs = ubatch->n_seqs; + + GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer)); + float * data = (float *) kq_mask->data; + + for (int h = 0; h < 1; ++h) { + for (int s1 = 0; s1 < n_seqs; ++s1) { + const llama_seq_id seq_id = ubatch->seq_id[s1][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const int32_t tj = s1*n_seq_tokens + j; + + for (int s0 = 0; s0 < n_seqs; ++s0) { + for (int i = 0; i < n_seq_tokens; ++i) { + const int32_t ti = s0*n_seq_tokens + i; + float f = -INFINITY; + + // TODO: fix indexing [UBATCH_IDX] + for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) { + if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) { + if (hparams.use_alibi) { + f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]); + } else { + f = 0.0f; + } + break; + } + } + + data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f; + } + } + } + } + } + } else { + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_seq_tokens = ubatch->n_seq_tokens; + const int64_t n_seqs = ubatch->n_seqs; + const int64_t n_stride = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer)); + + float * data = (float *) kq_mask->data; + + for (int h = 0; h < 1; ++h) { + for (int s1 = 0; s1 < n_seqs; ++s1) { + const llama_seq_id seq_id = ubatch->seq_id[s1][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const int32_t tj = s1*n_seq_tokens + j; + + for (int s0 = 0; s0 < n_seqs; ++s0) { + for (int i = 0; i < n_seq_tokens; ++i) { + const int32_t ti = s0*n_seq_tokens + i; + float f = -INFINITY; + + // TODO: fix indexing [UBATCH_IDX] + for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) { + if (ubatch->seq_id[s0][s] == seq_id) { + if (hparams.use_alibi) { + f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]); + } else { + f = 0.0f; + } + break; + } + } + + data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f; + } + } + + for (int i = n_tokens; i < n_stride; ++i) { + data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY; + } + } + } + } + } + } +} + +void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { + if (self_kq_mask) { + kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + } +} + +void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { + if (self_kq_mask) { + kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + } + + if (self_kq_mask_swa) { + kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + } +} + +void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { + if (cross_kq_mask) { + const int64_t n_enc = cross_kq_mask->ne[0]; + const int64_t n_tokens = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer)); + GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing + + float * data = (float *) cross_kq_mask->data; + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_enc; ++i) { + float f = -INFINITY; + // TODO: fix indexing [UBATCH_IDX] + for (int s = 0; s < ubatch->n_seq_id[j]; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[j][s]; + if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) { + f = 0.0f; + } + } + data[h*(n_enc*n_tokens) + j*n_enc + i] = f; + } + } + + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_enc; ++j) { + data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY; + } + } + } + } +} + +// +// llm_graph_context +// + +llm_graph_context::llm_graph_context(const llm_graph_params & params) : + arch (params.arch), + hparams (params.hparams), + cparams (params.cparams), + ubatch (params.ubatch), + n_embd (hparams.n_embd), + n_layer (hparams.n_layer), + n_rot (hparams.n_rot), + n_ctx (cparams.n_ctx), + n_head (hparams.n_head()), + n_head_kv (hparams.n_head_kv()), + n_embd_head_k (hparams.n_embd_head_k), + n_embd_k_gqa (hparams.n_embd_k_gqa()), + n_embd_head_v (hparams.n_embd_head_v), + n_embd_v_gqa (hparams.n_embd_v_gqa()), + n_expert (hparams.n_expert), + n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used), + freq_base (cparams.rope_freq_base), + freq_scale (cparams.rope_freq_scale), + ext_factor (cparams.yarn_ext_factor), + attn_factor (cparams.yarn_attn_factor), + beta_fast (cparams.yarn_beta_fast), + beta_slow (cparams.yarn_beta_slow), + norm_eps (hparams.f_norm_eps), + norm_rms_eps (hparams.f_norm_rms_eps), + n_tokens (ubatch.n_tokens), + n_outputs (params.n_outputs), + n_ctx_orig (cparams.n_ctx_orig_yarn), + pooling_type (cparams.pooling_type), + rope_type (hparams.rope_type), + ctx0 (params.ctx), + sched (params.sched), + backend_cpu (params.backend_cpu), + cvec (params.cvec), + loras (params.loras), + mstate (params.mstate), + cross (params.cross), + cb_func (params.cb), + res (std::make_unique()) { + } + +int64_t llm_graph_context::n_pos_per_embd() const { + return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1; +} + +void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const { + if (cb_func) { + cb_func(ubatch, cur, name, il); + } +} + +ggml_tensor * llm_graph_context::build_cvec( + ggml_tensor * cur, + int il) const { + return cvec->apply_to(ctx0, cur, il); +} + +ggml_tensor * llm_graph_context::build_lora_mm( + ggml_tensor * w, + ggml_tensor * cur) const { + ggml_tensor * res = ggml_mul_mat(ctx0, w, cur); + + for (const auto & lora : *loras) { + llama_adapter_lora_weight * lw = lora.first->get_weight(w); + if (lw == nullptr) { + continue; + } + + const float adapter_scale = lora.second; + const float scale = lw->get_scale(lora.first->alpha, adapter_scale); + + ggml_tensor * ab_cur = ggml_mul_mat( + ctx0, lw->b, + ggml_mul_mat(ctx0, lw->a, cur) + ); + + ab_cur = ggml_scale(ctx0, ab_cur, scale); + res = ggml_add(ctx0, res, ab_cur); + } + + return res; +} + +ggml_tensor * llm_graph_context::build_lora_mm_id( + ggml_tensor * w, // ggml_tensor * as + ggml_tensor * cur, // ggml_tensor * b + ggml_tensor * ids) const { + ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids); + for (const auto & lora : *loras) { + llama_adapter_lora_weight * lw = lora.first->get_weight(w); + if (lw == nullptr) { + continue; + } + + const float alpha = lora.first->alpha; + const float rank = (float) lw->b->ne[0]; + const float scale = alpha ? lora.second * alpha / rank : lora.second; + + ggml_tensor * ab_cur = ggml_mul_mat_id( + ctx0, lw->b, + ggml_mul_mat_id(ctx0, lw->a, cur, ids), + ids + ); + + ab_cur = ggml_scale(ctx0, ab_cur, scale); + res = ggml_add(ctx0, res, ab_cur); + } + + return res; +} + +ggml_tensor * llm_graph_context::build_norm( + ggml_tensor * cur, + ggml_tensor * mw, + ggml_tensor * mb, + llm_norm_type type, + int il) const { + switch (type) { + case LLM_NORM: cur = ggml_norm (ctx0, cur, hparams.f_norm_eps); break; + case LLM_NORM_RMS: cur = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); break; + case LLM_NORM_GROUP: + { + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], 1, cur->ne[1]); + cur = ggml_group_norm(ctx0, cur, hparams.n_norm_groups, hparams.f_norm_group_eps); + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[2]); + } break; + } + + if (mw || mb) { + cb(cur, "norm", il); + } + + if (mw) { + cur = ggml_mul(ctx0, cur, mw); + if (mb) { + cb(cur, "norm_w", il); + } + } + + if (mb) { + cur = ggml_add(ctx0, cur, mb); + } + + return cur; +} + +ggml_tensor * llm_graph_context::build_ffn( + ggml_tensor * cur, + ggml_tensor * up, + ggml_tensor * up_b, + ggml_tensor * up_s, + ggml_tensor * gate, + ggml_tensor * gate_b, + ggml_tensor * gate_s, + ggml_tensor * down, + ggml_tensor * down_b, + ggml_tensor * down_s, + ggml_tensor * act_scales, + llm_ffn_op_type type_op, + llm_ffn_gate_type type_gate, + int il) const { + ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur; + cb(tmp, "ffn_up", il); + + if (up_b) { + tmp = ggml_add(ctx0, tmp, up_b); + cb(tmp, "ffn_up_b", il); + } + + if (up_s) { + tmp = ggml_mul(ctx0, tmp, up_s); + cb(tmp, "ffn_up_s", il); + } + + if (gate) { + switch (type_gate) { + case LLM_FFN_SEQ: + { + cur = build_lora_mm(gate, tmp); + cb(cur, "ffn_gate", il); + } break; + case LLM_FFN_PAR: + { + cur = build_lora_mm(gate, cur); + cb(cur, "ffn_gate", il); + } break; + } + + if (gate_b) { + cur = ggml_add(ctx0, cur, gate_b); + cb(cur, "ffn_gate_b", il); + } + + if (gate_s) { + cur = ggml_mul(ctx0, cur, gate_s); + cb(cur, "ffn_gate_s", il); + } + + } else { + cur = tmp; + } + + switch (type_op) { + case LLM_FFN_SILU: + { + cur = ggml_silu(ctx0, cur); + cb(cur, "ffn_silu", il); + } break; + case LLM_FFN_GELU: + { + cur = ggml_gelu(ctx0, cur); + cb(cur, "ffn_gelu", il); + if (act_scales != NULL) { + cur = ggml_div(ctx0, cur, act_scales); + cb(cur, "ffn_act", il); + } + } break; + case LLM_FFN_RELU: + { + cur = ggml_relu(ctx0, cur); + cb(cur, "ffn_relu", il); + } break; + case LLM_FFN_RELU_SQR: + { + cur = ggml_relu(ctx0, cur); + cb(cur, "ffn_relu", il); + + cur = ggml_sqr(ctx0, cur); + cb(cur, "ffn_sqr(relu)", il); + } break; + case LLM_FFN_SWIGLU: + { + // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + int64_t split_point = cur->ne[0] / 2; + // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217 + ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0)); + ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); + + x0 = ggml_silu(ctx0, x0); + cb(cur, "ffn_silu", il); + + cur = ggml_mul(ctx0, x0, x1); + cb(cur, "ffn_mul", il); + } break; + case LLM_FFN_GEGLU: + { + // Split into two equal parts + int64_t split_point = cur->ne[0] / 2; + // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217 + ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0)); + ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); + + x0 = ggml_gelu(ctx0, x0); + cb(x0, "ffn_gelu", il); + + cur = ggml_mul(ctx0, x0, x1); + cb(cur, "ffn_geglu", il); + } break; + } + + if (gate && type_gate == LLM_FFN_PAR) { + cur = ggml_mul(ctx0, cur, tmp); + cb(cur, "ffn_gate_par", il); + } + + if (down) { + cur = build_lora_mm(down, cur); + if (arch == LLM_ARCH_GLM4) { + // GLM4 seems to have numerical issues with half-precision accumulators + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + } + } + + if (down_b) { + cb(cur, "ffn_down", il); + } + + if (down_b) { + cur = ggml_add(ctx0, cur, down_b); + } + + if (down_s) { + cur = ggml_mul(ctx0, cur, down_s); + cb(cur, "ffn_down_s", il); + } + + return cur; +} + +ggml_tensor * llm_graph_context::build_moe_ffn( + ggml_tensor * cur, + ggml_tensor * gate_inp, + ggml_tensor * up_exps, + ggml_tensor * gate_exps, + ggml_tensor * down_exps, + ggml_tensor * exp_probs_b, + int64_t n_expert, + int64_t n_expert_used, + llm_ffn_op_type type_op, + bool norm_w, + bool scale_w, + float w_scale, + llama_expert_gating_func_type gating_op, + int il) const { + const int64_t n_embd = cur->ne[0]; + const int64_t n_tokens = cur->ne[1]; + const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN + + ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens] + cb(logits, "ffn_moe_logits", il); + + ggml_tensor * probs = nullptr; + switch (gating_op) { + case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX: + { + probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens] + } break; + case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID: + { + probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens] + } break; + default: + GGML_ABORT("fatal error"); + } + cb(probs, "ffn_moe_probs", il); + + // add experts selection bias - introduced in DeepSeek V3 + // leave probs unbiased as it's later used to get expert weights + ggml_tensor * selection_probs = probs; + if (exp_probs_b != nullptr) { + selection_probs = ggml_add(ctx0, probs, exp_probs_b); + cb(selection_probs, "ffn_moe_probs_biased", il); + } + + // llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k + // see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198 + if (arch == LLM_ARCH_LLAMA4) { + selection_probs = logits; + } + + // select experts + ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] + cb(selected_experts->src[0], "ffn_moe_argsort", il); + cb(selected_experts, "ffn_moe_topk", il); + + ggml_tensor * weights = ggml_get_rows(ctx0, + ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] + cb(weights, "ffn_moe_weights", il); + + if (norm_w) { + weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); + + ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens] + cb(weights_sum, "ffn_moe_weights_sum", il); + + weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens] + cb(weights, "ffn_moe_weights_norm", il); + + weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens); + } + if (scale_w) { + weights = ggml_scale(ctx0, weights, w_scale); + cb(weights, "ffn_moe_weights_scaled", il); + } + + cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); + + if (weight_before_ffn) { + // repeat cur to [n_embd, n_expert_used, n_tokens] + ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1); + cur = ggml_mul(ctx0, repeated, weights); + cb(cur, "ffn_moe_weighted", il); + } + + ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(up, "ffn_moe_up", il); + + ggml_tensor * experts = nullptr; + if (gate_exps) { + cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(cur, "ffn_moe_gate", il); + } else { + cur = up; + } + + switch (type_op) { + case LLM_FFN_SILU: + { + cur = ggml_silu(ctx0, cur); + cb(cur, "ffn_moe_silu", il); + } break; + case LLM_FFN_GELU: + { + cur = ggml_gelu(ctx0, cur); + cb(cur, "ffn_moe_gelu", il); + } break; + default: + GGML_ABORT("fatal error"); + } + + if (gate_exps) { + cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens] + cb(cur, "ffn_moe_gate_par", il); + } + + experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens] + cb(experts, "ffn_moe_down", il); + + if (!weight_before_ffn) { + experts = ggml_mul(ctx0, experts, weights); + cb(cur, "ffn_moe_weighted", il); + } + + // aggregate experts + ggml_tensor * moe_out = nullptr; + for (int i = 0; i < n_expert_used; ++i) { + ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens, + experts->nb[2], i*experts->nb[1]); + + if (i == 0) { + moe_out = cur_expert; + } else { + moe_out = ggml_add(ctx0, moe_out, cur_expert); + } + } + + if (n_expert_used == 1) { + // avoid returning a non-contiguous tensor + moe_out = ggml_cont(ctx0, moe_out); + } + + cb(moe_out, "ffn_moe_out", il); + + return moe_out; +} + +// input embeddings with optional lora +ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { + const int64_t n_embd = hparams.n_embd; + + auto inp = std::make_unique(); + + ggml_tensor * cur = nullptr; + + if (ubatch.token) { + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + //cb(inp->tokens, "inp_tokens", -1); + ggml_set_input(inp->tokens); + res->t_tokens = inp->tokens; + + cur = ggml_get_rows(ctx0, tok_embd, inp->tokens); + + // apply lora for embedding tokens if needed + for (const auto & lora : *loras) { + llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd); + if (lw == nullptr) { + continue; + } + + const float adapter_scale = lora.second; + const float scale = lw->get_scale(lora.first->alpha, adapter_scale); + + ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat( + ctx0, lw->b, // non-transposed lora_b + ggml_get_rows(ctx0, lw->a, inp->tokens) + ), scale); + + cur = ggml_add(ctx0, cur, inpL_delta); + } + } else { + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens); + ggml_set_input(inp->embd); + + cur = inp->embd; + } + + // For Granite architecture + if (hparams.f_embedding_scale != 0.0f) { + cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale); + } + + cb(cur, "inp_embd", -1); + + res->add_input(std::move(inp)); + + return cur; +} + +ggml_tensor * llm_graph_context::build_inp_pos() const { + auto inp = std::make_unique(n_pos_per_embd()); + + auto & cur = inp->pos; + + cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd()); + ggml_set_input(cur); + + res->add_input(std::move(inp)); + + return cur; +} + +ggml_tensor * llm_graph_context::build_inp_attn_scale() const { + auto inp = std::make_unique(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale); + + auto & cur = inp->attn_scale; + + // this need to be 1x1xN for broadcasting + cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens); + ggml_set_input(cur); + + res->add_input(std::move(inp)); + + return cur; +} + +ggml_tensor * llm_graph_context::build_inp_out_ids() const { + auto inp = std::make_unique(hparams, cparams, n_outputs); + + auto & cur = inp->out_ids; + + cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs); + ggml_set_input(cur); + + res->add_input(std::move(inp)); + + return cur; +} + +ggml_tensor * llm_graph_context::build_inp_mean() const { + auto inp = std::make_unique(cparams); + + auto & cur = inp->mean; + + cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens); + ggml_set_input(cur); + + res->add_input(std::move(inp)); + + return cur; +} + +ggml_tensor * llm_graph_context::build_inp_cls() const { + auto inp = std::make_unique(cparams); + + auto & cur = inp->cls; + + cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(cur); + + res->add_input(std::move(inp)); + + return cur; +} + +ggml_tensor * llm_graph_context::build_inp_s_copy() const { + const auto * kv_state = static_cast(mstate); + + auto inp = std::make_unique(kv_state); + + const auto n_kv = kv_state->get_n_kv(); + + auto & cur = inp->s_copy; + + cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); + ggml_set_input(cur); + + res->add_input(std::move(inp)); + + return cur; +} + +ggml_tensor * llm_graph_context::build_inp_cross_embd() const { + auto inp = std::make_unique(cross); + + auto & cur = inp->cross_embd; + + // if we have the output embeddings from the encoder, use them directly + // TODO: needs more work to be correct, for now just use the tensor shape + //if (cross->t_embd) { + // cur = ggml_view_tensor(ctx0, cross->t_embd); + + // return cur; + //} + + const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd; + const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; + + cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc); + ggml_set_input(cur); + + res->add_input(std::move(inp)); + + return cur; +} + +ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const { + auto inp = std::make_unique(hparams); + + auto & cur = inp->pos_bucket; + + cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens); + ggml_set_input(cur); + + res->add_input(std::move(inp)); + + return cur; +} + +ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const { + const auto * kv_state = static_cast(mstate); + + auto inp = std::make_unique(hparams, kv_state); + + const auto n_kv = kv_state->get_n_kv(); + + auto & cur = inp->pos_bucket; + + cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens); + ggml_set_input(cur); + + res->add_input(std::move(inp)); + + return cur; +} + +ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const { + ggml_tensor * pos_bucket_1d = ggml_reshape_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1]); + cb(pos_bucket_1d, "pos_bucket_1d", -1); + + ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d); + + pos_bias = ggml_reshape_3d(ctx0, pos_bias, pos_bias->ne[0], pos_bucket->ne[0], pos_bucket->ne[1]); + pos_bias = ggml_permute (ctx0, pos_bias, 2, 0, 1, 3); + pos_bias = ggml_cont (ctx0, pos_bias); + + cb(pos_bias, "pos_bias", -1); + + return pos_bias; +} + +ggml_tensor * llm_graph_context::build_attn_mha( + ggml_cgraph * gf, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * kq_b, + ggml_tensor * kq_mask, + ggml_tensor * v_mla, + float kq_scale) const { + const bool v_trans = v->nb[1] > v->nb[2]; + + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + k = ggml_permute(ctx0, k, 0, 2, 1, 3); + v = ggml_permute(ctx0, v, 0, 2, 1, 3); + + const auto n_tokens = q->ne[1]; + const auto n_head = q->ne[2]; + const auto n_kv = k->ne[1]; + + ggml_tensor * cur; + + // TODO: replace hardcoded padding with ggml-provided padding + if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) { + GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet"); + + if (v_trans) { + v = ggml_transpose(ctx0, v); + } + + // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn) + if (k->type == GGML_TYPE_F32) { + k = ggml_cast(ctx0, k, GGML_TYPE_F16); + } + + if (v->type == GGML_TYPE_F32) { + v = ggml_cast(ctx0, v, GGML_TYPE_F16); + } + + cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, + hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); + + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + + if (v_mla) { +#if 0 + // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens. + // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient. + cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens); + cur = ggml_mul_mat(ctx0, v_mla, cur); +#else + // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1. + // The permutations are noops and only change how the tensor data is interpreted. + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_mul_mat(ctx0, v_mla, cur); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs. +#endif + } + + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); + } else { + ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + + // note: this op tends to require high floating point range + // while for some models F16 is enough, for others it is not, so we default to F32 here + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + + if (arch == LLM_ARCH_GROK) { + // need to do the following: + // multiply by attn_output_multiplyer of 0.08838834764831845 + // and then : + // kq = 30 * tanh(kq / 30) + // before the softmax below + + kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f)); + kq = ggml_scale(ctx0, kq, 30); + } + + if (hparams.attn_soft_cap) { + kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping); + kq = ggml_tanh (ctx0, kq); + kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping); + } + + if (kq_b) { + kq = ggml_add(ctx0, kq, kq_b); + } + + kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); + + if (!v_trans) { + // note: avoid this branch + v = ggml_cont(ctx0, ggml_transpose(ctx0, v)); + } + + ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + + // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA + if (v_mla) { + kqv = ggml_mul_mat(ctx0, v_mla, kqv); + } + + cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + + cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); + + if (!cparams.offload_kqv) { + // all nodes between the KV store and the attention output are run on the CPU + ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu); + } + } + + ggml_build_forward_expand(gf, cur); + + return cur; +} + +llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const { + auto inp = std::make_unique(hparams, cparams); + + // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch + inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(inp_kq_mask, "KQ_mask", -1); + ggml_set_input(inp->kq_mask); + + inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask; + + return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_no_cache * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * v_mla, + float kq_scale, + int il) const { + GGML_UNUSED(n_tokens); + + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, k_cur); + ggml_build_forward_expand(gf, v_cur); + + const auto & kq_mask = inp->get_kq_mask(); + + ggml_tensor * q = q_cur; + ggml_tensor * k = k_cur; + ggml_tensor * v = v_cur; + + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur); + } + + if (wo_b) { + //cb(cur, "kqv_wo", il); + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + +llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const { + const auto * kv_state = static_cast(mstate); + + auto inp = std::make_unique(hparams, cparams, kv_state); + + { + GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); + + const auto n_kv = kv_state->get_n_kv(); + + inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(inp->self_kq_mask, "KQ_mask", -1); + ggml_set_input(inp->self_kq_mask); + + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + } + + return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_kv_unified * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * v_mla, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, k_cur); + ggml_build_forward_expand(gf, v_cur); + + const auto * kv_state = static_cast(mstate); + + // store to KV cache + { + ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); + } + + const auto & kq_mask = inp->get_kq_mask(); + + ggml_tensor * q = q_cur; + ggml_tensor * k = kv_state->get_k(ctx0, il); + ggml_tensor * v = kv_state->get_v(ctx0, il); + + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur); + if (arch == LLM_ARCH_GLM4) { + // GLM4 seems to have numerical issues with half-precision accumulators + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + } + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + +llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { + const auto * kv_state = static_cast(mstate); + + auto inp = std::make_unique(hparams, cparams, kv_state); + + { + const auto n_kv = kv_state->get_base()->get_n_kv(); + + inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(inp->self_kq_mask, "KQ_mask", -1); + ggml_set_input(inp->self_kq_mask); + + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + } + + { + GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA"); + + const auto n_kv = kv_state->get_swa()->get_n_kv(); + + inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); + ggml_set_input(inp->self_kq_mask_swa); + + inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + } + + return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_kv_unified_iswa * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * v_mla, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, k_cur); + ggml_build_forward_expand(gf, v_cur); + + const auto * kv_state_iswa = static_cast(mstate); + + const bool is_swa = hparams.is_swa(il); + + const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base(); + + // store to KV cache + { + ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); + } + + const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); + + ggml_tensor * q = q_cur; + ggml_tensor * k = kv_state->get_k(ctx0, il); + ggml_tensor * v = kv_state->get_v(ctx0, il); + + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur); + } + + if (wo_b) { + //cb(cur, "kqv_wo", il); + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + +llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const { + auto inp = std::make_unique(cross); + + const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; + + inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + ggml_set_input(inp->cross_kq_mask); + + inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask; + + return (llm_graph_input_attn_cross *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_cross * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * v_mla, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, k_cur); + ggml_build_forward_expand(gf, v_cur); + + const auto & kq_mask = inp->get_kq_mask_cross(); + + ggml_tensor * q = q_cur; + ggml_tensor * k = k_cur; + ggml_tensor * v = v_cur; + + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur); + } + + if (wo_b) { + //cb(cur, "kqv_wo", il); + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + +ggml_tensor * llm_graph_context::build_recurrent_state( + ggml_cgraph * gf, + ggml_tensor * s, + ggml_tensor * state_copy, + int32_t state_size, + int32_t n_seqs, + bool avoid_copies) const { + const auto * kv_state = static_cast(mstate); + + const auto n_kv = kv_state->get_n_kv(); + const auto kv_head = kv_state->get_head(); + const auto rs_zero = kv_state->get_rs_z(); + + ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size()); + + // Clear a single state which will then be copied to the other cleared states. + // Note that this is a no-op when the view is zero-sized. + ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); + ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); + + ggml_tensor * output_states; + + if (!avoid_copies) { + // copy states + // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv + // {state_size, kv_size} -> {state_size, n_seqs} + output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); + ggml_build_forward_expand(gf, output_states); + } else { + // FIXME: make the gathering operation happen before the copy below + // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?) + output_states = states; + } + + // copy extra states which won't be changed further (between n_seqs and n_kv) + ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + states_extra, + ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s)))); + + return output_states; +} + +ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( + ggml_cgraph * gf, + ggml_tensor * state_copy, + const llama_ubatch & ubatch, + int il) const { + const auto * kv_state = static_cast(mstate); + + const auto token_shift_count = hparams.token_shift_count; + + const int64_t n_seqs = ubatch.n_seqs; + + ggml_tensor * token_shift_all = kv_state->get_k_l(il); + + ggml_tensor * token_shift = build_recurrent_state( + gf, token_shift_all, state_copy, + hparams.n_embd_k_s(), n_seqs); + + token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs); + + return token_shift; +} + +ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( + ggml_tensor * token_shift, + const llama_ubatch & ubatch, + int il) const { + const auto * kv_state = static_cast(mstate); + + const auto token_shift_count = hparams.token_shift_count; + const auto n_embd = hparams.n_embd; + + const int64_t n_seqs = ubatch.n_seqs; + + const auto kv_head = kv_state->get_head(); + + return ggml_cpy( + ctx0, + ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0), + ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il))) + ); +} + +void llm_graph_context::build_pooling( + ggml_cgraph * gf, + ggml_tensor * cls, + ggml_tensor * cls_b, + ggml_tensor * cls_out, + ggml_tensor * cls_out_b) const { + if (!cparams.embeddings) { + return; + } + + ggml_tensor * inp = res->t_embd; + + //// find result_norm tensor for input + //for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) { + // inp = ggml_graph_node(gf, i); + // if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) { + // break; + // } + + // inp = nullptr; + //} + + GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor"); + + ggml_tensor * cur; + + switch (pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + cur = inp; + } break; + case LLAMA_POOLING_TYPE_MEAN: + { + ggml_tensor * inp_mean = build_inp_mean(); + cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean); + } break; + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + ggml_tensor * inp_cls = build_inp_cls(); + cur = ggml_get_rows(ctx0, inp, inp_cls); + } break; + case LLAMA_POOLING_TYPE_RANK: + { + ggml_tensor * inp_cls = build_inp_cls(); + inp = ggml_get_rows(ctx0, inp, inp_cls); + + if (cls) { + // classification head + // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 + cur = ggml_mul_mat(ctx0, cls, inp); + if (cls_b) { + cur = ggml_add(ctx0, cur, cls_b); + } + cur = ggml_tanh(ctx0, cur); + + // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en + // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896 + if (cls_out) { + cur = ggml_mul_mat(ctx0, cls_out, cur); + if (cls_out_b) { + cur = ggml_add(ctx0, cur, cls_out_b); + } + } + } else if (cls_out) { + // Single layer classification head (direct projection) + // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476 + cur = ggml_mul_mat(ctx0, cls_out, inp); + if (cls_out_b) { + cur = ggml_add(ctx0, cur, cls_out_b); + } + } else { + GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b"); + } + } break; + default: + { + GGML_ABORT("unknown pooling type"); + } + } + + cb(cur, "result_embd_pooled", -1); + res->t_embd_pooled = cur; + + ggml_build_forward_expand(gf, cur); +} + +int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { + // TODO move to hparams if a T5 variant appears that uses a different value + const int64_t max_distance = 128; + + if (bidirectional) { + n_buckets >>= 1; + } + + const int64_t max_exact = n_buckets >> 1; + + int32_t relative_position = x - y; + int32_t relative_bucket = 0; + + if (bidirectional) { + relative_bucket += (relative_position > 0) * n_buckets; + relative_position = abs(relative_position); + } else { + relative_position = -std::min(relative_position, 0); + } + + int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact)); + relative_position_if_large = std::min(relative_position_if_large, n_buckets - 1); + relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large); + + return relative_bucket; +} diff --git a/src/llama-graph.h b/src/llama-graph.h new file mode 100644 index 0000000000000000000000000000000000000000..87813119b1a3cf5656bec8b837df18628e7902a5 --- /dev/null +++ b/src/llama-graph.h @@ -0,0 +1,628 @@ +#pragma once + +#include "llama-arch.h" +#include "llama-hparams.h" +#include "llama-adapter.h" + +#include +#include +#include +#include +#include + +struct ggml_cgraph; +struct ggml_context; +struct ggml_tensor; + +struct llama_ubatch; +struct llama_cparams; + +struct llama_memory_state_i; + +class llama_kv_cache_unified_state; +class llama_kv_cache_unified_iswa_state; +class llama_kv_cache_recurrent_state; + +// certain models (typically multi-modal) can produce different types of graphs +enum llm_graph_type { + LLM_GRAPH_TYPE_DEFAULT, + LLM_GRAPH_TYPE_ENCODER, + LLM_GRAPH_TYPE_DECODER, +}; + +enum llm_ffn_op_type { + LLM_FFN_SILU, + LLM_FFN_GELU, + LLM_FFN_RELU, + LLM_FFN_RELU_SQR, + LLM_FFN_SWIGLU, + LLM_FFN_GEGLU, +}; + +enum llm_ffn_gate_type { + LLM_FFN_SEQ, + LLM_FFN_PAR, // ffn_gate is parallel to ffn_up +}; + +enum llm_norm_type { + LLM_NORM, + LLM_NORM_RMS, + LLM_NORM_GROUP, +}; + +// TODO: tmp - need something better to pass the data from the encoder to the decoder +struct llama_cross { + // the output embeddings from the encoder as a ggml tensor + // TODO: this needs more work to be correct, for now copy the embeddings data to host memory + // ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524 + //ggml_tensor * t_embd = nullptr; + + int64_t n_embd = 0; + int64_t n_enc = 0; + + // embeddings data copied to host memory (tmp) + std::vector v_embd; + + // needed to construct the cross-attention mask in the decoder + std::vector> seq_ids_enc; +}; + +// +// llm_graph_input +// + +class llm_graph_input_i { +public: + virtual ~llm_graph_input_i() = default; + + virtual void set_input(const llama_ubatch * ubatch) = 0; +}; + +using llm_graph_input_ptr = std::unique_ptr; + + +class llm_graph_input_embd : public llm_graph_input_i { +public: + llm_graph_input_embd() = default; + virtual ~llm_graph_input_embd() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * tokens = nullptr; // I32 [n_batch] + ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch] +}; + +class llm_graph_input_pos : public llm_graph_input_i { +public: + llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {} + virtual ~llm_graph_input_pos() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * pos = nullptr; // I32 [n_batch] + + const int64_t n_pos_per_embd = 1; +}; + +// temperature tuning, used by llama4 +class llm_graph_input_attn_temp : public llm_graph_input_i { +public: + llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale) + : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {} + virtual ~llm_graph_input_attn_temp() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * attn_scale = nullptr; // F32 [n_batch] + + const uint32_t n_attn_temp_floor_scale; + const float f_attn_temp_scale; +}; + +class llm_graph_input_pos_bucket : public llm_graph_input_i { +public: + llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {} + virtual ~llm_graph_input_pos_bucket() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch] + + const llama_hparams & hparams; +}; + +class llm_graph_input_pos_bucket_kv : public llm_graph_input_i { +public: + llm_graph_input_pos_bucket_kv( + const llama_hparams & hparams, + const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {} + virtual ~llm_graph_input_pos_bucket_kv() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch] + + const llama_hparams & hparams; + const llama_kv_cache_unified_state * kv_state; +}; + +class llm_graph_input_out_ids : public llm_graph_input_i { +public: + llm_graph_input_out_ids( + const llama_hparams & hparams, + const llama_cparams & cparams, + int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {} + virtual ~llm_graph_input_out_ids() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * out_ids; // I32 [n_outputs] + + const llama_hparams & hparams; + const llama_cparams & cparams; + + const int32_t n_outputs; +}; + +class llm_graph_input_mean : public llm_graph_input_i { +public: + llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {} + virtual ~llm_graph_input_mean() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * mean; // F32 [n_batch, n_batch] + + const llama_cparams & cparams; +}; + +class llm_graph_input_cls : public llm_graph_input_i { +public: + llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {} + virtual ~llm_graph_input_cls() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * cls; // I32 [n_batch] + + const llama_cparams & cparams; +}; + +class llm_graph_input_s_copy : public llm_graph_input_i { +public: + llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {} + virtual ~llm_graph_input_s_copy() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * s_copy; // I32 [kv_size] + + const llama_kv_cache_recurrent_state * kv_state; +}; + +class llm_graph_input_cross_embd : public llm_graph_input_i { +public: + llm_graph_input_cross_embd( + const llama_cross * cross) : cross(cross) {} + virtual ~llm_graph_input_cross_embd() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc] + + const llama_cross * cross; +}; + +class llm_graph_input_attn_no_cache : public llm_graph_input_i { +public: + llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) : + hparams(hparams), + cparams(cparams) { + } + ~llm_graph_input_attn_no_cache() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * get_kq_mask() const { return kq_mask_cnv; } + + ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch] + ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch] + + const llama_hparams & hparams; + const llama_cparams & cparams; +}; + +class llm_graph_input_attn_kv_unified : public llm_graph_input_i { +public: + llm_graph_input_attn_kv_unified( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_unified_state * kv_state) : + hparams(hparams), + cparams(cparams), + kv_state(kv_state) { + } + ~llm_graph_input_attn_kv_unified() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } + + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] + + const llama_hparams & hparams; + const llama_cparams & cparams; + + const llama_kv_cache_unified_state * kv_state; +}; + +class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { +public: + llm_graph_input_attn_kv_unified_iswa( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_unified_iswa_state * kv_state) : + hparams(hparams), + cparams(cparams), + kv_state(kv_state) { + } + ~llm_graph_input_attn_kv_unified_iswa() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } + ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } + + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] + ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch] + + const llama_hparams & hparams; + const llama_cparams & cparams; + + const llama_kv_cache_unified_iswa_state * kv_state; +}; + +class llm_graph_input_attn_cross : public llm_graph_input_i { +public: + llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {} + ~llm_graph_input_attn_cross() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; } + + ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch] + ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch] + + const llama_cross * cross = nullptr; +}; + +// +// llm_graph_result +// + +// these objects deliver the result from the graph build process back to the llama_context +// note that the input tensors created for the graph are referenced here - the goal is to be able to populate their +// specific data, by calling the set_inputs() method +// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc. +// these are used by the llama_context to extact the relevant data, based on the compute parameters + +class llm_graph_result_i { +public: + virtual ~llm_graph_result_i() = default; + + virtual ggml_tensor * get_tokens() = 0; + virtual ggml_tensor * get_logits() = 0; + virtual ggml_tensor * get_embd() = 0; + virtual ggml_tensor * get_embd_pooled() = 0; + + virtual void set_inputs(const llama_ubatch * ubatch) = 0; +}; + +using llm_graph_result_ptr = std::unique_ptr; + + +class llm_graph_result : public llm_graph_result_i { +public: + virtual ~llm_graph_result() = default; + + ggml_tensor * get_tokens() override { return t_tokens; } + ggml_tensor * get_logits() override { return t_logits; } + ggml_tensor * get_embd() override { return t_embd; } + ggml_tensor * get_embd_pooled() override { return t_embd_pooled; } + + void set_inputs(const llama_ubatch * ubatch) override { + for (auto & input : inputs) { + input->set_input(ubatch); + } + } + + llm_graph_input_i * add_input(llm_graph_input_ptr input) { + inputs.emplace_back(std::move(input)); + return inputs.back().get(); + } + + // important graph nodes + ggml_tensor * t_tokens = nullptr; + ggml_tensor * t_logits = nullptr; + ggml_tensor * t_embd = nullptr; + ggml_tensor * t_embd_pooled = nullptr; + + std::vector inputs; +}; + +// +// llm_graph_context +// + +// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.) +using llm_graph_cb = std::function; + +struct llm_graph_params { + ggml_context * ctx; + + const llm_arch arch; + + const llama_hparams & hparams; + const llama_cparams & cparams; + const llama_ubatch & ubatch; + + ggml_backend_sched_t sched; + ggml_backend_t backend_cpu; + + const llama_adapter_cvec * cvec; + const llama_adapter_loras * loras; + const llama_memory_state_i * mstate; + const llama_cross * cross; + + uint32_t n_outputs; + + const llm_graph_cb & cb; +}; + +struct llm_graph_context { + const llm_arch arch; + + const llama_hparams & hparams; + const llama_cparams & cparams; + const llama_ubatch & ubatch; + + const int64_t n_embd; + const int64_t n_layer; + const int64_t n_rot; + const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train) + const int64_t n_head; + const int64_t n_head_kv; + const int64_t n_embd_head_k; + const int64_t n_embd_k_gqa; + const int64_t n_embd_head_v; + const int64_t n_embd_v_gqa; + const int64_t n_expert; + const int64_t n_expert_used; + + const float freq_base; + const float freq_scale; + const float ext_factor; + const float attn_factor; + const float beta_fast; + const float beta_slow; + const float norm_eps; + const float norm_rms_eps; + + const int64_t n_tokens; + const int64_t n_outputs; + const int32_t n_ctx_orig; // yarn + + const enum llama_pooling_type pooling_type; + const enum llama_rope_type rope_type; + + ggml_context * ctx0 = nullptr; + + ggml_backend_sched_t sched; + + ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? + + const llama_adapter_cvec * cvec; + const llama_adapter_loras * loras; + const llama_memory_state_i * mstate; + const llama_cross * cross; + + const llm_graph_cb & cb_func; + + std::unique_ptr res; + + llm_graph_context(const llm_graph_params & params); + + int64_t n_pos_per_embd() const; + + void cb(ggml_tensor * cur, const char * name, int il) const; + + // + // common + // + + ggml_tensor * build_cvec( + ggml_tensor * cur, + int il) const; + + // do mat_mul, while optionally apply lora + ggml_tensor * build_lora_mm( + ggml_tensor * w, + ggml_tensor * cur) const; + + // do mat_mul_id, while optionally apply lora + ggml_tensor * build_lora_mm_id( + ggml_tensor * w, // ggml_tensor * as + ggml_tensor * cur, // ggml_tensor * b + ggml_tensor * ids) const; + + ggml_tensor * build_norm( + ggml_tensor * cur, + ggml_tensor * mw, + ggml_tensor * mb, + llm_norm_type type, + int il) const; + + ggml_tensor * build_ffn( + ggml_tensor * cur, + ggml_tensor * up, + ggml_tensor * up_b, + ggml_tensor * up_s, + ggml_tensor * gate, + ggml_tensor * gate_b, + ggml_tensor * gate_s, + ggml_tensor * down, + ggml_tensor * down_b, + ggml_tensor * down_s, + ggml_tensor * act_scales, + llm_ffn_op_type type_op, + llm_ffn_gate_type type_gate, + int il) const; + + ggml_tensor * build_moe_ffn( + ggml_tensor * cur, + ggml_tensor * gate_inp, + ggml_tensor * up_exps, + ggml_tensor * gate_exps, + ggml_tensor * down_exps, + ggml_tensor * exp_probs_b, + int64_t n_expert, + int64_t n_expert_used, + llm_ffn_op_type type_op, + bool norm_w, + bool scale_w, + float w_scale, + llama_expert_gating_func_type gating_op, + int il) const; + + // + // inputs + // + + ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const; + ggml_tensor * build_inp_pos() const; + ggml_tensor * build_inp_attn_scale() const; + ggml_tensor * build_inp_out_ids() const; + ggml_tensor * build_inp_mean() const; + ggml_tensor * build_inp_cls() const; + ggml_tensor * build_inp_s_copy() const; + + ggml_tensor * build_inp_cross_embd() const; + ggml_tensor * build_inp_pos_bucket_enc() const; + ggml_tensor * build_inp_pos_bucket_dec() const; + ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const; + + // + // attention + // + + ggml_tensor * build_attn_mha( + ggml_cgraph * gf, + ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false) + ggml_tensor * kq_b, + ggml_tensor * kq_mask, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + float kq_scale) const; + + llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_no_cache * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + float kq_scale, + int il) const; + + llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_kv_unified * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + float kq_scale, + int il) const; + + llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_kv_unified_iswa * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + float kq_scale, + int il) const; + + llm_graph_input_attn_cross * build_attn_inp_cross() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_cross * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + float kq_scale, + int il) const; + + // + // recurrent + // + + ggml_tensor * build_recurrent_state( + ggml_cgraph * gf, + ggml_tensor * s, + ggml_tensor * state_copy, + int32_t state_size, + int32_t n_seqs, + bool avoid_copies = false) const; + + ggml_tensor * build_rwkv_token_shift_load( + ggml_cgraph * gf, + ggml_tensor * state_copy, + const llama_ubatch & ubatch, + int il) const; + + ggml_tensor * build_rwkv_token_shift_store( + ggml_tensor * token_shift, + const llama_ubatch & ubatch, + int il) const; + + // + // pooling + // + + void build_pooling( + ggml_cgraph * gf, + ggml_tensor * cls, + ggml_tensor * cls_b, + ggml_tensor * cls_out, + ggml_tensor * cls_out_b) const; +}; + +// TODO: better name +int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional); diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1499eb08a5dd9246f182dc3545d4e23aecc5ca29 --- /dev/null +++ b/src/llama-hparams.cpp @@ -0,0 +1,95 @@ +#include "llama-hparams.h" + +#include "ggml.h" + +void llama_hparams::set_swa_pattern(uint32_t n_pattern) { + for (uint32_t il = 0; il < n_layer; ++il) { + swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); + } +} + +bool llama_hparams::is_swa_any() const { + for (uint32_t il = 0; il < n_layer; ++il) { + if (swa_layers[il]) { + return true; + } + } + + return false; +} + +uint32_t llama_hparams::n_head(uint32_t il) const { + if (il < n_layer) { + return n_head_arr[il]; + } + + GGML_ABORT("fatal error"); +} + +uint32_t llama_hparams::n_head_kv(uint32_t il) const { + if (il < n_layer) { + return n_head_kv_arr[il]; + } + + GGML_ABORT("fatal error"); +} + +uint32_t llama_hparams::n_ff(uint32_t il) const { + if (il < n_layer) { + return n_ff_arr[il]; + } + + GGML_ABORT("fatal error"); +} + +uint32_t llama_hparams::n_gqa(uint32_t il) const { + const uint32_t n_head = this->n_head(il); + const uint32_t n_head_kv = this->n_head_kv(il); + + if (n_head_kv == 0) { + return 0; + } + + return n_head/n_head_kv; +} + +uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const { + const uint32_t n_head_kv = this->n_head_kv(il); + + return n_embd_head_k * n_head_kv; +} + +uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { + const uint32_t n_head_kv = this->n_head_kv(il); + + return n_embd_head_v * n_head_kv; +} + +uint32_t llama_hparams::n_embd_k_s() const { + if (wkv_head_size != 0) { + // for RWKV models + return token_shift_count * n_embd; + } + + // TODO: maybe support other convolution strides than 1 + // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; +} + +uint32_t llama_hparams::n_embd_v_s() const { + if (wkv_head_size != 0) { + // corresponds to RWKV's wkv_states size + return n_embd * wkv_head_size; + } + + // corresponds to Mamba's ssm_states size + return ssm_d_state * ssm_d_inner; +} + +bool llama_hparams::is_swa(uint32_t il) const { + if (il < n_layer) { + return swa_layers[il]; + } + + GGML_ABORT("fatal error"); +} diff --git a/src/llama-hparams.h b/src/llama-hparams.h new file mode 100644 index 0000000000000000000000000000000000000000..b2bcb8b01a18b8e07476cb332e6fd356687db0f8 --- /dev/null +++ b/src/llama-hparams.h @@ -0,0 +1,193 @@ +#pragma once + +#include "llama.h" + +#include + +// bump if necessary +#define LLAMA_MAX_LAYERS 512 +#define LLAMA_MAX_EXPERTS 256 // DeepSeekV3 + +enum llama_expert_gating_func_type { + LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1, + LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2, +}; + +enum llama_swa_type { + LLAMA_SWA_TYPE_NONE = 0, + LLAMA_SWA_TYPE_STANDARD = 1, + LLAMA_SWA_TYPE_CHUNKED = 2, +}; + +struct llama_hparams_posnet { + uint32_t n_embd; + uint32_t n_layer; +}; + +struct llama_hparams_convnext { + uint32_t n_embd; + uint32_t n_layer; +}; + +struct llama_hparams { + bool vocab_only; + bool rope_finetuned; + bool use_par_res; + bool swin_norm; + + uint32_t n_ctx_train; // context size the model was trained on + uint32_t n_embd; + uint32_t n_embd_features = 0; + uint32_t n_layer; + uint32_t n_rot; + uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads + uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head + uint32_t n_expert = 0; + uint32_t n_expert_used = 0; + uint32_t n_rel_attn_bkts = 0; + + // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA + uint32_t n_embd_head_k_mla = 0; + uint32_t n_embd_head_v_mla = 0; + + // for WavTokenizer + struct llama_hparams_posnet posnet; + struct llama_hparams_convnext convnext; + + std::array n_head_arr; + std::array n_head_kv_arr; + std::array n_ff_arr; + + uint32_t n_layer_dense_lead = 0; + uint32_t n_lora_q = 0; + uint32_t n_lora_kv = 0; + uint32_t n_ff_exp = 0; + uint32_t n_ff_shexp = 0; + uint32_t n_expert_shared = 0; + uint32_t n_norm_groups = 0; + + float expert_weights_scale = 0.0; + bool expert_weights_norm = false; + uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE; + uint32_t moe_every_n_layers = 0; + + float f_norm_eps; + float f_norm_rms_eps; + float f_norm_group_eps; + + float f_attn_logit_softcapping = 50.0f; + float f_final_logit_softcapping = 30.0f; + + // for RWKV + uint32_t rescale_every_n_layers = 0; + uint32_t time_mix_extra_dim = 0; + uint32_t time_decay_extra_dim = 0; + uint32_t wkv_head_size = 0; + uint32_t token_shift_count = 2; + uint32_t n_lora_decay = 0; + uint32_t n_lora_iclr = 0; + uint32_t n_lora_value_res_mix = 0; + uint32_t n_lora_gate = 0; + + float rope_attn_factor = 1.0f; + float rope_freq_base_train; + float rope_freq_base_train_swa; + float rope_freq_scale_train; + float rope_freq_scale_train_swa; + uint32_t n_ctx_orig_yarn; + float rope_yarn_log_mul; + + std::array rope_sections; + + // Sliding Window Attention (SWA) + llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; + // the size of the sliding window (0 - no SWA) + uint32_t n_swa = 0; + // if swa_layers[il] == true, then layer il is SWA + // if swa_layers[il] == false, then layer il is dense (i.e. non-SWA) + // by default, all layers are dense + std::array swa_layers; + + // for State Space Models + uint32_t ssm_d_conv = 0; + uint32_t ssm_d_inner = 0; + uint32_t ssm_d_state = 0; + uint32_t ssm_dt_rank = 0; + + bool ssm_dt_b_c_rms = false; + + float f_clamp_kqv = 0.0f; + float f_max_alibi_bias = 0.0f; + float f_logit_scale = 0.0f; + + // Additional scale factors (Granite/Granite MoE) + float f_residual_scale = 0.0f; + float f_embedding_scale = 0.0f; + float f_attention_scale = 0.0f; + + bool causal_attn = true; + bool use_alibi = false; + bool attn_soft_cap = false; + bool use_kq_norm = true; + + // for Classifiers + uint32_t n_cls_out = 1; + + // llama4 + uint32_t n_moe_layer_step = 0; + uint32_t n_no_rope_layer_step = 4; + uint32_t n_attn_temp_floor_scale = 8192; + float f_attn_temp_scale = 0.1; + + // needed by encoder-decoder models (e.g. T5, FLAN-T5) + // ref: https://github.com/ggerganov/llama.cpp/pull/8141 + llama_token dec_start_token_id = LLAMA_TOKEN_NULL; + + enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; + enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; + enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; + + // this value n_pattern means that every nth layer is dense (i.e. non-SWA) + // note that if n_pattern == 0, all layers are SWA + // if n_pattern == 1, all layers are dense + // example: n_pattern = 3 + // il == 0: swa + // il == 1: swa + // il == 2: dense + // il == 3: swa + // il == 4: swa + // il == 5: dense + // il == 6: swa + // etc ... + void set_swa_pattern(uint32_t n_pattern); + + // return true if one of the layers is SWA + bool is_swa_any() const; + + uint32_t n_head(uint32_t il = 0) const; + + uint32_t n_head_kv(uint32_t il = 0) const; + + uint32_t n_ff(uint32_t il = 0) const; + + uint32_t n_gqa(uint32_t il = 0) const; + + // dimension of key embeddings across all k-v heads + uint32_t n_embd_k_gqa(uint32_t il = 0) const; + + // dimension of value embeddings across all k-v heads + uint32_t n_embd_v_gqa(uint32_t il = 0) const; + + // dimension of the rolling state embeddings + // corresponds to Mamba's conv_states size or RWKV's token_shift states size + uint32_t n_embd_k_s() const; + + // dimension of the recurrent state embeddings + uint32_t n_embd_v_s() const; + + bool is_swa(uint32_t il) const; +}; + +static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable"); + diff --git a/src/llama-impl.cpp b/src/llama-impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6ec709dd323a6d5319d67d64f607c0acc85f3267 --- /dev/null +++ b/src/llama-impl.cpp @@ -0,0 +1,167 @@ +#include "llama-impl.h" + +#include "gguf.h" +#include "llama.h" + +#include +#include +#include +#include +#include +#include + +struct llama_logger_state { + ggml_log_callback log_callback = llama_log_callback_default; + void * log_callback_user_data = nullptr; +}; + +static llama_logger_state g_logger_state; + +time_meas::time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {} + +time_meas::~time_meas() { + if (t_start_us >= 0) { + t_acc += ggml_time_us() - t_start_us; + } + } + +void llama_log_set(ggml_log_callback log_callback, void * user_data) { + ggml_log_set(log_callback, user_data); + g_logger_state.log_callback = log_callback ? log_callback : llama_log_callback_default; + g_logger_state.log_callback_user_data = user_data; +} + +static void llama_log_internal_v(ggml_log_level level, const char * format, va_list args) { + va_list args_copy; + va_copy(args_copy, args); + char buffer[128]; + int len = vsnprintf(buffer, 128, format, args); + if (len < 128) { + g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data); + } else { + char * buffer2 = new char[len + 1]; + vsnprintf(buffer2, len + 1, format, args_copy); + buffer2[len] = 0; + g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data); + delete[] buffer2; + } + va_end(args_copy); +} + +void llama_log_internal(ggml_log_level level, const char * format, ...) { + va_list args; + va_start(args, format); + llama_log_internal_v(level, format, args); + va_end(args); +} + +void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} + +void replace_all(std::string & s, const std::string & search, const std::string & replace) { + if (search.empty()) { + return; + } + std::string builder; + builder.reserve(s.length()); + size_t pos = 0; + size_t last_pos = 0; + while ((pos = s.find(search, last_pos)) != std::string::npos) { + builder.append(s, last_pos, pos - last_pos); + builder.append(replace); + last_pos = pos + search.length(); + } + builder.append(s, last_pos, std::string::npos); + s = std::move(builder); +} + +std::string format(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::vector buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), size); +} + +std::string llama_format_tensor_shape(const std::vector & ne) { + char buf[256]; + snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0)); + for (size_t i = 1; i < ne.size(); i++) { + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i)); + } + return buf; +} + +std::string llama_format_tensor_shape(const struct ggml_tensor * t) { + char buf[256]; + snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]); + } + return buf; +} + +static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) { + switch (type) { + case GGUF_TYPE_UINT8: return std::to_string(((const uint8_t *)data)[i]); + case GGUF_TYPE_INT8: return std::to_string(((const int8_t *)data)[i]); + case GGUF_TYPE_UINT16: return std::to_string(((const uint16_t *)data)[i]); + case GGUF_TYPE_INT16: return std::to_string(((const int16_t *)data)[i]); + case GGUF_TYPE_UINT32: return std::to_string(((const uint32_t *)data)[i]); + case GGUF_TYPE_INT32: return std::to_string(((const int32_t *)data)[i]); + case GGUF_TYPE_UINT64: return std::to_string(((const uint64_t *)data)[i]); + case GGUF_TYPE_INT64: return std::to_string(((const int64_t *)data)[i]); + case GGUF_TYPE_FLOAT32: return std::to_string(((const float *)data)[i]); + case GGUF_TYPE_FLOAT64: return std::to_string(((const double *)data)[i]); + case GGUF_TYPE_BOOL: return ((const bool *)data)[i] ? "true" : "false"; + default: return format("unknown type %d", type); + } +} + +std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { + const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i); + + switch (type) { + case GGUF_TYPE_STRING: + return gguf_get_val_str(ctx_gguf, i); + case GGUF_TYPE_ARRAY: + { + const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i); + int arr_n = gguf_get_arr_n(ctx_gguf, i); + const void * data = arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx_gguf, i); + std::stringstream ss; + ss << "["; + for (int j = 0; j < arr_n; j++) { + if (arr_type == GGUF_TYPE_STRING) { + std::string val = gguf_get_arr_str(ctx_gguf, i, j); + // escape quotes + replace_all(val, "\\", "\\\\"); + replace_all(val, "\"", "\\\""); + ss << '"' << val << '"'; + } else if (arr_type == GGUF_TYPE_ARRAY) { + ss << "???"; + } else { + ss << gguf_data_to_str(arr_type, data, j); + } + if (j < arr_n - 1) { + ss << ", "; + } + } + ss << "]"; + return ss.str(); + } + default: + return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0); + } +} diff --git a/src/llama-impl.h b/src/llama-impl.h new file mode 100644 index 0000000000000000000000000000000000000000..02b1d07f8400dc3caa75f1766bf46cdc33685d65 --- /dev/null +++ b/src/llama-impl.h @@ -0,0 +1,61 @@ +#pragma once + +#include "ggml.h" // for ggml_log_level + +#include +#include + +#ifdef __GNUC__ +# if defined(__MINGW32__) && !defined(__clang__) +# define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +# else +# define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +# endif +#else +# define LLAMA_ATTRIBUTE_FORMAT(...) +#endif + +// +// logging +// + +LLAMA_ATTRIBUTE_FORMAT(2, 3) +void llama_log_internal (ggml_log_level level, const char * format, ...); +void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data); + +#define LLAMA_LOG(...) llama_log_internal(GGML_LOG_LEVEL_NONE , __VA_ARGS__) +#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) +#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) +#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) +#define LLAMA_LOG_DEBUG(...) llama_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) +#define LLAMA_LOG_CONT(...) llama_log_internal(GGML_LOG_LEVEL_CONT , __VA_ARGS__) + +// +// helpers +// + +template +struct no_init { + T value; + no_init() { /* do nothing */ } +}; + +struct time_meas { + time_meas(int64_t & t_acc, bool disable = false); + ~time_meas(); + + const int64_t t_start_us; + + int64_t & t_acc; +}; + +void replace_all(std::string & s, const std::string & search, const std::string & replace); + +// TODO: rename to llama_format ? +LLAMA_ATTRIBUTE_FORMAT(1, 2) +std::string format(const char * fmt, ...); + +std::string llama_format_tensor_shape(const std::vector & ne); +std::string llama_format_tensor_shape(const struct ggml_tensor * t); + +std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i); diff --git a/src/llama-io.cpp b/src/llama-io.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7ad70d163343d8936f5dbfbb5443dfb42eaf1ffd --- /dev/null +++ b/src/llama-io.cpp @@ -0,0 +1,15 @@ +#include "llama-io.h" + +void llama_io_write_i::write_string(const std::string & str) { + uint32_t str_size = str.size(); + + write(&str_size, sizeof(str_size)); + write(str.data(), str_size); +} + +void llama_io_read_i::read_string(std::string & str) { + uint32_t str_size; + read_to(&str_size, sizeof(str_size)); + + str.assign((const char *) read(str_size), str_size); +} diff --git a/src/llama-io.h b/src/llama-io.h new file mode 100644 index 0000000000000000000000000000000000000000..ce9216b83b19274feefd7fa2ab13c57a5f187471 --- /dev/null +++ b/src/llama-io.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include +#include + +struct ggml_tensor; + +class llama_io_write_i { +public: + llama_io_write_i() = default; + virtual ~llama_io_write_i() = default; + + virtual void write(const void * src, size_t size) = 0; + virtual void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) = 0; + + // bytes written so far + virtual size_t n_bytes() = 0; + + void write_string(const std::string & str); +}; + +class llama_io_read_i { +public: + llama_io_read_i() = default; + virtual ~llama_io_read_i() = default; + + virtual const uint8_t * read(size_t size) = 0; + virtual void read_to(void * dst, size_t size) = 0; + + // bytes read so far + virtual size_t n_bytes() = 0; + + void read_string(std::string & str); +}; diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8f6f120f682b769730113d3c4bf835284209a239 --- /dev/null +++ b/src/llama-kv-cache-recurrent.cpp @@ -0,0 +1,1115 @@ +#include "llama-kv-cache-recurrent.h" + +#include "llama-impl.h" +#include "llama-io.h" +#include "llama-batch.h" +#include "llama-model.h" + +#include +#include +#include +#include +#include + +// +// llama_kv_cache_recurrent +// + +llama_kv_cache_recurrent::llama_kv_cache_recurrent( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { + const int32_t n_layer = hparams.n_layer; + + LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n", + __func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer); + + head = 0; + size = kv_size; + used = 0; + + cells.clear(); + cells.resize(kv_size); + + // create a context for each buffer type + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return nullptr; + } + + ctx_map[buft] = ctx; + ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + k_l.reserve(n_layer); + v_l.reserve(n_layer); + + for (int i = 0; i < n_layer; i++) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + + const char * dev_name = "CPU"; + + ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); + + if (offload) { + auto * dev = model.dev_layer(i); + buft = ggml_backend_dev_buffer_type(dev); + + dev_name = ggml_backend_dev_name(dev); + } + + LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name); + + ggml_context * ctx = ctx_for_buft(buft); + if (!ctx) { + throw std::runtime_error("failed to create ggml context for kv cache"); + } + + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_format_name(k, "cache_k_l%d", i); + ggml_format_name(v, "cache_v_l%d", i); + k_l.push_back(k); + v_l.push_back(v); + } + + // allocate tensors and initialize the buffers to avoid NaNs in the padding + for (auto it : ctx_map) { + auto * buft = it.first; + auto * ctx = it.second; + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + throw std::runtime_error("failed to allocate buffer for kv cache"); + } + ggml_backend_buffer_clear(buf, 0); + LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + bufs.emplace_back(buf); + } + + { + const size_t memory_size_k = size_k_bytes(); + const size_t memory_size_v = size_v_bytes(); + + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } +} + +void llama_kv_cache_recurrent::clear(bool data) { + for (int32_t i = 0; i < (int32_t) size; ++i) { + cells[i].pos = -1; + cells[i].seq_id.clear(); + cells[i].src = -1; + cells[i].tail = -1; + } + + head = 0; + used = 0; + + if (data) { + for (auto & buf : bufs) { + ggml_backend_buffer_clear(buf.get(), 0); + } + } +} + +bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + uint32_t new_head = size; + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // models like Mamba or RWKV can't have a state partially erased + if (seq_id >= (int64_t) size) { + // could be fatal + return false; + } + if (0 <= seq_id) { + int32_t & tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + const kv_cell & cell = cells[tail_id]; + // partial intersection is invalid + if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + return false; + } + // invalidate tails which will be cleared + if (p0 <= cell.pos && cell.pos < p1) { + tail_id = -1; + } + } + } else { + // seq_id is negative, then the range should include everything or nothing + if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { + return false; + } + } + + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].pos >= p0 && cells[i].pos < p1) { + if (seq_id < 0) { + cells[i].seq_id.clear(); + } else if (cells[i].has_seq_id(seq_id)) { + cells[i].seq_id.erase(seq_id); + } else { + continue; + } + if (cells[i].is_empty()) { + // keep count of the number of used cells + if (cells[i].pos >= 0) { + used--; + } + cells[i].pos = -1; + cells[i].src = -1; + if (new_head == size) { + new_head = i; + } + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != size && new_head < head) { + head = new_head; + } + + return true; +} + +void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + if (seq_id_src == seq_id_dst) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) { + kv_cell & tail_src = cells[seq_id_src]; + kv_cell & tail_dst = cells[seq_id_dst]; + if (tail_dst.tail >= 0) { + // clear destination seq_id if it wasn't empty + kv_cell & cell_dst = cells[tail_dst.tail]; + + cell_dst.seq_id.erase(seq_id_dst); + tail_dst.tail = -1; + if (cell_dst.seq_id.empty()) { + cell_dst.pos = -1; + cell_dst.src = -1; + used -= 1; + } + } + if (tail_src.tail >= 0) { + kv_cell & cell_src = cells[tail_src.tail]; + + cell_src.seq_id.insert(seq_id_dst); + tail_dst.tail = tail_src.tail; + } + } +} + +void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) { + uint32_t new_head = size; + + for (uint32_t i = 0; i < size; ++i) { + if ((llama_seq_id) i != seq_id) { + cells[i].tail = -1; + } + + if (!cells[i].has_seq_id(seq_id)) { + if (cells[i].pos >= 0) { + used--; + } + + cells[i].pos = -1; + cells[i].src = -1; + cells[i].seq_id.clear(); + + if (new_head == size){ + new_head = i; + } + } else { + cells[i].seq_id.clear(); + cells[i].seq_id.insert(seq_id); + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != size && new_head < head) { + head = new_head; + } +} + +void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + if (shift == 0) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over the + if (p0 == p1) { + return; + } + + // for Mamba-like or RWKV models, only the pos needs to be shifted + if (0 <= seq_id && seq_id < (int64_t) size) { + const int32_t tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + kv_cell & cell = cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos += shift; + } + } + } +} + +void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (d == 1) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over the cache. + if (p0 == p1) { + return; + } + + // for Mamba-like or RWKV models, only the pos needs to be changed + if (0 <= seq_id && seq_id < (int64_t) size) { + const int32_t tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + kv_cell & cell = cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos /= d; + } + } + } +} + +llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const { + llama_pos result = std::numeric_limits::max(); + + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id)) { + result = std::min(result, cells[i].pos); + } + } + + if (result == std::numeric_limits::max()) { + result = -1; + } + + return result; +} + +llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const { + llama_pos result = -1; + + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id)) { + result = std::max(result, cells[i].pos); + } + } + + return result; +} + +llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) { + auto sbatch = llama_sbatch(batch, hparams.n_embd, false); + + std::vector ubatches; + + while (sbatch.n_tokens > 0) { + llama_ubatch ubatch; + + if (embd_all) { + // if all tokens are output, split by sequence + ubatch = sbatch.split_seq(n_ubatch); + } else { + ubatch = sbatch.split_equal(n_ubatch); + } + + ubatches.push_back(ubatch); + } + + if (!prepare(ubatches)) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches)); +} + +llama_memory_state_ptr llama_kv_cache_recurrent::init_full() { + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); +} + +llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) { + GGML_UNUSED(lctx); + GGML_UNUSED(optimize); + + return std::make_unique(LLAMA_MEMORY_STATUS_NO_UPDATE); +} + +bool llama_kv_cache_recurrent::prepare(const std::vector & ubatches) { + // simply remember the full state because it is very small for this type of cache + // TODO: optimize + auto org_cells = cells; + auto org_used = used; + auto org_head = head; + + bool success = true; + + for (const auto & ubatch : ubatches) { + if (!find_slot(ubatch)) { + success = false; + break; + } + } + + // restore the original state + cells = std::move(org_cells); + used = org_used; + head = org_head; + + return success; +} + +bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { + const uint32_t n_seqs = ubatch.n_seqs; + + const uint32_t n_seq_tokens = ubatch.n_seq_tokens; + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head > used + 2*n_seqs) { + head = 0; + } + + // For recurrent state architectures (like Mamba or RWKV), + // each cache cell can store the state for a whole sequence. + // A slot should be always be contiguous. + + // can only process batches with an equal number of new tokens in each sequence + GGML_ASSERT(ubatch.equal_seqs); + + int32_t min = size - 1; + int32_t max = 0; + + // everything should fit if all seq_ids are smaller than the max + for (uint32_t s = 0; s < n_seqs; ++s) { + const uint32_t n_seq_id = ubatch.n_seq_id[s]; + for (uint32_t j = 0; j < n_seq_id; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; + + if (seq_id < 0 || (uint32_t) seq_id >= size) { + // too big seq_id + // TODO: would it be possible to resize the cache instead? + LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max); + return false; + } + if (j > 0) { + kv_cell & seq = cells[seq_id]; + if (seq.tail >= 0) { + kv_cell & cell = cells[seq.tail]; + // clear cells from seq_ids that become shared + // (should not normally happen, but let's handle it anyway) + cell.seq_id.erase(seq_id); + seq.tail = -1; + if (cell.seq_id.empty()) { + cell.pos = -1; + cell.src = -1; + used -= 1; + } + } + } + } + } + +#ifndef NDEBUG + { + std::vector tails_verif; + tails_verif.assign(size, -1); + for (uint32_t i = 0; i < size; ++i) { + kv_cell & cell = cells[i]; + for (llama_seq_id seq_id : cell.seq_id) { + if (tails_verif[seq_id] != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); + } + tails_verif[seq_id] = i; + } + } + for (uint32_t i = 0; i < size; ++i) { + if (tails_verif[i] != cells[i].tail) { + LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]); + } + } + } +#endif + + // find next empty cell + uint32_t next_empty_cell = head; + + for (uint32_t i = 0; i < size; ++i) { + if (next_empty_cell >= size) { next_empty_cell -= size; } + kv_cell & cell = cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } + + // find usable cell range + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + kv_cell & seq_meta = cells[seq_id]; + bool has_cell = false; + if (seq_meta.tail >= 0) { + kv_cell & cell = cells[seq_meta.tail]; + GGML_ASSERT(cell.has_seq_id(seq_id)); + // does this seq_id "own" the cell? + if (cell.seq_id.size() == 1) { has_cell = true; } + } + if (!has_cell) { + kv_cell & empty_cell = cells[next_empty_cell]; + GGML_ASSERT(empty_cell.is_empty()); + // copy old tail into the empty cell + if (seq_meta.tail >= 0) { + kv_cell & orig_cell = cells[seq_meta.tail]; + empty_cell.pos = orig_cell.pos; + empty_cell.src = orig_cell.src; + orig_cell.seq_id.erase(seq_id); + empty_cell.seq_id.insert(seq_id); // will be overwritten + GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id + } + seq_meta.tail = next_empty_cell; + // find next empty cell + if (s + 1 < n_seqs) { + for (uint32_t i = 0; i < size; ++i) { + next_empty_cell += 1; + if (next_empty_cell >= size) { next_empty_cell -= size; } + kv_cell & cell = cells[next_empty_cell]; + if (cell.is_empty()) { break; } + } + } + } + if (min > seq_meta.tail) { min = seq_meta.tail; } + if (max < seq_meta.tail) { max = seq_meta.tail; } + } + + // gather and re-order + for (uint32_t s = 0; s < n_seqs; ++s) { + const int32_t dst_id = s + min; + const int32_t src_id = cells[ubatch.seq_id[s][0]].tail; + if (dst_id != src_id) { + kv_cell & dst_cell = cells[dst_id]; + kv_cell & src_cell = cells[src_id]; + + std::swap(dst_cell.pos, src_cell.pos); + std::swap(dst_cell.src, src_cell.src); + std::swap(dst_cell.seq_id, src_cell.seq_id); + + // swap tails + for (uint32_t i = 0; i < size; ++i) { + int32_t & tail = cells[i].tail; + if (tail == src_id) { + tail = dst_id; + } else if (tail == dst_id) { + tail = src_id; + } + } + } + } + + // update the pos of the used seqs + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; + const int32_t cell_id = s + min; + kv_cell & cell = cells[cell_id]; + + if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", + __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); + } + cell.pos = last_pos; + cell.seq_id.clear(); + for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; + cell.seq_id.insert(seq_id); + cells[seq_id].tail = cell_id; + } + } + + // Find first cell without src refs, to use as the zero-ed state + { + // TODO: bake-in src refcounts in the cell metadata + std::vector refcounts(size, 0); + for (size_t i = 0; i < size; ++i) { + const int32_t src = cells[i].src; + if (src >= 0) { + refcounts[src] += 1; + } + } + + rs_z = -1; + for (int i = min; i <= max; ++i) { + if (refcounts[i] == 0) { + rs_z = i; + break; + } + } + + for (int i = min; i <= max; ++i) { + if (cells[i].src < 0) { + GGML_ASSERT(rs_z >= 0); + cells[i].src0 = rs_z; + } else { + // Stage the source ids for all used cells to allow correct seq_* behavior + // and still make these values available when setting the inputs + cells[i].src0 = cells[i].src; + } + cells[i].src = i; // avoid moving or clearing twice + } + } + + // allow getting the range of used cells, from head to head + n + head = min; + n = max - min + 1; + used = std::count_if(cells.begin(), cells.end(), + [](const kv_cell & cell){ return !cell.is_empty(); }); + + // sanity check + return n >= n_seqs; +} + +bool llama_kv_cache_recurrent::get_can_shift() const { + // shifting the pos is trivial for recurrent models + return true; +} + +size_t llama_kv_cache_recurrent::total_size() const { + size_t size = 0; + for (const auto & buf : bufs) { + size += ggml_backend_buffer_get_size(buf.get()); + } + + return size; +} + +size_t llama_kv_cache_recurrent::size_k_bytes() const { + size_t size_k_bytes = 0; + + for (const auto & k : k_l) { + size_k_bytes += ggml_nbytes(k); + } + + return size_k_bytes; +} + +size_t llama_kv_cache_recurrent::size_v_bytes() const { + size_t size_v_bytes = 0; + + for (const auto & v : v_l) { + size_v_bytes += ggml_nbytes(v); + } + + return size_v_bytes; +} + +void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + std::vector> cell_ranges; // ranges, from inclusive, to exclusive + uint32_t cell_count = 0; + + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = size; + for (uint32_t i = 0; i < size; ++i) { + const auto & cell = cells[i]; + if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { + ++cell_count; + if (cell_range_begin == size) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != size) { + cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = size; + } + } + } + if (cell_range_begin != size) { + cell_ranges.emplace_back(cell_range_begin, size); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + + io.write(&cell_count, sizeof(cell_count)); + + state_write_meta(io, cell_ranges, seq_id); + state_write_data(io, cell_ranges); +} + +void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + uint32_t cell_count; + io.read_to(&cell_count, sizeof(cell_count)); + + bool res = true; + + res = res && state_read_meta(io, cell_count, seq_id); + res = res && state_read_data(io, cell_count); + + if (!res) { + if (seq_id == -1) { + clear(true); + } else { + seq_rm(seq_id, -1, -1); + } + throw std::runtime_error("failed to restore kv cache"); + } +} + +void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { + for (const auto & range : cell_ranges) { + for (uint32_t i = range.first; i < range.second; ++i) { + const auto & cell = cells[i]; + const llama_pos pos = cell.pos; + const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; + + io.write(&pos, sizeof(pos)); + io.write(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id) { + for (auto seq_id : cell.seq_id) { + io.write(&seq_id, sizeof(seq_id)); + } + } + } + } +} + +void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { + const uint32_t v_trans = 0; + const uint32_t n_layer = hparams.n_layer; + + io.write(&v_trans, sizeof(v_trans)); + io.write(&n_layer, sizeof(n_layer)); + + std::vector tmp_buf; + + // Iterate and write all the keys first, each row is a cell + // Get whole range at a time + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Write key type + const int32_t k_type_i = (int32_t)k_l[il]->type; + io.write(&k_type_i, sizeof(k_type_i)); + + // Write row size of key + const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + io.write(&k_size_row, sizeof(k_size_row)); + + // Read each range of cells of k_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * k_size_row; + io.write_tensor(k_l[il], range.first * k_size_row, buf_size); + } + } + + if (!v_trans) { + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)v_l[il]->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write row size of value + const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + io.write(&v_size_row, sizeof(v_size_row)); + + // Read each range of cells of v_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * v_size_row; + io.write_tensor(v_l[il], range.first * v_size_row, buf_size); + } + } + } else { + // When v is transposed, we also need the element size and get the element ranges from each row + const uint32_t kv_size = size; + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)v_l[il]->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write element size + const uint32_t v_size_el = ggml_type_size(v_l[il]->type); + io.write(&v_size_el, sizeof(v_size_el)); + + // Write GQA embedding size + io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + const size_t buf_size = range_size * v_size_el; + io.write_tensor(v_l[il], src_offset, buf_size); + } + } + } + } +} + +bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { + if (dest_seq_id != -1) { + // single sequence + + seq_rm(dest_seq_id, -1, -1); + + llama_sbatch sbatch; + llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); + + batch.n_tokens = cell_count; + batch.n_seq_tokens = cell_count; + batch.n_seqs = 1; + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id != 0) { + LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + return false; + } + + batch.pos[i] = pos; + } + batch.n_seq_id[0] = 1; + batch.seq_id[0] = &dest_seq_id; + + if (!find_slot(batch)) { + LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + return false; + } + + // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(head + cell_count <= size); + GGML_ASSERT(cells[head].pos == batch.pos[0]); + GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]); + GGML_ASSERT(cells[head].has_seq_id(dest_seq_id)); + GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id)); + } else { + // whole KV cache restore + + if (cell_count > size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); + return false; + } + + clear(true); + + for (uint32_t i = 0; i < cell_count; ++i) { + kv_cell & cell = cells[i]; + + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + cell.pos = pos; + + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + io.read_to(&seq_id, sizeof(seq_id)); + + // TODO: llama_kv_cache_recurrent should have a notion of max sequences + //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { + if (seq_id < 0) { + //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id); + return false; + } + + cell.seq_id.insert(seq_id); + + int32_t & tail = cells[seq_id].tail; + if (tail != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); + return false; + } + tail = i; + } + } + + head = 0; + used = cell_count; + } + + for (uint32_t i = 0; i < cell_count; ++i) { + uint32_t cell_id = head + i; + // make sure the recurrent states will keep their restored state + cells[cell_id].src = cell_id; + } + + return true; +} + +bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) { + uint32_t v_trans; + uint32_t n_layer; + io.read_to(&v_trans, sizeof(v_trans)); + io.read_to(&n_layer, sizeof(n_layer)); + + if (n_layer != hparams.n_layer) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); + return false; + } + if (cell_count > size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size); + return false; + } + if (false != (bool) v_trans) { + LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); + return false; + } + + // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Read type of key + int32_t k_type_i_ref; + io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); + const int32_t k_type_i = (int32_t) k_l[il]->type; + if (k_type_i != k_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t k_size_row_ref; + io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); + const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + if (k_size_row != k_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); + } + } + + if (!v_trans) { + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read row size of value + uint64_t v_size_row_ref; + io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); + const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + if (v_size_row != v_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the values for the whole cell range + ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); + } + } + } else { + // For each layer, read the values for each cell (transposed) + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read element size of value + uint32_t v_size_el_ref; + io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); + const size_t v_size_el = ggml_type_size(v_l[il]->type); + if (v_size_el != v_size_el_ref) { + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); + return false; + } + + // Read GQA embedding size + uint32_t n_embd_v_gqa_ref; + io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); + if (n_embd_v_gqa != n_embd_v_gqa_ref) { + LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); + return false; + } + + if (cell_count) { + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (head + j * size) * v_size_el; + ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + } + } + } + } + + return true; +} + +// +// llama_kv_cache_recurrent_state +// + +llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {} + +llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) { +} + +llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv, + llama_sbatch sbatch, + std::vector ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {} + +llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default; + +bool llama_kv_cache_recurrent_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_recurrent_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + kv->find_slot(ubatches[i_next]); + + return true; +} + +std::vector & llama_kv_cache_recurrent_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; +} + +llama_memory_status llama_kv_cache_recurrent_state::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_recurrent_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ubatches[i_next]; +} + +uint32_t llama_kv_cache_recurrent_state::get_n_kv() const { + return is_full ? kv->size : kv->n; +} + +uint32_t llama_kv_cache_recurrent_state::get_head() const { + return is_full ? 0 : kv->head; +} + +int32_t llama_kv_cache_recurrent_state::get_rs_z() const { + return is_full ? 0 : kv->rs_z; +} + +uint32_t llama_kv_cache_recurrent_state::get_size() const { + return kv->size; +} + +ggml_tensor * llama_kv_cache_recurrent_state::get_k_l(int32_t il) const { + return kv->k_l[il]; +} + +ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const { + return kv->v_l[il]; +} + +int32_t llama_kv_cache_recurrent_state::s_copy(int i) const { + return kv->cells[i + kv->head].src0; +} diff --git a/src/llama-kv-cache-recurrent.h b/src/llama-kv-cache-recurrent.h new file mode 100644 index 0000000000000000000000000000000000000000..f9b01a6513393fa4fea94eeeda9f51cbc89e5377 --- /dev/null +++ b/src/llama-kv-cache-recurrent.h @@ -0,0 +1,184 @@ +#pragma once + +#include "llama-batch.h" +#include "llama-graph.h" +#include "llama-memory.h" + +#include +#include + +// +// llama_kv_cache_recurrent +// + +// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i +// see the implementation of llama_kv_cache_unified_state_i for an example how to do it +class llama_kv_cache_recurrent : public llama_memory_i { +public: + llama_kv_cache_recurrent( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max); + + ~llama_kv_cache_recurrent() = default; + + // + // llama_memory_i + // + + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_all) override; + + llama_memory_state_ptr init_full() override; + + llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; + + void clear(bool data) override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + bool prepare(const std::vector & ubatches); + + // find a contiguous slot of kv cells and emplace the ubatch there + bool find_slot(const llama_ubatch & ubatch); + + bool get_can_shift() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + + uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) + uint32_t size = 0; // total number of cells, shared across all sequences + uint32_t used = 0; // used cells (i.e. at least one seq_id) + + // computed before each graph build + uint32_t n = 0; + + // first zero-ed state + int32_t rs_z = -1; + + // TODO: optimize for recurrent state needs + struct kv_cell { + llama_pos pos = -1; + int32_t src = -1; // used to know where states should be copied from + int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once) + int32_t tail = -1; + + std::set seq_id; + + bool has_seq_id(const llama_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } + + bool is_empty() const { + return seq_id.empty(); + } + + bool is_same_seq(const kv_cell & other) const { + return seq_id == other.seq_id; + } + }; + + std::vector cells; + + std::vector k_l; // per layer + std::vector v_l; + +private: + //const llama_model & model; + const llama_hparams & hparams; + + const uint32_t n_seq_max = 1; + + std::vector ctxs; + std::vector bufs; + + size_t total_size() const; + + size_t size_k_bytes() const; + size_t size_v_bytes() const; + + void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; + void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; + + bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t cell_count); +}; + +class llama_kv_cache_recurrent_state : public llama_memory_state_i { +public: + // used for errors + llama_kv_cache_recurrent_state(llama_memory_status status); + + // used to create a full-cache state + llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv); + + // used to create a state from a batch + llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv, + llama_sbatch sbatch, + std::vector ubatches); + + virtual ~llama_kv_cache_recurrent_state(); + + // + // llama_memory_state_i + // + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_recurrent_state specific API + // + + uint32_t get_n_kv() const; + uint32_t get_head() const; + int32_t get_rs_z() const; + uint32_t get_size() const; + + ggml_tensor * get_k_l(int32_t il) const; + ggml_tensor * get_v_l(int32_t il) const; + + int32_t s_copy(int i) const; + +private: + const llama_memory_status status; + + llama_kv_cache_recurrent * kv; + + llama_sbatch sbatch; + + size_t i_next = 0; + + std::vector ubatches; + + // + // data needed for building the compute graph for the current ubatch: + // TODO: extract all the state like `head` and `n` here + // + + const bool is_full = false; +}; diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a4a4c2b1b859de2c2b4be3904f8ecdf0c2e27d9c --- /dev/null +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -0,0 +1,285 @@ +#include "llama-kv-cache-unified-iswa.h" + +#include "llama-impl.h" +#include "llama-batch.h" +#include "llama-model.h" + +#include +#include + +// +// llama_kv_cache_unified_iswa +// + +llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool swa_full, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_ubatch, + uint32_t n_pad) : hparams(model.hparams) { + llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); }; + llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); }; + + const uint32_t size_base = kv_size; + + uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad)); + + // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size + if (swa_full) { + LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n", + __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); + + size_swa = size_base; + } + + LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); + + kv_base = std::make_unique( + model, std::move(filter_base), type_k, type_v, + v_trans, offload, size_base, n_seq_max, n_pad, + 0, LLAMA_SWA_TYPE_NONE); + + LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); + + kv_swa = std::make_unique( + model, std::move(filter_swa), type_k, type_v, + v_trans, offload, size_swa, n_seq_max, n_pad, + hparams.n_swa, hparams.swa_type); +} + +void llama_kv_cache_unified_iswa::clear(bool data) { + kv_base->clear(data); + kv_swa ->clear(data); +} + +bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + bool res = true; + + res = res & kv_base->seq_rm(seq_id, p0, p1); + res = res & kv_swa ->seq_rm(seq_id, p0, p1); + + return res; +} + +void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1); + kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) { + kv_base->seq_keep(seq_id); + kv_swa ->seq_keep(seq_id); +} + +void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + kv_base->seq_add(seq_id, p0, p1, shift); + kv_swa ->seq_add(seq_id, p0, p1, shift); +} + +void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + kv_base->seq_div(seq_id, p0, p1, d); + kv_swa ->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const { + // the base cache is a superset of the SWA cache, so we can just check the SWA cache + return kv_swa->seq_pos_min(seq_id); +} + +llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const { + return kv_swa->seq_pos_max(seq_id); +} + +llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) { + GGML_UNUSED(embd_all); + + // first try simple split + do { + auto sbatch = llama_sbatch(batch, hparams.n_embd, true); + + std::vector ubatches; + + while (sbatch.n_tokens > 0) { + auto ubatch = sbatch.split_simple(n_ubatch); + + ubatches.push_back(ubatch); + } + + auto heads_base = kv_base->prepare(ubatches); + if (heads_base.empty()) { + break; + } + + auto heads_swa = kv_swa->prepare(ubatches); + if (heads_swa.empty()) { + break; + } + + assert(heads_base.size() == heads_swa.size()); + + return std::make_unique( + this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches)); + } while (false); + + // if it fails, try equal split + do { + auto sbatch = llama_sbatch(batch, hparams.n_embd, false); + + std::vector ubatches; + + while (sbatch.n_tokens > 0) { + auto ubatch = sbatch.split_equal(n_ubatch); + + ubatches.push_back(ubatch); + } + + auto heads_base = kv_base->prepare(ubatches); + if (heads_base.empty()) { + break; + } + + auto heads_swa = kv_swa->prepare(ubatches); + if (heads_swa.empty()) { + break; + } + + assert(heads_base.size() == heads_swa.size()); + + return std::make_unique( + this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches)); + } while (false); + + // TODO: if we fail again, we should attempt different splitting strategies + // but to do that properly, we first have to refactor the batches to be more flexible + + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); +} + +llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() { + return std::make_unique(this); +} + +llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) { + return std::make_unique(this, lctx, optimize); +} + +bool llama_kv_cache_unified_iswa::get_can_shift() const { + return kv_base->get_size() == kv_swa->get_size(); +} + +void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + kv_base->state_write(io, seq_id); + kv_swa ->state_write(io, seq_id); +} + +void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + kv_base->state_read(io, seq_id); + kv_swa ->state_read(io, seq_id); +} + +llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const { + return kv_base.get(); +} + +llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const { + return kv_swa.get(); +} + +// +// llama_kv_cache_unified_iswa_state +// + +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {} + +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( + llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) { + state_base = kv->get_base()->init_full(); + state_swa = kv->get_swa ()->init_full(); + + status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status()); +} + +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( + llama_kv_cache_unified_iswa * kv, + llama_context * lctx, + bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) { + state_base = kv->get_base()->init_update(lctx, optimize); + state_swa = kv->get_swa ()->init_update(lctx, optimize); + + status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status()); +} + +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( + llama_kv_cache_unified_iswa * kv, + llama_sbatch sbatch, + std::vector heads_base, + std::vector heads_swa, + std::vector ubatches) + : status(LLAMA_MEMORY_STATUS_SUCCESS), + sbatch(std::move(sbatch)), + ubatches(std::move(ubatches)) { + // note: here we copy the ubatches. not sure if this is ideal + state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches)); + state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches)); + + status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status()); +} + +llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default; + +bool llama_kv_cache_unified_iswa_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + state_base->next(); + state_swa ->next(); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_unified_iswa_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + bool res = true; + + res = res & state_base->apply(); + res = res & state_swa ->apply(); + + return res; +} + +std::vector & llama_kv_cache_unified_iswa_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; +} + +llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ubatches[i_next]; +} + +const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return static_cast(state_base.get()); +} + +const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return static_cast(state_swa.get()); +} diff --git a/src/llama-kv-cache-unified-iswa.h b/src/llama-kv-cache-unified-iswa.h new file mode 100644 index 0000000000000000000000000000000000000000..6e941e1a41b88b31f0838c4834ae6eda830a3455 --- /dev/null +++ b/src/llama-kv-cache-unified-iswa.h @@ -0,0 +1,133 @@ +#pragma once + +#include "llama-kv-cache-unified.h" + +#include + +// +// llama_kv_cache_unified_iswa +// + +// utilizes two instances of llama_kv_cache_unified +// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers + +class llama_kv_cache_unified_iswa : public llama_memory_i { +public: + llama_kv_cache_unified_iswa( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool swa_full, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_ubatch, + uint32_t n_pad); + + ~llama_kv_cache_unified_iswa() = default; + + // + // llama_memory_i + // + + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_all) override; + + llama_memory_state_ptr init_full() override; + + llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; + + bool get_can_shift() const override; + + void clear(bool data) override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + + // + // llama_kv_cache_unified_iswa specific API + // + + llama_kv_cache_unified * get_base() const; + llama_kv_cache_unified * get_swa () const; + +private: + const llama_hparams & hparams; + + std::unique_ptr kv_base; + std::unique_ptr kv_swa; +}; + +class llama_kv_cache_unified_iswa_state : public llama_memory_state_i { +public: + // used for errors + llama_kv_cache_unified_iswa_state(llama_memory_status status); + + // used to create a full-cache state + llama_kv_cache_unified_iswa_state( + llama_kv_cache_unified_iswa * kv); + + // used to create an update state + llama_kv_cache_unified_iswa_state( + llama_kv_cache_unified_iswa * kv, + llama_context * lctx, + bool optimize); + + // used to create a state from a batch + llama_kv_cache_unified_iswa_state( + llama_kv_cache_unified_iswa * kv, + llama_sbatch sbatch, + std::vector heads_base, + std::vector heads_swa, + std::vector ubatches); + + virtual ~llama_kv_cache_unified_iswa_state(); + + // + // llama_memory_state_i + // + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_unified_iswa_state specific API + // + + const llama_kv_cache_unified_state * get_base() const; + const llama_kv_cache_unified_state * get_swa() const; + +private: + llama_memory_status status; + + //llama_kv_cache_unified_iswa * kv; + + llama_sbatch sbatch; + + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector ubatches; + + llama_memory_state_ptr state_base; + llama_memory_state_ptr state_swa; +}; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3b37679859d392481c4e56bb1004acbb4004ee76 --- /dev/null +++ b/src/llama-kv-cache-unified.cpp @@ -0,0 +1,1835 @@ +#include "llama-kv-cache-unified.h" + +#include "llama-impl.h" +#include "llama-io.h" +#include "llama-model.h" +#include "llama-context.h" + +#include +#include +#include +#include +#include +#include + +// +// llama_kv_cache_unified +// + +llama_kv_cache_unified::llama_kv_cache_unified( + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type) : + model(model), hparams(model.hparams), v_trans(v_trans), + n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { + + GGML_ASSERT(kv_size % n_pad == 0); + + // create a context for each buffer type + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return nullptr; + } + + ctx_map[buft] = ctx; + ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + head = 0; + + cells.resize(kv_size); + + for (uint32_t il = 0; il < hparams.n_layer; il++) { + if (filter && !filter(il)) { + LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il); + continue; + } + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + const char * dev_name = "CPU"; + + ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); + + if (offload) { + auto * dev = model.dev_layer(il); + buft = ggml_backend_dev_buffer_type(dev); + + dev_name = ggml_backend_dev_name(dev); + } + + LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name); + + ggml_context * ctx = ctx_for_buft(buft); + if (!ctx) { + throw std::runtime_error("failed to create ggml context for kv cache"); + } + + ggml_tensor * k; + ggml_tensor * v; + + k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size); + v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size); + + ggml_format_name(k, "cache_k_l%d", il); + ggml_format_name(v, "cache_v_l%d", il); + + map_layer_ids[il] = layers.size(); + layers.push_back({ il, k, v }); + } + + // allocate tensors and initialize the buffers to avoid NaNs in the padding + for (auto it : ctx_map) { + auto * buft = it.first; + auto * ctx = it.second; + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + throw std::runtime_error("failed to allocate buffer for kv cache"); + } + + LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + + ggml_backend_buffer_clear(buf, 0); + bufs.emplace_back(buf); + } + + { + const size_t memory_size_k = size_k_bytes(); + const size_t memory_size_v = size_v_bytes(); + + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } + + const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG"); + debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0; +} + +void llama_kv_cache_unified::clear(bool data) { + cells.reset(); + + head = 0; + + if (data) { + for (auto & buf : bufs) { + ggml_backend_buffer_clear(buf.get(), 0); + } + } +} + +bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + uint32_t new_head = cells.size(); + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + if (seq_id >= 0) { + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) { + if (new_head == cells.size()) { + new_head = i; + } + } + } + } else { + // match any sequence + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + cells.rm(i); + + if (new_head == cells.size()) { + new_head = i; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cells.size() && new_head < head) { + head = new_head; + } + + return true; +} + +void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + if (seq_id_src == seq_id_dst) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id_src)) { + cells.seq_add(i, seq_id_dst); + } + } +} + +void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { + uint32_t new_head = cells.size(); + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (cells.seq_keep(i, seq_id)) { + if (new_head == cells.size()) { + new_head = i; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cells.size() && new_head < head) { + head = new_head; + } +} + +void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + if (shift == 0) { + return; + } + + uint32_t new_head = cells.size(); + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over all cells. + if (p0 == p1) { + return; + } + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id)) { + if (cells.pos_add(i, shift)) { + if (new_head == cells.size()) { + new_head = i; + } + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + // Otherwise we just start the next search from the beginning. + head = new_head != cells.size() ? new_head : 0; +} + +void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (d == 1) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over the cache. + if (p0 == p1) { + return; + } + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id)) { + cells.pos_div(i, d); + } + } +} + +llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { + return cells.seq_pos_min(seq_id); +} + +llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { + return cells.seq_pos_max(seq_id); +} + +llama_memory_state_ptr llama_kv_cache_unified::init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_all) { + GGML_UNUSED(embd_all); + + do { + auto sbatch = llama_sbatch(batch, hparams.n_embd, true); + + std::vector ubatches; + while (sbatch.n_tokens > 0) { + ubatches.push_back(sbatch.split_simple(n_ubatch)); + } + + auto heads = prepare(ubatches); + if (heads.empty()) { + break; + } + + return std::make_unique( + this, std::move(sbatch), std::move(heads), std::move(ubatches)); + } while (false); + + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); +} + +llama_memory_state_ptr llama_kv_cache_unified::init_full() { + return std::make_unique(this); +} + +llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) { + bool do_shift = get_has_shift(); + + defrag_info dinfo; + + // see if we need to defrag + { + bool do_defrag = optimize; + + const auto thold = lctx->get_cparams().defrag_thold; + + if (!do_defrag && thold > 0.0f) { + const auto n_kv = cells.used_max_p1(); + + // - do not defrag small contexts (i.e. < 2048 tokens) + // - count the padding towards the number of used tokens + const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f; + + if (fragmentation > thold) { + LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); + + do_defrag = true; + } + } + + if (do_defrag) { + dinfo = defrag_prepare(lctx->graph_max_nodes()); + } + } + + return std::make_unique(this, lctx, do_shift, std::move(dinfo)); +} + +llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector & ubatches) { + llama_kv_cache_unified::ubatch_heads res; + + struct state { + uint32_t head_old; // old position of the head, before placing the ubatch + uint32_t head_new; // new position of the head, after placing the ubatch + + llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch + }; + + // remember the old state of the cells so we can restore it in the end + std::vector states; + + bool success = true; + + for (const auto & ubatch : ubatches) { + // only find a suitable slot for the ubatch. don't modify the cells yet + const int32_t head_new = find_slot(ubatch); + if (head_new < 0) { + success = false; + break; + } + + // remeber the position that we found + res.push_back(head_new); + + // store the old state of the cells in the recovery stack + states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)}); + + // now emplace the ubatch + apply_ubatch(head_new, ubatch); + } + + // iterate backwards and restore the cells to their original state + for (auto it = states.rbegin(); it != states.rend(); ++it) { + cells.set(it->head_new, it->cells); + head = it->head_old; + } + + if (!success) { + return {}; + } + + return res; +} + +bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) { + bool updated = false; + + auto * sched = lctx->get_sched(); + + if (do_shift) { + if (!get_can_shift()) { + GGML_ABORT("The current KV cache / model configuration does not support K-shift"); + } + + LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__); + + // apply K-shift if needed + if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { + ggml_backend_sched_reset(sched); + + auto * gf = lctx->graph_init(); + + auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__); + return updated; + } + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__); + return updated; + } + + res->set_inputs(nullptr); + + if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__); + return updated; + } + + updated = true; + } + + cells.reset_shift(); + } + + if (!dinfo.empty()) { + LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); + + // apply moves: + { + const auto n_kv = dinfo.ids.size(); + + for (uint32_t i = 0; i < n_kv; ++i) { + assert(dinfo.ids[i] <= n_kv); + + if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) { + continue; + } + + cells.mv(i, dinfo.ids[i]); + } + + // reset the head so we can find the first free slot during the next ubatch + head = 0; + } + + ggml_backend_sched_reset(sched); + + auto * gf = lctx->graph_init(); + + auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__); + return updated; + } + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__); + return updated; + } + + res->set_inputs(nullptr); + + if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__); + return updated; + } + + updated = true; + } + + return updated; +} + +int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { + const uint32_t n_tokens = ubatch.n_tokens; + + uint32_t head_cur = this->head; + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head_cur > cells.get_used() + 2*ubatch.n_tokens) { + head_cur = 0; + } + + if (n_tokens > cells.size()) { + LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); + return -1; + } + + if (debug > 0) { + LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa); + + if ((debug == 2 && n_swa > 0) || debug > 2) { + std::string ss; + for (uint32_t i = 0; i < cells.size(); ++i) { + if (cells.is_empty(i)) { + ss += '.'; + } else { + assert(cells.seq_count(i) >= 1); + + if (cells.seq_count(i) == 1) { + ss += std::to_string(cells.seq_get(i)); + } else { + ss += 'M'; + } + } + if (i%256 == 255) { + ss += " *"; + ss += '\n'; + } + } + LLAMA_LOG_DEBUG("\n%s\n", ss.c_str()); + } + + if ((debug == 2 && n_swa > 0) || debug > 2) { + std::string ss; + for (uint32_t i = 0; i < cells.size(); ++i) { + std::string cur; + if (cells.is_empty(i)) { + cur = '.'; + } else { + cur = std::to_string(cells.pos_get(i)); + } + const int n = cur.size(); + for (int j = 0; j < 5 - n; ++j) { + cur += ' '; + } + ss += cur; + if (i%256 == 255) { + ss += " *"; + } + if (i%64 == 63) { + ss += '\n'; + } + } + LLAMA_LOG_DEBUG("\n%s\n", ss.c_str()); + } + + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (cells.seq_pos_min(s) < 0) { + continue; + } + + LLAMA_LOG_DEBUG("%s: min[%d] = %5d, max[%d] = %5d\n", __func__, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s)); + } + } + + uint32_t n_tested = 0; + + while (true) { + if (head_cur + n_tokens > cells.size()) { + n_tested += cells.size() - head_cur; + head_cur = 0; + continue; + } + + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + //const llama_pos pos = ubatch.pos[i]; + //const llama_seq_id seq_id = ubatch.seq_id[i][0]; + + // can we use this cell? either: + // - the cell is empty + // - the cell is occupied only by one sequence: + // - (disabled) mask causally, if the sequence is the same as the one we are inserting + // - mask SWA, using current max pos for that sequence in the cache + // always insert in the cell with minimum pos + bool can_use = cells.is_empty(head_cur + i); + + if (!can_use && cells.seq_count(head_cur + i) == 1) { + const llama_pos pos_cell = cells.pos_get(head_cur + i); + + // (disabled) causal mask + // note: it's better to purge any "future" tokens beforehand + //if (cells.seq_has(head_cur + i, seq_id)) { + // can_use = pos_cell >= pos; + //} + + if (!can_use) { + const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i); + + // SWA mask + if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { + can_use = true; + } + } + } + + if (!can_use) { + found = false; + head_cur += i + 1; + n_tested += i + 1; + break; + } + } + + if (found) { + break; + } + + if (n_tested >= cells.size()) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return -1; + } + } + + return head_cur; +} + +void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) { + if (debug > 0) { + LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__); + LLAMA_LOG_DEBUG("%s: n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs); + LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs); + } + + // keep track of the max sequence position that we would overwrite with this ubatch + // for non-SWA cache, this would be always empty + llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + seq_pos_max_rm[s] = -1; + } + + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + for (uint32_t j = 0; j < ubatch.n_seq_tokens; ++j) { + const uint32_t idx = s*ubatch.n_seq_tokens + j; + + if (!cells.is_empty(head_cur + idx)) { + assert(cells.seq_count(head_cur + idx) == 1); + + const llama_seq_id seq_id = cells.seq_get(head_cur + idx); + const llama_pos pos = cells.pos_get(head_cur + idx); + + seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); + + cells.rm(head_cur + idx); + } + + cells.pos_set(head_cur + idx, ubatch.pos[idx]); + + // TODO: fix indexing [UBATCH_IDX] + for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) { + cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]); + } + } + } + + // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence + // will be present in the cache. so we have to purge any position which is less than those we would overwrite + // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (seq_pos_max_rm[s] == -1) { + continue; + } + + if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { + LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n", + __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s); + + seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1); + } + } + // move the head at the end of the slot + head = head_cur + ubatch.n_tokens; +} + +bool llama_kv_cache_unified::get_can_shift() const { + return true; +} + +uint32_t llama_kv_cache_unified::get_size() const { + return cells.size(); +} + +bool llama_kv_cache_unified::get_has_shift() const { + return cells.get_has_shift(); +} + +uint32_t llama_kv_cache_unified::get_n_kv() const { + return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); +} + +ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * k = layers[ikv].k; + + return ggml_view_3d(ctx, k, + hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, + ggml_row_size(k->type, hparams.n_embd_head_k), + ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), + 0); +} + +ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * v = layers[ikv].v; + + if (!v_trans) { + // note: v->nb[1] <= v->nb[2] + return ggml_view_3d(ctx, v, + hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, + ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2] + 0); + } + + // note: v->nb[1] > v->nb[2] + return ggml_view_3d(ctx, v, + n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, + ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, v->ne[1]), // v->nb[2] + 0); +} + +ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * k = layers[ikv].k; + + const int64_t n_tokens = k_cur->ne[2]; + + ggml_tensor * k_view = ggml_view_1d(ctx, k, + n_tokens*hparams.n_embd_k_gqa(il), + ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur); + + return ggml_cpy(ctx, k_cur, k_view); +} + +ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * v = layers[ikv].v; + + const int64_t n_tokens = v_cur->ne[2]; + + v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens); + + ggml_tensor * v_view = nullptr; + + if (!v_trans) { + v_view = ggml_view_1d(ctx, v, + n_tokens*hparams.n_embd_v_gqa(il), + ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur); + } else { + // note: the V cache is transposed when not using flash attention + v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il), + (v->ne[1])*ggml_element_size(v), + (head_cur)*ggml_element_size(v)); + + v_cur = ggml_transpose(ctx, v_cur); + } + + return ggml_cpy(ctx, v_cur, v_view); +} + +void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { + const uint32_t n_tokens = ubatch->n_tokens; + const uint32_t n_seq_tokens = ubatch->n_seq_tokens; + const uint32_t n_seqs = ubatch->n_seqs; + + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + float * data = (float *) dst->data; + + const int64_t n_kv = dst->ne[0]; + + // Use only the previous KV cells of the correct sequence for each token of the ubatch. + // It's assumed that if a token in the batch has multiple sequences, they are equivalent. + // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: + // Causal mask: + // xxx------- + // xxxx------ + // xxxxx----- + // Non-causal mask: + // xxxxx----- + // xxxxx----- + // xxxxx----- + // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 + for (uint32_t h = 0; h < 1; ++h) { + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[s][0]; + + for (uint32_t j = 0; j < n_seq_tokens; ++j) { + const uint32_t idx = s*n_seq_tokens + j; + + const llama_pos p1 = ubatch->pos[idx]; + + for (uint32_t i = 0; i < n_kv; ++i) { + float f = 0.0f; + + bool masked = false; + + if (cells.is_empty(i)) { + masked = true; + } else { + const llama_pos p0 = cells.pos_get(i); + + // mask the token if not the same sequence + masked = masked || (!cells.seq_has(i, seq_id)); + + // mask future tokens + masked = masked || (causal_attn && p0 > p1); + + // apply SWA if any + masked = masked || (is_masked_swa(p0, p1)); + + if (!masked && hparams.use_alibi) { + f = -std::abs(p0 - p1); + } + } + + if (masked) { + f = -INFINITY; + } + + data[h*(n_kv*n_tokens) + idx*n_kv + i] = f; + } + } + } + + // mask padded tokens + if (data) { + for (uint32_t j = n_tokens; j < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++j) { + for (uint32_t i = 0; i < n_kv; ++i) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } +} + +void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + int32_t * data = (int32_t *) dst->data; + + for (uint32_t i = 0; i < cells.size(); ++i) { + data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i); + } +} + +void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { + const int64_t n_tokens = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing + + int32_t * data = (int32_t *) dst->data; + + const int32_t n_kv = dst->ne[0]; + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_kv; ++i) { + // the position when the cells is empty is irrelevant - it will be masked out later in the attention + const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i); + + data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false); + } + } + } +} + +size_t llama_kv_cache_unified::total_size() const { + size_t size = 0; + + for (const auto & buf : bufs) { + size += ggml_backend_buffer_get_size(buf.get()); + } + + return size; +} + +size_t llama_kv_cache_unified::size_k_bytes() const { + size_t size_k_bytes = 0; + + for (const auto & layer : layers) { + size_k_bytes += ggml_nbytes(layer.k); + } + + return size_k_bytes; +} + +size_t llama_kv_cache_unified::size_v_bytes() const { + size_t size_v_bytes = 0; + + for (const auto & layer : layers) { + size_v_bytes += ggml_nbytes(layer.v); + } + + return size_v_bytes; +} + +ggml_tensor * llama_kv_cache_unified::build_rope_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * shift, + ggml_tensor * factors, + float freq_base, + float freq_scale) const { + const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; + + const auto & yarn_ext_factor = cparams.yarn_ext_factor; + const auto & yarn_beta_fast = cparams.yarn_beta_fast; + const auto & yarn_beta_slow = cparams.yarn_beta_slow; + + const auto & n_rot = hparams.n_rot; + const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE + // @ngxson : this is a workaround + // for M-RoPE, we want to rotate the whole vector when doing KV shift + // a normal RoPE should work, we just need to use the correct ordering + // ref: https://github.com/ggml-org/llama.cpp/pull/13870 + ? LLAMA_ROPE_TYPE_NEOX + : hparams.rope_type; + + // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. + // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. + const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 + ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) + : cparams.yarn_attn_factor; + + ggml_tensor * tmp; + + if (ggml_is_quantized(cur->type)) { + // dequantize to f32 -> RoPE -> quantize back + tmp = ggml_cast(ctx, cur, GGML_TYPE_F32); + + tmp = ggml_rope_ext(ctx, tmp, + shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + + tmp = ggml_cpy(ctx, tmp, cur); + } else { + // we rotate only the first n_rot dimensions + tmp = ggml_rope_ext_inplace(ctx, cur, + shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + } + + return tmp; +} + +class llm_graph_input_k_shift : public llm_graph_input_i { +public: + llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} + virtual ~llm_graph_input_k_shift() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * k_shift; // I32 [kv_size] + + const llama_kv_cache_unified * kv_self; +}; + +void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); + + if (k_shift) { + kv_self->set_input_k_shift(k_shift); + } +} + +llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const { + auto res = std::make_unique(); + + const auto & n_embd_head_k = hparams.n_embd_head_k; + //const auto & n_embd_head_v = hparams.n_embd_head_v; + + auto inp = std::make_unique(this); + + inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size()); + ggml_set_input(inp->k_shift); + + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + ggml_tensor * k = + ggml_view_3d(ctx, layer.k, + n_embd_head_k, n_head_kv, cells.size(), + ggml_row_size(layer.k->type, n_embd_head_k), + ggml_row_size(layer.k->type, n_embd_k_gqa), + 0); + + ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); + + ggml_build_forward_expand(gf, cur); + } + + res->add_input(std::move(inp)); + + return res; +} + +llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf, + const defrag_info & dinfo) const { + auto res = std::make_unique(); + + const auto & ids = dinfo.ids; + +#if 0 + // CPU defrag + // + // TODO: optimizations are possible: + // - multiple threads + // - avoid copying to the host memory when already there + // + // likely not worth the effort, as we have ggml_graph based defrag + // + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + + const uint32_t kv_size = size; + + std::vector buf_k; + std::vector buf_v; + + for (uint32_t il = 0; il < n_layer; ++il) { + const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size); + + const size_t v_size_el = ggml_type_size(v_l[il]->type); + const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size); + + buf_k.resize(k_size); + buf_v.resize(v_size); + + ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size()); + ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size()); + + // batch move [i, i+nm) to [id, id+nm) + // note: cells can move only to a lower index + for (uint32_t i = 0; i < n_kv; ++i) { + const uint32_t id = ids[i]; + + if (i == id || id == n_kv) { + continue; + } + + uint32_t nm = 1; + + while (i + nm < n_kv && ids[i + nm] == id + nm) { + nm++; + } + + // move keys + { + const int64_t os = i*k_size_row; + const int64_t od = id*k_size_row; + + memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row); + } + + // move values (note: they are transposed) + { + const int64_t os = i; + const int64_t od = id; + + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el); + } + } + + i += nm - 1; + } + + ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size()); + ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size()); + } +#else + for (uint32_t i = 0; i < ids.size(); ++i) { + const uint32_t id = ids[i]; + + if (i == id || id == ids.size()) { + continue; + } + + uint32_t nm = 1; + + while (i + nm < ids.size() && ids[i + nm] == id + nm) { + nm++; + } + + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + + ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k, + n_embd_k_gqa, nm, + ggml_row_size(layer.k->type, n_embd_k_gqa), + ggml_row_size(layer.k->type, n_embd_k_gqa*i)); + + ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k, + n_embd_k_gqa, nm, + ggml_row_size(layer.k->type, n_embd_k_gqa), + ggml_row_size(layer.k->type, n_embd_k_gqa*id)); + + ggml_tensor * view_v_src; + ggml_tensor * view_v_dst; + + if (cparams.flash_attn) { + // NOTE: the V cache is not transposed when using flash attention + view_v_src = ggml_view_2d(ctx, layer.v, + n_embd_v_gqa, nm, + ggml_row_size(layer.v->type, n_embd_v_gqa), + ggml_row_size(layer.v->type, n_embd_v_gqa*i)); + + view_v_dst = ggml_view_2d(ctx, layer.v, + n_embd_v_gqa, nm, + ggml_row_size(layer.v->type, n_embd_v_gqa), + ggml_row_size(layer.v->type, n_embd_v_gqa*id)); + } else { + view_v_src = ggml_view_2d(ctx, layer.v, + nm, n_embd_v_gqa, + ggml_row_size(layer.v->type, cells.size()), + ggml_row_size(layer.v->type, i)); + + view_v_dst = ggml_view_2d(ctx, layer.v, + nm, n_embd_v_gqa, + ggml_row_size(layer.v->type, cells.size()), + ggml_row_size(layer.v->type, id)); + } + + ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst)); + ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst)); + } + + i += nm - 1; + } + + //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); +#endif + + return res; +} + +llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const { + const uint32_t n_layer = layers.size(); + + const uint32_t n_kv = cells.used_max_p1(); + const uint32_t n_used = cells.get_used(); + + assert(n_used <= n_kv); + + //const int64_t t_start = ggml_time_us(); + + // number of cells moved + uint32_t n_moves = 0; + + // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag) + // - source view, destination view, copy operation + // - x2 for keys and values + //const uint32_t max_moves = max_nodes()/(6*n_layer); + // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 + const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer); + + // determine which KV cells to move where + defrag_info res; + auto & ids = res.ids; + + ids.resize(n_kv, n_kv); + + for (uint32_t i0 = 0; i0 < n_used; ++i0) { + if (!cells.is_empty(i0)) { + ids[i0] = i0; + + continue; + } + + // found a hole - fill it with data from the end of the cache + + uint32_t nh = 1; + + // determine the size of the hole + while (i0 + nh < n_used && cells.is_empty(i0 + nh)) { + nh++; + } + + uint32_t nf = 0; + uint32_t is = n_kv - 1; + + // starting from the end, find nh non-empty cells + for (; is > i0; --is) { + if (cells.is_empty(is) || ids[is] != n_kv) { + continue; + } + + // non-empty cell which is not yet moved + nf++; + + if (nf == nh) { + break; + } + } + + // this can only happen if `n_used` is not accurate, which would be a bug + GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh"); + + nf = 0; + + uint32_t i1 = is; + + // are we moving a continuous block of memory? + bool cont = false; + + // should we stop searching for the next move? + bool stop = false; + + // go back and move the nf cells to the hole + for (; i1 < n_kv; ++i1) { + if (cells.is_empty(i1) || ids[i1] != n_kv) { + if (n_moves == max_moves) { + stop = true; + break; + } + + cont = false; + continue; + } + + // this cell goes to (i0 + nf) + ids[i1] = i0 + nf; + + if (!cont) { + n_moves++; + cont = true; + } + + nf++; + + if (nf == nh) { + break; + } + } + + if (stop || n_moves == max_moves) { + break; + } + + //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh); + + i0 += nh - 1; + } + + if (n_moves == 0) { + return {}; + } + + LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves); + + LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer); + + return res; +} + +bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { + assert(p0 >= 0 && p1 >= 0); + + switch (swa_type) { + case LLAMA_SWA_TYPE_NONE: + { + } break; + case LLAMA_SWA_TYPE_STANDARD: + { + if (p1 - p0 >= (int32_t) n_swa) { + return true; + } + } break; + case LLAMA_SWA_TYPE_CHUNKED: + { + const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; + + if (p0 < pos_chunk_start) { + return true; + } + } break; + } + + return false; +} + +void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + std::vector> cell_ranges; // ranges, from inclusive, to exclusive + uint32_t cell_count = 0; + + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = cells.size(); + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) { + ++cell_count; + if (cell_range_begin == cells.size()) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != cells.size()) { + cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = cells.size(); + } + } + } + + if (cell_range_begin != cells.size()) { + cell_ranges.emplace_back(cell_range_begin, cells.size()); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + + io.write(&cell_count, sizeof(cell_count)); + + state_write_meta(io, cell_ranges, seq_id); + state_write_data(io, cell_ranges); +} + +void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + uint32_t cell_count; + io.read_to(&cell_count, sizeof(cell_count)); + + bool res = true; + res = res && state_read_meta(io, cell_count, seq_id); + res = res && state_read_data(io, cell_count); + + if (!res) { + if (seq_id == -1) { + clear(true); + } else { + seq_rm(seq_id, -1, -1); + } + throw std::runtime_error("failed to restore kv cache"); + } +} + +void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { + for (const auto & range : cell_ranges) { + for (uint32_t i = range.first; i < range.second; ++i) { + std::vector seq_ids; + + for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) { + if (cur == seq_id || seq_id == -1) { + if (cells.seq_has(i, cur)) { + seq_ids.push_back(cur); + } + } + } + + const llama_pos pos = cells.pos_get(i); + const uint32_t n_seq_id = seq_ids.size(); + + io.write(&pos, sizeof(pos)); + io.write(&n_seq_id, sizeof(n_seq_id)); + + for (const auto & seq_id : seq_ids) { + io.write(&seq_id, sizeof(seq_id)); + } + } + } +} + +void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { + const uint32_t v_trans = this->v_trans ? 1 : 0; + const uint32_t n_layer = layers.size(); + + io.write(&v_trans, sizeof(v_trans)); + io.write(&n_layer, sizeof(n_layer)); + + std::vector tmp_buf; + + // Iterate and write all the keys first, each row is a cell + // Get whole range at a time + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Write key type + const int32_t k_type_i = (int32_t)layer.k->type; + io.write(&k_type_i, sizeof(k_type_i)); + + // Write row size of key + const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); + io.write(&k_size_row, sizeof(k_size_row)); + + // Read each range of cells of k_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * k_size_row; + io.write_tensor(layer.k, range.first * k_size_row, buf_size); + } + } + + if (!v_trans) { + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)layer.v->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write row size of value + const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); + io.write(&v_size_row, sizeof(v_size_row)); + + // Read each range of cells of v_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * v_size_row; + io.write_tensor(layer.v, range.first * v_size_row, buf_size); + } + } + } else { + // When v is transposed, we also need the element size and get the element ranges from each row + const uint32_t kv_size = cells.size(); + + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)layer.v->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write element size + const uint32_t v_size_el = ggml_type_size(layer.v->type); + io.write(&v_size_el, sizeof(v_size_el)); + + // Write GQA embedding size + io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + const size_t buf_size = range_size * v_size_el; + io.write_tensor(layer.v, src_offset, buf_size); + } + } + } + } +} + +bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { + if (dest_seq_id != -1) { + // single sequence + + seq_rm(dest_seq_id, -1, -1); + + llama_sbatch sbatch; + llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); + + ubatch.n_tokens = cell_count; + ubatch.n_seq_tokens = cell_count; + ubatch.n_seqs = 1; + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id != 1) { + LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + return false; + } + + // read the sequence id, but directly discard it - we will use dest_seq_id instead + { + llama_seq_id seq_id; + io.read_to(&seq_id, sizeof(seq_id)); + } + + ubatch.pos[i] = pos; + ubatch.n_seq_id[i] = n_seq_id; + ubatch.seq_id[i] = &dest_seq_id; + } + + const auto head_cur = find_slot(ubatch); + if (head_cur < 0) { + LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + return false; + } + + apply_ubatch(head_cur, ubatch); + + // keep the head at the old position because we will read the KV data into it in state_read_data() + head = head_cur; + + // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(head_cur + cell_count <= cells.size()); + GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]); + GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]); + GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id)); + GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id)); + } else { + // whole KV cache restore + + if (cell_count > cells.size()) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); + return false; + } + + clear(true); + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + cells.pos_set(i, pos); + + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + io.read_to(&seq_id, sizeof(seq_id)); + + if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) { + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max); + return false; + } + + cells.seq_add(i, seq_id); + } + } + + head = 0; + } + + return true; +} + +bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) { + uint32_t v_trans; + uint32_t n_layer; + + io.read_to(&v_trans, sizeof(v_trans)); + io.read_to(&n_layer, sizeof(n_layer)); + + if (n_layer != layers.size()) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size()); + return false; + } + + if (cell_count > cells.size()) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size()); + return false; + } + + if (this->v_trans != (bool) v_trans) { + LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); + return false; + } + + // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Read type of key + int32_t k_type_i_ref; + io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); + const int32_t k_type_i = (int32_t) layer.k->type; + if (k_type_i != k_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t k_size_row_ref; + io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); + const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); + if (k_size_row != k_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); + } + } + + if (!this->v_trans) { + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)layer.v->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read row size of value + uint64_t v_size_row_ref; + io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); + const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); + if (v_size_row != v_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the values for the whole cell range + ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); + } + } + } else { + // For each layer, read the values for each cell (transposed) + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)layer.v->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read element size of value + uint32_t v_size_el_ref; + io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); + const size_t v_size_el = ggml_type_size(layer.v->type); + if (v_size_el != v_size_el_ref) { + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); + return false; + } + + // Read GQA embedding size + uint32_t n_embd_v_gqa_ref; + io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); + if (n_embd_v_gqa != n_embd_v_gqa_ref) { + LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); + return false; + } + + if (cell_count) { + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (head + j * cells.size()) * v_size_el; + ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + } + } + } + } + + return true; +} + +// +// llama_kv_cache_unified_state +// + +llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {} + +llama_kv_cache_unified_state::llama_kv_cache_unified_state( + llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) { + n_kv = kv->get_size(); + head = 0; +} + +llama_kv_cache_unified_state::llama_kv_cache_unified_state( + llama_kv_cache_unified * kv, + llama_context * lctx, + bool do_shift, + defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) { + if (!do_shift && this->dinfo.empty()) { + status = LLAMA_MEMORY_STATUS_NO_UPDATE; + } +} + +llama_kv_cache_unified_state::llama_kv_cache_unified_state( + llama_kv_cache_unified * kv, + llama_sbatch sbatch, + llama_kv_cache_unified::ubatch_heads heads, + std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) { +} + +llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default; + +bool llama_kv_cache_unified_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_unified_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + // no ubatches -> this is a KV cache update + if (ubatches.empty()) { + kv->update(lctx, do_shift, dinfo); + + return true; + } + + kv->apply_ubatch(heads[i_next], ubatches[i_next]); + + n_kv = kv->get_n_kv(); + head = heads[i_next]; + + return true; +} + +std::vector & llama_kv_cache_unified_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; +} + +llama_memory_status llama_kv_cache_unified_state::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ubatches[i_next]; +} + +uint32_t llama_kv_cache_unified_state::get_n_kv() const { + return n_kv; +} + +ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const { + return kv->get_k(ctx, il, n_kv); +} + +ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const { + return kv->get_v(ctx, il, n_kv); +} + +ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { + return kv->cpy_k(ctx, k_cur, il, head); +} + +ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { + return kv->cpy_v(ctx, v_cur, il, head); +} + +void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const { + kv->set_input_k_shift(dst); +} + +void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { + kv->set_input_kq_mask(dst, ubatch, causal_attn); +} + +void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { + kv->set_input_pos_bucket(dst, ubatch); +} + +uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { + // the FA kernels require padding to avoid extra runtime boundary checks + return cparams.flash_attn ? 256u : 32u; +} diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h new file mode 100644 index 0000000000000000000000000000000000000000..d96571d952b81db3e47b8653018b79b5e3235325 --- /dev/null +++ b/src/llama-kv-cache-unified.h @@ -0,0 +1,308 @@ +#pragma once + +#include "llama-batch.h" +#include "llama-graph.h" +#include "llama-kv-cells.h" +#include "llama-memory.h" + +#include +#include + +struct llama_cparams; +struct llama_hparams; +struct llama_model; +struct llama_context; + +// +// llama_kv_cache_unified +// + +class llama_kv_cache_unified : public llama_memory_i { +public: + static uint32_t get_padding(const llama_cparams & cparams); + + // this callback is used to filter out layers that should not be included in the cache + using layer_filter_cb = std::function; + + using ubatch_heads = std::vector; + + struct defrag_info { + bool empty() const { + return ids.empty(); + } + + // contains information about which cell moves where: + // - cell i moves to ids[i] + // - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved + std::vector ids; + }; + + llama_kv_cache_unified( + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type); + + ~llama_kv_cache_unified() = default; + + // + // llama_memory_i + // + + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_all) override; + + llama_memory_state_ptr init_full() override; + + llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; + + bool get_can_shift() const override; + + void clear(bool data) override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + + // + // llama_kv_cache_unified specific API + // + + uint32_t get_size() const; + + bool get_has_shift() const; + + // + // graph_build API + // + + uint32_t get_n_kv() const; + + // get views of the current state of the cache + ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const; + ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const; + + // store k_cur and v_cur in the cache based on the provided head location + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const; + + // + // preparation API + // + + // find places for the provided ubatches in the cache, returns the head locations + // return empty vector on failure + ubatch_heads prepare(const std::vector & ubatches); + + bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo); + + // return the cell position where we can insert the ubatch + // return -1 on failure to find a contiguous slot of kv cells + int32_t find_slot(const llama_ubatch & ubatch) const; + + // emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens) + void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch); + + // + // set_input API + // + + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; + void set_input_k_shift (ggml_tensor * dst) const; + void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + +private: + const llama_model & model; + const llama_hparams & hparams; + + struct kv_layer { + // layer index in the model + // note: can be different from the layer index in the KV cache + uint32_t il; + + ggml_tensor * k; + ggml_tensor * v; + }; + + bool v_trans = true; // the value tensor is transposed + + // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) + // note: this is not part of the KV state and it's only used to speed-up the find_slot() method + uint32_t head = 0; + + const uint32_t n_seq_max = 1; + + // required padding + const uint32_t n_pad = 1; + + // SWA + const uint32_t n_swa = 0; + + int debug = 0; + + const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; + + std::vector ctxs; + std::vector bufs; + + llama_kv_cells_unified cells; + + std::vector layers; + + // model layer id -> KV cache layer id + std::unordered_map map_layer_ids; + + // return non-empty vector if cells have been moved + defrag_info defrag_prepare(int32_t n_max_nodes) const; + + size_t total_size() const; + + size_t size_k_bytes() const; + size_t size_v_bytes() const; + + bool is_masked_swa(llama_pos p0, llama_pos p1) const; + + ggml_tensor * build_rope_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * shift, + ggml_tensor * factors, + float freq_base, + float freq_scale) const; + + llm_graph_result_ptr build_graph_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const; + + llm_graph_result_ptr build_graph_defrag( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf, + const defrag_info & dinfo) const; + + void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; + void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; + + bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t cell_count); +}; + +class llama_kv_cache_unified_state : public llama_memory_state_i { +public: + // some shorthands + using ubatch_heads = llama_kv_cache_unified::ubatch_heads; + using defrag_info = llama_kv_cache_unified::defrag_info; + + // used for errors + llama_kv_cache_unified_state(llama_memory_status status); + + // used to create a full-cache state + llama_kv_cache_unified_state( + llama_kv_cache_unified * kv); + + // used to create an update state + llama_kv_cache_unified_state( + llama_kv_cache_unified * kv, + llama_context * lctx, + bool do_shift, + defrag_info dinfo); + + // used to create a decode state from a batch + llama_kv_cache_unified_state( + llama_kv_cache_unified * kv, + llama_sbatch sbatch, + ubatch_heads heads, + std::vector ubatches); + + virtual ~llama_kv_cache_unified_state(); + + // + // llama_memory_state_i + // + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_unified_state specific API + // + + uint32_t get_n_kv() const; + + // get views of the current state of the cache + ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; + ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; + + // store k_cur and v_cur in the cache based on the provided head location + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const; + + void set_input_k_shift(ggml_tensor * dst) const; + + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; + void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + +private: + llama_memory_status status; + + llama_kv_cache_unified * kv; + llama_context * lctx; + + // + // update state + // + + bool do_shift = false; + + defrag_info dinfo; + + // + // batch processing state + // + + llama_sbatch sbatch; + + // the index of the next ubatch to process + size_t i_next = 0; + + ubatch_heads heads; + + std::vector ubatches; + + // + // data needed for building the compute graph for the current ubatch: + // + + // a heuristic, to avoid attending the full cache if it is not yet utilized + // as the cache gets filled, the benefit from this heuristic disappears + int32_t n_kv; + + // the beginning of the current slot in which the ubatch will be inserted + int32_t head; +}; diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h new file mode 100644 index 0000000000000000000000000000000000000000..1d4e70f4d321249882287e0bf6b1f56f1c8110dc --- /dev/null +++ b/src/llama-kv-cells.h @@ -0,0 +1,415 @@ +#pragma once + +#include "llama.h" +#include "llama-cparams.h" + +#include +#include +#include +#include + +// meta information about KV cells that can be part of multiple sequences at the same time +// TODO: add unit tests +class llama_kv_cells_unified { +public: + void reset() { + for (uint32_t i = 0; i < pos.size(); ++i) { + pos[i] = -1; + shift[i] = 0; + seq[i].reset(); + } + + has_shift = false; + + used.clear(); + + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + seq_pos[s].clear(); + } + } + + void reset_shift() { + has_shift = false; + + for (uint32_t i = 0; i < shift.size(); ++i) { + shift[i] = 0; + } + } + + uint32_t size() const { + return pos.size(); + } + + void resize(uint32_t n) { + pos.resize(n); + shift.resize(n); + seq.resize(n); + + reset(); + } + + bool is_empty(uint32_t i) const { + assert(i < pos.size()); + assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0); + + return pos[i] == -1; + } + + uint32_t get_used() const { + return used.size(); + } + + // the index of the first cell that is used + // return 0 if no cells are used + uint32_t used_min() const { + return used.empty() ? 0 : *used.begin(); + } + + // the index of the last cell that is used + 1 + // return 0 if no cells are used + uint32_t used_max_p1() const { + return used.empty() ? 0 : *used.rbegin() + 1; + } + + bool get_has_shift() const { + return has_shift; + } + + // move cell isrc to idst (used during defrag) + void mv(uint32_t isrc, uint32_t idst) { + assert(isrc < pos.size()); + assert(idst < pos.size()); + + assert(pos[idst] == -1); + assert(pos[isrc] != -1); + + pos [idst] = pos [isrc]; + shift[idst] = shift[isrc]; + seq [idst] = seq [isrc]; + + pos [isrc] = -1; + shift[isrc] = 0; + seq [isrc].reset(); + + used.erase (isrc); + used.insert(idst); + } + + // copy the state of cells [i, i + n) (used for save/restore the state of the cells) + llama_kv_cells_unified cp(uint32_t i, uint32_t n) const { + assert(i + n <= pos.size()); + + llama_kv_cells_unified res; + + res.resize(n); + + for (uint32_t j = 0; j < n; ++j) { + res.pos[j] = pos[i + j]; + res.seq[j] = seq[i + j]; + + assert(shift[i + j] == 0); + } + + return res; + } + + // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells) + void set(uint32_t i, const llama_kv_cells_unified & other) { + assert(i + other.pos.size() <= pos.size()); + + for (uint32_t j = 0; j < other.pos.size(); ++j) { + if (pos[i + j] == -1 && other.pos[j] != -1) { + used.insert(i + j); + } + + if (pos[i + j] != -1 && other.pos[j] == -1) { + used.erase(i + j); + } + + if (pos[i + j] != -1) { + seq_pos_rm(i + j); + } + + pos[i + j] = other.pos[j]; + seq[i + j] = other.seq[j]; + + if (pos[i + j] != -1) { + seq_pos_add(i + j); + } + + assert(shift[i + j] == 0); + } + } + + // clear a non-empty cell + void rm(uint32_t i) { + assert(i < pos.size()); + assert(pos[i] != -1); + + seq_pos_rm(i); + seq[i].reset(); + + pos[i] = -1; + shift[i] = 0; + + used.erase(i); + } + + // note: call only if the cell has seq_id + // return true if the cell becomes empty + bool seq_rm(uint32_t i, llama_seq_id seq_id) { + assert(i < pos.size()); + assert(seq[i].test(seq_id)); + assert(pos[i] != -1); + assert(seq_id >= 0); + + seq[i].reset(seq_id); + seq_pos[seq_id].erase(pos[i]); + + if (seq[i].none()) { + pos[i] = -1; + shift[i] = 0; + + used.erase(i); + + return true; + } + + return false; + } + + // return true if the cell becomes empty (i.e. it did not contain seq_id before the call) + bool seq_keep(uint32_t i, llama_seq_id seq_id) { + assert(i < pos.size()); + + if (seq[i].test(seq_id)) { + seq_pos_rm(i); + seq[i].reset(); + + seq[i].set(seq_id); + seq_pos[seq_id].insert(pos[i]); + + return false; + } + + if (seq[i].any()) { + seq_pos_rm(i); + seq[i].reset(); + + pos[i] = -1; + shift[i] = 0; + + used.erase(i); + + return true; + } + + assert(pos[i] == -1); + + return false; + } + + // number of different sequences in the cell + int seq_count(uint32_t i) const { + assert(i < pos.size()); + assert(pos[i] != -1); + + return seq[i].count(); + } + + // check if the cell contains seq_id + bool seq_has(uint32_t i, llama_seq_id seq_id) const { + assert(i < pos.size()); + assert(seq_id >= 0); + + return seq[i].test(seq_id); + } + + // note: call only if the cell is not empty and the seq_id is not in the cell + void seq_add(uint32_t i, llama_seq_id seq_id) { + assert(i < pos.size()); + assert(pos[i] != -1); + assert(!seq[i].test(seq_id)); + + seq[i].set(seq_id); + seq_pos[seq_id].insert(pos[i]); + } + + // return the sequence id of this cell + // note: call only for cells with exactly one sequence + llama_seq_id seq_get(uint32_t i) const { + assert(seq[i].count() == 1); + + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (seq[i].test(s)) { + return s; + } + } + + return -1; + } + + // the minimum position of sequence seq_id currently present in any of the cells + // return -1 if the sequence is not present + llama_pos seq_pos_min(llama_seq_id seq_id) const { + assert(seq_id >= 0); + assert(seq_id < LLAMA_MAX_SEQ); + + if (seq_pos[seq_id].empty()) { + return -1; + } + + return *seq_pos[seq_id].begin(); + } + + // the maximum position of sequence seq_id currently present in any of the cells + // return -1 if the sequence is not present + llama_pos seq_pos_max(llama_seq_id seq_id) const { + assert(seq_id >= 0); + assert(seq_id < LLAMA_MAX_SEQ); + + if (seq_pos[seq_id].empty()) { + return -1; + } + + return *seq_pos[seq_id].rbegin(); + } + + // note: call only if the cell is not empty + llama_pos pos_get(uint32_t i) const { + assert(i < pos.size()); + assert(pos[i] != -1); + + return pos[i]; + } + + // note: call only if the cell is not empty + llama_pos get_shift(uint32_t i) const { + assert(i < pos.size()); + assert(pos[i] != -1); + + return shift[i]; + } + + // check if a cell is not empty and its position is within [p0, p1) + bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const { + assert(i < pos.size()); + + return pos[i] >= p0 && pos[i] < p1; + } + + // set the position of an empty cell + // does not modify "has_shift" + // note: call only if the cell is empty + void pos_set(uint32_t i, llama_pos p) { + assert(i < pos.size()); + assert(pos[i] == -1); + assert(seq[i].none()); + + pos[i] = p; + + used.insert(i); + } + + // pos[i] = pos[i] + d + // sets "has_shift" to true + // note: call only if the cell is not empty + bool pos_add(uint32_t i, llama_pos d) { + assert(i < pos.size()); + assert(pos[i] != -1); + + seq_pos_rm(i); + + pos[i] += d; + shift[i] += d; + + has_shift = true; + + if (pos[i] < 0) { + seq[i].reset(); + pos[i] = -1; + shift[i] = 0; + + used.erase(i); + + return true; + } + + seq_pos_add(i); + + return false; + } + + // pos[i] = pos[i] / d + // sets "has_shift" to true + // note: call only if the cell is not empty + void pos_div(uint32_t i, int d) { + assert(i < pos.size()); + assert(pos[i] != -1); + + const llama_pos p_old = pos[i]; + + seq_pos_rm(i); + + pos[i] /= d; + shift[i] += p_old - pos[i]; + + seq_pos_add(i); + + has_shift = true; + } + +private: + bool has_shift = false; + + // set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id) + std::set used; + + std::vector pos; + + // this array accumulates any applied shifts to the pos array since the last reset_shift() call + // this is used to queue multiple updates to the pos array, which in the end can be applied in one go: + // + // cells.pos_add(x, shift_x); + // cells.pos_div(y, shift_y); + // ... + // + // if (cells.has_shift()) { + // for (int i = 0; i < n; ++i) { + // auto shift_i = cells.get_shift(i); + // ... + // } + // cells.reset_shift(); + // } + // + std::vector shift; + + using bits_t = std::bitset; + + // the bitset seq[i] tells us which sequences are currently occupying the i-th cell + std::vector seq; + + // the set seq_pos[s] tells us which positions are currently present for sequence s + // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache + std::set seq_pos[LLAMA_MAX_SEQ]; + + // helper functions for updating `seq_pos`, once cell at a time: + + // remove cell i + void seq_pos_rm(uint32_t i) { + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (seq[i].test(s)) { + seq_pos[s].erase(pos[i]); + } + } + } + + // add cell i + void seq_pos_add(uint32_t i) { + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (seq[i].test(s)) { + seq_pos[s].insert(pos[i]); + } + } + } +}; diff --git a/src/llama-memory.cpp b/src/llama-memory.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f1107672c6476411b04521db02379255328e7728 --- /dev/null +++ b/src/llama-memory.cpp @@ -0,0 +1,42 @@ +#include "llama-memory.h" + +llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1) { + bool has_update = false; + + switch (s0) { + case LLAMA_MEMORY_STATUS_SUCCESS: + { + has_update = true; + break; + } + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + break; + } + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + return s0; + } + } + + switch (s1) { + case LLAMA_MEMORY_STATUS_SUCCESS: + { + has_update = true; + break; + } + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + break; + } + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + return s1; + } + } + + // if either status has an update, then the combined status has an update + return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE; +} diff --git a/src/llama-memory.h b/src/llama-memory.h new file mode 100644 index 0000000000000000000000000000000000000000..24668f861b976243bb8bb59fd149e615c71cb87a --- /dev/null +++ b/src/llama-memory.h @@ -0,0 +1,117 @@ +#pragma once + +#include "llama.h" + +#include +#include + +struct llama_ubatch; + +class llama_io_write_i; +class llama_io_read_i; + +struct llama_memory_params { + // kv cache + ggml_type type_k; + ggml_type type_v; + + // use full-size SWA cache + bool swa_full; +}; + +enum llama_memory_status { + LLAMA_MEMORY_STATUS_SUCCESS = 0, + LLAMA_MEMORY_STATUS_NO_UPDATE, + LLAMA_MEMORY_STATUS_FAILED_PREPARE, + LLAMA_MEMORY_STATUS_FAILED_COMPUTE, +}; + +// helper function for combining the status of two memory states +// useful for implementing hybrid memory types (e.g. iSWA) +llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1); + +// the interface for managing the memory state during batch processing +// this interface is implemented per memory type. see: +// - llama_kv_cache_unified_state +// - llama_kv_cache_unified_iswa_state +// ... +// +// the only method that can mutate the memory and the memory state is llama_memory_i::apply() +// +// TODO: rename to llama_memory_context_i ? +struct llama_memory_state_i { + virtual ~llama_memory_state_i() = default; + + // consume the current ubatch from the state and proceed to the next one + // return false if we are done + virtual bool next() = 0; + + // apply the memory state for the current ubatch to the memory object + // return false on failure + virtual bool apply() = 0; + + // TODO: this might get reworked in the future when refactoring llama_batch + virtual std::vector & out_ids() = 0; + + // get the current ubatch + virtual const llama_ubatch & get_ubatch() const = 0; + + // get the status of the memory state - used for error handling and checking if any updates would be applied + virtual llama_memory_status get_status() const = 0; +}; + +using llama_memory_state_ptr = std::unique_ptr; + +// general concept of LLM memory +// the KV cache is a type of LLM memory, but there can be other types +struct llama_memory_i { + virtual ~llama_memory_i() = default; + + // split the input batch into a set of ubatches and verify that they can fit into the cache + // return a state object containing the ubatches and KV cache state required to process them + // check the llama_memory_state_i::get_status() for the result + virtual llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_all) = 0; + + // simulate full cache, used for allocating worst-case compute buffers + virtual llama_memory_state_ptr init_full() = 0; + + // prepare for any pending memory updates, such as shifts, defrags, etc. + // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update + virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0; + + // getters + virtual bool get_can_shift() const = 0; + + // + // ops + // + + // if data == true, the data buffers will also be cleared together with the metadata + virtual void clear(bool data) = 0; + + virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0; + virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; + virtual void seq_keep(llama_seq_id seq_id) = 0; + virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0; + virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0; + + virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0; + virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0; + + // + // state write/read + // + + virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0; + virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0; +}; + +using llama_memory_ptr = std::unique_ptr; + +// TODO: temporary until the llama_kv_cache is removed from the public API +struct llama_kv_cache : public llama_memory_i { + virtual ~llama_kv_cache() = default; +}; diff --git a/src/llama-mmap.cpp b/src/llama-mmap.cpp new file mode 100644 index 0000000000000000000000000000000000000000..47497cf953fd3990d6e147e4837980bed6223b63 --- /dev/null +++ b/src/llama-mmap.cpp @@ -0,0 +1,600 @@ +#include "llama-mmap.h" + +#include "llama-impl.h" + +#include "ggml.h" + +#include +#include +#include +#include +#include + +#ifdef __has_include + #if __has_include() + #include + #if defined(_POSIX_MAPPED_FILES) + #include + #include + #endif + #if defined(_POSIX_MEMLOCK_RANGE) + #include + #endif + #endif +#endif + +#if defined(_WIN32) + #define WIN32_LEAN_AND_MEAN + #ifndef NOMINMAX + #define NOMINMAX + #endif + #include + #ifndef PATH_MAX + #define PATH_MAX MAX_PATH + #endif + #include +#endif + +#if defined(__APPLE__) +#include +#endif + +// TODO: consider moving to llama-impl.h if needed in more places +#if defined(_WIN32) +static std::string llama_format_win_err(DWORD err) { + LPSTR buf; + size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL); + if (!size) { + return "FormatMessageA failed"; + } + std::string ret(buf, size); + LocalFree(buf); + return ret; +} +#endif + +// llama_file + +struct llama_file::impl { +#if defined(_WIN32) + HANDLE fp_win32; + std::string GetErrorMessageWin32(DWORD error_code) const { + std::string ret; + LPSTR lpMsgBuf = NULL; + DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL); + if (!bufLen) { + ret = format("Win32 error code: %lx", error_code); + } else { + ret = lpMsgBuf; + LocalFree(lpMsgBuf); + } + + return ret; + } + + impl(const char * fname, const char * mode) { + fp = ggml_fopen(fname, mode); + if (fp == NULL) { + throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); + } + fp_win32 = (HANDLE) _get_osfhandle(_fileno(fp)); + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + + size_t tell() const { + LARGE_INTEGER li; + li.QuadPart = 0; + BOOL ret = SetFilePointerEx(fp_win32, li, &li, FILE_CURRENT); + if (!ret) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + + return li.QuadPart; + } + + void seek(size_t offset, int whence) const { + static_assert(SEEK_SET == FILE_BEGIN, "SEEK_SET != FILE_BEGIN"); + static_assert(SEEK_CUR == FILE_CURRENT, "SEEK_CUR != FILE_CURRENT"); + static_assert(SEEK_END == FILE_END, "SEEK_END != FILE_END"); + + LARGE_INTEGER li; + li.QuadPart = offset; + BOOL ret = SetFilePointerEx(fp_win32, li, NULL, whence); + if (!ret) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + } + + void read_raw(void * ptr, size_t len) const { + size_t bytes_read = 0; + while (bytes_read < len) { + size_t chunk_size = std::min(len - bytes_read, 64*1024*1024); + DWORD chunk_read = 0; + BOOL result = ReadFile(fp_win32, reinterpret_cast(ptr) + bytes_read, chunk_size, &chunk_read, NULL); + if (!result) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + if (chunk_read < chunk_size || chunk_read == 0) { + throw std::runtime_error("unexpectedly reached end of file"); + } + + bytes_read += chunk_read; + } + } + + uint32_t read_u32() const { + uint32_t val; + read_raw(&val, sizeof(val)); + return val; + } + + void write_raw(const void * ptr, size_t len) const { + size_t bytes_written = 0; + while (bytes_written < len) { + size_t chunk_size = std::min(len - bytes_written, 64*1024*1024); + DWORD chunk_written = 0; + BOOL result = WriteFile(fp_win32, reinterpret_cast(ptr) + bytes_written, chunk_size, &chunk_written, NULL); + if (!result) { + throw std::runtime_error(format("write error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + if (chunk_written < chunk_size || chunk_written == 0) { + throw std::runtime_error("unexpectedly failed to write bytes"); + } + + bytes_written += chunk_written; + } + } + + void write_u32(uint32_t val) const { + write_raw(&val, sizeof(val)); + } + + ~impl() { + if (fp) { + std::fclose(fp); + } + } +#else + impl(const char * fname, const char * mode) { + fp = ggml_fopen(fname, mode); + if (fp == NULL) { + throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); + } + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + + size_t tell() const { +// TODO: this ifdef is never true? +#ifdef _WIN32 + __int64 ret = _ftelli64(fp); +#else + long ret = std::ftell(fp); +#endif + if (ret == -1) { + throw std::runtime_error(format("ftell error: %s", strerror(errno))); + } + + return (size_t) ret; + } + + void seek(size_t offset, int whence) const { +// TODO: this ifdef is never true? +#ifdef _WIN32 + int ret = _fseeki64(fp, (__int64) offset, whence); +#else + int ret = std::fseek(fp, (long) offset, whence); +#endif + if (ret != 0) { + throw std::runtime_error(format("seek error: %s", strerror(errno))); + } + } + + void read_raw(void * ptr, size_t len) const { + if (len == 0) { + return; + } + errno = 0; + std::size_t ret = std::fread(ptr, len, 1, fp); + if (ferror(fp)) { + throw std::runtime_error(format("read error: %s", strerror(errno))); + } + if (ret != 1) { + throw std::runtime_error("unexpectedly reached end of file"); + } + } + + uint32_t read_u32() const { + uint32_t ret; + read_raw(&ret, sizeof(ret)); + return ret; + } + + void write_raw(const void * ptr, size_t len) const { + if (len == 0) { + return; + } + errno = 0; + size_t ret = std::fwrite(ptr, len, 1, fp); + if (ret != 1) { + throw std::runtime_error(format("write error: %s", strerror(errno))); + } + } + + void write_u32(uint32_t val) const { + write_raw(&val, sizeof(val)); + } + + ~impl() { + if (fp) { + std::fclose(fp); + } + } +#endif + + FILE * fp; + size_t size; +}; + +llama_file::llama_file(const char * fname, const char * mode) : pimpl(std::make_unique(fname, mode)) {} +llama_file::~llama_file() = default; + +size_t llama_file::tell() const { return pimpl->tell(); } +size_t llama_file::size() const { return pimpl->size; } + +int llama_file::file_id() const { +#ifdef _WIN32 + return _fileno(pimpl->fp); +#else +#if defined(fileno) + return fileno(pimpl->fp); +#else + return ::fileno(pimpl->fp); +#endif +#endif +} + +void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); } +void llama_file::read_raw(void * ptr, size_t len) const { pimpl->read_raw(ptr, len); } + +uint32_t llama_file::read_u32() const { return pimpl->read_u32(); } + +void llama_file::write_raw(const void * ptr, size_t len) const { pimpl->write_raw(ptr, len); } +void llama_file::write_u32(uint32_t val) const { pimpl->write_u32(val); } + +// llama_mmap + +struct llama_mmap::impl { +#ifdef _POSIX_MAPPED_FILES + std::vector> mapped_fragments; + + impl(struct llama_file * file, size_t prefetch, bool numa) { + size = file->size(); + int fd = file->file_id(); + int flags = MAP_SHARED; + if (numa) { prefetch = 0; } +#ifdef __linux__ + if (posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL)) { + LLAMA_LOG_WARN("warning: posix_fadvise(.., POSIX_FADV_SEQUENTIAL) failed: %s\n", + strerror(errno)); + } + if (prefetch) { flags |= MAP_POPULATE; } +#endif + addr = mmap(NULL, file->size(), PROT_READ, flags, fd, 0); + if (addr == MAP_FAILED) { + throw std::runtime_error(format("mmap failed: %s", strerror(errno))); + } + + if (prefetch > 0) { + if (posix_madvise(addr, std::min(file->size(), prefetch), POSIX_MADV_WILLNEED)) { + LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n", + strerror(errno)); + } + } + if (numa) { + if (posix_madvise(addr, file->size(), POSIX_MADV_RANDOM)) { + LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_RANDOM) failed: %s\n", + strerror(errno)); + } + } + + mapped_fragments.emplace_back(0, file->size()); + } + + static void align_range(size_t * first, size_t * last, size_t page_size) { + size_t offset_in_page = *first & (page_size - 1); + size_t offset_to_page = offset_in_page == 0 ? 0 : page_size - offset_in_page; + *first += offset_to_page; + + *last = *last & ~(page_size - 1); + + if (*last <= *first) { + *last = *first; + } + } + + void unmap_fragment(size_t first, size_t last) { + int page_size = sysconf(_SC_PAGESIZE); + align_range(&first, &last, page_size); + size_t len = last - first; + + if (len == 0) { + return; + } + + GGML_ASSERT(first % page_size == 0); + GGML_ASSERT(last % page_size == 0); + GGML_ASSERT(last > first); + + void * next_page_start = (uint8_t *) addr + first; + + if (munmap(next_page_start, len)) { + LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno)); + } + + std::vector> new_mapped_fragments; + for (const auto & frag : mapped_fragments) { + if (frag.first < first && frag.second > last) { + new_mapped_fragments.emplace_back(frag.first, first); + new_mapped_fragments.emplace_back(last, frag.second); + } else if (frag.first < first && frag.second > first) { + new_mapped_fragments.emplace_back(frag.first, first); + } else if (frag.first < last && frag.second > last) { + new_mapped_fragments.emplace_back(last, frag.second); + } else if (frag.first >= first && frag.second <= last) { + } else { + new_mapped_fragments.push_back(frag); + } + } + mapped_fragments = std::move(new_mapped_fragments); + } + + ~impl() { + for (const auto & frag : mapped_fragments) { + if (munmap((char *) addr + frag.first, frag.second - frag.first)) { + LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno)); + } + } + } +#elif defined(_WIN32) + impl(struct llama_file * file, size_t prefetch, bool numa) { + GGML_UNUSED(numa); + + size = file->size(); + + HANDLE hFile = (HANDLE) _get_osfhandle(file->file_id()); + + HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); + + if (hMapping == NULL) { + DWORD error = GetLastError(); + throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str())); + } + + addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); + DWORD error = GetLastError(); + CloseHandle(hMapping); + + if (addr == NULL) { + throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str())); + } + + if (prefetch > 0) { +#if _WIN32_WINNT >= 0x602 + BOOL (WINAPI *pPrefetchVirtualMemory) (HANDLE, ULONG_PTR, PWIN32_MEMORY_RANGE_ENTRY, ULONG); + HMODULE hKernel32 = GetModuleHandleW(L"kernel32.dll"); + + pPrefetchVirtualMemory = (decltype(pPrefetchVirtualMemory))(void *) GetProcAddress(hKernel32, "PrefetchVirtualMemory"); + + if (pPrefetchVirtualMemory) { + WIN32_MEMORY_RANGE_ENTRY range; + range.VirtualAddress = addr; + range.NumberOfBytes = (SIZE_T) std::min(size, prefetch); + if (!pPrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) { + LLAMA_LOG_WARN("warning: PrefetchVirtualMemory failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } +#else + LLAMA_LOG_DEBUG("skipping PrefetchVirtualMemory because _WIN32_WINNT < 0x602\n"); +#endif + } + } + + void unmap_fragment(size_t first, size_t last) { + GGML_UNUSED(first); + GGML_UNUSED(last); + } + + ~impl() { + if (!UnmapViewOfFile(addr)) { + LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } +#else + impl(struct llama_file * file, size_t prefetch, bool numa) { + GGML_UNUSED(file); + GGML_UNUSED(prefetch); + GGML_UNUSED(numa); + + throw std::runtime_error("mmap not supported"); + } + + void unmap_fragment(size_t first, size_t last) { + GGML_UNUSED(first); + GGML_UNUSED(last); + + throw std::runtime_error("mmap not supported"); + } +#endif + + void * addr; + size_t size; +}; + +llama_mmap::llama_mmap(struct llama_file * file, size_t prefetch, bool numa) : pimpl(std::make_unique(file, prefetch, numa)) {} +llama_mmap::~llama_mmap() = default; + +size_t llama_mmap::size() const { return pimpl->size; } +void * llama_mmap::addr() const { return pimpl->addr; } + +void llama_mmap::unmap_fragment(size_t first, size_t last) { pimpl->unmap_fragment(first, last); } + +#if defined(_POSIX_MEMLOCK_RANGE) || defined(_WIN32) +const bool llama_mmap::SUPPORTED = true; +#else +const bool llama_mmap::SUPPORTED = false; +#endif + +// llama_mlock + +struct llama_mlock::impl { +#ifdef _POSIX_MEMLOCK_RANGE + static size_t lock_granularity() { + return (size_t) sysconf(_SC_PAGESIZE); + } + + bool raw_lock(const void * addr, size_t size) const { + if (!mlock(addr, size)) { + return true; + } + +#ifdef __APPLE__ +#define MLOCK_SUGGESTION \ + "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \ + "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MEMLOCK (ulimit -l).\n" +#else +#define MLOCK_SUGGESTION \ + "Try increasing RLIMIT_MEMLOCK ('ulimit -l' as root).\n" +#endif + + char* errmsg = std::strerror(errno); + bool suggest = (errno == ENOMEM); +#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX) + // visionOS/tvOS dont't support RLIMIT_MEMLOCK + // Skip resource limit checks on visionOS/tvOS + suggest = false; +#else + struct rlimit lock_limit; + if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) { + suggest = false; + } + if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) { + suggest = false; + } +#endif + + LLAMA_LOG_WARN("warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s", + size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : ""); + return false; + } + + static void raw_unlock(void * addr, size_t size) { + if (munlock(addr, size)) { + LLAMA_LOG_WARN("warning: failed to munlock buffer: %s\n", std::strerror(errno)); + } + } +#elif defined(_WIN32) + static size_t lock_granularity() { + SYSTEM_INFO si; + GetSystemInfo(&si); + return (size_t) si.dwPageSize; + } + + bool raw_lock(void * ptr, size_t len) const { + for (int tries = 1; ; tries++) { + if (VirtualLock(ptr, len)) { + return true; + } + if (tries == 2) { + LLAMA_LOG_WARN("warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n", + len, size, llama_format_win_err(GetLastError()).c_str()); + return false; + } + + SIZE_T min_ws_size, max_ws_size; + if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) { + LLAMA_LOG_WARN("warning: GetProcessWorkingSetSize failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + return false; + } + size_t increment = len + 1048576; + min_ws_size += increment; + max_ws_size += increment; + if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) { + LLAMA_LOG_WARN("warning: SetProcessWorkingSetSize failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + return false; + } + } + } + + static void raw_unlock(void * ptr, size_t len) { + if (!VirtualUnlock(ptr, len)) { + LLAMA_LOG_WARN("warning: failed to VirtualUnlock buffer: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } +#else + static size_t lock_granularity() { + return (size_t) 65536; + } + + bool raw_lock(const void * addr, size_t len) const { + LLAMA_LOG_WARN("warning: mlock not supported on this system\n"); + return false; + } + + static void raw_unlock(const void * addr, size_t len) {} +#endif + + impl() : addr(NULL), size(0), failed_already(false) {} + + void init(void * ptr) { + GGML_ASSERT(addr == NULL && size == 0); + addr = ptr; + } + + void grow_to(size_t target_size) { + GGML_ASSERT(addr); + if (failed_already) { + return; + } + size_t granularity = lock_granularity(); + target_size = (target_size + granularity - 1) & ~(granularity - 1); + if (target_size > size) { + if (raw_lock((uint8_t *) addr + size, target_size - size)) { + size = target_size; + } else { + failed_already = true; + } + } + } + + void * addr; + size_t size; + + bool failed_already; +}; + +llama_mlock::llama_mlock() : pimpl(std::make_unique()) {} +llama_mlock::~llama_mlock() = default; + +void llama_mlock::init(void * ptr) { pimpl->init(ptr); } +void llama_mlock::grow_to(size_t target_size) { pimpl->grow_to(target_size); } + +#if defined(_POSIX_MEMLOCK_RANGE) || defined(_WIN32) +const bool llama_mlock::SUPPORTED = true; +#else +const bool llama_mlock::SUPPORTED = false; +#endif + +size_t llama_path_max() { + return PATH_MAX; +} diff --git a/src/llama-mmap.h b/src/llama-mmap.h new file mode 100644 index 0000000000000000000000000000000000000000..4e5aec3f440d7005d254578e08124289c358054a --- /dev/null +++ b/src/llama-mmap.h @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +#include + +struct llama_file; +struct llama_mmap; +struct llama_mlock; + +using llama_files = std::vector>; +using llama_mmaps = std::vector>; +using llama_mlocks = std::vector>; + +struct llama_file { + llama_file(const char * fname, const char * mode); + ~llama_file(); + + size_t tell() const; + size_t size() const; + + int file_id() const; // fileno overload + + void seek(size_t offset, int whence) const; + + void read_raw(void * ptr, size_t len) const; + uint32_t read_u32() const; + + void write_raw(const void * ptr, size_t len) const; + void write_u32(uint32_t val) const; + +private: + struct impl; + std::unique_ptr pimpl; +}; + +struct llama_mmap { + llama_mmap(const llama_mmap &) = delete; + llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1, bool numa = false); + ~llama_mmap(); + + size_t size() const; + void * addr() const; + + void unmap_fragment(size_t first, size_t last); + + static const bool SUPPORTED; + +private: + struct impl; + std::unique_ptr pimpl; +}; + +struct llama_mlock { + llama_mlock(); + ~llama_mlock(); + + void init(void * ptr); + void grow_to(size_t target_size); + + static const bool SUPPORTED; + +private: + struct impl; + std::unique_ptr pimpl; +}; + +size_t llama_path_max(); diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bd9e6da8832b78c7d5a1f4661ef84c33269bea10 --- /dev/null +++ b/src/llama-model-loader.cpp @@ -0,0 +1,1163 @@ +#include "llama-model-loader.h" + +#include "ggml.h" + +#include +#include +#include +#include + +static const size_t kiB = 1024; +static const size_t MiB = 1024*kiB; +static const size_t GiB = 1024*MiB; + +const char * llama_file_version_name(llama_fver version) { + switch (version) { + case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)"; + case GGUF_FILE_VERSION_V2: return "GGUF V2"; + case GGUF_FILE_VERSION_V3: return "GGUF V3 (latest)"; + } + + return "unknown"; +} + +static std::string llama_model_ftype_name(llama_ftype ftype) { + if (ftype & LLAMA_FTYPE_GUESSED) { + return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)"; + } + + switch (ftype) { + case LLAMA_FTYPE_ALL_F32: return "all F32"; + case LLAMA_FTYPE_MOSTLY_F16: return "F16"; + case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; + case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; + case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; + case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; + case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; + case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; + case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large"; + case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; + case LLAMA_FTYPE_MOSTLY_TQ1_0: return "TQ1_0 - 1.69 bpw ternary"; + case LLAMA_FTYPE_MOSTLY_TQ2_0: return "TQ2_0 - 2.06 bpw ternary"; + case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_M: return "IQ2_M - 2.7 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return "IQ3_XXS - 3.0625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ1_S: return "IQ1_S - 1.5625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; + + default: return "unknown, may not work"; + } +} + +// return a list of splits for a given path +// for example, given "-00002-of-00004.gguf", returns list of all 4 splits +static std::vector llama_get_list_splits(const std::string & path, const int idx, const int n_split) { + std::vector paths; + std::string split_prefix; + std::vector buf(llama_path_max(), 0); + + { + int ret = llama_split_prefix(buf.data(), buf.size(), path.c_str(), idx, n_split); + if (!ret) { + throw std::runtime_error(format("invalid split file name: %s", path.c_str())); + } + split_prefix = std::string(buf.data(), ret); + } + + if (split_prefix.empty()) { + throw std::runtime_error(format("invalid split file: %s", path.c_str())); + } + + for (int idx = 0; idx < n_split; ++idx) { + int ret = llama_split_path(buf.data(), buf.size(), split_prefix.c_str(), idx, n_split); + paths.push_back(std::string(buf.data(), ret)); + } + + return paths; +} + +namespace GGUFMeta { + template + struct GKV_Base_Type { + static constexpr gguf_type gt = gt_; + + static T getter(const gguf_context * ctx, const int kid) { + return gfun(ctx, kid); + } + }; + + template struct GKV_Base; + + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + + template<> struct GKV_Base { + static constexpr gguf_type gt = GGUF_TYPE_STRING; + + static std::string getter(const gguf_context * ctx, const int kid) { + return gguf_get_val_str(ctx, kid); + } + }; + + struct ArrayInfo { + const gguf_type gt; + const size_t length; + const void * data; + }; + + template<> struct GKV_Base { + public: + static constexpr gguf_type gt = GGUF_TYPE_ARRAY; + static ArrayInfo getter(const gguf_context *ctx, const int k) { + const enum gguf_type arr_type = gguf_get_arr_type(ctx, k); + return ArrayInfo { + arr_type, + size_t(gguf_get_arr_n(ctx, k)), + arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx, k), + }; + } + }; + + template + class GKV : public GKV_Base { + GKV() = delete; + + public: + static T get_kv(const gguf_context * ctx, const int k) { + const enum gguf_type kt = gguf_get_kv_type(ctx, k); + + if (kt != GKV::gt) { + throw std::runtime_error(format("key %s has wrong type %s but expected type %s", + gguf_get_key(ctx, k), gguf_type_name(kt), gguf_type_name(GKV::gt))); + } + return GKV::getter(ctx, k); + } + + static const char * override_type_to_str(const llama_model_kv_override_type ty) { + switch (ty) { + case LLAMA_KV_OVERRIDE_TYPE_BOOL: return "bool"; + case LLAMA_KV_OVERRIDE_TYPE_INT: return "int"; + case LLAMA_KV_OVERRIDE_TYPE_FLOAT: return "float"; + case LLAMA_KV_OVERRIDE_TYPE_STR: return "str"; + } + return "unknown"; + } + + static bool validate_override(const llama_model_kv_override_type expected_type, const struct llama_model_kv_override * ovrd) { + if (!ovrd) { return false; } + if (ovrd->tag == expected_type) { + LLAMA_LOG_INFO("%s: Using metadata override (%5s) '%s' = ", + __func__, override_type_to_str(ovrd->tag), ovrd->key); + switch (ovrd->tag) { + case LLAMA_KV_OVERRIDE_TYPE_BOOL: { + LLAMA_LOG_INFO("%s\n", ovrd->val_bool ? "true" : "false"); + } break; + case LLAMA_KV_OVERRIDE_TYPE_INT: { + LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->val_i64); + } break; + case LLAMA_KV_OVERRIDE_TYPE_FLOAT: { + LLAMA_LOG_INFO("%.6f\n", ovrd->val_f64); + } break; + case LLAMA_KV_OVERRIDE_TYPE_STR: { + LLAMA_LOG_INFO("%s\n", ovrd->val_str); + } break; + default: + // Shouldn't be possible to end up here, but just in case... + throw std::runtime_error( + format("Unsupported attempt to override %s type for metadata key %s\n", + override_type_to_str(ovrd->tag), ovrd->key)); + } + return true; + } + LLAMA_LOG_WARN("%s: Warning: Bad metadata override type for key '%s', expected %s but got %s\n", + __func__, ovrd->key, override_type_to_str(expected_type), override_type_to_str(ovrd->tag)); + return false; + } + + template + static typename std::enable_if::value, bool>::type + try_override(OT & target, const struct llama_model_kv_override * ovrd) { + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_BOOL, ovrd)) { + target = ovrd->val_bool; + return true; + } + return false; + } + + template + static typename std::enable_if::value && std::is_integral::value, bool>::type + try_override(OT & target, const struct llama_model_kv_override * ovrd) { + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_INT, ovrd)) { + target = ovrd->val_i64; + return true; + } + return false; + } + + template + static typename std::enable_if::value, bool>::type + try_override(T & target, const struct llama_model_kv_override * ovrd) { + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_FLOAT, ovrd)) { + target = ovrd->val_f64; + return true; + } + return false; + } + + template + static typename std::enable_if::value, bool>::type + try_override(T & target, const struct llama_model_kv_override * ovrd) { + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_STR, ovrd)) { + target = ovrd->val_str; + return true; + } + return false; + } + + static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override * ovrd = nullptr) { + if (try_override(target, ovrd)) { + return true; + } + if (k < 0) { return false; } + target = get_kv(ctx, k); + return true; + } + + static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override * ovrd = nullptr) { + return set(ctx, gguf_find_key(ctx, key), target, ovrd); + } + + static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override * ovrd = nullptr) { + return set(ctx, key.c_str(), target, ovrd); + } + }; +} + + template + typename std::enable_if::value, bool>::type + llama_model_loader::get_arr_n(const std::string & key, T & result, bool required) { + const int kid = gguf_find_key(meta.get(), key.c_str()); + + if (kid < 0) { + if (required) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + return false; + } + + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta.get(), kid); + + + result = arr_info.length; + return true; + } + + template + typename std::enable_if::value, bool>::type + llama_model_loader::get_arr_n(enum llm_kv kid, T & result, bool required) { + return get_arr_n(llm_kv(kid), result, required); + } + + template bool llama_model_loader::get_arr_n(enum llm_kv kid, uint32_t & result, bool required); + + template + bool llama_model_loader::get_arr(const std::string & key, std::vector & result, bool required) { + const gguf_context * ctx = meta.get(); + const int kid = gguf_find_key(ctx, key.c_str()); + + if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) { + if (required) { + throw std::runtime_error(format("array key not found in model: %s", key.c_str())); + } + return false; + } + + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(ctx, kid); + + switch (arr_info.gt) { + case GGUF_TYPE_UINT32: + case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same::value) || + (std::is_same::value)); break; + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; + case GGUF_TYPE_STRING: GGML_ASSERT((std::is_same::value)); break; + default: + throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str())); + } + + if constexpr (std::is_same::value) { + const size_t n_items = gguf_get_arr_n(ctx, kid); + result.clear(); + + for (size_t i = 0; i < n_items; i++) { + const T value = gguf_get_arr_str(ctx, kid, i); + result.emplace_back(value); + } + } else { + result.resize(arr_info.length); + result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length); + } + + return true; + } + + template + bool llama_model_loader::get_arr(const std::string & key, std::array & result, bool required) { + const gguf_context * ctx = meta.get(); + const int kid = gguf_find_key(ctx, key.c_str()); + + if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) { + if (required) { + throw std::runtime_error(format("array key not found in model: %s", key.c_str())); + } + return false; + } + + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(ctx, kid); + + switch (arr_info.gt) { + case GGUF_TYPE_UINT32: + case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same::value) || + (std::is_same::value)); break; + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; + case GGUF_TYPE_STRING: GGML_ASSERT((std::is_same::value)); break; + default: + throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str())); + } + + if (arr_info.length > N_MAX) { + throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX)); + } + + if constexpr (std::is_same::value) { + const size_t n_items = gguf_get_arr_n(ctx, kid); + + for (size_t i = 0; i < n_items; i++) { + const T value = gguf_get_arr_str(ctx, kid, i); + result[i] = value; + } + } else { + std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); + } + + return true; + } + + template + bool llama_model_loader::get_arr(enum llm_kv kid, T & result, bool required) { + return get_arr(llm_kv(kid), result, required); + } + + template bool llama_model_loader::get_arr>(enum llm_kv kid, std::vector & result, bool required); + + template + bool llama_model_loader::get_key(const std::string & key, T & result, bool required) { + auto it = kv_overrides.find(key); + + const struct llama_model_kv_override * override = + it != kv_overrides.end() ? &it->second : nullptr; + + const bool found = GGUFMeta::GKV::set(meta.get(), key, result, override); + + if (required && !found) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + + return found; + } + + template + bool llama_model_loader::get_key(enum llm_kv kid, T & result, bool required) { + return get_key(llm_kv(kid), result, required); + } + + template bool llama_model_loader::get_key (enum llm_kv kid, bool & result, bool required); + template bool llama_model_loader::get_key (enum llm_kv kid, float & result, bool required); + template bool llama_model_loader::get_key (enum llm_kv kid, uint32_t & result, bool required); + template bool llama_model_loader::get_key(enum llm_kv kid, std::string & result, bool required); + + template<> + bool llama_model_loader::get_key(enum llm_kv kid, enum llama_pooling_type & result, bool required) { + uint32_t tmp; + const bool found = get_key(kid, tmp, required); + if (found) { + result = (enum llama_pooling_type) tmp; + } else { + result = LLAMA_POOLING_TYPE_UNSPECIFIED; + } + return found; + } + + // get array of n <= N_MAX elements, or a single element repeated n times + template + bool llama_model_loader::get_key_or_arr(const std::string & key, std::array & result, uint32_t n, bool required) { + const int kid = gguf_find_key(meta.get(), key.c_str()); + + if (kid < 0) { + if (required) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + return false; + } + + if (n > N_MAX) { + throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str())); + } + + if (gguf_get_kv_type(meta.get(), kid) == GGUF_TYPE_ARRAY) { + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta.get(), kid); + + if (n != arr_info.length) { + throw std::runtime_error(format("key %s has wrong array length; expected %u, got %u", key.c_str(), n, (uint32_t) arr_info.length)); + } + + return get_arr(key, result, required); + } + + T value; + + bool ok = get_key(key, value, required); + if (!ok) { + return false; + } + + for (uint32_t i = 0; i < n; i++) { + result[i] = value; + } + + return true; + } + + template + bool llama_model_loader::get_key_or_arr(enum llm_kv kid, T & result, uint32_t n, bool required) { + return get_key_or_arr(llm_kv(kid), result, n, required); + } + + // TODO: this is not very clever - figure out something better + template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + +llama_model_loader::llama_model_loader( + const std::string & fname, + std::vector & splits, + bool use_mmap, + bool check_tensors, + const llama_model_kv_override * param_overrides_p, + const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) { + int trace = 0; + if (getenv("LLAMA_TRACE")) { + trace = atoi(getenv("LLAMA_TRACE")); + } + + if (param_overrides_p != nullptr) { + for (const struct llama_model_kv_override * p = param_overrides_p; p->key[0] != 0; p++) { + kv_overrides.insert({std::string(p->key), *p}); + } + } + + tensor_buft_overrides = param_tensor_buft_overrides_p; + + // Load the main GGUF + struct ggml_context * ctx = NULL; + struct gguf_init_params params = { + /*.no_alloc = */ true, + /*.ctx = */ &ctx, + }; + + meta.reset(gguf_init_from_file(fname.c_str(), params)); + if (!meta) { + throw std::runtime_error(format("%s: failed to load model from %s", __func__, fname.c_str())); + } + + get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); + llm_kv = LLM_KV(llm_arch_from_string(arch_name)); + + files.emplace_back(new llama_file(fname.c_str(), "rb")); + contexts.emplace_back(ctx); + + // Save tensors data offset of the main file. + // For subsidiary files, `meta` tensor data offset must not be used, + // so we build a unified tensors index for weights. + for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + std::string tensor_name = std::string(cur->name); + // make sure there is no duplicated tensor names + if (weights_map.find(tensor_name) != weights_map.end()) { + throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur))); + } + n_elements += ggml_nelements(cur); + n_bytes += ggml_nbytes(cur); + weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), 0, meta.get(), cur)); + } + uint16_t n_split = 0; + get_key(llm_kv(LLM_KV_SPLIT_COUNT), n_split, false); + + // Load additional GGML contexts + if (n_split > 1) { + // make sure the main file is loaded first + uint16_t idx = 0; + const std::string kv_split_no = llm_kv(LLM_KV_SPLIT_NO); + get_key(kv_split_no, idx); + if (idx != 0) { + throw std::runtime_error(format("illegal split file idx: %d (file: %s), model must be loaded with the first split", idx, fname.c_str())); + } + + // generate list of splits if needed + if (splits.empty()) { + splits = llama_get_list_splits(fname, idx, n_split); + } + + // in case user give a custom list of splits, check if it matches the expected number + if (n_split != (uint16_t)splits.size()) { + throw std::runtime_error(format("invalid split count, given: %zu splits, but expected %d", splits.size(), n_split)); + } + + if (trace > 0) { + LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split); + } + + // load other splits + for (idx = 1; idx < n_split; idx++) { + const char * fname_split = splits[idx].c_str(); + + struct gguf_init_params split_params = { + /*.no_alloc = */ true, + /*.ctx = */ &ctx, + }; + gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) }; + if (!ctx_gguf) { + throw std::runtime_error(format("%s: failed to load GGUF split from %s", __func__, fname_split)); + } + + // check idx + { + const int kid = gguf_find_key(ctx_gguf.get(), kv_split_no.c_str()); + if (kid < 0) { + throw std::runtime_error(format("missing key %s in GGUF split %s", kv_split_no.c_str(), fname_split)); + } + int idx_gguf = gguf_get_val_u16(ctx_gguf.get(), kid); + if (idx_gguf != idx) { + throw std::runtime_error(format("invalid split file idx: %d (file: %s), expected %d", idx_gguf, fname_split, idx)); + } + } + + files.emplace_back(new llama_file(fname_split, "rb")); + contexts.emplace_back(ctx); + + // Save tensors data offset info of the shard. + for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + std::string tensor_name = std::string(cur->name); + // make sure there is no duplicated tensor names + if (weights_map.find(tensor_name) != weights_map.end()) { + throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur))); + } + n_elements += ggml_nelements(cur); + n_bytes += ggml_nbytes(cur); + weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), idx, ctx_gguf.get(), cur)); + } + } + + get_key(llm_kv(LLM_KV_SPLIT_TENSORS_COUNT), n_tensors); + + // sanity check + { + const int n_tensors_loaded = (int) weights_map.size(); + if (n_tensors != n_tensors_loaded) { + throw std::runtime_error(format("corrupted model: %d tensors expected but %d found", n_tensors, n_tensors_loaded)); + } + } + + LLAMA_LOG_INFO("%s: additional %d GGUFs metadata loaded.\n", __func__, n_split - 1); + } + + n_kv = gguf_get_n_kv(meta.get()); + n_tensors = weights_map.size(); + + fver = (enum llama_fver) gguf_get_version(meta.get()); + + LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n", + __func__, n_kv, n_tensors, fname.c_str(), llama_file_version_name(fver)); + + // determine file type based on the number of tensors for each quantization and print meta data + // TODO: make optional + { + std::map n_type; + + uint32_t n_type_max = 0; + enum ggml_type type_max = GGML_TYPE_F32; + + for (const auto & it : weights_map) { + const llama_tensor_weight & w = it.second; + const ggml_tensor * tensor = w.tensor; + + enum ggml_type type = tensor->type; + + n_type[type]++; + + if (n_type_max < n_type[type]) { + n_type_max = n_type[type]; + type_max = type; + } + + if (trace > 0) { + const uint16_t sid = w.idx; + LLAMA_LOG_INFO("%s: - tensor split %2d: %32s %-8s [ %s ] %8.2f MiB\n", __func__, + sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str(), + ggml_nbytes(tensor)/1024.0f/1024.0f); + } + } + + switch (type_max) { + case GGML_TYPE_F32: ftype = LLAMA_FTYPE_ALL_F32; break; + case GGML_TYPE_F16: ftype = LLAMA_FTYPE_MOSTLY_F16; break; + case GGML_TYPE_BF16: ftype = LLAMA_FTYPE_MOSTLY_BF16; break; + case GGML_TYPE_Q4_0: ftype = LLAMA_FTYPE_MOSTLY_Q4_0; break; + case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break; + case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break; + case GGML_TYPE_Q5_1: ftype = LLAMA_FTYPE_MOSTLY_Q5_1; break; + case GGML_TYPE_Q8_0: ftype = LLAMA_FTYPE_MOSTLY_Q8_0; break; + case GGML_TYPE_Q2_K: ftype = LLAMA_FTYPE_MOSTLY_Q2_K; break; + case GGML_TYPE_Q3_K: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M; break; + case GGML_TYPE_Q4_K: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M; break; + case GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break; + case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break; + case GGML_TYPE_TQ1_0: ftype = LLAMA_FTYPE_MOSTLY_TQ1_0; break; + case GGML_TYPE_TQ2_0: ftype = LLAMA_FTYPE_MOSTLY_TQ2_0; break; + case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break; + case GGML_TYPE_IQ2_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS; break; + case GGML_TYPE_IQ2_S: ftype = LLAMA_FTYPE_MOSTLY_IQ2_S; break; + case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break; + case GGML_TYPE_IQ1_S: ftype = LLAMA_FTYPE_MOSTLY_IQ1_S; break; + case GGML_TYPE_IQ1_M: ftype = LLAMA_FTYPE_MOSTLY_IQ1_M; break; + case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; + case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; + case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; + default: + { + LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); + ftype = LLAMA_FTYPE_ALL_F32; + } break; + } + + // this is a way to mark that we have "guessed" the file type + ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED); + + { + uint32_t ftype_val = 0; + if (get_key(LLM_KV_GENERAL_FILE_TYPE, ftype_val, false)) { + ftype = (llama_ftype) ftype_val; + } + } + + LLAMA_LOG_INFO("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__); + + for (int i = 0; i < n_kv; i++) { + const char * name = gguf_get_key(meta.get(), i); + const enum gguf_type type = gguf_get_kv_type(meta.get(), i); + const std::string type_name = + type == GGUF_TYPE_ARRAY + ? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta.get(), i)), gguf_get_arr_n(meta.get(), i)) + : gguf_type_name(type); + + std::string value = gguf_kv_to_str(meta.get(), i); + const size_t MAX_VALUE_LEN = 40; + if (value.size() > MAX_VALUE_LEN) { + value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str()); + } + replace_all(value, "\n", "\\n"); + + LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), value.c_str()); + } + + // print type counts + for (auto & kv : n_type) { + if (kv.second == 0) { + continue; + } + + LLAMA_LOG_INFO("%s: - type %4s: %4d tensors\n", __func__, ggml_type_name(kv.first), kv.second); + } + } + + if (!llama_mmap::SUPPORTED) { + LLAMA_LOG_WARN("%s: mmap is not supported on this platform\n", __func__); + use_mmap = false; + } + + this->use_mmap = use_mmap; + this->check_tensors = check_tensors; +} + +std::string llama_model_loader::get_arch_name() const { + return arch_name; +} + +enum llm_arch llama_model_loader::get_arch() const { + return llm_kv.arch; +} + +const llama_model_loader::llama_tensor_weight * llama_model_loader::get_weight(const char * name) const { + auto pos = weights_map.find(name); + if (pos != weights_map.end()) { + return &pos->second; + } + + return nullptr; +} + +const llama_model_loader::llama_tensor_weight & llama_model_loader::require_weight(const char * name) const { + const llama_tensor_weight * weight = get_weight(name); + if (!weight) { + throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name)); + } + return *weight; +} + +struct ggml_tensor * llama_model_loader::get_tensor_meta(const char * name) const { + const auto * weight = get_weight(name); + if (!weight) { + return nullptr; + } + return weight->tensor; +} + +struct ggml_tensor * llama_model_loader::require_tensor_meta(const std::string & name) const { + struct ggml_tensor * tensor = get_tensor_meta(name.c_str()); + if (!tensor) { + throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str())); + } + return tensor; +} + +const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::string & name, const std::vector & ne, bool required) const { + const struct ggml_tensor * cur = get_tensor_meta(name.c_str()); + + if (cur == NULL) { + if (!required) { + return NULL; + } + throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str())); + } + + { + bool is_ok = true; + for (size_t i = 0; i < GGML_MAX_DIMS; ++i) { + if ((i < ne.size() && ne[i] != cur->ne[i]) || (i >= ne.size() && cur->ne[i] != 1)) { + is_ok = false; + break; + } + } + if (!is_ok) { + throw std::runtime_error( + format("%s: tensor '%s' has wrong shape; expected %s, got %s", + __func__, name.c_str(), + llama_format_tensor_shape(ne).c_str(), + llama_format_tensor_shape(cur).c_str())); + } + } + + return cur; +} + +struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list & ne, int flags) { + const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED)); + + if (cur == NULL) { + return NULL; + } + + bool duplicated = flags & TENSOR_DUPLICATED; + + struct ggml_tensor * tensor = ggml_dup_tensor(ctx, cur); + ggml_set_name(tensor, ggml_get_name(cur)); + + if (duplicated) { + size_data += ggml_nbytes(cur); + } else { + n_created++; + } + + return tensor; + +} + +struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list & ne, size_t offset, bool required) { + const struct ggml_tensor * cur = check_tensor_dims(name, ne, required); + + if (cur == NULL) { + return NULL; + } + + if (cur->type != base->type) { + throw std::runtime_error(format("%s: tensor '%s' has wrong type; expected %s, got %s", __func__, name.c_str(), ggml_type_name(base->type), ggml_type_name(cur->type))); + } + + std::array dims; + for (size_t i = 0; i < GGML_MAX_DIMS; ++i) { + dims[i] = i < ne.size() ? ne.begin()[i] : 1; + } + + struct ggml_tensor * tensor = ggml_view_4d(ctx, base, + dims[0], dims[1], dims[2], dims[3], + cur->nb[1], cur->nb[2], cur->nb[3], + offset); + + ggml_set_name(tensor, name.c_str()); + + n_created++; + + return tensor; +} + +void llama_model_loader::done_getting_tensors() const { + if (n_created != n_tensors) { + throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); + } +} + +void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps) { + if (use_mmap) { + mappings.reserve(files.size()); + mmaps_used.reserve(files.size()); + for (const auto & file : files) { + bool is_numa = false; + + auto * dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (dev) { + auto * reg = ggml_backend_dev_backend_reg(dev); + auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa"); + if (is_numa_fn) { + is_numa = is_numa_fn(); + } + } + + std::unique_ptr mapping = std::make_unique(file.get(), prefetch ? -1 : 0, is_numa); + mmaps_used.emplace_back(mapping->size(), 0); + if (mlock_mmaps) { + std::unique_ptr mlock_mmap(new llama_mlock()); + mlock_mmap->init(mapping->addr()); + mlock_mmaps->emplace_back(std::move(mlock_mmap)); + } + mappings.emplace_back(std::move(mapping)); + } + } + + // compute the total size of all tensors for progress reporting + for (const auto & it : weights_map) { + size_data += ggml_nbytes(it.second.tensor); + } +} + +void llama_model_loader::get_mapping_range(size_t * first, size_t * last, void ** addr, int idx, ggml_context * ctx) const { + GGML_ASSERT(!mappings.empty()); + const auto & mapping = mappings.at(idx); + + *first = mapping->size(); + *last = 0; + *addr = mapping->addr(); + for (ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor; tensor = ggml_get_next_tensor(ctx, tensor)) { + const auto * weight = get_weight(ggml_get_name(tensor)); + if (!weight || weight->idx != idx) { + continue; + } + *first = std::min(*first, weight->offs); + *last = std::max(*last, weight->offs + ggml_nbytes(tensor)); + } +} + +void llama_model_loader::load_data_for(struct ggml_tensor * cur) const { + const auto & w = require_weight(ggml_get_name(cur)); + + if (use_mmap) { + const auto & mapping = mappings.at(w.idx); + if (cur->data == nullptr) { + cur->data = (uint8_t *)mapping->addr() + w.offs; + } else { + memcpy(cur->data, (uint8_t *)mapping->addr() + w.offs, ggml_nbytes(cur)); + } + } else { + GGML_ASSERT(cur->data != nullptr); + GGML_ASSERT(w.idx < files.size()); + const auto & file = files.at(w.idx); + file->seek(w.offs, SEEK_SET); + file->read_raw(cur->data, ggml_nbytes(cur)); + } + + if (check_tensors && !ggml_validate_row_data(cur->type, cur->data, ggml_nbytes(cur))) { + throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); + } +} + +bool llama_model_loader::load_all_data( + struct ggml_context * ctx, + llama_buf_map & bufs, + llama_mlocks * lmlocks, + llama_progress_callback progress_callback, + void * progress_callback_user_data) { + GGML_ASSERT(size_data != 0 && "call init_mappings() first"); + + std::vector> read_buf; + std::vector>> validation_result; + + // 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives. + // NVMe raid configurations might require more / larger buffers. + constexpr size_t n_buffers = 4; + constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB + + std::vector host_buffers; + std::vector events; + std::vector host_ptrs; + size_t buffer_idx = 0; // buffer to use for async loads + ggml_backend_t upload_backend = [&](const char * func) -> ggml_backend_t { + if (use_mmap || check_tensors) { + return nullptr; + } + // When not using mmaped io use async uploads from pinned memory to GPU memory. + // First determine if the backend supports the necessary features for async uploads. + auto * buf = bufs.count(0) ? bufs.at(0) : nullptr; + if (!buf) { + LLAMA_LOG_DEBUG("%s: no buffer found for async uploads\n", func); + return nullptr; + } + + auto * buft = ggml_backend_buffer_get_type(buf); + auto * dev = ggml_backend_buft_get_device(buft); + if (!dev) { + LLAMA_LOG_DEBUG("%s: no device found for buffer type %s for async uploads\n", func, + ggml_backend_buft_name(buft)); + return nullptr; + } + + if (buft != ggml_backend_dev_buffer_type(dev)) { + LLAMA_LOG_DEBUG("%s: buffer type %s is not the default buffer type for device %s for async uploads\n", func, + ggml_backend_buft_name(buft), ggml_backend_dev_name(dev)); + return nullptr; + } + + ggml_backend_dev_props props; + ggml_backend_dev_get_props(dev, &props); + if (!props.caps.async || !props.caps.host_buffer || !props.caps.events) { + LLAMA_LOG_DEBUG("%s: device %s does not support async, host buffers or events\n", func, + ggml_backend_dev_name(dev)); + return nullptr; + } + + auto * host_buft = ggml_backend_dev_host_buffer_type(dev); + if (!host_buft) { + LLAMA_LOG_DEBUG("%s: no host buffer type found for device %s\n", func, + ggml_backend_dev_name(dev)); + return nullptr; + } + + // If the backend is supported, create pinned memory buffers and events for synchronisation. + for (size_t idx = 0; idx < n_buffers; ++idx) { + auto * buf = ggml_backend_buft_alloc_buffer(host_buft, buffer_size); + if (!buf) { + LLAMA_LOG_DEBUG("%s: failed to allocate host buffer for async uploads for device %s\n", func, + ggml_backend_dev_name(dev)); + return nullptr; + } + + host_buffers.emplace_back(buf); + host_ptrs.emplace_back(ggml_backend_buffer_get_base(buf)); + + auto * event = ggml_backend_event_new(dev); + if (!event) { + LLAMA_LOG_DEBUG("%s: failed to create event for async uploads for device %s\n", func, + ggml_backend_dev_name(dev)); + return nullptr; + } + + events.emplace_back(event); + } + + ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); + if (!backend) { + LLAMA_LOG_DEBUG("%s: failed to initialize backend for device %s for async uploads\n", func, + ggml_backend_dev_name(dev)); + return nullptr; + } + + return backend; + }(__func__); + + if (upload_backend) { + LLAMA_LOG_DEBUG("%s: using async uploads for device %s, buffer type %s, backend %s\n", __func__, + ggml_backend_dev_name(ggml_backend_get_device(upload_backend)), + ggml_backend_buft_name(ggml_backend_buffer_get_type(bufs.at(0))), + ggml_backend_name(upload_backend)); + } + + for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) { + const auto * weight = get_weight(ggml_get_name(cur)); + if (weight == nullptr) { + // this can happen with split experts models + continue; + } + + if (progress_callback) { + if (!progress_callback((float) size_done / size_data, progress_callback_user_data)) { + return false; + } + } + + size_t n_size = ggml_nbytes(cur); + + if (use_mmap) { + const auto & mapping = mappings.at(weight->idx); + ggml_backend_buffer_t buf_mmap = nullptr; + if (bufs.count(weight->idx)) { + buf_mmap = bufs.at(weight->idx); + } + uint8_t * data = (uint8_t *) mapping->addr() + weight->offs; + + if (check_tensors) { + validation_result.emplace_back(std::async(std::launch::async, [cur, data, n_size] { + return std::make_pair(cur, ggml_validate_row_data(cur->type, data, n_size)); + })); + } + + GGML_ASSERT(buf_mmap || cur->data); // either we have a buffer to allocate the tensor in, or it is already allocated + if (buf_mmap && cur->data == nullptr) { + ggml_backend_tensor_alloc(buf_mmap, cur, data); + if (lmlocks) { + const auto & lmlock = lmlocks->at(weight->idx); + lmlock->grow_to(weight->offs + n_size); + } + + auto & mmap_used = mmaps_used[weight->idx]; + mmap_used.first = std::min(mmap_used.first, weight->offs); + mmap_used.second = std::max(mmap_used.second, weight->offs + n_size); + } else { + ggml_backend_tensor_set(cur, data, 0, n_size); + } + } else { + const auto & file = files.at(weight->idx); + if (ggml_backend_buffer_is_host(cur->buffer)) { + file->seek(weight->offs, SEEK_SET); + file->read_raw(cur->data, n_size); + if (check_tensors) { + validation_result.emplace_back(std::async(std::launch::async, [cur, n_size] { + return std::make_pair(cur, ggml_validate_row_data(cur->type, cur->data, n_size)); + })); + } + } else { + // If upload_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU. + if (upload_backend) { + file->seek(weight->offs, SEEK_SET); + + size_t bytes_read = 0; + + while (bytes_read < n_size) { + size_t read_iteration = std::min(buffer_size, n_size - bytes_read); + + ggml_backend_event_synchronize(events[buffer_idx]); + file->read_raw(host_ptrs[buffer_idx], read_iteration); + ggml_backend_tensor_set_async(upload_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration); + ggml_backend_event_record(events[buffer_idx], upload_backend); + + bytes_read += read_iteration; + ++buffer_idx; + buffer_idx %= n_buffers; + } + } else { + read_buf.resize(n_size); + file->seek(weight->offs, SEEK_SET); + file->read_raw(read_buf.data(), n_size); + ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size); + if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) { + throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); + } + } + } + } + + size_done += n_size; + } + + // free temporary resources used for async uploads + for (auto * event : events) { + ggml_backend_event_synchronize(event); + ggml_backend_event_free(event); + } + for (auto * buf : host_buffers) { + ggml_backend_buffer_free(buf); + } + ggml_backend_free(upload_backend); + + // check validation results + bool validation_failed = false; + for (auto & future : validation_result) { + auto result = future.get(); + if (!result.second) { + LLAMA_LOG_ERROR("%s: tensor '%s' has invalid data\n", __func__, ggml_get_name(result.first)); + validation_failed = true; + } + } + if (validation_failed) { + throw std::runtime_error("found tensors with invalid data"); + } + + // check if this is the last call and do final cleanup + if (size_done >= size_data) { + // unmap offloaded tensors and metadata + if (use_mmap) { + for (uint32_t idx = 0; idx < mappings.size(); idx++) { + const auto & mmap_used = mmaps_used.at(idx); + auto & mapping = mappings.at(idx); + mapping->unmap_fragment(0, mmap_used.first); + if (mmap_used.second != 0) { + mapping->unmap_fragment(mmap_used.second, mapping->size()); + } + } + } + if (progress_callback) { + // Even though the model is done loading, we still honor + // cancellation since we need to free allocations. + return progress_callback(1.0f, progress_callback_user_data); + } + } + + return true; +} + +std::string llama_model_loader::ftype_name() const { + return llama_model_ftype_name(ftype); +} + +void llama_model_loader::print_info() const { + LLAMA_LOG_INFO("%s: file format = %s\n", __func__, llama_file_version_name(fver)); + LLAMA_LOG_INFO("%s: file type = %s\n", __func__, llama_model_ftype_name(ftype).c_str()); + if (n_bytes < GiB) { + LLAMA_LOG_INFO("%s: file size = %.2f MiB (%.2f BPW) \n", __func__, n_bytes/1024.0/1024.0, n_bytes*8.0/n_elements); + } else { + LLAMA_LOG_INFO("%s: file size = %.2f GiB (%.2f BPW) \n", __func__, n_bytes/1024.0/1024.0/1024.0, n_bytes*8.0/n_elements); + } +} diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h new file mode 100644 index 0000000000000000000000000000000000000000..0f52b011b698624e560c767d4ad5d6cd3140343c --- /dev/null +++ b/src/llama-model-loader.h @@ -0,0 +1,169 @@ +#pragma once + +#include "llama.h" + +#include "llama-impl.h" +#include "llama-arch.h" +#include "llama-mmap.h" + +#include "ggml-cpp.h" + +#include +#include +#include +#include + +using llama_buf_map = std::unordered_map; + +enum llama_fver { + GGUF_FILE_VERSION_V1 = 1, + GGUF_FILE_VERSION_V2 = 2, + GGUF_FILE_VERSION_V3 = 3, +}; + +const char * llama_file_version_name(llama_fver version); + +struct llama_model_loader { + // Holds information on a model weight + struct llama_tensor_weight { + uint16_t idx; // source file index + size_t offs; // tensor data offset in the original file + + ggml_tensor * tensor; + + llama_tensor_weight(const llama_file * file, uint16_t idx, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) { + const int tensor_idx = gguf_find_tensor(gguf_ctx, ggml_get_name(tensor)); + if (tensor_idx < 0) { + throw std::runtime_error(format("tensor '%s' not found in the model", ggml_get_name(tensor))); + } + + offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx); + if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size()) { + throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", ggml_get_name(tensor))); + } + } + }; + + // custom comparator to sort weights more nicely by layer + struct weight_name_comparer { + bool operator()(const std::string & a, const std::string & b) const { + int a_layer = -1; + int b_layer = -1; + sscanf(a.c_str(), "blk.%d.", &a_layer); + sscanf(b.c_str(), "blk.%d.", &b_layer); + if (a_layer != b_layer) { + return a_layer < b_layer; + } + return a < b; + } + }; + + static const int TENSOR_NOT_REQUIRED = 1; + static const int TENSOR_DUPLICATED = 2; + + int n_kv = 0; + int n_tensors = 0; + int n_created = 0; + + uint64_t n_elements = 0; + size_t n_bytes = 0; + + bool use_mmap = false; + bool check_tensors; + + llama_files files; + llama_ftype ftype; + llama_fver fver; + + llama_mmaps mappings; + + std::map weights_map; + std::unordered_map kv_overrides; + const llama_model_tensor_buft_override * tensor_buft_overrides; + + gguf_context_ptr meta; + std::vector contexts; + + std::string arch_name; + LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); + + size_t size_done = 0; + size_t size_data = 0; + std::vector> mmaps_used; + + llama_model_loader( + const std::string & fname, + std::vector & splits, // optional, only need if the split does not follow naming scheme + bool use_mmap, + bool check_tensors, + const llama_model_kv_override * param_overrides_p, + const llama_model_tensor_buft_override * param_tensor_buft_overrides_p); + + template + typename std::enable_if::value, bool>::type + get_arr_n(const std::string & key, T & result, bool required = true); + + template + typename std::enable_if::value, bool>::type + get_arr_n(enum llm_kv kid, T & result, bool required = true); + + template + bool get_arr(const std::string & key, std::vector & result, bool required = true); + + template + bool get_arr(const std::string & key, std::array & result, bool required = true); + + template + bool get_arr(enum llm_kv kid, T & result, bool required = true); + + template + bool get_key(const std::string & key, T & result, bool required = true); + + template + bool get_key(enum llm_kv kid, T & result, bool required = true); + + template + bool get_key_or_arr(const std::string & key, std::array & result, uint32_t n, bool required = true); + + template + bool get_key_or_arr(enum llm_kv kid, T & result, uint32_t n, bool required = true); + + std::string get_arch_name() const; + + enum llm_arch get_arch() const; + + const llama_tensor_weight * get_weight(const char * name) const; + + const llama_tensor_weight & require_weight(const char * name) const; + + struct ggml_tensor * get_tensor_meta(const char * name) const; + + struct ggml_tensor * require_tensor_meta(const std::string & name) const; + + const struct ggml_tensor * check_tensor_dims(const std::string & name, const std::vector & ne, bool required) const; + + struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list & ne, int flags = 0); + + struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list & ne, size_t offset, bool required = true); + + void done_getting_tensors() const; + + void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr); + + void get_mapping_range(size_t * first, size_t * last, void ** addr, int idx, ggml_context * ctx) const; + + // for backwards compatibility, does not support ggml-backend + void load_data_for(struct ggml_tensor * cur) const; + + // Returns false if cancelled by progress_callback + bool load_all_data( + struct ggml_context * ctx, + llama_buf_map & bufs, + llama_mlocks * lmlocks, + llama_progress_callback progress_callback, + void * progress_callback_user_data); + + std::string ftype_name() const; + + void print_info() const; +}; diff --git a/src/llama-model-saver.cpp b/src/llama-model-saver.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a70b9892347cb0a0ea1204e6adf4416597e4f66e --- /dev/null +++ b/src/llama-model-saver.cpp @@ -0,0 +1,281 @@ +#include "llama-model-saver.h" + +#include "gguf.h" + +#include "llama.h" +#include "llama-hparams.h" +#include "llama-model.h" +#include "llama-vocab.h" + +#include + +llama_model_saver::llama_model_saver(const struct llama_model & model) : model(model), llm_kv(model.arch) { + gguf_ctx = gguf_init_empty(); +} + +llama_model_saver::~llama_model_saver() { + gguf_free(gguf_ctx); +} + +void llama_model_saver::add_kv(const enum llm_kv key, const uint32_t value) { + gguf_set_val_u32(gguf_ctx, llm_kv(key).c_str(), value); +} + +void llama_model_saver::add_kv(const enum llm_kv key, const int32_t value) { + gguf_set_val_i32(gguf_ctx, llm_kv(key).c_str(), value); +} + +void llama_model_saver::add_kv(const enum llm_kv key, const float value) { + gguf_set_val_f32(gguf_ctx, llm_kv(key).c_str(), value); +} + +void llama_model_saver::add_kv(const enum llm_kv key, const bool value) { + gguf_set_val_bool(gguf_ctx, llm_kv(key).c_str(), value); +} + +void llama_model_saver::add_kv(const enum llm_kv key, const char * value) { + gguf_set_val_str(gguf_ctx, llm_kv(key).c_str(), value); +} + +[[noreturn]] +void llama_model_saver::add_kv(const enum llm_kv key, const char value) { + GGML_UNUSED(key); + GGML_UNUSED(value); + GGML_ABORT("fatal error"); // this should never be called, only needed to make the template below compile +} + +template +void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, const bool per_layer) { + const size_t n_values = per_layer ? size_t(model.hparams.n_layer) : value.size(); + GGML_ASSERT(n_values <= value.size()); + + if (n_values == 0) { + return; + } + + if (per_layer) { + bool all_values_the_same = true; + for (size_t i = 1; i < n_values; ++i) { + if (value[i] != value[0]) { + all_values_the_same = false; + break; + } + } + if (all_values_the_same) { + add_kv(key, value[0]); + return; + } + } + + if (std::is_same::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT8, value.data(), n_values); + } else if (std::is_same::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT8, value.data(), n_values); + } else if (std::is_same::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT32, value.data(), n_values); + } else if (std::is_same::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT32, value.data(), n_values); + } else if (std::is_same::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_FLOAT32, value.data(), n_values); + } else if (std::is_same::value) { + gguf_set_val_str(gguf_ctx, llm_kv(key).c_str(), reinterpret_cast(value.data())); + } else { + GGML_ABORT("fatal error"); + } +} + +void llama_model_saver::add_kv(const enum llm_kv key, const std::vector & value) { + std::vector tmp(value.size()); + for (size_t i = 0; i < value.size(); ++i) { + tmp[i] = value[i].c_str(); + } + gguf_set_arr_str(gguf_ctx, llm_kv(key).c_str(), tmp.data(), tmp.size()); +} + +void llama_model_saver::add_tensor(const struct ggml_tensor * tensor) { + if (!tensor) { + return; + } + if (gguf_find_tensor(gguf_ctx, tensor->name) >= 0) { + GGML_ASSERT(std::string(tensor->name) == "rope_freqs.weight"); // FIXME + return; + } + gguf_add_tensor(gguf_ctx, tensor); +} + +void llama_model_saver::add_kv_from_model() { + const llama_hparams & hparams = model.hparams; + const llama_vocab & vocab = model.vocab; + + const int32_t n_vocab = vocab.n_tokens(); + std::vector tokens(n_vocab); + std::vector scores(n_vocab); + std::vector token_types(n_vocab); + + for (int32_t id = 0; id < n_vocab; ++id) { + const llama_vocab::token_data & token_data = vocab.get_token_data(id); + + tokens[id] = token_data.text; + scores[id] = token_data.score; + + switch(token_data.attr) { + case LLAMA_TOKEN_ATTR_UNKNOWN: token_types[id] = LLAMA_TOKEN_TYPE_UNKNOWN; break; + case LLAMA_TOKEN_ATTR_UNUSED: token_types[id] = LLAMA_TOKEN_TYPE_UNUSED; break; + case LLAMA_TOKEN_ATTR_NORMAL: token_types[id] = LLAMA_TOKEN_TYPE_NORMAL; break; + case LLAMA_TOKEN_ATTR_CONTROL: token_types[id] = LLAMA_TOKEN_TYPE_CONTROL; break; + case LLAMA_TOKEN_ATTR_USER_DEFINED: token_types[id] = LLAMA_TOKEN_TYPE_USER_DEFINED; break; + case LLAMA_TOKEN_ATTR_BYTE: token_types[id] = LLAMA_TOKEN_TYPE_BYTE; break; + case LLAMA_TOKEN_ATTR_UNDEFINED: + default: token_types[id] = LLAMA_TOKEN_TYPE_UNDEFINED; break; + } + } + + // add_kv(LLM_KV_GENERAL_TYPE, ???); + add_kv(LLM_KV_GENERAL_ARCHITECTURE, model.arch_name()); + // add_kv(LLM_KV_GENERAL_QUANTIZATION_VERSION, ???); + // add_kv(LLM_KV_GENERAL_ALIGNMENT, ???); + add_kv(LLM_KV_GENERAL_NAME, model.name); + // add_kv(LLM_KV_GENERAL_AUTHOR, ???); + // add_kv(LLM_KV_GENERAL_VERSION, ???); + // add_kv(LLM_KV_GENERAL_URL, ???); + // add_kv(LLM_KV_GENERAL_DESCRIPTION, ???); + // add_kv(LLM_KV_GENERAL_LICENSE, ???); + // add_kv(LLM_KV_GENERAL_SOURCE_URL, ???); + // add_kv(LLM_KV_GENERAL_SOURCE_HF_REPO, ???); + + add_kv(LLM_KV_VOCAB_SIZE, vocab.n_tokens()); + add_kv(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); + add_kv(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); + add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer); + add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + add_kv(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, true); + add_kv(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + add_kv(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + add_kv(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); + // add_kv(LLM_KV_TENSOR_DATA_LAYOUT, ???); + add_kv(LLM_KV_EXPERT_COUNT, hparams.n_expert); + add_kv(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + add_kv(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + add_kv(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + add_kv(LLM_KV_POOLING_TYPE, uint32_t(hparams.pooling_type)); + add_kv(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + add_kv(LLM_KV_DECODER_START_TOKEN_ID, hparams.dec_start_token_id); + add_kv(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping); + add_kv(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping); + add_kv(LLM_KV_SWIN_NORM, hparams.swin_norm); + add_kv(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers); + add_kv(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); + add_kv(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); + add_kv(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); + add_kv(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); + + add_kv(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, true); + add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true); + add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); + add_kv(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); + add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k); + add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v); + add_kv(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + add_kv(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + add_kv(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + add_kv(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + add_kv(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + add_kv(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); + + const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train; + + add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot); + add_kv(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train); + // add_kv(LLM_KV_ROPE_SCALE_LINEAR, rope_scaling_factor); // old name + add_kv(LLM_KV_ROPE_SCALING_TYPE, llama_rope_scaling_type_name(hparams.rope_scaling_type_train)); + add_kv(LLM_KV_ROPE_SCALING_FACTOR, rope_scaling_factor); + add_kv(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor); + add_kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn); + add_kv(LLM_KV_ROPE_SCALING_FINETUNED, hparams.rope_finetuned); + add_kv(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); + + // TODO: implement split file support + // add_kv(LLM_KV_SPLIT_NO, ???); + // add_kv(LLM_KV_SPLIT_COUNT, ???); + // add_kv(LLM_KV_SPLIT_TENSORS_COUNT, ???); + + add_kv(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + add_kv(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + add_kv(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + add_kv(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + add_kv(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms); + + add_kv(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + + add_kv(LLM_KV_TOKENIZER_MODEL, vocab.get_tokenizer_model()); + add_kv(LLM_KV_TOKENIZER_PRE, vocab.get_tokenizer_pre()); + add_kv(LLM_KV_TOKENIZER_LIST, tokens); + add_kv(LLM_KV_TOKENIZER_TOKEN_TYPE, token_types); + add_kv(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, vocab.n_token_types()); + add_kv(LLM_KV_TOKENIZER_SCORES, scores); + add_kv(LLM_KV_TOKENIZER_MERGES, vocab.get_bpe_merges()); + // FIXME llama_token is type i32 but when reading in a GGUF file u32 is expected, not an issue for writing though + add_kv(LLM_KV_TOKENIZER_BOS_ID, uint32_t(vocab.token_bos())); + add_kv(LLM_KV_TOKENIZER_EOS_ID, uint32_t(vocab.token_eos())); + add_kv(LLM_KV_TOKENIZER_EOT_ID, uint32_t(vocab.token_eot())); + add_kv(LLM_KV_TOKENIZER_EOM_ID, uint32_t(vocab.token_eom())); + add_kv(LLM_KV_TOKENIZER_UNK_ID, uint32_t(vocab.token_unk())); + add_kv(LLM_KV_TOKENIZER_SEP_ID, uint32_t(vocab.token_sep())); + add_kv(LLM_KV_TOKENIZER_PAD_ID, uint32_t(vocab.token_pad())); + // add_kv(LLM_KV_TOKENIZER_CLS_ID, uint32_t(vocab.token_bos())); // deprecated + // add_kv(LLM_KV_TOKENIZER_MASK_ID, ???); + add_kv(LLM_KV_TOKENIZER_ADD_BOS, vocab.get_add_bos()); + add_kv(LLM_KV_TOKENIZER_ADD_EOS, vocab.get_add_eos()); + add_kv(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.get_add_space_prefix()); + add_kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.get_remove_extra_whitespaces()); + add_kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, vocab.get_precompiled_charsmap()); + // add_kv(LLM_KV_TOKENIZER_HF_JSON, ???); + // add_kv(LLM_KV_TOKENIZER_RWKV, ???); + add_kv(LLM_KV_TOKENIZER_FIM_PRE_ID, uint32_t(vocab.token_fim_pre())); + add_kv(LLM_KV_TOKENIZER_FIM_SUF_ID, uint32_t(vocab.token_fim_suf())); + add_kv(LLM_KV_TOKENIZER_FIM_MID_ID, uint32_t(vocab.token_fim_mid())); + add_kv(LLM_KV_TOKENIZER_FIM_PAD_ID, uint32_t(vocab.token_fim_pad())); + add_kv(LLM_KV_TOKENIZER_FIM_REP_ID, uint32_t(vocab.token_fim_rep())); + add_kv(LLM_KV_TOKENIZER_FIM_SEP_ID, uint32_t(vocab.token_fim_sep())); + + // TODO: implement LoRA support + // add_kv(LLM_KV_ADAPTER_TYPE, ???); + // add_kv(LLM_KV_ADAPTER_LORA_ALPHA, ???); + + // deprecated + // add_kv(LLM_KV_TOKENIZER_PREFIX_ID, ???); + // add_kv(LLM_KV_TOKENIZER_SUFFIX_ID, ???); + // add_kv(LLM_KV_TOKENIZER_MIDDLE_ID, ???); +} + +void llama_model_saver::add_tensors_from_model() { + if (std::string(model.output->name) != std::string(model.tok_embd->name)) { + add_tensor(model.tok_embd); // some models use the same tensor for tok_embd and output + } + add_tensor(model.type_embd); + add_tensor(model.pos_embd); + add_tensor(model.tok_norm); + add_tensor(model.tok_norm_b); + add_tensor(model.output_norm); + add_tensor(model.output_norm_b); + add_tensor(model.output); + add_tensor(model.output_b); + add_tensor(model.output_norm_enc); + add_tensor(model.cls); + add_tensor(model.cls_b); + add_tensor(model.cls_out); + add_tensor(model.cls_out_b); + + for (const struct llama_layer & layer : model.layers) { + for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { + add_tensor(reinterpret_cast(&layer)[i]); + } + } +} + +void llama_model_saver::save(const std::string & path_model) { + gguf_write_to_file(gguf_ctx, path_model.c_str(), false); +} + diff --git a/src/llama-model-saver.h b/src/llama-model-saver.h new file mode 100644 index 0000000000000000000000000000000000000000..a5a434c30698a432b1c52604383fcad140a36004 --- /dev/null +++ b/src/llama-model-saver.h @@ -0,0 +1,37 @@ +#pragma once + +#include "llama.h" +#include "llama-arch.h" + +#include + +struct llama_model_saver { + struct gguf_context * gguf_ctx = nullptr; + const struct llama_model & model; + const struct LLM_KV llm_kv; + + llama_model_saver(const struct llama_model & model); + ~llama_model_saver(); + + void add_kv(enum llm_kv key, uint32_t value); + void add_kv(enum llm_kv key, int32_t value); + void add_kv(enum llm_kv key, float value); + void add_kv(enum llm_kv key, bool value); + void add_kv(enum llm_kv key, const char * value); + + [[noreturn]] + void add_kv(enum llm_kv key, char value); // needed to make the template below compile + + template + void add_kv(enum llm_kv key, const Container & value, bool per_layer = false); + + void add_kv(enum llm_kv key, const std::vector & value); + + void add_tensor(const struct ggml_tensor * tensor); + + void add_kv_from_model(); + + void add_tensors_from_model(); + + void save(const std::string & path_model); +}; diff --git a/src/llama-model.cpp b/src/llama-model.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a5eb122f998d85cbcd309953dd71518271578ceb --- /dev/null +++ b/src/llama-model.cpp @@ -0,0 +1,14392 @@ +#include "llama-model.h" + +#include "llama-impl.h" +#include "llama-mmap.h" +#include "llama-batch.h" +#include "llama-cparams.h" +#include "llama-model-loader.h" + +#include "llama-kv-cache-unified.h" +#include "llama-kv-cache-unified-iswa.h" +#include "llama-kv-cache-recurrent.h" + +#include "ggml-cpp.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +const char * llm_type_name(llm_type type) { + switch (type) { + case LLM_TYPE_14M: return "14M"; + case LLM_TYPE_17M: return "17M"; + case LLM_TYPE_22M: return "22M"; + case LLM_TYPE_33M: return "33M"; + case LLM_TYPE_60M: return "60M"; + case LLM_TYPE_70M: return "70M"; + case LLM_TYPE_80M: return "80M"; + case LLM_TYPE_109M: return "109M"; + case LLM_TYPE_137M: return "137M"; + case LLM_TYPE_160M: return "160M"; + case LLM_TYPE_190M: return "190M"; + case LLM_TYPE_220M: return "220M"; + case LLM_TYPE_250M: return "250M"; + case LLM_TYPE_270M: return "270M"; + case LLM_TYPE_335M: return "335M"; + case LLM_TYPE_410M: return "410M"; + case LLM_TYPE_450M: return "450M"; + case LLM_TYPE_475M: return "475M"; + case LLM_TYPE_770M: return "770M"; + case LLM_TYPE_780M: return "780M"; + case LLM_TYPE_0_5B: return "0.5B"; + case LLM_TYPE_0_6B: return "0.6B"; + case LLM_TYPE_1B: return "1B"; + case LLM_TYPE_1_3B: return "1.3B"; + case LLM_TYPE_1_4B: return "1.4B"; + case LLM_TYPE_1_5B: return "1.5B"; + case LLM_TYPE_1_6B: return "1.6B"; + case LLM_TYPE_1_7B: return "1.7B"; + case LLM_TYPE_1_8B: return "1.8B"; + case LLM_TYPE_2B: return "2B"; + case LLM_TYPE_2_8B: return "2.8B"; + case LLM_TYPE_2_9B: return "2.9B"; + case LLM_TYPE_3B: return "3B"; + case LLM_TYPE_4B: return "4B"; + case LLM_TYPE_6B: return "6B"; + case LLM_TYPE_6_9B: return "6.9B"; + case LLM_TYPE_7B: return "7B"; + case LLM_TYPE_8B: return "8B"; + case LLM_TYPE_9B: return "9B"; + case LLM_TYPE_11B: return "11B"; + case LLM_TYPE_12B: return "12B"; + case LLM_TYPE_13B: return "13B"; + case LLM_TYPE_14B: return "14B"; + case LLM_TYPE_15B: return "15B"; + case LLM_TYPE_16B: return "16B"; + case LLM_TYPE_20B: return "20B"; + case LLM_TYPE_27B: return "27B"; + case LLM_TYPE_30B: return "30B"; + case LLM_TYPE_32B: return "32B"; + case LLM_TYPE_34B: return "34B"; + case LLM_TYPE_35B: return "35B"; + case LLM_TYPE_40B: return "40B"; + case LLM_TYPE_65B: return "65B"; + case LLM_TYPE_70B: return "70B"; + case LLM_TYPE_142B: return "142B"; + case LLM_TYPE_236B: return "236B"; + case LLM_TYPE_290B: return "290B"; + case LLM_TYPE_314B: return "314B"; + case LLM_TYPE_405B: return "405B"; + case LLM_TYPE_671B: return "671B"; + case LLM_TYPE_SMALL: return "0.1B"; + case LLM_TYPE_MEDIUM: return "0.4B"; + case LLM_TYPE_LARGE: return "0.8B"; + case LLM_TYPE_XL: return "1.5B"; + case LLM_TYPE_A1_7B: return "A1.7B"; + case LLM_TYPE_A2_7B: return "A2.7B"; + case LLM_TYPE_8x7B: return "8x7B"; + case LLM_TYPE_8x22B: return "8x22B"; + case LLM_TYPE_16x12B: return "16x12B"; + case LLM_TYPE_16x3_8B: return "16x3.8B"; + case LLM_TYPE_10B_128x3_66B: return "10B+128x3.66B"; + case LLM_TYPE_57B_A14B: return "57B.A14B"; + case LLM_TYPE_17B_16E: return "17Bx16E (Scout)"; + case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)"; + case LLM_TYPE_30B_A3B: return "30B.A3B"; + case LLM_TYPE_235B_A22B: return "235B.A22B"; + default: return "?B"; + } +} + +static const char * llama_expert_gating_func_name(llama_expert_gating_func_type type) { + switch (type) { + case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX: return "softmax"; + case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID: return "sigmoid"; + default: return "unknown"; + } +} + +static const std::map LLAMA_ROPE_SCALING_TYPES = { + { LLAMA_ROPE_SCALING_TYPE_NONE, "none" }, + { LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" }, + { LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" }, + { LLAMA_ROPE_SCALING_TYPE_LONGROPE, "longrope" }, +}; + +std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type) { + return LLAMA_ROPE_SCALING_TYPES.at(rope_scaling_type); +} + +static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) { + for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) { + if (kv.second == name) { + return (llama_rope_scaling_type) kv.first; + } + } + + return LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; +} + +// checks if the weight tensor can be used with the specified buffer type and device +static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { + GGML_ASSERT(w != nullptr); + + if (op == GGML_OP_NONE) { + return true; + } + + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead()*8, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error(format("failed to create ggml context")); + } + ggml_context * ctx = ctx_ptr.get(); + + ggml_tensor * op_tensor = nullptr; + + switch (op) { + case GGML_OP_GET_ROWS: + { + ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); + op_tensor = ggml_get_rows(ctx, w, b); + } break; + case GGML_OP_MUL_MAT: + { + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], 512, w->ne[2], w->ne[3]); + op_tensor = ggml_mul_mat(ctx, w, b); + } break; + case GGML_OP_MUL_MAT_ID: + { + int n_expert_used = hparams.n_expert_used; + ggml_tensor * b = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512); + ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512); + op_tensor = ggml_mul_mat_id(ctx, w, b, ids); + } break; + case GGML_OP_ADD: + { + ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); + op_tensor = ggml_add(ctx, a, w); + } break; + case GGML_OP_MUL: + { + ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); + op_tensor = ggml_mul(ctx, a, w); + } break; + case GGML_OP_DIV: + { + ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, w->ne[0]); + op_tensor = ggml_div(ctx, a, w); + } break; + case GGML_OP_ROPE: + { + int n_embd_head = hparams.n_embd_head_v; + int n_head = hparams.n_head(); + ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, 512); + ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); + op_tensor = ggml_rope_ext( + ctx, a, b, w, + 0, 0, 0, 0, 0, + 0, 0, 0, 0 + ); + + } break; + case GGML_OP_SSM_CONV: + { + // FIXME + ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789); + op_tensor = ggml_ssm_conv(ctx, conv_x, w); + } break; + case GGML_OP_SSM_SCAN: + { + // FIXME + const int64_t d_state = w->ne[0]; + const int64_t d_inner = w->ne[1]; + const int64_t n_seq_tokens = 512; + const int64_t n_seqs = 1; + ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs); + ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); + ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); + op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C); + } break; + case GGML_OP_RWKV_WKV6: + { + // FIXME + const int64_t S = 123; + const int64_t H = 123; + const int64_t n_tokens = 123; + const int64_t n_seqs = 123; + ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * r = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * tf = w; + ggml_tensor * td = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H); + op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state); + } break; + case GGML_OP_IM2COL: + { + const int n_embd = hparams.n_embd; + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1); + op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16); + } break; + default: + GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name); + } + + // create a temporary dummy buffer for the weight so that supports_op can check the buffer type + GGML_ASSERT(w->buffer == nullptr); + w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); + bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + ggml_backend_buffer_free(w->buffer); + w->buffer = nullptr; + + return op_supported; +} + +// lists of buffer types used for each layer +using buft_list_t = std::vector>; + +// find the first buffer type in the list that can use the tensor +static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hparams, ggml_tensor * tensor, ggml_op op, const buft_list_t & buft_list) { + GGML_ASSERT(!buft_list.empty()); + for (const auto & cur : buft_list) { + ggml_backend_dev_t cur_dev = cur.first; + ggml_backend_buffer_type_t cur_buft = cur.second; + if (weight_buft_supported(hparams, tensor, op, cur_buft, cur_dev)) { + return cur_buft; + } + } + + return nullptr; +} + +// CPU: ACCEL -> GPU host -> CPU extra -> CPU +static buft_list_t make_cpu_buft_list(const std::vector & devices) { + buft_list_t buft_list; + + // add ACCEL buffer types + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { + auto * buft = ggml_backend_dev_buffer_type(dev); + // skip + if (buft != ggml_backend_cpu_buffer_type()) { + buft_list.emplace_back(dev, buft); + } + } + } + + // add a host buffer type + // storing the tensors in a host buffer is useful when the processing of large batches + // is offloaded to a GPU device, since it reduces the time spent on data transfers + // generally, this will be done using the first device in the list + // a better approach would be to handle this on a weight-by-weight basis using the offload_op + // function of the device to determine if it would benefit from being stored in a host buffer + for (auto * dev : devices) { + ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); + if (buft) { + buft_list.emplace_back(dev, buft); + break; + } + } + + // add extra buffer types, only if no GPU device is present + // ref: https://github.com/ggml-org/llama.cpp/issues/12481#issuecomment-2743136094 + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev == nullptr) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } + + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + if (ggml_backend_dev_get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(cpu_dev, *extra_bufts); + ++extra_bufts; + } + } + + // add the CPU buffer type + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) { + buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev)); + } + } + + return buft_list; +} + +// GPU: split if LLAMA_SPLIT_MODE_ROW -> GPU +static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, llama_split_mode split_mode, const float * tensor_split) { + buft_list_t buft_list; + + // add the device split buffer type if requested and available + if (split_mode == LLAMA_SPLIT_MODE_ROW) { + ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); + auto ggml_backend_split_buffer_type_fn = (ggml_backend_split_buffer_type_t) + ggml_backend_reg_get_proc_address(reg, "ggml_backend_split_buffer_type"); + if (ggml_backend_split_buffer_type_fn) { + size_t dev_index = [&]() { + auto * reg = ggml_backend_dev_backend_reg(dev); + for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); ++i) { + if (ggml_backend_reg_dev_get(reg, i) == dev) { + return i; + } + } + throw std::runtime_error(format("device %s not found in its backend reg", ggml_backend_dev_name(dev))); + }(); + auto * buft = ggml_backend_split_buffer_type_fn(dev_index, tensor_split); + if (buft != nullptr) { + buft_list.emplace_back(dev, buft); + } + } + } + + // add the device default buffer type + buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev)); + + return buft_list; +} + +struct llama_model::impl { + impl() {} + ~impl() {} + + uint64_t n_elements = 0; + + size_t n_bytes = 0; + + std::string desc_str; + + // model memory mapped files + llama_mmaps mappings; + + // objects representing data potentially being locked in memory + llama_mlocks mlock_bufs; + llama_mlocks mlock_mmaps; + + // contexts where the model tensors metadata is stored + std::vector ctxs; + + // the model memory buffers for the tensor data + std::vector bufs; + + buft_list_t cpu_buft_list; + std::map gpu_buft_list; + + struct layer_dev { + ggml_backend_dev_t dev; + buft_list_t * buft_list; + }; + + layer_dev dev_input = {}; + layer_dev dev_output = {}; + std::vector dev_layer; + + bool has_tensor_overrides; +}; + +llama_model::llama_model(const llama_model_params & params) : params(params), pimpl(std::make_unique()) { + pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern; +} + +llama_model::~llama_model() {} + +void llama_model::load_stats(llama_model_loader & ml) { + pimpl->n_elements = ml.n_elements; + pimpl->n_bytes = ml.n_bytes; +} + +void llama_model::load_arch(llama_model_loader & ml) { + arch = ml.get_arch(); + if (arch == LLM_ARCH_UNKNOWN) { + throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); + } +} + +void llama_model::load_hparams(llama_model_loader & ml) { + const gguf_context * ctx = ml.meta.get(); + + // get metadata as string + for (int i = 0; i < gguf_get_n_kv(ctx); i++) { + gguf_type type = gguf_get_kv_type(ctx, i); + if (type == GGUF_TYPE_ARRAY) { + continue; + } + const char * name = gguf_get_key(ctx, i); + const std::string value = gguf_kv_to_str(ctx, i); + gguf_kv.emplace(name, value); + } + + // get general kv + ml.get_key(LLM_KV_GENERAL_NAME, name, false); + + // everything past this point is not vocab-related + if (hparams.vocab_only) { + return; + } + + ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); + ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); + ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); + + if (arch == LLM_ARCH_WAVTOKENIZER_DEC) { + ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features); + + ml.get_key(LLM_KV_POSNET_EMBEDDING_LENGTH, hparams.posnet.n_embd); + ml.get_key(LLM_KV_POSNET_BLOCK_COUNT, hparams.posnet.n_layer); + + ml.get_key(LLM_KV_CONVNEXT_EMBEDDING_LENGTH, hparams.convnext.n_embd); + ml.get_key(LLM_KV_CONVNEXT_BLOCK_COUNT, hparams.convnext.n_layer); + } + + GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS); + GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert); + if (hparams.n_expert > 0) { + GGML_ASSERT(hparams.n_expert_used > 0); + } else { + GGML_ASSERT(hparams.n_expert_used == 0); + } + + std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); + std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); + std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); + + std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); + + std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0); + + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); + + // n_head_kv is optional, default to n_head + hparams.n_head_kv_arr = hparams.n_head_arr; + + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, hparams.n_layer, false); + + bool rope_finetuned = false; + ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); + hparams.rope_finetuned = rope_finetuned; + + hparams.n_ctx_orig_yarn = hparams.n_ctx_train; + ml.get_key(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn, false); + + // rope_freq_base (optional) + hparams.rope_freq_base_train = 10000.0f; + ml.get_key(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train, false); + + std::string rope_scaling("linear"); + ml.get_key(LLM_KV_ROPE_SCALING_TYPE, rope_scaling, false); + hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling); + GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED); + + // rope_freq_scale (inverse of the kv) is optional + float ropescale = 0.0f; + if (!ml.get_key(LLM_KV_ROPE_SCALING_FACTOR, ropescale, false)) { + // try the old key name + ml.get_key(LLM_KV_ROPE_SCALE_LINEAR, ropescale, false); + } + hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; + + // by default assume that the sliding-window layers use the same scaling type as the non-sliding-window layers + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + + ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); + + // non-transformer models do not have attention heads + if (hparams.n_head() > 0) { + // gpt-neox n_rot = rotary_pct * (n_embd / n_head) + // gpt-j n_rot = rotary_dim + + hparams.n_embd_head_k = hparams.n_embd / hparams.n_head(); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); + + hparams.n_embd_head_v = hparams.n_embd / hparams.n_head(); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); + + // sanity check for n_rot (optional) + hparams.n_rot = hparams.n_embd_head_k; + + ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); + + if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) { + if (hparams.n_rot != hparams.n_embd_head_k) { + throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); + } + } + } else { + hparams.n_rot = 0; + hparams.n_embd_head_k = 0; + hparams.n_embd_head_v = 0; + } + + // for differentiating model types + uint32_t n_vocab = 0; + ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false); + + // for classifier models + ml.get_arr(LLM_KV_CLASSIFIER_OUTPUT_LABELS, classifier_labels, false); + if (!classifier_labels.empty()) { + hparams.n_cls_out = classifier_labels.size(); + } + + // arch-specific KVs + switch (arch) { + case LLM_ARCH_LLAMA: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + if (hparams.n_expert == 8) { + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8x7B; break; + case 56: type = LLM_TYPE_8x22B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } else { + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_1B; break; // Llama 3.2 1B + case 22: type = LLM_TYPE_1B; break; + case 26: type = LLM_TYPE_3B; break; + case 28: type = LLM_TYPE_3B; break; // Llama 3.2 3B + // granite uses a vocab with len 49152 + case 32: type = n_vocab == 49152 ? LLM_TYPE_3B : (n_vocab < 40000 ? LLM_TYPE_7B : LLM_TYPE_8B); break; + case 36: type = LLM_TYPE_8B; break; // granite + case 40: type = LLM_TYPE_13B; break; + case 48: type = LLM_TYPE_34B; break; + case 60: type = LLM_TYPE_30B; break; + case 80: type = hparams.n_head() == hparams.n_head_kv() ? LLM_TYPE_65B : LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } + } break; + case LLM_ARCH_LLAMA4: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); + + hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; + hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick + hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full + + switch (hparams.n_expert) { + case 16: type = LLM_TYPE_17B_16E; break; + case 128: type = LLM_TYPE_17B_128E; break; + default: type = LLM_TYPE_UNKNOWN; + } + + if (type == LLM_TYPE_17B_128E) { + hparams.use_kq_norm = false; + } + } break; + case LLM_ARCH_ARCEE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Arcee uses the same structure as Llama + switch (hparams.n_layer) { + case 36: type = LLM_TYPE_4B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_DECI: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 80: type = LLM_TYPE_70B; break; + case 162: type = LLM_TYPE_405B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_MINICPM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + + switch (hparams.n_layer) { + case 52: type = LLM_TYPE_1B; break; + case 40: type = LLM_TYPE_2B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_MINICPM3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + + switch (hparams.n_layer) { + case 62: type = LLM_TYPE_4B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_GROK: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 64: type = LLM_TYPE_314B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_FALCON: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 60: type = LLM_TYPE_40B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_BAICHUAN: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + if (type == LLM_TYPE_13B) { + // TODO: become GGUF KV parameter + hparams.f_max_alibi_bias = 8.0f; + } + } break; + case LLM_ARCH_STARCODER: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1B; break; + case 36: type = LLM_TYPE_3B; break; + case 42: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_15B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_REFACT: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + // TODO: become GGUF KV parameter + hparams.f_max_alibi_bias = 8.0f; + } break; + case LLM_ARCH_BERT: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); + + switch (hparams.n_layer) { + case 3: + type = LLM_TYPE_17M; break; // bge-micro + case 6: + type = LLM_TYPE_22M; break; // MiniLM-L6 + case 12: + switch (hparams.n_embd) { + case 384: type = LLM_TYPE_33M; break; // MiniLM-L12, bge-small + case 768: type = LLM_TYPE_109M; break; // bge-base + default: type = LLM_TYPE_UNKNOWN; + } break; + case 24: + type = LLM_TYPE_335M; break; // bge-large + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_JINA_BERT_V2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); + hparams.f_max_alibi_bias = 8.0f; + + switch (hparams.n_layer) { + case 4: type = LLM_TYPE_33M; break; // jina-embeddings-small + case 12: type = LLM_TYPE_137M; break; // jina-embeddings-base + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_NOMIC_BERT: + case LLM_ARCH_NOMIC_BERT_MOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0); + + if (hparams.n_layer == 12 && hparams.n_embd == 768) { + if (arch == LLM_ARCH_NOMIC_BERT) { + type = LLM_TYPE_137M; + } else if (arch == LLM_ARCH_NOMIC_BERT_MOE && hparams.moe_every_n_layers == 2) { + type = LLM_TYPE_475M; + } + } + } break; + case LLM_ARCH_NEO_BERT: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + + if (hparams.n_layer == 28) { + type = LLM_TYPE_250M; + } + } break; + case LLM_ARCH_BLOOM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1B; break; + case 30: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } + + // TODO: become GGUF KV parameter + hparams.f_max_alibi_bias = 8.0f; + } break; + case LLM_ARCH_MPT: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); + ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_30B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_STABLELM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_12B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN2VL: + { + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + } + // fall through + case LLM_ARCH_QWEN2: + { + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break; + case 28: type = hparams.n_embd == 1536 ? LLM_TYPE_1_5B : LLM_TYPE_7B; break; + case 32: type = LLM_TYPE_7B; break; + case 36: type = LLM_TYPE_3B; break; + case 40: type = hparams.n_head() == 20 ? LLM_TYPE_4B : LLM_TYPE_13B; break; + case 48: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; + case 80: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN2MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_A2_7B; break; + case 28: type = LLM_TYPE_57B_A14B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 28: type = hparams.n_embd == 1024 ? LLM_TYPE_0_6B : LLM_TYPE_1_7B; break; + case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break; + case 40: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN3MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 48: type = LLM_TYPE_30B_A3B; break; + case 94: type = LLM_TYPE_235B_A22B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_PHI2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_PHI3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + + if (found_swa && hparams.n_swa > 0) { + LLAMA_LOG_WARN("%s: Phi SWA is currently disabled - results might be suboptimal for some models (see %s)\n", + __func__, "https://github.com/ggml-org/llama.cpp/pull/13676"); + + // TODO: fix conversion scripts to correctly populate `n_swa` and `n_swa_pattern` + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + + hparams.n_swa = 0; + hparams.set_swa_pattern(1); + } + } break; + case LLM_ARCH_PHIMOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_16x3_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_PLAMO: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_GPT2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 12: type = LLM_TYPE_SMALL; break; + case 24: type = LLM_TYPE_MEDIUM; break; + case 36: type = LLM_TYPE_LARGE; break; + case 48: type = LLM_TYPE_XL; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_CODESHELL: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 42: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_ORION: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_INTERNLM2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_GEMMA: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 18: type = LLM_TYPE_2B; break; + case 28: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_GEMMA2: + { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 4096; // default value of gemma 2 + hparams.set_swa_pattern(2); + hparams.attn_soft_cap = true; + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + + switch (hparams.n_layer) { + case 26: type = LLM_TYPE_2B; break; + case 42: type = LLM_TYPE_9B; break; + case 46: type = LLM_TYPE_27B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173 + hparams.f_attention_scale = type == LLM_TYPE_27B + ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) + : 1.0f / std::sqrt(float(hparams.n_embd_head_k)); + } break; + case LLM_ARCH_GEMMA3: + { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.set_swa_pattern(6); + + hparams.rope_freq_base_train_swa = 10000.0f; + hparams.rope_freq_scale_train_swa = 1.0f; + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 26: type = LLM_TYPE_1B; break; + case 34: type = LLM_TYPE_4B; break; + case 48: type = LLM_TYPE_12B; break; + case 62: type = LLM_TYPE_27B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289 + hparams.f_attention_scale = type == LLM_TYPE_27B + ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) + : 1.0f / std::sqrt(float(hparams.n_embd_head_k)); + } break; + case LLM_ARCH_STARCODER2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 30: type = LLM_TYPE_3B; break; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_15B; break; + case 52: type = LLM_TYPE_20B; break; // granite + case 88: type = LLM_TYPE_34B; break; // granite + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_MAMBA: + { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 24: + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_SMALL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 48: + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_MEDIUM; break; + case 1536: type = LLM_TYPE_LARGE; break; + case 2048: type = LLM_TYPE_XL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 64: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_XVERSE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + case 80: type = LLM_TYPE_65B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_COMMAND_R: + { + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_35B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_COHERE2: + { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.set_swa_pattern(4); + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_DBRX: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); + + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_16x12B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_OLMO: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); + + switch (hparams.n_layer) { + case 22: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_7B; break; + case 80: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_OLMO2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_OLMOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_A1_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_OPENELM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_270M; break; + case 20: type = LLM_TYPE_450M; break; + case 28: type = LLM_TYPE_1B; break; + case 36: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_GPTNEOX: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); + switch (hparams.n_layer) { + case 6: + switch (hparams.n_ff()) { + case 512: type = LLM_TYPE_14M; break; + case 2048: type = LLM_TYPE_70M; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 12: + switch (hparams.n_ff()) { + case 3072: type = LLM_TYPE_160M; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 16: + switch (hparams.n_ff()) { + case 8192: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 24: + switch (hparams.n_ff()) { + case 4096: type = LLM_TYPE_410M; break; + case 8192: type = LLM_TYPE_1_4B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 32: + switch (hparams.n_ff()) { + case 10240: type = LLM_TYPE_2_8B; break; + case 16384: type = LLM_TYPE_6_9B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 36: + switch (hparams.n_ff()) { + case 20480: type = LLM_TYPE_12B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 44: + switch (hparams.n_ff()) { + case 24576: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_ARCTIC: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + if (hparams.n_expert == 128) { + switch (hparams.n_layer) { + case 35: type = LLM_TYPE_10B_128x3_66B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } else { + type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_DEEPSEEK: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + + switch (hparams.n_layer) { + case 28: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_DEEPSEEK2: + { + bool is_lite = (hparams.n_layer == 27); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + if (!is_lite) { + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + } + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + // for compatibility with existing DeepSeek V2 and V2.5 GGUFs + // that have no expert_gating_func model parameter set + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + } + ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); + + switch (hparams.n_layer) { + case 27: type = LLM_TYPE_16B; break; + case 60: type = LLM_TYPE_236B; break; + case 61: type = LLM_TYPE_671B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_PLM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_1_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_CHATGLM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 28: { + if (hparams.n_head(0) == 16) { + type = LLM_TYPE_1_5B; + } else { + type = LLM_TYPE_6B; + } + } break; + case 40: { + if (hparams.n_head(0) == 24) { + type = LLM_TYPE_4B; + } else { + type = LLM_TYPE_9B; + } + } break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_GLM4: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_9B; break; + case 61: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_BITNET: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 26: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_T5: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + + uint32_t dec_start_token_id; + if (ml.get_key(LLM_KV_DECODER_START_TOKEN_ID, dec_start_token_id, false)) { + hparams.dec_start_token_id = dec_start_token_id; + } + + switch (hparams.n_layer) { + case 6: type = LLM_TYPE_60M; break; // t5-small + case 8: type = LLM_TYPE_80M; break; // flan-t5-small + case 12: + switch (hparams.n_ff()) { + case 3072: type = LLM_TYPE_220M; break; // t5-base + case 2048: type = LLM_TYPE_250M; break; // flan-t5-base + default: type = LLM_TYPE_UNKNOWN; + } break; + case 24: + switch (hparams.n_ff()) { + case 4096: type = LLM_TYPE_770M; break; // t5-large + case 2816: type = LLM_TYPE_780M; break; // flan-t5-large + case 16384: type = LLM_TYPE_3B; break; // t5-3b + case 5120: type = LLM_TYPE_3B; break; // flan-t5-xl + case 65536: type = LLM_TYPE_11B; break; // t5-11b + case 10240: type = LLM_TYPE_11B; break; // flan-t5-xxl + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_T5ENCODER: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + type = LLM_TYPE_UNKNOWN; + } break; + case LLM_ARCH_JAIS: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1_3B; break; + case 40: type = LLM_TYPE_13B; break; + /* TODO: add variants */ + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_NEMOTRON: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_4B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_EXAONE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_RWKV6: + case LLM_ARCH_RWKV6QWEN2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); + ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); + ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); + ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); + ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1_6B; break; + case 32: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 61: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_RWKV7: + case LLM_ARCH_ARWKV7: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); + ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + ml.get_key(LLM_KV_ATTENTION_DECAY_LORA_RANK, hparams.n_lora_decay); + ml.get_key(LLM_KV_ATTENTION_ICLR_LORA_RANK, hparams.n_lora_iclr); + ml.get_key(LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, hparams.n_lora_value_res_mix); + ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate, false); + ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); + + switch (hparams.n_layer) { + case 12: type = LLM_TYPE_190M; break; + case 24: + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_450M; break; + case 2048: type = LLM_TYPE_1_5B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 28: + switch (hparams.n_embd) { + case 1536: type = LLM_TYPE_1_5B; break; + case 3584: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 32: type = LLM_TYPE_2_9B; break; // RWKV-7-World + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_GRANITE: + case LLM_ARCH_GRANITE_MOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); + ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_3B; break; + // Add additional layer/vocab/etc checks here for other model sizes + default: type = LLM_TYPE_UNKNOWN; + } + + // For Granite MoE Shared + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); + } break; + case LLM_ARCH_CHAMELEON: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + hparams.f_norm_eps = 1e-5; // eps for qk-norm, torch default + ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_34B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_WAVTOKENIZER_DEC: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_GROUPNORM_EPS, hparams.f_norm_group_eps); + ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + } break; + case LLM_ARCH_BAILINGMOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + switch (hparams.n_layer) { + case 28: type = LLM_TYPE_16B; break; + case 88: type = LLM_TYPE_290B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_DOTS1: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + switch (hparams.n_layer) { + case 62: type = LLM_TYPE_142B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + default: throw std::runtime_error("unsupported model architecture"); + } + + pimpl->n_bytes = ml.n_bytes; + + pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name(); + + if (hparams.f_max_alibi_bias > 0.0f) { + hparams.use_alibi = true; + } + + hparams.rope_type = llama_model_rope_type(this); +} + +void llama_model::load_vocab(llama_model_loader & ml) { + const auto kv = LLM_KV(arch); + + vocab.load(ml, kv); +} + +bool llama_model::load_tensors(llama_model_loader & ml) { + const auto & split_mode = params.split_mode; + const auto & n_gpu_layers = params.n_gpu_layers; + const auto & use_mlock = params.use_mlock; + const auto & tensor_split = params.tensor_split; + + const int n_layer = hparams.n_layer; + + const bool use_mmap_buffer = true; + + LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, ml.use_mmap ? "true" : "false"); + + // build a list of buffer types for the CPU and GPU devices + pimpl->cpu_buft_list = make_cpu_buft_list(devices); + for (auto * dev : devices) { + buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split); + // add CPU buffer types as a fallback + buft_list.insert(buft_list.end(), pimpl->cpu_buft_list.begin(), pimpl->cpu_buft_list.end()); + pimpl->gpu_buft_list.emplace(dev, std::move(buft_list)); + } + + // calculate the split points + bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + n_devices(), [](float x) { return x == 0.0f; }); + std::vector splits(n_devices()); + if (all_zero) { + // default split, by free memory + for (size_t i = 0; i < n_devices(); ++i) { + ggml_backend_dev_t dev = devices[i]; + size_t total; + size_t free; + ggml_backend_dev_memory(dev, &free, &total); + splits[i] = free; + } + } else { + std::copy(tensor_split, tensor_split + n_devices(), splits.begin()); + } + + // sum and normalize the splits to get the split points + float split_sum = 0.0f; + for (size_t i = 0; i < n_devices(); ++i) { + split_sum += splits[i]; + splits[i] = split_sum; + } + for (size_t i = 0; i < n_devices(); ++i) { + splits[i] /= split_sum; + } + + ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev == nullptr) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } + const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0); + const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1); + auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { + const bool is_swa = il < (int) hparams.n_layer && hparams.is_swa(il); + if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) { + LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(cpu_dev), is_swa); + return {cpu_dev, &pimpl->cpu_buft_list}; + } + const int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + n_devices(), float(il - i_gpu_start)/act_gpu_layers) - splits.begin(); + auto * dev = devices.at(layer_gpu); + LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(dev), is_swa); + return {dev, &pimpl->gpu_buft_list.at(dev)}; + }; + + // assign the input layer + // there is very little benefit to offloading the input layer, so always keep it on the CPU + pimpl->dev_input = { cpu_dev, &pimpl->cpu_buft_list }; + + // assign the repeating layers to the devices according to the splits + pimpl->dev_layer.resize(n_layer); + for (int il = 0; il < n_layer; ++il) { + pimpl->dev_layer[il] = get_layer_buft_list(il); + } + + // assign the output layer + pimpl->dev_output = get_layer_buft_list(n_layer); + + // one ggml context per buffer type + int max_n_tensors = ml.n_tensors; + max_n_tensors += 1; // duplicated output tensor + max_n_tensors += n_layer*2; // duplicated rope freq tensors + const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors; + + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ ctx_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + throw std::runtime_error(format("failed to create ggml context")); + } + + ctx_map[buft] = ctx; + pimpl->ctxs.emplace_back(ctx); + + return ctx; + } + return it->second; + }; + + const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED; + const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED; + + // create tensors for the weights + { + // note: cast to int64_t since we will use these for the tensor dimensions + const int64_t n_head = hparams.n_head(); + const int64_t n_head_kv = hparams.n_head_kv(); + const int64_t n_embd = hparams.n_embd; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + const int64_t n_embd_head_k = hparams.n_embd_head_k; + const int64_t n_embd_head_v = hparams.n_embd_head_v; + const int64_t n_ff = hparams.n_ff(); + const int64_t n_embd_gqa = n_embd_v_gqa; + const int64_t n_vocab = vocab.n_tokens(); + const int64_t n_token_types = vocab.n_token_types(); + const int64_t n_rot = hparams.n_rot; + const int64_t n_expert = hparams.n_expert; + const int64_t n_expert_used = hparams.n_expert_used; + const int64_t n_ctx_train = hparams.n_ctx_train; + + if (n_expert > 0 && hparams.n_expert_used == 0) { + throw std::runtime_error("model has expert layers but no expert layers are used"); + } + + int n_moved_tensors = 0; + ggml_tensor * first_moved_tensor = nullptr; + ggml_backend_buffer_type_t first_moved_from_buft = nullptr; + ggml_backend_buffer_type_t first_moved_to_buft = nullptr; + + auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) -> ggml_tensor * { + ggml_tensor * t_meta = ml.get_tensor_meta(tn.str().c_str()); + + if (!t_meta) { + if (flags & TENSOR_NOT_REQUIRED) { + return nullptr; + } + throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str())); + } + + // some models use the token embedding tensor as the output, but since these are used in different layers and with different ops + // the tensor is duplicated + // to handle this, we check if the tensor is duplicated, and if so, we assume that it is being loaded as the output tensor + llm_tensor tn_tensor = tn.tensor; + if (tn.tensor == LLM_TENSOR_TOKEN_EMBD && flags & TENSOR_DUPLICATED) { + tn_tensor = LLM_TENSOR_OUTPUT; + } + + llm_tensor_info info; + try { + info = llm_tensor_info_for(tn_tensor); + } catch (const std::out_of_range & e) { + throw std::runtime_error(format("missing tensor info mapping for %s", tn.str().c_str())); + } + + // skip unused tensors + if (info.op == GGML_OP_NONE) { + const size_t nbytes = ggml_nbytes(t_meta); + LLAMA_LOG_WARN("model has unused tensor %s (size = %zu bytes) -- ignoring\n", tn.str().c_str(), nbytes); + + ml.size_data -= nbytes; + ml.n_created++; + + return nullptr; + } + + // tensors with "bias" suffix are always used with GGML_OP_ADD + ggml_op op; + bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0; + if (bias) { + op = GGML_OP_ADD; + } else { + op = info.op; + } + + // sanity checks + if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) { + if (tn.bid != -1) { + GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str()); + } + } else { + if (tn.bid == -1) { + GGML_ABORT("repeating layer tensor %s used without a layer number", tn.str().c_str()); + } + } + + // select the buffer type for this tensor + buft_list_t * buft_list; + switch (info.layer) { + case LLM_TENSOR_LAYER_INPUT: + buft_list = pimpl->dev_input.buft_list; + break; + case LLM_TENSOR_LAYER_OUTPUT: + buft_list = pimpl->dev_output.buft_list; + break; + case LLM_TENSOR_LAYER_REPEATING: + buft_list = pimpl->dev_layer.at(tn.bid).buft_list; + break; + default: + GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str()); + } + + ggml_backend_buffer_type_t buft = nullptr; + + // check overrides + if (ml.tensor_buft_overrides) { + std::string tensor_name = tn.str(); + for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) { + std::regex pattern(overrides->pattern); + if (std::regex_search(tensor_name, pattern)) { + buft = overrides->buft; + LLAMA_LOG_DEBUG("tensor %s (%zu MiB %s) buffer type overridden to %s\n", + tensor_name.c_str(), + ggml_nbytes(t_meta) / 1024 / 1024, ggml_type_name(t_meta->type), + ggml_backend_buft_name(buft)); + break; + } + } + } + + if (!buft) { + buft = select_weight_buft(hparams, t_meta, op, *buft_list); + if (!buft) { + throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str())); + } + } + + // avoid using a host buffer when using mmap + auto * buft_dev = ggml_backend_buft_get_device(buft); + if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) { + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!cpu_dev) { + throw std::runtime_error("no CPU backend found"); + } + buft = ggml_backend_dev_buffer_type(cpu_dev); + } + + if (buft != buft_list->front().second) { + n_moved_tensors++; + if (!first_moved_tensor) { + first_moved_tensor = t_meta; + first_moved_from_buft = buft_list->front().second; + first_moved_to_buft = buft; + } + } + + ggml_context * ctx = ctx_for_buft(buft); + + // if duplicated, check if the original tensor was allocated in the same buffer type context and avoid creating a new one + if (flags & TENSOR_DUPLICATED) { + ggml_tensor * t = ggml_get_tensor(ctx, tn.str().c_str()); + if (t) { + return t; + } + } + return ml.create_tensor(ctx, tn, ne, flags); + }; + + layers.resize(n_layer); + + // TODO: move to a separate function + const auto tn = LLM_TN(arch); + switch (arch) { + case LLM_ARCH_LLAMA: + case LLM_ARCH_REFACT: + case LLM_ARCH_MINICPM: + case LLM_ARCH_GRANITE: + case LLM_ARCH_GRANITE_MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_expert == 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } + } + } break; + case LLM_ARCH_LLAMA4: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Llama 4 requires n_moe_layer_step > 0"); + for (int i = 0; i < n_layer; ++i) { + bool is_moe_layer = (i + 1) % hparams.n_moe_layer_step == 0; + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + + if (is_moe_layer) { + int n_ff_exp = hparams.n_ff_exp; + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert + const int64_t n_ff_shexp = n_ff_exp; + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + } else { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } + } break; + case LLM_ARCH_DECI: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); + const int64_t n_ff = hparams.n_ff(i); + const int64_t n_head = hparams.n_head(i); + const int64_t n_head_kv = hparams.n_head_kv(i); + + if (n_head_kv == 0 && n_head > 0) { + // linear attention for DeciLMCausalModel + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + } + else if (n_head_kv > 0) { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + } + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + if (n_ff > 0) { + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + } + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_ff > 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } + } break; + case LLM_ARCH_MINICPM3: + { + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); + + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0); + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head_qk_rope/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head_qk_rope/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + } break; + case LLM_ARCH_GROK: + { + if (n_expert == 0) { + throw std::runtime_error("Grok model cannot have zero experts"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + } + } break; + case LLM_ARCH_DBRX: + { + if (n_expert == 0) { + throw std::runtime_error("DBRX model cannot have zero experts"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } + } break; + case LLM_ARCH_BAICHUAN: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_FALCON: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // needs to be on GPU + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_STARCODER: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + // needs to be on GPU + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_BERT: + case LLM_ARCH_NOMIC_BERT: + case LLM_ARCH_NOMIC_BERT_MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_BERT) { + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + } + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + + if (!layer.wqkv) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + } + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) { + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + } else { + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + + if (arch == LLM_ARCH_BERT || arch == LLM_ARCH_NOMIC_BERT_MOE) { + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + } else { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + } + } + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } + } break; + case LLM_ARCH_NEO_BERT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + + output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff*2}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + } + } break; + case LLM_ARCH_JINA_BERT_V2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); // token_type_embeddings + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); // LayerNorm + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); //LayerNorm bias + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {1}, TENSOR_NOT_REQUIRED); + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; // JinaBertLayer + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); //output_dens + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); //output_dens + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); //output_norm + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, layer.ffn_gate ? n_ff : n_ff * 2}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } + } break; + case LLM_ARCH_BLOOM: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_MPT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, TENSOR_NOT_REQUIRED); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // needs to be on GPU + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + // AWQ ScaleActivation layer + layer.ffn_act = create_tensor(tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, TENSOR_NOT_REQUIRED); + } + } break; + case LLM_ARCH_STABLELM: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors, present in Stable LM 2 1.6B + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + + // optional q and k layernorms, present in StableLM 2 12B + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); + + // optional FFN norm, not present in StableLM 2 12B which uses parallel residual + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_QWEN: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff/2}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff/2, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff/2}, 0); + } + } break; + case LLM_ARCH_QWEN2: + case LLM_ARCH_QWEN2VL: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_QWEN2MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for QWEN2MOE"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for QWEN2MOE"); + } + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; + + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + } + } break; + case LLM_ARCH_QWEN3: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_QWEN3MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for QWEN3MOE"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for QWEN3MOE"); + } + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } + } break; + case LLM_ARCH_PHI2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + + if (layer.wqkv == nullptr) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + } + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_PHI3: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0); + + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + } break; + case LLM_ARCH_PHIMOE: + { + const int64_t n_embd_head = n_embd / n_head; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), { n_vocab }, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, TENSOR_NOT_REQUIRED); + if (layer.wqkv == nullptr) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + } + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), { n_embd }, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + } break; + case LLM_ARCH_PLAMO: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_GPT2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_CODESHELL: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if tok embd is NULL, init from output + if (tok_embd == NULL) { + tok_embd = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_ORION: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_INTERNLM2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + // layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_GEMMA: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + } + } break; + case LLM_ARCH_GEMMA2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + } break; + case LLM_ARCH_GEMMA3: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + } break; + case LLM_ARCH_STARCODER2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional bias tensors + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP , "bias", i), { n_ff}, 0); + } + } break; + case LLM_ARCH_MAMBA: + { + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + + // only an expansion factor of 2 is supported for now + if (2 * n_embd != d_inner) { + throw std::runtime_error("only an expansion factor of 2 is supported for now"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0); + + layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0); + + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } + } break; + case LLM_ARCH_XVERSE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_COMMAND_R: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // init output from the input tok embed + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (n_layer >= 64){ + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0); + } + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_COHERE2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + // init output from the input tok embed + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, + TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + } + } + break; + case LLM_ARCH_OLMO: // adapted from LLM_ARCH_LLAMA with norm params removed + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_OLMO2: + { + const int64_t n_embd_head = n_embd / n_head; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_head_kv * n_embd_head}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + } break; + case LLM_ARCH_OLMOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } + } break; + case LLM_ARCH_OPENELM: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // init output from the input tok embed + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + const int64_t n_head = hparams.n_head(i); + const int64_t n_head_qkv = 2*hparams.n_head_kv(i) + n_head; + const int64_t n_ff = hparams.n_ff(i); + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_GPTNEOX: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_ARCTIC: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_norm_exps = create_tensor(tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } + } break; + case LLM_ARCH_DEEPSEEK: + { + + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } + } break; + case LLM_ARCH_DEEPSEEK2: + { + const bool is_lite = (hparams.n_layer == 27); + + const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k_mla = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k; + const int64_t n_embd_head_v_mla = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v; + + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + if (!is_lite) { + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); + } + + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + if (!is_lite) { + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); + } else { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); + } + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, 0); + + // note: only old legacy GGUF files will have the unsplit wkv_b tensor in + if (is_mla) { + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); + } else { + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v_mla)}, 0); + } + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } + } break; + case LLM_ARCH_PLM: + { + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_BITNET: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_sub_norm = create_tensor(tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wq_scale = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wk_scale = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv_scale = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_scale = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_sub_norm = create_tensor(tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate_scale = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_scale = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_scale = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + } break; + case LLM_ARCH_T5: + { + const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm = create_tensor(tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); + + layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); + + layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); + + layer.wq = create_tensor(tn(LLM_TENSOR_DEC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_DEC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_DEC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); + + layer.attn_norm_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM, "weight", i), {n_embd}, 0); + // this tensor seems to be unused in HF transformers implementation + layer.attn_rel_b_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); + + layer.wq_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wk_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_DEC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_T5ENCODER: + { + const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); + + layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); + + layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_JAIS: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_CHATGLM: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + + if (layer.wqkv == nullptr) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + } + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + } + } break; + case LLM_ARCH_GLM4: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + + if (layer.wqkv == nullptr) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + } + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); + + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + } break; + case LLM_ARCH_NEMOTRON: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } + } break; + case LLM_ARCH_EXAONE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_RWKV6: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // Block 0, LN0 + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + const int time_mix_extra_dim = hparams.time_mix_extra_dim; + const int time_decay_extra_dim = hparams.time_decay_extra_dim; + const int head_size = hparams.wkv_head_size; + const int attn_hidden_size = n_embd; + const int ffn_size = hparams.n_ff_arr[0]; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, 0); + + layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0); + layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0); + + layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0); + layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, TENSOR_NOT_REQUIRED); + GGML_ASSERT(!(layer.time_mix_lerp_fused == NULL && layer.time_mix_lerp_w == NULL)); + + layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0); + layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0); + layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0); + layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0); + layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0); + + layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0); + layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0); + layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); + + layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0); + layer.channel_mix_lerp_r = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0); + + layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0); + layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0); + layer.channel_mix_receptance = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "weight", i), {n_embd, n_embd}, 0); + } + + } break; + case LLM_ARCH_RWKV6QWEN2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + const int time_mix_extra_dim = hparams.time_mix_extra_dim; + const int time_decay_extra_dim = hparams.time_decay_extra_dim; + const int head_size = hparams.wkv_head_size; + const int attn_hidden_size = n_embd; + const int n_head_kv = hparams.n_head_kv(); + int attn_key_value_size; + if (n_head_kv == 0 || attn_hidden_size / head_size == n_head_kv) { + attn_key_value_size = attn_hidden_size; + } else { + attn_key_value_size = n_head_kv * head_size; + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0); + layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0); + + layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0); + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0); + + layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, TENSOR_NOT_REQUIRED); + layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0); + layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0); + layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0); + layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {n_embd, attn_key_value_size}, 0); + layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {n_embd, attn_key_value_size}, 0); + layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0); + // optional bias tensors + layer.time_mix_key_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "bias", i), {attn_key_value_size}, TENSOR_NOT_REQUIRED); + layer.time_mix_value_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "bias", i), {attn_key_value_size}, TENSOR_NOT_REQUIRED); + layer.time_mix_receptance_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "bias", i), {attn_hidden_size}, TENSOR_NOT_REQUIRED); + + layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_RWKV7: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // Block 0, LN0 + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + const int n_lora_decay = hparams.n_lora_decay; + const int n_lora_iclr = hparams.n_lora_iclr; + const int n_lora_value_res_mix = hparams.n_lora_value_res_mix; + const int n_lora_gate = hparams.n_lora_gate; + const int attn_hidden_size = n_embd; + const int ffn_size = hparams.n_ff_arr[0]; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, 0); + + layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0); + layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0); + layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0); + + layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0); + layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0); + layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0); + + if (i == 0) { + // actually not used + layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); + layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0); + layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0); + } else { + layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); + layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0); + layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0); + } + + layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, 0); + layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, 0); + + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0); + + layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0); + layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0); + layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0); + + layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); + + layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0); + layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0); + layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); + + layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0); + + layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0); + layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0); + } + + } break; + case LLM_ARCH_ARWKV7: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + const int n_lora_decay = hparams.n_lora_decay; + const int n_lora_iclr = hparams.n_lora_iclr; + const int n_lora_value_res_mix = hparams.n_lora_value_res_mix; + const int n_lora_gate = hparams.n_lora_gate; + const int attn_hidden_size = n_embd; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0); + layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0); + layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0); + + layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0); + layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0); + layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0); + + if (i == 0) { + // actually not used + layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); + layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0); + layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0); + } else { + layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); + layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0); + layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0); + } + + layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, TENSOR_NOT_REQUIRED); + layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, TENSOR_NOT_REQUIRED); + + try { + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0); + } catch(std::runtime_error & e) { + // ARWKV models may not have gate tensors + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0); + } + + layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0); + layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0); + layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0); + + layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); + + layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + + } break; + case LLM_ARCH_CHAMELEON: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_WAVTOKENIZER_DEC: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd_features, n_vocab}, 0); + + conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd_features, hparams.posnet.n_embd}, 0); + conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"), {1, hparams.posnet.n_embd}, 0); + + // posnet + { + const int64_t n_embd = hparams.posnet.n_embd; + + for (uint32_t i = 0; i < hparams.posnet.n_layer; ++i) { + auto & layer = layers[i].posnet; + + // posnet: + // + // - resnet + // - resnet + // - attn + // - resnet + // - resnet + // - norm + // + switch (i) { + case 0: + case 1: + case 3: + case 4: + { + layer.norm1 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", i), {1, n_embd}, 0); + layer.norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias", i), {1, n_embd}, 0); + + layer.conv1 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "weight", i), {3, n_embd, n_embd}, 0); + layer.conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias", i), {1, n_embd}, 0); + + layer.norm2 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", i), {1, n_embd}, 0); + layer.norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias", i), {1, n_embd}, 0); + + layer.conv2 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "weight", i), {3, n_embd, n_embd}, 0); + layer.conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias", i), {1, n_embd}, 0); + } break; + case 2: + { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", i), {1, n_embd}, 0); + + layer.attn_q = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_q_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q, "bias", i), {1, n_embd}, 0); + + layer.attn_k = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_k_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K, "bias", i), {1, n_embd}, 0); + + layer.attn_v = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_v_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V, "bias", i), {1, n_embd}, 0); + + layer.attn_o = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_o_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT, "bias", i), {1, n_embd}, 0); + } break; + case 5: + { + layer.norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0); + layer.norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", i), {1, n_embd}, 0); + } break; + default: GGML_ABORT("unknown posnet layer"); + }; + } + } + + GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd); + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {hparams.posnet.n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {hparams.posnet.n_embd}, 0); + + // convnext + { + const int64_t n_embd = hparams.convnext.n_embd; + + for (uint32_t i = 0; i < hparams.convnext.n_layer; ++i) { + auto & layer = layers[i].convnext; + + layer.dw = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW, "weight", i), {7, 1, n_embd}, 0); + layer.dw_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW, "bias", i), {1, n_embd}, 0); + + layer.norm = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM, "weight", i), {n_embd}, 0); + layer.norm_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM, "bias", i), {n_embd}, 0); + + layer.pw1 = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1, "weight", i), {n_embd, n_ff}, 0); + layer.pw1_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1, "bias", i), {n_ff}, 0); + + layer.pw2 = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2, "weight", i), {n_ff, n_embd}, 0); + layer.pw2_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2, "bias", i), {n_embd}, 0); + + layer.gamma = create_tensor(tn(LLM_TENSOR_CONVNEXT_GAMMA, "weight", i), {n_embd}, 0); + } + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + } + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0); + } break; + case LLM_ARCH_BAILINGMOE: + { + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } break; + case LLM_ARCH_DOTS1: + { + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } + } break; + case LLM_ARCH_ARCEE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + default: + throw std::runtime_error("unknown architecture"); + } + + if (n_moved_tensors > 0) { + LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %d others) cannot be used with preferred buffer type %s, using %s instead\n", + __func__, first_moved_tensor->name, ggml_type_name(first_moved_tensor->type), n_moved_tensors - 1, + ggml_backend_buft_name(first_moved_from_buft), ggml_backend_buft_name(first_moved_to_buft)); + } + } + + ml.done_getting_tensors(); + + ml.init_mappings(true, use_mlock ? &pimpl->mlock_mmaps : nullptr); + pimpl->mappings.reserve(ml.mappings.size()); + + // create the backend buffers + std::vector> ctx_bufs; + ctx_bufs.reserve(ctx_map.size()); + + // Ensure we have enough capacity for the maximum backend buffer we will potentially create + const size_t n_max_backend_buffer = ctx_map.size() * ml.files.size(); + pimpl->bufs.reserve(n_max_backend_buffer); + + for (auto & it : ctx_map) { + ggml_backend_buffer_type_t buft = it.first; + ggml_context * ctx = it.second; + + // skip contexts without tensors + if (ggml_get_first_tensor(ctx) == nullptr) { + continue; + } + + llama_buf_map buf_map; + buf_map.reserve(n_max_backend_buffer); + + // check if it is possible to use buffer_from_host_ptr with this buffer type + ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); + if (!dev) { + // FIXME: workaround for CPU backend buft having a NULL device + dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!dev) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } + } + ggml_backend_dev_props props; + ggml_backend_dev_get_props(dev, &props); + bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr; + bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev); + + if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) { + for (uint32_t idx = 0; idx < ml.files.size(); idx++) { + // only the mmap region containing the tensors in the model is mapped to the backend buffer + // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers + // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size + void * addr = nullptr; + size_t first, last; // NOLINT + ml.get_mapping_range(&first, &last, &addr, idx, ctx); + if (first >= last) { + continue; + } + const size_t max_size = ggml_get_max_tensor_size(ctx); + ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size); + if (buf == nullptr) { + throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); + } + pimpl->bufs.emplace_back(buf); + buf_map.emplace(idx, buf); + } + } + else { + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (buf == nullptr) { + throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); + } + pimpl->bufs.emplace_back(buf); + if (use_mlock && ggml_backend_buffer_is_host(buf)) { + pimpl->mlock_bufs.emplace_back(new llama_mlock); + auto & mlock_buf = pimpl->mlock_bufs.back(); + mlock_buf->init (ggml_backend_buffer_get_base(buf)); + mlock_buf->grow_to(ggml_backend_buffer_get_size(buf)); + } + for (uint32_t idx = 0; idx < ml.files.size(); idx++) { + buf_map.emplace(idx, buf); + } + } + + if (pimpl->bufs.empty()) { + throw std::runtime_error("failed to allocate buffer"); + } + + for (auto & buf : buf_map) { + // indicate that this buffer contains weights + // this is used by ggml_backend_sched to improve op scheduling: ops that use a weight are preferably scheduled to the backend that contains the weight + ggml_backend_buffer_set_usage(buf.second, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + } + + ctx_bufs.emplace_back(ctx, buf_map); + } + + if (llama_supports_gpu_offload()) { + const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); + + LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu); + if (n_gpu_layers > (int) hparams.n_layer) { + LLAMA_LOG_INFO("%s: offloading output layer to GPU\n", __func__); + } + + const int max_backend_supported_layers = hparams.n_layer + 1; + const int max_offloadable_layers = hparams.n_layer + 1; + + LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); + } + + // print memory requirements per buffer type + for (auto & buf : pimpl->bufs) { + LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0); + } + + // populate tensors_by_name + for (auto & ctx : pimpl->ctxs) { + for (auto * cur = ggml_get_first_tensor(ctx.get()); cur != NULL; cur = ggml_get_next_tensor(ctx.get(), cur)) { + tensors_by_name.emplace_back(ggml_get_name(cur), cur); + } + } + + // load tensor data + for (auto & it : ctx_bufs) { + ggml_context * ctx = it.first; + auto & bufs = it.second; + if (!ml.load_all_data(ctx, bufs, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) { + return false; + } + } + + if (use_mmap_buffer) { + for (auto & mapping : ml.mappings) { + pimpl->mappings.emplace_back(std::move(mapping)); + } + } + + return true; +} + +std::string llama_model::arch_name() const { + return llm_arch_name(arch); +} + +std::string llama_model::type_name() const { + return llm_type_name(type); +} + +std::string llama_model::desc() const { + return pimpl->desc_str; +} + +size_t llama_model::size() const { + return pimpl->n_bytes; +} + +size_t llama_model::n_tensors() const { + return tensors_by_name.size(); +} + +size_t llama_model::n_devices() const { + return devices.size(); +} + +uint64_t llama_model::n_elements() const { + return pimpl->n_elements; +} + +void llama_model::print_info() const { + const std::string rope_scaling_type = llama_rope_scaling_type_name(hparams.rope_scaling_type_train); + + auto print_f = [](const std::function & f, uint32_t n) { + bool is_var = false; + + std::vector v; + for (uint32_t i = 0; i < n; ++i) { + v.push_back(f(i)); + if (v[i] != v[0]) { + is_var = true; + } + } + + std::stringstream ss; + + if (is_var) { + ss << "["; + for (uint32_t i = 0; i < n; ++i) { + ss << v[i]; + if (i < n - 1) { + ss << ", "; + } + } + ss << "]"; + } else { + ss << v[0]; + } + + return ss.str(); + }; + + // hparams + LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str()); + LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only); + + if (!hparams.vocab_only) { + LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); + LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); + LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); + LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); + LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); + LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); + LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k); + LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v); + LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); + LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); + LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); + LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); + LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale); + LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); + LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); + LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn); + LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type); + LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type); + LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str()); + LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); + LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); + LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); + LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); + LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); + LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); + LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); + LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); + + if (!classifier_labels.empty()) { + LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); + + size_t i = 0; + for (auto label : classifier_labels) { + LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str()); + } + } + } + + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); + if (pimpl->n_elements >= 1e12) { + LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); + } else if (pimpl->n_elements >= 1e9) { + LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); + } else if (pimpl->n_elements >= 1e6) { + LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); + } else { + LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); + } + + // general kv + LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); + + if (arch == LLM_ARCH_DEEPSEEK) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + } + + if (arch == LLM_ARCH_DEEPSEEK2) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); + LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); + LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla); + LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); + } + + if (arch == LLM_ARCH_QWEN2MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + } + + if (arch == LLM_ARCH_QWEN3MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + } + + if (arch == LLM_ARCH_MINICPM || + arch == LLM_ARCH_GRANITE || + arch == LLM_ARCH_GRANITE_MOE) { + LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); + LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); + LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + } + + if (arch == LLM_ARCH_BAILINGMOE) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + } + + vocab.print_info(); +} + +ggml_backend_dev_t llama_model::dev_layer(int il) const { + return pimpl->dev_layer.at(il).dev; +} + +ggml_backend_dev_t llama_model::dev_output() const { + return pimpl->dev_output.dev; +} + +template +static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) { + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead()*8, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx { ggml_init(params) }; + if (!ctx) { + throw std::runtime_error(format("failed to create ggml context")); + } + + ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) }; + ggml_tensor * op_tensor = fn(ctx.get()); + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (op_tensor->src[i] != nullptr) { + assert(op_tensor->src[i]->buffer == nullptr); + op_tensor->src[i]->buffer = buf.get(); + } + } + + bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + + return op_supported; +} + +template +static ggml_backend_buffer_type_t select_buft(const buft_list_t & buft_list, const F & fn) { + for (const auto & cur : buft_list) { + ggml_backend_dev_t cur_dev = cur.first; + ggml_backend_buffer_type_t cur_buft = cur.second; + if (buft_supported(cur_buft, cur_dev, fn)) { + return cur_buft; + } + } + + throw std::runtime_error(format("no suitable buffer type found")); +} + +ggml_backend_buffer_type_t llama_model::select_buft(int il) const { + return ::select_buft( + *pimpl->dev_layer.at(il).buft_list, + [&](ggml_context * ctx) { + ggml_tensor * cur = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); + ggml_tensor * layer_dir = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); + return ggml_add(ctx, cur, layer_dir); + }); +} + +bool llama_model::has_tensor_overrides() const { + return pimpl->has_tensor_overrides; +} + +const ggml_tensor * llama_model::get_tensor(const char * name) const { + auto it = std::find_if(tensors_by_name.begin(), tensors_by_name.end(), + [name](const std::pair & it) { + return it.first == name; + }); + if (it == tensors_by_name.end()) { + return nullptr; + } + + return it->second; +} + +float llama_model::get_rope_freq_base (const llama_cparams & cparams, int il) const { + return hparams.is_swa(il) ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; +} + +float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) const { + return hparams.is_swa(il) ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale; +} + +ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const { + const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; + + // choose long/short freq factors based on the context size + if (layers[il].rope_freqs != nullptr) { + return layers[il].rope_freqs; + } + + if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) { + return layers[il].rope_long; + } + + return layers[il].rope_short; +} + +struct llm_build_llama : public llm_graph_context { + llm_build_llama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network (non-MoE) + if (model.layers[il].ffn_gate_inp == nullptr) { + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_llama_iswa : public llm_graph_context { + llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // temperature tuning + ggml_tensor * inp_attn_scale = nullptr; + inp_attn_scale = build_inp_attn_scale(); + + auto * inp_attn = build_attn_inp_kv_unified_iswa(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + if (use_rope) { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } else if (inp_attn_scale) { + Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + if (use_rope && hparams.use_kq_norm) { + // Llama4TextL2Norm + Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps); + Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps); + cb(Qcur, "Qcur_normed", il); + cb(Kcur, "Kcur_normed", il); + } + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network (non-MoE) + if (model.layers[il].ffn_gate_inp == nullptr) { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + ggml_tensor * ffn_inp_normed = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = build_moe_ffn(ffn_inp_normed, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, + il); + + // Shared experts + ggml_tensor * shexp_out = build_ffn(ffn_inp_normed, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(shexp_out, "ffn_moe_shexp", il); + + cur = ggml_add(ctx0, moe_out, shexp_out); + cb(cur, "ffn_moe_out_merged", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_deci : public llm_graph_context { + llm_build_deci(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_head = hparams.n_head(il); + const int64_t n_ff = hparams.n_ff(il); + + if (n_head == 0) { + // attention-free layer of Llama-3_1-Nemotron-51B + cur = inpL; + } else { + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + } + + if (n_head > 0 && n_head_kv == 0) { + // "linear attention" of Llama-3_1-Nemotron-51B + cur = build_lora_mm(model.layers[il].wo, cur); + cb(cur, "wo", il); + } else if (n_head > 0) { + // self-attention + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // FFN-free layer of Llama-3_1-Nemotron-Ultra-253B + if (n_ff == 0) { + continue; + } + + // modified to support attention-free layer of Llama-3_1-Nemotron-51B + ggml_tensor * ffn_inp = cur; + if (n_head > 0) { + ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + } + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_baichuan : public llm_graph_context { + llm_build_baichuan(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = model.type == LLM_TYPE_7B ? build_inp_pos() : nullptr; + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + switch (model.type) { + case LLM_TYPE_7B: + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + break; + case LLM_TYPE_13B: + break; + default: + GGML_ABORT("fatal error"); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_xverse : public llm_graph_context { + llm_build_xverse(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_falcon : public llm_graph_context { + llm_build_falcon(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * attn_norm; + + attn_norm = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(attn_norm, "attn_norm", il); + + // self-attention + { + if (model.layers[il].attn_norm_2) { + // Falcon-40B + cur = build_norm(inpL, + model.layers[il].attn_norm_2, + model.layers[il].attn_norm_2_b, + LLM_NORM, il); + cb(cur, "attn_norm_2", il); + } else { + cur = attn_norm; + } + + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // using mode = 2 for neox mode + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + attn_norm = ggml_get_rows(ctx0, attn_norm, inp_out_ids); + } + + ggml_tensor * ffn_inp = cur; + + // feed forward + { + cur = build_ffn(attn_norm, // !! use the attn norm, not the result + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cur = ggml_add(ctx0, cur, inpL); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + // norm + cur = build_norm(cur, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_grok : public llm_graph_context { + llm_build_grok(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // multiply by embedding_multiplier_scale of 78.38367176906169 + inpL = ggml_scale(ctx0, inpL, 78.38367176906169f); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Grok + // if attn_out_norm is present then apply it before adding the input + if (model.layers[il].attn_out_norm) { + cur = build_norm(cur, + model.layers[il].attn_out_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_out_norm", il); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_GELU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); + + // Grok + // if layer_out_norm is present then apply it before adding the input + // Idea: maybe ffn_out_norm is a better name + if (model.layers[il].layer_out_norm) { + cur = build_norm(cur, + model.layers[il].layer_out_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "layer_out_norm", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + // Grok + // multiply logits by output_multiplier_scale of 0.5773502691896257 + + cur = ggml_scale(ctx0, cur, 0.5773502691896257f); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_dbrx : public llm_graph_context { + llm_build_dbrx(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(cur, "wqkv_clamped", il); + + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].attn_out_norm, NULL, + LLM_NORM, il); + cb(cur, "attn_out_norm", il); + + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_starcoder : public llm_graph_context { + llm_build_starcoder(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + ggml_tensor * pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); + cb(pos, "pos_embd", -1); + + inpL = ggml_add(ctx0, inpL, pos); + cb(inpL, "inpL", -1); + + for (int il = 0; il < n_layer; ++il) { + cur = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // add the input + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = build_norm(inpL, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_refact : public llm_graph_context { + llm_build_refact(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_bert : public llm_graph_context { + llm_build_bert(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + ggml_tensor * inp_pos = nullptr; + + if (model.arch != LLM_ARCH_JINA_BERT_V2) { + inp_pos = build_inp_pos(); + } + + // construct input embeddings (token, type, position) + inpL = build_inp_embd(model.tok_embd); + + // token types are hardcoded to zero ("Sentence A") + if (model.type_embd) { + ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0); + inpL = ggml_add(ctx0, inpL, type_row0); + } + if (model.arch == LLM_ARCH_BERT) { + inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL); + } + cb(inpL, "inp_embd", -1); + + // embed layer norm + inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); + cb(inpL, "inp_norm", -1); + + auto * inp_attn = build_attn_inp_no_cache(); + + // iterate layers + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * cur = inpL; + + ggml_tensor * Qcur; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + + // self-attention + if (model.layers[il].wqkv) { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + if (model.layers[il].bqkv) { + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + } + + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + } else { + Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq); + Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk); + Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv); + } + + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, + model.layers[il].attn_q_norm_b, + LLM_NORM, il); + } + + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, + model.layers[il].attn_k_norm_b, + LLM_NORM, il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // RoPE + if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + cb(cur, "kqv_out", il); + + if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // re-add the layer input + cur = ggml_add(ctx0, cur, inpL); + + // attention layer norm + cur = build_norm(cur, model.layers[il].attn_out_norm, model.layers[il].attn_out_norm_b, LLM_NORM, il); + + if (model.layers[il].attn_norm_2 != nullptr) { + cur = ggml_add(ctx0, cur, inpL); // re-add the layer input + cur = build_norm(cur, model.layers[il].attn_norm_2, model.layers[il].attn_norm_2_b, LLM_NORM, il); + } + + ggml_tensor * ffn_inp = cur; + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + if (hparams.moe_every_n_layers > 0 && il % hparams.moe_every_n_layers == 1) { + // MoE branch + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + nullptr, + model.layers[il].ffn_down_exps, + nullptr, + hparams.n_expert, + hparams.n_expert_used, + LLM_FFN_GELU, + false, false, + 0.0f, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); + cb(cur, "ffn_moe_out", il); + } else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) { + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + } else if (model.arch == LLM_ARCH_JINA_BERT_V2) { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_GEGLU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + // attentions bypass the intermediate layer + cur = ggml_add(ctx0, cur, ffn_inp); + + // output layer norm + cur = build_norm(cur, model.layers[il].layer_out_norm, model.layers[il].layer_out_norm_b, LLM_NORM, il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cb(cur, "result_embd", -1); + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_neo_bert : public llm_graph_context { + llm_build_neo_bert(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + ggml_tensor * inp_pos = build_inp_pos(); + + // construct input embeddings (token, type, position) + inpL = build_inp_embd(model.tok_embd); + cb(inpL, "inp_embd", -1); + + auto * inp_attn = build_attn_inp_no_cache(); + + // iterate layers + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * cur = inpL; + + ggml_tensor * Qcur; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + + // pre-norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + + // self-attention + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // RoPE + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + cb(cur, "kqv_out", il); + + if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // re-add the layer input + cur = ggml_add(ctx0, cur, inpL); + + ggml_tensor * ffn_inp = cur; + cb(ffn_inp, "ffn_inp", il); + + // pre-norm + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + cur = build_ffn(cur, + model.layers[il].ffn_up, + NULL, NULL, NULL, NULL, NULL, + model.layers[il].ffn_down, + NULL, NULL, NULL, + LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); + + // attentions bypass the intermediate layer + cur = ggml_add(ctx0, cur, ffn_inp); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm_enc, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_embd", -1); + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_bloom : public llm_graph_context { + llm_build_bloom(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + auto * inp_attn = build_attn_inp_kv_unified(); + + inpL = build_norm(inpL, + model.tok_norm, + model.tok_norm_b, + LLM_NORM, -1); + cb(inpL, "inp_norm", -1); + + for (int il = 0; il < n_layer; ++il) { + cur = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // Add the input + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = build_norm(inpL, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_mpt : public llm_graph_context { + llm_build_mpt(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * pos; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + auto * inp_attn = build_attn_inp_kv_unified(); + + if (model.pos_embd) { + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); + cb(pos, "pos_embd", -1); + + inpL = ggml_add(ctx0, inpL, pos); + cb(inpL, "inpL", -1); + } + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * attn_norm; + + attn_norm = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(attn_norm, "attn_norm", il); + + // self-attention + { + cur = attn_norm; + + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + if (model.layers[il].bqkv){ + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + } + + if (hparams.f_clamp_kqv > 0.0f) { + cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(cur, "wqkv_clamped", il); + } + + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // Q/K Layernorm + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, + model.layers[il].attn_q_norm_b, + LLM_NORM, il); + cb(Qcur, "Qcur", il); + + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, + model.layers[il].attn_k_norm_b, + LLM_NORM, il); + cb(Kcur, "Kcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // Add the input + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // feed forward + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + model.layers[il].ffn_act, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_stablelm : public llm_graph_context { + llm_build_stablelm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + ggml_tensor * inpSA = cur; + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, + NULL, + LLM_NORM, il); + cb(Qcur, "Qcur", il); + } + + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, + NULL, + LLM_NORM, il); + cb(Kcur, "Kcur", il); + } + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + if (model.layers[il].ffn_norm) { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + } else { + // parallel residual + cur = inpSA; + } + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_qwen : public llm_graph_context { + llm_build_qwen(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd))); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // using mode = 2 for neox mode + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward forward + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_qwen2 : public llm_graph_context { + llm_build_qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_qwen2vl : public llm_graph_context { + llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_qwen2moe : public llm_graph_context { + llm_build_qwen2moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + + // FFN shared expert + { + ggml_tensor * cur_gate_inp = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur); + cb(cur_gate_inp, "ffn_shexp_gate_inp", il); + + // sigmoid + ggml_tensor * cur_gate = ggml_div(ctx0, ggml_silu(ctx0, cur_gate_inp), cur_gate_inp); + cb(cur_gate, "ffn_shexp_gate", il); + + ggml_tensor * cur_ffn = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur_ffn, "ffn_shexp", il); + + ggml_tensor * ffn_shexp_out = ggml_mul(ctx0, cur_ffn, cur_gate); + cb(ffn_shexp_out, "ffn_shexp_out", il); + + moe_out = ggml_add(ctx0, moe_out, ffn_shexp_out); + cb(moe_out, "ffn_out", il); + + cur = moe_out; + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_qwen3 : public llm_graph_context { + llm_build_qwen3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_qwen3moe : public llm_graph_context { + llm_build_qwen3moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + cur = moe_out; + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_phi2 : public llm_graph_context { + llm_build_phi2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * attn_norm_output; + ggml_tensor * ffn_output; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + attn_norm_output = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(attn_norm_output, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + + if (model.layers[il].wqkv) { + cur = build_lora_mm(model.layers[il].wqkv, attn_norm_output); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + } else { + Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); + Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); + Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // with phi2, we scale the Q to avoid precision issues + // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66 + Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head))); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + attn_norm_output = ggml_get_rows(ctx0, attn_norm_output, inp_out_ids); + } + + // FF + { + ffn_output = build_ffn(attn_norm_output, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(ffn_output, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_output); + cur = ggml_add(ctx0, cur, inpL); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = build_norm(inpL, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output_no_bias", -1); + + cur = ggml_add(ctx0, cur, model.output_b); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +template +struct llm_build_phi3 : public llm_graph_context { + llm_build_phi3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + using inp_attn_type = std::conditional_t; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_unified_iswa(); + } else { + inp_attn = build_attn_inp_kv_unified(); + } + + for (int il = 0; il < n_layer; ++il) { + auto * residual = inpL; + + // self-attention + { + // rope freq factors for 128k context + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + ggml_tensor* attn_norm_output = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM_RMS, il); + cb(attn_norm_output, "attn_norm", il); + + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + + if (model.layers[il].wqkv) { + cur = build_lora_mm(model.layers[il].wqkv, attn_norm_output); + cb(cur, "wqkv", il); + + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa))); + } else { + Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); + Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); + Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); + cb(Qcur, "Qcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor* inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + residual = ggml_get_rows(ctx0, residual, inp_out_ids); + } + + cur = ggml_add(ctx0, cur, residual); + residual = cur; + + cur = build_norm(cur, + model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); + } + + cur = ggml_add(ctx0, residual, cur); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = build_norm(inpL, + model.output_norm, + model.output_norm_b, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + if (model.output_b != nullptr) { + cb(cur, "result_output_no_bias", -1); + cur = ggml_add(ctx0, cur, model.output_b); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_plamo : public llm_graph_context { + llm_build_plamo(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + ggml_tensor * attention_norm = cur; + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + ggml_tensor * sa_out = cur; + + cur = attention_norm; + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + sa_out = ggml_get_rows(ctx0, sa_out, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // feed-forward network + { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, sa_out); + cur = ggml_add(ctx0, cur, inpL); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_gpt2 : public llm_graph_context { + llm_build_gpt2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * pos; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); + cb(pos, "pos_embd", -1); + + inpL = ggml_add(ctx0, inpL, pos); + cb(inpL, "inpL", -1); + + for (int il = 0; il < n_layer; ++il) { + cur = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // add the input + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = build_norm(inpL, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_codeshell : public llm_graph_context { + llm_build_codeshell(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + cur = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // add the input + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = build_norm(inpL, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_orion : public llm_graph_context { + llm_build_orion(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + // if (model.layers[il].bq) { + // Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + // cb(Qcur, "Qcur", il); + // } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + // if (model.layers[il].bk) { + // Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + // cb(Kcur, "Kcur", il); + // } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + // if (model.layers[il].bv) { + // Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + // cb(Vcur, "Vcur", il); + // } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_internlm2 : public llm_graph_context { + llm_build_internlm2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_minicpm3 : public llm_graph_context { + llm_build_minicpm3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + //TODO: if the model varies, these parameters need to be read from the model + const int64_t n_embd_base = 256; + const float scale_embd = 12.0f; + const float scale_depth = 1.4f; + const float kq_scale = 1.0f / sqrtf(float(hparams.n_embd_head_k)); + + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // scale the input embeddings + inpL = ggml_scale(ctx0, inpL, scale_embd); + cb(inpL, "inp_scaled", -1); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + ggml_tensor * q = NULL; + // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens} + q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); + cb(q, "q", il); + + q = build_norm(q, + model.layers[il].attn_q_a_norm, NULL, + LLM_NORM_RMS, il); + cb(q, "q", il); + + // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens} + q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q); + cb(q, "q", il); + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(q->type, hparams.n_embd_head_k), + ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + 0); + cb(q_nope, "q_nope", il); + + // and {n_head * n_embd_head_qk_rope, n_tokens} + ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, + ggml_row_size(q->type, hparams.n_embd_head_k), + ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + ggml_row_size(q->type, n_embd_head_qk_nope)); + cb(q_pe, "q_pe", il); + + // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens} + ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_pe_compresseed, "kv_pe_compresseed", il); + + // split into {kv_lora_rank, n_tokens} + ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens, + kv_pe_compresseed->nb[1], + 0); + cb(kv_compressed, "kv_compressed", il); + + // and {n_embd_head_qk_rope, n_tokens} + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, + kv_pe_compresseed->nb[1], + kv_pe_compresseed->nb[1], + ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); + cb(k_pe, "k_pe", il); + + // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont + kv_compressed = ggml_cont(ctx0, kv_compressed); + kv_compressed = build_norm(kv_compressed, + model.layers[il].attn_kv_a_norm, NULL, + LLM_NORM_RMS, il); + cb(kv_compressed, "kv_compressed", il); + + // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} + ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); + cb(kv, "kv", il); + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + 0); + cb(k_nope, "k_nope", il); + + // and {n_head * n_embd_head_v, n_tokens} + ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), + ggml_row_size(kv->type, (n_embd_head_qk_nope))); + cb(v_states, "v_states", il); + + v_states = ggml_cont(ctx0, v_states); + cb(v_states, "v_states", il); + + v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, + ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), + 0); + cb(v_states, "v_states", il); + + q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this + q_pe = ggml_rope_ext( + ctx0, q_pe, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(q_pe, "q_pe", il); + + // shared RoPE key + k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this + k_pe = ggml_rope_ext( + ctx0, k_pe, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(k_pe, "k_pe", il); + + ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); + cb(q_states, "q_states", il); + + ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); + cb(k_states, "k_states", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + q_states, k_states, v_states, nullptr, nullptr, kq_scale, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // scale_res - scale the hidden states for residual connection + const float scale_res = scale_depth/sqrtf(float(n_layer)); + cur = ggml_scale(ctx0, cur, scale_res); + cb(cur, "hidden_scaled", il); + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + // scale the hidden states for residual connection + cur = ggml_scale(ctx0, cur, scale_res); + cb(cur, "hidden_scaled_ffn", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head scaling + const float scale_lmhead = float(n_embd_base)/float(n_embd); + cur = ggml_scale(ctx0, cur, scale_lmhead); + cb(cur, "lmhead_scaling", -1); + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_gemma : public llm_graph_context { + llm_build_gemma(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); + cb(inpL, "inp_scaled", -1); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); + cb(Qcur, "Qcur_scaled", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); + cb(sa_out, "sa_out", il); + + cur = build_norm(sa_out, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, sa_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_gemma2_iswa : public llm_graph_context { + llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_k; + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); + cb(inpL, "inp_scaled", -1); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified_iswa(); + + for (int il = 0; il < n_layer; ++il) { + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); + } + + cur = build_norm(cur, + model.layers[il].attn_post_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); + cb(sa_out, "sa_out", il); + + cur = build_norm(sa_out, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = build_norm(cur, + model.layers[il].ffn_post_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", -1); + + cur = ggml_add(ctx0, cur, sa_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + // final logit soft-capping + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); + cur = ggml_tanh(ctx0, cur); + cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_gemma3_iswa : public llm_graph_context { + llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_k; + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + if (ubatch.token) { + inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); + cb(inpL, "inp_scaled", -1); + } + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // TODO: is causal == true correct? might need some changes + auto * inp_attn = build_attn_inp_kv_unified_iswa(); + + for (int il = 0; il < n_layer; ++il) { + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315 + Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); + } + + cur = build_norm(cur, + model.layers[il].attn_post_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); + cb(sa_out, "sa_out", il); + + cur = build_norm(sa_out, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = build_norm(cur, + model.layers[il].ffn_post_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", -1); + + cur = ggml_add(ctx0, cur, sa_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +// TODO: move up next to build_starcoder +struct llm_build_starcoder2 : public llm_graph_context { + llm_build_starcoder2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_mamba : public llm_graph_context { + const llama_model & model; + + llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) { + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * state_copy = build_inp_s_copy(); + + for (int il = 0; il < n_layer; ++il) { + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + cur = build_mamba_layer(gf, cur, state_copy, ubatch, il); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // residual + cur = ggml_add(ctx0, cur, inpL); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + // final rmsnorm + cur = build_norm(inpL, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } + + // TODO: split + ggml_tensor * build_mamba_layer( + ggml_cgraph * gf, + ggml_tensor * cur, + ggml_tensor * state_copy, + const llama_ubatch & ubatch, + int il) const { + const auto * kv_state = static_cast(mstate); + + const auto kv_head = kv_state->get_head(); + + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + const int64_t n_seqs = ubatch.n_seqs; + // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) + const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; + // Use the same RMS norm as the final layer norm + const float norm_rms_eps = hparams.f_norm_rms_eps; + + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + ggml_tensor * conv_states_all = kv_state->get_k_l(il); + ggml_tensor * ssm_states_all = kv_state->get_v_l(il); + + // (ab)using the KV cache to store the states + ggml_tensor * conv = build_recurrent_state( + gf, conv_states_all, state_copy, + hparams.n_embd_k_s(), n_seqs); + conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); + ggml_tensor * ssm = build_recurrent_state( + gf, ssm_states_all, state_copy, + hparams.n_embd_v_s(), n_seqs); + ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} + ggml_tensor * xz = build_lora_mm(model.layers[il].ssm_in, cur); + // split the above in two + // => {d_inner, n_seq_tokens, n_seqs} + ggml_tensor * x = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0); + ggml_tensor * z = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], d_inner*ggml_element_size(xz)); + + // conv + { + // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs} + ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0); + + // copy last (d_conv - 1) columns back into the state cache + ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, last_conv, + ggml_view_1d(ctx0, conv_states_all, + (d_conv - 1)*(d_inner)*(n_seqs), + kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + + // 1D convolution + // The equivalent is to make a self-overlapping view of conv_x + // over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weight, + // then sum the elements of each row, + // (the last two steps are a dot product over rows (also doable with mul_mat)) + // then permute away the ne[0] dimension, + // and then you're left with the resulting x tensor. + // For simultaneous sequences, all sequences need to have the same length. + x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); + + // bias + x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b); + + x = ggml_silu(ctx0, x); + } + + // ssm + { + // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} + ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_x, x); + // split + ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0); + ggml_tensor * B = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); + ggml_tensor * C = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); + + // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers + if (ssm_dt_b_c_rms) { + dt = ggml_rms_norm(ctx0, dt, norm_rms_eps); + B = ggml_rms_norm(ctx0, B, norm_rms_eps); + C = ggml_rms_norm(ctx0, C, norm_rms_eps); + } + + // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} + dt = build_lora_mm(model.layers[il].ssm_dt, dt); + dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); + + // Custom operator to optimize the parallel associative scan + // as described in the Annex D of the Mamba paper. + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C); + + // store last states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]), + ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0); + + // TODO: skip computing output earlier for unused tokens + + // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} + y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); + y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); + + // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + cur = build_lora_mm(model.layers[il].ssm_out, y); + } + + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); + //cb(cur, "mamba_out", il); + + return cur; + } +}; + +struct llm_build_command_r : public llm_graph_context { + llm_build_command_r(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + const float f_logit_scale = hparams.f_logit_scale; + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM, il); + cb(cur, "attn_norm", il); + ggml_tensor * ffn_inp = cur; + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, + NULL, + LLM_NORM, il); + cb(Qcur, "Qcur", il); + } + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, + NULL, + LLM_NORM, il); + cb(Kcur, "Kcur", il); + } + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids); + } + + ggml_tensor * attn_out = cur; + + // feed-forward network + { + cur = build_ffn(ffn_inp, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + // add together residual + FFN + self-attention + cur = ggml_add(ctx0, cur, inpL); + cur = ggml_add(ctx0, cur, attn_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + if (f_logit_scale) { + cur = ggml_scale(ctx0, cur, f_logit_scale); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_cohere2_iswa : public llm_graph_context { + llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + const float f_logit_scale = hparams.f_logit_scale; + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified_iswa(); + + for (int il = 0; il < n_layer; ++il) { + const bool is_swa = hparams.is_swa(il); + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM, il); + cb(cur, "attn_norm", il); + ggml_tensor * ffn_inp = cur; + + // self-attention + { + // rope freq factors for 128k context + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + if (is_swa) { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids); + } + + ggml_tensor * attn_out = cur; + + // feed-forward network + { + cur = build_ffn(ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, + NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, + il); + cb(cur, "ffn_out", il); + } + + // add together residual + FFN + self-attention + cur = ggml_add(ctx0, cur, inpL); + cur = ggml_add(ctx0, cur, attn_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + if (f_logit_scale) { + cur = ggml_scale(ctx0, cur, f_logit_scale); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +// ref: https://allenai.org/olmo +// based on the original build_llama() function, changes: +// * non-parametric layer norm +// * clamp qkv +// * removed bias +// * removed MoE +struct llm_build_olmo : public llm_graph_context { + llm_build_olmo(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + NULL, NULL, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (hparams.f_clamp_kqv > 0.0f) { + Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (hparams.f_clamp_kqv > 0.0f) { + Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (hparams.f_clamp_kqv > 0.0f) { + Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + NULL, NULL, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + NULL, NULL, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_olmo2 : public llm_graph_context { + llm_build_olmo2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = inpL; + + // self_attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, + LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, + LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + cur = build_norm(cur, + model.layers[il].attn_post_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_ffn(ffn_inp, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = build_norm(cur, + model.layers[il].ffn_post_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", -1); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +// based on the build_qwen2moe() function, changes: +// * removed shared experts +// * removed bias +// * added q, k norm +struct llm_build_olmoe : public llm_graph_context { + llm_build_olmoe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, + LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, + LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_openelm : public llm_graph_context { + llm_build_openelm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + const int64_t n_head = hparams.n_head(il); + const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_head_qkv = 2*n_head_kv + n_head; + + cur = inpL; + ggml_tensor * residual = cur; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_reshape_3d(ctx0, cur, n_embd_head_k, n_head_qkv, n_tokens); + + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, cur->nb[1], cur->nb[2], 0)); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head)); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv))); + cb(Vcur, "Vcur", il); + + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, NULL, + LLM_NORM_RMS, il); + cb(Qcur, "Qcur", il); + + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, NULL, + LLM_NORM_RMS, il); + cb(Kcur, "Kcur", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, NULL, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, NULL, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Qcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + residual = ggml_get_rows(ctx0, residual, inp_out_ids); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + inpL = cur; + } + + cur = inpL; + + // norm + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_gptneox : public llm_graph_context { + llm_build_gptneox(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + cur = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // ffn + if (hparams.use_par_res) { + // attention and ffn are computed in parallel + // x = x + attn(ln1(x)) + ffn(ln2(x)) + + ggml_tensor * attn_out = cur; + + cur = build_norm(inpL, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, inpL); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, attn_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } else { + // attention and ffn are computed sequentially + // x = x + attn(ln1(x)) + // x = x + ffn(ln2(x)) + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + } + + cur = build_norm(inpL, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_arctic : public llm_graph_context { + llm_build_arctic(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + ggml_tensor * ffn_out = ggml_add(ctx0, cur, ffn_inp); + cb(ffn_out, "ffn_out", il); + + // MoE + cur = build_norm(inpSA, + model.layers[il].ffn_norm_exps, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm_exps", il); + + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); + + cur = ggml_add(ctx0, cur, ffn_out); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_deepseek : public llm_graph_context { + llm_build_deepseek(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + if ((uint32_t) il < hparams.n_layer_dense_lead) { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + false, hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + + // FFN shared expert + { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_deepseek2 : public llm_graph_context { + llm_build_deepseek2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + bool is_lite = (hparams.n_layer == 27); + + const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k; + const int64_t n_embd_head_v = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v; + + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope; + + const uint32_t kv_lora_rank = hparams.n_lora_kv; + + // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. + // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. + const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); + const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(n_embd_head_k)); + const float attn_factor = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); + + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + ggml_tensor * q = NULL; + if (!is_lite) { + q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); + cb(q, "q", il); + + q = build_norm(q, + model.layers[il].attn_q_a_norm, nullptr, + LLM_NORM_RMS, il); + cb(q, "q", il); + + q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q); + cb(q, "q", il); + } else { + q = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(q, "q", il); + } + + // split into {n_embd_head_qk_nope, n_head, n_tokens} + ggml_tensor * q_nope = ggml_view_3d(ctx0, q, + n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, + 0); + cb(q_nope, "q_nope", il); + + // and {n_embd_head_qk_rope, n_head, n_tokens} + ggml_tensor * q_pe = ggml_view_3d(ctx0, q, + n_embd_head_qk_rope, n_head, n_tokens, + ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, + ggml_row_size(q->type, n_embd_head_qk_nope)); + cb(q_pe, "q_pe", il); + + ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_cmpr_pe, "kv_cmpr_pe", il); + + // split into {kv_lora_rank, n_tokens} + ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe, + kv_lora_rank, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + 0); + cb(kv_cmpr, "kv_cmpr", il); + + // and {n_embd_head_qk_rope, 1, n_tokens} + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, + n_embd_head_qk_rope, 1, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank)); + cb(k_pe, "k_pe", il); + + q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(q_pe, "q_pe", il); + + k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(k_pe, "k_pe", il); + + kv_cmpr = build_norm(kv_cmpr, + model.layers[il].attn_kv_a_norm, nullptr, + LLM_NORM_RMS, il); + cb(kv_cmpr, "kv_cmpr", il); + + if (is_mla) { + // {n_embd_head_qk_nope, n_tokens, n_head} + q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); + cb(q_nope, "q_nope_perm", il); + + // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head} + ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope); + cb(q_nope_absorbed, "q_nope_absorbed", il); + + // {kv_lora_rank, n_head, n_tokens} + q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3); + cb(q_nope_absorbed, "q_nope_absorbed_perm", il); + + // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} + // note: rope must go first for in-place context shifting in build_rope_shift() + ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0); + cb(Qcur, "Qcur", il); + + kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens); + cb(kv_cmpr, "kv_cmpr_reshape", il); + + // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens} + ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0); + cb(Kcur, "Kcur", il); + + // {kv_lora_rank, 1, n_tokens} + ggml_tensor * Vcur = kv_cmpr; + cb(Vcur, "Vcur", il); + + // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group) + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, model.layers[il].wv_b, kq_scale, il); + } else { + ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr); + cb(kv, "kv", il); + + // split into {n_embd_head_qk_nope, n_head, n_tokens} + ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, + n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v), + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head, + 0); + cb(k_nope, "k_nope_view", il); + + // and {n_embd_head_v, n_head, n_tokens} + ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, + n_embd_head_v, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v), + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head, + ggml_row_size(kv->type, n_embd_head_qk_nope)); + cb(Vcur, "Vcur_view", il); + + Vcur = ggml_cont(ctx0, Vcur); + cb(Vcur, "Vcur_cont", il); + + // note: rope must go first for in-place context shifting in build_rope_shift() + ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope, 0); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0); + cb(Kcur, "Kcur", il); + + // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups) + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + } + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + if ((uint32_t) il < hparams.n_layer_dense_lead) { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "ffn_moe_out", il); + + // FFN shared expert + { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_bitnet : public llm_graph_context { + llm_build_bitnet(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].wq_scale) { + Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale); + } + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + // B1.K + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].wk_scale) { + Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale); + } + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + // B1.V + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].wv_scale) { + Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale); + } + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + NULL, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + + cur = build_norm(cur, + model.layers[il].attn_sub_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_sub_norm", il); + + cur = build_lora_mm(model.layers[il].wo, cur); + if (model.layers[il].wo_scale) { + cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); + } + if (model.layers[il].bo) { + cur = ggml_add(ctx0, cur, model.layers[il].bo); + } + cb(cur, "attn_o_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward forward + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_scale, + model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale, + NULL, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_sub_out", il); + + cur = build_norm(cur, + model.layers[il].ffn_sub_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_sub_norm", il); + + cur = build_lora_mm(model.layers[il].ffn_down, cur); + if (model.layers[il].ffn_down_scale) { + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); + } + cb(cur, "ffn_down", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + // FIXME: do not use model.tok_embd directly, duplicate as model.output + cur = build_lora_mm(model.tok_embd, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_t5_enc : public llm_graph_context { + llm_build_t5_enc(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * pos_bucket_enc = build_inp_pos_bucket_enc(); + + auto * inp_attn = build_attn_inp_no_cache(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm_enc, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_enc, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_enc, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_enc, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc; + ggml_tensor * kq_b = build_pos_bias(pos_bucket_enc, attn_rel_b); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo_enc, nullptr, + Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il); + cb(cur, "kqv_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm_enc, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // T5 uses relu, flan-T5 uses gelu-gated + cur = build_ffn(cur, + model.layers[il].ffn_up_enc, NULL, NULL, + model.layers[il].ffn_gate_enc, NULL, NULL, + model.layers[il].ffn_down_enc, NULL, NULL, + NULL, + model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, + model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, + il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + cb(cur, "result_embd", -1); + + cur = build_norm(cur, + model.output_norm_enc, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_t5_dec : public llm_graph_context { + llm_build_t5_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + //const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * embd_enc = build_inp_cross_embd(); + ggml_tensor * pos_bucket_dec = build_inp_pos_bucket_dec(); + + const int64_t n_outputs_enc = embd_enc->ne[1]; + + auto * inp_attn_self = build_attn_inp_kv_unified(); + auto * inp_attn_cross = build_attn_inp_cross(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b; + ggml_tensor * kq_b = build_pos_bias(pos_bucket_dec, attn_rel_b); + + cur = build_attn(inp_attn_self, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il); + cb(cur, "kqv_out", il); + } + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "cross_inp", il); + + ggml_tensor * inpCA = cur; + + // norm + cur = build_norm(cur, + model.layers[il].attn_norm_cross, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm_cross", il); + + // cross-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_cross, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_cross, embd_enc); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_cross, embd_enc); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_outputs_enc); + + cur = build_attn(inp_attn_cross, gf, + model.layers[il].wo_cross, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); + cb(cur, "kqv_out", il); + + //ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + //ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); + + //ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + //cb(kq, "kq", il); + + //kq = ggml_soft_max_ext(ctx0, kq, KQ_mask_cross, 1.0f, hparams.f_max_alibi_bias); + //cb(kq, "kq_soft_max_ext", il); + + //ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_outputs_enc))); + //cb(v, "v", il); + + //ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_outputs_enc, n_embd_head, n_head_kv), kq); + //cb(kqv, "kqv", il); + + //ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + //cb(kqv_merged, "kqv_merged", il); + + //cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens); + //cb(cur, "kqv_merged_cont", il); + + //ggml_build_forward_expand(gf, cur); + + //cur = build_lora_mm(model.layers[il].wo_cross, cur); + //cb(cur, "kqv_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // T5 uses relu, flan-T5 uses gelu-gated + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, + model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, + il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + cb(cur, "result_embd", -1); + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_jais : public llm_graph_context { + llm_build_jais(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + cur = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*cur->nb[0]*(n_embd))); + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd))); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/float(n_embd_head), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // add the input + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + + cur = build_norm(inpL, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_chatglm : public llm_graph_context { + llm_build_chatglm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + + if (model.layers[il].wqkv == nullptr) { + Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + } else { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + if (model.layers[il].bqkv) { + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + } + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor); + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Add the input + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + } + + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + + cur = build_norm(inpL, + model.output_norm, + NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_glm4 : public llm_graph_context { + llm_build_glm4(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // Pre-attention norm + cur = build_norm(inpL, + model.layers[il].attn_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + + if (model.layers[il].wqkv == nullptr) { + Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + } else { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + if (model.layers[il].bqkv) { + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + } + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Post-attention norm (new!) + cur = build_norm(cur, + model.layers[il].attn_post_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + + // Add the input (residual connection after post-attention norm) + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + // Pre-MLP norm + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // MLP + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + // Post-MLP norm + cur = build_norm(cur, + model.layers[il].ffn_post_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "post_mlp_norm", il); + } + + // Add residual connection after post-MLP norm + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + + // Final norm + cur = build_norm(inpL, + model.output_norm, + NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // Output projection + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_nemotron : public llm_graph_context { + llm_build_nemotron(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + //GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_exaone : public llm_graph_context { + llm_build_exaone(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_rwkv6_base : public llm_graph_context { + const llama_model & model; + + llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) { + } + + ggml_tensor * build_rwkv6_channel_mix( + const llama_layer * layer, + ggml_tensor * cur, + ggml_tensor * x_prev, + llm_arch arch) const { + ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); + switch (arch) { + case LLM_ARCH_RWKV6: + { + ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur); + ggml_tensor * xr = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_r), cur); + + ggml_tensor * r = ggml_sigmoid(ctx0, build_lora_mm(layer->channel_mix_receptance, xr)); + ggml_tensor * k = ggml_sqr( + ctx0, + ggml_relu( + ctx0, + build_lora_mm(layer->channel_mix_key, xk) + ) + ); + cur = ggml_mul(ctx0, r, build_lora_mm(layer->channel_mix_value, k)); + } break; + default: + GGML_ABORT("fatal error"); + } + + return cur; + } + + ggml_tensor * build_rwkv6_time_mix( + ggml_cgraph * gf, + ggml_tensor * cur, + ggml_tensor * x_prev, + ggml_tensor * state_copy, + const llama_ubatch & ubatch, + int il) const { + const auto * kv_state = static_cast(mstate); + + const auto n_tokens = ubatch.n_tokens; + const auto n_seqs = ubatch.n_seqs; + const auto n_seq_tokens = ubatch.n_seq_tokens; + const auto n_embd = hparams.n_embd; + const auto head_size = hparams.wkv_head_size; + const auto n_head = n_embd / head_size; + const auto n_head_kv = hparams.n_head_kv(il); + + const auto kv_head = kv_state->get_head(); + + const auto & layer = model.layers[il]; + + bool is_qrwkv = layer.time_mix_first == nullptr; + + ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); + + sx = ggml_reshape_2d(ctx0, sx, n_embd, n_tokens); + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); + + ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_x), cur); + + xxx = ggml_reshape_4d( + ctx0, + ggml_tanh( + ctx0, + ggml_mul_mat(ctx0, layer.time_mix_w1, xxx) + ), + layer.time_mix_w1->ne[1] / 5, 1, 5, n_tokens + ); + + xxx = ggml_cont(ctx0, ggml_permute(ctx0, xxx, 0, 1, 3, 2)); + + xxx = ggml_mul_mat( + ctx0, + ggml_reshape_4d( + ctx0, + layer.time_mix_w2, + layer.time_mix_w2->ne[0], layer.time_mix_w2->ne[1], 1, 5 + ), + xxx + ); + + ggml_tensor *xw, *xk, *xv, *xr, *xg; + if (layer.time_mix_lerp_fused) { + // fusing these weights makes some performance improvement + sx = ggml_reshape_3d(ctx0, sx, n_embd, 1, n_tokens); + cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); + xxx = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xxx, layer.time_mix_lerp_fused), sx), cur); + xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0); + xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float)); + xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float)); + xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float)); + xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float)); + } else { + // for backward compatibility + xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0); + xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float)); + xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float)); + xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float)); + xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float)); + + xw = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xw, layer.time_mix_lerp_w), sx), cur); + xk = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xk, layer.time_mix_lerp_k), sx), cur); + xv = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xv, layer.time_mix_lerp_v), sx), cur); + xr = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xr, layer.time_mix_lerp_r), sx), cur); + xg = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xg, layer.time_mix_lerp_g), sx), cur); + } + + ggml_tensor * r = build_lora_mm(layer.time_mix_receptance, xr); + ggml_tensor * k = build_lora_mm(layer.time_mix_key, xk); + ggml_tensor * v = build_lora_mm(layer.time_mix_value, xv); + if (layer.time_mix_receptance_b) { + r = ggml_add(ctx0, r, layer.time_mix_receptance_b); + } + if (layer.time_mix_key_b) { + k = ggml_add(ctx0, k, layer.time_mix_key_b); + } + if (layer.time_mix_value_b) { + v = ggml_add(ctx0, v, layer.time_mix_value_b); + } + + ggml_tensor * g = build_lora_mm(layer.time_mix_gate, xg); + if (is_qrwkv) { + g = ggml_sigmoid(ctx0, g); + } else { + g = ggml_silu(ctx0, g); + } + + if (n_head_kv != 0 && n_head_kv != n_head) { + GGML_ASSERT(n_head % n_head_kv == 0); + k = ggml_reshape_4d(ctx0, k, head_size, 1, n_head_kv, n_tokens); + v = ggml_reshape_4d(ctx0, v, head_size, 1, n_head_kv, n_tokens); + ggml_tensor * tmp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_size, n_head / n_head_kv, n_head_kv, n_tokens); + k = ggml_repeat(ctx0, k, tmp); + v = ggml_repeat(ctx0, v, tmp); + } + + k = ggml_reshape_3d(ctx0, k, head_size, n_head, n_tokens); + v = ggml_reshape_3d(ctx0, v, head_size, n_head, n_tokens); + r = ggml_reshape_3d(ctx0, r, head_size, n_head, n_tokens); + + ggml_tensor * w = ggml_mul_mat( + ctx0, + layer.time_mix_decay_w2, + ggml_tanh( + ctx0, + ggml_mul_mat(ctx0, layer.time_mix_decay_w1, xw) + ) + ); + + w = ggml_add(ctx0, w, layer.time_mix_decay); + w = ggml_exp(ctx0, ggml_neg(ctx0, ggml_exp(ctx0, w))); + w = ggml_reshape_3d(ctx0, w, head_size, n_head, n_tokens); + + if (is_qrwkv) { + // k = k * (1 - w) + k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w)); + } + + ggml_tensor * wkv_state = build_recurrent_state( + gf, kv_state->get_v_l(il), state_copy, + hparams.n_embd_v_s(), n_seqs); + + ggml_tensor * wkv_output; + if (is_qrwkv) { + wkv_output = ggml_gated_linear_attn(ctx0, k, v, r, w, wkv_state, pow(head_size, -0.5f)); + } else { + wkv_output = ggml_rwkv_wkv6(ctx0, k, v, r, layer.time_mix_first, w, wkv_state); + } + cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0); + wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float)); + + ggml_build_forward_expand( + gf, + ggml_cpy( + ctx0, + wkv_state, + ggml_view_1d( + ctx0, + kv_state->get_v_l(il), + hparams.n_embd_v_s() * n_seqs, + hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il)) + ) + ) + ); + + if (!is_qrwkv) { + // group norm with head_count groups + cur = ggml_reshape_3d(ctx0, cur, n_embd / n_head, n_head, n_tokens); + cur = ggml_norm(ctx0, cur, 64e-5f); + + // Convert back to regular vectors. + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.time_mix_ln), layer.time_mix_ln_b); + } else { + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); + } + + cur = ggml_mul(ctx0, cur, g); + cur = build_lora_mm(layer.time_mix_output, cur); + + return ggml_reshape_3d(ctx0, cur, n_embd, n_seq_tokens, n_seqs); + } +}; + +struct llm_build_rwkv6 : public llm_build_rwkv6_base { + llm_build_rwkv6(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) { + GGML_ASSERT(hparams.token_shift_count == 2); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); + + ggml_tensor * state_copy = build_inp_s_copy(); + + const auto n_embd = hparams.n_embd; + const auto n_seq_tokens = ubatch.n_seq_tokens; + const auto n_seqs = ubatch.n_seqs; + + for (int il = 0; il < n_layer; ++il) { + const llama_layer * layer = &model.layers[il]; + inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); + + ggml_tensor * token_shift = build_rwkv_token_shift_load( + gf, state_copy, ubatch, il + ); + + ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); + ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift)); + + ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM, il); + cb(att_norm, "attn_norm", il); + + ggml_tensor * x_prev = ggml_concat( + ctx0, + att_shift, + ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0), + 1 + ); + + cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il); + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + ggml_tensor * ffn_norm = build_norm(ffn_inp, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, il); + cb(ffn_norm, "ffn_norm", il); + + x_prev = ggml_concat( + ctx0, + ffn_shift, + ggml_view_3d(ctx0, ffn_norm, n_embd, n_seq_tokens - 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], 0), + 1 + ); + + token_shift = ggml_concat(ctx0, + ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)), + ggml_view_3d(ctx0, ffn_norm, n_embd, 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(ffn_norm)), + 1 + ); + ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids); + ffn_norm = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens), inp_out_ids); + x_prev = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens), inp_out_ids); + cur = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids); + } + + cur = build_rwkv6_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV6); + cur = ggml_add(ctx0, cur, ffn_inp); + + if (hparams.rescale_every_n_layers != 0 && (il + 1) % hparams.rescale_every_n_layers == 0) { + cur = ggml_scale(ctx0, cur, 0.5F); + } + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +// ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py +struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { + llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) { + GGML_ASSERT(n_embd == hparams.n_embd_k_s()); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * state_copy = build_inp_s_copy(); + + const auto n_embd = hparams.n_embd; + const auto n_seq_tokens = ubatch.n_seq_tokens; + const auto n_seqs = ubatch.n_seqs; + + for (int il = 0; il < n_layer; ++il) { + const llama_layer * layer = &model.layers[il]; + inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); + + ggml_tensor * token_shift = build_rwkv_token_shift_load( + gf, state_copy, ubatch, il + ); + + ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); + cb(att_norm, "attn_norm", il); + + ggml_tensor * x_prev = ggml_concat( + ctx0, + token_shift, + ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0), + 1 + ); + + cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il); + + token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); + ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids); + ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids); + } + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_rwkv7_base : public llm_graph_context { + const llama_model & model; + + llm_build_rwkv7_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) { + } + + ggml_tensor * build_rwkv7_channel_mix( + const llama_layer * layer, + ggml_tensor * cur, + ggml_tensor * x_prev, + llm_arch arch) const { + ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); + switch (arch) { + case LLM_ARCH_RWKV7: + { + ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur); + + ggml_tensor * k = ggml_sqr( + ctx0, + ggml_relu( + ctx0, + build_lora_mm(layer->channel_mix_key, xk) + ) + ); + + cur = build_lora_mm(layer->channel_mix_value, k); + } break; + default: + GGML_ABORT("fatal error"); + } + + return cur; + } + + ggml_tensor * build_rwkv7_time_mix( + ggml_cgraph * gf, + ggml_tensor * cur, + ggml_tensor * x_prev, + ggml_tensor * state_copy, + ggml_tensor *& first_layer_value, + const llama_ubatch & ubatch, + int il) const { + const auto * kv_state = static_cast(mstate); + + const auto n_tokens = ubatch.n_tokens; + const auto n_seqs = ubatch.n_seqs; + const auto n_embd = hparams.n_embd; + const auto head_size = hparams.wkv_head_size; + const auto head_count = n_embd / head_size; + const auto n_seq_tokens = ubatch.n_seq_tokens; + + const auto kv_head = kv_state->get_head(); + + const auto & layer = model.layers[il]; + + bool has_gating = layer.time_mix_g1 && layer.time_mix_g2; + + ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); + ggml_tensor * dummy = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_embd, n_seq_tokens, n_seqs, has_gating ? 6 : 5); + sx = ggml_repeat(ctx0, sx, dummy); + + ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_fused), cur); + + ggml_tensor * xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0); + ggml_tensor * xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float)); + ggml_tensor * xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float)); + ggml_tensor * xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float)); + ggml_tensor * xa = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float)); + ggml_tensor * xg = has_gating ? ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 5 * sizeof(float)) : nullptr; + + ggml_tensor * r = build_lora_mm(layer.time_mix_receptance, xr); + ggml_tensor * w = ggml_add( + ctx0, + ggml_mul_mat(ctx0, layer.time_mix_w2, ggml_tanh(ctx0, ggml_mul_mat(ctx0, layer.time_mix_w1, xw))), + layer.time_mix_w0 + ); + w = ggml_exp(ctx0, ggml_scale(ctx0, ggml_sigmoid(ctx0, w), -0.606531)); + + ggml_tensor * k = build_lora_mm(layer.time_mix_key, xk); + ggml_tensor * v = build_lora_mm(layer.time_mix_value, xv); + if (first_layer_value == nullptr) { + first_layer_value = v; + } else { + // Add the first layer value as a residual connection. + v = ggml_add(ctx0, v, + ggml_mul(ctx0, + ggml_sub(ctx0, first_layer_value, v), + ggml_sigmoid(ctx0, ggml_add(ctx0, + ggml_mul_mat(ctx0, layer.time_mix_v2, ggml_mul_mat(ctx0, layer.time_mix_v1, xv)), + layer.time_mix_v0 + ) + ) + ) + ); + } + + ggml_tensor * g = nullptr; + if (layer.time_mix_g1 && layer.time_mix_g2) { + g = ggml_mul_mat(ctx0, layer.time_mix_g2, ggml_sigmoid(ctx0, ggml_mul_mat(ctx0, layer.time_mix_g1, xg))); + } + + ggml_tensor * a = ggml_sigmoid(ctx0, + ggml_add( + ctx0, + ggml_mul_mat(ctx0, layer.time_mix_a2, ggml_mul_mat(ctx0, layer.time_mix_a1, xa)), + layer.time_mix_a0 + ) + ); + + ggml_tensor * kk = ggml_reshape_3d(ctx0, ggml_mul(ctx0, k, layer.time_mix_k_k), head_size, head_count, n_tokens); + kk = ggml_l2_norm(ctx0, kk, 1e-12); + + ggml_tensor * ka = ggml_mul(ctx0, k, layer.time_mix_k_a); + k = ggml_add(ctx0, k, ggml_sub(ctx0, ggml_mul(ctx0, a, ka), ka)); + + r = ggml_reshape_3d(ctx0, r, head_size, head_count, n_tokens); + w = ggml_reshape_3d(ctx0, w, head_size, head_count, n_tokens); + k = ggml_reshape_3d(ctx0, k, head_size, head_count, n_tokens); + v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens); + a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens); + + ggml_tensor * wkv_state = build_recurrent_state( + gf, kv_state->get_v_l(il), state_copy, + hparams.n_embd_v_s(), n_seqs); + + ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state); + cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0); + wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float)); + + ggml_build_forward_expand( + gf, + ggml_cpy( + ctx0, + wkv_state, + ggml_view_1d( + ctx0, + kv_state->get_v_l(il), + hparams.n_embd_v_s() * n_seqs, + hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il)) + ) + ) + ); + + if (layer.time_mix_ln && layer.time_mix_ln_b) { + // group norm with head_count groups + cur = ggml_reshape_3d(ctx0, cur, n_embd / head_count, head_count, n_tokens); + cur = ggml_norm(ctx0, cur, 64e-5f); + + // Convert back to regular vectors. + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.time_mix_ln), layer.time_mix_ln_b); + } else { + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); + } + + ggml_tensor * rk = ggml_sum_rows(ctx0, + ggml_mul(ctx0, ggml_mul(ctx0, k, r), ggml_reshape_2d(ctx0, layer.time_mix_r_k, head_size, head_count))); + cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, ggml_mul(ctx0, v, rk), n_embd, n_tokens)); + + if (has_gating) { + cur = ggml_mul(ctx0, cur, g); + } + cur = build_lora_mm(layer.time_mix_output, cur); + + return ggml_reshape_3d(ctx0, cur, n_embd, n_seq_tokens, n_seqs); + } +}; + +struct llm_build_rwkv7 : public llm_build_rwkv7_base { + llm_build_rwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) { + GGML_ASSERT(hparams.token_shift_count == 2); + + ggml_tensor * cur; + ggml_tensor * inpL; + ggml_tensor * v_first = nullptr; + + inpL = build_inp_embd(model.tok_embd); + inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); + + ggml_tensor * state_copy = build_inp_s_copy(); + + const auto n_embd = hparams.n_embd; + const auto n_seq_tokens = ubatch.n_seq_tokens; + const auto n_seqs = ubatch.n_seqs; + + for (int il = 0; il < n_layer; ++il) { + const llama_layer * layer = &model.layers[il]; + inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); + + ggml_tensor * token_shift = build_rwkv_token_shift_load( + gf, state_copy, ubatch, il + ); + + ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); + ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift)); + + ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM, il); + cb(att_norm, "attn_norm", il); + + ggml_tensor * x_prev = ggml_concat( + ctx0, + att_shift, + ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0), + 1 + ); + + cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il); + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + ggml_tensor * ffn_norm = build_norm(ffn_inp, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, il); + cb(ffn_norm, "ffn_norm", il); + + x_prev = ggml_concat( + ctx0, + ffn_shift, + ggml_view_3d(ctx0, ffn_norm, n_embd, n_seq_tokens - 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], 0), + 1 + ); + + token_shift = ggml_concat(ctx0, + ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)), + ggml_view_3d(ctx0, ffn_norm, n_embd, 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(ffn_norm)), + 1 + ); + ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids); + ffn_norm = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens), inp_out_ids); + x_prev = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens), inp_out_ids); + } + + cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7); + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + + +struct llm_build_arwkv7 : public llm_build_rwkv7_base { + llm_build_arwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) { + GGML_ASSERT(n_embd == hparams.n_embd_k_s()); + + ggml_tensor * cur; + ggml_tensor * inpL; + ggml_tensor * v_first = nullptr; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * state_copy = build_inp_s_copy(); + + const auto n_embd = hparams.n_embd; + const auto n_seq_tokens = ubatch.n_seq_tokens; + const auto n_seqs = ubatch.n_seqs; + + for (int il = 0; il < n_layer; ++il) { + const llama_layer * layer = &model.layers[il]; + inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); + + ggml_tensor * token_shift = build_rwkv_token_shift_load( + gf, state_copy, ubatch, il + ); + + ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); + cb(att_norm, "attn_norm", il); + + ggml_tensor * x_prev = ggml_concat( + ctx0, + token_shift, + ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0), + 1 + ); + + cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il); + + token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); + ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids); + ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids); + } + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + + +struct llm_build_granite : public llm_graph_context { + llm_build_granite( + const llama_model & model, + const llm_graph_params & params, + ggml_cgraph * gf, + const bool use_rope = true) + : llm_graph_context(params) { + + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - built only if rope enabled + ggml_tensor * inp_pos = nullptr; + if (use_rope) { + inp_pos = build_inp_pos(); + } + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and (optionally) RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + if (use_rope) { + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // For Granite architectures - scale residual + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network (non-MoE) + if (model.layers[il].ffn_gate_inp == nullptr) { + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + } else { + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } else { + cur = moe_out; + } + } + + // For Granite architectures - scale residual + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + // For Granite architectures - scale logits + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +// ref: https://github.com/facebookresearch/chameleon +// based on the original build_llama() function, changes: +// * qk-norm +// * swin-norm +// * removed bias +// * removed MoE +struct llm_build_chameleon : public llm_graph_context { + llm_build_chameleon(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + if (hparams.swin_norm) { + cur = inpL; + } else { + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + } + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + if (model.layers[il].attn_q_norm) { + Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur) * n_embd_head, + ggml_element_size(Qcur) * n_embd_head * n_head, + 0); + cb(Qcur, "Qcur", il); + + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, + model.layers[il].attn_q_norm_b, + LLM_NORM, il); + cb(Qcur, "Qcur", il); + } + + if (model.layers[il].attn_k_norm) { + Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens, + ggml_element_size(Kcur) * n_embd_head, + ggml_element_size(Kcur) * n_embd_head * n_head_kv, + 0); + cb(Kcur, "Kcur", il); + + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, + model.layers[il].attn_k_norm_b, + LLM_NORM, il); + cb(Kcur, "Kcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + + if (hparams.swin_norm) { + cur = build_norm(cur, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + } + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + if (!hparams.swin_norm) { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + } + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + if (hparams.swin_norm) { + cur = build_norm(cur, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output_with_img_logits", -1); + + // TODO: this suppresses the output of image tokens, which is required to enable text-only outputs. + // Needs to be removed once image outputs are supported. + int img_token_end_idx = 8196; + int img_token_start_idx = 4; + int num_img_tokens = img_token_end_idx - img_token_start_idx; + // creates 1d tensor of size num_img_tokens and values -FLT_MAX, + // which ensures that text token values are always at least larger than image token values + ggml_tensor * img_logits = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, num_img_tokens); + img_logits = ggml_clamp(ctx0, img_logits, -FLT_MAX, -FLT_MAX); + cb(img_logits, "img_logits", -1); + + cur = ggml_set_1d(ctx0, cur, img_logits, ggml_element_size(cur) * img_token_start_idx); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_wavtokenizer_dec : public llm_graph_context { + llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + cur = ggml_cont(ctx0, ggml_transpose(ctx0, inpL)); + + cur = ggml_conv_1d_ph(ctx0, model.conv1d, cur, 1, 1); + cur = ggml_add(ctx0, cur, model.conv1d_b); + + // posnet + for (uint32_t il = 0; il < hparams.posnet.n_layer; ++il) { + const auto & layer = model.layers[il].posnet; + + inpL = cur; + + switch (il) { + case 0: + case 1: + case 3: + case 4: + { + cur = build_norm(cur, + layer.norm1, + layer.norm1_b, + LLM_NORM_GROUP, 0); + + cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur); + + cur = ggml_conv_1d_ph(ctx0, layer.conv1, cur, 1, 1); + cur = ggml_add(ctx0, cur, layer.conv1_b); + + cur = build_norm(cur, + layer.norm2, + layer.norm2_b, + LLM_NORM_GROUP, 0); + + cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur); + + cur = ggml_conv_1d_ph(ctx0, layer.conv2, cur, 1, 1); + cur = ggml_add(ctx0, cur, layer.conv2_b); + + cur = ggml_add(ctx0, cur, inpL); + } break; + case 2: + { + cur = build_norm(cur, + layer.attn_norm, + layer.attn_norm_b, + LLM_NORM_GROUP, 0); + + ggml_tensor * q; + ggml_tensor * k; + ggml_tensor * v; + + q = ggml_conv_1d_ph(ctx0, layer.attn_q, cur, 1, 1); + k = ggml_conv_1d_ph(ctx0, layer.attn_k, cur, 1, 1); + v = ggml_conv_1d_ph(ctx0, layer.attn_v, cur, 1, 1); + + q = ggml_add(ctx0, q, layer.attn_q_b); + k = ggml_add(ctx0, k, layer.attn_k_b); + v = ggml_add(ctx0, v, layer.attn_v_b); + + q = ggml_cont(ctx0, ggml_transpose(ctx0, q)); + k = ggml_cont(ctx0, ggml_transpose(ctx0, k)); + + ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + + kq = ggml_soft_max_ext(ctx0, kq, nullptr, 1.0f/sqrtf(float(hparams.posnet.n_embd)), 0.0f); + + cur = ggml_mul_mat(ctx0, kq, v); + + cur = ggml_conv_1d_ph(ctx0, layer.attn_o, cur, 1, 1); + cur = ggml_add(ctx0, cur, layer.attn_o_b); + + cur = ggml_add(ctx0, cur, inpL); + } break; + case 5: + { + cur = build_norm(cur, + layer.norm, + layer.norm_b, + LLM_NORM_GROUP, 0); + } break; + default: GGML_ABORT("unknown posnet layer"); + }; + } + + cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + + cur = build_norm(cur, + model.tok_norm, + model.tok_norm_b, + LLM_NORM, -1); + + cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + + inpL = cur; + + // convnext + for (uint32_t il = 0; il < hparams.convnext.n_layer; ++il) { + const auto & layer = model.layers[il].convnext; + + cur = inpL; + + cur = ggml_conv_1d_dw_ph(ctx0, layer.dw, cur, 1, 1); + cur = ggml_add(ctx0, cur, layer.dw_b); + + cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + + cur = build_norm(cur, + layer.norm, + layer.norm_b, + LLM_NORM, -1); + + cur = build_ffn(cur, + layer.pw1, layer.pw1_b, NULL, + NULL, NULL, NULL, + layer.pw2, layer.pw2_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + + cur = ggml_mul(ctx0, cur, layer.gamma); + + cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + + inpL = ggml_add(ctx0, cur, inpL); + } + + cur = inpL; + + cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + + cur = build_norm(cur, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + + // lm_head + cur = build_lora_mm(model.output, cur); + + cur = ggml_add(ctx0, cur, model.output_b); + + cb(cur, "result_embd", -1); + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_plm : public llm_graph_context { + llm_build_plm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const float kq_scale = 1.0f/sqrtf(float(hparams.n_embd_head_k)); + + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + ggml_tensor * q = NULL; + q = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(q, "q", il); + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(q->type, hparams.n_embd_head_k), + ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + 0); + cb(q_nope, "q_nope", il); + + // and {n_head * n_embd_head_qk_rope, n_tokens} + ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, + ggml_row_size(q->type, hparams.n_embd_head_k), + ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + ggml_row_size(q->type, n_embd_head_qk_nope)); + cb(q_pe, "q_pe", il); + + // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens} + ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_pe_compresseed, "kv_pe_compresseed", il); + + // split into {kv_lora_rank, n_tokens} + ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens, + kv_pe_compresseed->nb[1], + 0); + cb(kv_compressed, "kv_compressed", il); + + // and {n_embd_head_qk_rope, n_tokens} + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, + kv_pe_compresseed->nb[1], + kv_pe_compresseed->nb[1], + ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); + cb(k_pe, "k_pe", il); + + kv_compressed = build_norm(kv_compressed, + model.layers[il].attn_kv_a_norm, NULL, + LLM_NORM_RMS, il); + cb(kv_compressed, "kv_compressed", il); + + // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} + ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); + cb(kv, "kv", il); + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + 0); + cb(k_nope, "k_nope", il); + + // and {n_head * n_embd_head_v, n_tokens} + ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), + ggml_row_size(kv->type, (n_embd_head_qk_nope))); + cb(v_states, "v_states", il); + + v_states = ggml_cont(ctx0, v_states); + cb(v_states, "v_states", il); + + v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, + ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), + 0); + cb(v_states, "v_states", il); + + q_pe = ggml_rope_ext( + ctx0, q_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(q_pe, "q_pe", il); + + // shared RoPE key + k_pe = ggml_rope_ext( + ctx0, k_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(k_pe, "k_pe", il); + + ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); + cb(q_states, "q_states", il); + + ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); + cb(k_states, "k_states", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + q_states, k_states, v_states, nullptr, nullptr, kq_scale, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_bailingmoe : public llm_graph_context { + llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_rot, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + false, hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + + // FFN shared expert + { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_dots1 : public llm_graph_context { + llm_build_dots1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + if ((uint32_t) il < hparams.n_layer_dense_lead) { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "ffn_moe_out", il); + + { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_arcee : public llm_graph_context { + llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + // ARCEE uses relu^2 instead of silu + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { + llama_memory_i * res; + + switch (arch) { + case LLM_ARCH_BERT: + case LLM_ARCH_JINA_BERT_V2: + case LLM_ARCH_NOMIC_BERT: + case LLM_ARCH_NOMIC_BERT_MOE: + case LLM_ARCH_NEO_BERT: + case LLM_ARCH_WAVTOKENIZER_DEC: + { + res = nullptr; + } break; + case LLM_ARCH_MAMBA: + case LLM_ARCH_RWKV6: + case LLM_ARCH_RWKV6QWEN2: + case LLM_ARCH_RWKV7: + case LLM_ARCH_ARWKV7: + { + res = new llama_kv_cache_recurrent( + *this, + GGML_TYPE_F32, + GGML_TYPE_F32, + cparams.offload_kqv, + std::max((uint32_t) 1, cparams.n_seq_max), + cparams.n_seq_max); + } break; + default: + { + const auto padding = llama_kv_cache_unified::get_padding(cparams); + + cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); + + LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); + + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + GGML_ASSERT(hparams.is_swa_any()); + + res = new llama_kv_cache_unified_iswa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + params.swa_full, + cparams.n_ctx, + cparams.n_seq_max, + cparams.n_ubatch, + padding); + } else { + GGML_ASSERT(!hparams.is_swa_any()); + + res = new llama_kv_cache_unified( + *this, + nullptr, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.n_ctx, + cparams.n_seq_max, + padding, + hparams.n_swa, + hparams.swa_type); + } + } + } + + return res; +} + +llm_graph_result_ptr llama_model::build_graph( + const llm_graph_params & params, + ggml_cgraph * gf, + llm_graph_type type) const { + std::unique_ptr llm; + + switch (arch) { + case LLM_ARCH_LLAMA: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_LLAMA4: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_DECI: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_BAICHUAN: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_FALCON: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_GROK: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_STARCODER: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_REFACT: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_BERT: + case LLM_ARCH_JINA_BERT_V2: + case LLM_ARCH_NOMIC_BERT: + case LLM_ARCH_NOMIC_BERT_MOE: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_NEO_BERT: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_BLOOM: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_MPT: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_STABLELM: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_QWEN: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_QWEN2: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_QWEN2VL: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_QWEN2MOE: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_QWEN3: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_QWEN3MOE: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_PHI2: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_PHI3: + case LLM_ARCH_PHIMOE: + { + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + llm = std::make_unique> (*this, params, gf); + } else { + llm = std::make_unique>(*this, params, gf); + } + } break; + case LLM_ARCH_PLAMO: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_GPT2: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_CODESHELL: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_ORION: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_INTERNLM2: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_MINICPM3: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_GEMMA: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_GEMMA2: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_GEMMA3: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_STARCODER2: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_MAMBA: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_XVERSE: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_COMMAND_R: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_COHERE2: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_DBRX: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_OLMO: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_OLMO2: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_OLMOE: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_OPENELM: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_GPTNEOX: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_ARCTIC: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_DEEPSEEK: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_DEEPSEEK2: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_CHATGLM: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_GLM4: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_BITNET: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_T5: + { + switch (type) { + case LLM_GRAPH_TYPE_ENCODER: + llm = std::make_unique(*this, params, gf); + break; + case LLM_GRAPH_TYPE_DEFAULT: + case LLM_GRAPH_TYPE_DECODER: + llm = std::make_unique(*this, params, gf); + break; + default: + GGML_ABORT("invalid graph type"); + }; + } break; + case LLM_ARCH_T5ENCODER: + { + llm = std::make_unique(*this, params, gf); + } + break; + case LLM_ARCH_JAIS: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_NEMOTRON: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_EXAONE: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_RWKV6: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_RWKV6QWEN2: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_RWKV7: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_ARWKV7: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_GRANITE: + case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_MINICPM: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_CHAMELEON: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_WAVTOKENIZER_DEC: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_PLM: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_BAILINGMOE: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_DOTS1: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_ARCEE: + { + llm = std::make_unique(*this, params, gf); + } break; + default: + GGML_ABORT("fatal error"); + } + + // add on pooling layer + llm->build_pooling(gf, cls, cls_b, cls_out, cls_out_b); + + return std::move(llm->res); +} + +// +// interface implementation +// + +llama_model_params llama_model_default_params() { + llama_model_params result = { + /*.devices =*/ nullptr, + /*.tensor_buft_overrides =*/ nullptr, + /*.n_gpu_layers =*/ 0, + /*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER, + /*.main_gpu =*/ 0, + /*.tensor_split =*/ nullptr, + /*.progress_callback =*/ nullptr, + /*.progress_callback_user_data =*/ nullptr, + /*.kv_overrides =*/ nullptr, + /*.vocab_only =*/ false, + /*.use_mmap =*/ true, + /*.use_mlock =*/ false, + /*.check_tensors =*/ false, + }; + +#ifdef GGML_USE_METAL + // note: we usually have plenty of VRAM, so by default offload all layers to the GPU + result.n_gpu_layers = 999; +#endif + + return result; +} + +const llama_vocab * llama_model_get_vocab(const llama_model * model) { + return &model->vocab; +} + +void llama_free_model(llama_model * model) { + llama_model_free(model); +} + +void llama_model_free(llama_model * model) { + delete model; +} + +int32_t llama_model_n_ctx_train(const llama_model * model) { + return model->hparams.n_ctx_train; +} + +int32_t llama_model_n_embd(const llama_model * model) { + return model->hparams.n_embd; +} + +int32_t llama_model_n_layer(const llama_model * model) { + return model->hparams.n_layer; +} + +int32_t llama_model_n_head(const llama_model * model) { + return model->hparams.n_head(); +} + +int32_t llama_model_n_head_kv(const llama_model * model) { + return model->hparams.n_head_kv(); +} + +int32_t llama_model_n_swa(const llama_model * model) { + return model->hparams.n_swa; +} + +uint32_t llama_model_n_cls_out(const struct llama_model * model) { + return model->hparams.n_cls_out; +} + +const char * llama_model_cls_label(const struct llama_model * model, uint32_t i) { + if (i < model->classifier_labels.size()) { + return model->classifier_labels[i].c_str(); + } + + return nullptr; +} + +// deprecated +int32_t llama_n_ctx_train(const llama_model * model) { + return llama_model_n_ctx_train(model); +} + +// deprecated +int32_t llama_n_embd(const llama_model * model) { + return llama_model_n_embd(model); +} + +// deprecated +int32_t llama_n_layer(const llama_model * model) { + return llama_model_n_layer(model); +} + +// deprecated +int32_t llama_n_head(const llama_model * model) { + return llama_model_n_head(model); +} + +llama_rope_type llama_model_rope_type(const llama_model * model) { + switch (model->arch) { + // these models do not use RoPE + case LLM_ARCH_GPT2: + case LLM_ARCH_GPTJ: + case LLM_ARCH_MPT: + case LLM_ARCH_REFACT: + case LLM_ARCH_BLOOM: + case LLM_ARCH_MAMBA: + case LLM_ARCH_JINA_BERT_V2: + case LLM_ARCH_T5: + case LLM_ARCH_T5ENCODER: + case LLM_ARCH_JAIS: + case LLM_ARCH_RWKV6: + case LLM_ARCH_RWKV6QWEN2: + case LLM_ARCH_RWKV7: + case LLM_ARCH_ARWKV7: + case LLM_ARCH_WAVTOKENIZER_DEC: + return LLAMA_ROPE_TYPE_NONE; + + // use what we call a normal RoPE, operating on pairs of consecutive head values + case LLM_ARCH_LLAMA: + case LLM_ARCH_LLAMA4: + case LLM_ARCH_DECI: + case LLM_ARCH_BAICHUAN: + case LLM_ARCH_STARCODER: + case LLM_ARCH_INTERNLM2: + case LLM_ARCH_MINICPM: + case LLM_ARCH_XVERSE: + case LLM_ARCH_COMMAND_R: + case LLM_ARCH_COHERE2: + case LLM_ARCH_OLMO: + case LLM_ARCH_ARCTIC: + case LLM_ARCH_DEEPSEEK: + case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_PLM: + case LLM_ARCH_CHATGLM: + case LLM_ARCH_GLM4: + case LLM_ARCH_GRANITE: + case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_CHAMELEON: + case LLM_ARCH_BAILINGMOE: + case LLM_ARCH_NEO_BERT: + case LLM_ARCH_ARCEE: + return LLAMA_ROPE_TYPE_NORM; + + // the pairs of head values are offset by n_rot/2 + case LLM_ARCH_FALCON: + case LLM_ARCH_GROK: + case LLM_ARCH_DBRX: + case LLM_ARCH_BERT: + case LLM_ARCH_NOMIC_BERT: + case LLM_ARCH_NOMIC_BERT_MOE: + case LLM_ARCH_STABLELM: + case LLM_ARCH_BITNET: + case LLM_ARCH_QWEN: + case LLM_ARCH_QWEN2: + case LLM_ARCH_QWEN2MOE: + case LLM_ARCH_QWEN3: + case LLM_ARCH_QWEN3MOE: + case LLM_ARCH_OLMO2: + case LLM_ARCH_OLMOE: + case LLM_ARCH_PHI2: + case LLM_ARCH_PHI3: + case LLM_ARCH_PHIMOE: + case LLM_ARCH_PLAMO: + case LLM_ARCH_GEMMA: + case LLM_ARCH_GEMMA2: + case LLM_ARCH_GEMMA3: + case LLM_ARCH_STARCODER2: + case LLM_ARCH_OPENELM: + case LLM_ARCH_GPTNEOX: + case LLM_ARCH_CODESHELL: + case LLM_ARCH_ORION: + case LLM_ARCH_NEMOTRON: + case LLM_ARCH_EXAONE: + case LLM_ARCH_MINICPM3: + case LLM_ARCH_DOTS1: + return LLAMA_ROPE_TYPE_NEOX; + + case LLM_ARCH_QWEN2VL: + return LLAMA_ROPE_TYPE_MROPE; + + // all model arches should be listed explicitly here + case LLM_ARCH_UNKNOWN: + GGML_ABORT("unknown architecture"); + } + + return LLAMA_ROPE_TYPE_NONE; +} + +float llama_model_rope_freq_scale_train(const llama_model * model) { + return model->hparams.rope_freq_scale_train; +} + +int32_t llama_model_meta_val_str(const llama_model * model, const char * key, char * buf, size_t buf_size) { + const auto & it = model->gguf_kv.find(key); + if (it == model->gguf_kv.end()) { + if (buf_size > 0) { + buf[0] = '\0'; + } + return -1; + } + return snprintf(buf, buf_size, "%s", it->second.c_str()); +} + +int32_t llama_model_meta_count(const llama_model * model) { + return (int)model->gguf_kv.size(); +} + +int32_t llama_model_meta_key_by_index(const llama_model * model, int i, char * buf, size_t buf_size) { + if (i < 0 || i >= (int)model->gguf_kv.size()) { + if (buf_size > 0) { + buf[0] = '\0'; + } + return -1; + } + auto it = model->gguf_kv.begin(); + std::advance(it, i); + return snprintf(buf, buf_size, "%s", it->first.c_str()); +} + +int32_t llama_model_meta_val_str_by_index(const llama_model * model, int32_t i, char * buf, size_t buf_size) { + if (i < 0 || i >= (int)model->gguf_kv.size()) { + if (buf_size > 0) { + buf[0] = '\0'; + } + return -1; + } + auto it = model->gguf_kv.begin(); + std::advance(it, i); + return snprintf(buf, buf_size, "%s", it->second.c_str()); +} + +int32_t llama_model_desc(const llama_model * model, char * buf, size_t buf_size) { + return snprintf(buf, buf_size, "%s", model->desc().c_str()); +} + +uint64_t llama_model_size(const llama_model * model) { + return model->size(); +} + +const char * llama_model_chat_template(const llama_model * model, const char * name) { + const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE) + : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE); + const auto & it = model->gguf_kv.find(key); + if (it == model->gguf_kv.end()) { + // one-off fix for very popular models (so we are not flooded with issues) + // do not extend this list unless absolutely necessary + // Mistral-Small-2503 does not have built-in chat template + llama_vocab_pre_type pre_type = model->vocab.get_pre_type(); + if (pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) { + return "mistral-v7-tekken"; + } + + return nullptr; + } + + return it->second.c_str(); +} + +uint64_t llama_model_n_params(const llama_model * model) { + return model->n_elements(); +} + +bool llama_model_has_encoder(const llama_model * model) { + switch (model->arch) { + case LLM_ARCH_T5: return true; + case LLM_ARCH_T5ENCODER: return true; + default: return false; + } +} + +bool llama_model_has_decoder(const llama_model * model) { + switch (model->arch) { + case LLM_ARCH_T5ENCODER: return false; + default: return true; + } +} + +llama_token llama_model_decoder_start_token(const llama_model * model) { + return model->hparams.dec_start_token_id; +} + +bool llama_model_is_recurrent(const llama_model * model) { + switch (model->arch) { + case LLM_ARCH_MAMBA: return true; + case LLM_ARCH_RWKV6: return true; + case LLM_ARCH_RWKV6QWEN2: return true; + case LLM_ARCH_RWKV7: return true; + case LLM_ARCH_ARWKV7: return true; + default: return false; + } +} + +const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { + return model->tensors_by_name; +} diff --git a/src/llama-model.h b/src/llama-model.h new file mode 100644 index 0000000000000000000000000000000000000000..06e6c687943cc23e615bd1f49f773347d2b4247b --- /dev/null +++ b/src/llama-model.h @@ -0,0 +1,429 @@ +#pragma once + +#include "llama.h" +#include "llama-arch.h" +#include "llama-graph.h" +#include "llama-hparams.h" +#include "llama-memory.h" +#include "llama-vocab.h" + +#include +#include +#include +#include + +struct llama_cparams; +struct llama_ubatch; +struct llama_model_loader; + +// available models +enum llm_type { + LLM_TYPE_UNKNOWN, + LLM_TYPE_14M, + LLM_TYPE_17M, + LLM_TYPE_22M, + LLM_TYPE_33M, + LLM_TYPE_60M, + LLM_TYPE_70M, + LLM_TYPE_80M, + LLM_TYPE_109M, + LLM_TYPE_137M, + LLM_TYPE_160M, + LLM_TYPE_190M, + LLM_TYPE_220M, + LLM_TYPE_250M, + LLM_TYPE_270M, + LLM_TYPE_335M, + LLM_TYPE_410M, + LLM_TYPE_450M, + LLM_TYPE_475M, + LLM_TYPE_770M, + LLM_TYPE_780M, + LLM_TYPE_0_5B, + LLM_TYPE_0_6B, + LLM_TYPE_1B, + LLM_TYPE_1_3B, + LLM_TYPE_1_4B, + LLM_TYPE_1_5B, + LLM_TYPE_1_6B, + LLM_TYPE_1_7B, + LLM_TYPE_1_8B, + LLM_TYPE_2B, + LLM_TYPE_2_8B, + LLM_TYPE_2_9B, + LLM_TYPE_3B, + LLM_TYPE_4B, + LLM_TYPE_6B, + LLM_TYPE_6_9B, + LLM_TYPE_7B, + LLM_TYPE_8B, + LLM_TYPE_9B, + LLM_TYPE_11B, + LLM_TYPE_12B, + LLM_TYPE_13B, + LLM_TYPE_14B, + LLM_TYPE_15B, + LLM_TYPE_16B, + LLM_TYPE_20B, + LLM_TYPE_27B, + LLM_TYPE_30B, + LLM_TYPE_32B, + LLM_TYPE_34B, + LLM_TYPE_35B, + LLM_TYPE_40B, + LLM_TYPE_65B, + LLM_TYPE_70B, + LLM_TYPE_142B, + LLM_TYPE_236B, + LLM_TYPE_290B, + LLM_TYPE_314B, + LLM_TYPE_405B, + LLM_TYPE_671B, + LLM_TYPE_SMALL, + LLM_TYPE_MEDIUM, + LLM_TYPE_LARGE, + LLM_TYPE_XL, + LLM_TYPE_A1_7B, + LLM_TYPE_A2_7B, + LLM_TYPE_8x7B, + LLM_TYPE_8x22B, + LLM_TYPE_16x12B, + LLM_TYPE_16x3_8B, + LLM_TYPE_10B_128x3_66B, + LLM_TYPE_57B_A14B, + LLM_TYPE_17B_16E, // llama4 Scout + LLM_TYPE_17B_128E, // llama4 Maverick + LLM_TYPE_30B_A3B, + LLM_TYPE_235B_A22B, +}; + +std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type); + +struct llama_layer_posnet { + // resnet + struct ggml_tensor * norm1 = nullptr; + struct ggml_tensor * norm1_b = nullptr; + + struct ggml_tensor * conv1 = nullptr; + struct ggml_tensor * conv1_b = nullptr; + + struct ggml_tensor * norm2 = nullptr; + struct ggml_tensor * norm2_b = nullptr; + + struct ggml_tensor * conv2 = nullptr; + struct ggml_tensor * conv2_b = nullptr; + + // attention + struct ggml_tensor * attn_norm = nullptr; + struct ggml_tensor * attn_norm_b = nullptr; + + struct ggml_tensor * attn_q = nullptr; + struct ggml_tensor * attn_q_b = nullptr; + + struct ggml_tensor * attn_k = nullptr; + struct ggml_tensor * attn_k_b = nullptr; + + struct ggml_tensor * attn_v = nullptr; + struct ggml_tensor * attn_v_b = nullptr; + + struct ggml_tensor * attn_o = nullptr; + struct ggml_tensor * attn_o_b = nullptr; + + // normalize + struct ggml_tensor * norm = nullptr; + struct ggml_tensor * norm_b = nullptr; +}; + +struct llama_layer_convnext { + struct ggml_tensor * dw = nullptr; + struct ggml_tensor * dw_b = nullptr; + + struct ggml_tensor * norm = nullptr; + struct ggml_tensor * norm_b = nullptr; + + struct ggml_tensor * pw1 = nullptr; + struct ggml_tensor * pw1_b = nullptr; + + struct ggml_tensor * pw2 = nullptr; + struct ggml_tensor * pw2_b = nullptr; + + struct ggml_tensor * gamma = nullptr; +}; + +struct llama_layer { + // normalization + struct ggml_tensor * attn_norm = nullptr; + struct ggml_tensor * attn_norm_b = nullptr; + struct ggml_tensor * attn_norm_2 = nullptr; + struct ggml_tensor * attn_norm_2_b = nullptr; + struct ggml_tensor * attn_q_norm = nullptr; + struct ggml_tensor * attn_q_norm_b = nullptr; + struct ggml_tensor * attn_k_norm = nullptr; + struct ggml_tensor * attn_k_norm_b = nullptr; + struct ggml_tensor * attn_out_norm = nullptr; + struct ggml_tensor * attn_out_norm_b = nullptr; + struct ggml_tensor * attn_q_a_norm = nullptr; + struct ggml_tensor * attn_kv_a_norm = nullptr; + struct ggml_tensor * attn_sub_norm = nullptr; + struct ggml_tensor * attn_post_norm = nullptr; + struct ggml_tensor * ffn_sub_norm = nullptr; + struct ggml_tensor * attn_norm_cross = nullptr; + struct ggml_tensor * attn_norm_enc = nullptr; + + // attention + struct ggml_tensor * wq = nullptr; + struct ggml_tensor * wk = nullptr; + struct ggml_tensor * wv = nullptr; + struct ggml_tensor * wo = nullptr; + struct ggml_tensor * wqkv = nullptr; + struct ggml_tensor * wq_a = nullptr; + struct ggml_tensor * wq_b = nullptr; + struct ggml_tensor * wkv_a_mqa = nullptr; + struct ggml_tensor * wkv_b = nullptr; + struct ggml_tensor * wk_b = nullptr; + struct ggml_tensor * wv_b = nullptr; + struct ggml_tensor * wq_cross = nullptr; + struct ggml_tensor * wk_cross = nullptr; + struct ggml_tensor * wv_cross = nullptr; + struct ggml_tensor * wo_cross = nullptr; + struct ggml_tensor * wq_enc = nullptr; + struct ggml_tensor * wk_enc = nullptr; + struct ggml_tensor * wv_enc = nullptr; + struct ggml_tensor * wo_enc = nullptr; + + // attention bias + struct ggml_tensor * bq = nullptr; + struct ggml_tensor * bk = nullptr; + struct ggml_tensor * bv = nullptr; + struct ggml_tensor * bo = nullptr; + struct ggml_tensor * bqkv = nullptr; + + // relative position bias + struct ggml_tensor * attn_rel_b = nullptr; + struct ggml_tensor * attn_rel_b_enc = nullptr; + struct ggml_tensor * attn_rel_b_cross = nullptr; + + // normalization + struct ggml_tensor * ffn_norm = nullptr; + struct ggml_tensor * ffn_norm_b = nullptr; + struct ggml_tensor * ffn_post_norm = nullptr; + struct ggml_tensor * layer_out_norm = nullptr; + struct ggml_tensor * layer_out_norm_b = nullptr; + struct ggml_tensor * ffn_norm_exps = nullptr; + struct ggml_tensor * ffn_norm_enc = nullptr; + + // ff + struct ggml_tensor * ffn_gate = nullptr; // w1 + struct ggml_tensor * ffn_down = nullptr; // w2 + struct ggml_tensor * ffn_up = nullptr; // w3 + struct ggml_tensor * ffn_gate_enc = nullptr; + struct ggml_tensor * ffn_down_enc = nullptr; + struct ggml_tensor * ffn_up_enc = nullptr; + + // ff MoE + struct ggml_tensor * ffn_gate_inp = nullptr; + struct ggml_tensor * ffn_gate_exps = nullptr; + struct ggml_tensor * ffn_down_exps = nullptr; + struct ggml_tensor * ffn_up_exps = nullptr; + + // ff shared expert (shexp) + struct ggml_tensor * ffn_gate_inp_shexp = nullptr; + struct ggml_tensor * ffn_gate_shexp = nullptr; + struct ggml_tensor * ffn_down_shexp = nullptr; + struct ggml_tensor * ffn_up_shexp = nullptr; + + // ff bias + struct ggml_tensor * ffn_gate_b = nullptr; + struct ggml_tensor * ffn_down_b = nullptr; // b2 + struct ggml_tensor * ffn_up_b = nullptr; // b3 + struct ggml_tensor * ffn_act = nullptr; + struct ggml_tensor * ffn_exp_probs_b = nullptr; + + // mamba proj + struct ggml_tensor * ssm_in = nullptr; + struct ggml_tensor * ssm_x = nullptr; + struct ggml_tensor * ssm_dt = nullptr; + struct ggml_tensor * ssm_out = nullptr; + + // mamba + struct ggml_tensor * ssm_conv1d = nullptr; + struct ggml_tensor * ssm_a = nullptr; + struct ggml_tensor * ssm_d = nullptr; + + // mamba bias + struct ggml_tensor * ssm_conv1d_b = nullptr; + struct ggml_tensor * ssm_dt_b = nullptr; + + // rwkv + struct ggml_tensor * time_mix_w1 = nullptr; + struct ggml_tensor * time_mix_w2 = nullptr; + struct ggml_tensor * time_mix_lerp_x = nullptr; + struct ggml_tensor * time_mix_lerp_w = nullptr; + struct ggml_tensor * time_mix_lerp_k = nullptr; + struct ggml_tensor * time_mix_lerp_v = nullptr; + struct ggml_tensor * time_mix_lerp_r = nullptr; + struct ggml_tensor * time_mix_lerp_g = nullptr; + struct ggml_tensor * time_mix_lerp_fused = nullptr; + + struct ggml_tensor * time_mix_first = nullptr; + struct ggml_tensor * time_mix_decay = nullptr; + struct ggml_tensor * time_mix_decay_w1 = nullptr; + struct ggml_tensor * time_mix_decay_w2 = nullptr; + struct ggml_tensor * time_mix_key = nullptr; + struct ggml_tensor * time_mix_key_b = nullptr; + struct ggml_tensor * time_mix_value = nullptr; + struct ggml_tensor * time_mix_value_b = nullptr; + struct ggml_tensor * time_mix_receptance = nullptr; + struct ggml_tensor * time_mix_receptance_b = nullptr; + struct ggml_tensor * time_mix_gate = nullptr; + + // rwkv7 + struct ggml_tensor * time_mix_w0 = nullptr; + struct ggml_tensor * time_mix_a0 = nullptr; + struct ggml_tensor * time_mix_a1 = nullptr; + struct ggml_tensor * time_mix_a2 = nullptr; + struct ggml_tensor * time_mix_v0 = nullptr; + struct ggml_tensor * time_mix_v1 = nullptr; + struct ggml_tensor * time_mix_v2 = nullptr; + struct ggml_tensor * time_mix_g1 = nullptr; + struct ggml_tensor * time_mix_g2 = nullptr; + struct ggml_tensor * time_mix_k_k = nullptr; + struct ggml_tensor * time_mix_k_a = nullptr; + struct ggml_tensor * time_mix_r_k = nullptr; + + struct ggml_tensor * time_mix_ln = nullptr; + struct ggml_tensor * time_mix_ln_b = nullptr; + struct ggml_tensor * time_mix_output = nullptr; + + struct ggml_tensor * channel_mix_lerp_k = nullptr; + struct ggml_tensor * channel_mix_lerp_r = nullptr; + + struct ggml_tensor * channel_mix_key = nullptr; + struct ggml_tensor * channel_mix_receptance = nullptr; + struct ggml_tensor * channel_mix_value = nullptr; + + // long rope factors + struct ggml_tensor * rope_long = nullptr; + struct ggml_tensor * rope_short = nullptr; + struct ggml_tensor * rope_freqs = nullptr; + + // bitnet scale + struct ggml_tensor * wq_scale = nullptr; + struct ggml_tensor * wk_scale = nullptr; + struct ggml_tensor * wv_scale = nullptr; + struct ggml_tensor * wo_scale = nullptr; + struct ggml_tensor * ffn_gate_scale = nullptr; + struct ggml_tensor * ffn_up_scale = nullptr; + struct ggml_tensor * ffn_down_scale = nullptr; + + struct llama_layer_posnet posnet; + + struct llama_layer_convnext convnext; +}; + +struct llama_model { + llm_type type = LLM_TYPE_UNKNOWN; + llm_arch arch = LLM_ARCH_UNKNOWN; + + std::string name = "n/a"; + + llama_hparams hparams = {}; + llama_vocab vocab; + + // for classifier models + std::vector classifier_labels; + + struct ggml_tensor * tok_embd = nullptr; + struct ggml_tensor * type_embd = nullptr; + struct ggml_tensor * pos_embd = nullptr; + struct ggml_tensor * tok_norm = nullptr; + struct ggml_tensor * tok_norm_b = nullptr; + + struct ggml_tensor * output_norm = nullptr; + struct ggml_tensor * output_norm_b = nullptr; + struct ggml_tensor * output = nullptr; + struct ggml_tensor * output_b = nullptr; + struct ggml_tensor * output_norm_enc = nullptr; + + // classifier + struct ggml_tensor * cls = nullptr; + struct ggml_tensor * cls_b = nullptr; + struct ggml_tensor * cls_out = nullptr; + struct ggml_tensor * cls_out_b = nullptr; + + struct ggml_tensor * conv1d = nullptr; + struct ggml_tensor * conv1d_b = nullptr; + + std::vector layers; + + llama_model_params params; + + // gguf metadata + std::unordered_map gguf_kv; + + // list of devices used in this model + std::vector devices; + + // for quantize-stats only + std::vector> tensors_by_name; + + int64_t t_load_us = 0; + int64_t t_start_us = 0; + + explicit llama_model(const struct llama_model_params & params); + ~llama_model(); + + void load_stats (llama_model_loader & ml); + void load_arch (llama_model_loader & ml); + void load_hparams(llama_model_loader & ml); + void load_vocab (llama_model_loader & ml); + bool load_tensors(llama_model_loader & ml); // returns false if cancelled by progress_callback + + std::string arch_name() const; + std::string type_name() const; + + std::string desc() const; + + size_t size() const; + size_t n_tensors() const; + size_t n_devices() const; + + // total number of parameters in the model + uint64_t n_elements() const; + + void print_info() const; + + ggml_backend_dev_t dev_layer(int il) const; + ggml_backend_dev_t dev_output() const; + + ggml_backend_buffer_type_t select_buft(int il) const; + + bool has_tensor_overrides() const; + + const struct ggml_tensor * get_tensor(const char * name) const; + + float get_rope_freq_base (const llama_cparams & cparams, int il) const; + float get_rope_freq_scale(const llama_cparams & cparams, int il) const; + + ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const; + + // note: can mutate `cparams` + // TODO: move this to new llm_arch_model_i interface + llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const; + + // TODO: move this to new llm_arch_model_i interface + llm_graph_result_ptr build_graph( + const llm_graph_params & params, + ggml_cgraph * gf, + llm_graph_type type) const; + +private: + struct impl; + std::unique_ptr pimpl; +}; + +const char * llm_type_name(llm_type type); + +// For internal test use +// TODO: remove +const std::vector> & llama_internal_get_tensor_map(const llama_model * model); diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8cf45732fd6d4817cdafd6fd3de6c2bea421fec4 --- /dev/null +++ b/src/llama-quant.cpp @@ -0,0 +1,967 @@ +#include "llama-quant.h" + +#include "llama-impl.h" +#include "llama-model.h" +#include "llama-model-loader.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Quantization types. Changes to this struct must be replicated in quantize.cpp +struct tensor_quantization { + std::string name; + ggml_type quant = GGML_TYPE_COUNT; +}; + +static void zeros(std::ofstream & file, size_t n) { + char zero = 0; + for (size_t i = 0; i < n; ++i) { + file.write(&zero, 1); + } +} + +struct quantize_state_impl { + const llama_model & model; + const llama_model_quantize_params * params; + + int n_attention_wv = 0; + int n_ffn_down = 0; + int n_ffn_gate = 0; + int n_ffn_up = 0; + int i_attention_wv = 0; + int i_ffn_down = 0; + int i_ffn_gate = 0; + int i_ffn_up = 0; + + int n_k_quantized = 0; + int n_fallback = 0; + + bool has_imatrix = false; + + // used to figure out if a model shares tok_embd with the output weight + bool has_output = false; + + quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params) + : model(model) + , params(params) + {} +}; + +static void llama_tensor_dequantize_impl( + ggml_tensor * tensor, std::vector> & output, std::vector & workers, + const size_t nelements, const int nthread +) { + if (output.size() < nelements) { + output.resize(nelements); + } + float * f32_output = (float *) output.data(); + + const ggml_type_traits * qtype = ggml_get_type_traits(tensor->type); + if (ggml_is_quantized(tensor->type)) { + if (qtype->to_float == NULL) { + throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type))); + } + } else if (tensor->type != GGML_TYPE_F16 && + tensor->type != GGML_TYPE_BF16) { + throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type))); + } + + if (nthread < 2) { + if (tensor->type == GGML_TYPE_F16) { + ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements); + } else if (tensor->type == GGML_TYPE_BF16) { + ggml_bf16_to_fp32_row((ggml_bf16_t *)tensor->data, f32_output, nelements); + } else if (ggml_is_quantized(tensor->type)) { + qtype->to_float(tensor->data, f32_output, nelements); + } else { + GGML_ABORT("fatal error"); // unreachable + } + return; + } + + size_t block_size; + if (tensor->type == GGML_TYPE_F16 || + tensor->type == GGML_TYPE_BF16) { + block_size = 1; + } else { + block_size = (size_t)ggml_blck_size(tensor->type); + } + + size_t block_size_bytes = ggml_type_size(tensor->type); + + GGML_ASSERT(nelements % block_size == 0); + size_t nblocks = nelements / block_size; + size_t blocks_per_thread = nblocks / nthread; + size_t spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count + + size_t in_buff_offs = 0; + size_t out_buff_offs = 0; + + for (int tnum = 0; tnum < nthread; tnum++) { + size_t thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread + size_t thr_elems = thr_blocks * block_size; // number of elements for this thread + size_t thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread + + auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) { + if (typ == GGML_TYPE_F16) { + ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels); + } else if (typ == GGML_TYPE_BF16) { + ggml_bf16_to_fp32_row((ggml_bf16_t *)inbuf, outbuf, nels); + } else { + qtype->to_float(inbuf, outbuf, nels); + } + }; + workers.emplace_back(compute, tensor->type, (uint8_t *) tensor->data + in_buff_offs, f32_output + out_buff_offs, thr_elems); + in_buff_offs += thr_block_bytes; + out_buff_offs += thr_elems; + } + for (auto & w : workers) { w.join(); } + workers.clear(); +} + +static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) { + const std::string name = ggml_get_name(tensor); + + // TODO: avoid hardcoded tensor names - use the TN_* constants + const llm_arch arch = qs.model.arch; + const auto tn = LLM_TN(arch); + + auto use_more_bits = [](int i_layer, int n_layers) -> bool { + return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2; + }; + const int n_expert = std::max(1, (int)qs.model.hparams.n_expert); + auto layer_info = [n_expert] (int i_layer, int n_layer, const char * name) { + if (n_expert > 1) { + // Believe it or not, "experts" in the FFN of Mixtral-8x7B are not consecutive, but occasionally randomly + // sprinkled in the model. Hence, simply dividing i_ffn_down by n_expert does not work + // for getting the current layer as I initially thought, and we need to resort to parsing the + // tensor name. + if (sscanf(name, "blk.%d.", &i_layer) != 1) { + throw std::runtime_error(format("Failed to determine layer for tensor %s", name)); + } + if (i_layer < 0 || i_layer >= n_layer) { + throw std::runtime_error(format("Bad layer %d for tensor %s. Must be in [0, %d)", i_layer, name, n_layer)); + } + } + return std::make_pair(i_layer, n_layer); + }; + + // for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings + // with the quantization of the output tensor + if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) { + if (qs.params->output_tensor_type < GGML_TYPE_COUNT) { + new_type = qs.params->output_tensor_type; + } else { + const int64_t nx = tensor->ne[0]; + const int64_t qk_k = ggml_blck_size(new_type); + + if (arch == LLM_ARCH_FALCON || nx % qk_k != 0) { + new_type = GGML_TYPE_Q8_0; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || + ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || + ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) { + new_type = GGML_TYPE_Q5_K; + } + else if (new_type != GGML_TYPE_Q8_0) { + new_type = GGML_TYPE_Q6_K; + } + } + } else if (name == "token_embd.weight") { + if (qs.params->token_embedding_type < GGML_TYPE_COUNT) { + new_type = qs.params->token_embedding_type; + } else { + if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || + ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) { + new_type = GGML_TYPE_Q2_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) { + new_type = GGML_TYPE_IQ3_S; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { + new_type = GGML_TYPE_IQ3_S; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_TQ1_0 || ftype == LLAMA_FTYPE_MOSTLY_TQ2_0) { + new_type = GGML_TYPE_Q4_K; + } + } + } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || + ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) { + if (name.find("attn_v.weight") != std::string::npos) { + if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K; + else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K; + ++qs.i_attention_wv; + } + else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) { + new_type = GGML_TYPE_Q4_K; + } + else if (name.find("ffn_down") != std::string::npos) { + if (qs.i_ffn_down < qs.n_ffn_down/8) { + new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K; + } + ++qs.i_ffn_down; + } + else if (name.find("attn_output.weight") != std::string::npos) { + if (qs.model.hparams.n_expert == 8) { + new_type = GGML_TYPE_Q5_K; + } else { + if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_XXS; + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) new_type = GGML_TYPE_IQ3_S; + } + } + } else if (name.find("attn_v.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) { + new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && qs.model.hparams.n_gqa() >= 4) { + new_type = GGML_TYPE_Q4_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { + new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : !qs.has_imatrix ? GGML_TYPE_IQ3_S : GGML_TYPE_IQ3_XXS; + } + else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S) && qs.model.hparams.n_gqa() >= 4) { + new_type = GGML_TYPE_Q4_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) { + new_type = GGML_TYPE_Q4_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { + new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; + else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && qs.model.hparams.n_gqa() >= 4) { + new_type = GGML_TYPE_Q5_K; + } + else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && + use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_Q6_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K; + if (qs.model.type == LLM_TYPE_70B) { + // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is + // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with + // nearly negligible increase in model size by quantizing this tensor with more bits: + if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K; + } + if (qs.model.hparams.n_expert == 8) { + // for the 8-expert model, bumping this to Q8_0 trades just ~128MB + // TODO: explore better strategies + new_type = GGML_TYPE_Q8_0; + } + ++qs.i_attention_wv; + } else if (name.find("attn_k.weight") != std::string::npos) { + if (qs.model.hparams.n_expert == 8) { + // for the 8-expert model, bumping this to Q8_0 trades just ~128MB + // TODO: explore better strategies + new_type = GGML_TYPE_Q8_0; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) { + new_type = GGML_TYPE_IQ3_XXS; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { + new_type = GGML_TYPE_IQ2_S; + } + } else if (name.find("attn_q.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) { + new_type = GGML_TYPE_IQ3_XXS; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { + new_type = GGML_TYPE_IQ2_S; + } + } else if (name.find("ffn_down") != std::string::npos) { + auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str()); + int i_layer = info.first, n_layer = info.second; + if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) { + if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS && !qs.has_imatrix) { + new_type = i_layer < n_layer/8 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { + new_type = i_layer < n_layer/16 ? GGML_TYPE_Q5_K + : arch != LLM_ARCH_FALCON || use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q4_K + : GGML_TYPE_Q3_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M && (i_layer < n_layer/8 || + (qs.model.hparams.n_expert == 8 && use_more_bits(i_layer, n_layer)))) { + new_type = GGML_TYPE_Q4_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) { + new_type = arch == LLM_ARCH_FALCON ? GGML_TYPE_Q4_K : GGML_TYPE_Q5_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) { + if (arch == LLM_ARCH_FALCON) { + new_type = i_layer < n_layer/16 ? GGML_TYPE_Q6_K : + use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; + } else { + if (use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K; + } + } + else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && !qs.has_imatrix) { + new_type = GGML_TYPE_Q5_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && arch != LLM_ARCH_FALCON && i_layer < n_layer/8) { + new_type = GGML_TYPE_Q5_K; + } + else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || ftype == LLAMA_FTYPE_MOSTLY_Q5_0) + && qs.has_imatrix && i_layer < n_layer/8) { + // Guard against craziness in the first few ffn_down layers that can happen even with imatrix for Q4_0/Q5_0. + // We only do it when an imatrix is provided because a) we want to make sure that one can always get the + // same quantization as before imatrix stuff, and b) Q4_1/Q5_1 do go crazy on ffn_down without an imatrix. + new_type = ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q5_1; + } + ++qs.i_ffn_down; + } else if (name.find("attn_output.weight") != std::string::npos) { + if (arch != LLM_ARCH_FALCON) { + if (qs.model.hparams.n_expert == 8) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || + ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || + ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S || + ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) { + new_type = GGML_TYPE_Q5_K; + } + } else { + if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K ) new_type = GGML_TYPE_Q3_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) new_type = GGML_TYPE_IQ3_S; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M ) new_type = GGML_TYPE_Q4_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L ) new_type = GGML_TYPE_Q5_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M ) new_type = GGML_TYPE_Q4_K; + } + } else { + if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K; + } + } + else if (name.find("attn_qkv.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) { + new_type = GGML_TYPE_Q4_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K; + } + else if (name.find("ffn_gate") != std::string::npos) { + auto info = layer_info(qs.i_ffn_gate, qs.n_ffn_gate, name.c_str()); + int i_layer = info.first, n_layer = info.second; + if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) { + new_type = GGML_TYPE_IQ3_XXS; + } + ++qs.i_ffn_gate; + } + else if (name.find("ffn_up") != std::string::npos) { + auto info = layer_info(qs.i_ffn_up, qs.n_ffn_up, name.c_str()); + int i_layer = info.first, n_layer = info.second; + if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) { + new_type = GGML_TYPE_IQ3_XXS; + } + ++qs.i_ffn_up; + } + + // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; + //} + // IK: let's remove this, else Q2_K is almost the same as Q3_K_S + //else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) { + // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; + //} + // This can be used to reduce the size of the Q5_K_S model. + // The associated PPL increase is fully in line with the size reduction + //else { + // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K; + //} + bool convert_incompatible_tensor = false; + { + const int64_t nx = tensor->ne[0]; + const int64_t ny = tensor->ne[1]; + const int64_t qk_k = ggml_blck_size(new_type); + + if (nx % qk_k != 0) { + LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type)); + convert_incompatible_tensor = true; + } else { + ++qs.n_k_quantized; + } + } + + if (convert_incompatible_tensor) { + switch (new_type) { + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; + case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; + case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break; + case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; + default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); + } + if (tensor->ne[0] % ggml_blck_size(new_type) != 0) { + new_type = GGML_TYPE_F16; + } + LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type)); + ++qs.n_fallback; + } + + return new_type; +} + +static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector & workers, const int nthread) { + if (nthread < 2) { + // single-thread + size_t new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix); + if (!ggml_validate_row_data(new_type, new_data, new_size)) { + throw std::runtime_error("quantized data validation failed"); + } + return new_size; + } + + std::mutex mutex; + int64_t counter = 0; + size_t new_size = 0; + bool valid = true; + auto compute = [&mutex, &counter, &new_size, &valid, new_type, f32_data, new_data, chunk_size, + nrows, n_per_row, imatrix]() { + const int64_t nrows_per_chunk = chunk_size / n_per_row; + size_t local_size = 0; + while (true) { + std::unique_lock lock(mutex); + int64_t first_row = counter; counter += nrows_per_chunk; + if (first_row >= nrows) { + if (local_size > 0) { + new_size += local_size; + } + break; + } + lock.unlock(); + const int64_t this_nrow = std::min(nrows - first_row, nrows_per_chunk); + size_t this_size = ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix); + local_size += this_size; + + // validate the quantized data + const size_t row_size = ggml_row_size(new_type, n_per_row); + void * this_data = (char *) new_data + first_row * row_size; + if (!ggml_validate_row_data(new_type, this_data, this_size)) { + std::unique_lock lock(mutex); + valid = false; + break; + } + } + }; + for (int it = 0; it < nthread - 1; ++it) { + workers.emplace_back(compute); + } + compute(); + for (auto & w : workers) { w.join(); } + workers.clear(); + if (!valid) { + throw std::runtime_error("quantized data validation failed"); + } + return new_size; +} + +static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { + ggml_type default_type; + llama_ftype ftype = params->ftype; + + switch (params->ftype) { + case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break; + case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break; + case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break; + case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break; + case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break; + case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; + case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; + case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break; + + // K-quants + case LLAMA_FTYPE_MOSTLY_Q2_K_S: + case LLAMA_FTYPE_MOSTLY_Q2_K: default_type = GGML_TYPE_Q2_K; break; + case LLAMA_FTYPE_MOSTLY_IQ3_XS: default_type = GGML_TYPE_IQ3_S; break; + case LLAMA_FTYPE_MOSTLY_Q3_K_S: + case LLAMA_FTYPE_MOSTLY_Q3_K_M: + case LLAMA_FTYPE_MOSTLY_Q3_K_L: default_type = GGML_TYPE_Q3_K; break; + case LLAMA_FTYPE_MOSTLY_Q4_K_S: + case LLAMA_FTYPE_MOSTLY_Q4_K_M: default_type = GGML_TYPE_Q4_K; break; + case LLAMA_FTYPE_MOSTLY_Q5_K_S: + case LLAMA_FTYPE_MOSTLY_Q5_K_M: default_type = GGML_TYPE_Q5_K; break; + case LLAMA_FTYPE_MOSTLY_Q6_K: default_type = GGML_TYPE_Q6_K; break; + case LLAMA_FTYPE_MOSTLY_TQ1_0: default_type = GGML_TYPE_TQ1_0; break; + case LLAMA_FTYPE_MOSTLY_TQ2_0: default_type = GGML_TYPE_TQ2_0; break; + case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break; + case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break; + case LLAMA_FTYPE_MOSTLY_IQ2_S: default_type = GGML_TYPE_IQ2_XS; break; + case LLAMA_FTYPE_MOSTLY_IQ2_M: default_type = GGML_TYPE_IQ2_S; break; + case LLAMA_FTYPE_MOSTLY_IQ3_XXS: default_type = GGML_TYPE_IQ3_XXS; break; + case LLAMA_FTYPE_MOSTLY_IQ1_S: default_type = GGML_TYPE_IQ1_S; break; + case LLAMA_FTYPE_MOSTLY_IQ1_M: default_type = GGML_TYPE_IQ1_M; break; + case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break; + case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; + case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break; + case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break; + + default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); + } + + int nthread = params->nthread; + + if (nthread <= 0) { + nthread = std::thread::hardware_concurrency(); + } + + // mmap consistently increases speed on Linux, and also increases speed on Windows with + // hot cache. It may cause a slowdown on macOS, possibly related to free memory. +#if defined(__linux__) || defined(_WIN32) + constexpr bool use_mmap = true; +#else + constexpr bool use_mmap = false; +#endif + + llama_model_kv_override * kv_overrides = nullptr; + if (params->kv_overrides) { + auto * v = (std::vector*)params->kv_overrides; + kv_overrides = v->data(); + } + + std::vector splits = {}; + llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides, nullptr); + ml.init_mappings(false); // no prefetching + + llama_model model(llama_model_default_params()); + + model.load_arch (ml); + model.load_hparams(ml); + model.load_stats (ml); + + quantize_state_impl qs(model, params); + + if (params->only_copy) { + ftype = ml.ftype; + } + const std::unordered_map> * imatrix_data = nullptr; + if (params->imatrix) { + imatrix_data = static_cast>*>(params->imatrix); + if (imatrix_data) { + LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size())); + qs.has_imatrix = true; + // check imatrix for nans or infs + for (const auto & kv : *imatrix_data) { + for (float f : kv.second) { + if (!std::isfinite(f)) { + throw std::runtime_error(format("imatrix contains non-finite value %f\n", f)); + } + } + } + } + } + + const size_t align = GGUF_DEFAULT_ALIGNMENT; + gguf_context_ptr ctx_out { gguf_init_empty() }; + + // copy the KV pairs from the input file + gguf_set_kv (ctx_out.get(), ml.meta.get()); + gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV + gguf_set_val_u32(ctx_out.get(), "general.file_type", ftype); // TODO: use LLM_KV + + // Remove split metadata + gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str()); + gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str()); + gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str()); + + if (params->kv_overrides) { + const std::vector & overrides = *(const std::vector *)params->kv_overrides; + for (const auto & o : overrides) { + if (o.key[0] == 0) break; + if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) { + gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64); + } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) { + // Setting type to UINT32. See https://github.com/ggml-org/llama.cpp/pull/14182 for context + gguf_set_val_u32(ctx_out.get(), o.key, (uint32_t)abs(o.val_i64)); + } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) { + gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool); + } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) { + gguf_set_val_str(ctx_out.get(), o.key, o.val_str); + } else { + LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key); + } + } + } + + // make a list of weights + std::vector tensors; + tensors.reserve(ml.weights_map.size()); + for (const auto & it : ml.weights_map) { + tensors.push_back(&it.second); + } + + // keep_split requires that the weights are sorted by split index + if (params->keep_split) { + std::sort(tensors.begin(), tensors.end(), [](const llama_model_loader::llama_tensor_weight * a, const llama_model_loader::llama_tensor_weight * b) { + if (a->idx == b->idx) { + return a->offs < b->offs; + } + return a->idx < b->idx; + }); + } + + for (const auto * it : tensors) { + const struct ggml_tensor * tensor = it->tensor; + + const std::string name = ggml_get_name(tensor); + + // TODO: avoid hardcoded tensor names - use the TN_* constants + if (name.find("attn_v.weight") != std::string::npos || + name.find("attn_qkv.weight") != std::string::npos || + name.find("attn_kv_b.weight")!= std::string::npos) { + ++qs.n_attention_wv; + } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) { + qs.has_output = true; + } + } + + qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer; + + // sanity checks for models that have attention layers + if (qs.n_attention_wv != 0) + { + const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin(); + // attention layers have a non-zero number of kv heads + int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0); + if (llama_model_has_encoder(&model)) { + n_attn_layer *= 3; + } + GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected"); + } + + size_t total_size_org = 0; + size_t total_size_new = 0; + + std::vector workers; + workers.reserve(nthread); + + int idx = 0; + + std::vector> read_data; + std::vector> work; + std::vector> f32_conv_buf; + + uint16_t n_split = 1; + + // Assume split index is continuous + if (params->keep_split) { + for (const auto * it : tensors) { + n_split = std::max(uint16_t(it->idx + 1), n_split); + } + } + std::vector ctx_outs(n_split); + ctx_outs[0] = std::move(ctx_out); + + // populate the original tensors so we get an initial meta data + for (const auto * it : tensors) { + uint16_t i_split = params->keep_split ? it->idx : 0; + ggml_tensor * tensor = it->tensor; + if (!ctx_outs[i_split]) { + ctx_outs[i_split].reset(gguf_init_empty()); + } + gguf_add_tensor(ctx_outs[i_split].get(), tensor); + } + + // Set split info if needed + if (n_split > 1) { + for (size_t i = 0; i < ctx_outs.size(); ++i) { + gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i); + gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split); + gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors); + } + } + + int cur_split = -1; + std::ofstream fout; + auto close_ofstream = [&]() { + // Write metadata and close file handler + if (fout.is_open()) { + fout.seekp(0); + std::vector data(gguf_get_meta_size(ctx_outs[cur_split].get())); + gguf_get_meta_data(ctx_outs[cur_split].get(), data.data()); + fout.write((const char *) data.data(), data.size()); + fout.close(); + } + }; + auto new_ofstream = [&](int index) { + cur_split = index; + GGML_ASSERT(ctx_outs[cur_split] && "Find uninitialized gguf_context"); + std::string fname = fname_out; + if (params->keep_split) { + std::vector split_path(llama_path_max(), 0); + llama_split_path(split_path.data(), split_path.size(), fname_out.c_str(), cur_split, n_split); + fname = std::string(split_path.data()); + } + + fout = std::ofstream(fname, std::ios::binary); + fout.exceptions(std::ofstream::failbit); // fail fast on write errors + const size_t meta_size = gguf_get_meta_size(ctx_outs[cur_split].get()); + // placeholder for the meta data + ::zeros(fout, meta_size); + }; + + const auto tn = LLM_TN(model.arch); + new_ofstream(0); + for (const auto * it : tensors) { + const auto & weight = *it; + ggml_tensor * tensor = weight.tensor; + if (weight.idx != cur_split && params->keep_split) { + close_ofstream(); + new_ofstream(weight.idx); + } + + const std::string name = ggml_get_name(tensor); + + if (!ml.use_mmap) { + if (read_data.size() < ggml_nbytes(tensor)) { + read_data.resize(ggml_nbytes(tensor)); + } + tensor->data = read_data.data(); + } + ml.load_data_for(tensor); + + LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ", + ++idx, ml.n_tensors, + ggml_get_name(tensor), + llama_format_tensor_shape(tensor).c_str(), + ggml_type_name(tensor->type)); + + // This used to be a regex, but has an extreme cost to compile times. + bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? + + // quantize only 2D and 3D tensors (experts) + quantize &= (ggml_n_dims(tensor) >= 2); + + // do not quantize norm tensors + quantize &= name.find("_norm.weight") == std::string::npos; + + quantize &= params->quantize_output_tensor || name != "output.weight"; + quantize &= !params->only_copy; + + // do not quantize expert gating tensors + // NOTE: can't use LLM_TN here because the layer number is not known + quantize &= name.find("ffn_gate_inp.weight") == std::string::npos; + + // do not quantize positional embeddings and token types (BERT) + quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight"); + quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight"); + + // do not quantize Mamba's small yet 2D weights + // NOTE: can't use LLM_TN here because the layer number is not known + quantize &= name.find("ssm_conv1d.weight") == std::string::npos; + + // do not quantize RWKV's small yet 2D weights + quantize &= name.find("time_mix_first.weight") == std::string::npos; + quantize &= name.find("time_mix_w0.weight") == std::string::npos; + quantize &= name.find("time_mix_w1.weight") == std::string::npos; + quantize &= name.find("time_mix_w2.weight") == std::string::npos; + quantize &= name.find("time_mix_v0.weight") == std::string::npos; + quantize &= name.find("time_mix_v1.weight") == std::string::npos; + quantize &= name.find("time_mix_v2.weight") == std::string::npos; + quantize &= name.find("time_mix_a0.weight") == std::string::npos; + quantize &= name.find("time_mix_a1.weight") == std::string::npos; + quantize &= name.find("time_mix_a2.weight") == std::string::npos; + quantize &= name.find("time_mix_g1.weight") == std::string::npos; + quantize &= name.find("time_mix_g2.weight") == std::string::npos; + quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos; + quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos; + quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos; + + // do not quantize relative position bias (T5) + quantize &= name.find("attn_rel_b.weight") == std::string::npos; + + ggml_type new_type; + void * new_data; + size_t new_size; + + if (quantize) { + new_type = default_type; + + // get more optimal quantization type based on the tensor shape, layer, etc. + if (!params->pure && ggml_is_quantized(default_type)) { + new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); + // unless the user specifies a type + if (params->tensor_types) { + const std::vector & tensor_types = *static_cast *>(params->tensor_types); + const std::string tensor_name(tensor->name); + for (const auto & [tname, qtype] : tensor_types) { + if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) { + if (qtype != new_type) { + LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type)); + new_type = qtype; + break; // if two or more types are specified for the tensor, first match wins + } + } + } + } + } + + if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { + new_type = params->token_embedding_type; + } + if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) { + new_type = params->output_tensor_type; + } + + // If we've decided to quantize to the same type the tensor is already + // in then there's nothing to do. + quantize = tensor->type != new_type; + } + + if (!quantize) { + new_type = tensor->type; + new_data = tensor->data; + new_size = ggml_nbytes(tensor); + LLAMA_LOG_INFO("size = %8.3f MB\n", ggml_nbytes(tensor)/1024.0/1024.0); + } else { + const int64_t nelements = ggml_nelements(tensor); + + const float * imatrix = nullptr; + if (imatrix_data) { + auto it = imatrix_data->find(tensor->name); + if (it == imatrix_data->end()) { + LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name); + } else { + if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) { + imatrix = it->second.data(); + } else { + LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__, + int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name); + + // this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix + // this is a significant error and it may be good idea to abort the process if this happens, + // since many people will miss the error and not realize that most of the model is being quantized without an imatrix + // tok_embd should be ignored in this case, since it always causes this warning + if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) { + throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s", + int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name)); + } + } + } + } + if ((new_type == GGML_TYPE_IQ2_XXS || + new_type == GGML_TYPE_IQ2_XS || + new_type == GGML_TYPE_IQ2_S || + new_type == GGML_TYPE_IQ1_S || + (new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight")) || + (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) { + LLAMA_LOG_ERROR("\n\n============================================================\n"); + LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name); + LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n"); + LLAMA_LOG_ERROR("============================================================\n\n"); + throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name)); + } + + float * f32_data; + + if (tensor->type == GGML_TYPE_F32) { + f32_data = (float *) tensor->data; + } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) { + throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type))); + } else { + llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread); + f32_data = (float *) f32_conv_buf.data(); + } + + LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type)); + fflush(stdout); + + if (work.size() < (size_t)nelements * 4) { + work.resize(nelements * 4); // upper bound on size + } + new_data = work.data(); + + const int64_t n_per_row = tensor->ne[0]; + const int64_t nrows = tensor->ne[1]; + + static const int64_t min_chunk_size = 32 * 512; + const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)); + + const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1]; + const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size; + const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1; + + // quantize each expert separately since they have different importance matrices + new_size = 0; + for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) { + const float * f32_data_03 = f32_data + i03 * nelements_matrix; + void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows; + const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr; + + new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use); + } + LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0); + } + total_size_org += ggml_nbytes(tensor); + total_size_new += new_size; + + // update the gguf meta data as we go + gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type); + GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size); + gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data); + + // write tensor data + padding + fout.write((const char *) new_data, new_size); + zeros(fout, GGML_PAD(new_size, align) - new_size); + } + close_ofstream(); + + LLAMA_LOG_INFO("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); + LLAMA_LOG_INFO("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0); + + if (qs.n_fallback > 0) { + LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n", + __func__, qs.n_fallback, qs.n_k_quantized + qs.n_fallback); + } +} + +// +// interface implementation +// + +llama_model_quantize_params llama_model_quantize_default_params() { + llama_model_quantize_params result = { + /*.nthread =*/ 0, + /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1, + /*.output_tensor_type =*/ GGML_TYPE_COUNT, + /*.token_embedding_type =*/ GGML_TYPE_COUNT, + /*.allow_requantize =*/ false, + /*.quantize_output_tensor =*/ true, + /*.only_copy =*/ false, + /*.pure =*/ false, + /*.keep_split =*/ false, + /*.imatrix =*/ nullptr, + /*.kv_overrides =*/ nullptr, + /*.tensor_type =*/ nullptr, + }; + + return result; +} + +uint32_t llama_model_quantize( + const char * fname_inp, + const char * fname_out, + const llama_model_quantize_params * params) { + try { + llama_model_quantize_impl(fname_inp, fname_out, params); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what()); + return 1; + } + + return 0; +} diff --git a/src/llama-quant.h b/src/llama-quant.h new file mode 100644 index 0000000000000000000000000000000000000000..6f70f09beec2219624baeca92e2cd7deaa104fb4 --- /dev/null +++ b/src/llama-quant.h @@ -0,0 +1 @@ +#pragma once diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bfbf5fa23011240c0dec57b390670ef1ff47079b --- /dev/null +++ b/src/llama-sampling.cpp @@ -0,0 +1,2575 @@ +#include "llama-sampling.h" + +#include "llama-impl.h" +#include "llama-vocab.h" +#include "llama-grammar.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// the ring buffer works similarly to std::deque, but with a fixed capacity +template +struct ring_buffer { + ring_buffer(size_t cap) : capacity(cap), data(cap) {} + + T & front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; + } + + const T & front() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; + } + + T & back() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + const T & back() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + void push_back(const T & value) { + if (capacity == 0) { + throw std::runtime_error("ring buffer: capacity is zero"); + } + + if (sz == capacity) { + // advance the start when buffer is full + first = (first + 1) % capacity; + } else { + sz++; + } + data[pos] = value; + pos = (pos + 1) % capacity; + } + + T pop_front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + T value = data[first]; + first = (first + 1) % capacity; + sz--; + return value; + } + + //T & operator[](size_t i) { + // if (i >= sz) { + // throw std::runtime_error("ring buffer: index out of bounds"); + // } + // return data[(first + i) % capacity]; + //} + + //const T & at(size_t i) const { + // if (i >= sz) { + // throw std::runtime_error("ring buffer: index out of bounds"); + // } + // return data[(first + i) % capacity]; + //} + + const T & rat(size_t i) const { + if (i >= sz) { + throw std::runtime_error("ring buffer: index out of bounds"); + } + return data[(first + sz - i - 1) % capacity]; + } + + std::vector to_vector() const { + std::vector result; + result.reserve(sz); + for (size_t i = 0; i < sz; i++) { + result.push_back(data[(first + i) % capacity]); + } + return result; + } + + void clear() { + // here only reset the status of the buffer + sz = 0; + first = 0; + pos = 0; + } + + bool empty() const { + return sz == 0; + } + + size_t size() const { + return sz; + } + + size_t capacity = 0; + size_t sz = 0; + size_t first = 0; + size_t pos = 0; + + std::vector data; +}; + +static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) { + // iterator for the probabilities +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wunused-local-typedefs" +#endif + + struct probs_iterator { + typedef std::input_iterator_tag iterator_category; + typedef float value_type; + typedef float * pointer; + typedef float & reference; + typedef ptrdiff_t difference_type; + + const llama_token_data * data; + + bool operator==(const probs_iterator & other) const { return data == other.data; } + bool operator!=(const probs_iterator & other) const { return data != other.data; } + const float & operator*() const { return data->p; } + probs_iterator & operator++() { ++data; return *this; } + probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; } + }; + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif + + std::discrete_distribution dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size}); + + return dist(rng); +} + +/* +static void llama_log_softmax(float * array, size_t size) { + float max_l = *std::max_element(array, array + size); + float sum = 0.f; + for (size_t i = 0; i < size; ++i) { + float p = expf(array[i] - max_l); + sum += p; + array[i] = p; + } + + for (size_t i = 0; i < size; ++i) { + array[i] = logf(array[i] / sum); + } +} +*/ + +static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) { + if (temp <= 0.0f) { + // find the token with the highest logit and set the rest to -inf + size_t max_i = 0; + float max_l = cur_p->data[0].logit; + + for (size_t i = 1; i < cur_p->size; ++i) { + if (cur_p->data[i ].logit > max_l) { + cur_p->data[max_i].logit = -INFINITY; + max_i = i; + max_l = cur_p->data[i].logit; + } else { + cur_p->data[i].logit = -INFINITY; + } + } + + return; + } + + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].logit /= temp; + } +} + +static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) { + GGML_ASSERT(cur_p->size > 0); + + // Sort the logits in descending order + if (!cur_p->sorted) { + std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + cur_p->sorted = true; + } + + float max_l = cur_p->data[0].logit; + float cum_sum = 0.0f; + + for (size_t i = 0; i < cur_p->size; ++i) { + float p = expf(cur_p->data[i].logit - max_l); + cur_p->data[i].p = p; + cum_sum += p; + } + + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= cum_sum; + } +} + +static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) { + // TODO: move bucket sort to separate function so that top_p/typical/softmax first is equally fast + // if (k >= (int32_t)cur_p->size) { + // return; + // } + + if (k <= 0) { + return; + } + + k = std::min(k, (int) cur_p->size); + + // Sort scores in descending order + if (!cur_p->sorted) { + auto comp = [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }; + if (k <= 128) { + std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp); + } else { + constexpr int nbuckets = 128; + constexpr float bucket_low = -10.0f; + constexpr float bucket_high = 10.0f; + constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low); + constexpr float bucket_inter = -bucket_low * bucket_scale; + + std::vector bucket_idx(cur_p->size); + std::vector histo(nbuckets, 0); + + for (int i = 0; i < (int)cur_p->size; ++i) { + const float val = cur_p->data[i].logit; + int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); + ib = std::max(0, std::min(nbuckets - 1, ib)); + bucket_idx[i] = ib; + ++histo[ib]; + } + int nhave = 0; + int ib = nbuckets - 1; + for ( ; ib >= 0; --ib) { + nhave += histo[ib]; + if (nhave >= k) { + break; + } + } + std::vector tmp_tokens(nhave); + auto * ptr = tmp_tokens.data(); + std::vector bucket_ptrs; + bucket_ptrs.reserve(nbuckets - ib); + for (int j = nbuckets - 1; j >= ib; --j) { + bucket_ptrs.push_back(ptr); + ptr += histo[j]; + } + for (int i = 0; i < (int)cur_p->size; ++i) { + int j = bucket_idx[i]; + if (j >= ib) { + *bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data[i]; + } + } + + ptr = tmp_tokens.data(); + int ndone = 0; + for (int j = nbuckets - 1; j > ib; --j) { + std::sort(ptr, ptr + histo[j], comp); + ptr += histo[j]; + ndone += histo[j]; + } + std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp); + + std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data)); + + } + cur_p->sorted = true; + } + + cur_p->size = k; +} + +static uint32_t get_rng_seed(uint32_t seed) { + if (seed == LLAMA_DEFAULT_SEED) { + // use system clock if std::random_device is not a true RNG + static bool is_rd_prng = std::random_device().entropy() == 0; + if (is_rd_prng) { + return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count(); + } + std::random_device rd; + return rd(); + } + return seed; +} + +// llama_sampler API + +struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) { + return new llama_sampler { + /* .iface = */ iface, + /* .ctx = */ ctx, + }; +} + +const char * llama_sampler_name(const struct llama_sampler * smpl) { + if (!smpl->iface) { + return "(null)"; + } + + return smpl->iface->name(smpl); +} + +void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { + if (smpl->iface->accept) { + smpl->iface->accept(smpl, token); + } +} + +void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) { + GGML_ASSERT(smpl->iface->apply); + smpl->iface->apply(smpl, cur_p); +} + +void llama_sampler_reset(struct llama_sampler * smpl) { + if (smpl->iface->reset) { + smpl->iface->reset(smpl); + } +} + +struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) { + if (smpl->iface->clone) { + return smpl->iface->clone(smpl); + } + + if (smpl->ctx == nullptr) { + return llama_sampler_init( + /* .iface = */ smpl->iface, + /* .ctx = */ nullptr + ); + } + + GGML_ABORT("the sampler does not support cloning"); +} + +void llama_sampler_free(struct llama_sampler * smpl) { + if (smpl == nullptr) { + return; + } + + if (smpl->iface->free) { + smpl->iface->free(smpl); + } + + delete smpl; +} + +llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { + const auto * logits = llama_get_logits_ith(ctx, idx); + + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_vocab = llama_vocab_n_tokens(vocab); + + // TODO: do not allocate each time + std::vector cur; + cur.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array cur_p = { + /* .data = */ cur.data(), + /* .size = */ cur.size(), + /* .selected = */ -1, + /* .sorted = */ false, + }; + + llama_sampler_apply(smpl, &cur_p); + + GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size); + + auto token = cur_p.data[cur_p.selected].id; + + llama_sampler_accept(smpl, token); + + return token; +} + +// sampler chain + +static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) { + return "chain"; +} + +static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + time_meas tm(chain->t_sample_us, chain->params.no_perf); + + for (auto * smpl : chain->samplers) { + llama_sampler_accept(smpl, token); + } + + chain->n_sample++; +} + +static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + time_meas tm(chain->t_sample_us, chain->params.no_perf); + + for (auto * smpl : chain->samplers) { + llama_sampler_apply(smpl, cur_p); + } +} + +static void llama_sampler_chain_reset(struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + llama_sampler_reset(smpl); + } + + chain->t_sample_us = 0; + chain->n_sample = 0; +} + +static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) { + const auto * chain_src = (const llama_sampler_chain *) smpl->ctx; + + auto * result = llama_sampler_chain_init(chain_src->params); + + for (auto * smpl : chain_src->samplers) { + llama_sampler_chain_add(result, llama_sampler_clone(smpl)); + } + + return result; +} + +static void llama_sampler_chain_free(struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + llama_sampler_free(smpl); + } + + delete chain; +} + +static struct llama_sampler_i llama_sampler_chain_i = { + /* .name = */ llama_sampler_chain_name, + /* .accept = */ llama_sampler_chain_accept, + /* .apply = */ llama_sampler_chain_apply, + /* .reset = */ llama_sampler_chain_reset, + /* .clone = */ llama_sampler_chain_clone, + /* .free = */ llama_sampler_chain_free, +}; + +struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_chain_i, + /* .ctx = */ new llama_sampler_chain { + /* .params = */ params, + /* .samplers = */ {}, + /* .t_sample_us = */ 0, + /* .n_sample = */ 0, + } + ); +} + +void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { + auto * p = (llama_sampler_chain *) chain->ctx; + p->samplers.push_back(smpl); +} + +struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) { + const auto * p = (const llama_sampler_chain *) chain->ctx; + + if (i < 0 || (size_t) i >= p->samplers.size()) { + return nullptr; + } + + return p->samplers[i]; +} + +struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) { + auto * p = (llama_sampler_chain *) chain->ctx; + + if (i < 0 || (size_t) i >= p->samplers.size()) { + return nullptr; + } + + auto * result = p->samplers[i]; + p->samplers.erase(p->samplers.begin() + i); + + return result; +} + +int llama_sampler_chain_n(const struct llama_sampler * chain) { + const auto * p = (const llama_sampler_chain *) chain->ctx; + + return p->samplers.size(); +} + +// +// samplers +// + +// greedy + +static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) { + return "greedy"; +} + +static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { + cur_p->selected = 0; + for (size_t i = 1; i < cur_p->size; ++i) { + if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) { + cur_p->selected = i; + } + } +} + +static struct llama_sampler_i llama_sampler_greedy_i = { + /* .name = */ llama_sampler_greedy_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_greedy_apply, + /* .reset = */ nullptr, + /* .clone = */ nullptr, + /* .free = */ nullptr, +}; + +struct llama_sampler * llama_sampler_init_greedy() { + return llama_sampler_init( + /* .iface = */ &llama_sampler_greedy_i, + /* .ctx = */ nullptr + ); +} + +// dist + +struct llama_sampler_dist { + const uint32_t seed; + uint32_t seed_cur; + + std::mt19937 rng; +}; + +static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) { + return "dist"; +} + +static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_dist *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p); + + cur_p->selected = llama_sample_dist(cur_p, ctx->rng); +} + +static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_dist *) smpl->ctx; + auto * result = llama_sampler_init_dist(ctx->seed); + + // copy the state + { + auto * result_ctx = (llama_sampler_dist *) result->ctx; + + result_ctx->rng = ctx->rng; + } + + return result; +} + +static void llama_sampler_dist_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_dist *) smpl->ctx; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} + +static void llama_sampler_dist_free(struct llama_sampler * smpl) { + delete (llama_sampler_dist *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_dist_i = { + /* .name = */ llama_sampler_dist_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_dist_apply, + /* .reset = */ llama_sampler_dist_reset, + /* .clone = */ llama_sampler_dist_clone, + /* .free = */ llama_sampler_dist_free, +}; + +struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { + auto seed_cur = get_rng_seed(seed); + return llama_sampler_init( + /* .iface = */ &llama_sampler_dist_i, + /* .ctx = */ new llama_sampler_dist { + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .rng = */ std::mt19937(seed_cur), + } + ); +} + +// softmax + +static const char * llama_sampler_softmax_name(const struct llama_sampler * /*smpl*/) { + return "softmax"; +} + +static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { + llama_sampler_softmax_impl(cur_p); +} + +static struct llama_sampler_i llama_sampler_softmax_i = { + /* .name = */ llama_sampler_softmax_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_softmax_apply, + /* .reset = */ nullptr, + /* .clone = */ nullptr, + /* .free = */ nullptr, +}; + +struct llama_sampler * llama_sampler_init_softmax() { + return llama_sampler_init( + /* .iface = */ &llama_sampler_softmax_i, + /* .ctx = */ nullptr + ); +} + +// top-k + +struct llama_sampler_top_k { + const int32_t k; +}; + +static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) { + return "top-k"; +} + +static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_top_k *) smpl->ctx; + llama_sampler_top_k_impl(cur_p, ctx->k); +} + +static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_top_k *) smpl->ctx; + return llama_sampler_init_top_k(ctx->k); +} + +static void llama_sampler_top_k_free(struct llama_sampler * smpl) { + delete (llama_sampler_top_k *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_top_k_i = { + /* .name = */ llama_sampler_top_k_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_k_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_k_clone, + /* .free = */ llama_sampler_top_k_free, +}; + +struct llama_sampler * llama_sampler_init_top_k(int32_t k) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_top_k_i, + /* .ctx = */ new llama_sampler_top_k { + /* .k = */ k, + } + ); +} + +// top-p + +struct llama_sampler_top_p { + const float p; + const size_t min_keep; +}; + +static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) { + return "top-p"; +} + +static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_top_p *) smpl->ctx; + + if (ctx->p >= 1.0f) { + return; + } + + llama_sampler_softmax_impl(cur_p); + + // Compute the cumulative probabilities + float cum_sum = 0.0f; + size_t last_idx = cur_p->size; + + for (size_t i = 0; i < cur_p->size; ++i) { + cum_sum += cur_p->data[i].p; + + // Check if the running sum is at least p or if we have kept at least min_keep tokens + // we set the last index to i+1 to indicate that the current iterate should be included in the set + if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) { + last_idx = i + 1; + break; + } + } + + // Resize the output vector to keep only the top-p tokens + cur_p->size = last_idx; +} + +static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_top_p *) smpl->ctx; + return llama_sampler_init_top_p(ctx->p, ctx->min_keep); +} + +static void llama_sampler_top_p_free(struct llama_sampler * smpl) { + delete (llama_sampler_top_p *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_top_p_i = { + /* .name = */ llama_sampler_top_p_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_p_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_p_clone, + /* .free = */ llama_sampler_top_p_free, +}; + +struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_top_p_i, + /* .ctx = */ new llama_sampler_top_p { + /* .p = */ p, + /* .min_keep = */ min_keep, + } + ); +} + +// min-p + +struct llama_sampler_min_p { + const float p; + const size_t min_keep; +}; + +static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) { + return "min-p"; +} + +static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_min_p *) smpl->ctx; + + if (ctx->p <= 0.0f || !cur_p->size) { + return; + } + + bool min_p_applied = false; + + // if the cur_p aren't sorted, try the unsorted implementation first + if (!cur_p->sorted) { + std::vector filtered_tokens; + + float max_logit = -FLT_MAX; + for (size_t i = 0; i < cur_p->size; ++i) { + max_logit = std::max(max_logit, cur_p->data[i].logit); + } + const float min_logit = max_logit + logf(ctx->p); // min logit for p_i >= p * p_max + + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].logit >= min_logit) { + filtered_tokens.push_back(cur_p->data[i]); + } + } + + // if we have enough values the operation was a success + if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) { + memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data)); + cur_p->size = filtered_tokens.size(); + min_p_applied = true; + } + } + + // if the cur_p are sorted or the unsorted implementation failed, use this implementation + if (!min_p_applied) { + // Sort the logits in descending order + if (!cur_p->sorted) { + std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + cur_p->sorted = true; + } + + const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max + size_t i = 1; // first token always matches + + for (; i < cur_p->size; ++i) { + if (cur_p->data[i].logit < min_logit && i >= ctx->min_keep) { + break; // prob too small + } + } + + // Resize the output vector to keep only the matching tokens + cur_p->size = i; + } +} + +static struct llama_sampler * llama_sampler_min_p_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_min_p *) smpl->ctx; + return llama_sampler_init_min_p(ctx->p, ctx->min_keep); +} + +static void llama_sampler_min_p_free(struct llama_sampler * smpl) { + delete (llama_sampler_min_p *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_min_p_i = { + /* .name = */ llama_sampler_min_p_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_min_p_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_min_p_clone, + /* .free = */ llama_sampler_min_p_free, +}; + +struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_min_p_i, + /* .ctx = */ new llama_sampler_min_p { + /* .p = */ p, + /* .min_keep = */ min_keep, + } + ); +} + +// typical + +struct llama_sampler_typical { + const float p; + const size_t min_keep; +}; + +static const char * llama_sampler_typical_name(const struct llama_sampler * /*smpl*/) { + return "typical"; +} + +static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_typical *) smpl->ctx; + + // Reference implementation: + // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr + if (ctx->p >= 1.0f) { + return; + } + + // Compute the softmax of logits and calculate entropy + llama_sampler_softmax_impl(cur_p); + + float entropy = 0.0f; + for (size_t i = 0; i < cur_p->size; ++i) { + entropy += -cur_p->data[i].p * logf(cur_p->data[i].p); + } + + // Compute the absolute difference between negative log probability and entropy for each candidate + std::vector shifted_scores; + for (size_t i = 0; i < cur_p->size; ++i) { + float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy); + shifted_scores.push_back(shifted_score); + } + + // Sort tokens based on the shifted_scores and their corresponding indices + std::vector indices(cur_p->size); + std::iota(indices.begin(), indices.end(), 0); + + std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) { + return shifted_scores[a] < shifted_scores[b]; + }); + + // Compute the cumulative probabilities + float cum_sum = 0.0f; + size_t last_idx = indices.size(); + + for (size_t i = 0; i < indices.size(); ++i) { + size_t idx = indices[i]; + cum_sum += cur_p->data[idx].p; + + // Check if the running sum is greater than typical or if we have kept at least min_keep tokens + if (cum_sum > ctx->p && (ctx->min_keep == 0 || i >= ctx->min_keep - 1)) { + last_idx = i + 1; + break; + } + } + + // Resize the output vector to keep only the locally typical tokens + std::vector cur_p_new; + for (size_t i = 0; i < last_idx; ++i) { + size_t idx = indices[i]; + cur_p_new.push_back(cur_p->data[idx]); + } + + // Replace the data in cur_p with the cur_p_new data + std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data); + cur_p->size = cur_p_new.size(); + cur_p->sorted = false; +} + +static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_typical *) smpl->ctx; + return llama_sampler_init_typical(ctx->p, ctx->min_keep); +} + +static void llama_sampler_typical_free(struct llama_sampler * smpl) { + delete (llama_sampler_typical *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_typical_i = { + /* .name = */ llama_sampler_typical_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_typical_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_typical_clone, + /* .free = */ llama_sampler_typical_free, +}; + +struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_typical_i, + /* .ctx = */ new llama_sampler_typical { + /* .p = */ p, + /* .min_keep = */ min_keep, + } + ); +} + +// temp + +struct llama_sampler_temp { + const float temp; +}; + +static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) { + return "temp"; +} + +static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_temp *) smpl->ctx; + + llama_sampler_temp_impl(cur_p, ctx->temp); +} + +static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_temp *) smpl->ctx; + return llama_sampler_init_temp(ctx->temp); +} + +static void llama_sampler_temp_free(struct llama_sampler * smpl) { + delete (llama_sampler_temp *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_temp_i = { + /* .name = */ llama_sampler_temp_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_temp_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_temp_clone, + /* .free = */ llama_sampler_temp_free, +}; + +struct llama_sampler * llama_sampler_init_temp(float temp) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_temp_i, + /* .ctx = */ new llama_sampler_temp { + /*.temp = */ temp, + } + ); +} + +// temp-ext + +struct llama_sampler_temp_ext { + const float temp; + const float delta; + const float exponent; +}; + +static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) { + return "temp-ext"; +} + +static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx; + if (ctx->delta > 0) { + const float min_temp = std::max(0.0f, ctx->temp - ctx->delta); + const float max_temp = ctx->temp + ctx->delta; + + float exponent_val = ctx->exponent; + + // no need to do anything if there is only one (or zero) candidates + if (cur_p->size <= 1) { + return; + } + + // Calculate maximum possible entropy + float max_entropy = -logf(1.0f / cur_p->size); + + llama_sampler_softmax_impl(cur_p); + + // Calculate entropy of the softmax probabilities + float entropy = 0.0f; + for (size_t i = 0; i < cur_p->size; ++i) { + float prob = cur_p->data[i].p; + if (prob > 0.0f) { // Ensure no log(0) + entropy -= prob * logf(prob); + } + } + + // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above) + float normalized_entropy = entropy / max_entropy; + + // Map the normalized entropy to the desired temperature range using the power function + float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val); + + #ifdef DEBUG + LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp); + LLAMA_LOG_INFO("Entropy: %f\n", entropy); + LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy); + LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy); + LLAMA_LOG_INFO("Exponent: %f\n", exponent_val); + LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp); + #endif + + // Apply the dynamically calculated temperature scaling + llama_sampler_temp_impl(cur_p, dyn_temp); + + // Re-compute softmax probabilities after scaling logits with dynamic temperature + const double max_l_double = cur_p->data[0].logit; + + double cum_sum_double = 0.0; + for (size_t i = 0; i < cur_p->size; ++i) { + double p = exp(cur_p->data[i].logit - max_l_double); + cur_p->data[i].p = p; // Store the scaled probability + cum_sum_double += p; + } + + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities + } + + #ifdef DEBUG + // Print the updated top 25 probabilities after temperature scaling + LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n"); + for (size_t i = 0; i < 25 && i < cur_p->size; ++i) { + LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f); + } + #endif + } else { + llama_sampler_temp_impl(cur_p, ctx->temp); + } +} + +static struct llama_sampler * llama_sampler_temp_ext_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx; + return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent); +} + +static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) { + delete (llama_sampler_temp_ext *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_temp_ext_i = { + /* .name = */ llama_sampler_temp_ext_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_temp_ext_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_temp_ext_clone, + /* .free = */ llama_sampler_temp_ext_free, +}; + +struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_temp_ext_i, + /* .ctx = */ new llama_sampler_temp_ext { + /* .temp = */ temp, + /* .delta = */ delta, + /* .exponent = */ exponent, + } + ); +} + +// xtc + +struct llama_sampler_xtc { + const float probability; + const float threshold; + const size_t min_keep; + + const uint32_t seed; + uint32_t seed_cur; + + std::mt19937 rng; +}; + +static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) { + return "xtc"; +} + +static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_xtc *) smpl->ctx; + + if (ctx->probability <= 0.0f + || ctx->threshold > 0.5f + || cur_p->size < 2) { + return; + } + + std::uniform_real_distribution distribution(0.0f, 1.0f); + float chance = distribution(ctx->rng); + if (chance > ctx->probability) return; + + // in case it's not sorted/recalculated yet + llama_sampler_softmax_impl(cur_p); + + int pos_last = 0; + + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].p >= ctx->threshold) { + pos_last = i; + } else break; + } + + if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) { + cur_p->data += pos_last; + cur_p->size -= pos_last; + } +} + +static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_xtc *) smpl->ctx; + auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed); + + // copy the state + { + auto * result_ctx = (llama_sampler_xtc *) result->ctx; + + result_ctx->rng = ctx->rng; + } + + return result; +} + +static void llama_sampler_xtc_free(struct llama_sampler * smpl) { + delete (llama_sampler_xtc *) smpl->ctx; +} + +static void llama_sampler_xtc_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_xtc *) smpl->ctx; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} + +static struct llama_sampler_i llama_sampler_xtc_i = { + /* .name = */ llama_sampler_xtc_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sample_xtc_apply, + /* .reset = */ llama_sampler_xtc_reset, + /* .clone = */ llama_sampler_xtc_clone, + /* .free = */ llama_sampler_xtc_free, +}; + +struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { + auto seed_cur = get_rng_seed(seed); + return llama_sampler_init( + /* .iface = */ &llama_sampler_xtc_i, + /* .ctx = */ new llama_sampler_xtc { + /* .probability = */ p, + /* .threshold = */ t, + /* .min_keep = */ min_keep, + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .rng = */ std::mt19937(seed_cur), + } + ); +} + +// mirostat + +struct llama_sampler_mirostat { + const int32_t n_vocab; + + const uint32_t seed; + uint32_t seed_cur; + + const float tau; + const float eta; + + const int32_t m; + + float mu; + + std::mt19937 rng; +}; + +static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) { + return "mirostat"; +} + +static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_mirostat *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p); + + // Estimate s_hat using the most probable m tokens + float s_hat = 0.0; + float sum_ti_bi = 0.0; + float sum_ti_sq = 0.0; + for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) { + float t_i = logf(float(i + 2) / float(i + 1)); + float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p); + sum_ti_bi += t_i * b_i; + sum_ti_sq += t_i * t_i; + } + s_hat = sum_ti_bi / sum_ti_sq; + + // Compute k from the estimated s_hat and target surprise value + float epsilon_hat = s_hat - 1; + float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat); + + llama_sampler_top_k_impl(cur_p, std::max(int(k), 1)); + llama_sampler_softmax_impl(cur_p); + + const int idx = llama_sample_dist(cur_p, ctx->rng); + + cur_p->selected = idx; + + float observed_surprise = -log2f(cur_p->data[idx].p); + float e = observed_surprise - ctx->tau; + + // Update mu using the learning rate and error + ctx->mu = ctx->mu - ctx->eta * e; +} + +static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx; + auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m); + + // copy the state + { + auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx; + + result_ctx->mu = ctx->mu; + result_ctx->rng = ctx->rng; + } + + return result; +} + +static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_mirostat *) smpl->ctx; + ctx->mu = 2.0f*ctx->tau; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} + +static void llama_sampler_mirostat_free(struct llama_sampler * smpl) { + delete (llama_sampler_mirostat *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_mirostat_i = { + /* .name = */ llama_sampler_mirostat_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_apply, + /* .reset = */ llama_sampler_mirostat_reset, + /* .clone = */ llama_sampler_mirostat_clone, + /* .free = */ llama_sampler_mirostat_free, +}; + +struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { + auto seed_cur = get_rng_seed(seed); + return llama_sampler_init( + /* .iface = */ &llama_sampler_mirostat_i, + /* .ctx = */ new llama_sampler_mirostat { + /* .n_vocab = */ n_vocab, + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .tau = */ tau, + /* .eta = */ eta, + /* .m = */ m, + /* .mu = */ 2.0f*tau, + /* .rng = */ std::mt19937(seed_cur), + } + ); +} + +// mirostat v2 + +struct llama_sampler_mirostat_v2 { + const uint32_t seed; + uint32_t seed_cur; + + const float tau; + const float eta; + + float mu; + + std::mt19937 rng; +}; + +static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) { + return "mirostat-v2"; +} + +static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p); + + // Truncate the words with surprise values greater than mu + cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) { + return -log2f(candidate.p) > ctx->mu; + })); + + if (cur_p->size == 0) { + cur_p->size = 1; + } + + // Normalize the probabilities of the remaining words + llama_sampler_softmax_impl(cur_p); + + const int idx = llama_sample_dist(cur_p, ctx->rng); + + cur_p->selected = idx; + + float observed_surprise = -log2f(cur_p->data[idx].p); + float e = observed_surprise - ctx->tau; + + // Update mu using the learning rate and error + ctx->mu = ctx->mu - ctx->eta * e; +} + +static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx; + ctx->mu = 2.0f*ctx->tau; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} + +static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx; + + auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta); + + // copy the state + { + auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx; + + result_ctx->mu = ctx->mu; + result_ctx->rng = ctx->rng; + } + + return result; +} + +static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) { + delete (llama_sampler_mirostat_v2 *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_mirostat_v2_i = { + /* .name = */ llama_sampler_mirostat_v2_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_v2_apply, + /* .reset = */ llama_sampler_mirostat_v2_reset, + /* .clone = */ llama_sampler_mirostat_v2_clone, + /* .free = */ llama_sampler_mirostat_v2_free, +}; + +struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { + auto seed_cur = get_rng_seed(seed); + return llama_sampler_init( + /* .iface = */ &llama_sampler_mirostat_v2_i, + /* .ctx = */ new llama_sampler_mirostat_v2 { + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .tau = */ tau, + /* .eta = */ eta, + /* .mu = */ 2.0f*tau, + /* .rng = */ std::mt19937(seed_cur), + } + ); +} + +// grammar + +struct llama_sampler_grammar { + const struct llama_vocab * vocab; + + std::string grammar_str; + std::string grammar_root; + + struct llama_grammar * grammar; +}; + +static const char * llama_sampler_grammar_name(const struct llama_sampler * /*smpl*/) { + return "grammar"; +} + +static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_grammar *) smpl->ctx; + if (ctx->grammar) { + llama_grammar_accept_impl(*ctx->grammar, token); + } +} + +static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_grammar *) smpl->ctx; + if (ctx->grammar) { + llama_grammar_apply_impl(*ctx->grammar, cur_p); + } +} + +// Fwd declare to break reset --> init_impl --> llama_sampler_grammar_i --> reset cycle. +static struct llama_sampler * llama_sampler_init_grammar_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens, + const char ** trigger_patterns, + size_t num_trigger_patterns); + +static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_grammar *) smpl->ctx; + if (!ctx->grammar) { + return; + } + + std::vector trigger_patterns_c; + trigger_patterns_c.reserve(ctx->grammar->trigger_patterns.size()); + for (auto & trigger_pattern : ctx->grammar->trigger_patterns) { + trigger_patterns_c.push_back(trigger_pattern.pattern.c_str()); + } + + auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(), + ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(), + ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size()); + + llama_grammar_free_impl(ctx->grammar); + ctx->grammar = grammar_new; +} + +static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_grammar *) smpl->ctx; + + auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0, nullptr, 0); + GGML_ASSERT(result); + + // copy the state + { + auto * result_ctx = (llama_sampler_grammar *) result->ctx; + + if (ctx->grammar) { + result_ctx->grammar_str = ctx->grammar_str; + result_ctx->grammar_root = ctx->grammar_root; + + result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar); + } + } + + return result; +} + +static void llama_sampler_grammar_free(struct llama_sampler * smpl) { + const auto * ctx = (llama_sampler_grammar *) smpl->ctx; + + if (ctx->grammar) { + llama_grammar_free_impl(ctx->grammar); + } + + delete ctx; +} + +static struct llama_sampler_i llama_sampler_grammar_i = { + /* .name = */ llama_sampler_grammar_name, + /* .accept = */ llama_sampler_grammar_accept_impl, + /* .apply = */ llama_sampler_grammar_apply, + /* .reset = */ llama_sampler_grammar_reset, + /* .clone = */ llama_sampler_grammar_clone, + /* .free = */ llama_sampler_grammar_free, +}; + +static struct llama_sampler * llama_sampler_init_grammar_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens, + const char ** trigger_patterns, + size_t num_trigger_patterns) { + auto * ctx = new llama_sampler_grammar; + + if (grammar_str != nullptr && grammar_str[0] != '\0') { + // TODO: remove trigger_words support. + if (trigger_words != nullptr && num_trigger_words > 0) { + GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0); + std::string trigger_pattern("[\\s\\S]*?("); + for (size_t i = 0; i < num_trigger_words; ++i) { + static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]"); + if (i > 0) { + trigger_pattern += "|"; + } + trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0"); + } + trigger_pattern += ")[\\s\\S]*"; + auto trigger_pattern_c = trigger_pattern.c_str(); + trigger_patterns = &trigger_pattern_c; + num_trigger_patterns = 1; + } + *ctx = { + /* .vocab = */ vocab, + /* .grammar_str = */ grammar_str, + /* .grammar_root = */ grammar_root, + /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens), + }; + if (!ctx->grammar) { + delete ctx; + return nullptr; + } + } else { + *ctx = { + /* .vocab = */ vocab, + /* .grammar_str = */ {}, + /* .grammar_root = */ {}, + /* .grammar = */ nullptr, + }; + } + + return llama_sampler_init( + /* .iface = */ &llama_sampler_grammar_i, + /* .ctx = */ ctx + ); +} + +struct llama_sampler * llama_sampler_init_grammar( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root) { + return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0, nullptr, 0); +} + +struct llama_sampler * llama_sampler_init_grammar_lazy( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens) { + return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens, nullptr, 0); +} + +struct llama_sampler * llama_sampler_init_grammar_lazy_patterns( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_patterns, + size_t num_trigger_patterns, + const llama_token * trigger_tokens, + size_t num_trigger_tokens) { + return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, nullptr, 0, trigger_tokens, num_trigger_tokens, trigger_patterns, num_trigger_patterns); +} + +// penalties + +struct llama_sampler_penalties { + const int32_t penalty_last_n; + const float penalty_repeat; + const float penalty_freq; + const float penalty_present; + + ring_buffer prev; + + // a frequency map to count token occurrences + std::unordered_map token_count; +}; + +static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) { + return "penalties"; +} + +static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_penalties *) smpl->ctx; + if (ctx->penalty_last_n == 0) { + return; + } + + ctx->token_count[token]++; + + // if the ring buffer is full, remove the oldest token + if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) { + const auto old = ctx->prev.front(); + + ctx->token_count[old]--; + if (ctx->token_count[old] == 0) { + ctx->token_count.erase(old); + } + } + + ctx->prev.push_back(token); + +#if 0 + // sanity check + std::unordered_map tmp; + for (int i = 0; i < std::min(ctx->penalty_last_n, ctx->prev.size()); ++i) { + tmp[ctx->prev.rat(i)]++; + } + + assert(ctx->token_count == tmp); +#endif +} + +static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_penalties *) smpl->ctx; + + if ((ctx->penalty_last_n == 0) || + (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) { + return; + } + + // Apply frequency and presence penalties to the cur_p + for (size_t i = 0; i < cur_p->size; ++i) { + const auto token_iter = ctx->token_count.find(cur_p->data[i].id); + if (token_iter == ctx->token_count.end()) { + continue; + } + + const int count = token_iter->second; + + assert(count > 0 && count <= ctx->penalty_last_n); + + // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. + // This is common fix for this problem, which is to multiply by the penalty instead of dividing. + if (cur_p->data[i].logit <= 0) { + cur_p->data[i].logit *= ctx->penalty_repeat; + } else { + cur_p->data[i].logit /= ctx->penalty_repeat; + } + + cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present; + } + + cur_p->sorted = false; +} + +static void llama_sampler_penalties_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_penalties *) smpl->ctx; + ctx->prev.clear(); + ctx->token_count.clear(); +} + +static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_penalties *) smpl->ctx; + auto * result = llama_sampler_init_penalties( + ctx->penalty_last_n, + ctx->penalty_repeat, + ctx->penalty_freq, + ctx->penalty_present); + + // copy the state + { + auto * result_ctx = (llama_sampler_penalties *) result->ctx; + + result_ctx->prev = ctx->prev; + } + + return result; +} + +static void llama_sampler_penalties_free(struct llama_sampler * smpl) { + delete (llama_sampler_penalties *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_penalties_i = { + /* .name = */ llama_sampler_penalties_name, + /* .accept = */ llama_sampler_penalties_accept, + /* .apply = */ llama_sampler_penalties_apply, + /* .reset = */ llama_sampler_penalties_reset, + /* .clone = */ llama_sampler_penalties_clone, + /* .free = */ llama_sampler_penalties_free, +}; + +struct llama_sampler * llama_sampler_init_penalties( + int32_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present) { + penalty_last_n = std::max(penalty_last_n, 0); + + return llama_sampler_init( + /* .iface = */ &llama_sampler_penalties_i, + /* .ctx = */ new llama_sampler_penalties { + /* .penalty_last_n = */ penalty_last_n, + /* .penalty_repeat = */ penalty_repeat, + /* .penalty_freq = */ penalty_freq, + /* .penalty_present = */ penalty_present, + /* .prev = */ ring_buffer(penalty_last_n), + /* .token_count = */ {}, + } + ); +} + +// top-n-sigma + +struct llama_sampler_top_n_sigma { + const float n; +}; + +static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) { + return "top-n-sigma"; +} + +static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx; + + if (ctx->n <= 0.0f || cur_p->size <= 1) { + return; + } + + // find max logit and calculate mean + float max = cur_p->data[0].logit; + float logits_sum = 0; + size_t valid_count = 0; + for (size_t i = 0; i < cur_p->size; ++i) { + // Only count non-negative infinity values + if (cur_p->data[i].logit != -INFINITY) { + if (cur_p->data[i].logit > max) { + max = cur_p->data[i].logit; + } + logits_sum += cur_p->data[i].logit; + valid_count++; + } + } + float mean = valid_count > 0 ? logits_sum/valid_count : 0; + + // calculate standard deviation + float acc = 0; + for (size_t i = 0; i < cur_p->size; ++i) { + // Skip -infinity in std calculation + if (cur_p->data[i].logit != -INFINITY) { + acc += pow(cur_p->data[i].logit - mean, 2); + } + } + float std = valid_count > 0 ? sqrt(acc/valid_count) : 0; + + //apply mask + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].logit < max - (ctx->n * std)) { + cur_p->data[i].logit = -INFINITY; + } + } + llama_sampler_softmax_impl(cur_p); +} + +static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx; + return llama_sampler_init_top_n_sigma(ctx->n); +} + +static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) { + delete (llama_sampler_top_n_sigma *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_top_n_sigma_i = { + /* .name = */ llama_sampler_top_n_sigma_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_n_sigma_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_n_sigma_clone, + /* .free = */ llama_sampler_top_n_sigma_free, +}; + +struct llama_sampler * llama_sampler_init_top_n_sigma(float n) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_top_n_sigma_i, + /* .ctx = */ new llama_sampler_top_n_sigma { + /* .n = */ n, + } + ); +} + +// DRY + +struct llama_sampler_dry { + int32_t total_context_size; + + const float dry_multiplier; + const float dry_base; + const int32_t dry_allowed_length; + const int32_t dry_penalty_last_n; + + std::unordered_multimap> dry_processed_breakers; + std::vector dry_repeat_count; + std::unordered_map dry_max_token_repeat; + ring_buffer last_tokens; +}; + +// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) +static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap>& token_sequences, int max_tail_len = -1) { + for (llama_token token_id = 0; token_id < (llama_token) vocab.n_tokens(); token_id++) { + std::string word = vocab.detokenize({token_id}, true); + if (word.find(str) != std::string::npos) { + token_sequences.emplace(token_id, std::vector()); + } else { + size_t word_len = word.size(); + size_t str_len = str.size(); + size_t pos = -1; + while ((pos = word.find(str[0], pos + 1)) != std::string::npos) { + bool match = true; + size_t i; + for (i = 1; i < str_len && i + pos < word_len; ++i) { + if (word[pos + i] != str[i]) { + match = false; + break; + } + } + if (match) { + std::vector tokenization = vocab.tokenize(str.substr(i), false, false); + if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) { + tokenization.resize(max_tail_len); + } + + // Ensure we don't already have a duplicate matching tokenization + auto its = token_sequences.equal_range(token_id); + bool found = false; + for (auto it = its.first; it != its.second; ++it) { + if (tokenization == it->second) { + found = true; + break; + } + } + if (!found) { + token_sequences.emplace(token_id, tokenization); + } + } + } + } + } +} + +static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/) { + return "dry"; +} + +static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_dry *) smpl->ctx; + if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) { + return; + } + + ctx->last_tokens.push_back(token); +} + +// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) +static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_dry *) smpl->ctx; + + if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) { + return; + } + + int32_t effective_dry_penalty_last_n = (ctx->dry_penalty_last_n == -1) ? ctx->total_context_size : std::max(ctx->dry_penalty_last_n, 0); + int last_n_repeat = std::min(std::min((int)ctx->last_tokens.size(), effective_dry_penalty_last_n), ctx->total_context_size); + + if (last_n_repeat <= ctx->dry_allowed_length) { + return; + } + + ctx->dry_repeat_count.assign(last_n_repeat, 0); + ctx->dry_max_token_repeat.clear(); + + // Step 1: Look for restart sequences to limit the maximum repetition length. + // Work backwards through the context looking for any token that begins a restart sequence. + // + // The collection `restart_sequences` is a mapping from a "head" token to all "tail" + // sequences that together comprise a restart sequence. This allows us to quickly check + // whether each token is the head of a complete sequence. Most restart sequences are actually + // a single token, and for these the "tail" is an empty vector. + // + // If the token is a "head", test all restart sequences that begin with this token + // (there will often only be one sequence for each token, but if sequences like 'aaaq1' and + // 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The + // longest matching sequence (if any) is used to limit the maximum repetition length. + // + // Note that in the case case of a short sequence contained in a longer one, this might fail to + // find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as + // restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress + // 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare. + // + // This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we + // have already clamped the maximum tail sequence length when generating `restart_sequences`. + // With clamping, this scan is O(N) in the context length. + + int rep_limit = last_n_repeat; + for (int i = 0; i < last_n_repeat; ++i) { + llama_token token = ctx->last_tokens.rat(i); + auto its = ctx->dry_processed_breakers.equal_range(token); + if (its.first == ctx->dry_processed_breakers.end()) { + continue; + } + int longest_match = -1; + for (auto it = its.first; it != its.second; ++it) { + // Note that (*it) does not contain the head character, so seq_len will be + // the restart sequence length minus 1. + // In the common case of a single-token restart sequence, (*it) will be empty + // and we will trivially match. + int seq_len = (int)it->second.size(); + if (seq_len > longest_match && seq_len <= (int)i) { + bool match = true; + for (int offset = 0; offset < seq_len; ++offset) { + // The -1 when indexing `last_tokens` is because we already matched the head. + if (it->second[offset] != ctx->last_tokens.rat(i - offset - 1)) { + match = false; + break; + } + } + if (match) { + longest_match = seq_len; + } + } + } + if (longest_match >= 0) { + // We found a restart sequence starting `i` tokens from the end and continuing for + // `longest_match` tokens. + rep_limit = i - longest_match; + break; + } + } + if (rep_limit < ctx->dry_allowed_length) { + return; + } + + // Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in + // the reverse direction) to efficiently compute the positions and lengths of suffixes appearing + // elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences. + // + // This algorithm is not currently documented on Wikipedia, but there is a clear description here: + // https://ivanyu.me/blog/2014/10/15/z-algorithm/ + // + // The code below is adapted from the public domain implementation by the same author here: + // https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py + // + // Example: + // Last N tokens: a b c c b c y a b c + // Repeat counts: 0 0 3 1 0 2 0 0 0 0 + // ^ + // This `3` means that the last three tokens of the context (a b c) also appear here. + // + // This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested + // for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each + // repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables + // ensure that the inner while loops only examine each token in the context once as the outer + // for loop iterates over the context. + + { + const int last = last_n_repeat - 1; + int rt = 0, lt = 0; + + for (int k = 1; k < last_n_repeat; ++k) { + if (k > rt) { + // If k is outside the current Z-box, do naive computation. + int n = 0; + while (n + k < last_n_repeat && ctx->last_tokens.rat(n) == ctx->last_tokens.rat(n+k)) { + ++n; + } + ctx->dry_repeat_count[last - k] = std::min(n, rep_limit); + if (n > 0) { + lt = k; + rt = k + n - 1; + } + } else { + // If k is inside the current Z-box, consider two cases. + + int p = k - lt; // Pair index. + int right_part_len = rt - k + 1; + + if (ctx->dry_repeat_count[last - p] < right_part_len) { + int n = std::min(ctx->dry_repeat_count[last - p], rep_limit); + ctx->dry_repeat_count[last - k] = n; + } else { + int i = rt + 1; + while (i < last_n_repeat && ctx->last_tokens.rat(i) == ctx->last_tokens.rat(i - k)) { + i += 1; + } + + int n = std::min(i - k, rep_limit); + ctx->dry_repeat_count[last - k] = n; + lt = k; + rt = i - 1; + } + } + } + } + + // Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length + // that would be generated by emitting each new token that would extend a sequence. + // + // Following the same example as above: + // Last N tokens: a b c c b c y a b c + // Repeat counts: 0 0 3 1 0 2 0 0 0 0 + // + // For each non-zero, look ahead one token. This token, if emitted, would extend the repetition. + // c: 3 -> 4 (from `a b c` to `a b c c`) + // b: 1 -> 2 (from `c` to `c b`) + // y: 2 -> 3 (from `b c` to `b c y`) + + for (int i = 0; i < last_n_repeat - 1; ++i) { + int repeat_len = ctx->dry_repeat_count[i]; + if (repeat_len >= ctx->dry_allowed_length) { + // This token ends a repeat, so the next token would continue one. + // By convention, the value of `repeat_len` only includes the tokens currently + // in the context, not the new token that would be added. + llama_token token = ctx->last_tokens.rat(last_n_repeat - 2 - i); + // Track the maximum sequence ending in this token. + const auto& it = ctx->dry_max_token_repeat.find(token); + if (it == ctx->dry_max_token_repeat.end() || it->second < repeat_len) { + ctx->dry_max_token_repeat[token] = repeat_len; + } + } + } + + // Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens. + + // Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`. + // Compute it from `penalty_base` and the approximate log of `std::numeric_limits::max()` + const float FLOAT_MAX_LOG = 88.7228391f; + int max_exponent = 0; + if (ctx->dry_base > 1.000001f) { + max_exponent = FLOAT_MAX_LOG / std::log(ctx->dry_base); + } + + for (size_t i = 0; i < cur_p->size; ++i) { + const auto& af_kvp = ctx->dry_max_token_repeat.find(cur_p->data[i].id); + if (af_kvp != ctx->dry_max_token_repeat.end()) { + // Check all sequence breakers starting with this token + auto range = ctx->dry_processed_breakers.equal_range(cur_p->data[i].id); + bool is_single_token_breaker = false; + + for (auto it = range.first; it != range.second; ++it) { + if (it->second.empty()) { + is_single_token_breaker = true; + break; + } + } + + // Apply penalty only if it's not a single-token sequence breaker + if (!is_single_token_breaker) { + int repeat_exp = af_kvp->second - ctx->dry_allowed_length; + if (max_exponent > 0 && repeat_exp > max_exponent) { + repeat_exp = max_exponent; + } + float penalty = ctx->dry_multiplier * std::pow(ctx->dry_base, repeat_exp); + cur_p->data[i].logit -= penalty; + } + } + } + + cur_p->sorted = false; +} + +static void llama_sampler_dry_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_dry *) smpl->ctx; + ctx->last_tokens.clear(); + ctx->dry_repeat_count.clear(); + ctx->dry_max_token_repeat.clear(); +} + +static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) { + const auto * ctx = (llama_sampler_dry *) smpl->ctx; + + llama_vocab dummy_vocab; + + // dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying + auto * result = llama_sampler_init_dry(&dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0); + + // Copy the state, including the processed breakers + { + auto * result_ctx = (llama_sampler_dry *) result->ctx; + result_ctx->dry_processed_breakers = ctx->dry_processed_breakers; + result_ctx->dry_repeat_count = ctx->dry_repeat_count; + result_ctx->dry_max_token_repeat = ctx->dry_max_token_repeat; + result_ctx->last_tokens = ctx->last_tokens; + } + + return result; +} + +static void llama_sampler_dry_free(struct llama_sampler * smpl) { + delete (llama_sampler_dry *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_dry_i = { + /* .name = */ llama_sampler_dry_name, + /* .accept = */ llama_sampler_dry_accept, + /* .apply = */ llama_sampler_dry_apply, + /* .reset = */ llama_sampler_dry_reset, + /* .clone = */ llama_sampler_dry_clone, + /* .free = */ llama_sampler_dry_free, +}; + +struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { + int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0); + std::unordered_multimap> processed_breakers; + const int MAX_CHAR_LEN = 40; + const int MAX_SEQ_LEN = 20; + + const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0); + + if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) { + // Process sequence breakers + for (size_t i = 0; i < num_breakers; ++i) { + if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) { + LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i); + continue; + } + + std::string sequence_break(seq_breakers[i]); + if (sequence_break.empty()) { + LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n"); + continue; + } + + if (sequence_break.size() > MAX_CHAR_LEN) { + LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN); + sequence_break.resize(MAX_CHAR_LEN); + } + + get_overlapping_token_sequences(*vocab, sequence_break, processed_breakers, MAX_SEQ_LEN); + } + } + + return llama_sampler_init( + /* .iface = */ &llama_sampler_dry_i, + /* .ctx = */ new llama_sampler_dry { + /* .total_context_size = */ context_size, + /* .dry_multiplier = */ dry_multiplier, + /* .dry_base = */ dry_base, + /* .dry_allowed_length = */ dry_allowed_length, + /* .dry_penalty_last_n = */ dry_penalty_last_n, + /* .dry_processed_breakers = */ std::move(processed_breakers), + /* .dry_repeat_count = */ dry_enabled ? std::vector(effective_dry_penalty_last_n, 0) : std::vector{}, + /* .dry_max_token_repeat = */ {}, + /* .last_tokens = */ dry_enabled ? ring_buffer(effective_dry_penalty_last_n) : ring_buffer(0), + } + ); +} + +// wrapper for test-sampling.cpp +struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector>& seq_breakers) { + llama_vocab dummy_vocab; + auto * result = llama_sampler_init_dry(&dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0); + auto * ctx = (llama_sampler_dry *) result->ctx; + + // Process the token-based sequence breakers + ctx->dry_processed_breakers.clear(); + if (seq_breakers.empty()) { + LLAMA_LOG_WARN("empty DRY sequence breakers list in llama_sampler_init_dry_testing\n"); + } else { + for (const auto& breaker : seq_breakers) { + if (breaker.empty()) { + LLAMA_LOG_WARN("skipping DRY empty sequence breaker\n"); + continue; + } + llama_token head_token = breaker[0]; + std::vector tail_tokens(breaker.begin() + 1, breaker.end()); + ctx->dry_processed_breakers.emplace(head_token, std::move(tail_tokens)); + } + + if (ctx->dry_processed_breakers.empty()) { + LLAMA_LOG_WARN("no valid DRY sequence breakers processed in llama_sampler_init_dry_testing\n"); + } + } + + return result; +} + +// logit-bias + +struct llama_sampler_logit_bias { + const int32_t n_vocab; + + const std::vector logit_bias; + + std::vector to_search; +}; + +static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) { + return "logit-bias"; +} + +static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_logit_bias *) smpl->ctx; + + if (ctx->logit_bias.empty()) { + return; + } + + ctx->to_search.clear(); + + // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id) + for (const auto & lb : ctx->logit_bias) { + if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) { + cur_p->data[lb.token].logit += lb.bias; + } else { + ctx->to_search.push_back(lb); + } + } + + if (ctx->to_search.empty()) { + return; + } + + // search for the remaining candidates that were not found in the previous step + for (size_t i = 0; i < cur_p->size; ++i) { + for (const auto & lb : ctx->to_search) { + if (cur_p->data[i].id == lb.token) { + cur_p->data[i].logit += lb.bias; + break; + } + } + } +} + +static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx; + return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data()); +} + +static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) { + delete (llama_sampler_logit_bias *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_logit_bias_i = { + /* .name = */ llama_sampler_logit_bias_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_logit_bias_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_logit_bias_clone, + /* .free = */ llama_sampler_logit_bias_free, +}; + +struct llama_sampler * llama_sampler_init_logit_bias( + int32_t n_vocab, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_logit_bias_i, + /* .ctx = */ new llama_sampler_logit_bias { + /* .n_vocab = */ n_vocab, + /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), + /* .to_search = */ {}, + } + ); +} + +// infill + +//#define GGML_DEBUG_SAMPLER_INFILL + +struct llama_sampler_infill { + const struct llama_vocab * vocab; + + std::vector buf0; + std::vector buf1; +}; + +static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) { + return "infill"; +} + +static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_infill *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p); + +#if defined(GGML_DEBUG_SAMPLER_INFILL) +#define LOG_DBG_CUR LLAMA_LOG_DEBUG +#else +#define LOG_DBG_CUR(...) +#endif + + for (size_t i = 0; i < cur_p->size; ++i) { + LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); + } + + float p_txt_sum = 0.0f; + float p_eog_sum = 0.0f; + + for (size_t i = 0; i < cur_p->size; ++i) { + if (ctx->vocab->is_eog(cur_p->data[i].id)) { + p_eog_sum += cur_p->data[i].p; + } else { + p_txt_sum += cur_p->data[i].p; + } + } + + const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat); + + LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size); + + if (3*p_eog_sum*cur_p->size > p_txt_sum) { + LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum); + + // keep just the EOG tokens + const auto size_org = cur_p->size; + + cur_p->size = 0; + + float p_sum = 0.0f; + + for (size_t i = 0; i < size_org; ++i) { + if (ctx->vocab->is_eog(cur_p->data[i].id)) { + p_sum += cur_p->data[i].p; + + cur_p->data[cur_p->size++] = cur_p->data[i]; + } + } + + // normalize probs + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= p_sum; + } + + return; + } + + size_t n_combined = 0; GGML_UNUSED(n_combined); + + // combine tokens with common prefix + for (size_t i0 = 0; i0 < cur_p->size; ++i0) { + for (size_t i1 = 0; i1 < cur_p->size; ++i1) { + if (cur_p->data[i0].logit == -INFINITY) { + break; + } + + if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) { + continue; + } + + int len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false); + if (len0 < 0) { + ctx->buf0.resize(len0); + len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false); + assert(len0 > 0); + } + + int len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false); + if (len1 < 0) { + ctx->buf1.resize(len1); + len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false); + assert(len1 > 0); + } + + // token i0 is a prefix of token i1 + if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) { + int dst = i0; + int src = i1; + + // merge into the token with higher probability + if (cur_p->data[i1].p > cur_p->data[i0].p) { + std::swap(dst, src); + } + + cur_p->data[dst].p += cur_p->data[src].p; + cur_p->data[src].logit = -INFINITY; + cur_p->data[src].p = 0.0f; + + n_combined++; + } + } + } + + size_t n_non_eog = 0; + + size_t size_org = cur_p->size; + + float p_sum = 0.0f; + float thold = 0.2f; + + cur_p->size = 0; + + LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold); + + for (size_t i = 0; i < size_org; ++i) { + const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id); + + if (cur_p->data[i].p < thold && !is_eog) { + continue; + } + + if (!is_eog) { + ++n_non_eog; + } + + p_sum += cur_p->data[i].p; + + // keep this token + cur_p->data[cur_p->size++] = cur_p->data[i]; + } + + LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog); + + // if no non-EOG tokens are left -> reduce cur_p to single EOT token + if (n_non_eog == 0) { + cur_p->size = 1; + cur_p->data[0].id = ctx->vocab->token_eot(); + cur_p->data[0].logit = 1.0f; + + return; + } + + // normalize probs + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= p_sum; + + LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); + } + + size_org = cur_p->size; + p_sum = 0.0f; + thold = 1.0/(n_non_eog + 1); + + cur_p->size = 0; + + LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold); + + for (size_t i = 0; i < size_org; ++i) { + const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id); + + if (cur_p->data[i].p < thold && !is_eog) { + continue; + } + + p_sum += cur_p->data[i].p; + + cur_p->data[cur_p->size++] = cur_p->data[i]; + } + + // normalize probs + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= p_sum; + + LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); + } + +#undef LOG_DBG_CUR +} + +static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_infill *) smpl->ctx; + return llama_sampler_init_infill(ctx->vocab); +} + +static void llama_sampler_infill_free(struct llama_sampler * smpl) { + delete (llama_sampler_infill *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_infill_i = { + /* .name = */ llama_sampler_infill_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_infill_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_infill_clone, + /* .free = */ llama_sampler_infill_free, +}; + +struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_infill_i, + /* .ctx = */ new llama_sampler_infill { + /* .vocab = */ vocab, + /* .buf0 = */ std::vector(512), + /* .buf1 = */ std::vector(512), + } + ); +} + +// utils + +uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) { + if (smpl->iface == &llama_sampler_dist_i) { + return ((const llama_sampler_dist *) smpl->ctx)->seed_cur; + } + + if (smpl->iface == &llama_sampler_mirostat_i) { + return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur; + } + + if (smpl->iface == &llama_sampler_mirostat_v2_i) { + return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur; + } + + if (smpl->iface == &llama_sampler_chain_i) { + const auto * ctx = (const llama_sampler_chain *) smpl->ctx; + for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) { + const uint32_t seed = llama_sampler_get_seed(*it); + if (seed != LLAMA_DEFAULT_SEED) { + return seed; + } + } + } + + return LLAMA_DEFAULT_SEED; +} + +// perf + +struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * chain) { + struct llama_perf_sampler_data data = {}; + + if (chain == nullptr || chain->iface != &llama_sampler_chain_i) { + GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__); + } + + const auto * ctx = (const struct llama_sampler_chain *) chain->ctx; + + data.t_sample_ms = 1e-3 * ctx->t_sample_us; + data.n_sample = std::max(0, ctx->n_sample); + + return data; +} + +void llama_perf_sampler_print(const struct llama_sampler * chain) { + const auto data = llama_perf_sampler(chain); + + LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample); +} + +void llama_perf_sampler_reset(struct llama_sampler * chain) { + if (chain == nullptr || chain->iface != &llama_sampler_chain_i) { + GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__); + } + + auto * ctx = (struct llama_sampler_chain *) chain->ctx; + + ctx->t_sample_us = ctx->n_sample = 0; +} diff --git a/src/llama-sampling.h b/src/llama-sampling.h new file mode 100644 index 0000000000000000000000000000000000000000..759dd7dcb7042e182013a34a232769e534e70220 --- /dev/null +++ b/src/llama-sampling.h @@ -0,0 +1,32 @@ +#pragma once + +// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ? + +#include "llama.h" + +#include + +struct llama_vocab; +struct llama_grammar; + +// sampler chain + +struct llama_sampler_chain { + llama_sampler_chain_params params; + + std::vector samplers; + + // timing + + mutable int64_t t_sample_us; + + mutable int32_t n_sample; +}; + +struct llama_sampler * llama_sampler_init_dry_testing( + int32_t context_size, + float dry_multiplier, + float dry_base, + int32_t dry_allowed_length, + int32_t dry_penalty_last_n, + const std::vector>& seq_breakers); diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dd2251ef3cbefa7bc6d88fecab69cec4bc70b7ad --- /dev/null +++ b/src/llama-vocab.cpp @@ -0,0 +1,3354 @@ +#include "llama-vocab.h" + +#include "ggml.h" +#include "gguf.h" +#include "llama-impl.h" +#include "llama-model-loader.h" + +#include "unicode.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// +// helpers +// + +struct naive_trie { + naive_trie() : has_value(false), value(0) { + } + void insert(const char * key, size_t len, int32_t value = 0) { + if (len == 0) { + this->has_value = true; + this->value = value; + return; + } + char c = key[0]; + auto res = children.find(c); + if (res != children.end()) { + res->second.insert(key + 1, len - 1, value); + } else { + auto res = children.insert(std::make_pair(c, naive_trie())); + res.first->second.insert(key + 1, len - 1, value); + } + } + std::pair get_longest_prefix(const char * key, size_t len, size_t offset = 0) const { + if (len == 0 || offset == len) { + return std::make_pair(key, offset); + } + char c = key[offset]; + auto res = children.find(c); + if (res != children.end()) { + return res->second.get_longest_prefix(key, len, offset + 1); + } + + return std::make_pair(key, offset); + } + const struct naive_trie * traverse(const char c) const { + auto res = children.find(c); + if (res != children.end()) { + return &res->second; + } + + return NULL; + } + std::map children; + bool has_value; + llama_token value; +}; + +// +// tokenizers +// + +struct llm_tokenizer { + llm_tokenizer() {} + virtual ~llm_tokenizer() = default; +}; + +struct llm_symbol { + using index = int; + index prev; + index next; + const char * text; + size_t n; +}; + +static_assert(std::is_trivially_copyable::value, "llm_symbol is not trivially copyable"); + +// +// SPM tokenizer +// original implementation: +// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 +// + +struct llm_bigram_spm { + struct comparator { + bool operator()(llm_bigram_spm & l, llm_bigram_spm & r) { + return (l.score < r.score) || (l.score == r.score && l.left > r.left); + } + }; + using queue_storage = std::vector; + using queue = std::priority_queue; + llm_symbol::index left; + llm_symbol::index right; + float score; + size_t size; +}; + +struct llm_tokenizer_spm : llm_tokenizer { + llm_tokenizer_spm(const llama_vocab & /*vocab*/) {} +}; + +struct llm_tokenizer_spm_session { + llm_tokenizer_spm_session(const llama_vocab & vocab) : vocab(vocab) {} + + void tokenize(const std::string & text, std::vector & output) { + // split string into utf8 chars + int index = 0; + size_t offs = 0; + while (offs < text.size()) { + llm_symbol sym; + size_t len = unicode_len_utf8(text[offs]); + sym.text = text.c_str() + offs; + sym.n = std::min(len, text.size() - offs); + offs += sym.n; + sym.prev = index - 1; + sym.next = offs == text.size() ? -1 : index + 1; + index++; + symbols.emplace_back(sym); + } + + // seed the work queue with all possible 2-character tokens. + for (int i = 1; i < (int) symbols.size(); ++i) { + try_add_bigram(i - 1, i); + } + + // keep substituting the highest frequency pairs for as long as we can. + while (!work_queue.empty()) { + auto bigram = work_queue.top(); + work_queue.pop(); + + auto & left_sym = symbols[bigram.left]; + auto & right_sym = symbols[bigram.right]; + + // if one of the symbols already got merged, skip it. + if (left_sym.n == 0 || right_sym.n == 0 || + left_sym.n + right_sym.n != bigram.size) { + continue; + } + + // merge the right sym into the left one + left_sym.n += right_sym.n; + right_sym.n = 0; + + //LLAMA_LOG_INFO("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size); + + // remove the right sym from the chain + left_sym.next = right_sym.next; + if (right_sym.next >= 0) { + symbols[right_sym.next].prev = bigram.left; + } + + // find more substitutions + try_add_bigram(left_sym.prev, bigram.left); + try_add_bigram(bigram.left, left_sym.next); + } + + for (int i = 0; i != -1; i = symbols[i].next) { + auto & symbol = symbols[i]; + resegment(symbol, output); + } + } + +private: + void resegment(llm_symbol & symbol, std::vector & output) { + auto text = std::string(symbol.text, symbol.n); + auto token = vocab.text_to_token(text); + + // Do we need to support is_unused? + if (token != LLAMA_TOKEN_NULL) { + output.push_back(token); + return; + } + + const auto p = rev_merge.find(text); + + if (p == rev_merge.end()) { + // output any symbols that did not form tokens as bytes. + output.reserve(output.size() + symbol.n); + for (int j = 0; j < (int)symbol.n; ++j) { + llama_token id = vocab.byte_to_token(symbol.text[j]); + output.push_back(id); + } + return; + } + + resegment(symbols[p->second.first], output); + resegment(symbols[p->second.second], output); + } + + void try_add_bigram(int left, int right) { + if (left == -1 || right == -1) { + return; + } + const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n); + auto token = vocab.text_to_token(text); + + if (token == LLAMA_TOKEN_NULL) { + return; + } + + if (static_cast(token) >= vocab.n_tokens()) { + return; + } + + const auto & tok_data = vocab.get_token_data(token); + + llm_bigram_spm bigram; + bigram.left = left; + bigram.right = right; + bigram.score = tok_data.score; + bigram.size = text.size(); + + work_queue.push(bigram); + + // Do we need to support is_unused? + rev_merge[text] = std::make_pair(left, right); + } + + const llama_vocab & vocab; + // currently unused + // const llm_tokenizer_spm * spm_tokenizer; + + std::vector symbols; + llm_bigram_spm::queue work_queue; + std::map> rev_merge; +}; + +// +// BPE tokenizer +// adapted from https://github.com/cmp-nct/ggllm.cpp [MIT License] +// tried to simplify unicode stuff, so most likely does not work 100% correctly! +// + +// TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused + +template, typename Compare = std::less> +class llama_priority_queue : public std::priority_queue { +public: + using std::priority_queue::priority_queue; + + T pop_move() { + T item = std::move(this->c.front()); + std::pop_heap(this->c.begin(), this->c.end(), this->comp); + this->c.pop_back(); + return item; + } + + void pop() = delete; +}; + +struct llm_bigram_bpe { + struct comparator { + bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const { + return l.rank > r.rank || (l.rank == r.rank && l.left > r.left); + } + }; + + using queue_storage = std::vector; + using queue = llama_priority_queue; + llm_symbol::index left; + llm_symbol::index right; + std::string text; + int rank; + size_t size; +}; + +struct llm_tokenizer_bpe : llm_tokenizer { + llm_tokenizer_bpe(const llama_vocab & vocab) { + GGML_ASSERT(vocab.get_type() == LLAMA_VOCAB_TYPE_BPE); + switch (vocab.get_pre_type()) { + case LLAMA_VOCAB_PRE_TYPE_LLAMA3: + regex_exprs = { + // original regex from tokenizer.json + //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + + // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989 + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_DBRX: + case LLAMA_VOCAB_PRE_TYPE_SMAUG: + regex_exprs = { + // same as llama3 + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM: + regex_exprs = { + "[\r\n]", + "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+", + "\\s?[!-/:-~!-/:-~‘-‟ -。]+", + "\\s+$", + "[一-龥ࠀ-一가-퟿]+", + "\\p{N}+", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM: + regex_exprs = { + "\\p{N}{1,3}", + "[一-龥぀-ゟ゠-ヿ]+", + "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER: + regex_exprs = { + "[\r\n]", + "\\s?\\p{L}+", + "\\s?\\p{P}+", + "[一-龥ࠀ-一가-퟿]+", + "\\p{N}", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_FALCON: + regex_exprs = { + "[\\p{P}\\$\\+<=>\\^~\\|`]+", + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + "[0-9][0-9][0-9]", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_STARCODER: + case LLAMA_VOCAB_PRE_TYPE_REFACT: + case LLAMA_VOCAB_PRE_TYPE_COMMAND_R: + case LLAMA_VOCAB_PRE_TYPE_SMOLLM: + case LLAMA_VOCAB_PRE_TYPE_CODESHELL: + case LLAMA_VOCAB_PRE_TYPE_EXAONE: + case LLAMA_VOCAB_PRE_TYPE_MINERVA: + regex_exprs = { + "\\p{N}", + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_GPT2: + case LLAMA_VOCAB_PRE_TYPE_MPT: + case LLAMA_VOCAB_PRE_TYPE_OLMO: + case LLAMA_VOCAB_PRE_TYPE_JAIS: + case LLAMA_VOCAB_PRE_TYPE_TRILLION: + regex_exprs = { + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_STABLELM2: + case LLAMA_VOCAB_PRE_TYPE_QWEN2: + regex_exprs = { + // original regex from tokenizer.json + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_PORO: + case LLAMA_VOCAB_PRE_TYPE_BLOOM: + case LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH: + regex_exprs = { + " ?[^(\\s|.,!?…。,、।۔،)]+", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_CHATGLM4: + regex_exprs = { + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_VIKING: + regex_exprs = { + " ?[^(\\s|.,!?…。,、।۔،)]+", + "\\p{N}", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_TEKKEN: + // original regex from tokenizer.json + // "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + regex_exprs = { + "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_CHAMELEON: + // Note: in theory, the special token (sentinel and image token) regex_exprs below + // are unnecessary, as they are split in `tokenizer_st_partition` anyway. + // However, since the upstream pre-tokenizer uses them, they are also + // included here (see https://huggingface.co/facebook/chameleon-7b). + regex_exprs = { + "", // Sentinel tokens + "(IMGIMG)((A|B|C|D|E|F|G|H|I){1,4})Z", // Image tokens + "([\\t\\n]| | )", // directly from tokenizer.json + "\\p{N}", // Individual digits + "[\\p{P}!-/:-@\\[-`{-~]", // Punctuation, Isolated + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_GPT4O: + regex_exprs = { + // original regex from tokenizer.json + // "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_SUPERBPE: + regex_exprs = { + "\\p{N}+", + "(?=(\\d{3})+(?!\\d))", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_BAILINGMOE: + regex_exprs = { + // original regex from tokenizer.json + // "'(?i:[sdmt]|ll|ve|re)|[^\\r\\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]++[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+" + // FIXME? Changed possessive quantifiers (?+ and ++) to greedy to avoid errors and imatrix hanging (tried atomic grouping but it's not supported?) + "'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_SEED_CODER: + regex_exprs = { + // original regex from tokenizer.json + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\r\n]+|\\s*[\r\n]+|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\\r\\n]+|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; + default: + // default regex for BPE tokenization pre-processing + regex_exprs = { + "[\\p{P}\\$\\+<=>\\^~\\|]+", + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + "\\p{N}+", + "[0-9][0-9][0-9]", + }; + break; + } + } + + std::vector regex_exprs; +}; + +struct llm_tokenizer_bpe_session { + llm_tokenizer_bpe_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : vocab(vocab), tokenizer(tokenizer) {} + + static void append(const llama_token token_id, std::vector & output) { + output.push_back(token_id); + } + + bool append_bos(std::vector & output) const { + if (vocab.get_add_bos()) { + GGML_ASSERT(vocab.token_bos() != LLAMA_TOKEN_NULL); + output.push_back(vocab.token_bos()); + return true; + } + return false; + } + + bool append_eos(std::vector & output) const { + if (vocab.get_add_eos()) { + GGML_ASSERT(vocab.token_eos() != LLAMA_TOKEN_NULL); + output.push_back(vocab.token_eos()); + return true; + } + return false; + } + + void check_double_bos_eos(const std::vector & output) const { + if (vocab.get_add_bos() && output.size() >= 2 && output[1] == vocab.token_bos()) { + LLAMA_LOG_WARN( + "%s: Added a BOS token to the prompt as specified by the model but the prompt " + "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. " + "Are you sure this is what you want?\n", __FUNCTION__); + } + if (vocab.get_add_eos() && output.size() >= 2 && *(output.end()-2) == vocab.token_eos()) { + LLAMA_LOG_WARN( + "%s: Added a EOS token to the prompt as specified by the model but the prompt " + "also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. " + "Are you sure this is what you want?\n", __FUNCTION__); + } + } + + void tokenize(const std::string & text, std::vector & output) { + int final_prev_index = -1; + const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs); + + symbols_final.clear(); + + for (const auto & word : word_collection) { + work_queue = llm_bigram_bpe::queue(); + symbols.clear(); + + int index = 0; + size_t offset = 0; + + //if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) { + if (vocab.get_ignore_merges() && vocab.text_to_token(word) != LLAMA_TOKEN_NULL) { + symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()}); + offset = word.size(); + } + + while (offset < word.size()) { + llm_symbol sym; + size_t char_len = std::min(word.size() - offset, (size_t) unicode_len_utf8(word[offset])); + sym.text = word.c_str() + offset; + sym.n = char_len; + offset += sym.n; + sym.prev = index - 1; + sym.next = offset == word.size() ? -1 : index + 1; + index++; + symbols.emplace_back(sym); + } + for (int i = 1; i < (int) symbols.size(); ++i) { + add_new_bigram(i - 1, i); + } + + // build token(s) + while (!work_queue.empty()) { + auto bigram = work_queue.pop_move(); + + auto & left_symbol = symbols[bigram.left]; + auto & right_symbol = symbols[bigram.right]; + + if (left_symbol.n == 0 || right_symbol.n == 0) { + continue; + } + std::string left_token = std::string(left_symbol.text, left_symbol.n); + std::string right_token = std::string(right_symbol.text, right_symbol.n); + if (left_token + right_token != bigram.text) { + continue; // Skip this bigram if it's outdated + } + + // merge the right sym into the left one + left_symbol.n += right_symbol.n; + right_symbol.n = 0; + + // remove the right sym from the chain + left_symbol.next = right_symbol.next; + if (right_symbol.next >= 0) { + symbols[right_symbol.next].prev = bigram.left; + } + + add_new_bigram(left_symbol.prev, bigram.left); // left side of current symbol + add_new_bigram(bigram.left, left_symbol.next); // right side of current symbol + } + + // add the finished tokens to the final list keeping correct order for next and prev + for (auto & sym : symbols) { + if (sym.n > 0) { + sym.prev = final_prev_index; + sym.next = -1; + if (final_prev_index != -1) { + symbols_final[final_prev_index].next = symbols_final.size(); + } + symbols_final.emplace_back(sym); + final_prev_index = symbols_final.size() - 1; + } + } + } + + symbols = symbols_final; + + if (!symbols.empty()) { + for (int i = 0; i != -1; i = symbols[i].next) { + auto & symbol = symbols[i]; + if (symbol.n == 0) { + continue; + } + + const std::string str = std::string(symbol.text, symbol.n); + const auto token = vocab.text_to_token(str); + + if (token == LLAMA_TOKEN_NULL) { + for (auto j = str.begin(); j != str.end(); ++j) { + std::string byte_str(1, *j); + auto token_multibyte = vocab.text_to_token(byte_str); + if (token_multibyte != LLAMA_TOKEN_NULL) { + output.push_back(token_multibyte); + } + } + } else { + output.push_back(token); + } + } + } + } + +private: + void add_new_bigram(int left, int right) { + if (left == -1 || right == -1) { + return; + } + std::string left_token = std::string(symbols[left].text, symbols[left].n); + std::string right_token = std::string(symbols[right].text, symbols[right].n); + + int rank_found = -1; + + rank_found = vocab.find_bpe_rank(left_token, right_token); + + if (rank_found < 0) { + return; + } + + llm_bigram_bpe bigram; + + bigram.left = left; + bigram.right = right; + bigram.text = left_token + right_token; + bigram.size = left_token.size() + right_token.size(); + bigram.rank = rank_found; + + work_queue.push(bigram); + } + + const llama_vocab & vocab; + const llm_tokenizer_bpe & tokenizer; + + std::vector symbols; + std::vector symbols_final; + llm_bigram_bpe::queue work_queue; +}; + +// +// WPM tokenizer +// + +struct llm_tokenizer_wpm : llm_tokenizer { + llm_tokenizer_wpm(const llama_vocab & /*vocab*/) {} +}; + +struct llm_tokenizer_wpm_session { + llm_tokenizer_wpm_session(const llama_vocab & vocab) : vocab(vocab) {} + + void tokenize(const std::string & text, std::vector & output) { + // normalize and split by whitespace + std::vector words = preprocess(text); + // bos token prepended already + + // find the longest tokens that form the words + for (const std::string & word : words) { + // skip empty words + if (word.size() == 0) { + continue; + } + + // prepend phantom space + const std::string word1 = "\xe2\x96\x81" + word; + const int n = word1.size(); + + const size_t current_tokens = output.size(); + + // we're at the start of a new word + // move through character position in word + for (int i = 0; i < n; ++i) { + // loop through possible match length + bool match = false; + for (int j = std::min(n, i + vocab.max_token_len() + 1); j > i; j--) { + auto id = vocab.text_to_token(word1.substr(i, j - i)); + if (id != LLAMA_TOKEN_NULL) { + output.push_back(id); + match = true; + i = j - 1; + break; + } + } + + if (!match) { // discard all + output.resize(current_tokens); + break; // and discard next tokens + } + } + + // we didn't find any matches for this word + if (current_tokens == output.size()) { + output.push_back(vocab.token_unk()); + } + } + } + + // TODO: reduce string copies by using cpts_offs array + static std::vector preprocess(const std::string & text) { + const std::vector cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); + std::vector words(1, ""); + + for (const uint32_t cpt : cpts_nfd) { + const auto flags = unicode_cpt_flags_from_cpt(cpt); + + if (flags.is_whitespace) { + if (words.back().size()) { // finish previous word if any + words.emplace_back(); + } + continue; + } + + assert (!flags.is_separator); + if (cpt == 0 || cpt == 0xFFFD || flags.is_control) { + continue; + } + + const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt)); + if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) { + if (words.back().size()) { // finish previous word if any + words.emplace_back(); + } + words.back() = s; // single char word + words.emplace_back(); // start a new word + } else { + words.back() += s; // append char to word + } + } + + if (!words.back().size()) { + words.pop_back(); + } + + return words; + } + + static bool is_chinese_char(uint32_t cpt) { + return + (cpt >= 0x04E00 && cpt <= 0x09FFF) || + (cpt >= 0x03400 && cpt <= 0x04DBF) || + (cpt >= 0x20000 && cpt <= 0x2A6DF) || + (cpt >= 0x2A700 && cpt <= 0x2B73F) || + (cpt >= 0x2B740 && cpt <= 0x2B81F) || + (cpt >= 0x2B920 && cpt <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920 + (cpt >= 0x0F900 && cpt <= 0x0FAFF) || + (cpt >= 0x2F800 && cpt <= 0x2FA1F); + //(cpt >= 0x3000 && cpt <= 0x303F) || + //(cpt >= 0xFF00 && cpt <= 0xFFEF); + } + +private: + const llama_vocab & vocab; + // currently unused + // const llm_tokenizer_wpm * wpm_tokenizer; +}; + +// +// UGM tokenizer +// + +struct llm_tokenizer_ugm : llm_tokenizer { + llm_tokenizer_ugm(const llama_vocab & vocab, const std::vector & precompiled_charsmap) { + if (precompiled_charsmap.size() > 0) { + size_t charsmap_offset = 0; + + // First four bytes of precompiled_charsmap contains length of binary + // blob containing XOR-compressed compact double array (XCDA) entries + uint32_t xcda_blob_size = *(const uint32_t *) &precompiled_charsmap[0]; + charsmap_offset += sizeof(xcda_blob_size); + if (xcda_blob_size + charsmap_offset >= precompiled_charsmap.size()) { + throw std::runtime_error("Index out of array bounds in precompiled charsmap!"); + } + + // Next xcda_blob_size bytes contain entries of XOR-compressed compact + // double array (XCDA). Each entry is bit-packed into a 32-bit integer. + xcda_array = (const uint32_t *) &precompiled_charsmap[charsmap_offset]; + xcda_array_size = xcda_blob_size / sizeof(uint32_t); + charsmap_offset += xcda_blob_size; + + // Remaining bytes of precompiled charsmap contain null-terminated + // replacement strings for prefixes matched by the XCDA. + prefix_replacements = &precompiled_charsmap[charsmap_offset]; + prefix_replacements_size = precompiled_charsmap.size() - charsmap_offset; + } + + for (uint32_t id = 0; id < vocab.n_tokens(); ++id) { + const auto & token_data = vocab.get_token_data(id); + + if (vocab.is_normal(id)) { + min_score = std::min(min_score, token_data.score); + max_score = std::max(max_score, token_data.score); + } + + if (vocab.is_normal(id) || + vocab.is_user_defined(id) || + vocab.is_unused(id)) { + token_matcher.insert(token_data.text.data(), token_data.text.size(), id); + } + + if (vocab.is_user_defined(id)) { + user_defined_token_matcher.insert(token_data.text.data(), token_data.text.size()); + } + } + + unknown_token_score = min_score - unknown_token_score_penalty; + } + + // escaped space symbol - U+2581 (Lower One Eighth Block) + const std::string escaped_space = "\xE2\x96\x81"; + + const char * prefix_replacements = NULL; + size_t prefix_replacements_size = 0; + + const uint32_t * xcda_array = NULL; + size_t xcda_array_size = 0; + + struct naive_trie user_defined_token_matcher; + + float min_score = FLT_MAX; + float max_score = -FLT_MAX; + + float unknown_token_score_penalty = 10.0; + float unknown_token_score; + + struct naive_trie token_matcher; +}; + +struct llm_tokenizer_ugm_session { + llm_tokenizer_ugm_session(const llama_vocab & vocab, const llm_tokenizer_ugm & tokenizer) : vocab(vocab), tokenizer(tokenizer) {} + + /* This implementation is based on SentencePiece optimized Viterbi algorithm for + * unigram language models. The general idea is to: + * - move along the input sequence in steps of one UTF code point, + * - at each step find all possible tokenizations of the prefix by + * traversing the tokens trie, + * - for each tokenization store the best one so far (by higher score) + * - use the position in sequence after given token as an index to store + * results + * - if there was no valid tokenization of the current UTF code point + * then use unknown token with additional score penalty + * After processing the whole sequence we backtrack from the end to get + * the best tokenization. + */ + void tokenize(const std::string & text, std::vector & output) { + // get current size of output (for reversal later) + size_t output_size = output.size(); + + // normalize the input first + std::string normalized; + normalize(text, &normalized); + size_t input_len = normalized.size(); + if (input_len == 0) { + return; + } + + // initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores + std::vector tokenization_results(input_len + 1, {vocab.token_unk(), 0, -DBL_MAX}); + // at the beginning tokenization score is zero + tokenization_results[0] = { vocab.token_unk(), 0, 0 }; + + for (size_t input_offset = 0; input_offset < input_len;) { + size_t prefix_offset = input_offset; + // calculate how many code units are in the currently processed UTF code point + size_t n_utf8_code_units = std::min(unicode_len_utf8(normalized[input_offset]), input_len - input_offset); + + // traverse the token matcher trie to find a matching token + bool single_codepoint_token_found = false; + const struct best_tokenization & current_best = tokenization_results[input_offset]; + const struct naive_trie * node = tokenizer.token_matcher.traverse(normalized[prefix_offset++]); + + while (prefix_offset <= input_len && node != NULL) { + // check if we found valid token in prefix + if (node->has_value) { + // check if it corresponds to the whole UTF code point + if (prefix_offset - input_offset == n_utf8_code_units) { + single_codepoint_token_found = true; + } + llama_token token_id = node->value; + const auto & token_data = vocab.get_token_data(token_id); + + // we set the user-defined token scores to 0 to make them more likely to be selected + // (normal token scores are log probabilities, so they are negative) + // score type is double here to make tokenization results exactly + // the same as in the HF tokenizer using SentencePiece + const double token_score = vocab.is_user_defined(token_id) ? 0.0 : token_data.score; + const double challenger_score = current_best.score_sum + token_score; + struct best_tokenization & current_champ = tokenization_results[prefix_offset]; + if (challenger_score > current_champ.score_sum) { + struct best_tokenization challenger = { token_id, input_offset, challenger_score }; + current_champ = challenger; + } + } + node = node->traverse(normalized[prefix_offset++]); + } + + // if we didn't find a valid token corresponding to the whole UTF code point + // then use unknown token as the tokenization of this UTF code point + if (!single_codepoint_token_found) { + const double challenger_score = current_best.score_sum + tokenizer.unknown_token_score; + prefix_offset = input_offset + n_utf8_code_units; + struct best_tokenization & current_champ = tokenization_results[prefix_offset]; + if (challenger_score > current_champ.score_sum) { + struct best_tokenization challenger = { vocab.token_unk(), input_offset, challenger_score }; + current_champ = challenger; + } + } + + // move to the next UTF code point + input_offset += n_utf8_code_units; + } + + // now backtrack from the end to gather token ids of the best tokenization + // merge sequences of consecutive unknown tokens into single unknown tokens + bool is_prev_unknown = false; + for (struct best_tokenization & tokenization = tokenization_results[input_len]; ; tokenization = tokenization_results[tokenization.input_offset]) { + bool is_unknown = tokenization.token_id == vocab.token_unk(); + if (!(is_prev_unknown && is_unknown)) { + output.push_back(tokenization.token_id); + } + if (tokenization.input_offset == 0) { + break; + } + is_prev_unknown = is_unknown; + } + + // reverse the output since we added tokens starting from the end of the input + std::reverse(output.begin() + output_size, output.end()); + } + +private: + + // helper structure for returning normalization results + struct normalization_result { + const char * normalized; + size_t normalized_len; + size_t consumed_input; + }; + + void normalize(const std::string& input, std::string * normalized) { + normalized->clear(); + normalized->reserve(input.size() * 3); + + const std::string space = vocab.get_escape_whitespaces() ? tokenizer.escaped_space : " "; + + const bool shall_prepend_space = !vocab.get_treat_whitespace_as_suffix() && vocab.get_add_space_prefix(); + const bool shall_append_space = vocab.get_treat_whitespace_as_suffix() && vocab.get_add_space_prefix(); + const bool shall_merge_spaces = vocab.get_remove_extra_whitespaces(); + + bool is_space_prepended = false; + bool processing_non_ws = false; + + size_t input_len = input.size(); + + for (size_t input_offset = 0; input_offset < input_len; ) { + auto norm_res = normalize_prefix(input, input_offset); + for (size_t i = 0; i < norm_res.normalized_len; i++) { + char c = norm_res.normalized[i]; + if (c != ' ') { + if (!processing_non_ws) { + processing_non_ws = true; + if ((shall_prepend_space && !is_space_prepended) || shall_merge_spaces) { + normalized->append(space); + is_space_prepended = true; + } + } + normalized->push_back(c); + } else { + if (processing_non_ws) { + processing_non_ws = false; + } + if (!shall_merge_spaces) { + normalized->append(space); + } + } + } + + input_offset += norm_res.consumed_input; + } + + if (shall_append_space) { + normalized->append(space); + } + } + + /* + * This structure is a view wrapper for XOR-compressed double array (XCDA) + * See Shunsuke Kanda (2018). Space- and Time-Efficient String Dictionaries. + * Each bit-packed entry contains: + * - BASE array value in bits 10-30 + * - LCHECK array value in bits 0-7 + * - LEAF array value in bit 9 + * Entries containing indexes of replacement sequences have set bit 31 + */ + struct xcda_array_view { + public: + xcda_array_view(const uint32_t * xcda_array, size_t xcda_array_size) : xcda_array(xcda_array), xcda_array_size(xcda_array_size) { + } + uint32_t get_base(size_t index) { + uint32_t packed_node = get_node(index); + return (packed_node >> 10) << ((packed_node & (1U << 9)) >> 6); + } + uint32_t get_lcheck(size_t index) { + uint32_t packed_node = get_node(index); + return packed_node & ((1U << 31) | 0xff); + } + bool get_leaf(size_t index) { + uint32_t packed_node = get_node(index); + return (packed_node >> 8) & 1; + } + uint32_t get_value(size_t index) { + uint32_t packed_node = get_node(index); + return packed_node & ((1U << 31) - 1); + } + private: + uint32_t get_node(size_t index) { + if (index > xcda_array_size) { + throw std::runtime_error("Index out of array bounds in XCDA array!"); + } + return xcda_array[index]; + } + const uint32_t * xcda_array; + size_t xcda_array_size; + }; + + // this structure stores the best tokenization so far at input_offset + struct best_tokenization { + llama_token token_id; + size_t input_offset; + double score_sum; + }; + + struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) { + if (input_offset == input.size()) { + return { &input[input_offset], 0, 0 }; + } + + // if input prefix matches some user-defined token return this token as normalization result + auto user_defined_token_match = + tokenizer.user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset); + if (user_defined_token_match.second > 0) { + return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second }; + } + + size_t longest_prefix_length = 0; + size_t longest_prefix_offset = 0; + + if (tokenizer.xcda_array_size > 0) { + struct xcda_array_view xcda_view(tokenizer.xcda_array, tokenizer.xcda_array_size); + + // Find the longest normalized sequence matching the input prefix by walking + // the XOR-compressed compact double array (XCDA) starting from the root node + // We find the index of the next node by calculating BASE[s] ^ c where s is + // the index of the previous node and c is a numerical character value + uint32_t node_index = 0; + // get BASE of the root node + node_index = xcda_view.get_base(node_index); + for (size_t prefix_offset = input_offset; prefix_offset < input.size(); prefix_offset++) { + unsigned char c = input[prefix_offset]; + if (c == 0) { + break; + } + node_index ^= c; + // if value of LCHECK is not c it means that this is not a child of + // the previous node, so we stop matching + if (xcda_view.get_lcheck(node_index) != c) { + break; + } + bool is_leaf = xcda_view.get_leaf(node_index); + // get BASE of the current node + node_index ^= xcda_view.get_base(node_index); + // if LEAF of the current node is true, it means that its BASE points to the node + // containing index of replacement sequence for currently matched input prefix + if (is_leaf) + { + longest_prefix_length = prefix_offset - input_offset + 1; + // get index of replacement sequence for currently matched input prefix + longest_prefix_offset = xcda_view.get_value(node_index); + } + } + } + + if (longest_prefix_length > 0) { + // we have a match, so return the replacement sequence + if (longest_prefix_offset >= tokenizer.prefix_replacements_size) { + throw std::runtime_error("Index out of array bounds in precompiled charsmap!"); + } + const char * prefix_replacement = &(tokenizer.prefix_replacements)[longest_prefix_offset]; + return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length }; + } + + // check if the input prefix contains a valid sequence of UTF-8 code units + try { + // if yes, return this sequence unmodified + size_t prefix_offset = input_offset; + unicode_cpt_from_utf8(input, prefix_offset); + return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset }; + } catch (std::invalid_argument & /*ex*/) { + // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER + return { "\xEF\xBF\xBD", 3, 1 }; + } + } + + const llama_vocab & vocab; + const llm_tokenizer_ugm & tokenizer; +}; + +// +// RWKV tokenizer +// + +static std::vector llama_unescape_rwkv_token(const std::string & escaped) { + std::vector output; + output.reserve(escaped.size()); + + // Parser state + bool escaping = false; + uint8_t hex_remaining = 0; + uint8_t hex_acc = 0; + + // Step through characters, performing parsing + for (const char & c : escaped) { + // If we're parsing a hex code, interpret the next character + if (hex_remaining != 0) { + uint8_t value = (c >= 'a') ? (c - 'a' + 10) : (c - '0'); + hex_acc = (hex_acc << 4) + value; + + hex_remaining -= 1; + if (hex_remaining == 0) { + output.push_back(hex_acc); + hex_acc = 0; + } + + continue; + } + + // If we got an escape character, interpret it + if (escaping) { + if (c == 't') { + output.push_back('\t'); + } else if (c == 'n') { + output.push_back('\n'); + } else if (c == 'r') { + output.push_back('\r'); + } else if (c == 'x') { + hex_remaining = 2; + } else { + output.push_back(c); + } + + escaping = false; + continue; + } + + if (c == '\\') { + escaping = true; + continue; + } + + output.push_back(c); + } + + return output; +} + +struct llm_tokenizer_rwkv : llm_tokenizer { + llm_tokenizer_rwkv(const llama_vocab & vocab) { + // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens. + // For now, we decode the vocab here into the lookup we'll use for tokenization. + + // build trie + for (uint32_t id = 0; id < vocab.n_tokens(); ++id) { + const auto & data = vocab.get_token_data(id); + const auto text = llama_unescape_rwkv_token(data.text); + token_matcher.insert((const char *) text.data(), text.size(), id); + } + } + + struct naive_trie token_matcher; +}; + +struct llm_tokenizer_rwkv_session { + llm_tokenizer_rwkv_session(const llama_vocab & vocab, const llm_tokenizer_rwkv & tokenizer) : vocab(vocab), tokenizer(tokenizer) {} + + void tokenize(const std::string & text, std::vector & output) { + uint32_t position = 0; + while (position < text.size()) { + const struct naive_trie * node = tokenizer.token_matcher.traverse(text[position]); + if (node == NULL) { + // no matching token found, add unknown token + output.push_back(vocab.token_unk()); + position += 1; + continue; + } + + // traverse the trie to find the longest matching token + uint32_t token_id = 0; + uint32_t token_length = 0; + while (node != NULL) { + if (node->has_value) { + token_id = node->value; + token_length = position + 1; + } + node = node->traverse(text[++position]); + } + + // add the longest matching token + output.push_back(token_id); + position = token_length; + } + } + +private: + const llama_vocab & vocab; + const llm_tokenizer_rwkv & tokenizer; +}; + +// +// impl +// + +typedef enum FRAGMENT_BUFFER_VARIANT_TYPE { + FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN, + FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT +} FRAGMENT_BUFFER_VARIANT_TYPE; + +struct fragment_buffer_variant { + fragment_buffer_variant(llama_token _token) + : + type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN), + token(_token), + raw_text(_dummy), + offset(0), + length(0) {} + + fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length) + : + type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT), + token((llama_token) - 1), + raw_text(_raw_text), + offset(_offset), + length(_length){ + GGML_ASSERT(_offset >= 0); + GGML_ASSERT(_length >= 1); + GGML_ASSERT(offset + length <= raw_text.length()); + } + + const FRAGMENT_BUFFER_VARIANT_TYPE type; + const llama_token token; + const std::string _dummy; + const std::string & raw_text; + const uint64_t offset; + const uint64_t length; +}; + +struct llama_vocab::impl { + uint32_t n_token_types = 0; // for BERT-style token types + + std::string tokenizer_model; + std::string tokenizer_pre; + + enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; + enum llama_vocab_pre_type pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + + int max_token_len = 0; // used for optimizing longest token search + + // default LLaMA special tokens + // TODO: should we set all of these to LLAMA_TOKEN_NULL? + llama_token special_bos_id = 1; + llama_token special_eos_id = 2; + llama_token special_eot_id = LLAMA_TOKEN_NULL; + llama_token special_eom_id = LLAMA_TOKEN_NULL; + llama_token special_unk_id = 0; + llama_token special_sep_id = LLAMA_TOKEN_NULL; + llama_token special_pad_id = LLAMA_TOKEN_NULL; + llama_token special_mask_id = LLAMA_TOKEN_NULL; + + llama_token linefeed_id = 13; + + // fim tokens + llama_token special_fim_pre_id = LLAMA_TOKEN_NULL; + llama_token special_fim_suf_id = LLAMA_TOKEN_NULL; + llama_token special_fim_mid_id = LLAMA_TOKEN_NULL; + llama_token special_fim_pad_id = LLAMA_TOKEN_NULL; + llama_token special_fim_rep_id = LLAMA_TOKEN_NULL; // repo + llama_token special_fim_sep_id = LLAMA_TOKEN_NULL; // file separator + + // tokenizer flags + bool add_space_prefix = false; + bool add_bos = false; + bool add_eos = false; + bool ignore_merges = false; + bool clean_spaces = false; // clean_up_tokenization_spaces + bool remove_extra_whitespaces = false; + bool escape_whitespaces = true; + bool treat_whitespace_as_suffix = false; + + std::unordered_map token_to_id; + std::vector id_to_token; + + std::vector cache_special_tokens; + std::vector cache_token_to_piece; // llama_token_to_piece(special = true); + struct pair_hash { + size_t operator()(const std::pair & p) const { + return std::hash{}(p.first) ^ //create some hash for pair + (std::hash{}(p.second) << 1); + } + }; + std::unordered_map, int, pair_hash> bpe_ranks; + + // set of all tokens that cause "end of generation" + std::set special_eog_ids; + + std::unique_ptr tokenizer; + + std::vector precompiled_charsmap; + + impl(const llama_vocab & vocab) : vocab(vocab) { + } + + ~impl() = default; + + void load(llama_model_loader & ml, const LLM_KV & kv); + + enum llama_vocab_type get_type() const; + + std::string type_name() const; + + bool is_normal (llama_token id) const; + bool is_unknown (llama_token id) const; + bool is_control (llama_token id) const; + bool is_byte (llama_token id) const; + bool is_user_defined(llama_token id) const; + bool is_unused (llama_token id) const; + bool is_eog (llama_token id) const; + + uint8_t token_to_byte(llama_token id) const; + + llama_token_attr token_get_attr(llama_token id) const; + + void init_tokenizer(enum llama_vocab_type type); + + void tokenizer_st_partition(std::forward_list & buffer, bool parse_special) const; + + std::string token_to_piece_for_cache( + llama_token token, + bool special) const; + + + std::vector tokenize( + const std::string & raw_text, + bool add_special, + bool parse_special = false) const; + + int32_t tokenize( + const char * text, + int32_t text_len, + llama_token * tokens, + int32_t n_tokens_max, + bool add_special, + bool parse_special) const; + + // does not write null-terminator to buf + int32_t token_to_piece( + llama_token token, + char * buf, + int32_t length, + int32_t lstrip, + bool special) const; + + // use cached data + const std::string & token_to_piece(llama_token token) const; + + int32_t detokenize( + const llama_token * tokens, + int32_t n_tokens, + char * text, + int32_t text_len_max, + bool remove_special, + bool unparse_special) const; + + std::string detokenize( + const std::vector & tokens, + bool special) const; + + void print_info() const; + +private: + const llama_vocab & vocab; +}; + +void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { + struct gguf_context * ctx = ml.meta.get(); + + // determine vocab type + { + ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model); + ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); + + ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, n_token_types, false); + + if (tokenizer_model == "no_vocab" || tokenizer_model == "none") { + type = LLAMA_VOCAB_TYPE_NONE; + + // default special tokens + special_bos_id = LLAMA_TOKEN_NULL; + special_eos_id = LLAMA_TOKEN_NULL; + special_unk_id = LLAMA_TOKEN_NULL; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = LLAMA_TOKEN_NULL; + special_mask_id = LLAMA_TOKEN_NULL; + linefeed_id = LLAMA_TOKEN_NULL; + + // read vocab size from metadata + uint32_t n_tokens = 0; + if (ml.get_key(LLM_KV_VOCAB_SIZE, n_tokens, false)) { + LLAMA_LOG_WARN("%s: adding %u dummy tokens\n", __func__, n_tokens); + id_to_token.resize(n_tokens); + } + + return; + } + + if (tokenizer_model == "llama") { + type = LLAMA_VOCAB_TYPE_SPM; + + // default special tokens + special_bos_id = 1; + special_eos_id = 2; + special_unk_id = 0; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = LLAMA_TOKEN_NULL; + special_mask_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "bert") { + type = LLAMA_VOCAB_TYPE_WPM; + + // default special tokens + special_bos_id = 101; + special_eos_id = LLAMA_TOKEN_NULL; + special_unk_id = 100; + special_sep_id = 102; + special_pad_id = 0; + special_mask_id = 103; + } else if (tokenizer_model == "gpt2") { + type = LLAMA_VOCAB_TYPE_BPE; + + // read bpe merges and populate bpe ranks + const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); + if (merges_keyidx == -1) { + throw std::runtime_error("cannot find tokenizer merges in model file\n"); + } + + const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); + for (int i = 0; i < n_merges; i++) { + const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); + //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); + + std::string first; + std::string second; + + const size_t pos = word.find(' ', 1); + + if (pos != std::string::npos) { + first = word.substr(0, pos); + second = word.substr(pos + 1); + } + + bpe_ranks.emplace(std::make_pair(first, second), i); + } + + // default special tokens + special_bos_id = 11; + special_eos_id = 11; + special_unk_id = LLAMA_TOKEN_NULL; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = LLAMA_TOKEN_NULL; + special_mask_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "t5") { + type = LLAMA_VOCAB_TYPE_UGM; + + // default special tokens + special_bos_id = LLAMA_TOKEN_NULL; + special_eos_id = 1; + special_unk_id = 2; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = 0; + special_mask_id = LLAMA_TOKEN_NULL; + + const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str()); + if (precompiled_charsmap_keyidx != -1) { + const gguf_type pc_type = gguf_get_arr_type(ctx, precompiled_charsmap_keyidx); + GGML_ASSERT(pc_type == GGUF_TYPE_INT8 || pc_type == GGUF_TYPE_UINT8); + + const size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx); + const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); + precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap); +#ifdef IS_BIG_ENDIAN + // correct endiannes of data in precompiled_charsmap binary blob + uint32_t * xcda_blob_size = (uint32_t *) &precompiled_charsmap[0]; + *xcda_blob_size = __builtin_bswap32(*xcda_blob_size); + assert(*xcda_blob_size + sizeof(uint32_t) < n_precompiled_charsmap); + size_t xcda_array_size = *xcda_blob_size / sizeof(uint32_t); + uint32_t * xcda_array = (uint32_t *) &precompiled_charsmap[sizeof(uint32_t)]; + for (size_t i = 0; i < xcda_array_size; ++i) { + xcda_array[i] = __builtin_bswap32(xcda_array[i]); + } +#endif + } + } else if (tokenizer_model == "rwkv") { + type = LLAMA_VOCAB_TYPE_RWKV; + + // default special tokens + special_bos_id = LLAMA_TOKEN_NULL; + special_eos_id = LLAMA_TOKEN_NULL; + special_unk_id = LLAMA_TOKEN_NULL; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = LLAMA_TOKEN_NULL; + } else { + throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); + } + + // for now, only BPE models have pre-tokenizers + if (type == LLAMA_VOCAB_TYPE_BPE) { + add_space_prefix = false; + clean_spaces = true; + if (tokenizer_pre.empty()) { + LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__); + LLAMA_LOG_WARN("%s: \n", __func__); + LLAMA_LOG_WARN("%s: ************************************ \n", __func__); + LLAMA_LOG_WARN("%s: GENERATION QUALITY WILL BE DEGRADED! \n", __func__); + LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL \n", __func__); + LLAMA_LOG_WARN("%s: ************************************ \n", __func__); + LLAMA_LOG_WARN("%s: \n", __func__); + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } else if (tokenizer_pre == "default") { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } else if ( + tokenizer_pre == "llama3" || + tokenizer_pre == "llama-v3" || + tokenizer_pre == "llama-bpe"|| + tokenizer_pre == "falcon3" || + tokenizer_pre == "pixtral") { + pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3; + ignore_merges = true; + add_bos = true; + } else if ( + tokenizer_pre == "deepseek-llm") { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM; + clean_spaces = false; + } else if ( + tokenizer_pre == "deepseek-coder") { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER; + clean_spaces = false; + } else if ( + tokenizer_pre == "deepseek-v3") { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM; + clean_spaces = false; + } else if ( + tokenizer_pre == "falcon") { + pre_type = LLAMA_VOCAB_PRE_TYPE_FALCON; + } else if ( + tokenizer_pre == "mpt") { + pre_type = LLAMA_VOCAB_PRE_TYPE_MPT; + } else if ( + tokenizer_pre == "starcoder") { + pre_type = LLAMA_VOCAB_PRE_TYPE_STARCODER; + } else if ( + tokenizer_pre == "gpt-2" || + tokenizer_pre == "phi-2" || + tokenizer_pre == "jina-es" || + tokenizer_pre == "jina-de" || + tokenizer_pre == "gigachat" || + tokenizer_pre == "jina-v1-en" || + tokenizer_pre == "jina-v2-es" || + tokenizer_pre == "jina-v2-de" || + tokenizer_pre == "jina-v2-code" || + tokenizer_pre == "roberta-bpe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; + } else if ( + tokenizer_pre == "refact") { + pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT; + } else if ( + tokenizer_pre == "command-r") { + pre_type = LLAMA_VOCAB_PRE_TYPE_COMMAND_R; + clean_spaces = false; + } else if ( + tokenizer_pre == "qwen2" || + tokenizer_pre == "deepseek-r1-qwen") { + pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; + clean_spaces = false; + } else if ( + tokenizer_pre == "stablelm2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_STABLELM2; + } else if ( + tokenizer_pre == "olmo") { + pre_type = LLAMA_VOCAB_PRE_TYPE_OLMO; + } else if ( + tokenizer_pre == "dbrx") { + pre_type = LLAMA_VOCAB_PRE_TYPE_DBRX; + } else if ( + tokenizer_pre == "smaug-bpe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SMAUG; + } else if ( + tokenizer_pre == "poro-chat") { + pre_type = LLAMA_VOCAB_PRE_TYPE_PORO; + clean_spaces = false; + } else if ( + tokenizer_pre == "glm4" || + tokenizer_pre == "chatglm-bpe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_CHATGLM4; + special_bos_id = LLAMA_TOKEN_NULL; + } else if ( + tokenizer_pre == "viking") { + pre_type = LLAMA_VOCAB_PRE_TYPE_VIKING; + clean_spaces = false; + } else if ( + tokenizer_pre == "jais") { + pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS; + } else if ( + tokenizer_pre == "tekken") { + pre_type = LLAMA_VOCAB_PRE_TYPE_TEKKEN; + clean_spaces = false; + ignore_merges = true; + add_bos = true; + } else if ( + tokenizer_pre == "smollm") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SMOLLM; + clean_spaces = false; + } else if ( + tokenizer_pre == "codeshell") { + pre_type = LLAMA_VOCAB_PRE_TYPE_CODESHELL; + } else if ( + tokenizer_pre == "bloom") { + pre_type = LLAMA_VOCAB_PRE_TYPE_BLOOM; + } else if ( + tokenizer_pre == "gpt3-finnish") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH; + } else if ( + tokenizer_pre == "exaone") { + pre_type = LLAMA_VOCAB_PRE_TYPE_EXAONE; + } else if ( + tokenizer_pre == "chameleon") { + pre_type = LLAMA_VOCAB_PRE_TYPE_CHAMELEON; + add_bos = true; + clean_spaces = false; + } else if ( + tokenizer_pre == "minerva-7b") { + pre_type = LLAMA_VOCAB_PRE_TYPE_MINERVA; + } else if ( + tokenizer_pre == "megrez") { + pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; + } else if ( + tokenizer_pre == "gpt-4o" || + tokenizer_pre == "llama4") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O; + clean_spaces = false; + } else if ( + tokenizer_pre == "superbpe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SUPERBPE; + clean_spaces = false; + } else if ( + tokenizer_pre == "trillion") { + pre_type = LLAMA_VOCAB_PRE_TYPE_TRILLION; + clean_spaces = false; + } else if ( + tokenizer_pre == "bailingmoe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE; + clean_spaces = false; + } else if ( + tokenizer_pre == "seed-coder") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER; + clean_spaces = false; + } else { + throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); + } + } else if (type == LLAMA_VOCAB_TYPE_SPM) { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + add_space_prefix = true; + clean_spaces = false; + add_bos = true; + add_eos = false; + } else if (type == LLAMA_VOCAB_TYPE_WPM) { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + add_space_prefix = false; + clean_spaces = true; + add_bos = true; + add_eos = false; + } else if (type == LLAMA_VOCAB_TYPE_UGM) { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + add_bos = false; + add_eos = true; + } else if (type == LLAMA_VOCAB_TYPE_RWKV) { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + add_space_prefix = false; + clean_spaces = false; + add_bos = false; + add_eos = false; + } else { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } + + ml.get_key(LLM_KV_TOKENIZER_ADD_PREFIX, add_space_prefix, false); + ml.get_key(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, remove_extra_whitespaces, false); + } + + const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str()); + if (token_idx == -1) { + throw std::runtime_error("cannot find tokenizer vocab in model file\n"); + } + + const float * scores = nullptr; + const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str()); + if (score_idx != -1) { + scores = (const float * ) gguf_get_arr_data(ctx, score_idx); + } + + const int * toktypes = nullptr; + const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); + if (toktype_idx != -1) { + toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); + } + + uint32_t n_tokens = gguf_get_arr_n(ctx, token_idx); + id_to_token.resize(n_tokens); + + for (uint32_t i = 0; i < n_tokens; i++) { + std::string word = gguf_get_arr_str(ctx, token_idx, i); + if (word.empty()) { + LLAMA_LOG_WARN("%s: empty token at index %u\n", __func__, i); + word = "[EMPTY_" + std::to_string(i) + "]"; + } + + token_to_id[word] = i; + max_token_len = std::max(max_token_len, (int) word.size()); + + auto & token_data = id_to_token[i]; + token_data.text = std::move(word); + token_data.score = scores ? scores[i] : 0.0f; + token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; + + if (toktypes) { //TODO: remove, required until per token attributes are available from GGUF file + switch(toktypes[i]) { + case LLAMA_TOKEN_TYPE_UNKNOWN: token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN; break; + case LLAMA_TOKEN_TYPE_UNUSED: token_data.attr = LLAMA_TOKEN_ATTR_UNUSED; break; + case LLAMA_TOKEN_TYPE_NORMAL: token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; break; + case LLAMA_TOKEN_TYPE_CONTROL: token_data.attr = LLAMA_TOKEN_ATTR_CONTROL; break; + case LLAMA_TOKEN_TYPE_USER_DEFINED: token_data.attr = LLAMA_TOKEN_ATTR_USER_DEFINED; break; + case LLAMA_TOKEN_TYPE_BYTE: token_data.attr = LLAMA_TOKEN_ATTR_BYTE; break; + case LLAMA_TOKEN_TYPE_UNDEFINED: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; + default: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; + } + } + } + GGML_ASSERT(id_to_token.size() == token_to_id.size()); + + init_tokenizer(type); + + // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' + if (type == LLAMA_VOCAB_TYPE_SPM) { + try { + linefeed_id = vocab.byte_to_token('\n'); + } catch (const std::exception & e) { + LLAMA_LOG_WARN("%s: SPM vocabulary, but newline token not found: %s! Using special_pad_id instead.", __func__, e.what()); + linefeed_id = special_pad_id; + } + } else if (type == LLAMA_VOCAB_TYPE_WPM) { + linefeed_id = special_pad_id; + } else if (type == LLAMA_VOCAB_TYPE_RWKV) { + const std::vector ids = tokenize("\n", false); + GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); + linefeed_id = ids[0]; + } else { + const std::vector ids = tokenize("\n", false); + + //GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); + if (ids.empty()) { + LLAMA_LOG_WARN("%s: model vocab missing newline token, using special_pad_id instead\n", __func__); + linefeed_id = special_pad_id; + } else { + linefeed_id = ids[0]; + } + } + + // special tokens + { + const std::vector> special_token_types = { + { LLM_KV_TOKENIZER_BOS_ID, special_bos_id }, + { LLM_KV_TOKENIZER_EOS_ID, special_eos_id }, + { LLM_KV_TOKENIZER_EOT_ID, special_eot_id }, + { LLM_KV_TOKENIZER_EOM_ID, special_eom_id }, + { LLM_KV_TOKENIZER_UNK_ID, special_unk_id }, + { LLM_KV_TOKENIZER_SEP_ID, special_sep_id }, + { LLM_KV_TOKENIZER_PAD_ID, special_pad_id }, + { LLM_KV_TOKENIZER_MASK_ID, special_mask_id }, + { LLM_KV_TOKENIZER_FIM_PRE_ID, special_fim_pre_id }, + { LLM_KV_TOKENIZER_FIM_SUF_ID, special_fim_suf_id }, + { LLM_KV_TOKENIZER_FIM_MID_ID, special_fim_mid_id }, + { LLM_KV_TOKENIZER_FIM_PAD_ID, special_fim_pad_id }, + { LLM_KV_TOKENIZER_FIM_REP_ID, special_fim_rep_id }, + { LLM_KV_TOKENIZER_FIM_SEP_ID, special_fim_sep_id }, + + // deprecated + { LLM_KV_TOKENIZER_PREFIX_ID, special_fim_pre_id }, + { LLM_KV_TOKENIZER_SUFFIX_ID, special_fim_suf_id }, + { LLM_KV_TOKENIZER_MIDDLE_ID, special_fim_mid_id }, + }; + + for (const auto & it : special_token_types) { + const std::string & key = kv(std::get<0>(it)); + int32_t & id = std::get<1>(it); + + uint32_t new_id; + if (!ml.get_key(std::get<0>(it), new_id, false)) { + continue; + } + if (new_id >= id_to_token.size()) { + LLAMA_LOG_WARN("%s: bad special token: '%s' = %u, using default id %d\n", + __func__, key.c_str(), new_id, id); + } else { + id = new_id; + } + } + + // Handle add_bos and add_eos + { + bool temp = true; + + if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) { + add_bos = temp; + } + if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) { + add_eos = temp; + } + } + + // auto-detect special tokens by text + // TODO: convert scripts should provide these tokens through the KV metadata LLM_KV_TOKENIZER_... + // for now, we apply this workaround to find the tokens based on their text + + for (const auto & t : token_to_id) { + // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc. + if (special_eot_id == LLAMA_TOKEN_NULL) { + if (false + || t.first == "<|eot_id|>" + || t.first == "<|im_end|>" + || t.first == "<|end|>" + || t.first == "" + || t.first == "<|endoftext|>" + || t.first == "" + || t.first == "_" + || t.first == "<|end▁of▁sentence|>" // DeepSeek + ) { + special_eot_id = t.second; + if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", + __func__, t.second, t.first.c_str()); + id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + } + } + } + + // find EOM token: "<|eom_id|>" + if (special_eom_id == LLAMA_TOKEN_NULL) { + if (false + || t.first == "<|eom_id|>" + ) { + special_eom_id = t.second; + if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", + __func__, t.second, t.first.c_str()); + id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + } + } + } + + // find FIM_PRE token: "<|fim_prefix|>", "", "
", etc.
+            if (special_fim_pre_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_prefix|>"  // Qwen
+                        || t.first == ""
+                        || t.first == ""    // Granite
+                        || t.first == "<|fim▁begin|>" // DeepSeek
+                        || t.first == "
"
+                        || t.first == "▁
"          // CodeLlama
+                        ) {
+                    special_fim_pre_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_SUF token: "<|fim_suffix|>", "", "", etc.
+            if (special_fim_suf_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_suffix|>" // Qwen
+                        || t.first == ""
+                        || t.first == ""   // Granite
+                        || t.first == "<|fim▁hole|>" // DeepSeek
+                        || t.first == ""
+                        || t.first == "▁"         // CodeLlama
+                        ) {
+                    special_fim_suf_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_MID token: "<|fim_middle|>", "", "", etc.
+            if (special_fim_mid_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_middle|>" // Qwen
+                        || t.first == ""
+                        || t.first == ""   // Granite
+                        || t.first == "<|fim▁end|>"  // DeepSeek
+                        || t.first == ""
+                        || t.first == "▁"         // CodeLlama
+                        ) {
+                    special_fim_mid_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_PAD token: "<|fim_pad|>", "", "", etc.
+            if (special_fim_pad_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_pad|>" // Qwen
+                        || t.first == ""
+                        || t.first == ""   // Granite
+                        || t.first == ""
+                        ) {
+                    special_fim_pad_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_REP token: "<|fim_repo|>", "", "", etc.
+            if (special_fim_rep_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_repo|>"  // Qwen
+                        || t.first == "<|repo_name|>"
+                        || t.first == ""
+                        || t.first == ""
+                        || t.first == ""    // Granite
+                        ) {
+                    special_fim_rep_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_SEP token: "<|file_sep|>"
+            if (special_fim_sep_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|file_sep|>" // Qwen
+                        ) {
+                    special_fim_sep_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+        }
+
+        // maintain a list of tokens that cause end-of-generation
+        // this is currently determined based on the token text, which is obviously not ideal
+        // ref: https://github.com/ggerganov/llama.cpp/issues/9606
+        special_eog_ids.clear();
+
+        if (special_fim_pad_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_pad_id) == 0) {
+            special_eog_ids.insert(special_fim_pad_id);
+        }
+
+        if (special_fim_rep_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_rep_id) == 0) {
+            special_eog_ids.insert(special_fim_rep_id);
+        }
+
+        if (special_fim_sep_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_sep_id) == 0) {
+            special_eog_ids.insert(special_fim_sep_id);
+        }
+
+        for (const auto & t : token_to_id) {
+            if (false
+                    || t.first == "<|eot_id|>"
+                    || t.first == "<|im_end|>"
+                    || t.first == "<|end|>"
+                    || t.first == ""
+                    || t.first == "<|endoftext|>"
+                    || t.first == "<|eom_id|>"
+                    || t.first == ""
+                    || t.first == "_"
+                    || t.first == "<|end_of_text|>"
+               ) {
+                special_eog_ids.insert(t.second);
+                if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                            __func__, t.second, t.first.c_str());
+                    id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                }
+            } else {
+                // token is control, but not marked as EOG -> print a debug log
+                if (id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && special_eog_ids.count(t.second) == 0) {
+                    LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n",
+                            __func__, t.second, t.first.c_str());
+                }
+            }
+        }
+
+        // sanity checks
+        if (special_eos_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eos_id) == 0) {
+            special_eog_ids.insert(special_eos_id);
+            LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        if (special_eot_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eot_id) == 0) {
+            special_eog_ids.insert(special_eot_id);
+            LLAMA_LOG_WARN("%s: special_eot_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        if (special_eom_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eom_id) == 0) {
+            special_eog_ids.insert(special_eom_id);
+            LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+    }
+
+    // build special tokens cache
+    {
+        for (llama_token id = 0; id < (llama_token) n_tokens; ++id) {
+            if (id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
+                cache_special_tokens.push_back(id);
+            }
+        }
+
+        std::sort(cache_special_tokens.begin(), cache_special_tokens.end(),
+            [&] (const llama_token a, const llama_token b) {
+                return id_to_token[a].text.size() > id_to_token[b].text.size();
+            }
+        );
+
+        LLAMA_LOG_INFO("%s: special tokens cache size = %u\n", __func__, (uint32_t) cache_special_tokens.size());
+    }
+
+    // build token to piece cache
+    {
+        size_t size_cache = 0;
+
+        std::vector cache(n_tokens);
+
+        for (uint32_t id = 0; id < n_tokens; ++id) {
+            cache[id] = token_to_piece_for_cache(id, true);
+
+            size_cache += cache[id].size();
+        }
+
+        std::swap(cache_token_to_piece, cache);
+
+        LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0);
+    }
+
+    // Handle per token attributes
+    //NOTE: Each model customizes per token attributes.
+    //NOTE: Per token attributes are missing from the GGUF file.
+    //TODO: Extract attributes from GGUF file.
+    {
+        auto _contains_any = [] (const std::string & str, const std::vector & substrs) -> bool {
+            for (const auto & substr : substrs) {
+                if (str.find(substr) < std::string::npos) {
+                    return true;
+                }
+            }
+            return false;
+        };
+
+        auto _set_tokenid_attr = [&] (const llama_token id, llama_token_attr attr, bool value) {
+            uint32_t current = id_to_token.at(id).attr;
+            current = value ? (current | attr) : (current & ~attr);
+            id_to_token[id].attr = (llama_token_attr) current;
+        };
+
+        auto _set_token_attr = [&] (const std::string & token, llama_token_attr attr, bool value) {
+            _set_tokenid_attr(token_to_id.at(token), attr, value);
+        };
+
+        std::string model_name;
+        std::string tokenizer_pre;
+        std::string general_arch;
+
+        ml.get_key(LLM_KV_GENERAL_NAME,  model_name,    false);
+        ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
+        ml.get_key(LLM_KV_GENERAL_ARCHITECTURE, general_arch, false);
+
+        // model name to lowercase
+        std::transform(model_name.begin(), model_name.end(), model_name.begin(),
+            [] (const std::string::value_type x) {
+                return std::tolower(x);
+            }
+        );
+
+        // set attributes by model/tokenizer/architecture name
+        if (false
+                || _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})
+                || _contains_any(general_arch, {"nomic-bert-moe"})
+           ) {
+            if (token_to_id.count("") == 0) {
+                LLAMA_LOG_WARN("%s: Mask token is missing in vocab, please reconvert model!\n", __func__);
+            } else {
+                _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true);
+            }
+        } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
+            for (auto id : cache_special_tokens) {
+                _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
+            }
+            for (const auto * token : {""}) {
+                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true);
+            }
+            for (const auto * token : {"", "", "<|endoftext|>"}) {
+                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
+            }
+        }
+    }
+}
+
+enum llama_vocab_type llama_vocab::impl::get_type() const {
+    return type;
+}
+
+std::string llama_vocab::impl::type_name() const{
+    switch (type) {
+        case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
+        case LLAMA_VOCAB_TYPE_SPM:  return "SPM";
+        case LLAMA_VOCAB_TYPE_BPE:  return "BPE";
+        case LLAMA_VOCAB_TYPE_WPM:  return "WPM";
+        case LLAMA_VOCAB_TYPE_UGM:  return "UGM";
+        case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
+        default:                    return "unknown";
+    }
+}
+
+bool llama_vocab::impl::is_normal(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL;
+}
+
+bool llama_vocab::impl::is_unknown(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN;
+}
+
+bool llama_vocab::impl::is_control(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL;
+}
+
+bool llama_vocab::impl::is_byte(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE;
+}
+
+bool llama_vocab::impl::is_user_defined(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED;
+}
+
+bool llama_vocab::impl::is_unused(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNUSED;
+}
+
+bool llama_vocab::impl::is_eog(llama_token id) const {
+    return id != LLAMA_TOKEN_NULL && special_eog_ids.count(id) > 0;
+}
+
+uint8_t llama_vocab::impl::token_to_byte(llama_token id) const {
+    GGML_ASSERT(get_type() != LLAMA_VOCAB_TYPE_NONE);
+    GGML_ASSERT(is_byte(id));
+    const auto & token_data = id_to_token.at(id);
+    switch (get_type()) {
+        case LLAMA_VOCAB_TYPE_SPM:
+        case LLAMA_VOCAB_TYPE_UGM: {
+            auto buf = token_data.text.substr(3, 2);
+            return strtol(buf.c_str(), NULL, 16);
+        }
+        case LLAMA_VOCAB_TYPE_BPE: {
+            GGML_ABORT("fatal error");
+        }
+        case LLAMA_VOCAB_TYPE_WPM: {
+            GGML_ABORT("fatal error");
+        }
+        default:
+            GGML_ABORT("fatal error");
+    }
+}
+
+llama_token_attr llama_vocab::impl::token_get_attr(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token.at(id).attr;
+}
+
+void llama_vocab::impl::init_tokenizer(enum llama_vocab_type type) {
+    LLAMA_LOG_DEBUG("%s: initializing tokenizer for type %d\n", __func__, type);
+
+    switch (type) {
+        case LLAMA_VOCAB_TYPE_SPM:
+            tokenizer = std::make_unique(vocab);
+            break;
+        case LLAMA_VOCAB_TYPE_BPE:
+            tokenizer = std::make_unique(vocab);
+            break;
+        case LLAMA_VOCAB_TYPE_WPM:
+            tokenizer = std::make_unique(vocab);
+            break;
+        case LLAMA_VOCAB_TYPE_UGM:
+            tokenizer = std::make_unique(vocab, precompiled_charsmap);
+            break;
+        case LLAMA_VOCAB_TYPE_RWKV:
+            tokenizer = std::make_unique(vocab);
+            break;
+        default:
+            GGML_ABORT("unsupported vocab type");
+    }
+}
+
+//
+// (de-) tokenize
+//
+
+// #define PRETOKENIZERDEBUG
+
+void llama_vocab::impl::tokenizer_st_partition(std::forward_list & buffer, bool parse_special) const {
+    // for each special token
+    for (const llama_token special_id : cache_special_tokens) {
+        const auto & data = vocab.get_token_data(special_id);
+        const auto & text = data.text;
+
+        if (!parse_special && (data.attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_UNKNOWN))) {
+            // Ignore control and unknown tokens when parse_special == false
+            continue;
+            // User-defined tokens are still pre-tokenized before everything else
+            // ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726
+            // This is mostly relevant for neox-style tokenizers (mpt, olmo, stablelm, etc.)
+        }
+
+        // for each text fragment
+        std::forward_list::iterator it = buffer.begin();
+        while (it != buffer.end()) {
+            auto & fragment = (*it);
+
+            // if a fragment is text ( not yet processed )
+            if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                const auto & raw_text = fragment.raw_text;
+
+                auto raw_text_base_offset = fragment.offset;
+                auto raw_text_base_length = fragment.length;
+
+                // loop over the text
+                while (true) {
+                    // find the first occurrence of a given special token in this fragment
+                    //  passing offset argument only limit the "search area" but match coordinates
+                    //  are still relative to the source full raw_text
+                    //  string_view begins at pos 0 for the same reason
+                    auto match = std::string_view(raw_text.data(), raw_text_base_offset + raw_text_base_length).find(text, raw_text_base_offset);
+
+                    // no occurrences found, stop processing this fragment for a given special token
+                    if (match == std::string::npos) break;
+
+#ifdef PRETOKENIZERDEBUG
+                    LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
+#endif
+                    auto source = std::distance(buffer.begin(), it);
+
+                    // if match is further than base offset
+                    //  then we have some text to the left of it
+                    if (match > raw_text_base_offset) {
+                        // left
+                        const int64_t left_reminder_offset = raw_text_base_offset + 0;
+                        int64_t left_reminder_length = match - raw_text_base_offset;
+
+                        if (data.attr & LLAMA_TOKEN_ATTR_LSTRIP) {
+                            while (left_reminder_length > 0 && isspace(raw_text[left_reminder_offset + left_reminder_length - 1])) {
+                                left_reminder_length--;
+                            }
+                        }
+
+                        if (left_reminder_length > 0) {
+                            buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
+                            it++;
+                        }
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
+#endif
+                    }
+
+                    // special token
+                    buffer.emplace_after(it, special_id);
+                    it++;
+
+                    // right
+                    if (match + text.length() < raw_text_base_offset + raw_text_base_length) {
+                        int64_t right_reminder_offset = match + text.length();
+                        int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + text.length());
+
+                        if (data.attr & LLAMA_TOKEN_ATTR_RSTRIP) {
+                            while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) {
+                                right_reminder_offset++;
+                                right_reminder_length--;
+                            }
+                        }
+
+                        if (right_reminder_length > 0) {
+                            buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
+                            it++;
+                        }
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());
+#endif
+
+                        if (source == 0) {
+                            buffer.erase_after(buffer.before_begin());
+                        } else {
+                            buffer.erase_after(std::next(buffer.begin(), (source - 1)));
+                        }
+
+                        // repeat for the right side
+                        raw_text_base_offset = right_reminder_offset;
+                        raw_text_base_length = right_reminder_length;
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("RR: (%ld %ld) '%s'\n", raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
+#endif
+                    } else {
+                        if (source == 0) {
+                            buffer.erase_after(buffer.before_begin());
+                        } else {
+                            buffer.erase_after(std::next(buffer.begin(), (source - 1)));
+                        }
+                        break;
+                    }
+                }
+            }
+            it++;
+        }
+    }
+}
+
+// NOTE: avoid ever using this except for building the token_to_piece caches
+std::string llama_vocab::impl::token_to_piece_for_cache(llama_token token, bool special) const {
+    std::string piece;
+    piece.resize(piece.capacity());  // using string internal cache
+    const int n_chars = vocab.token_to_piece(token, &piece[0], piece.size(), 0, special);
+    if (n_chars < 0) {
+        piece.resize(-n_chars);
+        int check = vocab.token_to_piece(token, &piece[0], piece.size(), 0, special);
+        GGML_ASSERT(check == -n_chars);
+    }
+    else {
+        piece.resize(n_chars);
+    }
+
+    return piece;
+}
+
+static void llama_escape_whitespace(std::string & text) {
+    replace_all(text, " ", "\xe2\x96\x81");
+}
+
+static void llama_unescape_whitespace(std::string & word) {
+    replace_all(word, "\xe2\x96\x81", " ");
+}
+
+static std::string llama_decode_text(const std::string & text) {
+    std::string decoded_text;
+
+    const auto cpts = unicode_cpts_from_utf8(text);
+    for (const auto cpt : cpts) {
+        const auto utf8 = unicode_cpt_to_utf8(cpt);
+        try {
+            decoded_text += unicode_utf8_to_byte(utf8);
+        } catch (const std::out_of_range & /*e*/) {
+            decoded_text += "[UNK_BYTE_0x";
+            for (const auto c : utf8) {
+                decoded_text += format("%02x", (uint8_t) c);
+            }
+            decoded_text += text + "]";
+        }
+    }
+
+    return decoded_text;
+}
+
+std::vector llama_vocab::impl::tokenize(
+        const std::string & raw_text,
+        bool add_special,
+        bool parse_special) const {
+    GGML_ASSERT(tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
+
+    std::vector output;
+    std::forward_list fragment_buffer;
+
+    if (!raw_text.empty()) {
+        fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
+        tokenizer_st_partition(fragment_buffer, parse_special);
+    }
+
+    switch (get_type()) {
+        case LLAMA_VOCAB_TYPE_SPM:
+            {
+                // OG tokenizer behavior:
+                //
+                // tokenizer.encode('', add_special_tokens=True)  returns [1]
+                // tokenizer.encode('', add_special_tokens=False) returns []
+
+                bool is_prev_special = true;  // prefix with space if first token
+
+                if (add_special && add_bos) {
+                    GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_bos_id);
+                    is_prev_special = true;
+                }
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text;
+
+                        // prefix with space if previous is special
+                        if (add_space_prefix && is_prev_special) {
+                            text = ' ';
+                        }
+
+                        text += fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        llama_escape_whitespace(text);
+                        llm_tokenizer_spm_session session(vocab);
+                        session.tokenize(text, output);
+                        is_prev_special = false;
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                        is_prev_special = true;
+                    }
+                }
+
+                if (add_special && add_bos && output.size() >= 2 && output[1] == special_bos_id) {
+                    LLAMA_LOG_WARN(
+                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
+                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
+                        "Are you sure this is what you want?\n", __FUNCTION__);
+                }
+
+                if (add_special && add_eos) {
+                    GGML_ASSERT(special_eos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_eos_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_BPE:
+            {
+                llm_tokenizer_bpe_session session(vocab, *static_cast(tokenizer.get()));
+                // it calls some other methods that are not exist in llm_tokenizer,
+                // here just cast it to bpe tokenizer object
+                if (add_special) {
+                    session.append_bos(output);
+                }
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        session.append(fragment.token, output);
+                    }
+                }
+
+                if (add_special) {
+                    session.append_eos(output);
+                    session.check_double_bos_eos(output);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_WPM:
+            {
+                if (add_special) {
+                    GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_bos_id);
+                }
+
+                llm_tokenizer_wpm_session session(vocab);
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+
+                if (add_special) {
+                    GGML_ASSERT(special_sep_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_sep_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_UGM:
+            {
+                if (add_special && add_bos) {
+                    GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_bos_id);
+                }
+                llm_tokenizer_ugm_session session(vocab, *static_cast(tokenizer.get()));
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+
+                if (add_special && add_bos && output.size() >= 2 && output[1] == special_bos_id) {
+                    LLAMA_LOG_WARN(
+                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
+                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
+                        "Are you sure this is what you want?\n", __FUNCTION__);
+                }
+
+                if (add_special && add_eos) {
+                    GGML_ASSERT(special_eos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_eos_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_RWKV:
+            {
+                llm_tokenizer_rwkv_session session(vocab, *static_cast(tokenizer.get()));
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_NONE:
+            GGML_ABORT("fatal error");
+    }
+
+    return output;
+}
+
+int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) const {
+    // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843
+    static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL;
+    const llama_token_attr attr = token_get_attr(token);
+    if (!special && (attr & attr_special)) {
+        return 0;
+    }
+
+    // copy piece chars to output text buffer
+    // skip up to 'lstrip' leading spaces before copying
+    auto _try_copy = [=] (const char * token, size_t size) -> int32_t {
+        if (size >= static_cast(std::numeric_limits::max())) {
+            GGML_ABORT("invalid token size: %zu exceeds int32_t limit", size);
+        }
+
+        for (int32_t i = 0; i < lstrip && size && *token == ' '; ++i) {
+            token++;
+            size--;
+        }
+        if (length < (int32_t)size) {
+            return -(int32_t) size;
+        }
+        memcpy(buf, token, size);
+        return (int32_t) size;
+    };
+
+    // if we have a cache - use it
+    {
+        const auto & cache = cache_token_to_piece;
+
+        if (!cache.empty()) {
+            const auto & result = cache.at(token);
+            return _try_copy(result.data(), result.size());
+        }
+    }
+
+    if (0 <= token && token < (int32_t) id_to_token.size()) {
+        const std::string & token_text = id_to_token[token].text;
+        switch (get_type()) {
+            case LLAMA_VOCAB_TYPE_WPM:
+            case LLAMA_VOCAB_TYPE_SPM:
+            case LLAMA_VOCAB_TYPE_UGM: {
+                // NOTE: we accept all unsupported token types,
+                // suppressing them like CONTROL tokens.
+                if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
+                    return _try_copy(token_text.data(), token_text.size());
+                }
+                if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
+                    std::string result = token_text;
+                    llama_unescape_whitespace(result);
+                    return _try_copy(result.data(), result.size());
+                }
+                if (attr & LLAMA_TOKEN_ATTR_BYTE) {
+                    char byte = (char) token_to_byte(token);
+                    return _try_copy((char*) &byte, 1);
+                }
+                break;
+            }
+            case LLAMA_VOCAB_TYPE_BPE: {
+                // NOTE: we accept all unsupported token types,
+                // suppressing them like CONTROL tokens.
+                if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
+                    return _try_copy(token_text.data(), token_text.size());
+                }
+                if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
+                    std::string result = llama_decode_text(token_text);
+                    return _try_copy(result.data(), result.size());
+                }
+                break;
+            }
+            case LLAMA_VOCAB_TYPE_RWKV: {
+                std::vector result = llama_unescape_rwkv_token(token_text);
+
+                // If we don't have enough space, return an error
+                if (result.size() > (size_t)length) {
+                    return -(int)result.size();
+                }
+
+                memcpy(buf, result.data(), result.size());
+                return (int)result.size();
+            }
+            default:
+                GGML_ABORT("fatal error");
+        }
+    }
+
+    return 0;
+}
+
+const std::string & llama_vocab::impl::token_to_piece(llama_token token) const {
+    return cache_token_to_piece.at(token);
+}
+
+int32_t llama_vocab::impl::detokenize(
+               const llama_token * tokens,
+                         int32_t   n_tokens,
+                            char * text,
+                         int32_t   text_len_max,
+                            bool   remove_special,
+                            bool   unparse_special) const {
+    if (type == LLAMA_VOCAB_TYPE_NONE) {
+        return 0;
+    }
+
+    GGML_ASSERT(tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
+
+    int32_t avail = text_len_max;
+    int32_t total = 0;
+
+    // remove the leading space
+    bool remove_space = add_space_prefix;
+
+    if (remove_special && add_bos) {
+        if (n_tokens > 0 && tokens[0] == special_bos_id) {
+            remove_space = false;
+            n_tokens--;
+            tokens++;
+        }
+    }
+
+    if (remove_special && add_eos) {
+        if (n_tokens > 0 && tokens[n_tokens - 1] == special_eos_id) {
+            n_tokens--;
+        }
+    }
+
+    for (int32_t i = 0; i < n_tokens; ++i) {
+        GGML_ASSERT(avail >= 0);
+        int32_t n_chars = token_to_piece(tokens[i], text, avail, remove_space, unparse_special);
+        remove_space = false;
+        if (n_chars < 0) {
+            avail = 0;
+            total -= n_chars;
+        } else if (n_chars > 0) {
+            avail -= n_chars;
+            text  += n_chars;
+            total += n_chars;
+        }
+    }
+
+    if (total > text_len_max) {
+        return -total;
+    }
+
+    if (clean_spaces) {
+        text -= total;  // restart text
+
+        // first pass: characters ?!.,  //TODO: where do these characters come from?
+        const int32_t total1 = total;
+        total = total ? 1 : 0;
+        for (int32_t i = 1; i < total1; ++i) {
+            const char x = text[i];
+            if (text[i - 1] == ' ') {
+                if (x == '?' || x == '!' || x == '.' || x == ',') {  // " ?", " !", " .", " ,"
+                    total--;  // remove space
+                }
+            }
+            text[total++] = x;
+        }
+
+        // second pass: strip single apostrophe between spaces
+        const int32_t total2 = total;
+        total = total ? 1 : 0;
+        for (int32_t i = 1; i < total2; ++i) {
+            const char x = text[i];
+            if (x == '\'' && i + 1 < total2 && text[i - 1] == ' ' && text[i + 1] == ' ') {  // " ' "
+                total--;           // remove prev space
+                text[++i] = '\0';  // remove next space
+            }
+            text[total++] = x;
+        }
+
+        // third pass: apostrophe contractions  //NOTE: this makes sense?
+        const int32_t total3 = total;
+        total = total ? 1 : 0;
+        for (int32_t i = 1; i < total3; ++i) {
+            const char x = text[i];
+            if (text[i - 1] == ' ') {
+                if (x == '\'' && i + 1 < total3) {
+                    const char x1 = text[i + 1];
+                    if (x1 == 't' || x1 == 'd') {  // " 't", " 'd"
+                        //total--;  // remove space
+                    } else if (x1 == 's' || x1 == 'm') {  // " 's", " 'm"
+                        total--;  // remove space
+                    } else if (i + 2 < total3) {
+                        const char x2 = text[i + 2];
+                        if ((x1 == 'l' && x2 == 'l')) {  // " 'll"
+                            //total--;  // remove space
+                        } else if ((x1 == 'r' && x2 == 'e') || (x1 == 'v' && x2 == 'e')) {  // " 're", " 've"
+                            total--;  // remove space
+                        } else {
+                            //total--;  // remove space
+                        }
+                    } else {
+                        //total--;  // remove space
+                    }
+                }
+            }
+            text[total++] = x;
+        }
+    }
+
+    return total <= text_len_max ? total : -total;
+}
+
+void llama_vocab::impl::print_info() const {
+    LLAMA_LOG_INFO("%s: vocab type       = %s\n",     __func__, type_name().c_str());
+    LLAMA_LOG_INFO("%s: n_vocab          = %u\n",     __func__, vocab.n_tokens());
+    LLAMA_LOG_INFO("%s: n_merges         = %u\n",     __func__, (uint32_t) bpe_ranks.size());
+
+    // special tokens
+    if (special_bos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, special_bos_id,     id_to_token.at(special_bos_id).text.c_str() );  }
+    if (special_eos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, special_eos_id,     id_to_token.at(special_eos_id).text.c_str() );  }
+    if (special_eot_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, special_eot_id,     id_to_token.at(special_eot_id).text.c_str() );  }
+    if (special_eom_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOM token        = %d '%s'\n", __func__, special_eom_id,     id_to_token.at(special_eom_id).text.c_str() );  }
+    if (special_unk_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, special_unk_id,     id_to_token.at(special_unk_id).text.c_str() );  }
+    if (special_sep_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, special_sep_id,     id_to_token.at(special_sep_id).text.c_str() );  }
+    if (special_pad_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, special_pad_id,     id_to_token.at(special_pad_id).text.c_str() );  }
+    if (special_mask_id != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, special_mask_id,    id_to_token.at(special_mask_id).text.c_str() ); }
+
+    if (linefeed_id != LLAMA_TOKEN_NULL)        { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, linefeed_id,        id_to_token.at(linefeed_id).text.c_str() ); }
+
+    if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token    = %d '%s'\n", __func__, special_fim_pre_id, id_to_token.at(special_fim_pre_id).text.c_str() ); }
+    if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token    = %d '%s'\n", __func__, special_fim_suf_id, id_to_token.at(special_fim_suf_id).text.c_str() ); }
+    if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token    = %d '%s'\n", __func__, special_fim_mid_id, id_to_token.at(special_fim_mid_id).text.c_str() ); }
+    if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token    = %d '%s'\n", __func__, special_fim_pad_id, id_to_token.at(special_fim_pad_id).text.c_str() ); }
+    if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token    = %d '%s'\n", __func__, special_fim_rep_id, id_to_token.at(special_fim_rep_id).text.c_str() ); }
+    if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token    = %d '%s'\n", __func__, special_fim_sep_id, id_to_token.at(special_fim_sep_id).text.c_str() ); }
+
+    for (const auto & id : special_eog_ids) {
+        LLAMA_LOG_INFO( "%s: EOG token        = %d '%s'\n", __func__, id, id_to_token.at(id).text.c_str() );
+    }
+
+    LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len);
+}
+
+llama_vocab::llama_vocab() : pimpl(new impl(*this)) {
+}
+
+llama_vocab::~llama_vocab() {
+}
+
+void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
+    pimpl->load(ml, kv);
+}
+
+std::string llama_vocab::get_tokenizer_model() const {
+    return pimpl->tokenizer_model;
+}
+
+std::string llama_vocab::get_tokenizer_pre() const {
+    return pimpl->tokenizer_pre;
+}
+
+enum llama_vocab_type llama_vocab::get_type() const {
+    return pimpl->type;
+}
+
+enum llama_vocab_pre_type llama_vocab::get_pre_type() const {
+    return pimpl->pre_type;
+}
+
+uint32_t llama_vocab::n_tokens() const {
+    return (uint32_t) pimpl->id_to_token.size();
+}
+
+uint32_t llama_vocab::n_token_types() const {
+    return (uint32_t) pimpl->n_token_types;
+}
+
+std::string llama_vocab::type_name() const{
+    return pimpl->type_name();
+}
+
+bool llama_vocab::is_normal(llama_token id) const {
+    return pimpl->is_normal(id);
+}
+
+bool llama_vocab::is_unknown(llama_token id) const {
+    return pimpl->is_unknown(id);
+}
+
+bool llama_vocab::is_control(llama_token id) const {
+    return pimpl->is_control(id);
+}
+
+bool llama_vocab::is_byte(llama_token id) const {
+    return pimpl->is_byte(id);
+}
+
+bool llama_vocab::is_user_defined(llama_token id) const {
+    return pimpl->is_user_defined(id);
+}
+
+bool llama_vocab::is_unused(llama_token id) const {
+    return pimpl->is_unused(id);
+}
+
+bool llama_vocab::is_eog(llama_token id) const {
+    return pimpl->is_eog(id);
+}
+
+uint8_t llama_vocab::token_to_byte(llama_token id) const {
+    return pimpl->token_to_byte(id);
+}
+
+llama_token llama_vocab::byte_to_token(uint8_t ch) const {
+    GGML_ASSERT(get_type() != LLAMA_VOCAB_TYPE_NONE);
+    static const char * hex = "0123456789ABCDEF";
+    switch (get_type()) {
+        case LLAMA_VOCAB_TYPE_SPM:
+        case LLAMA_VOCAB_TYPE_UGM: {
+            const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
+            auto token = pimpl->token_to_id.find(buf);
+            if (token != pimpl->token_to_id.end()) {
+                return (*token).second;
+            }
+            // Try to fall back to just the byte as a string
+            const char buf2[2] = { (char)ch, 0 };
+            return pimpl->token_to_id.at(buf2);
+        }
+        case LLAMA_VOCAB_TYPE_WPM:
+        case LLAMA_VOCAB_TYPE_BPE: {
+            return pimpl->token_to_id.at(unicode_byte_to_utf8(ch));
+        }
+        default:
+            GGML_ABORT("fatal error");
+    }
+}
+
+llama_token llama_vocab::text_to_token(const std::string & text) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    auto it = pimpl->token_to_id.find(text);
+    if (it != pimpl->token_to_id.end()) {
+        return (*it).second;
+    }
+    return LLAMA_TOKEN_NULL;
+}
+
+const llama_vocab::token_data & llama_vocab::get_token_data(llama_token id) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    return pimpl->id_to_token.at(id);
+}
+
+const char * llama_vocab::token_get_text(llama_token id) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    return pimpl->id_to_token.at(id).text.c_str();
+}
+
+float llama_vocab::token_get_score(llama_token id) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    return pimpl->id_to_token.at(id).score;
+}
+
+llama_token_attr llama_vocab::token_get_attr(llama_token id) const {
+    return pimpl->token_get_attr(id);
+}
+
+llama_token llama_vocab::token_bos() const {
+    return pimpl->special_bos_id;
+}
+
+llama_token llama_vocab::token_eos() const {
+    return pimpl->special_eos_id;
+}
+
+llama_token llama_vocab::token_eot() const {
+    return pimpl->special_eot_id;
+}
+
+llama_token llama_vocab::token_eom() const {
+    return pimpl->special_eom_id;
+}
+
+llama_token llama_vocab::token_unk() const {
+    return pimpl->special_unk_id;
+}
+
+llama_token llama_vocab::token_sep() const {
+    return pimpl->special_sep_id;
+}
+
+llama_token llama_vocab::token_nl() const {
+    return pimpl->linefeed_id;
+}
+
+llama_token llama_vocab::token_pad() const {
+    return pimpl->special_pad_id;
+}
+
+llama_token llama_vocab::token_prefix() const {
+    return pimpl->special_fim_pre_id;
+}
+
+llama_token llama_vocab::token_middle() const {
+    return pimpl->special_fim_mid_id;
+}
+
+llama_token llama_vocab::token_suffix() const {
+    return pimpl->special_fim_suf_id;
+}
+
+llama_token llama_vocab::token_fim_pre() const {
+    return pimpl->special_fim_pre_id;
+}
+
+llama_token llama_vocab::token_fim_suf() const {
+    return pimpl->special_fim_suf_id;
+}
+
+llama_token llama_vocab::token_fim_mid() const {
+    return pimpl->special_fim_mid_id;
+}
+
+llama_token llama_vocab::token_fim_pad() const {
+    return pimpl->special_fim_pad_id;
+}
+
+llama_token llama_vocab::token_fim_rep() const {
+    return pimpl->special_fim_rep_id;
+}
+
+llama_token llama_vocab::token_fim_sep() const {
+    return pimpl->special_fim_sep_id;
+}
+
+bool llama_vocab::get_add_space_prefix() const {
+    return pimpl->add_space_prefix;
+}
+
+bool llama_vocab::get_add_bos() const {
+    return pimpl->add_bos;
+}
+
+bool llama_vocab::get_add_eos() const {
+    return pimpl->add_eos;
+}
+
+bool llama_vocab::get_ignore_merges() const {
+    return pimpl->ignore_merges;
+}
+
+bool llama_vocab::get_clean_spaces() const {
+    return pimpl->clean_spaces;
+}
+
+bool llama_vocab::get_remove_extra_whitespaces() const {
+    return pimpl->remove_extra_whitespaces;
+}
+
+bool llama_vocab::get_escape_whitespaces() const {
+    return pimpl->escape_whitespaces;
+}
+
+bool llama_vocab::get_treat_whitespace_as_suffix() const {
+    return pimpl->treat_whitespace_as_suffix;
+}
+
+int llama_vocab::max_token_len() const {
+    return pimpl->max_token_len;
+}
+
+int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
+    GGML_ASSERT(token_left.find(' ')   == std::string::npos);
+    GGML_ASSERT(token_left.find('\n')  == std::string::npos);
+    GGML_ASSERT(token_right.find(' ')  == std::string::npos);
+    GGML_ASSERT(token_right.find('\n') == std::string::npos);
+
+    auto it = pimpl->bpe_ranks.find(std::make_pair(token_left, token_right));
+    if (it == pimpl->bpe_ranks.end()) {
+        return -1;
+    }
+
+    return it->second;
+}
+
+std::vector llama_vocab::get_bpe_merges() const {
+    std::vector result(pimpl->bpe_ranks.size());
+
+    for (const auto & pair : pimpl->bpe_ranks) {
+        result[pair.second] = pair.first.first + " " + pair.first.second;
+    }
+
+    return result;
+}
+
+std::vector llama_vocab::get_precompiled_charsmap() const {
+    return pimpl->precompiled_charsmap;
+}
+
+int32_t llama_vocab::tokenize(
+                  const char * text,
+                     int32_t   text_len,
+                 llama_token * tokens,
+                     int32_t   n_tokens_max,
+                        bool   add_special,
+                        bool   parse_special) const {
+    auto res = tokenize(std::string(text, text_len), add_special, parse_special);
+    if (n_tokens_max < (int) res.size()) {
+        // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
+        return -((int) res.size());
+    }
+
+    for (size_t i = 0; i < res.size(); i++) {
+        tokens[i] = res[i];
+    }
+
+    return res.size();
+}
+
+std::vector llama_vocab::tokenize(
+        const std::string & raw_text,
+        bool add_special,
+        bool parse_special) const {
+    return pimpl->tokenize(raw_text, add_special, parse_special);
+}
+
+const std::string & llama_vocab::token_to_piece(llama_token token) const {
+    return pimpl->token_to_piece(token);
+}
+
+int32_t llama_vocab::token_to_piece(llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) const {
+    return pimpl->token_to_piece(token, buf, length, lstrip, special);
+}
+
+int32_t llama_vocab::detokenize(
+               const llama_token * tokens,
+                         int32_t   n_tokens,
+                            char * text,
+                         int32_t   text_len_max,
+                            bool   remove_special,
+                            bool   unparse_special) const {
+    return pimpl->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
+}
+
+std::string llama_vocab::detokenize(const std::vector & tokens, bool special) const {
+    std::string text;
+    text.resize(std::max(text.capacity(), tokens.size()));
+    int32_t n_chars = detokenize(tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
+    if (n_chars < 0) {
+        text.resize(-n_chars);
+        n_chars = detokenize(tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
+        GGML_ASSERT(n_chars <= (int32_t)text.size());  // whitespace trimming is performed after per-token detokenization
+    }
+
+    text.resize(n_chars);
+
+    // NOTE: the original tokenizer decodes bytes after collecting the pieces.
+    return text;
+}
+
+void llama_vocab::print_info() const {
+    pimpl->print_info();
+}
+
+//
+// interface implementation
+//
+
+int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab) {
+    return vocab->n_tokens();
+}
+
+// deprecated
+int32_t llama_n_vocab(const struct llama_vocab * vocab) {
+    return llama_vocab_n_tokens(vocab);
+}
+
+enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab) {
+    return vocab->get_type();
+}
+
+const char * llama_vocab_get_text(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->token_get_text(token);
+}
+
+float llama_vocab_get_score(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->token_get_score(token);
+}
+
+enum llama_token_attr llama_vocab_get_attr(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->token_get_attr(token);
+}
+
+bool llama_vocab_is_eog(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->is_eog(token);
+}
+
+bool llama_vocab_is_control(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->is_control(token);
+}
+
+llama_token llama_vocab_bos(const struct llama_vocab * vocab) {
+    return vocab->token_bos();
+}
+
+llama_token llama_vocab_eos(const struct llama_vocab * vocab) {
+    return vocab->token_eos();
+}
+
+llama_token llama_vocab_eot(const struct llama_vocab * vocab) {
+    return vocab->token_eot();
+}
+
+// deprecated
+llama_token llama_vocab_cls(const struct llama_vocab * vocab) {
+    return vocab->token_bos();
+}
+
+llama_token llama_vocab_sep(const struct llama_vocab * vocab) {
+    return vocab->token_sep();
+}
+
+llama_token llama_vocab_nl (const struct llama_vocab * vocab) {
+    return vocab->token_nl();
+}
+
+llama_token llama_vocab_pad(const struct llama_vocab * vocab) {
+    return vocab->token_pad();
+}
+
+bool llama_vocab_get_add_bos(const struct llama_vocab * vocab) {
+    return vocab->get_add_bos();
+}
+
+bool llama_vocab_get_add_eos(const struct llama_vocab * vocab) {
+    return vocab->get_add_eos();
+}
+
+llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab) {
+    return vocab->token_fim_pre();
+}
+
+llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab) {
+    return vocab->token_fim_suf();
+}
+
+llama_token llama_vocab_fim_mid(const struct llama_vocab * vocab) {
+    return vocab->token_fim_mid();
+}
+
+llama_token llama_vocab_fim_pad(const struct llama_vocab * vocab) {
+    return vocab->token_fim_pad();
+}
+
+llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab) {
+    return vocab->token_fim_rep();
+}
+
+llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab) {
+    return vocab->token_fim_sep();
+}
+
+// deprecated
+const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_get_text(vocab, token);
+}
+
+// deprecated
+float llama_token_get_score(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_get_score(vocab, token);
+}
+
+// deprecated
+enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_get_attr(vocab, token);
+}
+
+// deprecated
+bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_is_eog(vocab, token);
+}
+
+// deprecated
+bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_is_control(vocab, token);
+}
+
+// deprecated
+llama_token llama_token_bos(const struct llama_vocab * vocab) {
+    return llama_vocab_bos(vocab);
+}
+
+// deprecated
+llama_token llama_token_eos(const struct llama_vocab * vocab) {
+    return llama_vocab_eos(vocab);
+}
+
+// deprecated
+llama_token llama_token_eot(const struct llama_vocab * vocab) {
+    return llama_vocab_eot(vocab);
+}
+
+// deprecated
+llama_token llama_token_cls(const struct llama_vocab * vocab) {
+    //return llama_vocab_cls(vocab);
+    return llama_vocab_bos(vocab); // avoid deprecation warning
+}
+
+// deprecated
+llama_token llama_token_sep(const struct llama_vocab * vocab) {
+    return llama_vocab_sep(vocab);
+}
+
+// deprecated
+llama_token llama_token_nl (const struct llama_vocab * vocab) {
+    return llama_vocab_nl(vocab);
+}
+
+// deprecated
+llama_token llama_token_pad(const struct llama_vocab * vocab) {
+    return llama_vocab_pad(vocab);
+}
+
+// deprecated
+bool llama_add_bos_token(const struct llama_vocab * vocab) {
+    return llama_vocab_get_add_bos(vocab);
+}
+
+// deprecated
+bool llama_add_eos_token(const struct llama_vocab * vocab) {
+    return llama_vocab_get_add_eos(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_pre(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_pre(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_suf(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_suf(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_mid(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_mid(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_pad(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_pad(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_rep(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_rep(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_sep(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_sep(vocab);
+}
+
+//
+// tokenization
+//
+
+int32_t llama_tokenize(
+    const struct llama_vocab * vocab,
+                  const char * text,
+                     int32_t   text_len,
+                 llama_token * tokens,
+                     int32_t   n_tokens_max,
+                        bool   add_special,
+                        bool   parse_special) {
+    return vocab->tokenize(text, text_len, tokens, n_tokens_max, add_special, parse_special);
+}
+
+int32_t llama_token_to_piece(
+    const struct llama_vocab * vocab,
+                 llama_token   token,
+                        char * buf,
+                     int32_t   length,
+                     int32_t   lstrip,
+                        bool   special) {
+    return vocab->token_to_piece(token, buf, length, lstrip, special);
+}
+
+int32_t llama_detokenize(
+    const struct llama_vocab * vocab,
+           const llama_token * tokens,
+                     int32_t   n_tokens,
+                        char * text,
+                     int32_t   text_len_max,
+                        bool   remove_special,
+                        bool   unparse_special) {
+    return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
+}
+
diff --git a/src/llama-vocab.h b/src/llama-vocab.h
new file mode 100644
index 0000000000000000000000000000000000000000..daa6cf3082f90a3dc1ace5fff4b379bd3220c51e
--- /dev/null
+++ b/src/llama-vocab.h
@@ -0,0 +1,131 @@
+#pragma once
+
+#include "llama.h"
+
+#include 
+#include 
+#include 
+
+struct LLM_KV;
+struct llama_model_loader;
+
+struct llama_vocab {
+    struct token_data {
+        std::string      text;
+        float            score;
+        llama_token_attr attr;
+    };
+
+    llama_vocab();
+    ~llama_vocab();
+
+    void load(llama_model_loader & ml, const LLM_KV & kv);
+
+    std::string get_tokenizer_model() const;
+    std::string get_tokenizer_pre() const;
+
+    enum llama_vocab_type     get_type()     const;
+    enum llama_vocab_pre_type get_pre_type() const;
+
+    uint32_t n_tokens() const;
+    uint32_t n_token_types() const;
+
+    std::string type_name() const;
+
+    bool is_normal      (llama_token id) const;
+    bool is_unknown     (llama_token id) const;
+    bool is_control     (llama_token id) const;
+    bool is_byte        (llama_token id) const;
+    bool is_user_defined(llama_token id) const;
+    bool is_unused      (llama_token id) const;
+    bool is_eog         (llama_token id) const;
+
+    uint8_t     token_to_byte(llama_token id) const;
+    llama_token byte_to_token(uint8_t ch)     const;
+
+    llama_token text_to_token(const std::string & text) const;
+
+    const token_data & get_token_data(llama_token id) const;
+
+    const char *     token_get_text (llama_token id) const;
+    float            token_get_score(llama_token id) const;
+    llama_token_attr token_get_attr (llama_token id) const;
+
+    llama_token token_bos() const;
+    llama_token token_eos() const;
+    llama_token token_eot() const;
+    llama_token token_eom() const;
+    llama_token token_unk() const;
+    llama_token token_sep() const;
+    llama_token token_nl () const;
+    llama_token token_pad() const;
+
+    llama_token token_prefix() const;
+    llama_token token_middle() const;
+    llama_token token_suffix() const;
+
+    llama_token token_fim_pre() const;
+    llama_token token_fim_suf() const;
+    llama_token token_fim_mid() const;
+    llama_token token_fim_pad() const;
+    llama_token token_fim_rep() const;
+    llama_token token_fim_sep() const;
+
+    bool get_add_space_prefix          () const;
+    bool get_add_bos                   () const;
+    bool get_add_eos                   () const;
+    bool get_ignore_merges             () const;
+    bool get_clean_spaces              () const;
+    bool get_remove_extra_whitespaces  () const;
+    bool get_escape_whitespaces        () const;
+    bool get_treat_whitespace_as_suffix() const;
+
+    int max_token_len() const;
+
+    int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
+    std::vector get_bpe_merges() const;
+
+    std::vector get_precompiled_charsmap() const;
+
+    int32_t tokenize(
+                   const char * text,
+                      int32_t   text_len,
+                  llama_token * tokens,
+                      int32_t   n_tokens_max,
+                         bool   add_special,
+                         bool   parse_special) const;
+
+    std::vector tokenize(
+            const std::string & raw_text,
+                         bool   add_special,
+                         bool   parse_special = false) const;
+
+    // does not write null-terminator to buf
+    int32_t token_to_piece(
+                  llama_token   token,
+                         char * buf,
+                      int32_t   length,
+                      int32_t   lstrip,
+                         bool   special) const;
+
+    // use cached data
+    const std::string & token_to_piece(llama_token token) const;
+
+    int32_t detokenize(
+            const llama_token * tokens,
+                      int32_t   n_tokens,
+                         char * text,
+                      int32_t   text_len_max,
+                         bool   remove_special,
+                         bool   unparse_special) const;
+
+    std::string detokenize(
+            const std::vector & tokens,
+                                      bool   special) const;
+
+    void print_info() const;
+
+private:
+    struct impl;
+    std::unique_ptr pimpl;
+};
diff --git a/src/llama.cpp b/src/llama.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..34906cdb62844875bf572a2a1df6118a2a8aa885
--- /dev/null
+++ b/src/llama.cpp
@@ -0,0 +1,358 @@
+#include "llama-impl.h"
+
+#include "llama-chat.h"
+#include "llama-mmap.h"
+#include "llama-vocab.h"
+#include "llama-model-loader.h"
+#include "llama-model-saver.h"
+#include "llama-model.h"
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+//
+// interface implementation
+//
+
+struct llama_sampler_chain_params llama_sampler_chain_default_params() {
+    struct llama_sampler_chain_params result = {
+        /*.no_perf                     =*/ true,
+    };
+
+    return result;
+}
+
+size_t llama_max_devices(void) {
+    return 16;
+}
+
+bool llama_supports_mmap(void) {
+    return llama_mmap::SUPPORTED;
+}
+
+bool llama_supports_mlock(void) {
+    return llama_mlock::SUPPORTED;
+}
+
+bool llama_supports_gpu_offload(void) {
+    return ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU) != nullptr ||
+           llama_supports_rpc();
+}
+
+bool llama_supports_rpc(void) {
+    return ggml_backend_reg_by_name("RPC") != nullptr;
+}
+
+void llama_backend_init(void) {
+    ggml_time_init();
+
+    // needed to initialize f16 tables
+    {
+        struct ggml_init_params params = { 0, NULL, false };
+        struct ggml_context * ctx = ggml_init(params);
+        ggml_free(ctx);
+    }
+}
+
+void llama_numa_init(enum ggml_numa_strategy numa) {
+    if (numa != GGML_NUMA_STRATEGY_DISABLED) {
+        auto * dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
+        GGML_ASSERT(dev && "CPU backend is not loaded");
+        auto * reg = ggml_backend_dev_backend_reg(dev);
+        auto * numa_init_fn = (decltype(ggml_numa_init) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_numa_init");
+        numa_init_fn(numa);
+    }
+}
+
+void llama_backend_free(void) {
+    ggml_quantize_free();
+}
+
+int64_t llama_time_us(void) {
+    return ggml_time_us();
+}
+
+// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
+static int llama_model_load(const std::string & fname, std::vector & splits, llama_model & model, llama_model_params & params) {
+    // loading time will be recalculated after the first eval, so
+    // we take page faults deferred by mmap() into consideration
+    model.t_load_us = 0;
+    time_meas tm(model.t_load_us);
+
+    model.t_start_us = tm.t_start_us;
+
+    try {
+        llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides, params.tensor_buft_overrides);
+
+        ml.print_info();
+
+        model.hparams.vocab_only = params.vocab_only;
+
+        try {
+            model.load_arch(ml);
+        } catch(const std::exception & e) {
+            throw std::runtime_error("error loading model architecture: " + std::string(e.what()));
+        }
+        try {
+            model.load_hparams(ml);
+        } catch(const std::exception & e) {
+            throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
+        }
+        try {
+            model.load_vocab(ml);
+        } catch(const std::exception & e) {
+            throw std::runtime_error("error loading model vocabulary: " + std::string(e.what()));
+        }
+
+        model.load_stats(ml);
+        model.print_info();
+
+        if (params.vocab_only) {
+            LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
+            return 0;
+        }
+
+        if (!model.load_tensors(ml)) {
+            return -2;
+        }
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
+        return -1;
+    }
+
+    return 0;
+}
+
+static struct llama_model * llama_model_load_from_file_impl(
+        const std::string & path_model,
+        std::vector & splits,
+        struct llama_model_params params) {
+    ggml_time_init();
+
+    if (!params.vocab_only && ggml_backend_reg_count() == 0) {
+        LLAMA_LOG_ERROR("%s: no backends are loaded. hint: use ggml_backend_load() or ggml_backend_load_all() to load a backend before calling this function\n", __func__);
+        return nullptr;
+    }
+
+    unsigned cur_percentage = 0;
+    if (params.progress_callback == NULL) {
+        params.progress_callback_user_data = &cur_percentage;
+        params.progress_callback = [](float progress, void * ctx) {
+            unsigned * cur_percentage_p = (unsigned *) ctx;
+            unsigned percentage = (unsigned) (100 * progress);
+            while (percentage > *cur_percentage_p) {
+                *cur_percentage_p = percentage;
+                LLAMA_LOG_CONT(".");
+                if (percentage >= 100) {
+                    LLAMA_LOG_CONT("\n");
+                }
+            }
+            return true;
+        };
+    }
+
+    llama_model * model = new llama_model(params);
+
+    // create list of devices to use with this model
+    if (params.devices) {
+        for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) {
+            model->devices.push_back(*dev);
+        }
+    } else {
+        std::vector rpc_servers;
+        // use all available devices
+        for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
+            ggml_backend_dev_t dev = ggml_backend_dev_get(i);
+            switch (ggml_backend_dev_type(dev)) {
+                case GGML_BACKEND_DEVICE_TYPE_CPU:
+                case GGML_BACKEND_DEVICE_TYPE_ACCEL:
+                    // skip CPU backends since they are handled separately
+                    break;
+
+                case GGML_BACKEND_DEVICE_TYPE_GPU:
+                    ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
+                    if (ggml_backend_reg_name(reg) == std::string("RPC")) {
+                        rpc_servers.push_back(dev);
+                    } else {
+                        model->devices.push_back(dev);
+                    }
+                    break;
+            }
+        }
+        // add RPC servers at the front of the list
+        if (!rpc_servers.empty()) {
+            model->devices.insert(model->devices.begin(), rpc_servers.begin(), rpc_servers.end());
+        }
+    }
+
+    // if using single GPU mode, remove all except the main GPU
+    if (params.split_mode == LLAMA_SPLIT_MODE_NONE) {
+        if (params.main_gpu < 0) {
+            model->devices.clear();
+        } else {
+            if (params.main_gpu >= (int)model->devices.size()) {
+                LLAMA_LOG_ERROR("%s: invalid value for main_gpu: %d (available devices: %zu)\n", __func__, params.main_gpu, model->devices.size());
+                llama_model_free(model);
+                return nullptr;
+            }
+            ggml_backend_dev_t main_gpu = model->devices[params.main_gpu];
+            model->devices.clear();
+            model->devices.push_back(main_gpu);
+        }
+    }
+
+    for (auto * dev : model->devices) {
+        size_t free, total; // NOLINT
+        ggml_backend_dev_memory(dev, &free, &total);
+        LLAMA_LOG_INFO("%s: using device %s (%s) - %zu MiB free\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), free/1024/1024);
+    }
+
+    const int status = llama_model_load(path_model, splits, *model, params);
+    GGML_ASSERT(status <= 0);
+    if (status < 0) {
+        if (status == -1) {
+            LLAMA_LOG_ERROR("%s: failed to load model\n", __func__);
+        } else if (status == -2) {
+            LLAMA_LOG_INFO("%s: cancelled model load\n", __func__);
+        }
+
+        llama_model_free(model);
+        return nullptr;
+    }
+
+    return model;
+}
+
+// deprecated
+struct llama_model * llama_load_model_from_file(
+        const char * path_model,
+        struct llama_model_params params) {
+    return llama_model_load_from_file(path_model, params);
+}
+
+struct llama_model * llama_model_load_from_file(
+        const char * path_model,
+        struct llama_model_params params) {
+    std::vector splits = {};
+    return llama_model_load_from_file_impl(path_model, splits, params);
+}
+
+struct llama_model * llama_model_load_from_splits(
+        const char ** paths,
+        size_t n_paths,
+        struct llama_model_params params) {
+    std::vector splits;
+    if (n_paths == 0) {
+        LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__);
+        return nullptr;
+    }
+    for (size_t i = 0; i < n_paths; ++i) {
+        splits.push_back(paths[i]);
+    }
+    return llama_model_load_from_file_impl(splits.front(), splits, params);
+}
+
+void llama_model_save_to_file(const struct llama_model * model, const char * path_model) {
+    llama_model_saver ms(*model);
+    ms.add_kv_from_model();
+    ms.add_tensors_from_model();
+    ms.save(path_model);
+}
+
+//
+// chat templates
+//
+
+int32_t llama_chat_apply_template(
+                              const char * tmpl,
+         const struct llama_chat_message * chat,
+                                  size_t   n_msg,
+                                    bool   add_ass,
+                                    char * buf,
+                                 int32_t   length) {
+    const std::string curr_tmpl(tmpl == nullptr ? "chatml" : tmpl);
+
+    // format the chat to string
+    std::vector chat_vec;
+    chat_vec.resize(n_msg);
+    for (size_t i = 0; i < n_msg; i++) {
+        chat_vec[i] = &chat[i];
+    }
+
+    std::string formatted_chat;
+    llm_chat_template detected_tmpl = llm_chat_detect_template(curr_tmpl);
+    if (detected_tmpl == LLM_CHAT_TEMPLATE_UNKNOWN) {
+        return -1;
+    }
+    int32_t res = llm_chat_apply_template(detected_tmpl, chat_vec, formatted_chat, add_ass);
+    if (res < 0) {
+        return res;
+    }
+    if (buf && length > 0) {
+        strncpy(buf, formatted_chat.c_str(), length);
+    }
+    return res;
+}
+
+//
+// model split
+//
+
+int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
+    static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
+    if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {
+        return strlen(split_path);
+    }
+    return 0;
+}
+
+int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count) {
+    std::string str_split_path(split_path);
+    char postfix[32];
+    snprintf(postfix, 32, "-%05d-of-%05d.gguf", split_no + 1, split_count);
+    std::string str_postfix(postfix);
+
+    // check if split_prefix ends with postfix
+    int size_prefix = str_split_path.size() - str_postfix.size();
+    if (size_prefix > 0 && str_split_path.find(str_postfix, size_prefix) != std::string::npos) {
+        snprintf(split_prefix, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path);
+        return size_prefix;
+    }
+
+    return 0;
+}
+
+const char * llama_print_system_info(void) {
+    static std::string s;
+    s.clear(); // Clear the string, since it's static, otherwise it will accumulate data from previous calls.
+
+    for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
+        auto * reg = ggml_backend_reg_get(i);
+        auto * get_features_fn = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features");
+        if (get_features_fn) {
+            ggml_backend_feature * features = get_features_fn(reg);
+            s += ggml_backend_reg_name(reg);
+            s += " : ";
+            for (; features->name; features++) {
+                s += features->name;
+                s += " = ";
+                s += features->value;
+                s += " | ";
+            }
+        }
+    }
+
+    return s.c_str();
+}
+
diff --git a/src/unicode-data.cpp b/src/unicode-data.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..04dcd7fcfbcea90a4b23a2e62855b4ccbbb3521c
--- /dev/null
+++ b/src/unicode-data.cpp
@@ -0,0 +1,7034 @@
+// generated with scripts/gen-unicode-data.py
+
+#include "unicode-data.h"
+
+#include 
+#include 
+#include 
+#include 
+
+const std::initializer_list> unicode_ranges_flags = {  // start, flags // last=next_start-1
+{0x000000, 0x0080},
+{0x000020, 0x0008},
+{0x000021, 0x0020},
+{0x000024, 0x0040},
+{0x000025, 0x0020},
+{0x00002B, 0x0040},
+{0x00002C, 0x0020},
+{0x000030, 0x0002},
+{0x00003A, 0x0020},
+{0x00003C, 0x0040},
+{0x00003F, 0x0020},
+{0x000041, 0x0004},
+{0x00005B, 0x0020},
+{0x00005E, 0x0040},
+{0x00005F, 0x0020},
+{0x000060, 0x0040},
+{0x000061, 0x0004},
+{0x00007B, 0x0020},
+{0x00007C, 0x0040},
+{0x00007D, 0x0020},
+{0x00007E, 0x0040},
+{0x00007F, 0x0080},
+{0x0000A0, 0x0008},
+{0x0000A1, 0x0020},
+{0x0000A2, 0x0040},
+{0x0000A7, 0x0020},
+{0x0000A8, 0x0040},
+{0x0000AA, 0x0004},
+{0x0000AB, 0x0020},
+{0x0000AC, 0x0040},
+{0x0000AD, 0x0080},
+{0x0000AE, 0x0040},
+{0x0000B2, 0x0002},
+{0x0000B4, 0x0040},
+{0x0000B5, 0x0004},
+{0x0000B6, 0x0020},
+{0x0000B8, 0x0040},
+{0x0000B9, 0x0002},
+{0x0000BA, 0x0004},
+{0x0000BB, 0x0020},
+{0x0000BC, 0x0002},
+{0x0000BF, 0x0020},
+{0x0000C0, 0x0004},
+{0x0000D7, 0x0040},
+{0x0000D8, 0x0004},
+{0x0000F7, 0x0040},
+{0x0000F8, 0x0004},
+{0x0002C2, 0x0040},
+{0x0002C6, 0x0004},
+{0x0002D2, 0x0040},
+{0x0002E0, 0x0004},
+{0x0002E5, 0x0040},
+{0x0002EC, 0x0004},
+{0x0002ED, 0x0040},
+{0x0002EE, 0x0004},
+{0x0002EF, 0x0040},
+{0x000300, 0x0010},
+{0x000370, 0x0004},
+{0x000375, 0x0040},
+{0x000376, 0x0004},
+{0x000378, 0x0001},
+{0x00037A, 0x0004},
+{0x00037E, 0x0020},
+{0x00037F, 0x0004},
+{0x000380, 0x0001},
+{0x000384, 0x0040},
+{0x000386, 0x0004},
+{0x000387, 0x0020},
+{0x000388, 0x0004},
+{0x00038B, 0x0001},
+{0x00038C, 0x0004},
+{0x00038D, 0x0001},
+{0x00038E, 0x0004},
+{0x0003A2, 0x0001},
+{0x0003A3, 0x0004},
+{0x0003F6, 0x0040},
+{0x0003F7, 0x0004},
+{0x000482, 0x0040},
+{0x000483, 0x0010},
+{0x00048A, 0x0004},
+{0x000530, 0x0001},
+{0x000531, 0x0004},
+{0x000557, 0x0001},
+{0x000559, 0x0004},
+{0x00055A, 0x0020},
+{0x000560, 0x0004},
+{0x000589, 0x0020},
+{0x00058B, 0x0001},
+{0x00058D, 0x0040},
+{0x000590, 0x0001},
+{0x000591, 0x0010},
+{0x0005BE, 0x0020},
+{0x0005BF, 0x0010},
+{0x0005C0, 0x0020},
+{0x0005C1, 0x0010},
+{0x0005C3, 0x0020},
+{0x0005C4, 0x0010},
+{0x0005C6, 0x0020},
+{0x0005C7, 0x0010},
+{0x0005C8, 0x0001},
+{0x0005D0, 0x0004},
+{0x0005EB, 0x0001},
+{0x0005EF, 0x0004},
+{0x0005F3, 0x0020},
+{0x0005F5, 0x0001},
+{0x000600, 0x0080},
+{0x000606, 0x0040},
+{0x000609, 0x0020},
+{0x00060B, 0x0040},
+{0x00060C, 0x0020},
+{0x00060E, 0x0040},
+{0x000610, 0x0010},
+{0x00061B, 0x0020},
+{0x00061C, 0x0080},
+{0x00061D, 0x0020},
+{0x000620, 0x0004},
+{0x00064B, 0x0010},
+{0x000660, 0x0002},
+{0x00066A, 0x0020},
+{0x00066E, 0x0004},
+{0x000670, 0x0010},
+{0x000671, 0x0004},
+{0x0006D4, 0x0020},
+{0x0006D5, 0x0004},
+{0x0006D6, 0x0010},
+{0x0006DD, 0x0080},
+{0x0006DE, 0x0040},
+{0x0006DF, 0x0010},
+{0x0006E5, 0x0004},
+{0x0006E7, 0x0010},
+{0x0006E9, 0x0040},
+{0x0006EA, 0x0010},
+{0x0006EE, 0x0004},
+{0x0006F0, 0x0002},
+{0x0006FA, 0x0004},
+{0x0006FD, 0x0040},
+{0x0006FF, 0x0004},
+{0x000700, 0x0020},
+{0x00070E, 0x0001},
+{0x00070F, 0x0080},
+{0x000710, 0x0004},
+{0x000711, 0x0010},
+{0x000712, 0x0004},
+{0x000730, 0x0010},
+{0x00074B, 0x0001},
+{0x00074D, 0x0004},
+{0x0007A6, 0x0010},
+{0x0007B1, 0x0004},
+{0x0007B2, 0x0001},
+{0x0007C0, 0x0002},
+{0x0007CA, 0x0004},
+{0x0007EB, 0x0010},
+{0x0007F4, 0x0004},
+{0x0007F6, 0x0040},
+{0x0007F7, 0x0020},
+{0x0007FA, 0x0004},
+{0x0007FB, 0x0001},
+{0x0007FD, 0x0010},
+{0x0007FE, 0x0040},
+{0x000800, 0x0004},
+{0x000816, 0x0010},
+{0x00081A, 0x0004},
+{0x00081B, 0x0010},
+{0x000824, 0x0004},
+{0x000825, 0x0010},
+{0x000828, 0x0004},
+{0x000829, 0x0010},
+{0x00082E, 0x0001},
+{0x000830, 0x0020},
+{0x00083F, 0x0001},
+{0x000840, 0x0004},
+{0x000859, 0x0010},
+{0x00085C, 0x0001},
+{0x00085E, 0x0020},
+{0x00085F, 0x0001},
+{0x000860, 0x0004},
+{0x00086B, 0x0001},
+{0x000870, 0x0004},
+{0x000888, 0x0040},
+{0x000889, 0x0004},
+{0x00088F, 0x0001},
+{0x000890, 0x0080},
+{0x000892, 0x0001},
+{0x000898, 0x0010},
+{0x0008A0, 0x0004},
+{0x0008CA, 0x0010},
+{0x0008E2, 0x0080},
+{0x0008E3, 0x0010},
+{0x000904, 0x0004},
+{0x00093A, 0x0010},
+{0x00093D, 0x0004},
+{0x00093E, 0x0010},
+{0x000950, 0x0004},
+{0x000951, 0x0010},
+{0x000958, 0x0004},
+{0x000962, 0x0010},
+{0x000964, 0x0020},
+{0x000966, 0x0002},
+{0x000970, 0x0020},
+{0x000971, 0x0004},
+{0x000981, 0x0010},
+{0x000984, 0x0001},
+{0x000985, 0x0004},
+{0x00098D, 0x0001},
+{0x00098F, 0x0004},
+{0x000991, 0x0001},
+{0x000993, 0x0004},
+{0x0009A9, 0x0001},
+{0x0009AA, 0x0004},
+{0x0009B1, 0x0001},
+{0x0009B2, 0x0004},
+{0x0009B3, 0x0001},
+{0x0009B6, 0x0004},
+{0x0009BA, 0x0001},
+{0x0009BC, 0x0010},
+{0x0009BD, 0x0004},
+{0x0009BE, 0x0010},
+{0x0009C5, 0x0001},
+{0x0009C7, 0x0010},
+{0x0009C9, 0x0001},
+{0x0009CB, 0x0010},
+{0x0009CE, 0x0004},
+{0x0009CF, 0x0001},
+{0x0009D7, 0x0010},
+{0x0009D8, 0x0001},
+{0x0009DC, 0x0004},
+{0x0009DE, 0x0001},
+{0x0009DF, 0x0004},
+{0x0009E2, 0x0010},
+{0x0009E4, 0x0001},
+{0x0009E6, 0x0002},
+{0x0009F0, 0x0004},
+{0x0009F2, 0x0040},
+{0x0009F4, 0x0002},
+{0x0009FA, 0x0040},
+{0x0009FC, 0x0004},
+{0x0009FD, 0x0020},
+{0x0009FE, 0x0010},
+{0x0009FF, 0x0001},
+{0x000A01, 0x0010},
+{0x000A04, 0x0001},
+{0x000A05, 0x0004},
+{0x000A0B, 0x0001},
+{0x000A0F, 0x0004},
+{0x000A11, 0x0001},
+{0x000A13, 0x0004},
+{0x000A29, 0x0001},
+{0x000A2A, 0x0004},
+{0x000A31, 0x0001},
+{0x000A32, 0x0004},
+{0x000A34, 0x0001},
+{0x000A35, 0x0004},
+{0x000A37, 0x0001},
+{0x000A38, 0x0004},
+{0x000A3A, 0x0001},
+{0x000A3C, 0x0010},
+{0x000A3D, 0x0001},
+{0x000A3E, 0x0010},
+{0x000A43, 0x0001},
+{0x000A47, 0x0010},
+{0x000A49, 0x0001},
+{0x000A4B, 0x0010},
+{0x000A4E, 0x0001},
+{0x000A51, 0x0010},
+{0x000A52, 0x0001},
+{0x000A59, 0x0004},
+{0x000A5D, 0x0001},
+{0x000A5E, 0x0004},
+{0x000A5F, 0x0001},
+{0x000A66, 0x0002},
+{0x000A70, 0x0010},
+{0x000A72, 0x0004},
+{0x000A75, 0x0010},
+{0x000A76, 0x0020},
+{0x000A77, 0x0001},
+{0x000A81, 0x0010},
+{0x000A84, 0x0001},
+{0x000A85, 0x0004},
+{0x000A8E, 0x0001},
+{0x000A8F, 0x0004},
+{0x000A92, 0x0001},
+{0x000A93, 0x0004},
+{0x000AA9, 0x0001},
+{0x000AAA, 0x0004},
+{0x000AB1, 0x0001},
+{0x000AB2, 0x0004},
+{0x000AB4, 0x0001},
+{0x000AB5, 0x0004},
+{0x000ABA, 0x0001},
+{0x000ABC, 0x0010},
+{0x000ABD, 0x0004},
+{0x000ABE, 0x0010},
+{0x000AC6, 0x0001},
+{0x000AC7, 0x0010},
+{0x000ACA, 0x0001},
+{0x000ACB, 0x0010},
+{0x000ACE, 0x0001},
+{0x000AD0, 0x0004},
+{0x000AD1, 0x0001},
+{0x000AE0, 0x0004},
+{0x000AE2, 0x0010},
+{0x000AE4, 0x0001},
+{0x000AE6, 0x0002},
+{0x000AF0, 0x0020},
+{0x000AF1, 0x0040},
+{0x000AF2, 0x0001},
+{0x000AF9, 0x0004},
+{0x000AFA, 0x0010},
+{0x000B00, 0x0001},
+{0x000B01, 0x0010},
+{0x000B04, 0x0001},
+{0x000B05, 0x0004},
+{0x000B0D, 0x0001},
+{0x000B0F, 0x0004},
+{0x000B11, 0x0001},
+{0x000B13, 0x0004},
+{0x000B29, 0x0001},
+{0x000B2A, 0x0004},
+{0x000B31, 0x0001},
+{0x000B32, 0x0004},
+{0x000B34, 0x0001},
+{0x000B35, 0x0004},
+{0x000B3A, 0x0001},
+{0x000B3C, 0x0010},
+{0x000B3D, 0x0004},
+{0x000B3E, 0x0010},
+{0x000B45, 0x0001},
+{0x000B47, 0x0010},
+{0x000B49, 0x0001},
+{0x000B4B, 0x0010},
+{0x000B4E, 0x0001},
+{0x000B55, 0x0010},
+{0x000B58, 0x0001},
+{0x000B5C, 0x0004},
+{0x000B5E, 0x0001},
+{0x000B5F, 0x0004},
+{0x000B62, 0x0010},
+{0x000B64, 0x0001},
+{0x000B66, 0x0002},
+{0x000B70, 0x0040},
+{0x000B71, 0x0004},
+{0x000B72, 0x0002},
+{0x000B78, 0x0001},
+{0x000B82, 0x0010},
+{0x000B83, 0x0004},
+{0x000B84, 0x0001},
+{0x000B85, 0x0004},
+{0x000B8B, 0x0001},
+{0x000B8E, 0x0004},
+{0x000B91, 0x0001},
+{0x000B92, 0x0004},
+{0x000B96, 0x0001},
+{0x000B99, 0x0004},
+{0x000B9B, 0x0001},
+{0x000B9C, 0x0004},
+{0x000B9D, 0x0001},
+{0x000B9E, 0x0004},
+{0x000BA0, 0x0001},
+{0x000BA3, 0x0004},
+{0x000BA5, 0x0001},
+{0x000BA8, 0x0004},
+{0x000BAB, 0x0001},
+{0x000BAE, 0x0004},
+{0x000BBA, 0x0001},
+{0x000BBE, 0x0010},
+{0x000BC3, 0x0001},
+{0x000BC6, 0x0010},
+{0x000BC9, 0x0001},
+{0x000BCA, 0x0010},
+{0x000BCE, 0x0001},
+{0x000BD0, 0x0004},
+{0x000BD1, 0x0001},
+{0x000BD7, 0x0010},
+{0x000BD8, 0x0001},
+{0x000BE6, 0x0002},
+{0x000BF3, 0x0040},
+{0x000BFB, 0x0001},
+{0x000C00, 0x0010},
+{0x000C05, 0x0004},
+{0x000C0D, 0x0001},
+{0x000C0E, 0x0004},
+{0x000C11, 0x0001},
+{0x000C12, 0x0004},
+{0x000C29, 0x0001},
+{0x000C2A, 0x0004},
+{0x000C3A, 0x0001},
+{0x000C3C, 0x0010},
+{0x000C3D, 0x0004},
+{0x000C3E, 0x0010},
+{0x000C45, 0x0001},
+{0x000C46, 0x0010},
+{0x000C49, 0x0001},
+{0x000C4A, 0x0010},
+{0x000C4E, 0x0001},
+{0x000C55, 0x0010},
+{0x000C57, 0x0001},
+{0x000C58, 0x0004},
+{0x000C5B, 0x0001},
+{0x000C5D, 0x0004},
+{0x000C5E, 0x0001},
+{0x000C60, 0x0004},
+{0x000C62, 0x0010},
+{0x000C64, 0x0001},
+{0x000C66, 0x0002},
+{0x000C70, 0x0001},
+{0x000C77, 0x0020},
+{0x000C78, 0x0002},
+{0x000C7F, 0x0040},
+{0x000C80, 0x0004},
+{0x000C81, 0x0010},
+{0x000C84, 0x0020},
+{0x000C85, 0x0004},
+{0x000C8D, 0x0001},
+{0x000C8E, 0x0004},
+{0x000C91, 0x0001},
+{0x000C92, 0x0004},
+{0x000CA9, 0x0001},
+{0x000CAA, 0x0004},
+{0x000CB4, 0x0001},
+{0x000CB5, 0x0004},
+{0x000CBA, 0x0001},
+{0x000CBC, 0x0010},
+{0x000CBD, 0x0004},
+{0x000CBE, 0x0010},
+{0x000CC5, 0x0001},
+{0x000CC6, 0x0010},
+{0x000CC9, 0x0001},
+{0x000CCA, 0x0010},
+{0x000CCE, 0x0001},
+{0x000CD5, 0x0010},
+{0x000CD7, 0x0001},
+{0x000CDD, 0x0004},
+{0x000CDF, 0x0001},
+{0x000CE0, 0x0004},
+{0x000CE2, 0x0010},
+{0x000CE4, 0x0001},
+{0x000CE6, 0x0002},
+{0x000CF0, 0x0001},
+{0x000CF1, 0x0004},
+{0x000CF3, 0x0010},
+{0x000CF4, 0x0001},
+{0x000D00, 0x0010},
+{0x000D04, 0x0004},
+{0x000D0D, 0x0001},
+{0x000D0E, 0x0004},
+{0x000D11, 0x0001},
+{0x000D12, 0x0004},
+{0x000D3B, 0x0010},
+{0x000D3D, 0x0004},
+{0x000D3E, 0x0010},
+{0x000D45, 0x0001},
+{0x000D46, 0x0010},
+{0x000D49, 0x0001},
+{0x000D4A, 0x0010},
+{0x000D4E, 0x0004},
+{0x000D4F, 0x0040},
+{0x000D50, 0x0001},
+{0x000D54, 0x0004},
+{0x000D57, 0x0010},
+{0x000D58, 0x0002},
+{0x000D5F, 0x0004},
+{0x000D62, 0x0010},
+{0x000D64, 0x0001},
+{0x000D66, 0x0002},
+{0x000D79, 0x0040},
+{0x000D7A, 0x0004},
+{0x000D80, 0x0001},
+{0x000D81, 0x0010},
+{0x000D84, 0x0001},
+{0x000D85, 0x0004},
+{0x000D97, 0x0001},
+{0x000D9A, 0x0004},
+{0x000DB2, 0x0001},
+{0x000DB3, 0x0004},
+{0x000DBC, 0x0001},
+{0x000DBD, 0x0004},
+{0x000DBE, 0x0001},
+{0x000DC0, 0x0004},
+{0x000DC7, 0x0001},
+{0x000DCA, 0x0010},
+{0x000DCB, 0x0001},
+{0x000DCF, 0x0010},
+{0x000DD5, 0x0001},
+{0x000DD6, 0x0010},
+{0x000DD7, 0x0001},
+{0x000DD8, 0x0010},
+{0x000DE0, 0x0001},
+{0x000DE6, 0x0002},
+{0x000DF0, 0x0001},
+{0x000DF2, 0x0010},
+{0x000DF4, 0x0020},
+{0x000DF5, 0x0001},
+{0x000E01, 0x0004},
+{0x000E31, 0x0010},
+{0x000E32, 0x0004},
+{0x000E34, 0x0010},
+{0x000E3B, 0x0001},
+{0x000E3F, 0x0040},
+{0x000E40, 0x0004},
+{0x000E47, 0x0010},
+{0x000E4F, 0x0020},
+{0x000E50, 0x0002},
+{0x000E5A, 0x0020},
+{0x000E5C, 0x0001},
+{0x000E81, 0x0004},
+{0x000E83, 0x0001},
+{0x000E84, 0x0004},
+{0x000E85, 0x0001},
+{0x000E86, 0x0004},
+{0x000E8B, 0x0001},
+{0x000E8C, 0x0004},
+{0x000EA4, 0x0001},
+{0x000EA5, 0x0004},
+{0x000EA6, 0x0001},
+{0x000EA7, 0x0004},
+{0x000EB1, 0x0010},
+{0x000EB2, 0x0004},
+{0x000EB4, 0x0010},
+{0x000EBD, 0x0004},
+{0x000EBE, 0x0001},
+{0x000EC0, 0x0004},
+{0x000EC5, 0x0001},
+{0x000EC6, 0x0004},
+{0x000EC7, 0x0001},
+{0x000EC8, 0x0010},
+{0x000ECF, 0x0001},
+{0x000ED0, 0x0002},
+{0x000EDA, 0x0001},
+{0x000EDC, 0x0004},
+{0x000EE0, 0x0001},
+{0x000F00, 0x0004},
+{0x000F01, 0x0040},
+{0x000F04, 0x0020},
+{0x000F13, 0x0040},
+{0x000F14, 0x0020},
+{0x000F15, 0x0040},
+{0x000F18, 0x0010},
+{0x000F1A, 0x0040},
+{0x000F20, 0x0002},
+{0x000F34, 0x0040},
+{0x000F35, 0x0010},
+{0x000F36, 0x0040},
+{0x000F37, 0x0010},
+{0x000F38, 0x0040},
+{0x000F39, 0x0010},
+{0x000F3A, 0x0020},
+{0x000F3E, 0x0010},
+{0x000F40, 0x0004},
+{0x000F48, 0x0001},
+{0x000F49, 0x0004},
+{0x000F6D, 0x0001},
+{0x000F71, 0x0010},
+{0x000F85, 0x0020},
+{0x000F86, 0x0010},
+{0x000F88, 0x0004},
+{0x000F8D, 0x0010},
+{0x000F98, 0x0001},
+{0x000F99, 0x0010},
+{0x000FBD, 0x0001},
+{0x000FBE, 0x0040},
+{0x000FC6, 0x0010},
+{0x000FC7, 0x0040},
+{0x000FCD, 0x0001},
+{0x000FCE, 0x0040},
+{0x000FD0, 0x0020},
+{0x000FD5, 0x0040},
+{0x000FD9, 0x0020},
+{0x000FDB, 0x0001},
+{0x001000, 0x0004},
+{0x00102B, 0x0010},
+{0x00103F, 0x0004},
+{0x001040, 0x0002},
+{0x00104A, 0x0020},
+{0x001050, 0x0004},
+{0x001056, 0x0010},
+{0x00105A, 0x0004},
+{0x00105E, 0x0010},
+{0x001061, 0x0004},
+{0x001062, 0x0010},
+{0x001065, 0x0004},
+{0x001067, 0x0010},
+{0x00106E, 0x0004},
+{0x001071, 0x0010},
+{0x001075, 0x0004},
+{0x001082, 0x0010},
+{0x00108E, 0x0004},
+{0x00108F, 0x0010},
+{0x001090, 0x0002},
+{0x00109A, 0x0010},
+{0x00109E, 0x0040},
+{0x0010A0, 0x0004},
+{0x0010C6, 0x0001},
+{0x0010C7, 0x0004},
+{0x0010C8, 0x0001},
+{0x0010CD, 0x0004},
+{0x0010CE, 0x0001},
+{0x0010D0, 0x0004},
+{0x0010FB, 0x0020},
+{0x0010FC, 0x0004},
+{0x001249, 0x0001},
+{0x00124A, 0x0004},
+{0x00124E, 0x0001},
+{0x001250, 0x0004},
+{0x001257, 0x0001},
+{0x001258, 0x0004},
+{0x001259, 0x0001},
+{0x00125A, 0x0004},
+{0x00125E, 0x0001},
+{0x001260, 0x0004},
+{0x001289, 0x0001},
+{0x00128A, 0x0004},
+{0x00128E, 0x0001},
+{0x001290, 0x0004},
+{0x0012B1, 0x0001},
+{0x0012B2, 0x0004},
+{0x0012B6, 0x0001},
+{0x0012B8, 0x0004},
+{0x0012BF, 0x0001},
+{0x0012C0, 0x0004},
+{0x0012C1, 0x0001},
+{0x0012C2, 0x0004},
+{0x0012C6, 0x0001},
+{0x0012C8, 0x0004},
+{0x0012D7, 0x0001},
+{0x0012D8, 0x0004},
+{0x001311, 0x0001},
+{0x001312, 0x0004},
+{0x001316, 0x0001},
+{0x001318, 0x0004},
+{0x00135B, 0x0001},
+{0x00135D, 0x0010},
+{0x001360, 0x0020},
+{0x001369, 0x0002},
+{0x00137D, 0x0001},
+{0x001380, 0x0004},
+{0x001390, 0x0040},
+{0x00139A, 0x0001},
+{0x0013A0, 0x0004},
+{0x0013F6, 0x0001},
+{0x0013F8, 0x0004},
+{0x0013FE, 0x0001},
+{0x001400, 0x0020},
+{0x001401, 0x0004},
+{0x00166D, 0x0040},
+{0x00166E, 0x0020},
+{0x00166F, 0x0004},
+{0x001680, 0x0008},
+{0x001681, 0x0004},
+{0x00169B, 0x0020},
+{0x00169D, 0x0001},
+{0x0016A0, 0x0004},
+{0x0016EB, 0x0020},
+{0x0016EE, 0x0002},
+{0x0016F1, 0x0004},
+{0x0016F9, 0x0001},
+{0x001700, 0x0004},
+{0x001712, 0x0010},
+{0x001716, 0x0001},
+{0x00171F, 0x0004},
+{0x001732, 0x0010},
+{0x001735, 0x0020},
+{0x001737, 0x0001},
+{0x001740, 0x0004},
+{0x001752, 0x0010},
+{0x001754, 0x0001},
+{0x001760, 0x0004},
+{0x00176D, 0x0001},
+{0x00176E, 0x0004},
+{0x001771, 0x0001},
+{0x001772, 0x0010},
+{0x001774, 0x0001},
+{0x001780, 0x0004},
+{0x0017B4, 0x0010},
+{0x0017D4, 0x0020},
+{0x0017D7, 0x0004},
+{0x0017D8, 0x0020},
+{0x0017DB, 0x0040},
+{0x0017DC, 0x0004},
+{0x0017DD, 0x0010},
+{0x0017DE, 0x0001},
+{0x0017E0, 0x0002},
+{0x0017EA, 0x0001},
+{0x0017F0, 0x0002},
+{0x0017FA, 0x0001},
+{0x001800, 0x0020},
+{0x00180B, 0x0010},
+{0x00180E, 0x0080},
+{0x00180F, 0x0010},
+{0x001810, 0x0002},
+{0x00181A, 0x0001},
+{0x001820, 0x0004},
+{0x001879, 0x0001},
+{0x001880, 0x0004},
+{0x001885, 0x0010},
+{0x001887, 0x0004},
+{0x0018A9, 0x0010},
+{0x0018AA, 0x0004},
+{0x0018AB, 0x0001},
+{0x0018B0, 0x0004},
+{0x0018F6, 0x0001},
+{0x001900, 0x0004},
+{0x00191F, 0x0001},
+{0x001920, 0x0010},
+{0x00192C, 0x0001},
+{0x001930, 0x0010},
+{0x00193C, 0x0001},
+{0x001940, 0x0040},
+{0x001941, 0x0001},
+{0x001944, 0x0020},
+{0x001946, 0x0002},
+{0x001950, 0x0004},
+{0x00196E, 0x0001},
+{0x001970, 0x0004},
+{0x001975, 0x0001},
+{0x001980, 0x0004},
+{0x0019AC, 0x0001},
+{0x0019B0, 0x0004},
+{0x0019CA, 0x0001},
+{0x0019D0, 0x0002},
+{0x0019DB, 0x0001},
+{0x0019DE, 0x0040},
+{0x001A00, 0x0004},
+{0x001A17, 0x0010},
+{0x001A1C, 0x0001},
+{0x001A1E, 0x0020},
+{0x001A20, 0x0004},
+{0x001A55, 0x0010},
+{0x001A5F, 0x0001},
+{0x001A60, 0x0010},
+{0x001A7D, 0x0001},
+{0x001A7F, 0x0010},
+{0x001A80, 0x0002},
+{0x001A8A, 0x0001},
+{0x001A90, 0x0002},
+{0x001A9A, 0x0001},
+{0x001AA0, 0x0020},
+{0x001AA7, 0x0004},
+{0x001AA8, 0x0020},
+{0x001AAE, 0x0001},
+{0x001AB0, 0x0010},
+{0x001ACF, 0x0001},
+{0x001B00, 0x0010},
+{0x001B05, 0x0004},
+{0x001B34, 0x0010},
+{0x001B45, 0x0004},
+{0x001B4D, 0x0001},
+{0x001B50, 0x0002},
+{0x001B5A, 0x0020},
+{0x001B61, 0x0040},
+{0x001B6B, 0x0010},
+{0x001B74, 0x0040},
+{0x001B7D, 0x0020},
+{0x001B7F, 0x0001},
+{0x001B80, 0x0010},
+{0x001B83, 0x0004},
+{0x001BA1, 0x0010},
+{0x001BAE, 0x0004},
+{0x001BB0, 0x0002},
+{0x001BBA, 0x0004},
+{0x001BE6, 0x0010},
+{0x001BF4, 0x0001},
+{0x001BFC, 0x0020},
+{0x001C00, 0x0004},
+{0x001C24, 0x0010},
+{0x001C38, 0x0001},
+{0x001C3B, 0x0020},
+{0x001C40, 0x0002},
+{0x001C4A, 0x0001},
+{0x001C4D, 0x0004},
+{0x001C50, 0x0002},
+{0x001C5A, 0x0004},
+{0x001C7E, 0x0020},
+{0x001C80, 0x0004},
+{0x001C89, 0x0001},
+{0x001C90, 0x0004},
+{0x001CBB, 0x0001},
+{0x001CBD, 0x0004},
+{0x001CC0, 0x0020},
+{0x001CC8, 0x0001},
+{0x001CD0, 0x0010},
+{0x001CD3, 0x0020},
+{0x001CD4, 0x0010},
+{0x001CE9, 0x0004},
+{0x001CED, 0x0010},
+{0x001CEE, 0x0004},
+{0x001CF4, 0x0010},
+{0x001CF5, 0x0004},
+{0x001CF7, 0x0010},
+{0x001CFA, 0x0004},
+{0x001CFB, 0x0001},
+{0x001D00, 0x0004},
+{0x001DC0, 0x0010},
+{0x001E00, 0x0004},
+{0x001F16, 0x0001},
+{0x001F18, 0x0004},
+{0x001F1E, 0x0001},
+{0x001F20, 0x0004},
+{0x001F46, 0x0001},
+{0x001F48, 0x0004},
+{0x001F4E, 0x0001},
+{0x001F50, 0x0004},
+{0x001F58, 0x0001},
+{0x001F59, 0x0004},
+{0x001F5A, 0x0001},
+{0x001F5B, 0x0004},
+{0x001F5C, 0x0001},
+{0x001F5D, 0x0004},
+{0x001F5E, 0x0001},
+{0x001F5F, 0x0004},
+{0x001F7E, 0x0001},
+{0x001F80, 0x0004},
+{0x001FB5, 0x0001},
+{0x001FB6, 0x0004},
+{0x001FBD, 0x0040},
+{0x001FBE, 0x0004},
+{0x001FBF, 0x0040},
+{0x001FC2, 0x0004},
+{0x001FC5, 0x0001},
+{0x001FC6, 0x0004},
+{0x001FCD, 0x0040},
+{0x001FD0, 0x0004},
+{0x001FD4, 0x0001},
+{0x001FD6, 0x0004},
+{0x001FDC, 0x0001},
+{0x001FDD, 0x0040},
+{0x001FE0, 0x0004},
+{0x001FED, 0x0040},
+{0x001FF0, 0x0001},
+{0x001FF2, 0x0004},
+{0x001FF5, 0x0001},
+{0x001FF6, 0x0004},
+{0x001FFD, 0x0040},
+{0x001FFF, 0x0001},
+{0x002000, 0x0008},
+{0x00200B, 0x0080},
+{0x002010, 0x0020},
+{0x002028, 0x0008},
+{0x00202A, 0x0080},
+{0x00202F, 0x0008},
+{0x002030, 0x0020},
+{0x002044, 0x0040},
+{0x002045, 0x0020},
+{0x002052, 0x0040},
+{0x002053, 0x0020},
+{0x00205F, 0x0008},
+{0x002060, 0x0080},
+{0x002065, 0x0001},
+{0x002066, 0x0080},
+{0x002070, 0x0002},
+{0x002071, 0x0004},
+{0x002072, 0x0001},
+{0x002074, 0x0002},
+{0x00207A, 0x0040},
+{0x00207D, 0x0020},
+{0x00207F, 0x0004},
+{0x002080, 0x0002},
+{0x00208A, 0x0040},
+{0x00208D, 0x0020},
+{0x00208F, 0x0001},
+{0x002090, 0x0004},
+{0x00209D, 0x0001},
+{0x0020A0, 0x0040},
+{0x0020C1, 0x0001},
+{0x0020D0, 0x0010},
+{0x0020F1, 0x0001},
+{0x002100, 0x0040},
+{0x002102, 0x0004},
+{0x002103, 0x0040},
+{0x002107, 0x0004},
+{0x002108, 0x0040},
+{0x00210A, 0x0004},
+{0x002114, 0x0040},
+{0x002115, 0x0004},
+{0x002116, 0x0040},
+{0x002119, 0x0004},
+{0x00211E, 0x0040},
+{0x002124, 0x0004},
+{0x002125, 0x0040},
+{0x002126, 0x0004},
+{0x002127, 0x0040},
+{0x002128, 0x0004},
+{0x002129, 0x0040},
+{0x00212A, 0x0004},
+{0x00212E, 0x0040},
+{0x00212F, 0x0004},
+{0x00213A, 0x0040},
+{0x00213C, 0x0004},
+{0x002140, 0x0040},
+{0x002145, 0x0004},
+{0x00214A, 0x0040},
+{0x00214E, 0x0004},
+{0x00214F, 0x0040},
+{0x002150, 0x0002},
+{0x002183, 0x0004},
+{0x002185, 0x0002},
+{0x00218A, 0x0040},
+{0x00218C, 0x0001},
+{0x002190, 0x0040},
+{0x002308, 0x0020},
+{0x00230C, 0x0040},
+{0x002329, 0x0020},
+{0x00232B, 0x0040},
+{0x002427, 0x0001},
+{0x002440, 0x0040},
+{0x00244B, 0x0001},
+{0x002460, 0x0002},
+{0x00249C, 0x0040},
+{0x0024EA, 0x0002},
+{0x002500, 0x0040},
+{0x002768, 0x0020},
+{0x002776, 0x0002},
+{0x002794, 0x0040},
+{0x0027C5, 0x0020},
+{0x0027C7, 0x0040},
+{0x0027E6, 0x0020},
+{0x0027F0, 0x0040},
+{0x002983, 0x0020},
+{0x002999, 0x0040},
+{0x0029D8, 0x0020},
+{0x0029DC, 0x0040},
+{0x0029FC, 0x0020},
+{0x0029FE, 0x0040},
+{0x002B74, 0x0001},
+{0x002B76, 0x0040},
+{0x002B96, 0x0001},
+{0x002B97, 0x0040},
+{0x002C00, 0x0004},
+{0x002CE5, 0x0040},
+{0x002CEB, 0x0004},
+{0x002CEF, 0x0010},
+{0x002CF2, 0x0004},
+{0x002CF4, 0x0001},
+{0x002CF9, 0x0020},
+{0x002CFD, 0x0002},
+{0x002CFE, 0x0020},
+{0x002D00, 0x0004},
+{0x002D26, 0x0001},
+{0x002D27, 0x0004},
+{0x002D28, 0x0001},
+{0x002D2D, 0x0004},
+{0x002D2E, 0x0001},
+{0x002D30, 0x0004},
+{0x002D68, 0x0001},
+{0x002D6F, 0x0004},
+{0x002D70, 0x0020},
+{0x002D71, 0x0001},
+{0x002D7F, 0x0010},
+{0x002D80, 0x0004},
+{0x002D97, 0x0001},
+{0x002DA0, 0x0004},
+{0x002DA7, 0x0001},
+{0x002DA8, 0x0004},
+{0x002DAF, 0x0001},
+{0x002DB0, 0x0004},
+{0x002DB7, 0x0001},
+{0x002DB8, 0x0004},
+{0x002DBF, 0x0001},
+{0x002DC0, 0x0004},
+{0x002DC7, 0x0001},
+{0x002DC8, 0x0004},
+{0x002DCF, 0x0001},
+{0x002DD0, 0x0004},
+{0x002DD7, 0x0001},
+{0x002DD8, 0x0004},
+{0x002DDF, 0x0001},
+{0x002DE0, 0x0010},
+{0x002E00, 0x0020},
+{0x002E2F, 0x0004},
+{0x002E30, 0x0020},
+{0x002E50, 0x0040},
+{0x002E52, 0x0020},
+{0x002E5E, 0x0001},
+{0x002E80, 0x0040},
+{0x002E9A, 0x0001},
+{0x002E9B, 0x0040},
+{0x002EF4, 0x0001},
+{0x002F00, 0x0040},
+{0x002FD6, 0x0001},
+{0x002FF0, 0x0040},
+{0x003000, 0x0008},
+{0x003001, 0x0020},
+{0x003004, 0x0040},
+{0x003005, 0x0004},
+{0x003007, 0x0002},
+{0x003008, 0x0020},
+{0x003012, 0x0040},
+{0x003014, 0x0020},
+{0x003020, 0x0040},
+{0x003021, 0x0002},
+{0x00302A, 0x0010},
+{0x003030, 0x0020},
+{0x003031, 0x0004},
+{0x003036, 0x0040},
+{0x003038, 0x0002},
+{0x00303B, 0x0004},
+{0x00303D, 0x0020},
+{0x00303E, 0x0040},
+{0x003040, 0x0001},
+{0x003041, 0x0004},
+{0x003097, 0x0001},
+{0x003099, 0x0010},
+{0x00309B, 0x0040},
+{0x00309D, 0x0004},
+{0x0030A0, 0x0020},
+{0x0030A1, 0x0004},
+{0x0030FB, 0x0020},
+{0x0030FC, 0x0004},
+{0x003100, 0x0001},
+{0x003105, 0x0004},
+{0x003130, 0x0001},
+{0x003131, 0x0004},
+{0x00318F, 0x0001},
+{0x003190, 0x0040},
+{0x003192, 0x0002},
+{0x003196, 0x0040},
+{0x0031A0, 0x0004},
+{0x0031C0, 0x0040},
+{0x0031E4, 0x0001},
+{0x0031EF, 0x0040},
+{0x0031F0, 0x0004},
+{0x003200, 0x0040},
+{0x00321F, 0x0001},
+{0x003220, 0x0002},
+{0x00322A, 0x0040},
+{0x003248, 0x0002},
+{0x003250, 0x0040},
+{0x003251, 0x0002},
+{0x003260, 0x0040},
+{0x003280, 0x0002},
+{0x00328A, 0x0040},
+{0x0032B1, 0x0002},
+{0x0032C0, 0x0040},
+{0x003400, 0x0004},
+{0x004DC0, 0x0040},
+{0x004E00, 0x0004},
+{0x00A48D, 0x0001},
+{0x00A490, 0x0040},
+{0x00A4C7, 0x0001},
+{0x00A4D0, 0x0004},
+{0x00A4FE, 0x0020},
+{0x00A500, 0x0004},
+{0x00A60D, 0x0020},
+{0x00A610, 0x0004},
+{0x00A620, 0x0002},
+{0x00A62A, 0x0004},
+{0x00A62C, 0x0001},
+{0x00A640, 0x0004},
+{0x00A66F, 0x0010},
+{0x00A673, 0x0020},
+{0x00A674, 0x0010},
+{0x00A67E, 0x0020},
+{0x00A67F, 0x0004},
+{0x00A69E, 0x0010},
+{0x00A6A0, 0x0004},
+{0x00A6E6, 0x0002},
+{0x00A6F0, 0x0010},
+{0x00A6F2, 0x0020},
+{0x00A6F8, 0x0001},
+{0x00A700, 0x0040},
+{0x00A717, 0x0004},
+{0x00A720, 0x0040},
+{0x00A722, 0x0004},
+{0x00A789, 0x0040},
+{0x00A78B, 0x0004},
+{0x00A7CB, 0x0001},
+{0x00A7D0, 0x0004},
+{0x00A7D2, 0x0001},
+{0x00A7D3, 0x0004},
+{0x00A7D4, 0x0001},
+{0x00A7D5, 0x0004},
+{0x00A7DA, 0x0001},
+{0x00A7F2, 0x0004},
+{0x00A802, 0x0010},
+{0x00A803, 0x0004},
+{0x00A806, 0x0010},
+{0x00A807, 0x0004},
+{0x00A80B, 0x0010},
+{0x00A80C, 0x0004},
+{0x00A823, 0x0010},
+{0x00A828, 0x0040},
+{0x00A82C, 0x0010},
+{0x00A82D, 0x0001},
+{0x00A830, 0x0002},
+{0x00A836, 0x0040},
+{0x00A83A, 0x0001},
+{0x00A840, 0x0004},
+{0x00A874, 0x0020},
+{0x00A878, 0x0001},
+{0x00A880, 0x0010},
+{0x00A882, 0x0004},
+{0x00A8B4, 0x0010},
+{0x00A8C6, 0x0001},
+{0x00A8CE, 0x0020},
+{0x00A8D0, 0x0002},
+{0x00A8DA, 0x0001},
+{0x00A8E0, 0x0010},
+{0x00A8F2, 0x0004},
+{0x00A8F8, 0x0020},
+{0x00A8FB, 0x0004},
+{0x00A8FC, 0x0020},
+{0x00A8FD, 0x0004},
+{0x00A8FF, 0x0010},
+{0x00A900, 0x0002},
+{0x00A90A, 0x0004},
+{0x00A926, 0x0010},
+{0x00A92E, 0x0020},
+{0x00A930, 0x0004},
+{0x00A947, 0x0010},
+{0x00A954, 0x0001},
+{0x00A95F, 0x0020},
+{0x00A960, 0x0004},
+{0x00A97D, 0x0001},
+{0x00A980, 0x0010},
+{0x00A984, 0x0004},
+{0x00A9B3, 0x0010},
+{0x00A9C1, 0x0020},
+{0x00A9CE, 0x0001},
+{0x00A9CF, 0x0004},
+{0x00A9D0, 0x0002},
+{0x00A9DA, 0x0001},
+{0x00A9DE, 0x0020},
+{0x00A9E0, 0x0004},
+{0x00A9E5, 0x0010},
+{0x00A9E6, 0x0004},
+{0x00A9F0, 0x0002},
+{0x00A9FA, 0x0004},
+{0x00A9FF, 0x0001},
+{0x00AA00, 0x0004},
+{0x00AA29, 0x0010},
+{0x00AA37, 0x0001},
+{0x00AA40, 0x0004},
+{0x00AA43, 0x0010},
+{0x00AA44, 0x0004},
+{0x00AA4C, 0x0010},
+{0x00AA4E, 0x0001},
+{0x00AA50, 0x0002},
+{0x00AA5A, 0x0001},
+{0x00AA5C, 0x0020},
+{0x00AA60, 0x0004},
+{0x00AA77, 0x0040},
+{0x00AA7A, 0x0004},
+{0x00AA7B, 0x0010},
+{0x00AA7E, 0x0004},
+{0x00AAB0, 0x0010},
+{0x00AAB1, 0x0004},
+{0x00AAB2, 0x0010},
+{0x00AAB5, 0x0004},
+{0x00AAB7, 0x0010},
+{0x00AAB9, 0x0004},
+{0x00AABE, 0x0010},
+{0x00AAC0, 0x0004},
+{0x00AAC1, 0x0010},
+{0x00AAC2, 0x0004},
+{0x00AAC3, 0x0001},
+{0x00AADB, 0x0004},
+{0x00AADE, 0x0020},
+{0x00AAE0, 0x0004},
+{0x00AAEB, 0x0010},
+{0x00AAF0, 0x0020},
+{0x00AAF2, 0x0004},
+{0x00AAF5, 0x0010},
+{0x00AAF7, 0x0001},
+{0x00AB01, 0x0004},
+{0x00AB07, 0x0001},
+{0x00AB09, 0x0004},
+{0x00AB0F, 0x0001},
+{0x00AB11, 0x0004},
+{0x00AB17, 0x0001},
+{0x00AB20, 0x0004},
+{0x00AB27, 0x0001},
+{0x00AB28, 0x0004},
+{0x00AB2F, 0x0001},
+{0x00AB30, 0x0004},
+{0x00AB5B, 0x0040},
+{0x00AB5C, 0x0004},
+{0x00AB6A, 0x0040},
+{0x00AB6C, 0x0001},
+{0x00AB70, 0x0004},
+{0x00ABE3, 0x0010},
+{0x00ABEB, 0x0020},
+{0x00ABEC, 0x0010},
+{0x00ABEE, 0x0001},
+{0x00ABF0, 0x0002},
+{0x00ABFA, 0x0001},
+{0x00AC00, 0x0004},
+{0x00D7A4, 0x0001},
+{0x00D7B0, 0x0004},
+{0x00D7C7, 0x0001},
+{0x00D7CB, 0x0004},
+{0x00D7FC, 0x0001},
+{0x00D800, 0x0080},
+{0x00F900, 0x0004},
+{0x00FA6E, 0x0001},
+{0x00FA70, 0x0004},
+{0x00FADA, 0x0001},
+{0x00FB00, 0x0004},
+{0x00FB07, 0x0001},
+{0x00FB13, 0x0004},
+{0x00FB18, 0x0001},
+{0x00FB1D, 0x0004},
+{0x00FB1E, 0x0010},
+{0x00FB1F, 0x0004},
+{0x00FB29, 0x0040},
+{0x00FB2A, 0x0004},
+{0x00FB37, 0x0001},
+{0x00FB38, 0x0004},
+{0x00FB3D, 0x0001},
+{0x00FB3E, 0x0004},
+{0x00FB3F, 0x0001},
+{0x00FB40, 0x0004},
+{0x00FB42, 0x0001},
+{0x00FB43, 0x0004},
+{0x00FB45, 0x0001},
+{0x00FB46, 0x0004},
+{0x00FBB2, 0x0040},
+{0x00FBC3, 0x0001},
+{0x00FBD3, 0x0004},
+{0x00FD3E, 0x0020},
+{0x00FD40, 0x0040},
+{0x00FD50, 0x0004},
+{0x00FD90, 0x0001},
+{0x00FD92, 0x0004},
+{0x00FDC8, 0x0001},
+{0x00FDCF, 0x0040},
+{0x00FDD0, 0x0001},
+{0x00FDF0, 0x0004},
+{0x00FDFC, 0x0040},
+{0x00FE00, 0x0010},
+{0x00FE10, 0x0020},
+{0x00FE1A, 0x0001},
+{0x00FE20, 0x0010},
+{0x00FE30, 0x0020},
+{0x00FE53, 0x0001},
+{0x00FE54, 0x0020},
+{0x00FE62, 0x0040},
+{0x00FE63, 0x0020},
+{0x00FE64, 0x0040},
+{0x00FE67, 0x0001},
+{0x00FE68, 0x0020},
+{0x00FE69, 0x0040},
+{0x00FE6A, 0x0020},
+{0x00FE6C, 0x0001},
+{0x00FE70, 0x0004},
+{0x00FE75, 0x0001},
+{0x00FE76, 0x0004},
+{0x00FEFD, 0x0001},
+{0x00FEFF, 0x0080},
+{0x00FF00, 0x0001},
+{0x00FF01, 0x0020},
+{0x00FF04, 0x0040},
+{0x00FF05, 0x0020},
+{0x00FF0B, 0x0040},
+{0x00FF0C, 0x0020},
+{0x00FF10, 0x0002},
+{0x00FF1A, 0x0020},
+{0x00FF1C, 0x0040},
+{0x00FF1F, 0x0020},
+{0x00FF21, 0x0004},
+{0x00FF3B, 0x0020},
+{0x00FF3E, 0x0040},
+{0x00FF3F, 0x0020},
+{0x00FF40, 0x0040},
+{0x00FF41, 0x0004},
+{0x00FF5B, 0x0020},
+{0x00FF5C, 0x0040},
+{0x00FF5D, 0x0020},
+{0x00FF5E, 0x0040},
+{0x00FF5F, 0x0020},
+{0x00FF66, 0x0004},
+{0x00FFBF, 0x0001},
+{0x00FFC2, 0x0004},
+{0x00FFC8, 0x0001},
+{0x00FFCA, 0x0004},
+{0x00FFD0, 0x0001},
+{0x00FFD2, 0x0004},
+{0x00FFD8, 0x0001},
+{0x00FFDA, 0x0004},
+{0x00FFDD, 0x0001},
+{0x00FFE0, 0x0040},
+{0x00FFE7, 0x0001},
+{0x00FFE8, 0x0040},
+{0x00FFEF, 0x0001},
+{0x00FFF9, 0x0080},
+{0x00FFFC, 0x0040},
+{0x00FFFE, 0x0001},
+{0x010000, 0x0004},
+{0x01000C, 0x0001},
+{0x01000D, 0x0004},
+{0x010027, 0x0001},
+{0x010028, 0x0004},
+{0x01003B, 0x0001},
+{0x01003C, 0x0004},
+{0x01003E, 0x0001},
+{0x01003F, 0x0004},
+{0x01004E, 0x0001},
+{0x010050, 0x0004},
+{0x01005E, 0x0001},
+{0x010080, 0x0004},
+{0x0100FB, 0x0001},
+{0x010100, 0x0020},
+{0x010103, 0x0001},
+{0x010107, 0x0002},
+{0x010134, 0x0001},
+{0x010137, 0x0040},
+{0x010140, 0x0002},
+{0x010179, 0x0040},
+{0x01018A, 0x0002},
+{0x01018C, 0x0040},
+{0x01018F, 0x0001},
+{0x010190, 0x0040},
+{0x01019D, 0x0001},
+{0x0101A0, 0x0040},
+{0x0101A1, 0x0001},
+{0x0101D0, 0x0040},
+{0x0101FD, 0x0010},
+{0x0101FE, 0x0001},
+{0x010280, 0x0004},
+{0x01029D, 0x0001},
+{0x0102A0, 0x0004},
+{0x0102D1, 0x0001},
+{0x0102E0, 0x0010},
+{0x0102E1, 0x0002},
+{0x0102FC, 0x0001},
+{0x010300, 0x0004},
+{0x010320, 0x0002},
+{0x010324, 0x0001},
+{0x01032D, 0x0004},
+{0x010341, 0x0002},
+{0x010342, 0x0004},
+{0x01034A, 0x0002},
+{0x01034B, 0x0001},
+{0x010350, 0x0004},
+{0x010376, 0x0010},
+{0x01037B, 0x0001},
+{0x010380, 0x0004},
+{0x01039E, 0x0001},
+{0x01039F, 0x0020},
+{0x0103A0, 0x0004},
+{0x0103C4, 0x0001},
+{0x0103C8, 0x0004},
+{0x0103D0, 0x0020},
+{0x0103D1, 0x0002},
+{0x0103D6, 0x0001},
+{0x010400, 0x0004},
+{0x01049E, 0x0001},
+{0x0104A0, 0x0002},
+{0x0104AA, 0x0001},
+{0x0104B0, 0x0004},
+{0x0104D4, 0x0001},
+{0x0104D8, 0x0004},
+{0x0104FC, 0x0001},
+{0x010500, 0x0004},
+{0x010528, 0x0001},
+{0x010530, 0x0004},
+{0x010564, 0x0001},
+{0x01056F, 0x0020},
+{0x010570, 0x0004},
+{0x01057B, 0x0001},
+{0x01057C, 0x0004},
+{0x01058B, 0x0001},
+{0x01058C, 0x0004},
+{0x010593, 0x0001},
+{0x010594, 0x0004},
+{0x010596, 0x0001},
+{0x010597, 0x0004},
+{0x0105A2, 0x0001},
+{0x0105A3, 0x0004},
+{0x0105B2, 0x0001},
+{0x0105B3, 0x0004},
+{0x0105BA, 0x0001},
+{0x0105BB, 0x0004},
+{0x0105BD, 0x0001},
+{0x010600, 0x0004},
+{0x010737, 0x0001},
+{0x010740, 0x0004},
+{0x010756, 0x0001},
+{0x010760, 0x0004},
+{0x010768, 0x0001},
+{0x010780, 0x0004},
+{0x010786, 0x0001},
+{0x010787, 0x0004},
+{0x0107B1, 0x0001},
+{0x0107B2, 0x0004},
+{0x0107BB, 0x0001},
+{0x010800, 0x0004},
+{0x010806, 0x0001},
+{0x010808, 0x0004},
+{0x010809, 0x0001},
+{0x01080A, 0x0004},
+{0x010836, 0x0001},
+{0x010837, 0x0004},
+{0x010839, 0x0001},
+{0x01083C, 0x0004},
+{0x01083D, 0x0001},
+{0x01083F, 0x0004},
+{0x010856, 0x0001},
+{0x010857, 0x0020},
+{0x010858, 0x0002},
+{0x010860, 0x0004},
+{0x010877, 0x0040},
+{0x010879, 0x0002},
+{0x010880, 0x0004},
+{0x01089F, 0x0001},
+{0x0108A7, 0x0002},
+{0x0108B0, 0x0001},
+{0x0108E0, 0x0004},
+{0x0108F3, 0x0001},
+{0x0108F4, 0x0004},
+{0x0108F6, 0x0001},
+{0x0108FB, 0x0002},
+{0x010900, 0x0004},
+{0x010916, 0x0002},
+{0x01091C, 0x0001},
+{0x01091F, 0x0020},
+{0x010920, 0x0004},
+{0x01093A, 0x0001},
+{0x01093F, 0x0020},
+{0x010940, 0x0001},
+{0x010980, 0x0004},
+{0x0109B8, 0x0001},
+{0x0109BC, 0x0002},
+{0x0109BE, 0x0004},
+{0x0109C0, 0x0002},
+{0x0109D0, 0x0001},
+{0x0109D2, 0x0002},
+{0x010A00, 0x0004},
+{0x010A01, 0x0010},
+{0x010A04, 0x0001},
+{0x010A05, 0x0010},
+{0x010A07, 0x0001},
+{0x010A0C, 0x0010},
+{0x010A10, 0x0004},
+{0x010A14, 0x0001},
+{0x010A15, 0x0004},
+{0x010A18, 0x0001},
+{0x010A19, 0x0004},
+{0x010A36, 0x0001},
+{0x010A38, 0x0010},
+{0x010A3B, 0x0001},
+{0x010A3F, 0x0010},
+{0x010A40, 0x0002},
+{0x010A49, 0x0001},
+{0x010A50, 0x0020},
+{0x010A59, 0x0001},
+{0x010A60, 0x0004},
+{0x010A7D, 0x0002},
+{0x010A7F, 0x0020},
+{0x010A80, 0x0004},
+{0x010A9D, 0x0002},
+{0x010AA0, 0x0001},
+{0x010AC0, 0x0004},
+{0x010AC8, 0x0040},
+{0x010AC9, 0x0004},
+{0x010AE5, 0x0010},
+{0x010AE7, 0x0001},
+{0x010AEB, 0x0002},
+{0x010AF0, 0x0020},
+{0x010AF7, 0x0001},
+{0x010B00, 0x0004},
+{0x010B36, 0x0001},
+{0x010B39, 0x0020},
+{0x010B40, 0x0004},
+{0x010B56, 0x0001},
+{0x010B58, 0x0002},
+{0x010B60, 0x0004},
+{0x010B73, 0x0001},
+{0x010B78, 0x0002},
+{0x010B80, 0x0004},
+{0x010B92, 0x0001},
+{0x010B99, 0x0020},
+{0x010B9D, 0x0001},
+{0x010BA9, 0x0002},
+{0x010BB0, 0x0001},
+{0x010C00, 0x0004},
+{0x010C49, 0x0001},
+{0x010C80, 0x0004},
+{0x010CB3, 0x0001},
+{0x010CC0, 0x0004},
+{0x010CF3, 0x0001},
+{0x010CFA, 0x0002},
+{0x010D00, 0x0004},
+{0x010D24, 0x0010},
+{0x010D28, 0x0001},
+{0x010D30, 0x0002},
+{0x010D3A, 0x0001},
+{0x010E60, 0x0002},
+{0x010E7F, 0x0001},
+{0x010E80, 0x0004},
+{0x010EAA, 0x0001},
+{0x010EAB, 0x0010},
+{0x010EAD, 0x0020},
+{0x010EAE, 0x0001},
+{0x010EB0, 0x0004},
+{0x010EB2, 0x0001},
+{0x010EFD, 0x0010},
+{0x010F00, 0x0004},
+{0x010F1D, 0x0002},
+{0x010F27, 0x0004},
+{0x010F28, 0x0001},
+{0x010F30, 0x0004},
+{0x010F46, 0x0010},
+{0x010F51, 0x0002},
+{0x010F55, 0x0020},
+{0x010F5A, 0x0001},
+{0x010F70, 0x0004},
+{0x010F82, 0x0010},
+{0x010F86, 0x0020},
+{0x010F8A, 0x0001},
+{0x010FB0, 0x0004},
+{0x010FC5, 0x0002},
+{0x010FCC, 0x0001},
+{0x010FE0, 0x0004},
+{0x010FF7, 0x0001},
+{0x011000, 0x0010},
+{0x011003, 0x0004},
+{0x011038, 0x0010},
+{0x011047, 0x0020},
+{0x01104E, 0x0001},
+{0x011052, 0x0002},
+{0x011070, 0x0010},
+{0x011071, 0x0004},
+{0x011073, 0x0010},
+{0x011075, 0x0004},
+{0x011076, 0x0001},
+{0x01107F, 0x0010},
+{0x011083, 0x0004},
+{0x0110B0, 0x0010},
+{0x0110BB, 0x0020},
+{0x0110BD, 0x0080},
+{0x0110BE, 0x0020},
+{0x0110C2, 0x0010},
+{0x0110C3, 0x0001},
+{0x0110CD, 0x0080},
+{0x0110CE, 0x0001},
+{0x0110D0, 0x0004},
+{0x0110E9, 0x0001},
+{0x0110F0, 0x0002},
+{0x0110FA, 0x0001},
+{0x011100, 0x0010},
+{0x011103, 0x0004},
+{0x011127, 0x0010},
+{0x011135, 0x0001},
+{0x011136, 0x0002},
+{0x011140, 0x0020},
+{0x011144, 0x0004},
+{0x011145, 0x0010},
+{0x011147, 0x0004},
+{0x011148, 0x0001},
+{0x011150, 0x0004},
+{0x011173, 0x0010},
+{0x011174, 0x0020},
+{0x011176, 0x0004},
+{0x011177, 0x0001},
+{0x011180, 0x0010},
+{0x011183, 0x0004},
+{0x0111B3, 0x0010},
+{0x0111C1, 0x0004},
+{0x0111C5, 0x0020},
+{0x0111C9, 0x0010},
+{0x0111CD, 0x0020},
+{0x0111CE, 0x0010},
+{0x0111D0, 0x0002},
+{0x0111DA, 0x0004},
+{0x0111DB, 0x0020},
+{0x0111DC, 0x0004},
+{0x0111DD, 0x0020},
+{0x0111E0, 0x0001},
+{0x0111E1, 0x0002},
+{0x0111F5, 0x0001},
+{0x011200, 0x0004},
+{0x011212, 0x0001},
+{0x011213, 0x0004},
+{0x01122C, 0x0010},
+{0x011238, 0x0020},
+{0x01123E, 0x0010},
+{0x01123F, 0x0004},
+{0x011241, 0x0010},
+{0x011242, 0x0001},
+{0x011280, 0x0004},
+{0x011287, 0x0001},
+{0x011288, 0x0004},
+{0x011289, 0x0001},
+{0x01128A, 0x0004},
+{0x01128E, 0x0001},
+{0x01128F, 0x0004},
+{0x01129E, 0x0001},
+{0x01129F, 0x0004},
+{0x0112A9, 0x0020},
+{0x0112AA, 0x0001},
+{0x0112B0, 0x0004},
+{0x0112DF, 0x0010},
+{0x0112EB, 0x0001},
+{0x0112F0, 0x0002},
+{0x0112FA, 0x0001},
+{0x011300, 0x0010},
+{0x011304, 0x0001},
+{0x011305, 0x0004},
+{0x01130D, 0x0001},
+{0x01130F, 0x0004},
+{0x011311, 0x0001},
+{0x011313, 0x0004},
+{0x011329, 0x0001},
+{0x01132A, 0x0004},
+{0x011331, 0x0001},
+{0x011332, 0x0004},
+{0x011334, 0x0001},
+{0x011335, 0x0004},
+{0x01133A, 0x0001},
+{0x01133B, 0x0010},
+{0x01133D, 0x0004},
+{0x01133E, 0x0010},
+{0x011345, 0x0001},
+{0x011347, 0x0010},
+{0x011349, 0x0001},
+{0x01134B, 0x0010},
+{0x01134E, 0x0001},
+{0x011350, 0x0004},
+{0x011351, 0x0001},
+{0x011357, 0x0010},
+{0x011358, 0x0001},
+{0x01135D, 0x0004},
+{0x011362, 0x0010},
+{0x011364, 0x0001},
+{0x011366, 0x0010},
+{0x01136D, 0x0001},
+{0x011370, 0x0010},
+{0x011375, 0x0001},
+{0x011400, 0x0004},
+{0x011435, 0x0010},
+{0x011447, 0x0004},
+{0x01144B, 0x0020},
+{0x011450, 0x0002},
+{0x01145A, 0x0020},
+{0x01145C, 0x0001},
+{0x01145D, 0x0020},
+{0x01145E, 0x0010},
+{0x01145F, 0x0004},
+{0x011462, 0x0001},
+{0x011480, 0x0004},
+{0x0114B0, 0x0010},
+{0x0114C4, 0x0004},
+{0x0114C6, 0x0020},
+{0x0114C7, 0x0004},
+{0x0114C8, 0x0001},
+{0x0114D0, 0x0002},
+{0x0114DA, 0x0001},
+{0x011580, 0x0004},
+{0x0115AF, 0x0010},
+{0x0115B6, 0x0001},
+{0x0115B8, 0x0010},
+{0x0115C1, 0x0020},
+{0x0115D8, 0x0004},
+{0x0115DC, 0x0010},
+{0x0115DE, 0x0001},
+{0x011600, 0x0004},
+{0x011630, 0x0010},
+{0x011641, 0x0020},
+{0x011644, 0x0004},
+{0x011645, 0x0001},
+{0x011650, 0x0002},
+{0x01165A, 0x0001},
+{0x011660, 0x0020},
+{0x01166D, 0x0001},
+{0x011680, 0x0004},
+{0x0116AB, 0x0010},
+{0x0116B8, 0x0004},
+{0x0116B9, 0x0020},
+{0x0116BA, 0x0001},
+{0x0116C0, 0x0002},
+{0x0116CA, 0x0001},
+{0x011700, 0x0004},
+{0x01171B, 0x0001},
+{0x01171D, 0x0010},
+{0x01172C, 0x0001},
+{0x011730, 0x0002},
+{0x01173C, 0x0020},
+{0x01173F, 0x0040},
+{0x011740, 0x0004},
+{0x011747, 0x0001},
+{0x011800, 0x0004},
+{0x01182C, 0x0010},
+{0x01183B, 0x0020},
+{0x01183C, 0x0001},
+{0x0118A0, 0x0004},
+{0x0118E0, 0x0002},
+{0x0118F3, 0x0001},
+{0x0118FF, 0x0004},
+{0x011907, 0x0001},
+{0x011909, 0x0004},
+{0x01190A, 0x0001},
+{0x01190C, 0x0004},
+{0x011914, 0x0001},
+{0x011915, 0x0004},
+{0x011917, 0x0001},
+{0x011918, 0x0004},
+{0x011930, 0x0010},
+{0x011936, 0x0001},
+{0x011937, 0x0010},
+{0x011939, 0x0001},
+{0x01193B, 0x0010},
+{0x01193F, 0x0004},
+{0x011940, 0x0010},
+{0x011941, 0x0004},
+{0x011942, 0x0010},
+{0x011944, 0x0020},
+{0x011947, 0x0001},
+{0x011950, 0x0002},
+{0x01195A, 0x0001},
+{0x0119A0, 0x0004},
+{0x0119A8, 0x0001},
+{0x0119AA, 0x0004},
+{0x0119D1, 0x0010},
+{0x0119D8, 0x0001},
+{0x0119DA, 0x0010},
+{0x0119E1, 0x0004},
+{0x0119E2, 0x0020},
+{0x0119E3, 0x0004},
+{0x0119E4, 0x0010},
+{0x0119E5, 0x0001},
+{0x011A00, 0x0004},
+{0x011A01, 0x0010},
+{0x011A0B, 0x0004},
+{0x011A33, 0x0010},
+{0x011A3A, 0x0004},
+{0x011A3B, 0x0010},
+{0x011A3F, 0x0020},
+{0x011A47, 0x0010},
+{0x011A48, 0x0001},
+{0x011A50, 0x0004},
+{0x011A51, 0x0010},
+{0x011A5C, 0x0004},
+{0x011A8A, 0x0010},
+{0x011A9A, 0x0020},
+{0x011A9D, 0x0004},
+{0x011A9E, 0x0020},
+{0x011AA3, 0x0001},
+{0x011AB0, 0x0004},
+{0x011AF9, 0x0001},
+{0x011B00, 0x0020},
+{0x011B0A, 0x0001},
+{0x011C00, 0x0004},
+{0x011C09, 0x0001},
+{0x011C0A, 0x0004},
+{0x011C2F, 0x0010},
+{0x011C37, 0x0001},
+{0x011C38, 0x0010},
+{0x011C40, 0x0004},
+{0x011C41, 0x0020},
+{0x011C46, 0x0001},
+{0x011C50, 0x0002},
+{0x011C6D, 0x0001},
+{0x011C70, 0x0020},
+{0x011C72, 0x0004},
+{0x011C90, 0x0001},
+{0x011C92, 0x0010},
+{0x011CA8, 0x0001},
+{0x011CA9, 0x0010},
+{0x011CB7, 0x0001},
+{0x011D00, 0x0004},
+{0x011D07, 0x0001},
+{0x011D08, 0x0004},
+{0x011D0A, 0x0001},
+{0x011D0B, 0x0004},
+{0x011D31, 0x0010},
+{0x011D37, 0x0001},
+{0x011D3A, 0x0010},
+{0x011D3B, 0x0001},
+{0x011D3C, 0x0010},
+{0x011D3E, 0x0001},
+{0x011D3F, 0x0010},
+{0x011D46, 0x0004},
+{0x011D47, 0x0010},
+{0x011D48, 0x0001},
+{0x011D50, 0x0002},
+{0x011D5A, 0x0001},
+{0x011D60, 0x0004},
+{0x011D66, 0x0001},
+{0x011D67, 0x0004},
+{0x011D69, 0x0001},
+{0x011D6A, 0x0004},
+{0x011D8A, 0x0010},
+{0x011D8F, 0x0001},
+{0x011D90, 0x0010},
+{0x011D92, 0x0001},
+{0x011D93, 0x0010},
+{0x011D98, 0x0004},
+{0x011D99, 0x0001},
+{0x011DA0, 0x0002},
+{0x011DAA, 0x0001},
+{0x011EE0, 0x0004},
+{0x011EF3, 0x0010},
+{0x011EF7, 0x0020},
+{0x011EF9, 0x0001},
+{0x011F00, 0x0010},
+{0x011F02, 0x0004},
+{0x011F03, 0x0010},
+{0x011F04, 0x0004},
+{0x011F11, 0x0001},
+{0x011F12, 0x0004},
+{0x011F34, 0x0010},
+{0x011F3B, 0x0001},
+{0x011F3E, 0x0010},
+{0x011F43, 0x0020},
+{0x011F50, 0x0002},
+{0x011F5A, 0x0001},
+{0x011FB0, 0x0004},
+{0x011FB1, 0x0001},
+{0x011FC0, 0x0002},
+{0x011FD5, 0x0040},
+{0x011FF2, 0x0001},
+{0x011FFF, 0x0020},
+{0x012000, 0x0004},
+{0x01239A, 0x0001},
+{0x012400, 0x0002},
+{0x01246F, 0x0001},
+{0x012470, 0x0020},
+{0x012475, 0x0001},
+{0x012480, 0x0004},
+{0x012544, 0x0001},
+{0x012F90, 0x0004},
+{0x012FF1, 0x0020},
+{0x012FF3, 0x0001},
+{0x013000, 0x0004},
+{0x013430, 0x0080},
+{0x013440, 0x0010},
+{0x013441, 0x0004},
+{0x013447, 0x0010},
+{0x013456, 0x0001},
+{0x014400, 0x0004},
+{0x014647, 0x0001},
+{0x016800, 0x0004},
+{0x016A39, 0x0001},
+{0x016A40, 0x0004},
+{0x016A5F, 0x0001},
+{0x016A60, 0x0002},
+{0x016A6A, 0x0001},
+{0x016A6E, 0x0020},
+{0x016A70, 0x0004},
+{0x016ABF, 0x0001},
+{0x016AC0, 0x0002},
+{0x016ACA, 0x0001},
+{0x016AD0, 0x0004},
+{0x016AEE, 0x0001},
+{0x016AF0, 0x0010},
+{0x016AF5, 0x0020},
+{0x016AF6, 0x0001},
+{0x016B00, 0x0004},
+{0x016B30, 0x0010},
+{0x016B37, 0x0020},
+{0x016B3C, 0x0040},
+{0x016B40, 0x0004},
+{0x016B44, 0x0020},
+{0x016B45, 0x0040},
+{0x016B46, 0x0001},
+{0x016B50, 0x0002},
+{0x016B5A, 0x0001},
+{0x016B5B, 0x0002},
+{0x016B62, 0x0001},
+{0x016B63, 0x0004},
+{0x016B78, 0x0001},
+{0x016B7D, 0x0004},
+{0x016B90, 0x0001},
+{0x016E40, 0x0004},
+{0x016E80, 0x0002},
+{0x016E97, 0x0020},
+{0x016E9B, 0x0001},
+{0x016F00, 0x0004},
+{0x016F4B, 0x0001},
+{0x016F4F, 0x0010},
+{0x016F50, 0x0004},
+{0x016F51, 0x0010},
+{0x016F88, 0x0001},
+{0x016F8F, 0x0010},
+{0x016F93, 0x0004},
+{0x016FA0, 0x0001},
+{0x016FE0, 0x0004},
+{0x016FE2, 0x0020},
+{0x016FE3, 0x0004},
+{0x016FE4, 0x0010},
+{0x016FE5, 0x0001},
+{0x016FF0, 0x0010},
+{0x016FF2, 0x0001},
+{0x017000, 0x0004},
+{0x0187F8, 0x0001},
+{0x018800, 0x0004},
+{0x018CD6, 0x0001},
+{0x018D00, 0x0004},
+{0x018D09, 0x0001},
+{0x01AFF0, 0x0004},
+{0x01AFF4, 0x0001},
+{0x01AFF5, 0x0004},
+{0x01AFFC, 0x0001},
+{0x01AFFD, 0x0004},
+{0x01AFFF, 0x0001},
+{0x01B000, 0x0004},
+{0x01B123, 0x0001},
+{0x01B132, 0x0004},
+{0x01B133, 0x0001},
+{0x01B150, 0x0004},
+{0x01B153, 0x0001},
+{0x01B155, 0x0004},
+{0x01B156, 0x0001},
+{0x01B164, 0x0004},
+{0x01B168, 0x0001},
+{0x01B170, 0x0004},
+{0x01B2FC, 0x0001},
+{0x01BC00, 0x0004},
+{0x01BC6B, 0x0001},
+{0x01BC70, 0x0004},
+{0x01BC7D, 0x0001},
+{0x01BC80, 0x0004},
+{0x01BC89, 0x0001},
+{0x01BC90, 0x0004},
+{0x01BC9A, 0x0001},
+{0x01BC9C, 0x0040},
+{0x01BC9D, 0x0010},
+{0x01BC9F, 0x0020},
+{0x01BCA0, 0x0080},
+{0x01BCA4, 0x0001},
+{0x01CF00, 0x0010},
+{0x01CF2E, 0x0001},
+{0x01CF30, 0x0010},
+{0x01CF47, 0x0001},
+{0x01CF50, 0x0040},
+{0x01CFC4, 0x0001},
+{0x01D000, 0x0040},
+{0x01D0F6, 0x0001},
+{0x01D100, 0x0040},
+{0x01D127, 0x0001},
+{0x01D129, 0x0040},
+{0x01D165, 0x0010},
+{0x01D16A, 0x0040},
+{0x01D16D, 0x0010},
+{0x01D173, 0x0080},
+{0x01D17B, 0x0010},
+{0x01D183, 0x0040},
+{0x01D185, 0x0010},
+{0x01D18C, 0x0040},
+{0x01D1AA, 0x0010},
+{0x01D1AE, 0x0040},
+{0x01D1EB, 0x0001},
+{0x01D200, 0x0040},
+{0x01D242, 0x0010},
+{0x01D245, 0x0040},
+{0x01D246, 0x0001},
+{0x01D2C0, 0x0002},
+{0x01D2D4, 0x0001},
+{0x01D2E0, 0x0002},
+{0x01D2F4, 0x0001},
+{0x01D300, 0x0040},
+{0x01D357, 0x0001},
+{0x01D360, 0x0002},
+{0x01D379, 0x0001},
+{0x01D400, 0x0004},
+{0x01D455, 0x0001},
+{0x01D456, 0x0004},
+{0x01D49D, 0x0001},
+{0x01D49E, 0x0004},
+{0x01D4A0, 0x0001},
+{0x01D4A2, 0x0004},
+{0x01D4A3, 0x0001},
+{0x01D4A5, 0x0004},
+{0x01D4A7, 0x0001},
+{0x01D4A9, 0x0004},
+{0x01D4AD, 0x0001},
+{0x01D4AE, 0x0004},
+{0x01D4BA, 0x0001},
+{0x01D4BB, 0x0004},
+{0x01D4BC, 0x0001},
+{0x01D4BD, 0x0004},
+{0x01D4C4, 0x0001},
+{0x01D4C5, 0x0004},
+{0x01D506, 0x0001},
+{0x01D507, 0x0004},
+{0x01D50B, 0x0001},
+{0x01D50D, 0x0004},
+{0x01D515, 0x0001},
+{0x01D516, 0x0004},
+{0x01D51D, 0x0001},
+{0x01D51E, 0x0004},
+{0x01D53A, 0x0001},
+{0x01D53B, 0x0004},
+{0x01D53F, 0x0001},
+{0x01D540, 0x0004},
+{0x01D545, 0x0001},
+{0x01D546, 0x0004},
+{0x01D547, 0x0001},
+{0x01D54A, 0x0004},
+{0x01D551, 0x0001},
+{0x01D552, 0x0004},
+{0x01D6A6, 0x0001},
+{0x01D6A8, 0x0004},
+{0x01D6C1, 0x0040},
+{0x01D6C2, 0x0004},
+{0x01D6DB, 0x0040},
+{0x01D6DC, 0x0004},
+{0x01D6FB, 0x0040},
+{0x01D6FC, 0x0004},
+{0x01D715, 0x0040},
+{0x01D716, 0x0004},
+{0x01D735, 0x0040},
+{0x01D736, 0x0004},
+{0x01D74F, 0x0040},
+{0x01D750, 0x0004},
+{0x01D76F, 0x0040},
+{0x01D770, 0x0004},
+{0x01D789, 0x0040},
+{0x01D78A, 0x0004},
+{0x01D7A9, 0x0040},
+{0x01D7AA, 0x0004},
+{0x01D7C3, 0x0040},
+{0x01D7C4, 0x0004},
+{0x01D7CC, 0x0001},
+{0x01D7CE, 0x0002},
+{0x01D800, 0x0040},
+{0x01DA00, 0x0010},
+{0x01DA37, 0x0040},
+{0x01DA3B, 0x0010},
+{0x01DA6D, 0x0040},
+{0x01DA75, 0x0010},
+{0x01DA76, 0x0040},
+{0x01DA84, 0x0010},
+{0x01DA85, 0x0040},
+{0x01DA87, 0x0020},
+{0x01DA8C, 0x0001},
+{0x01DA9B, 0x0010},
+{0x01DAA0, 0x0001},
+{0x01DAA1, 0x0010},
+{0x01DAB0, 0x0001},
+{0x01DF00, 0x0004},
+{0x01DF1F, 0x0001},
+{0x01DF25, 0x0004},
+{0x01DF2B, 0x0001},
+{0x01E000, 0x0010},
+{0x01E007, 0x0001},
+{0x01E008, 0x0010},
+{0x01E019, 0x0001},
+{0x01E01B, 0x0010},
+{0x01E022, 0x0001},
+{0x01E023, 0x0010},
+{0x01E025, 0x0001},
+{0x01E026, 0x0010},
+{0x01E02B, 0x0001},
+{0x01E030, 0x0004},
+{0x01E06E, 0x0001},
+{0x01E08F, 0x0010},
+{0x01E090, 0x0001},
+{0x01E100, 0x0004},
+{0x01E12D, 0x0001},
+{0x01E130, 0x0010},
+{0x01E137, 0x0004},
+{0x01E13E, 0x0001},
+{0x01E140, 0x0002},
+{0x01E14A, 0x0001},
+{0x01E14E, 0x0004},
+{0x01E14F, 0x0040},
+{0x01E150, 0x0001},
+{0x01E290, 0x0004},
+{0x01E2AE, 0x0010},
+{0x01E2AF, 0x0001},
+{0x01E2C0, 0x0004},
+{0x01E2EC, 0x0010},
+{0x01E2F0, 0x0002},
+{0x01E2FA, 0x0001},
+{0x01E2FF, 0x0040},
+{0x01E300, 0x0001},
+{0x01E4D0, 0x0004},
+{0x01E4EC, 0x0010},
+{0x01E4F0, 0x0002},
+{0x01E4FA, 0x0001},
+{0x01E7E0, 0x0004},
+{0x01E7E7, 0x0001},
+{0x01E7E8, 0x0004},
+{0x01E7EC, 0x0001},
+{0x01E7ED, 0x0004},
+{0x01E7EF, 0x0001},
+{0x01E7F0, 0x0004},
+{0x01E7FF, 0x0001},
+{0x01E800, 0x0004},
+{0x01E8C5, 0x0001},
+{0x01E8C7, 0x0002},
+{0x01E8D0, 0x0010},
+{0x01E8D7, 0x0001},
+{0x01E900, 0x0004},
+{0x01E944, 0x0010},
+{0x01E94B, 0x0004},
+{0x01E94C, 0x0001},
+{0x01E950, 0x0002},
+{0x01E95A, 0x0001},
+{0x01E95E, 0x0020},
+{0x01E960, 0x0001},
+{0x01EC71, 0x0002},
+{0x01ECAC, 0x0040},
+{0x01ECAD, 0x0002},
+{0x01ECB0, 0x0040},
+{0x01ECB1, 0x0002},
+{0x01ECB5, 0x0001},
+{0x01ED01, 0x0002},
+{0x01ED2E, 0x0040},
+{0x01ED2F, 0x0002},
+{0x01ED3E, 0x0001},
+{0x01EE00, 0x0004},
+{0x01EE04, 0x0001},
+{0x01EE05, 0x0004},
+{0x01EE20, 0x0001},
+{0x01EE21, 0x0004},
+{0x01EE23, 0x0001},
+{0x01EE24, 0x0004},
+{0x01EE25, 0x0001},
+{0x01EE27, 0x0004},
+{0x01EE28, 0x0001},
+{0x01EE29, 0x0004},
+{0x01EE33, 0x0001},
+{0x01EE34, 0x0004},
+{0x01EE38, 0x0001},
+{0x01EE39, 0x0004},
+{0x01EE3A, 0x0001},
+{0x01EE3B, 0x0004},
+{0x01EE3C, 0x0001},
+{0x01EE42, 0x0004},
+{0x01EE43, 0x0001},
+{0x01EE47, 0x0004},
+{0x01EE48, 0x0001},
+{0x01EE49, 0x0004},
+{0x01EE4A, 0x0001},
+{0x01EE4B, 0x0004},
+{0x01EE4C, 0x0001},
+{0x01EE4D, 0x0004},
+{0x01EE50, 0x0001},
+{0x01EE51, 0x0004},
+{0x01EE53, 0x0001},
+{0x01EE54, 0x0004},
+{0x01EE55, 0x0001},
+{0x01EE57, 0x0004},
+{0x01EE58, 0x0001},
+{0x01EE59, 0x0004},
+{0x01EE5A, 0x0001},
+{0x01EE5B, 0x0004},
+{0x01EE5C, 0x0001},
+{0x01EE5D, 0x0004},
+{0x01EE5E, 0x0001},
+{0x01EE5F, 0x0004},
+{0x01EE60, 0x0001},
+{0x01EE61, 0x0004},
+{0x01EE63, 0x0001},
+{0x01EE64, 0x0004},
+{0x01EE65, 0x0001},
+{0x01EE67, 0x0004},
+{0x01EE6B, 0x0001},
+{0x01EE6C, 0x0004},
+{0x01EE73, 0x0001},
+{0x01EE74, 0x0004},
+{0x01EE78, 0x0001},
+{0x01EE79, 0x0004},
+{0x01EE7D, 0x0001},
+{0x01EE7E, 0x0004},
+{0x01EE7F, 0x0001},
+{0x01EE80, 0x0004},
+{0x01EE8A, 0x0001},
+{0x01EE8B, 0x0004},
+{0x01EE9C, 0x0001},
+{0x01EEA1, 0x0004},
+{0x01EEA4, 0x0001},
+{0x01EEA5, 0x0004},
+{0x01EEAA, 0x0001},
+{0x01EEAB, 0x0004},
+{0x01EEBC, 0x0001},
+{0x01EEF0, 0x0040},
+{0x01EEF2, 0x0001},
+{0x01F000, 0x0040},
+{0x01F02C, 0x0001},
+{0x01F030, 0x0040},
+{0x01F094, 0x0001},
+{0x01F0A0, 0x0040},
+{0x01F0AF, 0x0001},
+{0x01F0B1, 0x0040},
+{0x01F0C0, 0x0001},
+{0x01F0C1, 0x0040},
+{0x01F0D0, 0x0001},
+{0x01F0D1, 0x0040},
+{0x01F0F6, 0x0001},
+{0x01F100, 0x0002},
+{0x01F10D, 0x0040},
+{0x01F1AE, 0x0001},
+{0x01F1E6, 0x0040},
+{0x01F203, 0x0001},
+{0x01F210, 0x0040},
+{0x01F23C, 0x0001},
+{0x01F240, 0x0040},
+{0x01F249, 0x0001},
+{0x01F250, 0x0040},
+{0x01F252, 0x0001},
+{0x01F260, 0x0040},
+{0x01F266, 0x0001},
+{0x01F300, 0x0040},
+{0x01F6D8, 0x0001},
+{0x01F6DC, 0x0040},
+{0x01F6ED, 0x0001},
+{0x01F6F0, 0x0040},
+{0x01F6FD, 0x0001},
+{0x01F700, 0x0040},
+{0x01F777, 0x0001},
+{0x01F77B, 0x0040},
+{0x01F7DA, 0x0001},
+{0x01F7E0, 0x0040},
+{0x01F7EC, 0x0001},
+{0x01F7F0, 0x0040},
+{0x01F7F1, 0x0001},
+{0x01F800, 0x0040},
+{0x01F80C, 0x0001},
+{0x01F810, 0x0040},
+{0x01F848, 0x0001},
+{0x01F850, 0x0040},
+{0x01F85A, 0x0001},
+{0x01F860, 0x0040},
+{0x01F888, 0x0001},
+{0x01F890, 0x0040},
+{0x01F8AE, 0x0001},
+{0x01F8B0, 0x0040},
+{0x01F8B2, 0x0001},
+{0x01F900, 0x0040},
+{0x01FA54, 0x0001},
+{0x01FA60, 0x0040},
+{0x01FA6E, 0x0001},
+{0x01FA70, 0x0040},
+{0x01FA7D, 0x0001},
+{0x01FA80, 0x0040},
+{0x01FA89, 0x0001},
+{0x01FA90, 0x0040},
+{0x01FABE, 0x0001},
+{0x01FABF, 0x0040},
+{0x01FAC6, 0x0001},
+{0x01FACE, 0x0040},
+{0x01FADC, 0x0001},
+{0x01FAE0, 0x0040},
+{0x01FAE9, 0x0001},
+{0x01FAF0, 0x0040},
+{0x01FAF9, 0x0001},
+{0x01FB00, 0x0040},
+{0x01FB93, 0x0001},
+{0x01FB94, 0x0040},
+{0x01FBCB, 0x0001},
+{0x01FBF0, 0x0002},
+{0x01FBFA, 0x0001},
+{0x020000, 0x0004},
+{0x02A6E0, 0x0001},
+{0x02A700, 0x0004},
+{0x02B73A, 0x0001},
+{0x02B740, 0x0004},
+{0x02B81E, 0x0001},
+{0x02B820, 0x0004},
+{0x02CEA2, 0x0001},
+{0x02CEB0, 0x0004},
+{0x02EBE1, 0x0001},
+{0x02EBF0, 0x0004},
+{0x02EE5E, 0x0001},
+{0x02F800, 0x0004},
+{0x02FA1E, 0x0001},
+{0x030000, 0x0004},
+{0x03134B, 0x0001},
+{0x031350, 0x0004},
+{0x0323B0, 0x0001},
+{0x0E0001, 0x0080},
+{0x0E0002, 0x0001},
+{0x0E0020, 0x0080},
+{0x0E0080, 0x0001},
+{0x0E0100, 0x0010},
+{0x0E01F0, 0x0001},
+{0x0F0000, 0x0080},
+{0x0FFFFE, 0x0001},
+{0x100000, 0x0080},
+{0x10FFFE, 0x0001},
+{0x110000, 0x0000},
+};
+
+const std::unordered_set unicode_set_whitespace = {
+0x000009,
+0x00000A,
+0x00000B,
+0x00000C,
+0x00000D,
+0x000020,
+0x000085,
+0x0000A0,
+0x001680,
+0x002000,
+0x002001,
+0x002002,
+0x002003,
+0x002004,
+0x002005,
+0x002006,
+0x002007,
+0x002008,
+0x002009,
+0x00200A,
+0x002028,
+0x002029,
+0x00202F,
+0x00205F,
+0x003000,
+};
+
+// list is always in ascending order, to enable binary search
+const std::initializer_list> unicode_map_lowercase = {
+{0x000041, 0x000061},
+{0x000042, 0x000062},
+{0x000043, 0x000063},
+{0x000044, 0x000064},
+{0x000045, 0x000065},
+{0x000046, 0x000066},
+{0x000047, 0x000067},
+{0x000048, 0x000068},
+{0x000049, 0x000069},
+{0x00004A, 0x00006A},
+{0x00004B, 0x00006B},
+{0x00004C, 0x00006C},
+{0x00004D, 0x00006D},
+{0x00004E, 0x00006E},
+{0x00004F, 0x00006F},
+{0x000050, 0x000070},
+{0x000051, 0x000071},
+{0x000052, 0x000072},
+{0x000053, 0x000073},
+{0x000054, 0x000074},
+{0x000055, 0x000075},
+{0x000056, 0x000076},
+{0x000057, 0x000077},
+{0x000058, 0x000078},
+{0x000059, 0x000079},
+{0x00005A, 0x00007A},
+{0x0000C0, 0x0000E0},
+{0x0000C1, 0x0000E1},
+{0x0000C2, 0x0000E2},
+{0x0000C3, 0x0000E3},
+{0x0000C4, 0x0000E4},
+{0x0000C5, 0x0000E5},
+{0x0000C6, 0x0000E6},
+{0x0000C7, 0x0000E7},
+{0x0000C8, 0x0000E8},
+{0x0000C9, 0x0000E9},
+{0x0000CA, 0x0000EA},
+{0x0000CB, 0x0000EB},
+{0x0000CC, 0x0000EC},
+{0x0000CD, 0x0000ED},
+{0x0000CE, 0x0000EE},
+{0x0000CF, 0x0000EF},
+{0x0000D0, 0x0000F0},
+{0x0000D1, 0x0000F1},
+{0x0000D2, 0x0000F2},
+{0x0000D3, 0x0000F3},
+{0x0000D4, 0x0000F4},
+{0x0000D5, 0x0000F5},
+{0x0000D6, 0x0000F6},
+{0x0000D8, 0x0000F8},
+{0x0000D9, 0x0000F9},
+{0x0000DA, 0x0000FA},
+{0x0000DB, 0x0000FB},
+{0x0000DC, 0x0000FC},
+{0x0000DD, 0x0000FD},
+{0x0000DE, 0x0000FE},
+{0x000100, 0x000101},
+{0x000102, 0x000103},
+{0x000104, 0x000105},
+{0x000106, 0x000107},
+{0x000108, 0x000109},
+{0x00010A, 0x00010B},
+{0x00010C, 0x00010D},
+{0x00010E, 0x00010F},
+{0x000110, 0x000111},
+{0x000112, 0x000113},
+{0x000114, 0x000115},
+{0x000116, 0x000117},
+{0x000118, 0x000119},
+{0x00011A, 0x00011B},
+{0x00011C, 0x00011D},
+{0x00011E, 0x00011F},
+{0x000120, 0x000121},
+{0x000122, 0x000123},
+{0x000124, 0x000125},
+{0x000126, 0x000127},
+{0x000128, 0x000129},
+{0x00012A, 0x00012B},
+{0x00012C, 0x00012D},
+{0x00012E, 0x00012F},
+{0x000130, 0x000069},
+{0x000132, 0x000133},
+{0x000134, 0x000135},
+{0x000136, 0x000137},
+{0x000139, 0x00013A},
+{0x00013B, 0x00013C},
+{0x00013D, 0x00013E},
+{0x00013F, 0x000140},
+{0x000141, 0x000142},
+{0x000143, 0x000144},
+{0x000145, 0x000146},
+{0x000147, 0x000148},
+{0x00014A, 0x00014B},
+{0x00014C, 0x00014D},
+{0x00014E, 0x00014F},
+{0x000150, 0x000151},
+{0x000152, 0x000153},
+{0x000154, 0x000155},
+{0x000156, 0x000157},
+{0x000158, 0x000159},
+{0x00015A, 0x00015B},
+{0x00015C, 0x00015D},
+{0x00015E, 0x00015F},
+{0x000160, 0x000161},
+{0x000162, 0x000163},
+{0x000164, 0x000165},
+{0x000166, 0x000167},
+{0x000168, 0x000169},
+{0x00016A, 0x00016B},
+{0x00016C, 0x00016D},
+{0x00016E, 0x00016F},
+{0x000170, 0x000171},
+{0x000172, 0x000173},
+{0x000174, 0x000175},
+{0x000176, 0x000177},
+{0x000178, 0x0000FF},
+{0x000179, 0x00017A},
+{0x00017B, 0x00017C},
+{0x00017D, 0x00017E},
+{0x000181, 0x000253},
+{0x000182, 0x000183},
+{0x000184, 0x000185},
+{0x000186, 0x000254},
+{0x000187, 0x000188},
+{0x000189, 0x000256},
+{0x00018A, 0x000257},
+{0x00018B, 0x00018C},
+{0x00018E, 0x0001DD},
+{0x00018F, 0x000259},
+{0x000190, 0x00025B},
+{0x000191, 0x000192},
+{0x000193, 0x000260},
+{0x000194, 0x000263},
+{0x000196, 0x000269},
+{0x000197, 0x000268},
+{0x000198, 0x000199},
+{0x00019C, 0x00026F},
+{0x00019D, 0x000272},
+{0x00019F, 0x000275},
+{0x0001A0, 0x0001A1},
+{0x0001A2, 0x0001A3},
+{0x0001A4, 0x0001A5},
+{0x0001A6, 0x000280},
+{0x0001A7, 0x0001A8},
+{0x0001A9, 0x000283},
+{0x0001AC, 0x0001AD},
+{0x0001AE, 0x000288},
+{0x0001AF, 0x0001B0},
+{0x0001B1, 0x00028A},
+{0x0001B2, 0x00028B},
+{0x0001B3, 0x0001B4},
+{0x0001B5, 0x0001B6},
+{0x0001B7, 0x000292},
+{0x0001B8, 0x0001B9},
+{0x0001BC, 0x0001BD},
+{0x0001C4, 0x0001C6},
+{0x0001C5, 0x0001C6},
+{0x0001C7, 0x0001C9},
+{0x0001C8, 0x0001C9},
+{0x0001CA, 0x0001CC},
+{0x0001CB, 0x0001CC},
+{0x0001CD, 0x0001CE},
+{0x0001CF, 0x0001D0},
+{0x0001D1, 0x0001D2},
+{0x0001D3, 0x0001D4},
+{0x0001D5, 0x0001D6},
+{0x0001D7, 0x0001D8},
+{0x0001D9, 0x0001DA},
+{0x0001DB, 0x0001DC},
+{0x0001DE, 0x0001DF},
+{0x0001E0, 0x0001E1},
+{0x0001E2, 0x0001E3},
+{0x0001E4, 0x0001E5},
+{0x0001E6, 0x0001E7},
+{0x0001E8, 0x0001E9},
+{0x0001EA, 0x0001EB},
+{0x0001EC, 0x0001ED},
+{0x0001EE, 0x0001EF},
+{0x0001F1, 0x0001F3},
+{0x0001F2, 0x0001F3},
+{0x0001F4, 0x0001F5},
+{0x0001F6, 0x000195},
+{0x0001F7, 0x0001BF},
+{0x0001F8, 0x0001F9},
+{0x0001FA, 0x0001FB},
+{0x0001FC, 0x0001FD},
+{0x0001FE, 0x0001FF},
+{0x000200, 0x000201},
+{0x000202, 0x000203},
+{0x000204, 0x000205},
+{0x000206, 0x000207},
+{0x000208, 0x000209},
+{0x00020A, 0x00020B},
+{0x00020C, 0x00020D},
+{0x00020E, 0x00020F},
+{0x000210, 0x000211},
+{0x000212, 0x000213},
+{0x000214, 0x000215},
+{0x000216, 0x000217},
+{0x000218, 0x000219},
+{0x00021A, 0x00021B},
+{0x00021C, 0x00021D},
+{0x00021E, 0x00021F},
+{0x000220, 0x00019E},
+{0x000222, 0x000223},
+{0x000224, 0x000225},
+{0x000226, 0x000227},
+{0x000228, 0x000229},
+{0x00022A, 0x00022B},
+{0x00022C, 0x00022D},
+{0x00022E, 0x00022F},
+{0x000230, 0x000231},
+{0x000232, 0x000233},
+{0x00023A, 0x002C65},
+{0x00023B, 0x00023C},
+{0x00023D, 0x00019A},
+{0x00023E, 0x002C66},
+{0x000241, 0x000242},
+{0x000243, 0x000180},
+{0x000244, 0x000289},
+{0x000245, 0x00028C},
+{0x000246, 0x000247},
+{0x000248, 0x000249},
+{0x00024A, 0x00024B},
+{0x00024C, 0x00024D},
+{0x00024E, 0x00024F},
+{0x000370, 0x000371},
+{0x000372, 0x000373},
+{0x000376, 0x000377},
+{0x00037F, 0x0003F3},
+{0x000386, 0x0003AC},
+{0x000388, 0x0003AD},
+{0x000389, 0x0003AE},
+{0x00038A, 0x0003AF},
+{0x00038C, 0x0003CC},
+{0x00038E, 0x0003CD},
+{0x00038F, 0x0003CE},
+{0x000391, 0x0003B1},
+{0x000392, 0x0003B2},
+{0x000393, 0x0003B3},
+{0x000394, 0x0003B4},
+{0x000395, 0x0003B5},
+{0x000396, 0x0003B6},
+{0x000397, 0x0003B7},
+{0x000398, 0x0003B8},
+{0x000399, 0x0003B9},
+{0x00039A, 0x0003BA},
+{0x00039B, 0x0003BB},
+{0x00039C, 0x0003BC},
+{0x00039D, 0x0003BD},
+{0x00039E, 0x0003BE},
+{0x00039F, 0x0003BF},
+{0x0003A0, 0x0003C0},
+{0x0003A1, 0x0003C1},
+{0x0003A3, 0x0003C3},
+{0x0003A4, 0x0003C4},
+{0x0003A5, 0x0003C5},
+{0x0003A6, 0x0003C6},
+{0x0003A7, 0x0003C7},
+{0x0003A8, 0x0003C8},
+{0x0003A9, 0x0003C9},
+{0x0003AA, 0x0003CA},
+{0x0003AB, 0x0003CB},
+{0x0003CF, 0x0003D7},
+{0x0003D8, 0x0003D9},
+{0x0003DA, 0x0003DB},
+{0x0003DC, 0x0003DD},
+{0x0003DE, 0x0003DF},
+{0x0003E0, 0x0003E1},
+{0x0003E2, 0x0003E3},
+{0x0003E4, 0x0003E5},
+{0x0003E6, 0x0003E7},
+{0x0003E8, 0x0003E9},
+{0x0003EA, 0x0003EB},
+{0x0003EC, 0x0003ED},
+{0x0003EE, 0x0003EF},
+{0x0003F4, 0x0003B8},
+{0x0003F7, 0x0003F8},
+{0x0003F9, 0x0003F2},
+{0x0003FA, 0x0003FB},
+{0x0003FD, 0x00037B},
+{0x0003FE, 0x00037C},
+{0x0003FF, 0x00037D},
+{0x000400, 0x000450},
+{0x000401, 0x000451},
+{0x000402, 0x000452},
+{0x000403, 0x000453},
+{0x000404, 0x000454},
+{0x000405, 0x000455},
+{0x000406, 0x000456},
+{0x000407, 0x000457},
+{0x000408, 0x000458},
+{0x000409, 0x000459},
+{0x00040A, 0x00045A},
+{0x00040B, 0x00045B},
+{0x00040C, 0x00045C},
+{0x00040D, 0x00045D},
+{0x00040E, 0x00045E},
+{0x00040F, 0x00045F},
+{0x000410, 0x000430},
+{0x000411, 0x000431},
+{0x000412, 0x000432},
+{0x000413, 0x000433},
+{0x000414, 0x000434},
+{0x000415, 0x000435},
+{0x000416, 0x000436},
+{0x000417, 0x000437},
+{0x000418, 0x000438},
+{0x000419, 0x000439},
+{0x00041A, 0x00043A},
+{0x00041B, 0x00043B},
+{0x00041C, 0x00043C},
+{0x00041D, 0x00043D},
+{0x00041E, 0x00043E},
+{0x00041F, 0x00043F},
+{0x000420, 0x000440},
+{0x000421, 0x000441},
+{0x000422, 0x000442},
+{0x000423, 0x000443},
+{0x000424, 0x000444},
+{0x000425, 0x000445},
+{0x000426, 0x000446},
+{0x000427, 0x000447},
+{0x000428, 0x000448},
+{0x000429, 0x000449},
+{0x00042A, 0x00044A},
+{0x00042B, 0x00044B},
+{0x00042C, 0x00044C},
+{0x00042D, 0x00044D},
+{0x00042E, 0x00044E},
+{0x00042F, 0x00044F},
+{0x000460, 0x000461},
+{0x000462, 0x000463},
+{0x000464, 0x000465},
+{0x000466, 0x000467},
+{0x000468, 0x000469},
+{0x00046A, 0x00046B},
+{0x00046C, 0x00046D},
+{0x00046E, 0x00046F},
+{0x000470, 0x000471},
+{0x000472, 0x000473},
+{0x000474, 0x000475},
+{0x000476, 0x000477},
+{0x000478, 0x000479},
+{0x00047A, 0x00047B},
+{0x00047C, 0x00047D},
+{0x00047E, 0x00047F},
+{0x000480, 0x000481},
+{0x00048A, 0x00048B},
+{0x00048C, 0x00048D},
+{0x00048E, 0x00048F},
+{0x000490, 0x000491},
+{0x000492, 0x000493},
+{0x000494, 0x000495},
+{0x000496, 0x000497},
+{0x000498, 0x000499},
+{0x00049A, 0x00049B},
+{0x00049C, 0x00049D},
+{0x00049E, 0x00049F},
+{0x0004A0, 0x0004A1},
+{0x0004A2, 0x0004A3},
+{0x0004A4, 0x0004A5},
+{0x0004A6, 0x0004A7},
+{0x0004A8, 0x0004A9},
+{0x0004AA, 0x0004AB},
+{0x0004AC, 0x0004AD},
+{0x0004AE, 0x0004AF},
+{0x0004B0, 0x0004B1},
+{0x0004B2, 0x0004B3},
+{0x0004B4, 0x0004B5},
+{0x0004B6, 0x0004B7},
+{0x0004B8, 0x0004B9},
+{0x0004BA, 0x0004BB},
+{0x0004BC, 0x0004BD},
+{0x0004BE, 0x0004BF},
+{0x0004C0, 0x0004CF},
+{0x0004C1, 0x0004C2},
+{0x0004C3, 0x0004C4},
+{0x0004C5, 0x0004C6},
+{0x0004C7, 0x0004C8},
+{0x0004C9, 0x0004CA},
+{0x0004CB, 0x0004CC},
+{0x0004CD, 0x0004CE},
+{0x0004D0, 0x0004D1},
+{0x0004D2, 0x0004D3},
+{0x0004D4, 0x0004D5},
+{0x0004D6, 0x0004D7},
+{0x0004D8, 0x0004D9},
+{0x0004DA, 0x0004DB},
+{0x0004DC, 0x0004DD},
+{0x0004DE, 0x0004DF},
+{0x0004E0, 0x0004E1},
+{0x0004E2, 0x0004E3},
+{0x0004E4, 0x0004E5},
+{0x0004E6, 0x0004E7},
+{0x0004E8, 0x0004E9},
+{0x0004EA, 0x0004EB},
+{0x0004EC, 0x0004ED},
+{0x0004EE, 0x0004EF},
+{0x0004F0, 0x0004F1},
+{0x0004F2, 0x0004F3},
+{0x0004F4, 0x0004F5},
+{0x0004F6, 0x0004F7},
+{0x0004F8, 0x0004F9},
+{0x0004FA, 0x0004FB},
+{0x0004FC, 0x0004FD},
+{0x0004FE, 0x0004FF},
+{0x000500, 0x000501},
+{0x000502, 0x000503},
+{0x000504, 0x000505},
+{0x000506, 0x000507},
+{0x000508, 0x000509},
+{0x00050A, 0x00050B},
+{0x00050C, 0x00050D},
+{0x00050E, 0x00050F},
+{0x000510, 0x000511},
+{0x000512, 0x000513},
+{0x000514, 0x000515},
+{0x000516, 0x000517},
+{0x000518, 0x000519},
+{0x00051A, 0x00051B},
+{0x00051C, 0x00051D},
+{0x00051E, 0x00051F},
+{0x000520, 0x000521},
+{0x000522, 0x000523},
+{0x000524, 0x000525},
+{0x000526, 0x000527},
+{0x000528, 0x000529},
+{0x00052A, 0x00052B},
+{0x00052C, 0x00052D},
+{0x00052E, 0x00052F},
+{0x000531, 0x000561},
+{0x000532, 0x000562},
+{0x000533, 0x000563},
+{0x000534, 0x000564},
+{0x000535, 0x000565},
+{0x000536, 0x000566},
+{0x000537, 0x000567},
+{0x000538, 0x000568},
+{0x000539, 0x000569},
+{0x00053A, 0x00056A},
+{0x00053B, 0x00056B},
+{0x00053C, 0x00056C},
+{0x00053D, 0x00056D},
+{0x00053E, 0x00056E},
+{0x00053F, 0x00056F},
+{0x000540, 0x000570},
+{0x000541, 0x000571},
+{0x000542, 0x000572},
+{0x000543, 0x000573},
+{0x000544, 0x000574},
+{0x000545, 0x000575},
+{0x000546, 0x000576},
+{0x000547, 0x000577},
+{0x000548, 0x000578},
+{0x000549, 0x000579},
+{0x00054A, 0x00057A},
+{0x00054B, 0x00057B},
+{0x00054C, 0x00057C},
+{0x00054D, 0x00057D},
+{0x00054E, 0x00057E},
+{0x00054F, 0x00057F},
+{0x000550, 0x000580},
+{0x000551, 0x000581},
+{0x000552, 0x000582},
+{0x000553, 0x000583},
+{0x000554, 0x000584},
+{0x000555, 0x000585},
+{0x000556, 0x000586},
+{0x0010A0, 0x002D00},
+{0x0010A1, 0x002D01},
+{0x0010A2, 0x002D02},
+{0x0010A3, 0x002D03},
+{0x0010A4, 0x002D04},
+{0x0010A5, 0x002D05},
+{0x0010A6, 0x002D06},
+{0x0010A7, 0x002D07},
+{0x0010A8, 0x002D08},
+{0x0010A9, 0x002D09},
+{0x0010AA, 0x002D0A},
+{0x0010AB, 0x002D0B},
+{0x0010AC, 0x002D0C},
+{0x0010AD, 0x002D0D},
+{0x0010AE, 0x002D0E},
+{0x0010AF, 0x002D0F},
+{0x0010B0, 0x002D10},
+{0x0010B1, 0x002D11},
+{0x0010B2, 0x002D12},
+{0x0010B3, 0x002D13},
+{0x0010B4, 0x002D14},
+{0x0010B5, 0x002D15},
+{0x0010B6, 0x002D16},
+{0x0010B7, 0x002D17},
+{0x0010B8, 0x002D18},
+{0x0010B9, 0x002D19},
+{0x0010BA, 0x002D1A},
+{0x0010BB, 0x002D1B},
+{0x0010BC, 0x002D1C},
+{0x0010BD, 0x002D1D},
+{0x0010BE, 0x002D1E},
+{0x0010BF, 0x002D1F},
+{0x0010C0, 0x002D20},
+{0x0010C1, 0x002D21},
+{0x0010C2, 0x002D22},
+{0x0010C3, 0x002D23},
+{0x0010C4, 0x002D24},
+{0x0010C5, 0x002D25},
+{0x0010C7, 0x002D27},
+{0x0010CD, 0x002D2D},
+{0x0013A0, 0x00AB70},
+{0x0013A1, 0x00AB71},
+{0x0013A2, 0x00AB72},
+{0x0013A3, 0x00AB73},
+{0x0013A4, 0x00AB74},
+{0x0013A5, 0x00AB75},
+{0x0013A6, 0x00AB76},
+{0x0013A7, 0x00AB77},
+{0x0013A8, 0x00AB78},
+{0x0013A9, 0x00AB79},
+{0x0013AA, 0x00AB7A},
+{0x0013AB, 0x00AB7B},
+{0x0013AC, 0x00AB7C},
+{0x0013AD, 0x00AB7D},
+{0x0013AE, 0x00AB7E},
+{0x0013AF, 0x00AB7F},
+{0x0013B0, 0x00AB80},
+{0x0013B1, 0x00AB81},
+{0x0013B2, 0x00AB82},
+{0x0013B3, 0x00AB83},
+{0x0013B4, 0x00AB84},
+{0x0013B5, 0x00AB85},
+{0x0013B6, 0x00AB86},
+{0x0013B7, 0x00AB87},
+{0x0013B8, 0x00AB88},
+{0x0013B9, 0x00AB89},
+{0x0013BA, 0x00AB8A},
+{0x0013BB, 0x00AB8B},
+{0x0013BC, 0x00AB8C},
+{0x0013BD, 0x00AB8D},
+{0x0013BE, 0x00AB8E},
+{0x0013BF, 0x00AB8F},
+{0x0013C0, 0x00AB90},
+{0x0013C1, 0x00AB91},
+{0x0013C2, 0x00AB92},
+{0x0013C3, 0x00AB93},
+{0x0013C4, 0x00AB94},
+{0x0013C5, 0x00AB95},
+{0x0013C6, 0x00AB96},
+{0x0013C7, 0x00AB97},
+{0x0013C8, 0x00AB98},
+{0x0013C9, 0x00AB99},
+{0x0013CA, 0x00AB9A},
+{0x0013CB, 0x00AB9B},
+{0x0013CC, 0x00AB9C},
+{0x0013CD, 0x00AB9D},
+{0x0013CE, 0x00AB9E},
+{0x0013CF, 0x00AB9F},
+{0x0013D0, 0x00ABA0},
+{0x0013D1, 0x00ABA1},
+{0x0013D2, 0x00ABA2},
+{0x0013D3, 0x00ABA3},
+{0x0013D4, 0x00ABA4},
+{0x0013D5, 0x00ABA5},
+{0x0013D6, 0x00ABA6},
+{0x0013D7, 0x00ABA7},
+{0x0013D8, 0x00ABA8},
+{0x0013D9, 0x00ABA9},
+{0x0013DA, 0x00ABAA},
+{0x0013DB, 0x00ABAB},
+{0x0013DC, 0x00ABAC},
+{0x0013DD, 0x00ABAD},
+{0x0013DE, 0x00ABAE},
+{0x0013DF, 0x00ABAF},
+{0x0013E0, 0x00ABB0},
+{0x0013E1, 0x00ABB1},
+{0x0013E2, 0x00ABB2},
+{0x0013E3, 0x00ABB3},
+{0x0013E4, 0x00ABB4},
+{0x0013E5, 0x00ABB5},
+{0x0013E6, 0x00ABB6},
+{0x0013E7, 0x00ABB7},
+{0x0013E8, 0x00ABB8},
+{0x0013E9, 0x00ABB9},
+{0x0013EA, 0x00ABBA},
+{0x0013EB, 0x00ABBB},
+{0x0013EC, 0x00ABBC},
+{0x0013ED, 0x00ABBD},
+{0x0013EE, 0x00ABBE},
+{0x0013EF, 0x00ABBF},
+{0x0013F0, 0x0013F8},
+{0x0013F1, 0x0013F9},
+{0x0013F2, 0x0013FA},
+{0x0013F3, 0x0013FB},
+{0x0013F4, 0x0013FC},
+{0x0013F5, 0x0013FD},
+{0x001C90, 0x0010D0},
+{0x001C91, 0x0010D1},
+{0x001C92, 0x0010D2},
+{0x001C93, 0x0010D3},
+{0x001C94, 0x0010D4},
+{0x001C95, 0x0010D5},
+{0x001C96, 0x0010D6},
+{0x001C97, 0x0010D7},
+{0x001C98, 0x0010D8},
+{0x001C99, 0x0010D9},
+{0x001C9A, 0x0010DA},
+{0x001C9B, 0x0010DB},
+{0x001C9C, 0x0010DC},
+{0x001C9D, 0x0010DD},
+{0x001C9E, 0x0010DE},
+{0x001C9F, 0x0010DF},
+{0x001CA0, 0x0010E0},
+{0x001CA1, 0x0010E1},
+{0x001CA2, 0x0010E2},
+{0x001CA3, 0x0010E3},
+{0x001CA4, 0x0010E4},
+{0x001CA5, 0x0010E5},
+{0x001CA6, 0x0010E6},
+{0x001CA7, 0x0010E7},
+{0x001CA8, 0x0010E8},
+{0x001CA9, 0x0010E9},
+{0x001CAA, 0x0010EA},
+{0x001CAB, 0x0010EB},
+{0x001CAC, 0x0010EC},
+{0x001CAD, 0x0010ED},
+{0x001CAE, 0x0010EE},
+{0x001CAF, 0x0010EF},
+{0x001CB0, 0x0010F0},
+{0x001CB1, 0x0010F1},
+{0x001CB2, 0x0010F2},
+{0x001CB3, 0x0010F3},
+{0x001CB4, 0x0010F4},
+{0x001CB5, 0x0010F5},
+{0x001CB6, 0x0010F6},
+{0x001CB7, 0x0010F7},
+{0x001CB8, 0x0010F8},
+{0x001CB9, 0x0010F9},
+{0x001CBA, 0x0010FA},
+{0x001CBD, 0x0010FD},
+{0x001CBE, 0x0010FE},
+{0x001CBF, 0x0010FF},
+{0x001E00, 0x001E01},
+{0x001E02, 0x001E03},
+{0x001E04, 0x001E05},
+{0x001E06, 0x001E07},
+{0x001E08, 0x001E09},
+{0x001E0A, 0x001E0B},
+{0x001E0C, 0x001E0D},
+{0x001E0E, 0x001E0F},
+{0x001E10, 0x001E11},
+{0x001E12, 0x001E13},
+{0x001E14, 0x001E15},
+{0x001E16, 0x001E17},
+{0x001E18, 0x001E19},
+{0x001E1A, 0x001E1B},
+{0x001E1C, 0x001E1D},
+{0x001E1E, 0x001E1F},
+{0x001E20, 0x001E21},
+{0x001E22, 0x001E23},
+{0x001E24, 0x001E25},
+{0x001E26, 0x001E27},
+{0x001E28, 0x001E29},
+{0x001E2A, 0x001E2B},
+{0x001E2C, 0x001E2D},
+{0x001E2E, 0x001E2F},
+{0x001E30, 0x001E31},
+{0x001E32, 0x001E33},
+{0x001E34, 0x001E35},
+{0x001E36, 0x001E37},
+{0x001E38, 0x001E39},
+{0x001E3A, 0x001E3B},
+{0x001E3C, 0x001E3D},
+{0x001E3E, 0x001E3F},
+{0x001E40, 0x001E41},
+{0x001E42, 0x001E43},
+{0x001E44, 0x001E45},
+{0x001E46, 0x001E47},
+{0x001E48, 0x001E49},
+{0x001E4A, 0x001E4B},
+{0x001E4C, 0x001E4D},
+{0x001E4E, 0x001E4F},
+{0x001E50, 0x001E51},
+{0x001E52, 0x001E53},
+{0x001E54, 0x001E55},
+{0x001E56, 0x001E57},
+{0x001E58, 0x001E59},
+{0x001E5A, 0x001E5B},
+{0x001E5C, 0x001E5D},
+{0x001E5E, 0x001E5F},
+{0x001E60, 0x001E61},
+{0x001E62, 0x001E63},
+{0x001E64, 0x001E65},
+{0x001E66, 0x001E67},
+{0x001E68, 0x001E69},
+{0x001E6A, 0x001E6B},
+{0x001E6C, 0x001E6D},
+{0x001E6E, 0x001E6F},
+{0x001E70, 0x001E71},
+{0x001E72, 0x001E73},
+{0x001E74, 0x001E75},
+{0x001E76, 0x001E77},
+{0x001E78, 0x001E79},
+{0x001E7A, 0x001E7B},
+{0x001E7C, 0x001E7D},
+{0x001E7E, 0x001E7F},
+{0x001E80, 0x001E81},
+{0x001E82, 0x001E83},
+{0x001E84, 0x001E85},
+{0x001E86, 0x001E87},
+{0x001E88, 0x001E89},
+{0x001E8A, 0x001E8B},
+{0x001E8C, 0x001E8D},
+{0x001E8E, 0x001E8F},
+{0x001E90, 0x001E91},
+{0x001E92, 0x001E93},
+{0x001E94, 0x001E95},
+{0x001E9E, 0x0000DF},
+{0x001EA0, 0x001EA1},
+{0x001EA2, 0x001EA3},
+{0x001EA4, 0x001EA5},
+{0x001EA6, 0x001EA7},
+{0x001EA8, 0x001EA9},
+{0x001EAA, 0x001EAB},
+{0x001EAC, 0x001EAD},
+{0x001EAE, 0x001EAF},
+{0x001EB0, 0x001EB1},
+{0x001EB2, 0x001EB3},
+{0x001EB4, 0x001EB5},
+{0x001EB6, 0x001EB7},
+{0x001EB8, 0x001EB9},
+{0x001EBA, 0x001EBB},
+{0x001EBC, 0x001EBD},
+{0x001EBE, 0x001EBF},
+{0x001EC0, 0x001EC1},
+{0x001EC2, 0x001EC3},
+{0x001EC4, 0x001EC5},
+{0x001EC6, 0x001EC7},
+{0x001EC8, 0x001EC9},
+{0x001ECA, 0x001ECB},
+{0x001ECC, 0x001ECD},
+{0x001ECE, 0x001ECF},
+{0x001ED0, 0x001ED1},
+{0x001ED2, 0x001ED3},
+{0x001ED4, 0x001ED5},
+{0x001ED6, 0x001ED7},
+{0x001ED8, 0x001ED9},
+{0x001EDA, 0x001EDB},
+{0x001EDC, 0x001EDD},
+{0x001EDE, 0x001EDF},
+{0x001EE0, 0x001EE1},
+{0x001EE2, 0x001EE3},
+{0x001EE4, 0x001EE5},
+{0x001EE6, 0x001EE7},
+{0x001EE8, 0x001EE9},
+{0x001EEA, 0x001EEB},
+{0x001EEC, 0x001EED},
+{0x001EEE, 0x001EEF},
+{0x001EF0, 0x001EF1},
+{0x001EF2, 0x001EF3},
+{0x001EF4, 0x001EF5},
+{0x001EF6, 0x001EF7},
+{0x001EF8, 0x001EF9},
+{0x001EFA, 0x001EFB},
+{0x001EFC, 0x001EFD},
+{0x001EFE, 0x001EFF},
+{0x001F08, 0x001F00},
+{0x001F09, 0x001F01},
+{0x001F0A, 0x001F02},
+{0x001F0B, 0x001F03},
+{0x001F0C, 0x001F04},
+{0x001F0D, 0x001F05},
+{0x001F0E, 0x001F06},
+{0x001F0F, 0x001F07},
+{0x001F18, 0x001F10},
+{0x001F19, 0x001F11},
+{0x001F1A, 0x001F12},
+{0x001F1B, 0x001F13},
+{0x001F1C, 0x001F14},
+{0x001F1D, 0x001F15},
+{0x001F28, 0x001F20},
+{0x001F29, 0x001F21},
+{0x001F2A, 0x001F22},
+{0x001F2B, 0x001F23},
+{0x001F2C, 0x001F24},
+{0x001F2D, 0x001F25},
+{0x001F2E, 0x001F26},
+{0x001F2F, 0x001F27},
+{0x001F38, 0x001F30},
+{0x001F39, 0x001F31},
+{0x001F3A, 0x001F32},
+{0x001F3B, 0x001F33},
+{0x001F3C, 0x001F34},
+{0x001F3D, 0x001F35},
+{0x001F3E, 0x001F36},
+{0x001F3F, 0x001F37},
+{0x001F48, 0x001F40},
+{0x001F49, 0x001F41},
+{0x001F4A, 0x001F42},
+{0x001F4B, 0x001F43},
+{0x001F4C, 0x001F44},
+{0x001F4D, 0x001F45},
+{0x001F59, 0x001F51},
+{0x001F5B, 0x001F53},
+{0x001F5D, 0x001F55},
+{0x001F5F, 0x001F57},
+{0x001F68, 0x001F60},
+{0x001F69, 0x001F61},
+{0x001F6A, 0x001F62},
+{0x001F6B, 0x001F63},
+{0x001F6C, 0x001F64},
+{0x001F6D, 0x001F65},
+{0x001F6E, 0x001F66},
+{0x001F6F, 0x001F67},
+{0x001F88, 0x001F80},
+{0x001F89, 0x001F81},
+{0x001F8A, 0x001F82},
+{0x001F8B, 0x001F83},
+{0x001F8C, 0x001F84},
+{0x001F8D, 0x001F85},
+{0x001F8E, 0x001F86},
+{0x001F8F, 0x001F87},
+{0x001F98, 0x001F90},
+{0x001F99, 0x001F91},
+{0x001F9A, 0x001F92},
+{0x001F9B, 0x001F93},
+{0x001F9C, 0x001F94},
+{0x001F9D, 0x001F95},
+{0x001F9E, 0x001F96},
+{0x001F9F, 0x001F97},
+{0x001FA8, 0x001FA0},
+{0x001FA9, 0x001FA1},
+{0x001FAA, 0x001FA2},
+{0x001FAB, 0x001FA3},
+{0x001FAC, 0x001FA4},
+{0x001FAD, 0x001FA5},
+{0x001FAE, 0x001FA6},
+{0x001FAF, 0x001FA7},
+{0x001FB8, 0x001FB0},
+{0x001FB9, 0x001FB1},
+{0x001FBA, 0x001F70},
+{0x001FBB, 0x001F71},
+{0x001FBC, 0x001FB3},
+{0x001FC8, 0x001F72},
+{0x001FC9, 0x001F73},
+{0x001FCA, 0x001F74},
+{0x001FCB, 0x001F75},
+{0x001FCC, 0x001FC3},
+{0x001FD8, 0x001FD0},
+{0x001FD9, 0x001FD1},
+{0x001FDA, 0x001F76},
+{0x001FDB, 0x001F77},
+{0x001FE8, 0x001FE0},
+{0x001FE9, 0x001FE1},
+{0x001FEA, 0x001F7A},
+{0x001FEB, 0x001F7B},
+{0x001FEC, 0x001FE5},
+{0x001FF8, 0x001F78},
+{0x001FF9, 0x001F79},
+{0x001FFA, 0x001F7C},
+{0x001FFB, 0x001F7D},
+{0x001FFC, 0x001FF3},
+{0x002126, 0x0003C9},
+{0x00212A, 0x00006B},
+{0x00212B, 0x0000E5},
+{0x002132, 0x00214E},
+{0x002160, 0x002170},
+{0x002161, 0x002171},
+{0x002162, 0x002172},
+{0x002163, 0x002173},
+{0x002164, 0x002174},
+{0x002165, 0x002175},
+{0x002166, 0x002176},
+{0x002167, 0x002177},
+{0x002168, 0x002178},
+{0x002169, 0x002179},
+{0x00216A, 0x00217A},
+{0x00216B, 0x00217B},
+{0x00216C, 0x00217C},
+{0x00216D, 0x00217D},
+{0x00216E, 0x00217E},
+{0x00216F, 0x00217F},
+{0x002183, 0x002184},
+{0x0024B6, 0x0024D0},
+{0x0024B7, 0x0024D1},
+{0x0024B8, 0x0024D2},
+{0x0024B9, 0x0024D3},
+{0x0024BA, 0x0024D4},
+{0x0024BB, 0x0024D5},
+{0x0024BC, 0x0024D6},
+{0x0024BD, 0x0024D7},
+{0x0024BE, 0x0024D8},
+{0x0024BF, 0x0024D9},
+{0x0024C0, 0x0024DA},
+{0x0024C1, 0x0024DB},
+{0x0024C2, 0x0024DC},
+{0x0024C3, 0x0024DD},
+{0x0024C4, 0x0024DE},
+{0x0024C5, 0x0024DF},
+{0x0024C6, 0x0024E0},
+{0x0024C7, 0x0024E1},
+{0x0024C8, 0x0024E2},
+{0x0024C9, 0x0024E3},
+{0x0024CA, 0x0024E4},
+{0x0024CB, 0x0024E5},
+{0x0024CC, 0x0024E6},
+{0x0024CD, 0x0024E7},
+{0x0024CE, 0x0024E8},
+{0x0024CF, 0x0024E9},
+{0x002C00, 0x002C30},
+{0x002C01, 0x002C31},
+{0x002C02, 0x002C32},
+{0x002C03, 0x002C33},
+{0x002C04, 0x002C34},
+{0x002C05, 0x002C35},
+{0x002C06, 0x002C36},
+{0x002C07, 0x002C37},
+{0x002C08, 0x002C38},
+{0x002C09, 0x002C39},
+{0x002C0A, 0x002C3A},
+{0x002C0B, 0x002C3B},
+{0x002C0C, 0x002C3C},
+{0x002C0D, 0x002C3D},
+{0x002C0E, 0x002C3E},
+{0x002C0F, 0x002C3F},
+{0x002C10, 0x002C40},
+{0x002C11, 0x002C41},
+{0x002C12, 0x002C42},
+{0x002C13, 0x002C43},
+{0x002C14, 0x002C44},
+{0x002C15, 0x002C45},
+{0x002C16, 0x002C46},
+{0x002C17, 0x002C47},
+{0x002C18, 0x002C48},
+{0x002C19, 0x002C49},
+{0x002C1A, 0x002C4A},
+{0x002C1B, 0x002C4B},
+{0x002C1C, 0x002C4C},
+{0x002C1D, 0x002C4D},
+{0x002C1E, 0x002C4E},
+{0x002C1F, 0x002C4F},
+{0x002C20, 0x002C50},
+{0x002C21, 0x002C51},
+{0x002C22, 0x002C52},
+{0x002C23, 0x002C53},
+{0x002C24, 0x002C54},
+{0x002C25, 0x002C55},
+{0x002C26, 0x002C56},
+{0x002C27, 0x002C57},
+{0x002C28, 0x002C58},
+{0x002C29, 0x002C59},
+{0x002C2A, 0x002C5A},
+{0x002C2B, 0x002C5B},
+{0x002C2C, 0x002C5C},
+{0x002C2D, 0x002C5D},
+{0x002C2E, 0x002C5E},
+{0x002C2F, 0x002C5F},
+{0x002C60, 0x002C61},
+{0x002C62, 0x00026B},
+{0x002C63, 0x001D7D},
+{0x002C64, 0x00027D},
+{0x002C67, 0x002C68},
+{0x002C69, 0x002C6A},
+{0x002C6B, 0x002C6C},
+{0x002C6D, 0x000251},
+{0x002C6E, 0x000271},
+{0x002C6F, 0x000250},
+{0x002C70, 0x000252},
+{0x002C72, 0x002C73},
+{0x002C75, 0x002C76},
+{0x002C7E, 0x00023F},
+{0x002C7F, 0x000240},
+{0x002C80, 0x002C81},
+{0x002C82, 0x002C83},
+{0x002C84, 0x002C85},
+{0x002C86, 0x002C87},
+{0x002C88, 0x002C89},
+{0x002C8A, 0x002C8B},
+{0x002C8C, 0x002C8D},
+{0x002C8E, 0x002C8F},
+{0x002C90, 0x002C91},
+{0x002C92, 0x002C93},
+{0x002C94, 0x002C95},
+{0x002C96, 0x002C97},
+{0x002C98, 0x002C99},
+{0x002C9A, 0x002C9B},
+{0x002C9C, 0x002C9D},
+{0x002C9E, 0x002C9F},
+{0x002CA0, 0x002CA1},
+{0x002CA2, 0x002CA3},
+{0x002CA4, 0x002CA5},
+{0x002CA6, 0x002CA7},
+{0x002CA8, 0x002CA9},
+{0x002CAA, 0x002CAB},
+{0x002CAC, 0x002CAD},
+{0x002CAE, 0x002CAF},
+{0x002CB0, 0x002CB1},
+{0x002CB2, 0x002CB3},
+{0x002CB4, 0x002CB5},
+{0x002CB6, 0x002CB7},
+{0x002CB8, 0x002CB9},
+{0x002CBA, 0x002CBB},
+{0x002CBC, 0x002CBD},
+{0x002CBE, 0x002CBF},
+{0x002CC0, 0x002CC1},
+{0x002CC2, 0x002CC3},
+{0x002CC4, 0x002CC5},
+{0x002CC6, 0x002CC7},
+{0x002CC8, 0x002CC9},
+{0x002CCA, 0x002CCB},
+{0x002CCC, 0x002CCD},
+{0x002CCE, 0x002CCF},
+{0x002CD0, 0x002CD1},
+{0x002CD2, 0x002CD3},
+{0x002CD4, 0x002CD5},
+{0x002CD6, 0x002CD7},
+{0x002CD8, 0x002CD9},
+{0x002CDA, 0x002CDB},
+{0x002CDC, 0x002CDD},
+{0x002CDE, 0x002CDF},
+{0x002CE0, 0x002CE1},
+{0x002CE2, 0x002CE3},
+{0x002CEB, 0x002CEC},
+{0x002CED, 0x002CEE},
+{0x002CF2, 0x002CF3},
+{0x00A640, 0x00A641},
+{0x00A642, 0x00A643},
+{0x00A644, 0x00A645},
+{0x00A646, 0x00A647},
+{0x00A648, 0x00A649},
+{0x00A64A, 0x00A64B},
+{0x00A64C, 0x00A64D},
+{0x00A64E, 0x00A64F},
+{0x00A650, 0x00A651},
+{0x00A652, 0x00A653},
+{0x00A654, 0x00A655},
+{0x00A656, 0x00A657},
+{0x00A658, 0x00A659},
+{0x00A65A, 0x00A65B},
+{0x00A65C, 0x00A65D},
+{0x00A65E, 0x00A65F},
+{0x00A660, 0x00A661},
+{0x00A662, 0x00A663},
+{0x00A664, 0x00A665},
+{0x00A666, 0x00A667},
+{0x00A668, 0x00A669},
+{0x00A66A, 0x00A66B},
+{0x00A66C, 0x00A66D},
+{0x00A680, 0x00A681},
+{0x00A682, 0x00A683},
+{0x00A684, 0x00A685},
+{0x00A686, 0x00A687},
+{0x00A688, 0x00A689},
+{0x00A68A, 0x00A68B},
+{0x00A68C, 0x00A68D},
+{0x00A68E, 0x00A68F},
+{0x00A690, 0x00A691},
+{0x00A692, 0x00A693},
+{0x00A694, 0x00A695},
+{0x00A696, 0x00A697},
+{0x00A698, 0x00A699},
+{0x00A69A, 0x00A69B},
+{0x00A722, 0x00A723},
+{0x00A724, 0x00A725},
+{0x00A726, 0x00A727},
+{0x00A728, 0x00A729},
+{0x00A72A, 0x00A72B},
+{0x00A72C, 0x00A72D},
+{0x00A72E, 0x00A72F},
+{0x00A732, 0x00A733},
+{0x00A734, 0x00A735},
+{0x00A736, 0x00A737},
+{0x00A738, 0x00A739},
+{0x00A73A, 0x00A73B},
+{0x00A73C, 0x00A73D},
+{0x00A73E, 0x00A73F},
+{0x00A740, 0x00A741},
+{0x00A742, 0x00A743},
+{0x00A744, 0x00A745},
+{0x00A746, 0x00A747},
+{0x00A748, 0x00A749},
+{0x00A74A, 0x00A74B},
+{0x00A74C, 0x00A74D},
+{0x00A74E, 0x00A74F},
+{0x00A750, 0x00A751},
+{0x00A752, 0x00A753},
+{0x00A754, 0x00A755},
+{0x00A756, 0x00A757},
+{0x00A758, 0x00A759},
+{0x00A75A, 0x00A75B},
+{0x00A75C, 0x00A75D},
+{0x00A75E, 0x00A75F},
+{0x00A760, 0x00A761},
+{0x00A762, 0x00A763},
+{0x00A764, 0x00A765},
+{0x00A766, 0x00A767},
+{0x00A768, 0x00A769},
+{0x00A76A, 0x00A76B},
+{0x00A76C, 0x00A76D},
+{0x00A76E, 0x00A76F},
+{0x00A779, 0x00A77A},
+{0x00A77B, 0x00A77C},
+{0x00A77D, 0x001D79},
+{0x00A77E, 0x00A77F},
+{0x00A780, 0x00A781},
+{0x00A782, 0x00A783},
+{0x00A784, 0x00A785},
+{0x00A786, 0x00A787},
+{0x00A78B, 0x00A78C},
+{0x00A78D, 0x000265},
+{0x00A790, 0x00A791},
+{0x00A792, 0x00A793},
+{0x00A796, 0x00A797},
+{0x00A798, 0x00A799},
+{0x00A79A, 0x00A79B},
+{0x00A79C, 0x00A79D},
+{0x00A79E, 0x00A79F},
+{0x00A7A0, 0x00A7A1},
+{0x00A7A2, 0x00A7A3},
+{0x00A7A4, 0x00A7A5},
+{0x00A7A6, 0x00A7A7},
+{0x00A7A8, 0x00A7A9},
+{0x00A7AA, 0x000266},
+{0x00A7AB, 0x00025C},
+{0x00A7AC, 0x000261},
+{0x00A7AD, 0x00026C},
+{0x00A7AE, 0x00026A},
+{0x00A7B0, 0x00029E},
+{0x00A7B1, 0x000287},
+{0x00A7B2, 0x00029D},
+{0x00A7B3, 0x00AB53},
+{0x00A7B4, 0x00A7B5},
+{0x00A7B6, 0x00A7B7},
+{0x00A7B8, 0x00A7B9},
+{0x00A7BA, 0x00A7BB},
+{0x00A7BC, 0x00A7BD},
+{0x00A7BE, 0x00A7BF},
+{0x00A7C0, 0x00A7C1},
+{0x00A7C2, 0x00A7C3},
+{0x00A7C4, 0x00A794},
+{0x00A7C5, 0x000282},
+{0x00A7C6, 0x001D8E},
+{0x00A7C7, 0x00A7C8},
+{0x00A7C9, 0x00A7CA},
+{0x00A7D0, 0x00A7D1},
+{0x00A7D6, 0x00A7D7},
+{0x00A7D8, 0x00A7D9},
+{0x00A7F5, 0x00A7F6},
+{0x00FF21, 0x00FF41},
+{0x00FF22, 0x00FF42},
+{0x00FF23, 0x00FF43},
+{0x00FF24, 0x00FF44},
+{0x00FF25, 0x00FF45},
+{0x00FF26, 0x00FF46},
+{0x00FF27, 0x00FF47},
+{0x00FF28, 0x00FF48},
+{0x00FF29, 0x00FF49},
+{0x00FF2A, 0x00FF4A},
+{0x00FF2B, 0x00FF4B},
+{0x00FF2C, 0x00FF4C},
+{0x00FF2D, 0x00FF4D},
+{0x00FF2E, 0x00FF4E},
+{0x00FF2F, 0x00FF4F},
+{0x00FF30, 0x00FF50},
+{0x00FF31, 0x00FF51},
+{0x00FF32, 0x00FF52},
+{0x00FF33, 0x00FF53},
+{0x00FF34, 0x00FF54},
+{0x00FF35, 0x00FF55},
+{0x00FF36, 0x00FF56},
+{0x00FF37, 0x00FF57},
+{0x00FF38, 0x00FF58},
+{0x00FF39, 0x00FF59},
+{0x00FF3A, 0x00FF5A},
+{0x010400, 0x010428},
+{0x010401, 0x010429},
+{0x010402, 0x01042A},
+{0x010403, 0x01042B},
+{0x010404, 0x01042C},
+{0x010405, 0x01042D},
+{0x010406, 0x01042E},
+{0x010407, 0x01042F},
+{0x010408, 0x010430},
+{0x010409, 0x010431},
+{0x01040A, 0x010432},
+{0x01040B, 0x010433},
+{0x01040C, 0x010434},
+{0x01040D, 0x010435},
+{0x01040E, 0x010436},
+{0x01040F, 0x010437},
+{0x010410, 0x010438},
+{0x010411, 0x010439},
+{0x010412, 0x01043A},
+{0x010413, 0x01043B},
+{0x010414, 0x01043C},
+{0x010415, 0x01043D},
+{0x010416, 0x01043E},
+{0x010417, 0x01043F},
+{0x010418, 0x010440},
+{0x010419, 0x010441},
+{0x01041A, 0x010442},
+{0x01041B, 0x010443},
+{0x01041C, 0x010444},
+{0x01041D, 0x010445},
+{0x01041E, 0x010446},
+{0x01041F, 0x010447},
+{0x010420, 0x010448},
+{0x010421, 0x010449},
+{0x010422, 0x01044A},
+{0x010423, 0x01044B},
+{0x010424, 0x01044C},
+{0x010425, 0x01044D},
+{0x010426, 0x01044E},
+{0x010427, 0x01044F},
+{0x0104B0, 0x0104D8},
+{0x0104B1, 0x0104D9},
+{0x0104B2, 0x0104DA},
+{0x0104B3, 0x0104DB},
+{0x0104B4, 0x0104DC},
+{0x0104B5, 0x0104DD},
+{0x0104B6, 0x0104DE},
+{0x0104B7, 0x0104DF},
+{0x0104B8, 0x0104E0},
+{0x0104B9, 0x0104E1},
+{0x0104BA, 0x0104E2},
+{0x0104BB, 0x0104E3},
+{0x0104BC, 0x0104E4},
+{0x0104BD, 0x0104E5},
+{0x0104BE, 0x0104E6},
+{0x0104BF, 0x0104E7},
+{0x0104C0, 0x0104E8},
+{0x0104C1, 0x0104E9},
+{0x0104C2, 0x0104EA},
+{0x0104C3, 0x0104EB},
+{0x0104C4, 0x0104EC},
+{0x0104C5, 0x0104ED},
+{0x0104C6, 0x0104EE},
+{0x0104C7, 0x0104EF},
+{0x0104C8, 0x0104F0},
+{0x0104C9, 0x0104F1},
+{0x0104CA, 0x0104F2},
+{0x0104CB, 0x0104F3},
+{0x0104CC, 0x0104F4},
+{0x0104CD, 0x0104F5},
+{0x0104CE, 0x0104F6},
+{0x0104CF, 0x0104F7},
+{0x0104D0, 0x0104F8},
+{0x0104D1, 0x0104F9},
+{0x0104D2, 0x0104FA},
+{0x0104D3, 0x0104FB},
+{0x010570, 0x010597},
+{0x010571, 0x010598},
+{0x010572, 0x010599},
+{0x010573, 0x01059A},
+{0x010574, 0x01059B},
+{0x010575, 0x01059C},
+{0x010576, 0x01059D},
+{0x010577, 0x01059E},
+{0x010578, 0x01059F},
+{0x010579, 0x0105A0},
+{0x01057A, 0x0105A1},
+{0x01057C, 0x0105A3},
+{0x01057D, 0x0105A4},
+{0x01057E, 0x0105A5},
+{0x01057F, 0x0105A6},
+{0x010580, 0x0105A7},
+{0x010581, 0x0105A8},
+{0x010582, 0x0105A9},
+{0x010583, 0x0105AA},
+{0x010584, 0x0105AB},
+{0x010585, 0x0105AC},
+{0x010586, 0x0105AD},
+{0x010587, 0x0105AE},
+{0x010588, 0x0105AF},
+{0x010589, 0x0105B0},
+{0x01058A, 0x0105B1},
+{0x01058C, 0x0105B3},
+{0x01058D, 0x0105B4},
+{0x01058E, 0x0105B5},
+{0x01058F, 0x0105B6},
+{0x010590, 0x0105B7},
+{0x010591, 0x0105B8},
+{0x010592, 0x0105B9},
+{0x010594, 0x0105BB},
+{0x010595, 0x0105BC},
+{0x010C80, 0x010CC0},
+{0x010C81, 0x010CC1},
+{0x010C82, 0x010CC2},
+{0x010C83, 0x010CC3},
+{0x010C84, 0x010CC4},
+{0x010C85, 0x010CC5},
+{0x010C86, 0x010CC6},
+{0x010C87, 0x010CC7},
+{0x010C88, 0x010CC8},
+{0x010C89, 0x010CC9},
+{0x010C8A, 0x010CCA},
+{0x010C8B, 0x010CCB},
+{0x010C8C, 0x010CCC},
+{0x010C8D, 0x010CCD},
+{0x010C8E, 0x010CCE},
+{0x010C8F, 0x010CCF},
+{0x010C90, 0x010CD0},
+{0x010C91, 0x010CD1},
+{0x010C92, 0x010CD2},
+{0x010C93, 0x010CD3},
+{0x010C94, 0x010CD4},
+{0x010C95, 0x010CD5},
+{0x010C96, 0x010CD6},
+{0x010C97, 0x010CD7},
+{0x010C98, 0x010CD8},
+{0x010C99, 0x010CD9},
+{0x010C9A, 0x010CDA},
+{0x010C9B, 0x010CDB},
+{0x010C9C, 0x010CDC},
+{0x010C9D, 0x010CDD},
+{0x010C9E, 0x010CDE},
+{0x010C9F, 0x010CDF},
+{0x010CA0, 0x010CE0},
+{0x010CA1, 0x010CE1},
+{0x010CA2, 0x010CE2},
+{0x010CA3, 0x010CE3},
+{0x010CA4, 0x010CE4},
+{0x010CA5, 0x010CE5},
+{0x010CA6, 0x010CE6},
+{0x010CA7, 0x010CE7},
+{0x010CA8, 0x010CE8},
+{0x010CA9, 0x010CE9},
+{0x010CAA, 0x010CEA},
+{0x010CAB, 0x010CEB},
+{0x010CAC, 0x010CEC},
+{0x010CAD, 0x010CED},
+{0x010CAE, 0x010CEE},
+{0x010CAF, 0x010CEF},
+{0x010CB0, 0x010CF0},
+{0x010CB1, 0x010CF1},
+{0x010CB2, 0x010CF2},
+{0x0118A0, 0x0118C0},
+{0x0118A1, 0x0118C1},
+{0x0118A2, 0x0118C2},
+{0x0118A3, 0x0118C3},
+{0x0118A4, 0x0118C4},
+{0x0118A5, 0x0118C5},
+{0x0118A6, 0x0118C6},
+{0x0118A7, 0x0118C7},
+{0x0118A8, 0x0118C8},
+{0x0118A9, 0x0118C9},
+{0x0118AA, 0x0118CA},
+{0x0118AB, 0x0118CB},
+{0x0118AC, 0x0118CC},
+{0x0118AD, 0x0118CD},
+{0x0118AE, 0x0118CE},
+{0x0118AF, 0x0118CF},
+{0x0118B0, 0x0118D0},
+{0x0118B1, 0x0118D1},
+{0x0118B2, 0x0118D2},
+{0x0118B3, 0x0118D3},
+{0x0118B4, 0x0118D4},
+{0x0118B5, 0x0118D5},
+{0x0118B6, 0x0118D6},
+{0x0118B7, 0x0118D7},
+{0x0118B8, 0x0118D8},
+{0x0118B9, 0x0118D9},
+{0x0118BA, 0x0118DA},
+{0x0118BB, 0x0118DB},
+{0x0118BC, 0x0118DC},
+{0x0118BD, 0x0118DD},
+{0x0118BE, 0x0118DE},
+{0x0118BF, 0x0118DF},
+{0x016E40, 0x016E60},
+{0x016E41, 0x016E61},
+{0x016E42, 0x016E62},
+{0x016E43, 0x016E63},
+{0x016E44, 0x016E64},
+{0x016E45, 0x016E65},
+{0x016E46, 0x016E66},
+{0x016E47, 0x016E67},
+{0x016E48, 0x016E68},
+{0x016E49, 0x016E69},
+{0x016E4A, 0x016E6A},
+{0x016E4B, 0x016E6B},
+{0x016E4C, 0x016E6C},
+{0x016E4D, 0x016E6D},
+{0x016E4E, 0x016E6E},
+{0x016E4F, 0x016E6F},
+{0x016E50, 0x016E70},
+{0x016E51, 0x016E71},
+{0x016E52, 0x016E72},
+{0x016E53, 0x016E73},
+{0x016E54, 0x016E74},
+{0x016E55, 0x016E75},
+{0x016E56, 0x016E76},
+{0x016E57, 0x016E77},
+{0x016E58, 0x016E78},
+{0x016E59, 0x016E79},
+{0x016E5A, 0x016E7A},
+{0x016E5B, 0x016E7B},
+{0x016E5C, 0x016E7C},
+{0x016E5D, 0x016E7D},
+{0x016E5E, 0x016E7E},
+{0x016E5F, 0x016E7F},
+{0x01E900, 0x01E922},
+{0x01E901, 0x01E923},
+{0x01E902, 0x01E924},
+{0x01E903, 0x01E925},
+{0x01E904, 0x01E926},
+{0x01E905, 0x01E927},
+{0x01E906, 0x01E928},
+{0x01E907, 0x01E929},
+{0x01E908, 0x01E92A},
+{0x01E909, 0x01E92B},
+{0x01E90A, 0x01E92C},
+{0x01E90B, 0x01E92D},
+{0x01E90C, 0x01E92E},
+{0x01E90D, 0x01E92F},
+{0x01E90E, 0x01E930},
+{0x01E90F, 0x01E931},
+{0x01E910, 0x01E932},
+{0x01E911, 0x01E933},
+{0x01E912, 0x01E934},
+{0x01E913, 0x01E935},
+{0x01E914, 0x01E936},
+{0x01E915, 0x01E937},
+{0x01E916, 0x01E938},
+{0x01E917, 0x01E939},
+{0x01E918, 0x01E93A},
+{0x01E919, 0x01E93B},
+{0x01E91A, 0x01E93C},
+{0x01E91B, 0x01E93D},
+{0x01E91C, 0x01E93E},
+{0x01E91D, 0x01E93F},
+{0x01E91E, 0x01E940},
+{0x01E91F, 0x01E941},
+{0x01E920, 0x01E942},
+{0x01E921, 0x01E943},
+};
+
+// list is always in ascending order, to enable binary search
+const std::initializer_list> unicode_map_uppercase = {
+{0x000061, 0x000041},
+{0x000062, 0x000042},
+{0x000063, 0x000043},
+{0x000064, 0x000044},
+{0x000065, 0x000045},
+{0x000066, 0x000046},
+{0x000067, 0x000047},
+{0x000068, 0x000048},
+{0x000069, 0x000049},
+{0x00006A, 0x00004A},
+{0x00006B, 0x00004B},
+{0x00006C, 0x00004C},
+{0x00006D, 0x00004D},
+{0x00006E, 0x00004E},
+{0x00006F, 0x00004F},
+{0x000070, 0x000050},
+{0x000071, 0x000051},
+{0x000072, 0x000052},
+{0x000073, 0x000053},
+{0x000074, 0x000054},
+{0x000075, 0x000055},
+{0x000076, 0x000056},
+{0x000077, 0x000057},
+{0x000078, 0x000058},
+{0x000079, 0x000059},
+{0x00007A, 0x00005A},
+{0x0000B5, 0x00039C},
+{0x0000E0, 0x0000C0},
+{0x0000E1, 0x0000C1},
+{0x0000E2, 0x0000C2},
+{0x0000E3, 0x0000C3},
+{0x0000E4, 0x0000C4},
+{0x0000E5, 0x0000C5},
+{0x0000E6, 0x0000C6},
+{0x0000E7, 0x0000C7},
+{0x0000E8, 0x0000C8},
+{0x0000E9, 0x0000C9},
+{0x0000EA, 0x0000CA},
+{0x0000EB, 0x0000CB},
+{0x0000EC, 0x0000CC},
+{0x0000ED, 0x0000CD},
+{0x0000EE, 0x0000CE},
+{0x0000EF, 0x0000CF},
+{0x0000F0, 0x0000D0},
+{0x0000F1, 0x0000D1},
+{0x0000F2, 0x0000D2},
+{0x0000F3, 0x0000D3},
+{0x0000F4, 0x0000D4},
+{0x0000F5, 0x0000D5},
+{0x0000F6, 0x0000D6},
+{0x0000F8, 0x0000D8},
+{0x0000F9, 0x0000D9},
+{0x0000FA, 0x0000DA},
+{0x0000FB, 0x0000DB},
+{0x0000FC, 0x0000DC},
+{0x0000FD, 0x0000DD},
+{0x0000FE, 0x0000DE},
+{0x0000FF, 0x000178},
+{0x000101, 0x000100},
+{0x000103, 0x000102},
+{0x000105, 0x000104},
+{0x000107, 0x000106},
+{0x000109, 0x000108},
+{0x00010B, 0x00010A},
+{0x00010D, 0x00010C},
+{0x00010F, 0x00010E},
+{0x000111, 0x000110},
+{0x000113, 0x000112},
+{0x000115, 0x000114},
+{0x000117, 0x000116},
+{0x000119, 0x000118},
+{0x00011B, 0x00011A},
+{0x00011D, 0x00011C},
+{0x00011F, 0x00011E},
+{0x000121, 0x000120},
+{0x000123, 0x000122},
+{0x000125, 0x000124},
+{0x000127, 0x000126},
+{0x000129, 0x000128},
+{0x00012B, 0x00012A},
+{0x00012D, 0x00012C},
+{0x00012F, 0x00012E},
+{0x000131, 0x000049},
+{0x000133, 0x000132},
+{0x000135, 0x000134},
+{0x000137, 0x000136},
+{0x00013A, 0x000139},
+{0x00013C, 0x00013B},
+{0x00013E, 0x00013D},
+{0x000140, 0x00013F},
+{0x000142, 0x000141},
+{0x000144, 0x000143},
+{0x000146, 0x000145},
+{0x000148, 0x000147},
+{0x00014B, 0x00014A},
+{0x00014D, 0x00014C},
+{0x00014F, 0x00014E},
+{0x000151, 0x000150},
+{0x000153, 0x000152},
+{0x000155, 0x000154},
+{0x000157, 0x000156},
+{0x000159, 0x000158},
+{0x00015B, 0x00015A},
+{0x00015D, 0x00015C},
+{0x00015F, 0x00015E},
+{0x000161, 0x000160},
+{0x000163, 0x000162},
+{0x000165, 0x000164},
+{0x000167, 0x000166},
+{0x000169, 0x000168},
+{0x00016B, 0x00016A},
+{0x00016D, 0x00016C},
+{0x00016F, 0x00016E},
+{0x000171, 0x000170},
+{0x000173, 0x000172},
+{0x000175, 0x000174},
+{0x000177, 0x000176},
+{0x00017A, 0x000179},
+{0x00017C, 0x00017B},
+{0x00017E, 0x00017D},
+{0x00017F, 0x000053},
+{0x000180, 0x000243},
+{0x000183, 0x000182},
+{0x000185, 0x000184},
+{0x000188, 0x000187},
+{0x00018C, 0x00018B},
+{0x000192, 0x000191},
+{0x000195, 0x0001F6},
+{0x000199, 0x000198},
+{0x00019A, 0x00023D},
+{0x00019E, 0x000220},
+{0x0001A1, 0x0001A0},
+{0x0001A3, 0x0001A2},
+{0x0001A5, 0x0001A4},
+{0x0001A8, 0x0001A7},
+{0x0001AD, 0x0001AC},
+{0x0001B0, 0x0001AF},
+{0x0001B4, 0x0001B3},
+{0x0001B6, 0x0001B5},
+{0x0001B9, 0x0001B8},
+{0x0001BD, 0x0001BC},
+{0x0001BF, 0x0001F7},
+{0x0001C5, 0x0001C4},
+{0x0001C6, 0x0001C4},
+{0x0001C8, 0x0001C7},
+{0x0001C9, 0x0001C7},
+{0x0001CB, 0x0001CA},
+{0x0001CC, 0x0001CA},
+{0x0001CE, 0x0001CD},
+{0x0001D0, 0x0001CF},
+{0x0001D2, 0x0001D1},
+{0x0001D4, 0x0001D3},
+{0x0001D6, 0x0001D5},
+{0x0001D8, 0x0001D7},
+{0x0001DA, 0x0001D9},
+{0x0001DC, 0x0001DB},
+{0x0001DD, 0x00018E},
+{0x0001DF, 0x0001DE},
+{0x0001E1, 0x0001E0},
+{0x0001E3, 0x0001E2},
+{0x0001E5, 0x0001E4},
+{0x0001E7, 0x0001E6},
+{0x0001E9, 0x0001E8},
+{0x0001EB, 0x0001EA},
+{0x0001ED, 0x0001EC},
+{0x0001EF, 0x0001EE},
+{0x0001F2, 0x0001F1},
+{0x0001F3, 0x0001F1},
+{0x0001F5, 0x0001F4},
+{0x0001F9, 0x0001F8},
+{0x0001FB, 0x0001FA},
+{0x0001FD, 0x0001FC},
+{0x0001FF, 0x0001FE},
+{0x000201, 0x000200},
+{0x000203, 0x000202},
+{0x000205, 0x000204},
+{0x000207, 0x000206},
+{0x000209, 0x000208},
+{0x00020B, 0x00020A},
+{0x00020D, 0x00020C},
+{0x00020F, 0x00020E},
+{0x000211, 0x000210},
+{0x000213, 0x000212},
+{0x000215, 0x000214},
+{0x000217, 0x000216},
+{0x000219, 0x000218},
+{0x00021B, 0x00021A},
+{0x00021D, 0x00021C},
+{0x00021F, 0x00021E},
+{0x000223, 0x000222},
+{0x000225, 0x000224},
+{0x000227, 0x000226},
+{0x000229, 0x000228},
+{0x00022B, 0x00022A},
+{0x00022D, 0x00022C},
+{0x00022F, 0x00022E},
+{0x000231, 0x000230},
+{0x000233, 0x000232},
+{0x00023C, 0x00023B},
+{0x00023F, 0x002C7E},
+{0x000240, 0x002C7F},
+{0x000242, 0x000241},
+{0x000247, 0x000246},
+{0x000249, 0x000248},
+{0x00024B, 0x00024A},
+{0x00024D, 0x00024C},
+{0x00024F, 0x00024E},
+{0x000250, 0x002C6F},
+{0x000251, 0x002C6D},
+{0x000252, 0x002C70},
+{0x000253, 0x000181},
+{0x000254, 0x000186},
+{0x000256, 0x000189},
+{0x000257, 0x00018A},
+{0x000259, 0x00018F},
+{0x00025B, 0x000190},
+{0x00025C, 0x00A7AB},
+{0x000260, 0x000193},
+{0x000261, 0x00A7AC},
+{0x000263, 0x000194},
+{0x000265, 0x00A78D},
+{0x000266, 0x00A7AA},
+{0x000268, 0x000197},
+{0x000269, 0x000196},
+{0x00026A, 0x00A7AE},
+{0x00026B, 0x002C62},
+{0x00026C, 0x00A7AD},
+{0x00026F, 0x00019C},
+{0x000271, 0x002C6E},
+{0x000272, 0x00019D},
+{0x000275, 0x00019F},
+{0x00027D, 0x002C64},
+{0x000280, 0x0001A6},
+{0x000282, 0x00A7C5},
+{0x000283, 0x0001A9},
+{0x000287, 0x00A7B1},
+{0x000288, 0x0001AE},
+{0x000289, 0x000244},
+{0x00028A, 0x0001B1},
+{0x00028B, 0x0001B2},
+{0x00028C, 0x000245},
+{0x000292, 0x0001B7},
+{0x00029D, 0x00A7B2},
+{0x00029E, 0x00A7B0},
+{0x000345, 0x000399},
+{0x000371, 0x000370},
+{0x000373, 0x000372},
+{0x000377, 0x000376},
+{0x00037B, 0x0003FD},
+{0x00037C, 0x0003FE},
+{0x00037D, 0x0003FF},
+{0x0003AC, 0x000386},
+{0x0003AD, 0x000388},
+{0x0003AE, 0x000389},
+{0x0003AF, 0x00038A},
+{0x0003B1, 0x000391},
+{0x0003B2, 0x000392},
+{0x0003B3, 0x000393},
+{0x0003B4, 0x000394},
+{0x0003B5, 0x000395},
+{0x0003B6, 0x000396},
+{0x0003B7, 0x000397},
+{0x0003B8, 0x000398},
+{0x0003B9, 0x000399},
+{0x0003BA, 0x00039A},
+{0x0003BB, 0x00039B},
+{0x0003BC, 0x00039C},
+{0x0003BD, 0x00039D},
+{0x0003BE, 0x00039E},
+{0x0003BF, 0x00039F},
+{0x0003C0, 0x0003A0},
+{0x0003C1, 0x0003A1},
+{0x0003C2, 0x0003A3},
+{0x0003C3, 0x0003A3},
+{0x0003C4, 0x0003A4},
+{0x0003C5, 0x0003A5},
+{0x0003C6, 0x0003A6},
+{0x0003C7, 0x0003A7},
+{0x0003C8, 0x0003A8},
+{0x0003C9, 0x0003A9},
+{0x0003CA, 0x0003AA},
+{0x0003CB, 0x0003AB},
+{0x0003CC, 0x00038C},
+{0x0003CD, 0x00038E},
+{0x0003CE, 0x00038F},
+{0x0003D0, 0x000392},
+{0x0003D1, 0x000398},
+{0x0003D5, 0x0003A6},
+{0x0003D6, 0x0003A0},
+{0x0003D7, 0x0003CF},
+{0x0003D9, 0x0003D8},
+{0x0003DB, 0x0003DA},
+{0x0003DD, 0x0003DC},
+{0x0003DF, 0x0003DE},
+{0x0003E1, 0x0003E0},
+{0x0003E3, 0x0003E2},
+{0x0003E5, 0x0003E4},
+{0x0003E7, 0x0003E6},
+{0x0003E9, 0x0003E8},
+{0x0003EB, 0x0003EA},
+{0x0003ED, 0x0003EC},
+{0x0003EF, 0x0003EE},
+{0x0003F0, 0x00039A},
+{0x0003F1, 0x0003A1},
+{0x0003F2, 0x0003F9},
+{0x0003F3, 0x00037F},
+{0x0003F5, 0x000395},
+{0x0003F8, 0x0003F7},
+{0x0003FB, 0x0003FA},
+{0x000430, 0x000410},
+{0x000431, 0x000411},
+{0x000432, 0x000412},
+{0x000433, 0x000413},
+{0x000434, 0x000414},
+{0x000435, 0x000415},
+{0x000436, 0x000416},
+{0x000437, 0x000417},
+{0x000438, 0x000418},
+{0x000439, 0x000419},
+{0x00043A, 0x00041A},
+{0x00043B, 0x00041B},
+{0x00043C, 0x00041C},
+{0x00043D, 0x00041D},
+{0x00043E, 0x00041E},
+{0x00043F, 0x00041F},
+{0x000440, 0x000420},
+{0x000441, 0x000421},
+{0x000442, 0x000422},
+{0x000443, 0x000423},
+{0x000444, 0x000424},
+{0x000445, 0x000425},
+{0x000446, 0x000426},
+{0x000447, 0x000427},
+{0x000448, 0x000428},
+{0x000449, 0x000429},
+{0x00044A, 0x00042A},
+{0x00044B, 0x00042B},
+{0x00044C, 0x00042C},
+{0x00044D, 0x00042D},
+{0x00044E, 0x00042E},
+{0x00044F, 0x00042F},
+{0x000450, 0x000400},
+{0x000451, 0x000401},
+{0x000452, 0x000402},
+{0x000453, 0x000403},
+{0x000454, 0x000404},
+{0x000455, 0x000405},
+{0x000456, 0x000406},
+{0x000457, 0x000407},
+{0x000458, 0x000408},
+{0x000459, 0x000409},
+{0x00045A, 0x00040A},
+{0x00045B, 0x00040B},
+{0x00045C, 0x00040C},
+{0x00045D, 0x00040D},
+{0x00045E, 0x00040E},
+{0x00045F, 0x00040F},
+{0x000461, 0x000460},
+{0x000463, 0x000462},
+{0x000465, 0x000464},
+{0x000467, 0x000466},
+{0x000469, 0x000468},
+{0x00046B, 0x00046A},
+{0x00046D, 0x00046C},
+{0x00046F, 0x00046E},
+{0x000471, 0x000470},
+{0x000473, 0x000472},
+{0x000475, 0x000474},
+{0x000477, 0x000476},
+{0x000479, 0x000478},
+{0x00047B, 0x00047A},
+{0x00047D, 0x00047C},
+{0x00047F, 0x00047E},
+{0x000481, 0x000480},
+{0x00048B, 0x00048A},
+{0x00048D, 0x00048C},
+{0x00048F, 0x00048E},
+{0x000491, 0x000490},
+{0x000493, 0x000492},
+{0x000495, 0x000494},
+{0x000497, 0x000496},
+{0x000499, 0x000498},
+{0x00049B, 0x00049A},
+{0x00049D, 0x00049C},
+{0x00049F, 0x00049E},
+{0x0004A1, 0x0004A0},
+{0x0004A3, 0x0004A2},
+{0x0004A5, 0x0004A4},
+{0x0004A7, 0x0004A6},
+{0x0004A9, 0x0004A8},
+{0x0004AB, 0x0004AA},
+{0x0004AD, 0x0004AC},
+{0x0004AF, 0x0004AE},
+{0x0004B1, 0x0004B0},
+{0x0004B3, 0x0004B2},
+{0x0004B5, 0x0004B4},
+{0x0004B7, 0x0004B6},
+{0x0004B9, 0x0004B8},
+{0x0004BB, 0x0004BA},
+{0x0004BD, 0x0004BC},
+{0x0004BF, 0x0004BE},
+{0x0004C2, 0x0004C1},
+{0x0004C4, 0x0004C3},
+{0x0004C6, 0x0004C5},
+{0x0004C8, 0x0004C7},
+{0x0004CA, 0x0004C9},
+{0x0004CC, 0x0004CB},
+{0x0004CE, 0x0004CD},
+{0x0004CF, 0x0004C0},
+{0x0004D1, 0x0004D0},
+{0x0004D3, 0x0004D2},
+{0x0004D5, 0x0004D4},
+{0x0004D7, 0x0004D6},
+{0x0004D9, 0x0004D8},
+{0x0004DB, 0x0004DA},
+{0x0004DD, 0x0004DC},
+{0x0004DF, 0x0004DE},
+{0x0004E1, 0x0004E0},
+{0x0004E3, 0x0004E2},
+{0x0004E5, 0x0004E4},
+{0x0004E7, 0x0004E6},
+{0x0004E9, 0x0004E8},
+{0x0004EB, 0x0004EA},
+{0x0004ED, 0x0004EC},
+{0x0004EF, 0x0004EE},
+{0x0004F1, 0x0004F0},
+{0x0004F3, 0x0004F2},
+{0x0004F5, 0x0004F4},
+{0x0004F7, 0x0004F6},
+{0x0004F9, 0x0004F8},
+{0x0004FB, 0x0004FA},
+{0x0004FD, 0x0004FC},
+{0x0004FF, 0x0004FE},
+{0x000501, 0x000500},
+{0x000503, 0x000502},
+{0x000505, 0x000504},
+{0x000507, 0x000506},
+{0x000509, 0x000508},
+{0x00050B, 0x00050A},
+{0x00050D, 0x00050C},
+{0x00050F, 0x00050E},
+{0x000511, 0x000510},
+{0x000513, 0x000512},
+{0x000515, 0x000514},
+{0x000517, 0x000516},
+{0x000519, 0x000518},
+{0x00051B, 0x00051A},
+{0x00051D, 0x00051C},
+{0x00051F, 0x00051E},
+{0x000521, 0x000520},
+{0x000523, 0x000522},
+{0x000525, 0x000524},
+{0x000527, 0x000526},
+{0x000529, 0x000528},
+{0x00052B, 0x00052A},
+{0x00052D, 0x00052C},
+{0x00052F, 0x00052E},
+{0x000561, 0x000531},
+{0x000562, 0x000532},
+{0x000563, 0x000533},
+{0x000564, 0x000534},
+{0x000565, 0x000535},
+{0x000566, 0x000536},
+{0x000567, 0x000537},
+{0x000568, 0x000538},
+{0x000569, 0x000539},
+{0x00056A, 0x00053A},
+{0x00056B, 0x00053B},
+{0x00056C, 0x00053C},
+{0x00056D, 0x00053D},
+{0x00056E, 0x00053E},
+{0x00056F, 0x00053F},
+{0x000570, 0x000540},
+{0x000571, 0x000541},
+{0x000572, 0x000542},
+{0x000573, 0x000543},
+{0x000574, 0x000544},
+{0x000575, 0x000545},
+{0x000576, 0x000546},
+{0x000577, 0x000547},
+{0x000578, 0x000548},
+{0x000579, 0x000549},
+{0x00057A, 0x00054A},
+{0x00057B, 0x00054B},
+{0x00057C, 0x00054C},
+{0x00057D, 0x00054D},
+{0x00057E, 0x00054E},
+{0x00057F, 0x00054F},
+{0x000580, 0x000550},
+{0x000581, 0x000551},
+{0x000582, 0x000552},
+{0x000583, 0x000553},
+{0x000584, 0x000554},
+{0x000585, 0x000555},
+{0x000586, 0x000556},
+{0x0010D0, 0x001C90},
+{0x0010D1, 0x001C91},
+{0x0010D2, 0x001C92},
+{0x0010D3, 0x001C93},
+{0x0010D4, 0x001C94},
+{0x0010D5, 0x001C95},
+{0x0010D6, 0x001C96},
+{0x0010D7, 0x001C97},
+{0x0010D8, 0x001C98},
+{0x0010D9, 0x001C99},
+{0x0010DA, 0x001C9A},
+{0x0010DB, 0x001C9B},
+{0x0010DC, 0x001C9C},
+{0x0010DD, 0x001C9D},
+{0x0010DE, 0x001C9E},
+{0x0010DF, 0x001C9F},
+{0x0010E0, 0x001CA0},
+{0x0010E1, 0x001CA1},
+{0x0010E2, 0x001CA2},
+{0x0010E3, 0x001CA3},
+{0x0010E4, 0x001CA4},
+{0x0010E5, 0x001CA5},
+{0x0010E6, 0x001CA6},
+{0x0010E7, 0x001CA7},
+{0x0010E8, 0x001CA8},
+{0x0010E9, 0x001CA9},
+{0x0010EA, 0x001CAA},
+{0x0010EB, 0x001CAB},
+{0x0010EC, 0x001CAC},
+{0x0010ED, 0x001CAD},
+{0x0010EE, 0x001CAE},
+{0x0010EF, 0x001CAF},
+{0x0010F0, 0x001CB0},
+{0x0010F1, 0x001CB1},
+{0x0010F2, 0x001CB2},
+{0x0010F3, 0x001CB3},
+{0x0010F4, 0x001CB4},
+{0x0010F5, 0x001CB5},
+{0x0010F6, 0x001CB6},
+{0x0010F7, 0x001CB7},
+{0x0010F8, 0x001CB8},
+{0x0010F9, 0x001CB9},
+{0x0010FA, 0x001CBA},
+{0x0010FD, 0x001CBD},
+{0x0010FE, 0x001CBE},
+{0x0010FF, 0x001CBF},
+{0x0013F8, 0x0013F0},
+{0x0013F9, 0x0013F1},
+{0x0013FA, 0x0013F2},
+{0x0013FB, 0x0013F3},
+{0x0013FC, 0x0013F4},
+{0x0013FD, 0x0013F5},
+{0x001C80, 0x000412},
+{0x001C81, 0x000414},
+{0x001C82, 0x00041E},
+{0x001C83, 0x000421},
+{0x001C84, 0x000422},
+{0x001C85, 0x000422},
+{0x001C86, 0x00042A},
+{0x001C87, 0x000462},
+{0x001C88, 0x00A64A},
+{0x001D79, 0x00A77D},
+{0x001D7D, 0x002C63},
+{0x001D8E, 0x00A7C6},
+{0x001E01, 0x001E00},
+{0x001E03, 0x001E02},
+{0x001E05, 0x001E04},
+{0x001E07, 0x001E06},
+{0x001E09, 0x001E08},
+{0x001E0B, 0x001E0A},
+{0x001E0D, 0x001E0C},
+{0x001E0F, 0x001E0E},
+{0x001E11, 0x001E10},
+{0x001E13, 0x001E12},
+{0x001E15, 0x001E14},
+{0x001E17, 0x001E16},
+{0x001E19, 0x001E18},
+{0x001E1B, 0x001E1A},
+{0x001E1D, 0x001E1C},
+{0x001E1F, 0x001E1E},
+{0x001E21, 0x001E20},
+{0x001E23, 0x001E22},
+{0x001E25, 0x001E24},
+{0x001E27, 0x001E26},
+{0x001E29, 0x001E28},
+{0x001E2B, 0x001E2A},
+{0x001E2D, 0x001E2C},
+{0x001E2F, 0x001E2E},
+{0x001E31, 0x001E30},
+{0x001E33, 0x001E32},
+{0x001E35, 0x001E34},
+{0x001E37, 0x001E36},
+{0x001E39, 0x001E38},
+{0x001E3B, 0x001E3A},
+{0x001E3D, 0x001E3C},
+{0x001E3F, 0x001E3E},
+{0x001E41, 0x001E40},
+{0x001E43, 0x001E42},
+{0x001E45, 0x001E44},
+{0x001E47, 0x001E46},
+{0x001E49, 0x001E48},
+{0x001E4B, 0x001E4A},
+{0x001E4D, 0x001E4C},
+{0x001E4F, 0x001E4E},
+{0x001E51, 0x001E50},
+{0x001E53, 0x001E52},
+{0x001E55, 0x001E54},
+{0x001E57, 0x001E56},
+{0x001E59, 0x001E58},
+{0x001E5B, 0x001E5A},
+{0x001E5D, 0x001E5C},
+{0x001E5F, 0x001E5E},
+{0x001E61, 0x001E60},
+{0x001E63, 0x001E62},
+{0x001E65, 0x001E64},
+{0x001E67, 0x001E66},
+{0x001E69, 0x001E68},
+{0x001E6B, 0x001E6A},
+{0x001E6D, 0x001E6C},
+{0x001E6F, 0x001E6E},
+{0x001E71, 0x001E70},
+{0x001E73, 0x001E72},
+{0x001E75, 0x001E74},
+{0x001E77, 0x001E76},
+{0x001E79, 0x001E78},
+{0x001E7B, 0x001E7A},
+{0x001E7D, 0x001E7C},
+{0x001E7F, 0x001E7E},
+{0x001E81, 0x001E80},
+{0x001E83, 0x001E82},
+{0x001E85, 0x001E84},
+{0x001E87, 0x001E86},
+{0x001E89, 0x001E88},
+{0x001E8B, 0x001E8A},
+{0x001E8D, 0x001E8C},
+{0x001E8F, 0x001E8E},
+{0x001E91, 0x001E90},
+{0x001E93, 0x001E92},
+{0x001E95, 0x001E94},
+{0x001E9B, 0x001E60},
+{0x001EA1, 0x001EA0},
+{0x001EA3, 0x001EA2},
+{0x001EA5, 0x001EA4},
+{0x001EA7, 0x001EA6},
+{0x001EA9, 0x001EA8},
+{0x001EAB, 0x001EAA},
+{0x001EAD, 0x001EAC},
+{0x001EAF, 0x001EAE},
+{0x001EB1, 0x001EB0},
+{0x001EB3, 0x001EB2},
+{0x001EB5, 0x001EB4},
+{0x001EB7, 0x001EB6},
+{0x001EB9, 0x001EB8},
+{0x001EBB, 0x001EBA},
+{0x001EBD, 0x001EBC},
+{0x001EBF, 0x001EBE},
+{0x001EC1, 0x001EC0},
+{0x001EC3, 0x001EC2},
+{0x001EC5, 0x001EC4},
+{0x001EC7, 0x001EC6},
+{0x001EC9, 0x001EC8},
+{0x001ECB, 0x001ECA},
+{0x001ECD, 0x001ECC},
+{0x001ECF, 0x001ECE},
+{0x001ED1, 0x001ED0},
+{0x001ED3, 0x001ED2},
+{0x001ED5, 0x001ED4},
+{0x001ED7, 0x001ED6},
+{0x001ED9, 0x001ED8},
+{0x001EDB, 0x001EDA},
+{0x001EDD, 0x001EDC},
+{0x001EDF, 0x001EDE},
+{0x001EE1, 0x001EE0},
+{0x001EE3, 0x001EE2},
+{0x001EE5, 0x001EE4},
+{0x001EE7, 0x001EE6},
+{0x001EE9, 0x001EE8},
+{0x001EEB, 0x001EEA},
+{0x001EED, 0x001EEC},
+{0x001EEF, 0x001EEE},
+{0x001EF1, 0x001EF0},
+{0x001EF3, 0x001EF2},
+{0x001EF5, 0x001EF4},
+{0x001EF7, 0x001EF6},
+{0x001EF9, 0x001EF8},
+{0x001EFB, 0x001EFA},
+{0x001EFD, 0x001EFC},
+{0x001EFF, 0x001EFE},
+{0x001F00, 0x001F08},
+{0x001F01, 0x001F09},
+{0x001F02, 0x001F0A},
+{0x001F03, 0x001F0B},
+{0x001F04, 0x001F0C},
+{0x001F05, 0x001F0D},
+{0x001F06, 0x001F0E},
+{0x001F07, 0x001F0F},
+{0x001F10, 0x001F18},
+{0x001F11, 0x001F19},
+{0x001F12, 0x001F1A},
+{0x001F13, 0x001F1B},
+{0x001F14, 0x001F1C},
+{0x001F15, 0x001F1D},
+{0x001F20, 0x001F28},
+{0x001F21, 0x001F29},
+{0x001F22, 0x001F2A},
+{0x001F23, 0x001F2B},
+{0x001F24, 0x001F2C},
+{0x001F25, 0x001F2D},
+{0x001F26, 0x001F2E},
+{0x001F27, 0x001F2F},
+{0x001F30, 0x001F38},
+{0x001F31, 0x001F39},
+{0x001F32, 0x001F3A},
+{0x001F33, 0x001F3B},
+{0x001F34, 0x001F3C},
+{0x001F35, 0x001F3D},
+{0x001F36, 0x001F3E},
+{0x001F37, 0x001F3F},
+{0x001F40, 0x001F48},
+{0x001F41, 0x001F49},
+{0x001F42, 0x001F4A},
+{0x001F43, 0x001F4B},
+{0x001F44, 0x001F4C},
+{0x001F45, 0x001F4D},
+{0x001F51, 0x001F59},
+{0x001F53, 0x001F5B},
+{0x001F55, 0x001F5D},
+{0x001F57, 0x001F5F},
+{0x001F60, 0x001F68},
+{0x001F61, 0x001F69},
+{0x001F62, 0x001F6A},
+{0x001F63, 0x001F6B},
+{0x001F64, 0x001F6C},
+{0x001F65, 0x001F6D},
+{0x001F66, 0x001F6E},
+{0x001F67, 0x001F6F},
+{0x001F70, 0x001FBA},
+{0x001F71, 0x001FBB},
+{0x001F72, 0x001FC8},
+{0x001F73, 0x001FC9},
+{0x001F74, 0x001FCA},
+{0x001F75, 0x001FCB},
+{0x001F76, 0x001FDA},
+{0x001F77, 0x001FDB},
+{0x001F78, 0x001FF8},
+{0x001F79, 0x001FF9},
+{0x001F7A, 0x001FEA},
+{0x001F7B, 0x001FEB},
+{0x001F7C, 0x001FFA},
+{0x001F7D, 0x001FFB},
+{0x001F80, 0x001F88},
+{0x001F81, 0x001F89},
+{0x001F82, 0x001F8A},
+{0x001F83, 0x001F8B},
+{0x001F84, 0x001F8C},
+{0x001F85, 0x001F8D},
+{0x001F86, 0x001F8E},
+{0x001F87, 0x001F8F},
+{0x001F90, 0x001F98},
+{0x001F91, 0x001F99},
+{0x001F92, 0x001F9A},
+{0x001F93, 0x001F9B},
+{0x001F94, 0x001F9C},
+{0x001F95, 0x001F9D},
+{0x001F96, 0x001F9E},
+{0x001F97, 0x001F9F},
+{0x001FA0, 0x001FA8},
+{0x001FA1, 0x001FA9},
+{0x001FA2, 0x001FAA},
+{0x001FA3, 0x001FAB},
+{0x001FA4, 0x001FAC},
+{0x001FA5, 0x001FAD},
+{0x001FA6, 0x001FAE},
+{0x001FA7, 0x001FAF},
+{0x001FB0, 0x001FB8},
+{0x001FB1, 0x001FB9},
+{0x001FB3, 0x001FBC},
+{0x001FBE, 0x000399},
+{0x001FC3, 0x001FCC},
+{0x001FD0, 0x001FD8},
+{0x001FD1, 0x001FD9},
+{0x001FE0, 0x001FE8},
+{0x001FE1, 0x001FE9},
+{0x001FE5, 0x001FEC},
+{0x001FF3, 0x001FFC},
+{0x00214E, 0x002132},
+{0x002170, 0x002160},
+{0x002171, 0x002161},
+{0x002172, 0x002162},
+{0x002173, 0x002163},
+{0x002174, 0x002164},
+{0x002175, 0x002165},
+{0x002176, 0x002166},
+{0x002177, 0x002167},
+{0x002178, 0x002168},
+{0x002179, 0x002169},
+{0x00217A, 0x00216A},
+{0x00217B, 0x00216B},
+{0x00217C, 0x00216C},
+{0x00217D, 0x00216D},
+{0x00217E, 0x00216E},
+{0x00217F, 0x00216F},
+{0x002184, 0x002183},
+{0x0024D0, 0x0024B6},
+{0x0024D1, 0x0024B7},
+{0x0024D2, 0x0024B8},
+{0x0024D3, 0x0024B9},
+{0x0024D4, 0x0024BA},
+{0x0024D5, 0x0024BB},
+{0x0024D6, 0x0024BC},
+{0x0024D7, 0x0024BD},
+{0x0024D8, 0x0024BE},
+{0x0024D9, 0x0024BF},
+{0x0024DA, 0x0024C0},
+{0x0024DB, 0x0024C1},
+{0x0024DC, 0x0024C2},
+{0x0024DD, 0x0024C3},
+{0x0024DE, 0x0024C4},
+{0x0024DF, 0x0024C5},
+{0x0024E0, 0x0024C6},
+{0x0024E1, 0x0024C7},
+{0x0024E2, 0x0024C8},
+{0x0024E3, 0x0024C9},
+{0x0024E4, 0x0024CA},
+{0x0024E5, 0x0024CB},
+{0x0024E6, 0x0024CC},
+{0x0024E7, 0x0024CD},
+{0x0024E8, 0x0024CE},
+{0x0024E9, 0x0024CF},
+{0x002C30, 0x002C00},
+{0x002C31, 0x002C01},
+{0x002C32, 0x002C02},
+{0x002C33, 0x002C03},
+{0x002C34, 0x002C04},
+{0x002C35, 0x002C05},
+{0x002C36, 0x002C06},
+{0x002C37, 0x002C07},
+{0x002C38, 0x002C08},
+{0x002C39, 0x002C09},
+{0x002C3A, 0x002C0A},
+{0x002C3B, 0x002C0B},
+{0x002C3C, 0x002C0C},
+{0x002C3D, 0x002C0D},
+{0x002C3E, 0x002C0E},
+{0x002C3F, 0x002C0F},
+{0x002C40, 0x002C10},
+{0x002C41, 0x002C11},
+{0x002C42, 0x002C12},
+{0x002C43, 0x002C13},
+{0x002C44, 0x002C14},
+{0x002C45, 0x002C15},
+{0x002C46, 0x002C16},
+{0x002C47, 0x002C17},
+{0x002C48, 0x002C18},
+{0x002C49, 0x002C19},
+{0x002C4A, 0x002C1A},
+{0x002C4B, 0x002C1B},
+{0x002C4C, 0x002C1C},
+{0x002C4D, 0x002C1D},
+{0x002C4E, 0x002C1E},
+{0x002C4F, 0x002C1F},
+{0x002C50, 0x002C20},
+{0x002C51, 0x002C21},
+{0x002C52, 0x002C22},
+{0x002C53, 0x002C23},
+{0x002C54, 0x002C24},
+{0x002C55, 0x002C25},
+{0x002C56, 0x002C26},
+{0x002C57, 0x002C27},
+{0x002C58, 0x002C28},
+{0x002C59, 0x002C29},
+{0x002C5A, 0x002C2A},
+{0x002C5B, 0x002C2B},
+{0x002C5C, 0x002C2C},
+{0x002C5D, 0x002C2D},
+{0x002C5E, 0x002C2E},
+{0x002C5F, 0x002C2F},
+{0x002C61, 0x002C60},
+{0x002C65, 0x00023A},
+{0x002C66, 0x00023E},
+{0x002C68, 0x002C67},
+{0x002C6A, 0x002C69},
+{0x002C6C, 0x002C6B},
+{0x002C73, 0x002C72},
+{0x002C76, 0x002C75},
+{0x002C81, 0x002C80},
+{0x002C83, 0x002C82},
+{0x002C85, 0x002C84},
+{0x002C87, 0x002C86},
+{0x002C89, 0x002C88},
+{0x002C8B, 0x002C8A},
+{0x002C8D, 0x002C8C},
+{0x002C8F, 0x002C8E},
+{0x002C91, 0x002C90},
+{0x002C93, 0x002C92},
+{0x002C95, 0x002C94},
+{0x002C97, 0x002C96},
+{0x002C99, 0x002C98},
+{0x002C9B, 0x002C9A},
+{0x002C9D, 0x002C9C},
+{0x002C9F, 0x002C9E},
+{0x002CA1, 0x002CA0},
+{0x002CA3, 0x002CA2},
+{0x002CA5, 0x002CA4},
+{0x002CA7, 0x002CA6},
+{0x002CA9, 0x002CA8},
+{0x002CAB, 0x002CAA},
+{0x002CAD, 0x002CAC},
+{0x002CAF, 0x002CAE},
+{0x002CB1, 0x002CB0},
+{0x002CB3, 0x002CB2},
+{0x002CB5, 0x002CB4},
+{0x002CB7, 0x002CB6},
+{0x002CB9, 0x002CB8},
+{0x002CBB, 0x002CBA},
+{0x002CBD, 0x002CBC},
+{0x002CBF, 0x002CBE},
+{0x002CC1, 0x002CC0},
+{0x002CC3, 0x002CC2},
+{0x002CC5, 0x002CC4},
+{0x002CC7, 0x002CC6},
+{0x002CC9, 0x002CC8},
+{0x002CCB, 0x002CCA},
+{0x002CCD, 0x002CCC},
+{0x002CCF, 0x002CCE},
+{0x002CD1, 0x002CD0},
+{0x002CD3, 0x002CD2},
+{0x002CD5, 0x002CD4},
+{0x002CD7, 0x002CD6},
+{0x002CD9, 0x002CD8},
+{0x002CDB, 0x002CDA},
+{0x002CDD, 0x002CDC},
+{0x002CDF, 0x002CDE},
+{0x002CE1, 0x002CE0},
+{0x002CE3, 0x002CE2},
+{0x002CEC, 0x002CEB},
+{0x002CEE, 0x002CED},
+{0x002CF3, 0x002CF2},
+{0x002D00, 0x0010A0},
+{0x002D01, 0x0010A1},
+{0x002D02, 0x0010A2},
+{0x002D03, 0x0010A3},
+{0x002D04, 0x0010A4},
+{0x002D05, 0x0010A5},
+{0x002D06, 0x0010A6},
+{0x002D07, 0x0010A7},
+{0x002D08, 0x0010A8},
+{0x002D09, 0x0010A9},
+{0x002D0A, 0x0010AA},
+{0x002D0B, 0x0010AB},
+{0x002D0C, 0x0010AC},
+{0x002D0D, 0x0010AD},
+{0x002D0E, 0x0010AE},
+{0x002D0F, 0x0010AF},
+{0x002D10, 0x0010B0},
+{0x002D11, 0x0010B1},
+{0x002D12, 0x0010B2},
+{0x002D13, 0x0010B3},
+{0x002D14, 0x0010B4},
+{0x002D15, 0x0010B5},
+{0x002D16, 0x0010B6},
+{0x002D17, 0x0010B7},
+{0x002D18, 0x0010B8},
+{0x002D19, 0x0010B9},
+{0x002D1A, 0x0010BA},
+{0x002D1B, 0x0010BB},
+{0x002D1C, 0x0010BC},
+{0x002D1D, 0x0010BD},
+{0x002D1E, 0x0010BE},
+{0x002D1F, 0x0010BF},
+{0x002D20, 0x0010C0},
+{0x002D21, 0x0010C1},
+{0x002D22, 0x0010C2},
+{0x002D23, 0x0010C3},
+{0x002D24, 0x0010C4},
+{0x002D25, 0x0010C5},
+{0x002D27, 0x0010C7},
+{0x002D2D, 0x0010CD},
+{0x00A641, 0x00A640},
+{0x00A643, 0x00A642},
+{0x00A645, 0x00A644},
+{0x00A647, 0x00A646},
+{0x00A649, 0x00A648},
+{0x00A64B, 0x00A64A},
+{0x00A64D, 0x00A64C},
+{0x00A64F, 0x00A64E},
+{0x00A651, 0x00A650},
+{0x00A653, 0x00A652},
+{0x00A655, 0x00A654},
+{0x00A657, 0x00A656},
+{0x00A659, 0x00A658},
+{0x00A65B, 0x00A65A},
+{0x00A65D, 0x00A65C},
+{0x00A65F, 0x00A65E},
+{0x00A661, 0x00A660},
+{0x00A663, 0x00A662},
+{0x00A665, 0x00A664},
+{0x00A667, 0x00A666},
+{0x00A669, 0x00A668},
+{0x00A66B, 0x00A66A},
+{0x00A66D, 0x00A66C},
+{0x00A681, 0x00A680},
+{0x00A683, 0x00A682},
+{0x00A685, 0x00A684},
+{0x00A687, 0x00A686},
+{0x00A689, 0x00A688},
+{0x00A68B, 0x00A68A},
+{0x00A68D, 0x00A68C},
+{0x00A68F, 0x00A68E},
+{0x00A691, 0x00A690},
+{0x00A693, 0x00A692},
+{0x00A695, 0x00A694},
+{0x00A697, 0x00A696},
+{0x00A699, 0x00A698},
+{0x00A69B, 0x00A69A},
+{0x00A723, 0x00A722},
+{0x00A725, 0x00A724},
+{0x00A727, 0x00A726},
+{0x00A729, 0x00A728},
+{0x00A72B, 0x00A72A},
+{0x00A72D, 0x00A72C},
+{0x00A72F, 0x00A72E},
+{0x00A733, 0x00A732},
+{0x00A735, 0x00A734},
+{0x00A737, 0x00A736},
+{0x00A739, 0x00A738},
+{0x00A73B, 0x00A73A},
+{0x00A73D, 0x00A73C},
+{0x00A73F, 0x00A73E},
+{0x00A741, 0x00A740},
+{0x00A743, 0x00A742},
+{0x00A745, 0x00A744},
+{0x00A747, 0x00A746},
+{0x00A749, 0x00A748},
+{0x00A74B, 0x00A74A},
+{0x00A74D, 0x00A74C},
+{0x00A74F, 0x00A74E},
+{0x00A751, 0x00A750},
+{0x00A753, 0x00A752},
+{0x00A755, 0x00A754},
+{0x00A757, 0x00A756},
+{0x00A759, 0x00A758},
+{0x00A75B, 0x00A75A},
+{0x00A75D, 0x00A75C},
+{0x00A75F, 0x00A75E},
+{0x00A761, 0x00A760},
+{0x00A763, 0x00A762},
+{0x00A765, 0x00A764},
+{0x00A767, 0x00A766},
+{0x00A769, 0x00A768},
+{0x00A76B, 0x00A76A},
+{0x00A76D, 0x00A76C},
+{0x00A76F, 0x00A76E},
+{0x00A77A, 0x00A779},
+{0x00A77C, 0x00A77B},
+{0x00A77F, 0x00A77E},
+{0x00A781, 0x00A780},
+{0x00A783, 0x00A782},
+{0x00A785, 0x00A784},
+{0x00A787, 0x00A786},
+{0x00A78C, 0x00A78B},
+{0x00A791, 0x00A790},
+{0x00A793, 0x00A792},
+{0x00A794, 0x00A7C4},
+{0x00A797, 0x00A796},
+{0x00A799, 0x00A798},
+{0x00A79B, 0x00A79A},
+{0x00A79D, 0x00A79C},
+{0x00A79F, 0x00A79E},
+{0x00A7A1, 0x00A7A0},
+{0x00A7A3, 0x00A7A2},
+{0x00A7A5, 0x00A7A4},
+{0x00A7A7, 0x00A7A6},
+{0x00A7A9, 0x00A7A8},
+{0x00A7B5, 0x00A7B4},
+{0x00A7B7, 0x00A7B6},
+{0x00A7B9, 0x00A7B8},
+{0x00A7BB, 0x00A7BA},
+{0x00A7BD, 0x00A7BC},
+{0x00A7BF, 0x00A7BE},
+{0x00A7C1, 0x00A7C0},
+{0x00A7C3, 0x00A7C2},
+{0x00A7C8, 0x00A7C7},
+{0x00A7CA, 0x00A7C9},
+{0x00A7D1, 0x00A7D0},
+{0x00A7D7, 0x00A7D6},
+{0x00A7D9, 0x00A7D8},
+{0x00A7F6, 0x00A7F5},
+{0x00AB53, 0x00A7B3},
+{0x00AB70, 0x0013A0},
+{0x00AB71, 0x0013A1},
+{0x00AB72, 0x0013A2},
+{0x00AB73, 0x0013A3},
+{0x00AB74, 0x0013A4},
+{0x00AB75, 0x0013A5},
+{0x00AB76, 0x0013A6},
+{0x00AB77, 0x0013A7},
+{0x00AB78, 0x0013A8},
+{0x00AB79, 0x0013A9},
+{0x00AB7A, 0x0013AA},
+{0x00AB7B, 0x0013AB},
+{0x00AB7C, 0x0013AC},
+{0x00AB7D, 0x0013AD},
+{0x00AB7E, 0x0013AE},
+{0x00AB7F, 0x0013AF},
+{0x00AB80, 0x0013B0},
+{0x00AB81, 0x0013B1},
+{0x00AB82, 0x0013B2},
+{0x00AB83, 0x0013B3},
+{0x00AB84, 0x0013B4},
+{0x00AB85, 0x0013B5},
+{0x00AB86, 0x0013B6},
+{0x00AB87, 0x0013B7},
+{0x00AB88, 0x0013B8},
+{0x00AB89, 0x0013B9},
+{0x00AB8A, 0x0013BA},
+{0x00AB8B, 0x0013BB},
+{0x00AB8C, 0x0013BC},
+{0x00AB8D, 0x0013BD},
+{0x00AB8E, 0x0013BE},
+{0x00AB8F, 0x0013BF},
+{0x00AB90, 0x0013C0},
+{0x00AB91, 0x0013C1},
+{0x00AB92, 0x0013C2},
+{0x00AB93, 0x0013C3},
+{0x00AB94, 0x0013C4},
+{0x00AB95, 0x0013C5},
+{0x00AB96, 0x0013C6},
+{0x00AB97, 0x0013C7},
+{0x00AB98, 0x0013C8},
+{0x00AB99, 0x0013C9},
+{0x00AB9A, 0x0013CA},
+{0x00AB9B, 0x0013CB},
+{0x00AB9C, 0x0013CC},
+{0x00AB9D, 0x0013CD},
+{0x00AB9E, 0x0013CE},
+{0x00AB9F, 0x0013CF},
+{0x00ABA0, 0x0013D0},
+{0x00ABA1, 0x0013D1},
+{0x00ABA2, 0x0013D2},
+{0x00ABA3, 0x0013D3},
+{0x00ABA4, 0x0013D4},
+{0x00ABA5, 0x0013D5},
+{0x00ABA6, 0x0013D6},
+{0x00ABA7, 0x0013D7},
+{0x00ABA8, 0x0013D8},
+{0x00ABA9, 0x0013D9},
+{0x00ABAA, 0x0013DA},
+{0x00ABAB, 0x0013DB},
+{0x00ABAC, 0x0013DC},
+{0x00ABAD, 0x0013DD},
+{0x00ABAE, 0x0013DE},
+{0x00ABAF, 0x0013DF},
+{0x00ABB0, 0x0013E0},
+{0x00ABB1, 0x0013E1},
+{0x00ABB2, 0x0013E2},
+{0x00ABB3, 0x0013E3},
+{0x00ABB4, 0x0013E4},
+{0x00ABB5, 0x0013E5},
+{0x00ABB6, 0x0013E6},
+{0x00ABB7, 0x0013E7},
+{0x00ABB8, 0x0013E8},
+{0x00ABB9, 0x0013E9},
+{0x00ABBA, 0x0013EA},
+{0x00ABBB, 0x0013EB},
+{0x00ABBC, 0x0013EC},
+{0x00ABBD, 0x0013ED},
+{0x00ABBE, 0x0013EE},
+{0x00ABBF, 0x0013EF},
+{0x00FF41, 0x00FF21},
+{0x00FF42, 0x00FF22},
+{0x00FF43, 0x00FF23},
+{0x00FF44, 0x00FF24},
+{0x00FF45, 0x00FF25},
+{0x00FF46, 0x00FF26},
+{0x00FF47, 0x00FF27},
+{0x00FF48, 0x00FF28},
+{0x00FF49, 0x00FF29},
+{0x00FF4A, 0x00FF2A},
+{0x00FF4B, 0x00FF2B},
+{0x00FF4C, 0x00FF2C},
+{0x00FF4D, 0x00FF2D},
+{0x00FF4E, 0x00FF2E},
+{0x00FF4F, 0x00FF2F},
+{0x00FF50, 0x00FF30},
+{0x00FF51, 0x00FF31},
+{0x00FF52, 0x00FF32},
+{0x00FF53, 0x00FF33},
+{0x00FF54, 0x00FF34},
+{0x00FF55, 0x00FF35},
+{0x00FF56, 0x00FF36},
+{0x00FF57, 0x00FF37},
+{0x00FF58, 0x00FF38},
+{0x00FF59, 0x00FF39},
+{0x00FF5A, 0x00FF3A},
+{0x010428, 0x010400},
+{0x010429, 0x010401},
+{0x01042A, 0x010402},
+{0x01042B, 0x010403},
+{0x01042C, 0x010404},
+{0x01042D, 0x010405},
+{0x01042E, 0x010406},
+{0x01042F, 0x010407},
+{0x010430, 0x010408},
+{0x010431, 0x010409},
+{0x010432, 0x01040A},
+{0x010433, 0x01040B},
+{0x010434, 0x01040C},
+{0x010435, 0x01040D},
+{0x010436, 0x01040E},
+{0x010437, 0x01040F},
+{0x010438, 0x010410},
+{0x010439, 0x010411},
+{0x01043A, 0x010412},
+{0x01043B, 0x010413},
+{0x01043C, 0x010414},
+{0x01043D, 0x010415},
+{0x01043E, 0x010416},
+{0x01043F, 0x010417},
+{0x010440, 0x010418},
+{0x010441, 0x010419},
+{0x010442, 0x01041A},
+{0x010443, 0x01041B},
+{0x010444, 0x01041C},
+{0x010445, 0x01041D},
+{0x010446, 0x01041E},
+{0x010447, 0x01041F},
+{0x010448, 0x010420},
+{0x010449, 0x010421},
+{0x01044A, 0x010422},
+{0x01044B, 0x010423},
+{0x01044C, 0x010424},
+{0x01044D, 0x010425},
+{0x01044E, 0x010426},
+{0x01044F, 0x010427},
+{0x0104D8, 0x0104B0},
+{0x0104D9, 0x0104B1},
+{0x0104DA, 0x0104B2},
+{0x0104DB, 0x0104B3},
+{0x0104DC, 0x0104B4},
+{0x0104DD, 0x0104B5},
+{0x0104DE, 0x0104B6},
+{0x0104DF, 0x0104B7},
+{0x0104E0, 0x0104B8},
+{0x0104E1, 0x0104B9},
+{0x0104E2, 0x0104BA},
+{0x0104E3, 0x0104BB},
+{0x0104E4, 0x0104BC},
+{0x0104E5, 0x0104BD},
+{0x0104E6, 0x0104BE},
+{0x0104E7, 0x0104BF},
+{0x0104E8, 0x0104C0},
+{0x0104E9, 0x0104C1},
+{0x0104EA, 0x0104C2},
+{0x0104EB, 0x0104C3},
+{0x0104EC, 0x0104C4},
+{0x0104ED, 0x0104C5},
+{0x0104EE, 0x0104C6},
+{0x0104EF, 0x0104C7},
+{0x0104F0, 0x0104C8},
+{0x0104F1, 0x0104C9},
+{0x0104F2, 0x0104CA},
+{0x0104F3, 0x0104CB},
+{0x0104F4, 0x0104CC},
+{0x0104F5, 0x0104CD},
+{0x0104F6, 0x0104CE},
+{0x0104F7, 0x0104CF},
+{0x0104F8, 0x0104D0},
+{0x0104F9, 0x0104D1},
+{0x0104FA, 0x0104D2},
+{0x0104FB, 0x0104D3},
+{0x010597, 0x010570},
+{0x010598, 0x010571},
+{0x010599, 0x010572},
+{0x01059A, 0x010573},
+{0x01059B, 0x010574},
+{0x01059C, 0x010575},
+{0x01059D, 0x010576},
+{0x01059E, 0x010577},
+{0x01059F, 0x010578},
+{0x0105A0, 0x010579},
+{0x0105A1, 0x01057A},
+{0x0105A3, 0x01057C},
+{0x0105A4, 0x01057D},
+{0x0105A5, 0x01057E},
+{0x0105A6, 0x01057F},
+{0x0105A7, 0x010580},
+{0x0105A8, 0x010581},
+{0x0105A9, 0x010582},
+{0x0105AA, 0x010583},
+{0x0105AB, 0x010584},
+{0x0105AC, 0x010585},
+{0x0105AD, 0x010586},
+{0x0105AE, 0x010587},
+{0x0105AF, 0x010588},
+{0x0105B0, 0x010589},
+{0x0105B1, 0x01058A},
+{0x0105B3, 0x01058C},
+{0x0105B4, 0x01058D},
+{0x0105B5, 0x01058E},
+{0x0105B6, 0x01058F},
+{0x0105B7, 0x010590},
+{0x0105B8, 0x010591},
+{0x0105B9, 0x010592},
+{0x0105BB, 0x010594},
+{0x0105BC, 0x010595},
+{0x010CC0, 0x010C80},
+{0x010CC1, 0x010C81},
+{0x010CC2, 0x010C82},
+{0x010CC3, 0x010C83},
+{0x010CC4, 0x010C84},
+{0x010CC5, 0x010C85},
+{0x010CC6, 0x010C86},
+{0x010CC7, 0x010C87},
+{0x010CC8, 0x010C88},
+{0x010CC9, 0x010C89},
+{0x010CCA, 0x010C8A},
+{0x010CCB, 0x010C8B},
+{0x010CCC, 0x010C8C},
+{0x010CCD, 0x010C8D},
+{0x010CCE, 0x010C8E},
+{0x010CCF, 0x010C8F},
+{0x010CD0, 0x010C90},
+{0x010CD1, 0x010C91},
+{0x010CD2, 0x010C92},
+{0x010CD3, 0x010C93},
+{0x010CD4, 0x010C94},
+{0x010CD5, 0x010C95},
+{0x010CD6, 0x010C96},
+{0x010CD7, 0x010C97},
+{0x010CD8, 0x010C98},
+{0x010CD9, 0x010C99},
+{0x010CDA, 0x010C9A},
+{0x010CDB, 0x010C9B},
+{0x010CDC, 0x010C9C},
+{0x010CDD, 0x010C9D},
+{0x010CDE, 0x010C9E},
+{0x010CDF, 0x010C9F},
+{0x010CE0, 0x010CA0},
+{0x010CE1, 0x010CA1},
+{0x010CE2, 0x010CA2},
+{0x010CE3, 0x010CA3},
+{0x010CE4, 0x010CA4},
+{0x010CE5, 0x010CA5},
+{0x010CE6, 0x010CA6},
+{0x010CE7, 0x010CA7},
+{0x010CE8, 0x010CA8},
+{0x010CE9, 0x010CA9},
+{0x010CEA, 0x010CAA},
+{0x010CEB, 0x010CAB},
+{0x010CEC, 0x010CAC},
+{0x010CED, 0x010CAD},
+{0x010CEE, 0x010CAE},
+{0x010CEF, 0x010CAF},
+{0x010CF0, 0x010CB0},
+{0x010CF1, 0x010CB1},
+{0x010CF2, 0x010CB2},
+{0x0118C0, 0x0118A0},
+{0x0118C1, 0x0118A1},
+{0x0118C2, 0x0118A2},
+{0x0118C3, 0x0118A3},
+{0x0118C4, 0x0118A4},
+{0x0118C5, 0x0118A5},
+{0x0118C6, 0x0118A6},
+{0x0118C7, 0x0118A7},
+{0x0118C8, 0x0118A8},
+{0x0118C9, 0x0118A9},
+{0x0118CA, 0x0118AA},
+{0x0118CB, 0x0118AB},
+{0x0118CC, 0x0118AC},
+{0x0118CD, 0x0118AD},
+{0x0118CE, 0x0118AE},
+{0x0118CF, 0x0118AF},
+{0x0118D0, 0x0118B0},
+{0x0118D1, 0x0118B1},
+{0x0118D2, 0x0118B2},
+{0x0118D3, 0x0118B3},
+{0x0118D4, 0x0118B4},
+{0x0118D5, 0x0118B5},
+{0x0118D6, 0x0118B6},
+{0x0118D7, 0x0118B7},
+{0x0118D8, 0x0118B8},
+{0x0118D9, 0x0118B9},
+{0x0118DA, 0x0118BA},
+{0x0118DB, 0x0118BB},
+{0x0118DC, 0x0118BC},
+{0x0118DD, 0x0118BD},
+{0x0118DE, 0x0118BE},
+{0x0118DF, 0x0118BF},
+{0x016E60, 0x016E40},
+{0x016E61, 0x016E41},
+{0x016E62, 0x016E42},
+{0x016E63, 0x016E43},
+{0x016E64, 0x016E44},
+{0x016E65, 0x016E45},
+{0x016E66, 0x016E46},
+{0x016E67, 0x016E47},
+{0x016E68, 0x016E48},
+{0x016E69, 0x016E49},
+{0x016E6A, 0x016E4A},
+{0x016E6B, 0x016E4B},
+{0x016E6C, 0x016E4C},
+{0x016E6D, 0x016E4D},
+{0x016E6E, 0x016E4E},
+{0x016E6F, 0x016E4F},
+{0x016E70, 0x016E50},
+{0x016E71, 0x016E51},
+{0x016E72, 0x016E52},
+{0x016E73, 0x016E53},
+{0x016E74, 0x016E54},
+{0x016E75, 0x016E55},
+{0x016E76, 0x016E56},
+{0x016E77, 0x016E57},
+{0x016E78, 0x016E58},
+{0x016E79, 0x016E59},
+{0x016E7A, 0x016E5A},
+{0x016E7B, 0x016E5B},
+{0x016E7C, 0x016E5C},
+{0x016E7D, 0x016E5D},
+{0x016E7E, 0x016E5E},
+{0x016E7F, 0x016E5F},
+{0x01E922, 0x01E900},
+{0x01E923, 0x01E901},
+{0x01E924, 0x01E902},
+{0x01E925, 0x01E903},
+{0x01E926, 0x01E904},
+{0x01E927, 0x01E905},
+{0x01E928, 0x01E906},
+{0x01E929, 0x01E907},
+{0x01E92A, 0x01E908},
+{0x01E92B, 0x01E909},
+{0x01E92C, 0x01E90A},
+{0x01E92D, 0x01E90B},
+{0x01E92E, 0x01E90C},
+{0x01E92F, 0x01E90D},
+{0x01E930, 0x01E90E},
+{0x01E931, 0x01E90F},
+{0x01E932, 0x01E910},
+{0x01E933, 0x01E911},
+{0x01E934, 0x01E912},
+{0x01E935, 0x01E913},
+{0x01E936, 0x01E914},
+{0x01E937, 0x01E915},
+{0x01E938, 0x01E916},
+{0x01E939, 0x01E917},
+{0x01E93A, 0x01E918},
+{0x01E93B, 0x01E919},
+{0x01E93C, 0x01E91A},
+{0x01E93D, 0x01E91B},
+{0x01E93E, 0x01E91C},
+{0x01E93F, 0x01E91D},
+{0x01E940, 0x01E91E},
+{0x01E941, 0x01E91F},
+{0x01E942, 0x01E920},
+{0x01E943, 0x01E921},
+};
+
+const std::initializer_list unicode_ranges_nfd = {  // start, last, nfd
+{0x000000, 0x000000, 0x000000},
+{0x0000C0, 0x0000C5, 0x000041},
+{0x0000C7, 0x0000C7, 0x000043},
+{0x0000C8, 0x0000CB, 0x000045},
+{0x0000CC, 0x0000CF, 0x000049},
+{0x0000D1, 0x0000D1, 0x00004E},
+{0x0000D2, 0x0000D6, 0x00004F},
+{0x0000D9, 0x0000DC, 0x000055},
+{0x0000DD, 0x0000DD, 0x000059},
+{0x0000E0, 0x0000E5, 0x000061},
+{0x0000E7, 0x0000E7, 0x000063},
+{0x0000E8, 0x0000EB, 0x000065},
+{0x0000EC, 0x0000EF, 0x000069},
+{0x0000F1, 0x0000F1, 0x00006E},
+{0x0000F2, 0x0000F6, 0x00006F},
+{0x0000F9, 0x0000FC, 0x000075},
+{0x0000FD, 0x0000FD, 0x000079},
+{0x0000FF, 0x0000FF, 0x000079},
+{0x000100, 0x000100, 0x000041},
+{0x000101, 0x000101, 0x000061},
+{0x000102, 0x000102, 0x000041},
+{0x000103, 0x000103, 0x000061},
+{0x000104, 0x000104, 0x000041},
+{0x000105, 0x000105, 0x000061},
+{0x000106, 0x000106, 0x000043},
+{0x000107, 0x000107, 0x000063},
+{0x000108, 0x000108, 0x000043},
+{0x000109, 0x000109, 0x000063},
+{0x00010A, 0x00010A, 0x000043},
+{0x00010B, 0x00010B, 0x000063},
+{0x00010C, 0x00010C, 0x000043},
+{0x00010D, 0x00010D, 0x000063},
+{0x00010E, 0x00010E, 0x000044},
+{0x00010F, 0x00010F, 0x000064},
+{0x000112, 0x000112, 0x000045},
+{0x000113, 0x000113, 0x000065},
+{0x000114, 0x000114, 0x000045},
+{0x000115, 0x000115, 0x000065},
+{0x000116, 0x000116, 0x000045},
+{0x000117, 0x000117, 0x000065},
+{0x000118, 0x000118, 0x000045},
+{0x000119, 0x000119, 0x000065},
+{0x00011A, 0x00011A, 0x000045},
+{0x00011B, 0x00011B, 0x000065},
+{0x00011C, 0x00011C, 0x000047},
+{0x00011D, 0x00011D, 0x000067},
+{0x00011E, 0x00011E, 0x000047},
+{0x00011F, 0x00011F, 0x000067},
+{0x000120, 0x000120, 0x000047},
+{0x000121, 0x000121, 0x000067},
+{0x000122, 0x000122, 0x000047},
+{0x000123, 0x000123, 0x000067},
+{0x000124, 0x000124, 0x000048},
+{0x000125, 0x000125, 0x000068},
+{0x000128, 0x000128, 0x000049},
+{0x000129, 0x000129, 0x000069},
+{0x00012A, 0x00012A, 0x000049},
+{0x00012B, 0x00012B, 0x000069},
+{0x00012C, 0x00012C, 0x000049},
+{0x00012D, 0x00012D, 0x000069},
+{0x00012E, 0x00012E, 0x000049},
+{0x00012F, 0x00012F, 0x000069},
+{0x000130, 0x000130, 0x000049},
+{0x000134, 0x000134, 0x00004A},
+{0x000135, 0x000135, 0x00006A},
+{0x000136, 0x000136, 0x00004B},
+{0x000137, 0x000137, 0x00006B},
+{0x000139, 0x000139, 0x00004C},
+{0x00013A, 0x00013A, 0x00006C},
+{0x00013B, 0x00013B, 0x00004C},
+{0x00013C, 0x00013C, 0x00006C},
+{0x00013D, 0x00013D, 0x00004C},
+{0x00013E, 0x00013E, 0x00006C},
+{0x000143, 0x000143, 0x00004E},
+{0x000144, 0x000144, 0x00006E},
+{0x000145, 0x000145, 0x00004E},
+{0x000146, 0x000146, 0x00006E},
+{0x000147, 0x000147, 0x00004E},
+{0x000148, 0x000148, 0x00006E},
+{0x00014C, 0x00014C, 0x00004F},
+{0x00014D, 0x00014D, 0x00006F},
+{0x00014E, 0x00014E, 0x00004F},
+{0x00014F, 0x00014F, 0x00006F},
+{0x000150, 0x000150, 0x00004F},
+{0x000151, 0x000151, 0x00006F},
+{0x000154, 0x000154, 0x000052},
+{0x000155, 0x000155, 0x000072},
+{0x000156, 0x000156, 0x000052},
+{0x000157, 0x000157, 0x000072},
+{0x000158, 0x000158, 0x000052},
+{0x000159, 0x000159, 0x000072},
+{0x00015A, 0x00015A, 0x000053},
+{0x00015B, 0x00015B, 0x000073},
+{0x00015C, 0x00015C, 0x000053},
+{0x00015D, 0x00015D, 0x000073},
+{0x00015E, 0x00015E, 0x000053},
+{0x00015F, 0x00015F, 0x000073},
+{0x000160, 0x000160, 0x000053},
+{0x000161, 0x000161, 0x000073},
+{0x000162, 0x000162, 0x000054},
+{0x000163, 0x000163, 0x000074},
+{0x000164, 0x000164, 0x000054},
+{0x000165, 0x000165, 0x000074},
+{0x000168, 0x000168, 0x000055},
+{0x000169, 0x000169, 0x000075},
+{0x00016A, 0x00016A, 0x000055},
+{0x00016B, 0x00016B, 0x000075},
+{0x00016C, 0x00016C, 0x000055},
+{0x00016D, 0x00016D, 0x000075},
+{0x00016E, 0x00016E, 0x000055},
+{0x00016F, 0x00016F, 0x000075},
+{0x000170, 0x000170, 0x000055},
+{0x000171, 0x000171, 0x000075},
+{0x000172, 0x000172, 0x000055},
+{0x000173, 0x000173, 0x000075},
+{0x000174, 0x000174, 0x000057},
+{0x000175, 0x000175, 0x000077},
+{0x000176, 0x000176, 0x000059},
+{0x000177, 0x000177, 0x000079},
+{0x000178, 0x000178, 0x000059},
+{0x000179, 0x000179, 0x00005A},
+{0x00017A, 0x00017A, 0x00007A},
+{0x00017B, 0x00017B, 0x00005A},
+{0x00017C, 0x00017C, 0x00007A},
+{0x00017D, 0x00017D, 0x00005A},
+{0x00017E, 0x00017E, 0x00007A},
+{0x0001A0, 0x0001A0, 0x00004F},
+{0x0001A1, 0x0001A1, 0x00006F},
+{0x0001AF, 0x0001AF, 0x000055},
+{0x0001B0, 0x0001B0, 0x000075},
+{0x0001CD, 0x0001CD, 0x000041},
+{0x0001CE, 0x0001CE, 0x000061},
+{0x0001CF, 0x0001CF, 0x000049},
+{0x0001D0, 0x0001D0, 0x000069},
+{0x0001D1, 0x0001D1, 0x00004F},
+{0x0001D2, 0x0001D2, 0x00006F},
+{0x0001D3, 0x0001D3, 0x000055},
+{0x0001D4, 0x0001D4, 0x000075},
+{0x0001D5, 0x0001D5, 0x000055},
+{0x0001D6, 0x0001D6, 0x000075},
+{0x0001D7, 0x0001D7, 0x000055},
+{0x0001D8, 0x0001D8, 0x000075},
+{0x0001D9, 0x0001D9, 0x000055},
+{0x0001DA, 0x0001DA, 0x000075},
+{0x0001DB, 0x0001DB, 0x000055},
+{0x0001DC, 0x0001DC, 0x000075},
+{0x0001DE, 0x0001DE, 0x000041},
+{0x0001DF, 0x0001DF, 0x000061},
+{0x0001E0, 0x0001E0, 0x000041},
+{0x0001E1, 0x0001E1, 0x000061},
+{0x0001E2, 0x0001E2, 0x0000C6},
+{0x0001E3, 0x0001E3, 0x0000E6},
+{0x0001E6, 0x0001E6, 0x000047},
+{0x0001E7, 0x0001E7, 0x000067},
+{0x0001E8, 0x0001E8, 0x00004B},
+{0x0001E9, 0x0001E9, 0x00006B},
+{0x0001EA, 0x0001EA, 0x00004F},
+{0x0001EB, 0x0001EB, 0x00006F},
+{0x0001EC, 0x0001EC, 0x00004F},
+{0x0001ED, 0x0001ED, 0x00006F},
+{0x0001EE, 0x0001EE, 0x0001B7},
+{0x0001EF, 0x0001EF, 0x000292},
+{0x0001F0, 0x0001F0, 0x00006A},
+{0x0001F4, 0x0001F4, 0x000047},
+{0x0001F5, 0x0001F5, 0x000067},
+{0x0001F8, 0x0001F8, 0x00004E},
+{0x0001F9, 0x0001F9, 0x00006E},
+{0x0001FA, 0x0001FA, 0x000041},
+{0x0001FB, 0x0001FB, 0x000061},
+{0x0001FC, 0x0001FC, 0x0000C6},
+{0x0001FD, 0x0001FD, 0x0000E6},
+{0x0001FE, 0x0001FE, 0x0000D8},
+{0x0001FF, 0x0001FF, 0x0000F8},
+{0x000200, 0x000200, 0x000041},
+{0x000201, 0x000201, 0x000061},
+{0x000202, 0x000202, 0x000041},
+{0x000203, 0x000203, 0x000061},
+{0x000204, 0x000204, 0x000045},
+{0x000205, 0x000205, 0x000065},
+{0x000206, 0x000206, 0x000045},
+{0x000207, 0x000207, 0x000065},
+{0x000208, 0x000208, 0x000049},
+{0x000209, 0x000209, 0x000069},
+{0x00020A, 0x00020A, 0x000049},
+{0x00020B, 0x00020B, 0x000069},
+{0x00020C, 0x00020C, 0x00004F},
+{0x00020D, 0x00020D, 0x00006F},
+{0x00020E, 0x00020E, 0x00004F},
+{0x00020F, 0x00020F, 0x00006F},
+{0x000210, 0x000210, 0x000052},
+{0x000211, 0x000211, 0x000072},
+{0x000212, 0x000212, 0x000052},
+{0x000213, 0x000213, 0x000072},
+{0x000214, 0x000214, 0x000055},
+{0x000215, 0x000215, 0x000075},
+{0x000216, 0x000216, 0x000055},
+{0x000217, 0x000217, 0x000075},
+{0x000218, 0x000218, 0x000053},
+{0x000219, 0x000219, 0x000073},
+{0x00021A, 0x00021A, 0x000054},
+{0x00021B, 0x00021B, 0x000074},
+{0x00021E, 0x00021E, 0x000048},
+{0x00021F, 0x00021F, 0x000068},
+{0x000226, 0x000226, 0x000041},
+{0x000227, 0x000227, 0x000061},
+{0x000228, 0x000228, 0x000045},
+{0x000229, 0x000229, 0x000065},
+{0x00022A, 0x00022A, 0x00004F},
+{0x00022B, 0x00022B, 0x00006F},
+{0x00022C, 0x00022C, 0x00004F},
+{0x00022D, 0x00022D, 0x00006F},
+{0x00022E, 0x00022E, 0x00004F},
+{0x00022F, 0x00022F, 0x00006F},
+{0x000230, 0x000230, 0x00004F},
+{0x000231, 0x000231, 0x00006F},
+{0x000232, 0x000232, 0x000059},
+{0x000233, 0x000233, 0x000079},
+{0x000340, 0x000340, 0x000300},
+{0x000341, 0x000341, 0x000301},
+{0x000343, 0x000343, 0x000313},
+{0x000344, 0x000344, 0x000308},
+{0x000374, 0x000374, 0x0002B9},
+{0x00037E, 0x00037E, 0x00003B},
+{0x000385, 0x000385, 0x0000A8},
+{0x000386, 0x000386, 0x000391},
+{0x000387, 0x000387, 0x0000B7},
+{0x000388, 0x000388, 0x000395},
+{0x000389, 0x000389, 0x000397},
+{0x00038A, 0x00038A, 0x000399},
+{0x00038C, 0x00038C, 0x00039F},
+{0x00038E, 0x00038E, 0x0003A5},
+{0x00038F, 0x00038F, 0x0003A9},
+{0x000390, 0x000390, 0x0003B9},
+{0x0003AA, 0x0003AA, 0x000399},
+{0x0003AB, 0x0003AB, 0x0003A5},
+{0x0003AC, 0x0003AC, 0x0003B1},
+{0x0003AD, 0x0003AD, 0x0003B5},
+{0x0003AE, 0x0003AE, 0x0003B7},
+{0x0003AF, 0x0003AF, 0x0003B9},
+{0x0003B0, 0x0003B0, 0x0003C5},
+{0x0003CA, 0x0003CA, 0x0003B9},
+{0x0003CB, 0x0003CB, 0x0003C5},
+{0x0003CC, 0x0003CC, 0x0003BF},
+{0x0003CD, 0x0003CD, 0x0003C5},
+{0x0003CE, 0x0003CE, 0x0003C9},
+{0x0003D3, 0x0003D4, 0x0003D2},
+{0x000400, 0x000401, 0x000415},
+{0x000403, 0x000403, 0x000413},
+{0x000407, 0x000407, 0x000406},
+{0x00040C, 0x00040C, 0x00041A},
+{0x00040D, 0x00040D, 0x000418},
+{0x00040E, 0x00040E, 0x000423},
+{0x000419, 0x000419, 0x000418},
+{0x000439, 0x000439, 0x000438},
+{0x000450, 0x000451, 0x000435},
+{0x000453, 0x000453, 0x000433},
+{0x000457, 0x000457, 0x000456},
+{0x00045C, 0x00045C, 0x00043A},
+{0x00045D, 0x00045D, 0x000438},
+{0x00045E, 0x00045E, 0x000443},
+{0x000476, 0x000476, 0x000474},
+{0x000477, 0x000477, 0x000475},
+{0x0004C1, 0x0004C1, 0x000416},
+{0x0004C2, 0x0004C2, 0x000436},
+{0x0004D0, 0x0004D0, 0x000410},
+{0x0004D1, 0x0004D1, 0x000430},
+{0x0004D2, 0x0004D2, 0x000410},
+{0x0004D3, 0x0004D3, 0x000430},
+{0x0004D6, 0x0004D6, 0x000415},
+{0x0004D7, 0x0004D7, 0x000435},
+{0x0004DA, 0x0004DA, 0x0004D8},
+{0x0004DB, 0x0004DB, 0x0004D9},
+{0x0004DC, 0x0004DC, 0x000416},
+{0x0004DD, 0x0004DD, 0x000436},
+{0x0004DE, 0x0004DE, 0x000417},
+{0x0004DF, 0x0004DF, 0x000437},
+{0x0004E2, 0x0004E2, 0x000418},
+{0x0004E3, 0x0004E3, 0x000438},
+{0x0004E4, 0x0004E4, 0x000418},
+{0x0004E5, 0x0004E5, 0x000438},
+{0x0004E6, 0x0004E6, 0x00041E},
+{0x0004E7, 0x0004E7, 0x00043E},
+{0x0004EA, 0x0004EA, 0x0004E8},
+{0x0004EB, 0x0004EB, 0x0004E9},
+{0x0004EC, 0x0004EC, 0x00042D},
+{0x0004ED, 0x0004ED, 0x00044D},
+{0x0004EE, 0x0004EE, 0x000423},
+{0x0004EF, 0x0004EF, 0x000443},
+{0x0004F0, 0x0004F0, 0x000423},
+{0x0004F1, 0x0004F1, 0x000443},
+{0x0004F2, 0x0004F2, 0x000423},
+{0x0004F3, 0x0004F3, 0x000443},
+{0x0004F4, 0x0004F4, 0x000427},
+{0x0004F5, 0x0004F5, 0x000447},
+{0x0004F8, 0x0004F8, 0x00042B},
+{0x0004F9, 0x0004F9, 0x00044B},
+{0x000622, 0x000623, 0x000627},
+{0x000624, 0x000624, 0x000648},
+{0x000625, 0x000625, 0x000627},
+{0x000626, 0x000626, 0x00064A},
+{0x0006C0, 0x0006C0, 0x0006D5},
+{0x0006C2, 0x0006C2, 0x0006C1},
+{0x0006D3, 0x0006D3, 0x0006D2},
+{0x000929, 0x000929, 0x000928},
+{0x000931, 0x000931, 0x000930},
+{0x000934, 0x000934, 0x000933},
+{0x000958, 0x000958, 0x000915},
+{0x000959, 0x000959, 0x000916},
+{0x00095A, 0x00095A, 0x000917},
+{0x00095B, 0x00095B, 0x00091C},
+{0x00095C, 0x00095C, 0x000921},
+{0x00095D, 0x00095D, 0x000922},
+{0x00095E, 0x00095E, 0x00092B},
+{0x00095F, 0x00095F, 0x00092F},
+{0x0009CB, 0x0009CC, 0x0009C7},
+{0x0009DC, 0x0009DC, 0x0009A1},
+{0x0009DD, 0x0009DD, 0x0009A2},
+{0x0009DF, 0x0009DF, 0x0009AF},
+{0x000A33, 0x000A33, 0x000A32},
+{0x000A36, 0x000A36, 0x000A38},
+{0x000A59, 0x000A59, 0x000A16},
+{0x000A5A, 0x000A5A, 0x000A17},
+{0x000A5B, 0x000A5B, 0x000A1C},
+{0x000A5E, 0x000A5E, 0x000A2B},
+{0x000B48, 0x000B48, 0x000B47},
+{0x000B4B, 0x000B4C, 0x000B47},
+{0x000B5C, 0x000B5C, 0x000B21},
+{0x000B5D, 0x000B5D, 0x000B22},
+{0x000B94, 0x000B94, 0x000B92},
+{0x000BCA, 0x000BCA, 0x000BC6},
+{0x000BCB, 0x000BCB, 0x000BC7},
+{0x000BCC, 0x000BCC, 0x000BC6},
+{0x000C48, 0x000C48, 0x000C46},
+{0x000CC0, 0x000CC0, 0x000CBF},
+{0x000CC7, 0x000CC8, 0x000CC6},
+{0x000CCA, 0x000CCB, 0x000CC6},
+{0x000D4A, 0x000D4A, 0x000D46},
+{0x000D4B, 0x000D4B, 0x000D47},
+{0x000D4C, 0x000D4C, 0x000D46},
+{0x000DDA, 0x000DDA, 0x000DD9},
+{0x000DDC, 0x000DDE, 0x000DD9},
+{0x000F43, 0x000F43, 0x000F42},
+{0x000F4D, 0x000F4D, 0x000F4C},
+{0x000F52, 0x000F52, 0x000F51},
+{0x000F57, 0x000F57, 0x000F56},
+{0x000F5C, 0x000F5C, 0x000F5B},
+{0x000F69, 0x000F69, 0x000F40},
+{0x000F73, 0x000F73, 0x000F71},
+{0x000F75, 0x000F75, 0x000F71},
+{0x000F76, 0x000F76, 0x000FB2},
+{0x000F78, 0x000F78, 0x000FB3},
+{0x000F81, 0x000F81, 0x000F71},
+{0x000F93, 0x000F93, 0x000F92},
+{0x000F9D, 0x000F9D, 0x000F9C},
+{0x000FA2, 0x000FA2, 0x000FA1},
+{0x000FA7, 0x000FA7, 0x000FA6},
+{0x000FAC, 0x000FAC, 0x000FAB},
+{0x000FB9, 0x000FB9, 0x000F90},
+{0x001026, 0x001026, 0x001025},
+{0x001B06, 0x001B06, 0x001B05},
+{0x001B08, 0x001B08, 0x001B07},
+{0x001B0A, 0x001B0A, 0x001B09},
+{0x001B0C, 0x001B0C, 0x001B0B},
+{0x001B0E, 0x001B0E, 0x001B0D},
+{0x001B12, 0x001B12, 0x001B11},
+{0x001B3B, 0x001B3B, 0x001B3A},
+{0x001B3D, 0x001B3D, 0x001B3C},
+{0x001B40, 0x001B40, 0x001B3E},
+{0x001B41, 0x001B41, 0x001B3F},
+{0x001B43, 0x001B43, 0x001B42},
+{0x001E00, 0x001E00, 0x000041},
+{0x001E01, 0x001E01, 0x000061},
+{0x001E02, 0x001E02, 0x000042},
+{0x001E03, 0x001E03, 0x000062},
+{0x001E04, 0x001E04, 0x000042},
+{0x001E05, 0x001E05, 0x000062},
+{0x001E06, 0x001E06, 0x000042},
+{0x001E07, 0x001E07, 0x000062},
+{0x001E08, 0x001E08, 0x000043},
+{0x001E09, 0x001E09, 0x000063},
+{0x001E0A, 0x001E0A, 0x000044},
+{0x001E0B, 0x001E0B, 0x000064},
+{0x001E0C, 0x001E0C, 0x000044},
+{0x001E0D, 0x001E0D, 0x000064},
+{0x001E0E, 0x001E0E, 0x000044},
+{0x001E0F, 0x001E0F, 0x000064},
+{0x001E10, 0x001E10, 0x000044},
+{0x001E11, 0x001E11, 0x000064},
+{0x001E12, 0x001E12, 0x000044},
+{0x001E13, 0x001E13, 0x000064},
+{0x001E14, 0x001E14, 0x000045},
+{0x001E15, 0x001E15, 0x000065},
+{0x001E16, 0x001E16, 0x000045},
+{0x001E17, 0x001E17, 0x000065},
+{0x001E18, 0x001E18, 0x000045},
+{0x001E19, 0x001E19, 0x000065},
+{0x001E1A, 0x001E1A, 0x000045},
+{0x001E1B, 0x001E1B, 0x000065},
+{0x001E1C, 0x001E1C, 0x000045},
+{0x001E1D, 0x001E1D, 0x000065},
+{0x001E1E, 0x001E1E, 0x000046},
+{0x001E1F, 0x001E1F, 0x000066},
+{0x001E20, 0x001E20, 0x000047},
+{0x001E21, 0x001E21, 0x000067},
+{0x001E22, 0x001E22, 0x000048},
+{0x001E23, 0x001E23, 0x000068},
+{0x001E24, 0x001E24, 0x000048},
+{0x001E25, 0x001E25, 0x000068},
+{0x001E26, 0x001E26, 0x000048},
+{0x001E27, 0x001E27, 0x000068},
+{0x001E28, 0x001E28, 0x000048},
+{0x001E29, 0x001E29, 0x000068},
+{0x001E2A, 0x001E2A, 0x000048},
+{0x001E2B, 0x001E2B, 0x000068},
+{0x001E2C, 0x001E2C, 0x000049},
+{0x001E2D, 0x001E2D, 0x000069},
+{0x001E2E, 0x001E2E, 0x000049},
+{0x001E2F, 0x001E2F, 0x000069},
+{0x001E30, 0x001E30, 0x00004B},
+{0x001E31, 0x001E31, 0x00006B},
+{0x001E32, 0x001E32, 0x00004B},
+{0x001E33, 0x001E33, 0x00006B},
+{0x001E34, 0x001E34, 0x00004B},
+{0x001E35, 0x001E35, 0x00006B},
+{0x001E36, 0x001E36, 0x00004C},
+{0x001E37, 0x001E37, 0x00006C},
+{0x001E38, 0x001E38, 0x00004C},
+{0x001E39, 0x001E39, 0x00006C},
+{0x001E3A, 0x001E3A, 0x00004C},
+{0x001E3B, 0x001E3B, 0x00006C},
+{0x001E3C, 0x001E3C, 0x00004C},
+{0x001E3D, 0x001E3D, 0x00006C},
+{0x001E3E, 0x001E3E, 0x00004D},
+{0x001E3F, 0x001E3F, 0x00006D},
+{0x001E40, 0x001E40, 0x00004D},
+{0x001E41, 0x001E41, 0x00006D},
+{0x001E42, 0x001E42, 0x00004D},
+{0x001E43, 0x001E43, 0x00006D},
+{0x001E44, 0x001E44, 0x00004E},
+{0x001E45, 0x001E45, 0x00006E},
+{0x001E46, 0x001E46, 0x00004E},
+{0x001E47, 0x001E47, 0x00006E},
+{0x001E48, 0x001E48, 0x00004E},
+{0x001E49, 0x001E49, 0x00006E},
+{0x001E4A, 0x001E4A, 0x00004E},
+{0x001E4B, 0x001E4B, 0x00006E},
+{0x001E4C, 0x001E4C, 0x00004F},
+{0x001E4D, 0x001E4D, 0x00006F},
+{0x001E4E, 0x001E4E, 0x00004F},
+{0x001E4F, 0x001E4F, 0x00006F},
+{0x001E50, 0x001E50, 0x00004F},
+{0x001E51, 0x001E51, 0x00006F},
+{0x001E52, 0x001E52, 0x00004F},
+{0x001E53, 0x001E53, 0x00006F},
+{0x001E54, 0x001E54, 0x000050},
+{0x001E55, 0x001E55, 0x000070},
+{0x001E56, 0x001E56, 0x000050},
+{0x001E57, 0x001E57, 0x000070},
+{0x001E58, 0x001E58, 0x000052},
+{0x001E59, 0x001E59, 0x000072},
+{0x001E5A, 0x001E5A, 0x000052},
+{0x001E5B, 0x001E5B, 0x000072},
+{0x001E5C, 0x001E5C, 0x000052},
+{0x001E5D, 0x001E5D, 0x000072},
+{0x001E5E, 0x001E5E, 0x000052},
+{0x001E5F, 0x001E5F, 0x000072},
+{0x001E60, 0x001E60, 0x000053},
+{0x001E61, 0x001E61, 0x000073},
+{0x001E62, 0x001E62, 0x000053},
+{0x001E63, 0x001E63, 0x000073},
+{0x001E64, 0x001E64, 0x000053},
+{0x001E65, 0x001E65, 0x000073},
+{0x001E66, 0x001E66, 0x000053},
+{0x001E67, 0x001E67, 0x000073},
+{0x001E68, 0x001E68, 0x000053},
+{0x001E69, 0x001E69, 0x000073},
+{0x001E6A, 0x001E6A, 0x000054},
+{0x001E6B, 0x001E6B, 0x000074},
+{0x001E6C, 0x001E6C, 0x000054},
+{0x001E6D, 0x001E6D, 0x000074},
+{0x001E6E, 0x001E6E, 0x000054},
+{0x001E6F, 0x001E6F, 0x000074},
+{0x001E70, 0x001E70, 0x000054},
+{0x001E71, 0x001E71, 0x000074},
+{0x001E72, 0x001E72, 0x000055},
+{0x001E73, 0x001E73, 0x000075},
+{0x001E74, 0x001E74, 0x000055},
+{0x001E75, 0x001E75, 0x000075},
+{0x001E76, 0x001E76, 0x000055},
+{0x001E77, 0x001E77, 0x000075},
+{0x001E78, 0x001E78, 0x000055},
+{0x001E79, 0x001E79, 0x000075},
+{0x001E7A, 0x001E7A, 0x000055},
+{0x001E7B, 0x001E7B, 0x000075},
+{0x001E7C, 0x001E7C, 0x000056},
+{0x001E7D, 0x001E7D, 0x000076},
+{0x001E7E, 0x001E7E, 0x000056},
+{0x001E7F, 0x001E7F, 0x000076},
+{0x001E80, 0x001E80, 0x000057},
+{0x001E81, 0x001E81, 0x000077},
+{0x001E82, 0x001E82, 0x000057},
+{0x001E83, 0x001E83, 0x000077},
+{0x001E84, 0x001E84, 0x000057},
+{0x001E85, 0x001E85, 0x000077},
+{0x001E86, 0x001E86, 0x000057},
+{0x001E87, 0x001E87, 0x000077},
+{0x001E88, 0x001E88, 0x000057},
+{0x001E89, 0x001E89, 0x000077},
+{0x001E8A, 0x001E8A, 0x000058},
+{0x001E8B, 0x001E8B, 0x000078},
+{0x001E8C, 0x001E8C, 0x000058},
+{0x001E8D, 0x001E8D, 0x000078},
+{0x001E8E, 0x001E8E, 0x000059},
+{0x001E8F, 0x001E8F, 0x000079},
+{0x001E90, 0x001E90, 0x00005A},
+{0x001E91, 0x001E91, 0x00007A},
+{0x001E92, 0x001E92, 0x00005A},
+{0x001E93, 0x001E93, 0x00007A},
+{0x001E94, 0x001E94, 0x00005A},
+{0x001E95, 0x001E95, 0x00007A},
+{0x001E96, 0x001E96, 0x000068},
+{0x001E97, 0x001E97, 0x000074},
+{0x001E98, 0x001E98, 0x000077},
+{0x001E99, 0x001E99, 0x000079},
+{0x001E9B, 0x001E9B, 0x00017F},
+{0x001EA0, 0x001EA0, 0x000041},
+{0x001EA1, 0x001EA1, 0x000061},
+{0x001EA2, 0x001EA2, 0x000041},
+{0x001EA3, 0x001EA3, 0x000061},
+{0x001EA4, 0x001EA4, 0x000041},
+{0x001EA5, 0x001EA5, 0x000061},
+{0x001EA6, 0x001EA6, 0x000041},
+{0x001EA7, 0x001EA7, 0x000061},
+{0x001EA8, 0x001EA8, 0x000041},
+{0x001EA9, 0x001EA9, 0x000061},
+{0x001EAA, 0x001EAA, 0x000041},
+{0x001EAB, 0x001EAB, 0x000061},
+{0x001EAC, 0x001EAC, 0x000041},
+{0x001EAD, 0x001EAD, 0x000061},
+{0x001EAE, 0x001EAE, 0x000041},
+{0x001EAF, 0x001EAF, 0x000061},
+{0x001EB0, 0x001EB0, 0x000041},
+{0x001EB1, 0x001EB1, 0x000061},
+{0x001EB2, 0x001EB2, 0x000041},
+{0x001EB3, 0x001EB3, 0x000061},
+{0x001EB4, 0x001EB4, 0x000041},
+{0x001EB5, 0x001EB5, 0x000061},
+{0x001EB6, 0x001EB6, 0x000041},
+{0x001EB7, 0x001EB7, 0x000061},
+{0x001EB8, 0x001EB8, 0x000045},
+{0x001EB9, 0x001EB9, 0x000065},
+{0x001EBA, 0x001EBA, 0x000045},
+{0x001EBB, 0x001EBB, 0x000065},
+{0x001EBC, 0x001EBC, 0x000045},
+{0x001EBD, 0x001EBD, 0x000065},
+{0x001EBE, 0x001EBE, 0x000045},
+{0x001EBF, 0x001EBF, 0x000065},
+{0x001EC0, 0x001EC0, 0x000045},
+{0x001EC1, 0x001EC1, 0x000065},
+{0x001EC2, 0x001EC2, 0x000045},
+{0x001EC3, 0x001EC3, 0x000065},
+{0x001EC4, 0x001EC4, 0x000045},
+{0x001EC5, 0x001EC5, 0x000065},
+{0x001EC6, 0x001EC6, 0x000045},
+{0x001EC7, 0x001EC7, 0x000065},
+{0x001EC8, 0x001EC8, 0x000049},
+{0x001EC9, 0x001EC9, 0x000069},
+{0x001ECA, 0x001ECA, 0x000049},
+{0x001ECB, 0x001ECB, 0x000069},
+{0x001ECC, 0x001ECC, 0x00004F},
+{0x001ECD, 0x001ECD, 0x00006F},
+{0x001ECE, 0x001ECE, 0x00004F},
+{0x001ECF, 0x001ECF, 0x00006F},
+{0x001ED0, 0x001ED0, 0x00004F},
+{0x001ED1, 0x001ED1, 0x00006F},
+{0x001ED2, 0x001ED2, 0x00004F},
+{0x001ED3, 0x001ED3, 0x00006F},
+{0x001ED4, 0x001ED4, 0x00004F},
+{0x001ED5, 0x001ED5, 0x00006F},
+{0x001ED6, 0x001ED6, 0x00004F},
+{0x001ED7, 0x001ED7, 0x00006F},
+{0x001ED8, 0x001ED8, 0x00004F},
+{0x001ED9, 0x001ED9, 0x00006F},
+{0x001EDA, 0x001EDA, 0x00004F},
+{0x001EDB, 0x001EDB, 0x00006F},
+{0x001EDC, 0x001EDC, 0x00004F},
+{0x001EDD, 0x001EDD, 0x00006F},
+{0x001EDE, 0x001EDE, 0x00004F},
+{0x001EDF, 0x001EDF, 0x00006F},
+{0x001EE0, 0x001EE0, 0x00004F},
+{0x001EE1, 0x001EE1, 0x00006F},
+{0x001EE2, 0x001EE2, 0x00004F},
+{0x001EE3, 0x001EE3, 0x00006F},
+{0x001EE4, 0x001EE4, 0x000055},
+{0x001EE5, 0x001EE5, 0x000075},
+{0x001EE6, 0x001EE6, 0x000055},
+{0x001EE7, 0x001EE7, 0x000075},
+{0x001EE8, 0x001EE8, 0x000055},
+{0x001EE9, 0x001EE9, 0x000075},
+{0x001EEA, 0x001EEA, 0x000055},
+{0x001EEB, 0x001EEB, 0x000075},
+{0x001EEC, 0x001EEC, 0x000055},
+{0x001EED, 0x001EED, 0x000075},
+{0x001EEE, 0x001EEE, 0x000055},
+{0x001EEF, 0x001EEF, 0x000075},
+{0x001EF0, 0x001EF0, 0x000055},
+{0x001EF1, 0x001EF1, 0x000075},
+{0x001EF2, 0x001EF2, 0x000059},
+{0x001EF3, 0x001EF3, 0x000079},
+{0x001EF4, 0x001EF4, 0x000059},
+{0x001EF5, 0x001EF5, 0x000079},
+{0x001EF6, 0x001EF6, 0x000059},
+{0x001EF7, 0x001EF7, 0x000079},
+{0x001EF8, 0x001EF8, 0x000059},
+{0x001EF9, 0x001EF9, 0x000079},
+{0x001F00, 0x001F07, 0x0003B1},
+{0x001F08, 0x001F0F, 0x000391},
+{0x001F10, 0x001F15, 0x0003B5},
+{0x001F18, 0x001F1D, 0x000395},
+{0x001F20, 0x001F27, 0x0003B7},
+{0x001F28, 0x001F2F, 0x000397},
+{0x001F30, 0x001F37, 0x0003B9},
+{0x001F38, 0x001F3F, 0x000399},
+{0x001F40, 0x001F45, 0x0003BF},
+{0x001F48, 0x001F4D, 0x00039F},
+{0x001F50, 0x001F57, 0x0003C5},
+{0x001F59, 0x001F59, 0x0003A5},
+{0x001F5B, 0x001F5B, 0x0003A5},
+{0x001F5D, 0x001F5D, 0x0003A5},
+{0x001F5F, 0x001F5F, 0x0003A5},
+{0x001F60, 0x001F67, 0x0003C9},
+{0x001F68, 0x001F6F, 0x0003A9},
+{0x001F70, 0x001F71, 0x0003B1},
+{0x001F72, 0x001F73, 0x0003B5},
+{0x001F74, 0x001F75, 0x0003B7},
+{0x001F76, 0x001F77, 0x0003B9},
+{0x001F78, 0x001F79, 0x0003BF},
+{0x001F7A, 0x001F7B, 0x0003C5},
+{0x001F7C, 0x001F7D, 0x0003C9},
+{0x001F80, 0x001F87, 0x0003B1},
+{0x001F88, 0x001F8F, 0x000391},
+{0x001F90, 0x001F97, 0x0003B7},
+{0x001F98, 0x001F9F, 0x000397},
+{0x001FA0, 0x001FA7, 0x0003C9},
+{0x001FA8, 0x001FAF, 0x0003A9},
+{0x001FB0, 0x001FB4, 0x0003B1},
+{0x001FB6, 0x001FB7, 0x0003B1},
+{0x001FB8, 0x001FBC, 0x000391},
+{0x001FBE, 0x001FBE, 0x0003B9},
+{0x001FC1, 0x001FC1, 0x0000A8},
+{0x001FC2, 0x001FC4, 0x0003B7},
+{0x001FC6, 0x001FC7, 0x0003B7},
+{0x001FC8, 0x001FC9, 0x000395},
+{0x001FCA, 0x001FCC, 0x000397},
+{0x001FCD, 0x001FCF, 0x001FBF},
+{0x001FD0, 0x001FD3, 0x0003B9},
+{0x001FD6, 0x001FD7, 0x0003B9},
+{0x001FD8, 0x001FDB, 0x000399},
+{0x001FDD, 0x001FDF, 0x001FFE},
+{0x001FE0, 0x001FE3, 0x0003C5},
+{0x001FE4, 0x001FE5, 0x0003C1},
+{0x001FE6, 0x001FE7, 0x0003C5},
+{0x001FE8, 0x001FEB, 0x0003A5},
+{0x001FEC, 0x001FEC, 0x0003A1},
+{0x001FED, 0x001FEE, 0x0000A8},
+{0x001FEF, 0x001FEF, 0x000060},
+{0x001FF2, 0x001FF4, 0x0003C9},
+{0x001FF6, 0x001FF7, 0x0003C9},
+{0x001FF8, 0x001FF9, 0x00039F},
+{0x001FFA, 0x001FFC, 0x0003A9},
+{0x001FFD, 0x001FFD, 0x0000B4},
+{0x002000, 0x002000, 0x002002},
+{0x002001, 0x002001, 0x002003},
+{0x002126, 0x002126, 0x0003A9},
+{0x00212A, 0x00212A, 0x00004B},
+{0x00212B, 0x00212B, 0x000041},
+{0x00219A, 0x00219A, 0x002190},
+{0x00219B, 0x00219B, 0x002192},
+{0x0021AE, 0x0021AE, 0x002194},
+{0x0021CD, 0x0021CD, 0x0021D0},
+{0x0021CE, 0x0021CE, 0x0021D4},
+{0x0021CF, 0x0021CF, 0x0021D2},
+{0x002204, 0x002204, 0x002203},
+{0x002209, 0x002209, 0x002208},
+{0x00220C, 0x00220C, 0x00220B},
+{0x002224, 0x002224, 0x002223},
+{0x002226, 0x002226, 0x002225},
+{0x002241, 0x002241, 0x00223C},
+{0x002244, 0x002244, 0x002243},
+{0x002247, 0x002247, 0x002245},
+{0x002249, 0x002249, 0x002248},
+{0x002260, 0x002260, 0x00003D},
+{0x002262, 0x002262, 0x002261},
+{0x00226D, 0x00226D, 0x00224D},
+{0x00226E, 0x00226E, 0x00003C},
+{0x00226F, 0x00226F, 0x00003E},
+{0x002270, 0x002270, 0x002264},
+{0x002271, 0x002271, 0x002265},
+{0x002274, 0x002274, 0x002272},
+{0x002275, 0x002275, 0x002273},
+{0x002278, 0x002278, 0x002276},
+{0x002279, 0x002279, 0x002277},
+{0x002280, 0x002280, 0x00227A},
+{0x002281, 0x002281, 0x00227B},
+{0x002284, 0x002284, 0x002282},
+{0x002285, 0x002285, 0x002283},
+{0x002288, 0x002288, 0x002286},
+{0x002289, 0x002289, 0x002287},
+{0x0022AC, 0x0022AC, 0x0022A2},
+{0x0022AD, 0x0022AD, 0x0022A8},
+{0x0022AE, 0x0022AE, 0x0022A9},
+{0x0022AF, 0x0022AF, 0x0022AB},
+{0x0022E0, 0x0022E0, 0x00227C},
+{0x0022E1, 0x0022E1, 0x00227D},
+{0x0022E2, 0x0022E2, 0x002291},
+{0x0022E3, 0x0022E3, 0x002292},
+{0x0022EA, 0x0022EA, 0x0022B2},
+{0x0022EB, 0x0022EB, 0x0022B3},
+{0x0022EC, 0x0022EC, 0x0022B4},
+{0x0022ED, 0x0022ED, 0x0022B5},
+{0x002329, 0x002329, 0x003008},
+{0x00232A, 0x00232A, 0x003009},
+{0x002ADC, 0x002ADC, 0x002ADD},
+{0x00304C, 0x00304C, 0x00304B},
+{0x00304E, 0x00304E, 0x00304D},
+{0x003050, 0x003050, 0x00304F},
+{0x003052, 0x003052, 0x003051},
+{0x003054, 0x003054, 0x003053},
+{0x003056, 0x003056, 0x003055},
+{0x003058, 0x003058, 0x003057},
+{0x00305A, 0x00305A, 0x003059},
+{0x00305C, 0x00305C, 0x00305B},
+{0x00305E, 0x00305E, 0x00305D},
+{0x003060, 0x003060, 0x00305F},
+{0x003062, 0x003062, 0x003061},
+{0x003065, 0x003065, 0x003064},
+{0x003067, 0x003067, 0x003066},
+{0x003069, 0x003069, 0x003068},
+{0x003070, 0x003071, 0x00306F},
+{0x003073, 0x003074, 0x003072},
+{0x003076, 0x003077, 0x003075},
+{0x003079, 0x00307A, 0x003078},
+{0x00307C, 0x00307D, 0x00307B},
+{0x003094, 0x003094, 0x003046},
+{0x00309E, 0x00309E, 0x00309D},
+{0x0030AC, 0x0030AC, 0x0030AB},
+{0x0030AE, 0x0030AE, 0x0030AD},
+{0x0030B0, 0x0030B0, 0x0030AF},
+{0x0030B2, 0x0030B2, 0x0030B1},
+{0x0030B4, 0x0030B4, 0x0030B3},
+{0x0030B6, 0x0030B6, 0x0030B5},
+{0x0030B8, 0x0030B8, 0x0030B7},
+{0x0030BA, 0x0030BA, 0x0030B9},
+{0x0030BC, 0x0030BC, 0x0030BB},
+{0x0030BE, 0x0030BE, 0x0030BD},
+{0x0030C0, 0x0030C0, 0x0030BF},
+{0x0030C2, 0x0030C2, 0x0030C1},
+{0x0030C5, 0x0030C5, 0x0030C4},
+{0x0030C7, 0x0030C7, 0x0030C6},
+{0x0030C9, 0x0030C9, 0x0030C8},
+{0x0030D0, 0x0030D1, 0x0030CF},
+{0x0030D3, 0x0030D4, 0x0030D2},
+{0x0030D6, 0x0030D7, 0x0030D5},
+{0x0030D9, 0x0030DA, 0x0030D8},
+{0x0030DC, 0x0030DD, 0x0030DB},
+{0x0030F4, 0x0030F4, 0x0030A6},
+{0x0030F7, 0x0030F7, 0x0030EF},
+{0x0030F8, 0x0030F8, 0x0030F0},
+{0x0030F9, 0x0030F9, 0x0030F1},
+{0x0030FA, 0x0030FA, 0x0030F2},
+{0x0030FE, 0x0030FE, 0x0030FD},
+{0x00AC00, 0x00AE4B, 0x001100},
+{0x00AE4C, 0x00B097, 0x001101},
+{0x00B098, 0x00B2E3, 0x001102},
+{0x00B2E4, 0x00B52F, 0x001103},
+{0x00B530, 0x00B77B, 0x001104},
+{0x00B77C, 0x00B9C7, 0x001105},
+{0x00B9C8, 0x00BC13, 0x001106},
+{0x00BC14, 0x00BE5F, 0x001107},
+{0x00BE60, 0x00C0AB, 0x001108},
+{0x00C0AC, 0x00C2F7, 0x001109},
+{0x00C2F8, 0x00C543, 0x00110A},
+{0x00C544, 0x00C78F, 0x00110B},
+{0x00C790, 0x00C9DB, 0x00110C},
+{0x00C9DC, 0x00CC27, 0x00110D},
+{0x00CC28, 0x00CE73, 0x00110E},
+{0x00CE74, 0x00D0BF, 0x00110F},
+{0x00D0C0, 0x00D30B, 0x001110},
+{0x00D30C, 0x00D557, 0x001111},
+{0x00D558, 0x00D7A3, 0x001112},
+{0x00F900, 0x00F900, 0x008C48},
+{0x00F901, 0x00F901, 0x0066F4},
+{0x00F902, 0x00F902, 0x008ECA},
+{0x00F903, 0x00F903, 0x008CC8},
+{0x00F904, 0x00F904, 0x006ED1},
+{0x00F905, 0x00F905, 0x004E32},
+{0x00F906, 0x00F906, 0x0053E5},
+{0x00F907, 0x00F908, 0x009F9C},
+{0x00F909, 0x00F909, 0x005951},
+{0x00F90A, 0x00F90A, 0x0091D1},
+{0x00F90B, 0x00F90B, 0x005587},
+{0x00F90C, 0x00F90C, 0x005948},
+{0x00F90D, 0x00F90D, 0x0061F6},
+{0x00F90E, 0x00F90E, 0x007669},
+{0x00F90F, 0x00F90F, 0x007F85},
+{0x00F910, 0x00F910, 0x00863F},
+{0x00F911, 0x00F911, 0x0087BA},
+{0x00F912, 0x00F912, 0x0088F8},
+{0x00F913, 0x00F913, 0x00908F},
+{0x00F914, 0x00F914, 0x006A02},
+{0x00F915, 0x00F915, 0x006D1B},
+{0x00F916, 0x00F916, 0x0070D9},
+{0x00F917, 0x00F917, 0x0073DE},
+{0x00F918, 0x00F918, 0x00843D},
+{0x00F919, 0x00F919, 0x00916A},
+{0x00F91A, 0x00F91A, 0x0099F1},
+{0x00F91B, 0x00F91B, 0x004E82},
+{0x00F91C, 0x00F91C, 0x005375},
+{0x00F91D, 0x00F91D, 0x006B04},
+{0x00F91E, 0x00F91E, 0x00721B},
+{0x00F91F, 0x00F91F, 0x00862D},
+{0x00F920, 0x00F920, 0x009E1E},
+{0x00F921, 0x00F921, 0x005D50},
+{0x00F922, 0x00F922, 0x006FEB},
+{0x00F923, 0x00F923, 0x0085CD},
+{0x00F924, 0x00F924, 0x008964},
+{0x00F925, 0x00F925, 0x0062C9},
+{0x00F926, 0x00F926, 0x0081D8},
+{0x00F927, 0x00F927, 0x00881F},
+{0x00F928, 0x00F928, 0x005ECA},
+{0x00F929, 0x00F929, 0x006717},
+{0x00F92A, 0x00F92A, 0x006D6A},
+{0x00F92B, 0x00F92B, 0x0072FC},
+{0x00F92C, 0x00F92C, 0x0090CE},
+{0x00F92D, 0x00F92D, 0x004F86},
+{0x00F92E, 0x00F92E, 0x0051B7},
+{0x00F92F, 0x00F92F, 0x0052DE},
+{0x00F930, 0x00F930, 0x0064C4},
+{0x00F931, 0x00F931, 0x006AD3},
+{0x00F932, 0x00F932, 0x007210},
+{0x00F933, 0x00F933, 0x0076E7},
+{0x00F934, 0x00F934, 0x008001},
+{0x00F935, 0x00F935, 0x008606},
+{0x00F936, 0x00F936, 0x00865C},
+{0x00F937, 0x00F937, 0x008DEF},
+{0x00F938, 0x00F938, 0x009732},
+{0x00F939, 0x00F939, 0x009B6F},
+{0x00F93A, 0x00F93A, 0x009DFA},
+{0x00F93B, 0x00F93B, 0x00788C},
+{0x00F93C, 0x00F93C, 0x00797F},
+{0x00F93D, 0x00F93D, 0x007DA0},
+{0x00F93E, 0x00F93E, 0x0083C9},
+{0x00F93F, 0x00F93F, 0x009304},
+{0x00F940, 0x00F940, 0x009E7F},
+{0x00F941, 0x00F941, 0x008AD6},
+{0x00F942, 0x00F942, 0x0058DF},
+{0x00F943, 0x00F943, 0x005F04},
+{0x00F944, 0x00F944, 0x007C60},
+{0x00F945, 0x00F945, 0x00807E},
+{0x00F946, 0x00F946, 0x007262},
+{0x00F947, 0x00F947, 0x0078CA},
+{0x00F948, 0x00F948, 0x008CC2},
+{0x00F949, 0x00F949, 0x0096F7},
+{0x00F94A, 0x00F94A, 0x0058D8},
+{0x00F94B, 0x00F94B, 0x005C62},
+{0x00F94C, 0x00F94C, 0x006A13},
+{0x00F94D, 0x00F94D, 0x006DDA},
+{0x00F94E, 0x00F94E, 0x006F0F},
+{0x00F94F, 0x00F94F, 0x007D2F},
+{0x00F950, 0x00F950, 0x007E37},
+{0x00F951, 0x00F951, 0x00964B},
+{0x00F952, 0x00F952, 0x0052D2},
+{0x00F953, 0x00F953, 0x00808B},
+{0x00F954, 0x00F954, 0x0051DC},
+{0x00F955, 0x00F955, 0x0051CC},
+{0x00F956, 0x00F956, 0x007A1C},
+{0x00F957, 0x00F957, 0x007DBE},
+{0x00F958, 0x00F958, 0x0083F1},
+{0x00F959, 0x00F959, 0x009675},
+{0x00F95A, 0x00F95A, 0x008B80},
+{0x00F95B, 0x00F95B, 0x0062CF},
+{0x00F95C, 0x00F95C, 0x006A02},
+{0x00F95D, 0x00F95D, 0x008AFE},
+{0x00F95E, 0x00F95E, 0x004E39},
+{0x00F95F, 0x00F95F, 0x005BE7},
+{0x00F960, 0x00F960, 0x006012},
+{0x00F961, 0x00F961, 0x007387},
+{0x00F962, 0x00F962, 0x007570},
+{0x00F963, 0x00F963, 0x005317},
+{0x00F964, 0x00F964, 0x0078FB},
+{0x00F965, 0x00F965, 0x004FBF},
+{0x00F966, 0x00F966, 0x005FA9},
+{0x00F967, 0x00F967, 0x004E0D},
+{0x00F968, 0x00F968, 0x006CCC},
+{0x00F969, 0x00F969, 0x006578},
+{0x00F96A, 0x00F96A, 0x007D22},
+{0x00F96B, 0x00F96B, 0x0053C3},
+{0x00F96C, 0x00F96C, 0x00585E},
+{0x00F96D, 0x00F96D, 0x007701},
+{0x00F96E, 0x00F96E, 0x008449},
+{0x00F96F, 0x00F96F, 0x008AAA},
+{0x00F970, 0x00F970, 0x006BBA},
+{0x00F971, 0x00F971, 0x008FB0},
+{0x00F972, 0x00F972, 0x006C88},
+{0x00F973, 0x00F973, 0x0062FE},
+{0x00F974, 0x00F974, 0x0082E5},
+{0x00F975, 0x00F975, 0x0063A0},
+{0x00F976, 0x00F976, 0x007565},
+{0x00F977, 0x00F977, 0x004EAE},
+{0x00F978, 0x00F978, 0x005169},
+{0x00F979, 0x00F979, 0x0051C9},
+{0x00F97A, 0x00F97A, 0x006881},
+{0x00F97B, 0x00F97B, 0x007CE7},
+{0x00F97C, 0x00F97C, 0x00826F},
+{0x00F97D, 0x00F97D, 0x008AD2},
+{0x00F97E, 0x00F97E, 0x0091CF},
+{0x00F97F, 0x00F97F, 0x0052F5},
+{0x00F980, 0x00F980, 0x005442},
+{0x00F981, 0x00F981, 0x005973},
+{0x00F982, 0x00F982, 0x005EEC},
+{0x00F983, 0x00F983, 0x0065C5},
+{0x00F984, 0x00F984, 0x006FFE},
+{0x00F985, 0x00F985, 0x00792A},
+{0x00F986, 0x00F986, 0x0095AD},
+{0x00F987, 0x00F987, 0x009A6A},
+{0x00F988, 0x00F988, 0x009E97},
+{0x00F989, 0x00F989, 0x009ECE},
+{0x00F98A, 0x00F98A, 0x00529B},
+{0x00F98B, 0x00F98B, 0x0066C6},
+{0x00F98C, 0x00F98C, 0x006B77},
+{0x00F98D, 0x00F98D, 0x008F62},
+{0x00F98E, 0x00F98E, 0x005E74},
+{0x00F98F, 0x00F98F, 0x006190},
+{0x00F990, 0x00F990, 0x006200},
+{0x00F991, 0x00F991, 0x00649A},
+{0x00F992, 0x00F992, 0x006F23},
+{0x00F993, 0x00F993, 0x007149},
+{0x00F994, 0x00F994, 0x007489},
+{0x00F995, 0x00F995, 0x0079CA},
+{0x00F996, 0x00F996, 0x007DF4},
+{0x00F997, 0x00F997, 0x00806F},
+{0x00F998, 0x00F998, 0x008F26},
+{0x00F999, 0x00F999, 0x0084EE},
+{0x00F99A, 0x00F99A, 0x009023},
+{0x00F99B, 0x00F99B, 0x00934A},
+{0x00F99C, 0x00F99C, 0x005217},
+{0x00F99D, 0x00F99D, 0x0052A3},
+{0x00F99E, 0x00F99E, 0x0054BD},
+{0x00F99F, 0x00F99F, 0x0070C8},
+{0x00F9A0, 0x00F9A0, 0x0088C2},
+{0x00F9A1, 0x00F9A1, 0x008AAA},
+{0x00F9A2, 0x00F9A2, 0x005EC9},
+{0x00F9A3, 0x00F9A3, 0x005FF5},
+{0x00F9A4, 0x00F9A4, 0x00637B},
+{0x00F9A5, 0x00F9A5, 0x006BAE},
+{0x00F9A6, 0x00F9A6, 0x007C3E},
+{0x00F9A7, 0x00F9A7, 0x007375},
+{0x00F9A8, 0x00F9A8, 0x004EE4},
+{0x00F9A9, 0x00F9A9, 0x0056F9},
+{0x00F9AA, 0x00F9AA, 0x005BE7},
+{0x00F9AB, 0x00F9AB, 0x005DBA},
+{0x00F9AC, 0x00F9AC, 0x00601C},
+{0x00F9AD, 0x00F9AD, 0x0073B2},
+{0x00F9AE, 0x00F9AE, 0x007469},
+{0x00F9AF, 0x00F9AF, 0x007F9A},
+{0x00F9B0, 0x00F9B0, 0x008046},
+{0x00F9B1, 0x00F9B1, 0x009234},
+{0x00F9B2, 0x00F9B2, 0x0096F6},
+{0x00F9B3, 0x00F9B3, 0x009748},
+{0x00F9B4, 0x00F9B4, 0x009818},
+{0x00F9B5, 0x00F9B5, 0x004F8B},
+{0x00F9B6, 0x00F9B6, 0x0079AE},
+{0x00F9B7, 0x00F9B7, 0x0091B4},
+{0x00F9B8, 0x00F9B8, 0x0096B8},
+{0x00F9B9, 0x00F9B9, 0x0060E1},
+{0x00F9BA, 0x00F9BA, 0x004E86},
+{0x00F9BB, 0x00F9BB, 0x0050DA},
+{0x00F9BC, 0x00F9BC, 0x005BEE},
+{0x00F9BD, 0x00F9BD, 0x005C3F},
+{0x00F9BE, 0x00F9BE, 0x006599},
+{0x00F9BF, 0x00F9BF, 0x006A02},
+{0x00F9C0, 0x00F9C0, 0x0071CE},
+{0x00F9C1, 0x00F9C1, 0x007642},
+{0x00F9C2, 0x00F9C2, 0x0084FC},
+{0x00F9C3, 0x00F9C3, 0x00907C},
+{0x00F9C4, 0x00F9C4, 0x009F8D},
+{0x00F9C5, 0x00F9C5, 0x006688},
+{0x00F9C6, 0x00F9C6, 0x00962E},
+{0x00F9C7, 0x00F9C7, 0x005289},
+{0x00F9C8, 0x00F9C8, 0x00677B},
+{0x00F9C9, 0x00F9C9, 0x0067F3},
+{0x00F9CA, 0x00F9CA, 0x006D41},
+{0x00F9CB, 0x00F9CB, 0x006E9C},
+{0x00F9CC, 0x00F9CC, 0x007409},
+{0x00F9CD, 0x00F9CD, 0x007559},
+{0x00F9CE, 0x00F9CE, 0x00786B},
+{0x00F9CF, 0x00F9CF, 0x007D10},
+{0x00F9D0, 0x00F9D0, 0x00985E},
+{0x00F9D1, 0x00F9D1, 0x00516D},
+{0x00F9D2, 0x00F9D2, 0x00622E},
+{0x00F9D3, 0x00F9D3, 0x009678},
+{0x00F9D4, 0x00F9D4, 0x00502B},
+{0x00F9D5, 0x00F9D5, 0x005D19},
+{0x00F9D6, 0x00F9D6, 0x006DEA},
+{0x00F9D7, 0x00F9D7, 0x008F2A},
+{0x00F9D8, 0x00F9D8, 0x005F8B},
+{0x00F9D9, 0x00F9D9, 0x006144},
+{0x00F9DA, 0x00F9DA, 0x006817},
+{0x00F9DB, 0x00F9DB, 0x007387},
+{0x00F9DC, 0x00F9DC, 0x009686},
+{0x00F9DD, 0x00F9DD, 0x005229},
+{0x00F9DE, 0x00F9DE, 0x00540F},
+{0x00F9DF, 0x00F9DF, 0x005C65},
+{0x00F9E0, 0x00F9E0, 0x006613},
+{0x00F9E1, 0x00F9E1, 0x00674E},
+{0x00F9E2, 0x00F9E2, 0x0068A8},
+{0x00F9E3, 0x00F9E3, 0x006CE5},
+{0x00F9E4, 0x00F9E4, 0x007406},
+{0x00F9E5, 0x00F9E5, 0x0075E2},
+{0x00F9E6, 0x00F9E6, 0x007F79},
+{0x00F9E7, 0x00F9E7, 0x0088CF},
+{0x00F9E8, 0x00F9E8, 0x0088E1},
+{0x00F9E9, 0x00F9E9, 0x0091CC},
+{0x00F9EA, 0x00F9EA, 0x0096E2},
+{0x00F9EB, 0x00F9EB, 0x00533F},
+{0x00F9EC, 0x00F9EC, 0x006EBA},
+{0x00F9ED, 0x00F9ED, 0x00541D},
+{0x00F9EE, 0x00F9EE, 0x0071D0},
+{0x00F9EF, 0x00F9EF, 0x007498},
+{0x00F9F0, 0x00F9F0, 0x0085FA},
+{0x00F9F1, 0x00F9F1, 0x0096A3},
+{0x00F9F2, 0x00F9F2, 0x009C57},
+{0x00F9F3, 0x00F9F3, 0x009E9F},
+{0x00F9F4, 0x00F9F4, 0x006797},
+{0x00F9F5, 0x00F9F5, 0x006DCB},
+{0x00F9F6, 0x00F9F6, 0x0081E8},
+{0x00F9F7, 0x00F9F7, 0x007ACB},
+{0x00F9F8, 0x00F9F8, 0x007B20},
+{0x00F9F9, 0x00F9F9, 0x007C92},
+{0x00F9FA, 0x00F9FA, 0x0072C0},
+{0x00F9FB, 0x00F9FB, 0x007099},
+{0x00F9FC, 0x00F9FC, 0x008B58},
+{0x00F9FD, 0x00F9FD, 0x004EC0},
+{0x00F9FE, 0x00F9FE, 0x008336},
+{0x00F9FF, 0x00F9FF, 0x00523A},
+{0x00FA00, 0x00FA00, 0x005207},
+{0x00FA01, 0x00FA01, 0x005EA6},
+{0x00FA02, 0x00FA02, 0x0062D3},
+{0x00FA03, 0x00FA03, 0x007CD6},
+{0x00FA04, 0x00FA04, 0x005B85},
+{0x00FA05, 0x00FA05, 0x006D1E},
+{0x00FA06, 0x00FA06, 0x0066B4},
+{0x00FA07, 0x00FA07, 0x008F3B},
+{0x00FA08, 0x00FA08, 0x00884C},
+{0x00FA09, 0x00FA09, 0x00964D},
+{0x00FA0A, 0x00FA0A, 0x00898B},
+{0x00FA0B, 0x00FA0B, 0x005ED3},
+{0x00FA0C, 0x00FA0C, 0x005140},
+{0x00FA0D, 0x00FA0D, 0x0055C0},
+{0x00FA10, 0x00FA10, 0x00585A},
+{0x00FA12, 0x00FA12, 0x006674},
+{0x00FA15, 0x00FA15, 0x0051DE},
+{0x00FA16, 0x00FA16, 0x00732A},
+{0x00FA17, 0x00FA17, 0x0076CA},
+{0x00FA18, 0x00FA18, 0x00793C},
+{0x00FA19, 0x00FA19, 0x00795E},
+{0x00FA1A, 0x00FA1A, 0x007965},
+{0x00FA1B, 0x00FA1B, 0x00798F},
+{0x00FA1C, 0x00FA1C, 0x009756},
+{0x00FA1D, 0x00FA1D, 0x007CBE},
+{0x00FA1E, 0x00FA1E, 0x007FBD},
+{0x00FA20, 0x00FA20, 0x008612},
+{0x00FA22, 0x00FA22, 0x008AF8},
+{0x00FA25, 0x00FA25, 0x009038},
+{0x00FA26, 0x00FA26, 0x0090FD},
+{0x00FA2A, 0x00FA2A, 0x0098EF},
+{0x00FA2B, 0x00FA2B, 0x0098FC},
+{0x00FA2C, 0x00FA2C, 0x009928},
+{0x00FA2D, 0x00FA2D, 0x009DB4},
+{0x00FA2E, 0x00FA2E, 0x0090DE},
+{0x00FA2F, 0x00FA2F, 0x0096B7},
+{0x00FA30, 0x00FA30, 0x004FAE},
+{0x00FA31, 0x00FA31, 0x0050E7},
+{0x00FA32, 0x00FA32, 0x00514D},
+{0x00FA33, 0x00FA33, 0x0052C9},
+{0x00FA34, 0x00FA34, 0x0052E4},
+{0x00FA35, 0x00FA35, 0x005351},
+{0x00FA36, 0x00FA36, 0x00559D},
+{0x00FA37, 0x00FA37, 0x005606},
+{0x00FA38, 0x00FA38, 0x005668},
+{0x00FA39, 0x00FA39, 0x005840},
+{0x00FA3A, 0x00FA3A, 0x0058A8},
+{0x00FA3B, 0x00FA3B, 0x005C64},
+{0x00FA3C, 0x00FA3C, 0x005C6E},
+{0x00FA3D, 0x00FA3D, 0x006094},
+{0x00FA3E, 0x00FA3E, 0x006168},
+{0x00FA3F, 0x00FA3F, 0x00618E},
+{0x00FA40, 0x00FA40, 0x0061F2},
+{0x00FA41, 0x00FA41, 0x00654F},
+{0x00FA42, 0x00FA42, 0x0065E2},
+{0x00FA43, 0x00FA43, 0x006691},
+{0x00FA44, 0x00FA44, 0x006885},
+{0x00FA45, 0x00FA45, 0x006D77},
+{0x00FA46, 0x00FA46, 0x006E1A},
+{0x00FA47, 0x00FA47, 0x006F22},
+{0x00FA48, 0x00FA48, 0x00716E},
+{0x00FA49, 0x00FA49, 0x00722B},
+{0x00FA4A, 0x00FA4A, 0x007422},
+{0x00FA4B, 0x00FA4B, 0x007891},
+{0x00FA4C, 0x00FA4C, 0x00793E},
+{0x00FA4D, 0x00FA4D, 0x007949},
+{0x00FA4E, 0x00FA4E, 0x007948},
+{0x00FA4F, 0x00FA4F, 0x007950},
+{0x00FA50, 0x00FA50, 0x007956},
+{0x00FA51, 0x00FA51, 0x00795D},
+{0x00FA52, 0x00FA52, 0x00798D},
+{0x00FA53, 0x00FA53, 0x00798E},
+{0x00FA54, 0x00FA54, 0x007A40},
+{0x00FA55, 0x00FA55, 0x007A81},
+{0x00FA56, 0x00FA56, 0x007BC0},
+{0x00FA57, 0x00FA57, 0x007DF4},
+{0x00FA58, 0x00FA58, 0x007E09},
+{0x00FA59, 0x00FA59, 0x007E41},
+{0x00FA5A, 0x00FA5A, 0x007F72},
+{0x00FA5B, 0x00FA5B, 0x008005},
+{0x00FA5C, 0x00FA5C, 0x0081ED},
+{0x00FA5D, 0x00FA5E, 0x008279},
+{0x00FA5F, 0x00FA5F, 0x008457},
+{0x00FA60, 0x00FA60, 0x008910},
+{0x00FA61, 0x00FA61, 0x008996},
+{0x00FA62, 0x00FA62, 0x008B01},
+{0x00FA63, 0x00FA63, 0x008B39},
+{0x00FA64, 0x00FA64, 0x008CD3},
+{0x00FA65, 0x00FA65, 0x008D08},
+{0x00FA66, 0x00FA66, 0x008FB6},
+{0x00FA67, 0x00FA67, 0x009038},
+{0x00FA68, 0x00FA68, 0x0096E3},
+{0x00FA69, 0x00FA69, 0x0097FF},
+{0x00FA6A, 0x00FA6A, 0x00983B},
+{0x00FA6B, 0x00FA6B, 0x006075},
+{0x00FA6C, 0x00FA6C, 0x0242EE},
+{0x00FA6D, 0x00FA6D, 0x008218},
+{0x00FA70, 0x00FA70, 0x004E26},
+{0x00FA71, 0x00FA71, 0x0051B5},
+{0x00FA72, 0x00FA72, 0x005168},
+{0x00FA73, 0x00FA73, 0x004F80},
+{0x00FA74, 0x00FA74, 0x005145},
+{0x00FA75, 0x00FA75, 0x005180},
+{0x00FA76, 0x00FA76, 0x0052C7},
+{0x00FA77, 0x00FA77, 0x0052FA},
+{0x00FA78, 0x00FA78, 0x00559D},
+{0x00FA79, 0x00FA79, 0x005555},
+{0x00FA7A, 0x00FA7A, 0x005599},
+{0x00FA7B, 0x00FA7B, 0x0055E2},
+{0x00FA7C, 0x00FA7C, 0x00585A},
+{0x00FA7D, 0x00FA7D, 0x0058B3},
+{0x00FA7E, 0x00FA7E, 0x005944},
+{0x00FA7F, 0x00FA7F, 0x005954},
+{0x00FA80, 0x00FA80, 0x005A62},
+{0x00FA81, 0x00FA81, 0x005B28},
+{0x00FA82, 0x00FA82, 0x005ED2},
+{0x00FA83, 0x00FA83, 0x005ED9},
+{0x00FA84, 0x00FA84, 0x005F69},
+{0x00FA85, 0x00FA85, 0x005FAD},
+{0x00FA86, 0x00FA86, 0x0060D8},
+{0x00FA87, 0x00FA87, 0x00614E},
+{0x00FA88, 0x00FA88, 0x006108},
+{0x00FA89, 0x00FA89, 0x00618E},
+{0x00FA8A, 0x00FA8A, 0x006160},
+{0x00FA8B, 0x00FA8B, 0x0061F2},
+{0x00FA8C, 0x00FA8C, 0x006234},
+{0x00FA8D, 0x00FA8D, 0x0063C4},
+{0x00FA8E, 0x00FA8E, 0x00641C},
+{0x00FA8F, 0x00FA8F, 0x006452},
+{0x00FA90, 0x00FA90, 0x006556},
+{0x00FA91, 0x00FA91, 0x006674},
+{0x00FA92, 0x00FA92, 0x006717},
+{0x00FA93, 0x00FA93, 0x00671B},
+{0x00FA94, 0x00FA94, 0x006756},
+{0x00FA95, 0x00FA95, 0x006B79},
+{0x00FA96, 0x00FA96, 0x006BBA},
+{0x00FA97, 0x00FA97, 0x006D41},
+{0x00FA98, 0x00FA98, 0x006EDB},
+{0x00FA99, 0x00FA99, 0x006ECB},
+{0x00FA9A, 0x00FA9A, 0x006F22},
+{0x00FA9B, 0x00FA9B, 0x00701E},
+{0x00FA9C, 0x00FA9C, 0x00716E},
+{0x00FA9D, 0x00FA9D, 0x0077A7},
+{0x00FA9E, 0x00FA9E, 0x007235},
+{0x00FA9F, 0x00FA9F, 0x0072AF},
+{0x00FAA0, 0x00FAA0, 0x00732A},
+{0x00FAA1, 0x00FAA1, 0x007471},
+{0x00FAA2, 0x00FAA2, 0x007506},
+{0x00FAA3, 0x00FAA3, 0x00753B},
+{0x00FAA4, 0x00FAA4, 0x00761D},
+{0x00FAA5, 0x00FAA5, 0x00761F},
+{0x00FAA6, 0x00FAA6, 0x0076CA},
+{0x00FAA7, 0x00FAA7, 0x0076DB},
+{0x00FAA8, 0x00FAA8, 0x0076F4},
+{0x00FAA9, 0x00FAA9, 0x00774A},
+{0x00FAAA, 0x00FAAA, 0x007740},
+{0x00FAAB, 0x00FAAB, 0x0078CC},
+{0x00FAAC, 0x00FAAC, 0x007AB1},
+{0x00FAAD, 0x00FAAD, 0x007BC0},
+{0x00FAAE, 0x00FAAE, 0x007C7B},
+{0x00FAAF, 0x00FAAF, 0x007D5B},
+{0x00FAB0, 0x00FAB0, 0x007DF4},
+{0x00FAB1, 0x00FAB1, 0x007F3E},
+{0x00FAB2, 0x00FAB2, 0x008005},
+{0x00FAB3, 0x00FAB3, 0x008352},
+{0x00FAB4, 0x00FAB4, 0x0083EF},
+{0x00FAB5, 0x00FAB5, 0x008779},
+{0x00FAB6, 0x00FAB6, 0x008941},
+{0x00FAB7, 0x00FAB7, 0x008986},
+{0x00FAB8, 0x00FAB8, 0x008996},
+{0x00FAB9, 0x00FAB9, 0x008ABF},
+{0x00FABA, 0x00FABA, 0x008AF8},
+{0x00FABB, 0x00FABB, 0x008ACB},
+{0x00FABC, 0x00FABC, 0x008B01},
+{0x00FABD, 0x00FABD, 0x008AFE},
+{0x00FABE, 0x00FABE, 0x008AED},
+{0x00FABF, 0x00FABF, 0x008B39},
+{0x00FAC0, 0x00FAC0, 0x008B8A},
+{0x00FAC1, 0x00FAC1, 0x008D08},
+{0x00FAC2, 0x00FAC2, 0x008F38},
+{0x00FAC3, 0x00FAC3, 0x009072},
+{0x00FAC4, 0x00FAC4, 0x009199},
+{0x00FAC5, 0x00FAC5, 0x009276},
+{0x00FAC6, 0x00FAC6, 0x00967C},
+{0x00FAC7, 0x00FAC7, 0x0096E3},
+{0x00FAC8, 0x00FAC8, 0x009756},
+{0x00FAC9, 0x00FAC9, 0x0097DB},
+{0x00FACA, 0x00FACA, 0x0097FF},
+{0x00FACB, 0x00FACB, 0x00980B},
+{0x00FACC, 0x00FACC, 0x00983B},
+{0x00FACD, 0x00FACD, 0x009B12},
+{0x00FACE, 0x00FACE, 0x009F9C},
+{0x00FACF, 0x00FACF, 0x02284A},
+{0x00FAD0, 0x00FAD0, 0x022844},
+{0x00FAD1, 0x00FAD1, 0x0233D5},
+{0x00FAD2, 0x00FAD2, 0x003B9D},
+{0x00FAD3, 0x00FAD3, 0x004018},
+{0x00FAD4, 0x00FAD4, 0x004039},
+{0x00FAD5, 0x00FAD5, 0x025249},
+{0x00FAD6, 0x00FAD6, 0x025CD0},
+{0x00FAD7, 0x00FAD7, 0x027ED3},
+{0x00FAD8, 0x00FAD8, 0x009F43},
+{0x00FAD9, 0x00FAD9, 0x009F8E},
+{0x00FB1D, 0x00FB1D, 0x0005D9},
+{0x00FB1F, 0x00FB1F, 0x0005F2},
+{0x00FB2A, 0x00FB2D, 0x0005E9},
+{0x00FB2E, 0x00FB30, 0x0005D0},
+{0x00FB31, 0x00FB31, 0x0005D1},
+{0x00FB32, 0x00FB32, 0x0005D2},
+{0x00FB33, 0x00FB33, 0x0005D3},
+{0x00FB34, 0x00FB34, 0x0005D4},
+{0x00FB35, 0x00FB35, 0x0005D5},
+{0x00FB36, 0x00FB36, 0x0005D6},
+{0x00FB38, 0x00FB38, 0x0005D8},
+{0x00FB39, 0x00FB39, 0x0005D9},
+{0x00FB3A, 0x00FB3A, 0x0005DA},
+{0x00FB3B, 0x00FB3B, 0x0005DB},
+{0x00FB3C, 0x00FB3C, 0x0005DC},
+{0x00FB3E, 0x00FB3E, 0x0005DE},
+{0x00FB40, 0x00FB40, 0x0005E0},
+{0x00FB41, 0x00FB41, 0x0005E1},
+{0x00FB43, 0x00FB43, 0x0005E3},
+{0x00FB44, 0x00FB44, 0x0005E4},
+{0x00FB46, 0x00FB46, 0x0005E6},
+{0x00FB47, 0x00FB47, 0x0005E7},
+{0x00FB48, 0x00FB48, 0x0005E8},
+{0x00FB49, 0x00FB49, 0x0005E9},
+{0x00FB4A, 0x00FB4A, 0x0005EA},
+{0x00FB4B, 0x00FB4B, 0x0005D5},
+{0x00FB4C, 0x00FB4C, 0x0005D1},
+{0x00FB4D, 0x00FB4D, 0x0005DB},
+{0x00FB4E, 0x00FB4E, 0x0005E4},
+{0x01109A, 0x01109A, 0x011099},
+{0x01109C, 0x01109C, 0x01109B},
+{0x0110AB, 0x0110AB, 0x0110A5},
+{0x01112E, 0x01112E, 0x011131},
+{0x01112F, 0x01112F, 0x011132},
+{0x01134B, 0x01134C, 0x011347},
+{0x0114BB, 0x0114BC, 0x0114B9},
+{0x0114BE, 0x0114BE, 0x0114B9},
+{0x0115BA, 0x0115BA, 0x0115B8},
+{0x0115BB, 0x0115BB, 0x0115B9},
+{0x011938, 0x011938, 0x011935},
+{0x01D15E, 0x01D15E, 0x01D157},
+{0x01D15F, 0x01D164, 0x01D158},
+{0x01D1BB, 0x01D1BB, 0x01D1B9},
+{0x01D1BC, 0x01D1BC, 0x01D1BA},
+{0x01D1BD, 0x01D1BD, 0x01D1B9},
+{0x01D1BE, 0x01D1BE, 0x01D1BA},
+{0x01D1BF, 0x01D1BF, 0x01D1B9},
+{0x01D1C0, 0x01D1C0, 0x01D1BA},
+{0x02F800, 0x02F800, 0x004E3D},
+{0x02F801, 0x02F801, 0x004E38},
+{0x02F802, 0x02F802, 0x004E41},
+{0x02F803, 0x02F803, 0x020122},
+{0x02F804, 0x02F804, 0x004F60},
+{0x02F805, 0x02F805, 0x004FAE},
+{0x02F806, 0x02F806, 0x004FBB},
+{0x02F807, 0x02F807, 0x005002},
+{0x02F808, 0x02F808, 0x00507A},
+{0x02F809, 0x02F809, 0x005099},
+{0x02F80A, 0x02F80A, 0x0050E7},
+{0x02F80B, 0x02F80B, 0x0050CF},
+{0x02F80C, 0x02F80C, 0x00349E},
+{0x02F80D, 0x02F80D, 0x02063A},
+{0x02F80E, 0x02F80E, 0x00514D},
+{0x02F80F, 0x02F80F, 0x005154},
+{0x02F810, 0x02F810, 0x005164},
+{0x02F811, 0x02F811, 0x005177},
+{0x02F812, 0x02F812, 0x02051C},
+{0x02F813, 0x02F813, 0x0034B9},
+{0x02F814, 0x02F814, 0x005167},
+{0x02F815, 0x02F815, 0x00518D},
+{0x02F816, 0x02F816, 0x02054B},
+{0x02F817, 0x02F817, 0x005197},
+{0x02F818, 0x02F818, 0x0051A4},
+{0x02F819, 0x02F819, 0x004ECC},
+{0x02F81A, 0x02F81A, 0x0051AC},
+{0x02F81B, 0x02F81B, 0x0051B5},
+{0x02F81C, 0x02F81C, 0x0291DF},
+{0x02F81D, 0x02F81D, 0x0051F5},
+{0x02F81E, 0x02F81E, 0x005203},
+{0x02F81F, 0x02F81F, 0x0034DF},
+{0x02F820, 0x02F820, 0x00523B},
+{0x02F821, 0x02F821, 0x005246},
+{0x02F822, 0x02F822, 0x005272},
+{0x02F823, 0x02F823, 0x005277},
+{0x02F824, 0x02F824, 0x003515},
+{0x02F825, 0x02F825, 0x0052C7},
+{0x02F826, 0x02F826, 0x0052C9},
+{0x02F827, 0x02F827, 0x0052E4},
+{0x02F828, 0x02F828, 0x0052FA},
+{0x02F829, 0x02F829, 0x005305},
+{0x02F82A, 0x02F82A, 0x005306},
+{0x02F82B, 0x02F82B, 0x005317},
+{0x02F82C, 0x02F82C, 0x005349},
+{0x02F82D, 0x02F82D, 0x005351},
+{0x02F82E, 0x02F82E, 0x00535A},
+{0x02F82F, 0x02F82F, 0x005373},
+{0x02F830, 0x02F830, 0x00537D},
+{0x02F831, 0x02F833, 0x00537F},
+{0x02F834, 0x02F834, 0x020A2C},
+{0x02F835, 0x02F835, 0x007070},
+{0x02F836, 0x02F836, 0x0053CA},
+{0x02F837, 0x02F837, 0x0053DF},
+{0x02F838, 0x02F838, 0x020B63},
+{0x02F839, 0x02F839, 0x0053EB},
+{0x02F83A, 0x02F83A, 0x0053F1},
+{0x02F83B, 0x02F83B, 0x005406},
+{0x02F83C, 0x02F83C, 0x00549E},
+{0x02F83D, 0x02F83D, 0x005438},
+{0x02F83E, 0x02F83E, 0x005448},
+{0x02F83F, 0x02F83F, 0x005468},
+{0x02F840, 0x02F840, 0x0054A2},
+{0x02F841, 0x02F841, 0x0054F6},
+{0x02F842, 0x02F842, 0x005510},
+{0x02F843, 0x02F843, 0x005553},
+{0x02F844, 0x02F844, 0x005563},
+{0x02F845, 0x02F846, 0x005584},
+{0x02F847, 0x02F847, 0x005599},
+{0x02F848, 0x02F848, 0x0055AB},
+{0x02F849, 0x02F849, 0x0055B3},
+{0x02F84A, 0x02F84A, 0x0055C2},
+{0x02F84B, 0x02F84B, 0x005716},
+{0x02F84C, 0x02F84C, 0x005606},
+{0x02F84D, 0x02F84D, 0x005717},
+{0x02F84E, 0x02F84E, 0x005651},
+{0x02F84F, 0x02F84F, 0x005674},
+{0x02F850, 0x02F850, 0x005207},
+{0x02F851, 0x02F851, 0x0058EE},
+{0x02F852, 0x02F852, 0x0057CE},
+{0x02F853, 0x02F853, 0x0057F4},
+{0x02F854, 0x02F854, 0x00580D},
+{0x02F855, 0x02F855, 0x00578B},
+{0x02F856, 0x02F856, 0x005832},
+{0x02F857, 0x02F857, 0x005831},
+{0x02F858, 0x02F858, 0x0058AC},
+{0x02F859, 0x02F859, 0x0214E4},
+{0x02F85A, 0x02F85A, 0x0058F2},
+{0x02F85B, 0x02F85B, 0x0058F7},
+{0x02F85C, 0x02F85C, 0x005906},
+{0x02F85D, 0x02F85D, 0x00591A},
+{0x02F85E, 0x02F85E, 0x005922},
+{0x02F85F, 0x02F85F, 0x005962},
+{0x02F860, 0x02F860, 0x0216A8},
+{0x02F861, 0x02F861, 0x0216EA},
+{0x02F862, 0x02F862, 0x0059EC},
+{0x02F863, 0x02F863, 0x005A1B},
+{0x02F864, 0x02F864, 0x005A27},
+{0x02F865, 0x02F865, 0x0059D8},
+{0x02F866, 0x02F866, 0x005A66},
+{0x02F867, 0x02F867, 0x0036EE},
+{0x02F868, 0x02F868, 0x0036FC},
+{0x02F869, 0x02F869, 0x005B08},
+{0x02F86A, 0x02F86B, 0x005B3E},
+{0x02F86C, 0x02F86C, 0x0219C8},
+{0x02F86D, 0x02F86D, 0x005BC3},
+{0x02F86E, 0x02F86E, 0x005BD8},
+{0x02F86F, 0x02F86F, 0x005BE7},
+{0x02F870, 0x02F870, 0x005BF3},
+{0x02F871, 0x02F871, 0x021B18},
+{0x02F872, 0x02F872, 0x005BFF},
+{0x02F873, 0x02F873, 0x005C06},
+{0x02F874, 0x02F874, 0x005F53},
+{0x02F875, 0x02F875, 0x005C22},
+{0x02F876, 0x02F876, 0x003781},
+{0x02F877, 0x02F877, 0x005C60},
+{0x02F878, 0x02F878, 0x005C6E},
+{0x02F879, 0x02F879, 0x005CC0},
+{0x02F87A, 0x02F87A, 0x005C8D},
+{0x02F87B, 0x02F87B, 0x021DE4},
+{0x02F87C, 0x02F87C, 0x005D43},
+{0x02F87D, 0x02F87D, 0x021DE6},
+{0x02F87E, 0x02F87E, 0x005D6E},
+{0x02F87F, 0x02F87F, 0x005D6B},
+{0x02F880, 0x02F880, 0x005D7C},
+{0x02F881, 0x02F881, 0x005DE1},
+{0x02F882, 0x02F882, 0x005DE2},
+{0x02F883, 0x02F883, 0x00382F},
+{0x02F884, 0x02F884, 0x005DFD},
+{0x02F885, 0x02F885, 0x005E28},
+{0x02F886, 0x02F886, 0x005E3D},
+{0x02F887, 0x02F887, 0x005E69},
+{0x02F888, 0x02F888, 0x003862},
+{0x02F889, 0x02F889, 0x022183},
+{0x02F88A, 0x02F88A, 0x00387C},
+{0x02F88B, 0x02F88B, 0x005EB0},
+{0x02F88C, 0x02F88C, 0x005EB3},
+{0x02F88D, 0x02F88D, 0x005EB6},
+{0x02F88E, 0x02F88E, 0x005ECA},
+{0x02F88F, 0x02F88F, 0x02A392},
+{0x02F890, 0x02F890, 0x005EFE},
+{0x02F891, 0x02F892, 0x022331},
+{0x02F893, 0x02F893, 0x008201},
+{0x02F894, 0x02F895, 0x005F22},
+{0x02F896, 0x02F896, 0x0038C7},
+{0x02F897, 0x02F897, 0x0232B8},
+{0x02F898, 0x02F898, 0x0261DA},
+{0x02F899, 0x02F899, 0x005F62},
+{0x02F89A, 0x02F89A, 0x005F6B},
+{0x02F89B, 0x02F89B, 0x0038E3},
+{0x02F89C, 0x02F89C, 0x005F9A},
+{0x02F89D, 0x02F89D, 0x005FCD},
+{0x02F89E, 0x02F89E, 0x005FD7},
+{0x02F89F, 0x02F89F, 0x005FF9},
+{0x02F8A0, 0x02F8A0, 0x006081},
+{0x02F8A1, 0x02F8A1, 0x00393A},
+{0x02F8A2, 0x02F8A2, 0x00391C},
+{0x02F8A3, 0x02F8A3, 0x006094},
+{0x02F8A4, 0x02F8A4, 0x0226D4},
+{0x02F8A5, 0x02F8A5, 0x0060C7},
+{0x02F8A6, 0x02F8A6, 0x006148},
+{0x02F8A7, 0x02F8A7, 0x00614C},
+{0x02F8A8, 0x02F8A8, 0x00614E},
+{0x02F8A9, 0x02F8A9, 0x00614C},
+{0x02F8AA, 0x02F8AA, 0x00617A},
+{0x02F8AB, 0x02F8AB, 0x00618E},
+{0x02F8AC, 0x02F8AC, 0x0061B2},
+{0x02F8AD, 0x02F8AD, 0x0061A4},
+{0x02F8AE, 0x02F8AE, 0x0061AF},
+{0x02F8AF, 0x02F8AF, 0x0061DE},
+{0x02F8B0, 0x02F8B0, 0x0061F2},
+{0x02F8B1, 0x02F8B1, 0x0061F6},
+{0x02F8B2, 0x02F8B2, 0x006210},
+{0x02F8B3, 0x02F8B3, 0x00621B},
+{0x02F8B4, 0x02F8B4, 0x00625D},
+{0x02F8B5, 0x02F8B5, 0x0062B1},
+{0x02F8B6, 0x02F8B6, 0x0062D4},
+{0x02F8B7, 0x02F8B7, 0x006350},
+{0x02F8B8, 0x02F8B8, 0x022B0C},
+{0x02F8B9, 0x02F8B9, 0x00633D},
+{0x02F8BA, 0x02F8BA, 0x0062FC},
+{0x02F8BB, 0x02F8BB, 0x006368},
+{0x02F8BC, 0x02F8BC, 0x006383},
+{0x02F8BD, 0x02F8BD, 0x0063E4},
+{0x02F8BE, 0x02F8BE, 0x022BF1},
+{0x02F8BF, 0x02F8BF, 0x006422},
+{0x02F8C0, 0x02F8C0, 0x0063C5},
+{0x02F8C1, 0x02F8C1, 0x0063A9},
+{0x02F8C2, 0x02F8C2, 0x003A2E},
+{0x02F8C3, 0x02F8C3, 0x006469},
+{0x02F8C4, 0x02F8C4, 0x00647E},
+{0x02F8C5, 0x02F8C5, 0x00649D},
+{0x02F8C6, 0x02F8C6, 0x006477},
+{0x02F8C7, 0x02F8C7, 0x003A6C},
+{0x02F8C8, 0x02F8C8, 0x00654F},
+{0x02F8C9, 0x02F8C9, 0x00656C},
+{0x02F8CA, 0x02F8CA, 0x02300A},
+{0x02F8CB, 0x02F8CB, 0x0065E3},
+{0x02F8CC, 0x02F8CC, 0x0066F8},
+{0x02F8CD, 0x02F8CD, 0x006649},
+{0x02F8CE, 0x02F8CE, 0x003B19},
+{0x02F8CF, 0x02F8CF, 0x006691},
+{0x02F8D0, 0x02F8D0, 0x003B08},
+{0x02F8D1, 0x02F8D1, 0x003AE4},
+{0x02F8D2, 0x02F8D2, 0x005192},
+{0x02F8D3, 0x02F8D3, 0x005195},
+{0x02F8D4, 0x02F8D4, 0x006700},
+{0x02F8D5, 0x02F8D5, 0x00669C},
+{0x02F8D6, 0x02F8D6, 0x0080AD},
+{0x02F8D7, 0x02F8D7, 0x0043D9},
+{0x02F8D8, 0x02F8D8, 0x006717},
+{0x02F8D9, 0x02F8D9, 0x00671B},
+{0x02F8DA, 0x02F8DA, 0x006721},
+{0x02F8DB, 0x02F8DB, 0x00675E},
+{0x02F8DC, 0x02F8DC, 0x006753},
+{0x02F8DD, 0x02F8DD, 0x0233C3},
+{0x02F8DE, 0x02F8DE, 0x003B49},
+{0x02F8DF, 0x02F8DF, 0x0067FA},
+{0x02F8E0, 0x02F8E0, 0x006785},
+{0x02F8E1, 0x02F8E1, 0x006852},
+{0x02F8E2, 0x02F8E2, 0x006885},
+{0x02F8E3, 0x02F8E3, 0x02346D},
+{0x02F8E4, 0x02F8E4, 0x00688E},
+{0x02F8E5, 0x02F8E5, 0x00681F},
+{0x02F8E6, 0x02F8E6, 0x006914},
+{0x02F8E7, 0x02F8E7, 0x003B9D},
+{0x02F8E8, 0x02F8E8, 0x006942},
+{0x02F8E9, 0x02F8E9, 0x0069A3},
+{0x02F8EA, 0x02F8EA, 0x0069EA},
+{0x02F8EB, 0x02F8EB, 0x006AA8},
+{0x02F8EC, 0x02F8EC, 0x0236A3},
+{0x02F8ED, 0x02F8ED, 0x006ADB},
+{0x02F8EE, 0x02F8EE, 0x003C18},
+{0x02F8EF, 0x02F8EF, 0x006B21},
+{0x02F8F0, 0x02F8F0, 0x0238A7},
+{0x02F8F1, 0x02F8F1, 0x006B54},
+{0x02F8F2, 0x02F8F2, 0x003C4E},
+{0x02F8F3, 0x02F8F3, 0x006B72},
+{0x02F8F4, 0x02F8F4, 0x006B9F},
+{0x02F8F5, 0x02F8F5, 0x006BBA},
+{0x02F8F6, 0x02F8F6, 0x006BBB},
+{0x02F8F7, 0x02F8F7, 0x023A8D},
+{0x02F8F8, 0x02F8F8, 0x021D0B},
+{0x02F8F9, 0x02F8F9, 0x023AFA},
+{0x02F8FA, 0x02F8FA, 0x006C4E},
+{0x02F8FB, 0x02F8FB, 0x023CBC},
+{0x02F8FC, 0x02F8FC, 0x006CBF},
+{0x02F8FD, 0x02F8FD, 0x006CCD},
+{0x02F8FE, 0x02F8FE, 0x006C67},
+{0x02F8FF, 0x02F8FF, 0x006D16},
+{0x02F900, 0x02F900, 0x006D3E},
+{0x02F901, 0x02F901, 0x006D77},
+{0x02F902, 0x02F902, 0x006D41},
+{0x02F903, 0x02F903, 0x006D69},
+{0x02F904, 0x02F904, 0x006D78},
+{0x02F905, 0x02F905, 0x006D85},
+{0x02F906, 0x02F906, 0x023D1E},
+{0x02F907, 0x02F907, 0x006D34},
+{0x02F908, 0x02F908, 0x006E2F},
+{0x02F909, 0x02F909, 0x006E6E},
+{0x02F90A, 0x02F90A, 0x003D33},
+{0x02F90B, 0x02F90B, 0x006ECB},
+{0x02F90C, 0x02F90C, 0x006EC7},
+{0x02F90D, 0x02F90D, 0x023ED1},
+{0x02F90E, 0x02F90E, 0x006DF9},
+{0x02F90F, 0x02F90F, 0x006F6E},
+{0x02F910, 0x02F910, 0x023F5E},
+{0x02F911, 0x02F911, 0x023F8E},
+{0x02F912, 0x02F912, 0x006FC6},
+{0x02F913, 0x02F913, 0x007039},
+{0x02F914, 0x02F914, 0x00701E},
+{0x02F915, 0x02F915, 0x00701B},
+{0x02F916, 0x02F916, 0x003D96},
+{0x02F917, 0x02F917, 0x00704A},
+{0x02F918, 0x02F918, 0x00707D},
+{0x02F919, 0x02F919, 0x007077},
+{0x02F91A, 0x02F91A, 0x0070AD},
+{0x02F91B, 0x02F91B, 0x020525},
+{0x02F91C, 0x02F91C, 0x007145},
+{0x02F91D, 0x02F91D, 0x024263},
+{0x02F91E, 0x02F91E, 0x00719C},
+{0x02F91F, 0x02F91F, 0x0243AB},
+{0x02F920, 0x02F920, 0x007228},
+{0x02F921, 0x02F921, 0x007235},
+{0x02F922, 0x02F922, 0x007250},
+{0x02F923, 0x02F923, 0x024608},
+{0x02F924, 0x02F924, 0x007280},
+{0x02F925, 0x02F925, 0x007295},
+{0x02F926, 0x02F926, 0x024735},
+{0x02F927, 0x02F927, 0x024814},
+{0x02F928, 0x02F928, 0x00737A},
+{0x02F929, 0x02F929, 0x00738B},
+{0x02F92A, 0x02F92A, 0x003EAC},
+{0x02F92B, 0x02F92B, 0x0073A5},
+{0x02F92C, 0x02F92D, 0x003EB8},
+{0x02F92E, 0x02F92E, 0x007447},
+{0x02F92F, 0x02F92F, 0x00745C},
+{0x02F930, 0x02F930, 0x007471},
+{0x02F931, 0x02F931, 0x007485},
+{0x02F932, 0x02F932, 0x0074CA},
+{0x02F933, 0x02F933, 0x003F1B},
+{0x02F934, 0x02F934, 0x007524},
+{0x02F935, 0x02F935, 0x024C36},
+{0x02F936, 0x02F936, 0x00753E},
+{0x02F937, 0x02F937, 0x024C92},
+{0x02F938, 0x02F938, 0x007570},
+{0x02F939, 0x02F939, 0x02219F},
+{0x02F93A, 0x02F93A, 0x007610},
+{0x02F93B, 0x02F93B, 0x024FA1},
+{0x02F93C, 0x02F93C, 0x024FB8},
+{0x02F93D, 0x02F93D, 0x025044},
+{0x02F93E, 0x02F93E, 0x003FFC},
+{0x02F93F, 0x02F93F, 0x004008},
+{0x02F940, 0x02F940, 0x0076F4},
+{0x02F941, 0x02F941, 0x0250F3},
+{0x02F942, 0x02F942, 0x0250F2},
+{0x02F943, 0x02F943, 0x025119},
+{0x02F944, 0x02F944, 0x025133},
+{0x02F945, 0x02F945, 0x00771E},
+{0x02F946, 0x02F947, 0x00771F},
+{0x02F948, 0x02F948, 0x00774A},
+{0x02F949, 0x02F949, 0x004039},
+{0x02F94A, 0x02F94A, 0x00778B},
+{0x02F94B, 0x02F94B, 0x004046},
+{0x02F94C, 0x02F94C, 0x004096},
+{0x02F94D, 0x02F94D, 0x02541D},
+{0x02F94E, 0x02F94E, 0x00784E},
+{0x02F94F, 0x02F94F, 0x00788C},
+{0x02F950, 0x02F950, 0x0078CC},
+{0x02F951, 0x02F951, 0x0040E3},
+{0x02F952, 0x02F952, 0x025626},
+{0x02F953, 0x02F953, 0x007956},
+{0x02F954, 0x02F954, 0x02569A},
+{0x02F955, 0x02F955, 0x0256C5},
+{0x02F956, 0x02F956, 0x00798F},
+{0x02F957, 0x02F957, 0x0079EB},
+{0x02F958, 0x02F958, 0x00412F},
+{0x02F959, 0x02F959, 0x007A40},
+{0x02F95A, 0x02F95A, 0x007A4A},
+{0x02F95B, 0x02F95B, 0x007A4F},
+{0x02F95C, 0x02F95C, 0x02597C},
+{0x02F95D, 0x02F95E, 0x025AA7},
+{0x02F95F, 0x02F95F, 0x007AEE},
+{0x02F960, 0x02F960, 0x004202},
+{0x02F961, 0x02F961, 0x025BAB},
+{0x02F962, 0x02F962, 0x007BC6},
+{0x02F963, 0x02F963, 0x007BC9},
+{0x02F964, 0x02F964, 0x004227},
+{0x02F965, 0x02F965, 0x025C80},
+{0x02F966, 0x02F966, 0x007CD2},
+{0x02F967, 0x02F967, 0x0042A0},
+{0x02F968, 0x02F968, 0x007CE8},
+{0x02F969, 0x02F969, 0x007CE3},
+{0x02F96A, 0x02F96A, 0x007D00},
+{0x02F96B, 0x02F96B, 0x025F86},
+{0x02F96C, 0x02F96C, 0x007D63},
+{0x02F96D, 0x02F96D, 0x004301},
+{0x02F96E, 0x02F96E, 0x007DC7},
+{0x02F96F, 0x02F96F, 0x007E02},
+{0x02F970, 0x02F970, 0x007E45},
+{0x02F971, 0x02F971, 0x004334},
+{0x02F972, 0x02F972, 0x026228},
+{0x02F973, 0x02F973, 0x026247},
+{0x02F974, 0x02F974, 0x004359},
+{0x02F975, 0x02F975, 0x0262D9},
+{0x02F976, 0x02F976, 0x007F7A},
+{0x02F977, 0x02F977, 0x02633E},
+{0x02F978, 0x02F978, 0x007F95},
+{0x02F979, 0x02F979, 0x007FFA},
+{0x02F97A, 0x02F97A, 0x008005},
+{0x02F97B, 0x02F97B, 0x0264DA},
+{0x02F97C, 0x02F97C, 0x026523},
+{0x02F97D, 0x02F97D, 0x008060},
+{0x02F97E, 0x02F97E, 0x0265A8},
+{0x02F97F, 0x02F97F, 0x008070},
+{0x02F980, 0x02F980, 0x02335F},
+{0x02F981, 0x02F981, 0x0043D5},
+{0x02F982, 0x02F982, 0x0080B2},
+{0x02F983, 0x02F983, 0x008103},
+{0x02F984, 0x02F984, 0x00440B},
+{0x02F985, 0x02F985, 0x00813E},
+{0x02F986, 0x02F986, 0x005AB5},
+{0x02F987, 0x02F987, 0x0267A7},
+{0x02F988, 0x02F988, 0x0267B5},
+{0x02F989, 0x02F989, 0x023393},
+{0x02F98A, 0x02F98A, 0x02339C},
+{0x02F98B, 0x02F98B, 0x008201},
+{0x02F98C, 0x02F98C, 0x008204},
+{0x02F98D, 0x02F98D, 0x008F9E},
+{0x02F98E, 0x02F98E, 0x00446B},
+{0x02F98F, 0x02F98F, 0x008291},
+{0x02F990, 0x02F990, 0x00828B},
+{0x02F991, 0x02F991, 0x00829D},
+{0x02F992, 0x02F992, 0x0052B3},
+{0x02F993, 0x02F993, 0x0082B1},
+{0x02F994, 0x02F994, 0x0082B3},
+{0x02F995, 0x02F995, 0x0082BD},
+{0x02F996, 0x02F996, 0x0082E6},
+{0x02F997, 0x02F997, 0x026B3C},
+{0x02F998, 0x02F998, 0x0082E5},
+{0x02F999, 0x02F999, 0x00831D},
+{0x02F99A, 0x02F99A, 0x008363},
+{0x02F99B, 0x02F99B, 0x0083AD},
+{0x02F99C, 0x02F99C, 0x008323},
+{0x02F99D, 0x02F99D, 0x0083BD},
+{0x02F99E, 0x02F99E, 0x0083E7},
+{0x02F99F, 0x02F99F, 0x008457},
+{0x02F9A0, 0x02F9A0, 0x008353},
+{0x02F9A1, 0x02F9A1, 0x0083CA},
+{0x02F9A2, 0x02F9A2, 0x0083CC},
+{0x02F9A3, 0x02F9A3, 0x0083DC},
+{0x02F9A4, 0x02F9A4, 0x026C36},
+{0x02F9A5, 0x02F9A5, 0x026D6B},
+{0x02F9A6, 0x02F9A6, 0x026CD5},
+{0x02F9A7, 0x02F9A7, 0x00452B},
+{0x02F9A8, 0x02F9A8, 0x0084F1},
+{0x02F9A9, 0x02F9A9, 0x0084F3},
+{0x02F9AA, 0x02F9AA, 0x008516},
+{0x02F9AB, 0x02F9AB, 0x0273CA},
+{0x02F9AC, 0x02F9AC, 0x008564},
+{0x02F9AD, 0x02F9AD, 0x026F2C},
+{0x02F9AE, 0x02F9AE, 0x00455D},
+{0x02F9AF, 0x02F9AF, 0x004561},
+{0x02F9B0, 0x02F9B0, 0x026FB1},
+{0x02F9B1, 0x02F9B1, 0x0270D2},
+{0x02F9B2, 0x02F9B2, 0x00456B},
+{0x02F9B3, 0x02F9B3, 0x008650},
+{0x02F9B4, 0x02F9B4, 0x00865C},
+{0x02F9B5, 0x02F9B5, 0x008667},
+{0x02F9B6, 0x02F9B6, 0x008669},
+{0x02F9B7, 0x02F9B7, 0x0086A9},
+{0x02F9B8, 0x02F9B8, 0x008688},
+{0x02F9B9, 0x02F9B9, 0x00870E},
+{0x02F9BA, 0x02F9BA, 0x0086E2},
+{0x02F9BB, 0x02F9BB, 0x008779},
+{0x02F9BC, 0x02F9BC, 0x008728},
+{0x02F9BD, 0x02F9BD, 0x00876B},
+{0x02F9BE, 0x02F9BE, 0x008786},
+{0x02F9BF, 0x02F9BF, 0x0045D7},
+{0x02F9C0, 0x02F9C0, 0x0087E1},
+{0x02F9C1, 0x02F9C1, 0x008801},
+{0x02F9C2, 0x02F9C2, 0x0045F9},
+{0x02F9C3, 0x02F9C3, 0x008860},
+{0x02F9C4, 0x02F9C4, 0x008863},
+{0x02F9C5, 0x02F9C5, 0x027667},
+{0x02F9C6, 0x02F9C6, 0x0088D7},
+{0x02F9C7, 0x02F9C7, 0x0088DE},
+{0x02F9C8, 0x02F9C8, 0x004635},
+{0x02F9C9, 0x02F9C9, 0x0088FA},
+{0x02F9CA, 0x02F9CA, 0x0034BB},
+{0x02F9CB, 0x02F9CB, 0x0278AE},
+{0x02F9CC, 0x02F9CC, 0x027966},
+{0x02F9CD, 0x02F9CD, 0x0046BE},
+{0x02F9CE, 0x02F9CE, 0x0046C7},
+{0x02F9CF, 0x02F9CF, 0x008AA0},
+{0x02F9D0, 0x02F9D0, 0x008AED},
+{0x02F9D1, 0x02F9D1, 0x008B8A},
+{0x02F9D2, 0x02F9D2, 0x008C55},
+{0x02F9D3, 0x02F9D3, 0x027CA8},
+{0x02F9D4, 0x02F9D4, 0x008CAB},
+{0x02F9D5, 0x02F9D5, 0x008CC1},
+{0x02F9D6, 0x02F9D6, 0x008D1B},
+{0x02F9D7, 0x02F9D7, 0x008D77},
+{0x02F9D8, 0x02F9D8, 0x027F2F},
+{0x02F9D9, 0x02F9D9, 0x020804},
+{0x02F9DA, 0x02F9DA, 0x008DCB},
+{0x02F9DB, 0x02F9DB, 0x008DBC},
+{0x02F9DC, 0x02F9DC, 0x008DF0},
+{0x02F9DD, 0x02F9DD, 0x0208DE},
+{0x02F9DE, 0x02F9DE, 0x008ED4},
+{0x02F9DF, 0x02F9DF, 0x008F38},
+{0x02F9E0, 0x02F9E0, 0x0285D2},
+{0x02F9E1, 0x02F9E1, 0x0285ED},
+{0x02F9E2, 0x02F9E2, 0x009094},
+{0x02F9E3, 0x02F9E3, 0x0090F1},
+{0x02F9E4, 0x02F9E4, 0x009111},
+{0x02F9E5, 0x02F9E5, 0x02872E},
+{0x02F9E6, 0x02F9E6, 0x00911B},
+{0x02F9E7, 0x02F9E7, 0x009238},
+{0x02F9E8, 0x02F9E8, 0x0092D7},
+{0x02F9E9, 0x02F9E9, 0x0092D8},
+{0x02F9EA, 0x02F9EA, 0x00927C},
+{0x02F9EB, 0x02F9EB, 0x0093F9},
+{0x02F9EC, 0x02F9EC, 0x009415},
+{0x02F9ED, 0x02F9ED, 0x028BFA},
+{0x02F9EE, 0x02F9EE, 0x00958B},
+{0x02F9EF, 0x02F9EF, 0x004995},
+{0x02F9F0, 0x02F9F0, 0x0095B7},
+{0x02F9F1, 0x02F9F1, 0x028D77},
+{0x02F9F2, 0x02F9F2, 0x0049E6},
+{0x02F9F3, 0x02F9F3, 0x0096C3},
+{0x02F9F4, 0x02F9F4, 0x005DB2},
+{0x02F9F5, 0x02F9F5, 0x009723},
+{0x02F9F6, 0x02F9F6, 0x029145},
+{0x02F9F7, 0x02F9F7, 0x02921A},
+{0x02F9F8, 0x02F9F8, 0x004A6E},
+{0x02F9F9, 0x02F9F9, 0x004A76},
+{0x02F9FA, 0x02F9FA, 0x0097E0},
+{0x02F9FB, 0x02F9FB, 0x02940A},
+{0x02F9FC, 0x02F9FC, 0x004AB2},
+{0x02F9FD, 0x02F9FD, 0x029496},
+{0x02F9FE, 0x02F9FF, 0x00980B},
+{0x02FA00, 0x02FA00, 0x009829},
+{0x02FA01, 0x02FA01, 0x0295B6},
+{0x02FA02, 0x02FA02, 0x0098E2},
+{0x02FA03, 0x02FA03, 0x004B33},
+{0x02FA04, 0x02FA04, 0x009929},
+{0x02FA05, 0x02FA05, 0x0099A7},
+{0x02FA06, 0x02FA06, 0x0099C2},
+{0x02FA07, 0x02FA07, 0x0099FE},
+{0x02FA08, 0x02FA08, 0x004BCE},
+{0x02FA09, 0x02FA09, 0x029B30},
+{0x02FA0A, 0x02FA0A, 0x009B12},
+{0x02FA0B, 0x02FA0B, 0x009C40},
+{0x02FA0C, 0x02FA0C, 0x009CFD},
+{0x02FA0D, 0x02FA0D, 0x004CCE},
+{0x02FA0E, 0x02FA0E, 0x004CED},
+{0x02FA0F, 0x02FA0F, 0x009D67},
+{0x02FA10, 0x02FA10, 0x02A0CE},
+{0x02FA11, 0x02FA11, 0x004CF8},
+{0x02FA12, 0x02FA12, 0x02A105},
+{0x02FA13, 0x02FA13, 0x02A20E},
+{0x02FA14, 0x02FA14, 0x02A291},
+{0x02FA15, 0x02FA15, 0x009EBB},
+{0x02FA16, 0x02FA16, 0x004D56},
+{0x02FA17, 0x02FA17, 0x009EF9},
+{0x02FA18, 0x02FA18, 0x009EFE},
+{0x02FA19, 0x02FA19, 0x009F05},
+{0x02FA1A, 0x02FA1A, 0x009F0F},
+{0x02FA1B, 0x02FA1B, 0x009F16},
+{0x02FA1C, 0x02FA1C, 0x009F3B},
+{0x02FA1D, 0x02FA1D, 0x02A600},
+};
diff --git a/src/unicode-data.h b/src/unicode-data.h
new file mode 100644
index 0000000000000000000000000000000000000000..f6973ebd2e3506e7ace1dd8cfe7d252ce6d25c3d
--- /dev/null
+++ b/src/unicode-data.h
@@ -0,0 +1,20 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+struct range_nfd {
+    uint32_t first;
+    uint32_t last;
+    uint32_t nfd;
+};
+
+static const uint32_t MAX_CODEPOINTS = 0x110000;
+
+extern const std::initializer_list> unicode_ranges_flags;
+extern const std::unordered_set unicode_set_whitespace;
+extern const std::initializer_list> unicode_map_lowercase;
+extern const std::initializer_list> unicode_map_uppercase;
+extern const std::initializer_list unicode_ranges_nfd;
diff --git a/src/unicode.cpp b/src/unicode.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..e63bb4ab085d68a88861cd9e3c7e9da367b6516b
--- /dev/null
+++ b/src/unicode.cpp
@@ -0,0 +1,849 @@
+#if defined(_MSC_VER)
+#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
+#endif
+
+#include "unicode.h"
+#include "unicode-data.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+size_t unicode_len_utf8(char src) {
+    const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
+    uint8_t highbits = static_cast(src) >> 4;
+    return lookup[highbits];
+}
+
+static std::string unicode_cpts_to_utf8(const std::vector & cps) {
+    std::string result;
+    for (size_t i = 0; i < cps.size(); ++i) {
+        result.append(unicode_cpt_to_utf8(cps[i]));
+    }
+    return result;
+}
+
+uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
+    assert(offset < utf8.size());
+    if (!(utf8[offset + 0] & 0x80)) {
+        auto result = utf8[offset + 0];
+        offset += 1;
+        return result;
+    }
+    if (!(utf8[offset + 0] & 0x40)) {
+        throw std::invalid_argument("invalid character");
+    }
+    if (!(utf8[offset + 0] & 0x20)) {
+        if (offset + 1 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80)) {
+            throw std::invalid_argument("invalid character");
+        }
+        auto result = ((utf8[offset + 0] & 0x1f) << 6) | (utf8[offset + 1] & 0x3f);
+        offset += 2;
+        return result;
+    }
+    if (!(utf8[offset + 0] & 0x10)) {
+        if (offset + 2 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80)) {
+            throw std::invalid_argument("invalid character");
+        }
+        auto result = ((utf8[offset + 0] & 0x0f) << 12) | ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f);
+        offset += 3;
+        return result;
+    }
+    if (!(utf8[offset + 0] & 0x08)) {
+        if (offset + 3 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80)) {
+            throw std::invalid_argument("invalid character");
+        }
+        auto result = ((utf8[offset + 0] & 0x07) << 18) | ((utf8[offset + 1] & 0x3f) << 12) | ((utf8[offset + 2] & 0x3f) << 6) | (utf8[offset + 3] & 0x3f);
+        offset += 4;
+        return result;
+    }
+    throw std::invalid_argument("failed to convert utf8 to codepoint");
+}
+
+//static std::vector unicode_cpt_to_utf16(uint32_t cpt) {
+//    std::vector result;
+//    if (/* 0x0000 <= cpt && */ cpt <= 0xffff) {
+//        result.emplace_back(cpt);
+//        return result;
+//    }
+//    if (0x10000 <= cpt && cpt <= 0x10ffff) {
+//        result.emplace_back(0xd800 | ((cpt - 0x10000) >> 10));
+//        result.emplace_back(0xdc00 | ((cpt - 0x10000) & 0x03ff));
+//        return result;
+//    }
+//    throw std::invalid_argument("failed to convert codepoint to utf16");
+//}
+
+//static std::vector unicode_cpts_to_utf16(const std::vector & cps) {
+//    std::vector result;
+//    for (size_t i = 0; i < cps.size(); ++i) {
+//        auto temp = unicode_cpt_to_utf16(cps[i]);
+//        result.insert(result.end(), temp.begin(), temp.end());
+//    }
+//    return result;
+//}
+
+//static uint32_t unicode_cpt_from_utf16(const std::vector & utf16, size_t & offset) {
+//    assert(offset < utf16.size());
+//    if (((utf16[0] >> 10) << 10) != 0xd800) {
+//        auto result = utf16[offset + 0];
+//        offset += 1;
+//        return result;
+//    }
+//
+//    if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) {
+//        throw std::invalid_argument("invalid character");
+//    }
+//
+//    auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff));
+//    offset += 2;
+//    return result;
+//}
+
+//static std::vector unicode_cpts_from_utf16(const std::vector & utf16) {
+//    std::vector result;
+//    size_t offset = 0;
+//    while (offset < utf16.size()) {
+//        result.push_back(unicode_cpt_from_utf16(utf16, offset));
+//    }
+//    return result;
+//}
+
+static std::vector unicode_cpt_flags_array() {
+    std::vector cpt_flags(MAX_CODEPOINTS, unicode_cpt_flags::UNDEFINED);
+
+    assert (unicode_ranges_flags.begin()[0].first == 0);
+    assert (unicode_ranges_flags.begin()[unicode_ranges_flags.size()-1].first == MAX_CODEPOINTS);
+    for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
+        const auto range_ini = unicode_ranges_flags.begin()[i-1];  // codepoint_ini, flags
+        const auto range_end = unicode_ranges_flags.begin()[i];    // codepoint_end, flags
+        for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
+            cpt_flags[cpt] = range_ini.second;
+        }
+    }
+
+    for (auto cpt : unicode_set_whitespace) {
+        cpt_flags[cpt].is_whitespace = true;
+    }
+
+    for (auto p : unicode_map_lowercase) {
+        cpt_flags[p.second].is_lowercase = true;
+    }
+
+    for (auto p : unicode_map_uppercase) {
+        cpt_flags[p.second].is_uppercase = true;
+    }
+
+    for (auto &range : unicode_ranges_nfd) {  // start, last, nfd
+        cpt_flags[range.nfd].is_nfd = true;
+    }
+
+    return cpt_flags;
+}
+
+static std::unordered_map unicode_byte_to_utf8_map() {
+    std::unordered_map map;
+    for (int ch = 0x21; ch <= 0x7E; ++ch) {  // u'!' to u'~'
+        assert(0 <= ch && ch < 256);
+        map[ch] = unicode_cpt_to_utf8(ch);
+    }
+    for (int ch = 0xA1; ch <= 0xAC; ++ch) {  // u'¡' to u'¬'
+        assert(0 <= ch && ch < 256);
+        map[ch] = unicode_cpt_to_utf8(ch);
+    }
+    for (int ch = 0xAE; ch <= 0xFF; ++ch) {  // u'®' to u'ÿ'
+        assert(0 <= ch && ch < 256);
+        map[ch] = unicode_cpt_to_utf8(ch);
+    }
+    auto n = 0;
+    for (int ch = 0; ch < 256; ++ch) {
+        if (map.find(ch) == map.end()) {
+            map[ch] = unicode_cpt_to_utf8(256 + n);
+            ++n;
+        }
+    }
+    return map;
+}
+
+static std::unordered_map unicode_utf8_to_byte_map() {
+    std::unordered_map map;
+    for (int ch = 0x21; ch <= 0x7E; ++ch) {  // u'!' to u'~'
+        assert(0 <= ch && ch < 256);
+        map[unicode_cpt_to_utf8(ch)] = ch;
+    }
+    for (int ch = 0xA1; ch <= 0xAC; ++ch) {  // u'¡' to u'¬'
+        assert(0 <= ch && ch < 256);
+        map[unicode_cpt_to_utf8(ch)] = ch;
+    }
+    for (int ch = 0xAE; ch <= 0xFF; ++ch) {  // u'®' to u'ÿ'
+        assert(0 <= ch && ch < 256);
+        map[unicode_cpt_to_utf8(ch)] = ch;
+    }
+    auto n = 0;
+    for (int ch = 0; ch < 256; ++ch) {
+        if (map.find(unicode_cpt_to_utf8(ch)) == map.end()) {
+            map[unicode_cpt_to_utf8(256 + n)] = ch;
+            ++n;
+        }
+    }
+    return map;
+}
+
+static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
+#if defined(__clang__)
+    // disable C++17 deprecation warning for std::codecvt_utf8
+#    pragma clang diagnostic push
+#    pragma clang diagnostic ignored "-Wdeprecated-declarations"
+#endif
+
+    std::wstring_convert> conv;
+
+#if defined(__clang__)
+#    pragma clang diagnostic pop
+#endif
+
+    return conv.from_bytes(s);
+}
+
+static std::vector unicode_byte_encoding_process(const std::vector & bpe_words) {
+    std::vector bpe_encoded_words;
+    for (const auto & word : bpe_words) {
+        std::string text_utf;
+        auto utf_word =  unicode_cpts_from_utf8(word);
+        for (size_t i = 0; i < utf_word.size(); ++i) {
+            text_utf += unicode_cpt_to_utf8(utf_word[i]);
+        }
+
+        std::string encoded_token;
+        for (char & c : text_utf) {
+            encoded_token += unicode_byte_to_utf8(c);
+        }
+        bpe_encoded_words.emplace_back(encoded_token);
+    }
+    return bpe_encoded_words;
+}
+
+// GPT2 system regex:  's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
+static std::vector unicode_regex_split_custom_gpt2(const std::string & text, const std::vector & offsets) {
+    std::vector bpe_offsets; // store the offset of each word
+    bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
+
+    const auto cpts = unicode_cpts_from_utf8(text);
+
+    size_t start = 0;
+    for (auto offset : offsets) {
+        const size_t offset_ini = start;
+        const size_t offset_end = start + offset;
+        assert(offset_end <= cpts.size());
+        start = offset_end;
+
+        static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
+        auto _get_cpt = [&] (const size_t pos) -> uint32_t {
+            return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
+        };
+
+        auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
+            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
+        };
+
+        size_t _prev_end = offset_ini;
+        auto _add_token = [&] (const size_t end) -> size_t {
+            assert(_prev_end <= end && end <= offset_end);
+            size_t len = end - _prev_end;
+            if (len > 0) {
+                bpe_offsets.push_back(len);
+            }
+            _prev_end = end;
+            //if (len > 0) {
+            //    std::string s = "";
+            //    for(size_t p = end-len; p < end; p++)
+            //        s += unicode_cpt_to_utf8(cpts[p]);
+            //    printf(">>> '%s'\n", s.c_str());
+            //}
+            return len;
+        };
+
+        for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
+            const uint32_t cpt = _get_cpt(pos);
+            const auto flags = _get_flags(pos);
+
+            // regex: 's|'t|'re|'ve|'m|'ll|'d
+            if (cpt == '\'' && pos+1 < offset_end) {
+                uint32_t cpt_next = _get_cpt(pos+1);
+                if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
+                    pos += _add_token(pos+2);
+                    continue;
+                }
+                if (pos+2 < offset_end) {
+                    uint32_t cpt_next_next = _get_cpt(pos+2);
+                    if ((cpt_next == 'r' && cpt_next_next == 'e') ||
+                        (cpt_next == 'v' && cpt_next_next == 'e') ||
+                        (cpt_next == 'l' && cpt_next_next == 'l')) {
+                        pos += _add_token(pos+3);
+                        continue;
+                    }
+                }
+            }
+
+            auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
+            // regex: ?\p{L}+
+            if (flags2.is_letter) {
+                pos += (cpt == ' ');
+                while (flags2.is_letter) {
+                    flags2 = _get_flags(++pos);
+                }
+                _add_token(pos);
+                continue;
+            }
+            // regex: ?\p{N}+
+            if (flags2.is_number) {
+                pos += (cpt == ' ');
+                while (flags2.is_number) {
+                    flags2 = _get_flags(++pos);
+                }
+                _add_token(pos);
+                continue;
+            }
+            // regex: ?[^\s\p{L}\p{N}]+
+            if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
+                pos += (cpt == ' ');
+                while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
+                    flags2 = _get_flags(++pos);
+                }
+                _add_token(pos);
+                continue;
+            }
+
+            size_t num_whitespaces = 0;
+            while (_get_flags(pos+num_whitespaces).is_whitespace) {
+                num_whitespaces++;
+            }
+
+            // regex: \s+(?!\S)
+            if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
+                pos += num_whitespaces - 1;
+                _add_token(pos);
+                continue;
+            }
+
+            // regex: \s+
+            if (num_whitespaces > 0) {
+                pos += num_whitespaces;
+                _add_token(pos);
+                continue;
+            }
+
+            // no matches
+            _add_token(++pos);
+        }
+    }
+
+    return bpe_offsets;
+}
+
+// LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
+static std::vector unicode_regex_split_custom_llama3(const std::string & text, const std::vector & offsets) {
+    std::vector bpe_offsets; // store the offset of each word
+    bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
+
+    const auto cpts = unicode_cpts_from_utf8(text);
+
+    size_t start = 0;
+    for (auto offset : offsets) {
+        const size_t offset_ini = start;
+        const size_t offset_end = start + offset;
+        assert(offset_end <= cpts.size());
+        start = offset_end;
+
+        static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
+        auto _get_cpt = [&] (const size_t pos) -> uint32_t {
+            return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
+        };
+
+        auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
+            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
+        };
+
+        size_t _prev_end = offset_ini;
+        auto _add_token = [&] (const size_t end) -> size_t {
+            assert(_prev_end <= end && end <= offset_end);
+            size_t len = end - _prev_end;
+            if (len > 0) {
+                bpe_offsets.push_back(len);
+            }
+            _prev_end = end;
+            //if (len > 0) {
+            //    std::string s = "";
+            //    for(size_t p = end-len; p < end; p++)
+            //        s += unicode_cpt_to_utf8(cpts[p]);
+            //    printf(">>> '%s'\n", s.c_str());
+            //}
+            return len;
+        };
+
+        for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
+            const uint32_t cpt = _get_cpt(pos);
+            const auto flags = _get_flags(pos);
+
+            // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
+            if (cpt == '\'' && pos+1 < offset_end) {
+                uint32_t cpt_next = unicode_tolower(_get_cpt(pos+1));
+                if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
+                    pos += _add_token(pos+2);
+                    continue;
+                }
+                if (pos+2 < offset_end) {
+                    uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2));
+                    if ((cpt_next == 'r' && cpt_next_next == 'e') ||
+                        (cpt_next == 'v' && cpt_next_next == 'e') ||
+                        (cpt_next == 'l' && cpt_next_next == 'l')) {
+                        pos += _add_token(pos+3);
+                        continue;
+                    }
+                }
+            }
+
+            // regex: [^\r\n\p{L}\p{N}]?\p{L}+
+            if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) {
+                if (flags.is_letter || _get_flags(pos+1).is_letter) {  // one or more letters
+                    pos++;
+                    while (_get_flags(pos).is_letter) {
+                        pos++;
+                    }
+                    _add_token(pos);
+                    continue;
+                }
+            }
+
+            // regex: \p{N}{1,3}
+            if (flags.is_number) {
+                size_t ini = pos;
+                while (_get_flags(pos).is_number) {
+                    if (++pos - ini >= 3 ) {
+                        _add_token(pos);
+                        ini = pos;
+                    }
+                }
+                _add_token(pos);
+                continue;
+            }
+
+            // regex: ?[^\s\p{L}\p{N}]+[\r\n]*
+            auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
+            if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) {
+                pos += (cpt == ' ');
+                while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
+                    flags2 = _get_flags(++pos);
+                }
+                uint32_t cpt2 = _get_cpt(pos);
+                while (cpt2 == '\r' || cpt2 == '\n') {
+                    cpt2 = _get_cpt(++pos);
+                }
+                _add_token(pos);
+                continue;
+            }
+
+            size_t num_whitespaces = 0;
+            size_t last_end_r_or_n = 0;
+            while (_get_flags(pos+num_whitespaces).is_whitespace) {
+                uint32_t cpt2 = _get_cpt(pos+num_whitespaces);
+                if (cpt2 == '\r' || cpt2 == '\n') {
+                    last_end_r_or_n = pos + num_whitespaces + 1;
+                }
+                num_whitespaces++;
+            }
+
+            // regex: \s*[\r\n]+
+            if (last_end_r_or_n > 0) {
+                pos = last_end_r_or_n;
+                _add_token(pos);
+                continue;
+            }
+
+            // regex: \s+(?!\S)
+            if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
+                pos += num_whitespaces - 1;
+                _add_token(pos);
+                continue;
+            }
+
+            // regex: \s+
+            if (num_whitespaces > 0) {
+                pos += num_whitespaces;
+                _add_token(pos);
+                continue;
+            }
+
+            // no matches
+            _add_token(++pos);
+        }
+    }
+
+    return bpe_offsets;
+}
+
+// use std::wregex to split the text
+static std::vector unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector & offsets) {
+    std::wregex expr(regex_expr);
+    std::vector bpe_offsets; // store the offset of each word
+    bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
+    size_t start = 0;
+    for (auto offset : offsets) {
+        std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
+        std::wcregex_iterator end;
+
+        int64_t start_idx = 0;
+        while (it != end) {
+            std::wcmatch match = *it;
+            if (match.position() > start_idx) {
+                bpe_offsets.emplace_back(match.position() - start_idx);
+            }
+            bpe_offsets.emplace_back(match.length());
+            start_idx = match.position() + match.length();
+            ++it;
+        }
+
+        if (start_idx < (int64_t) offset) {
+            bpe_offsets.emplace_back(offset - start_idx);
+        }
+        start += offset;
+    }
+
+    return bpe_offsets;
+}
+
+// use std::regex to split the text
+static std::vector unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector & offsets) {
+    std::regex expr(regex_expr);
+    std::vector bpe_offsets; // store the offset of each word
+    bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
+    size_t start = 0;
+    for (auto offset : offsets) {
+        std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
+        std::cregex_iterator end;
+
+        int64_t start_idx = 0;
+        while (it != end) {
+            std::cmatch match = *it;
+            if (match.position() > start_idx) {
+                bpe_offsets.emplace_back(match.position() - start_idx);
+            }
+            bpe_offsets.emplace_back(match.length());
+            start_idx = match.position() + match.length();
+            ++it;
+        }
+
+        if (start_idx < (int64_t) offset) {
+            bpe_offsets.emplace_back(offset - start_idx);
+        }
+        start += offset;
+    }
+
+    return bpe_offsets;
+}
+
+static std::vector unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector & offsets) {
+    std::vector bpe_offsets;
+
+    if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
+        bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets);
+    } else if (
+            regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" ||
+            regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
+
+        bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
+    }
+
+    return bpe_offsets;
+}
+
+//
+// interface
+//
+
+std::string unicode_cpt_to_utf8(uint32_t cpt) {
+    std::string result;
+
+    if (/* 0x00 <= cpt && */ cpt <= 0x7f) {
+        result.push_back(cpt);
+        return result;
+    }
+    if (0x80 <= cpt && cpt <= 0x7ff) {
+        result.push_back(0xc0 | ((cpt >> 6) & 0x1f));
+        result.push_back(0x80 | (cpt & 0x3f));
+        return result;
+    }
+    if (0x800 <= cpt && cpt <= 0xffff) {
+        result.push_back(0xe0 | ((cpt >> 12) & 0x0f));
+        result.push_back(0x80 | ((cpt >> 6) & 0x3f));
+        result.push_back(0x80 | (cpt & 0x3f));
+        return result;
+    }
+    if (0x10000 <= cpt && cpt <= 0x10ffff) {
+        result.push_back(0xf0 | ((cpt >> 18) & 0x07));
+        result.push_back(0x80 | ((cpt >> 12) & 0x3f));
+        result.push_back(0x80 | ((cpt >> 6) & 0x3f));
+        result.push_back(0x80 | (cpt & 0x3f));
+        return result;
+    }
+
+    throw std::invalid_argument("invalid codepoint");
+}
+
+std::vector unicode_cpts_normalize_nfd(const std::vector & cpts) {
+    auto comp = [] (const uint32_t cpt, const range_nfd & range) {
+        return cpt < range.first;
+    };
+    std::vector result(cpts.size());
+    for (size_t i = 0; i < cpts.size(); ++i) {
+        const uint32_t cpt = cpts[i];
+        auto it = std::upper_bound(unicode_ranges_nfd.begin(), unicode_ranges_nfd.end(), cpt, comp) - 1;
+        result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt;
+    }
+    return result;
+}
+
+std::vector unicode_cpts_from_utf8(const std::string & utf8) {
+    std::vector result;
+    result.reserve(utf8.size());
+    size_t offset = 0;
+    while (offset < utf8.size()) {
+        try {
+            result.push_back(unicode_cpt_from_utf8(utf8, offset));
+        }
+        catch (const std::invalid_argument & /*ex*/) {
+            // Silently ignore invalid UTF-8 input to avoid leaking the exception beyond llama_tokenize
+            ++offset;
+            result.emplace_back(0xFFFD); // replacement character
+        }
+    }
+    return result;
+}
+
+unicode_cpt_flags unicode_cpt_flags_from_cpt(const uint32_t cpt) {
+    static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
+    static const auto cpt_flags = unicode_cpt_flags_array();
+    return cpt < cpt_flags.size() ? cpt_flags[cpt] : undef;
+}
+
+unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8) {
+    static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
+    if (utf8.empty()) {
+        return undef;  // undefined
+    }
+    size_t offset = 0;
+    return unicode_cpt_flags_from_cpt(unicode_cpt_from_utf8(utf8, offset));
+}
+
+std::string unicode_byte_to_utf8(uint8_t byte) {
+    static std::unordered_map map = unicode_byte_to_utf8_map();
+    return map.at(byte);
+}
+
+uint8_t unicode_utf8_to_byte(const std::string & utf8) {
+    static std::unordered_map map = unicode_utf8_to_byte_map();
+    return map.at(utf8);
+}
+
+uint32_t unicode_tolower(uint32_t cpt) {
+    // binary search
+    auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cpt,
+        [](const std::pair & pair, uint32_t value) {
+            return pair.first < value;
+        });
+    if (it != unicode_map_lowercase.end() && it->first == cpt) {
+        return it->second;
+    }
+    return cpt;  // Return the original code point if no lowercase mapping is found
+}
+
+std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) {
+    // unicode categories
+    static const std::map k_ucat_enum = {
+        { "\\p{N}", unicode_cpt_flags::NUMBER },
+        { "\\p{L}", unicode_cpt_flags::LETTER },
+        { "\\p{P}", unicode_cpt_flags::PUNCTUATION },
+        { "\\p{M}", unicode_cpt_flags::ACCENT_MARK },
+        { "\\p{S}", unicode_cpt_flags::SYMBOL },
+    };
+
+    static const std::map k_ucat_cpt = {
+        { unicode_cpt_flags::NUMBER,      0xD1 },
+        { unicode_cpt_flags::LETTER,      0xD2 },
+        { unicode_cpt_flags::PUNCTUATION, 0xD3 },
+        { unicode_cpt_flags::ACCENT_MARK, 0xD4 },
+        { unicode_cpt_flags::SYMBOL,      0xD5 },
+    };
+
+    static const std::map k_ucat_map = {
+        { unicode_cpt_flags::NUMBER,      "\x30-\x39" }, // 0-9
+        { unicode_cpt_flags::LETTER,      "\x41-\x5A\x61-\x7A" }, // A-Za-z
+        { unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
+        { unicode_cpt_flags::ACCENT_MARK, "" }, // no sub-128 codepoints
+        { unicode_cpt_flags::SYMBOL,      "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`|
+    };
+
+    // compute collapsed codepoints only if needed by at least one regex
+    bool need_collapse = false;
+    for (const auto & regex_expr : regex_exprs) {
+        // search for unicode categories
+        for (const auto & ucat : k_ucat_enum) {
+            if (std::string::npos != regex_expr.find(ucat.first)) {
+                need_collapse = true;
+                break;
+            }
+        }
+    }
+
+    const auto cpts = unicode_cpts_from_utf8(text);
+
+    // generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
+    // ref: https://github.com/ggml-org/llama.cpp/pull/6920#issuecomment-2081479935
+    std::string text_collapsed;
+    if (need_collapse) {
+        // collapse all unicode categories
+        text_collapsed.resize(cpts.size());
+
+        for (size_t i = 0; i < cpts.size(); ++i) {
+            // keep single-byte codepoints as is
+            if (cpts[i] < 128) {
+                text_collapsed[i] = cpts[i];
+                continue;
+            }
+
+            const auto flags = unicode_cpt_flags_from_cpt(cpts[i]);
+
+            if (flags.is_whitespace) {
+                //NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
+                //text_collapsed[i] = (char) 0x85;  //  as whitespace fallback
+                text_collapsed[i] = (char) 0x0B;    //  as whitespace fallback
+            } else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
+                text_collapsed[i] = k_ucat_cpt.at(flags.category_flag());
+            } else {
+                text_collapsed[i] = (char) 0xD0; // fallback
+            }
+        }
+    }
+
+    std::vector bpe_offsets = { cpts.size() };
+
+    for (const auto & regex_expr : regex_exprs) {
+        // first, see if we have an efficient custom regex implementation
+        auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets);
+
+        if (!tmp.empty()) {
+            bpe_offsets = std::move(tmp);
+            continue;
+        }
+
+        // fallback to general-purpose std::regex / std::wregex
+        try {
+            // if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
+            // with the corresponding collapsed representation
+            bool use_collapsed = false;
+            for (const auto & ucat : k_ucat_enum) {
+                if (std::string::npos != regex_expr.find(ucat.first)) {
+                    use_collapsed = true;
+                    break;
+                }
+            }
+
+            if (use_collapsed) {
+                // sanity-check that the original regex does not contain any non-ASCII characters
+                const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
+                for (size_t i = 0; i < cpts_regex.size(); ++i) {
+                    if (cpts_regex[i] >= 128) {
+                        throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported");
+                    }
+                }
+
+                // generate a collapsed representation of the regex
+                std::string regex_expr_collapsed;
+
+                // track if we are inside [], because nested [] are not allowed
+                bool inside = false;
+                for (size_t i = 0; i < regex_expr.size(); ++i) {
+                    if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) {
+                        regex_expr_collapsed += '[';
+                        inside = true;
+                        continue;
+                    }
+
+                    if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') {
+                        regex_expr_collapsed += ']';
+                        inside = false;
+                        continue;
+                    }
+
+                    if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() &&
+                        regex_expr[i + 1] == 'p' &&
+                        regex_expr[i + 2] == '{' &&
+                        regex_expr[i + 4] == '}') {
+                        const std::string pat = regex_expr.substr(i, 5);
+                        if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
+                            if (!inside) {
+                                regex_expr_collapsed += '[';
+                            }
+                            regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
+                            regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
+                            if (!inside) {
+                                regex_expr_collapsed += ']';
+                            }
+                            i += 4;
+                            continue;
+                        }
+                    }
+
+                    regex_expr_collapsed += regex_expr[i];
+                }
+
+                //printf("text_collapsed: %s\n", text_collapsed.c_str());
+                //printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str());
+                bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
+            } else {
+                // no unicode category used, we can use std::wregex directly
+                const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
+
+                // std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
+                std::wstring wtext(cpts.begin(), cpts.end());
+                for (size_t i = 0; i < wtext.size(); ++i) {
+                    if (wtext[i] > 0x7F && unicode_cpt_flags_from_cpt(wtext[i]).is_whitespace) {
+                        wtext[i] = 0x0B;
+                    }
+                }
+
+                //printf("text: %s\n", text.c_str());
+                //printf("regex_expr: %s\n", regex_expr.c_str());
+                bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
+            }
+        } catch (std::regex_error & e) {
+            fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str());
+            fprintf(stderr, "Regex error: %s\n", e.what());
+            throw std::runtime_error("Failed to process regex");
+        }
+    }
+
+    std::vector bpe_words;
+    bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size
+
+    size_t start = 0;
+    for (size_t & offset : bpe_offsets) {
+        bpe_words.emplace_back();
+        for (size_t i = start; i < start + offset; ++i) {
+            bpe_words.back() += unicode_cpt_to_utf8(cpts[i]);
+        }
+        start += offset;
+    }
+
+    return unicode_byte_encoding_process(bpe_words);
+}
diff --git a/src/unicode.h b/src/unicode.h
new file mode 100644
index 0000000000000000000000000000000000000000..c27098df7d4bec46cd06b1542d3f99c393755527
--- /dev/null
+++ b/src/unicode.h
@@ -0,0 +1,66 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+struct unicode_cpt_flags {
+    enum {
+        UNDEFINED       = 0x0001,
+        NUMBER          = 0x0002,  // regex: \p{N}
+        LETTER          = 0x0004,  // regex: \p{L}
+        SEPARATOR       = 0x0008,  // regex: \p{Z}
+        ACCENT_MARK     = 0x0010,  // regex: \p{M}
+        PUNCTUATION     = 0x0020,  // regex: \p{P}
+        SYMBOL          = 0x0040,  // regex: \p{S}
+        CONTROL         = 0x0080,  // regex: \p{C}
+        MASK_CATEGORIES = 0x00FF,
+    };
+
+    // codepoint type
+    uint16_t is_undefined   : 1;
+    uint16_t is_number      : 1;  // regex: \p{N}
+    uint16_t is_letter      : 1;  // regex: \p{L}
+    uint16_t is_separator   : 1;  // regex: \p{Z}
+    uint16_t is_accent_mark : 1;  // regex: \p{M}
+    uint16_t is_punctuation : 1;  // regex: \p{P}
+    uint16_t is_symbol      : 1;  // regex: \p{S}
+    uint16_t is_control     : 1;  // regex: \p{C}
+    // helper flags
+    uint16_t is_whitespace  : 1;  // regex: \s
+    uint16_t is_lowercase   : 1;
+    uint16_t is_uppercase   : 1;
+    uint16_t is_nfd         : 1;
+
+    // decode from uint16
+    inline unicode_cpt_flags(const uint16_t flags = 0) {
+        *reinterpret_cast(this) = flags;
+    }
+
+    inline uint16_t as_uint() const {
+        return *reinterpret_cast(this);
+    }
+
+    inline uint16_t category_flag() const {
+        return this->as_uint() & MASK_CATEGORIES;
+    }
+};
+
+size_t unicode_len_utf8(char src);
+
+std::string unicode_cpt_to_utf8  (uint32_t cpt);
+uint32_t    unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
+
+std::vector unicode_cpts_from_utf8(const std::string & utf8);
+
+std::vector unicode_cpts_normalize_nfd(const std::vector & cpts);
+
+unicode_cpt_flags unicode_cpt_flags_from_cpt (uint32_t cpt);
+unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8);
+
+std::string unicode_byte_to_utf8(uint8_t byte);
+uint8_t     unicode_utf8_to_byte(const std::string & utf8);
+
+uint32_t unicode_tolower(uint32_t cpt);
+
+std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs);
diff --git a/tests/.gitignore b/tests/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..620a48ee4449b402705783f40899e70bdcc030eb
--- /dev/null
+++ b/tests/.gitignore
@@ -0,0 +1,4 @@
+*
+!*.*
+*.o
+ggml-common.h
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fc1557a2d406594f55e5c3a36928dbd6620bb44d
--- /dev/null
+++ b/tests/CMakeLists.txt
@@ -0,0 +1,218 @@
+llama_add_compile_flags()
+
+function(llama_build source)
+    if (DEFINED LLAMA_TEST_NAME)
+        set(TEST_TARGET ${LLAMA_TEST_NAME})
+    else()
+        get_filename_component(TEST_TARGET ${source} NAME_WE)
+    endif()
+
+    add_executable(${TEST_TARGET} ${source})
+    target_link_libraries(${TEST_TARGET} PRIVATE common)
+    install(TARGETS ${TEST_TARGET} RUNTIME)
+endfunction()
+
+function(llama_test target)
+    include(CMakeParseArguments)
+    set(options)
+    set(oneValueArgs NAME LABEL WORKING_DIRECTORY)
+    set(multiValueArgs ARGS)
+    cmake_parse_arguments(LLAMA_TEST "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
+
+    if (NOT DEFINED LLAMA_TEST_LABEL)
+        set(LLAMA_TEST_LABEL "main")
+    endif()
+    if (NOT DEFINED LLAMA_TEST_WORKING_DIRECTORY)
+        set(LLAMA_TEST_WORKING_DIRECTORY .)
+    endif()
+    if (DEFINED LLAMA_TEST_NAME)
+        set(TEST_NAME ${LLAMA_TEST_NAME})
+    else()
+        set(TEST_NAME ${target})
+    endif()
+
+    set(TEST_TARGET ${target})
+
+    add_test(
+        NAME ${TEST_NAME}
+        WORKING_DIRECTORY ${LLAMA_TEST_WORKING_DIRECTORY}
+        COMMAND $
+        ${LLAMA_TEST_ARGS})
+
+    set_property(TEST ${TEST_NAME} PROPERTY LABELS ${LLAMA_TEST_LABEL})
+endfunction()
+
+function(llama_test_cmd target)
+    include(CMakeParseArguments)
+    set(options)
+    set(oneValueArgs NAME LABEL WORKING_DIRECTORY)
+    set(multiValueArgs ARGS)
+    cmake_parse_arguments(LLAMA_TEST "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
+
+    if (NOT DEFINED LLAMA_TEST_LABEL)
+        set(LLAMA_TEST_LABEL "main")
+    endif()
+    if (NOT DEFINED LLAMA_TEST_WORKING_DIRECTORY)
+        set(LLAMA_TEST_WORKING_DIRECTORY .)
+    endif()
+    if (DEFINED LLAMA_TEST_NAME)
+        set(TEST_NAME ${LLAMA_TEST_NAME})
+    else()
+        set(TEST_NAME ${target})
+    endif()
+
+    add_test(
+        NAME ${TEST_NAME}
+        WORKING_DIRECTORY ${LLAMA_TEST_WORKING_DIRECTORY}
+        COMMAND ${target}
+        ${LLAMA_TEST_ARGS})
+
+    set_property(TEST ${TEST_NAME} PROPERTY LABELS ${LLAMA_TEST_LABEL})
+endfunction()
+
+# Builds and runs a test source file.
+# Optional args:
+# - NAME: name of the executable & test target (defaults to the source file name without extension)
+# - LABEL: label for the test (defaults to main)
+# - ARGS: arguments to pass to the test executable
+# - WORKING_DIRECTORY
+function(llama_build_and_test source)
+    include(CMakeParseArguments)
+    set(options)
+    set(oneValueArgs NAME LABEL WORKING_DIRECTORY)
+    set(multiValueArgs ARGS)
+    cmake_parse_arguments(LLAMA_TEST "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
+
+    if (NOT DEFINED LLAMA_TEST_LABEL)
+        set(LLAMA_TEST_LABEL "main")
+    endif()
+    if (NOT DEFINED LLAMA_TEST_WORKING_DIRECTORY)
+        set(LLAMA_TEST_WORKING_DIRECTORY .)
+    endif()
+    if (DEFINED LLAMA_TEST_NAME)
+        set(TEST_TARGET ${LLAMA_TEST_NAME})
+    else()
+        get_filename_component(TEST_TARGET ${source} NAME_WE)
+    endif()
+
+    add_executable(${TEST_TARGET} ${source} get-model.cpp)
+    install(TARGETS ${TEST_TARGET} RUNTIME)
+    target_link_libraries(${TEST_TARGET} PRIVATE common)
+
+    add_test(
+        NAME ${TEST_TARGET}
+        WORKING_DIRECTORY ${LLAMA_TEST_WORKING_DIRECTORY}
+        COMMAND $
+        ${LLAMA_TEST_ARGS})
+
+    set_property(TEST ${TEST_TARGET} PROPERTY LABELS ${LLAMA_TEST_LABEL})
+endfunction()
+
+# build test-tokenizer-0 target once and add many tests
+llama_build(test-tokenizer-0.cpp)
+
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-bert-bge          ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-bert-bge.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-command-r         ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-command-r.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-deepseek-coder    ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-deepseek-coder.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-deepseek-llm      ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-deepseek-llm.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-falcon            ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-falcon.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-gpt-2             ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-gpt-2.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-llama-bpe         ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-llama-bpe.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-llama-spm         ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-llama-spm.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-mpt               ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-mpt.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-phi-3             ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-phi-3.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-qwen2             ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-qwen2.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-refact            ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-refact.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-starcoder         ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-starcoder.gguf)
+
+if (NOT WIN32)
+    llama_test_cmd(
+        ${CMAKE_CURRENT_SOURCE_DIR}/test-tokenizers-repo.sh
+        NAME test-tokenizers-ggml-vocabs
+        WORKING_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}
+        ARGS https://huggingface.co/ggml-org/vocabs ${PROJECT_SOURCE_DIR}/models/ggml-vocabs
+    )
+endif()
+
+if (LLAMA_LLGUIDANCE)
+    llama_build_and_test(test-grammar-llguidance.cpp ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-llama-bpe.gguf)
+endif ()
+
+if (NOT WIN32 OR NOT BUILD_SHARED_LIBS)
+    # these tests are disabled on Windows because they use internal functions not exported with LLAMA_API (when building with shared libraries)
+    llama_build_and_test(test-sampling.cpp)
+    llama_build_and_test(test-grammar-parser.cpp)
+    llama_build_and_test(test-grammar-integration.cpp)
+    llama_build_and_test(test-llama-grammar.cpp)
+    llama_build_and_test(test-chat.cpp)
+    # TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8
+    if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
+        llama_build_and_test(test-json-schema-to-grammar.cpp   WORKING_DIRECTORY ${PROJECT_SOURCE_DIR})
+        target_include_directories(test-json-schema-to-grammar PRIVATE ${PROJECT_SOURCE_DIR}/tools/server)
+    endif()
+
+    if (NOT GGML_BACKEND_DL)
+        llama_build(test-quantize-stats.cpp)
+    endif()
+
+    llama_build(test-gbnf-validator.cpp)
+
+    # build test-tokenizer-1-bpe target once and add many tests
+    llama_build(test-tokenizer-1-bpe.cpp)
+
+    # TODO: disabled due to slowness
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-aquila    ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-aquila.gguf)
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-falcon    ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-falcon.gguf)
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-gpt-2     ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-gpt-2.gguf)
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-gpt-neox  ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-gpt-neox.gguf)
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-llama-bpe ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-llama-bpe.gguf --ignore-merges)
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-mpt       ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-mpt.gguf)
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-refact    ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-refact.gguf)
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-starcoder ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-starcoder.gguf)
+
+    # build test-tokenizer-1-spm target once and add many tests
+    llama_build(test-tokenizer-1-spm.cpp)
+
+    llama_test(test-tokenizer-1-spm  NAME test-tokenizer-1-llama-spm ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-llama-spm.gguf)
+    #llama_test(test-tokenizer-1-spm  NAME test-tokenizer-1-baichuan  ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-baichuan.gguf)
+
+    # llama_build_and_test(test-double-float.cpp) # SLOW
+endif()
+
+llama_build_and_test(test-chat-parser.cpp)
+llama_build_and_test(test-chat-template.cpp)
+llama_build_and_test(test-json-partial.cpp)
+llama_build_and_test(test-log.cpp)
+llama_build_and_test(test-regex-partial.cpp)
+
+llama_build_and_test(test-thread-safety.cpp ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf -ngl 99 -p "The meaning of life is" -n 128 -c 256 -ub 32 -np 4)
+
+# this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135)
+if (NOT WIN32)
+    llama_build_and_test(test-arg-parser.cpp)
+endif()
+
+# llama_build_and_test(test-opt.cpp) # SLOW
+llama_build_and_test(test-gguf.cpp)
+llama_build_and_test(test-backend-ops.cpp)
+
+llama_build_and_test(test-model-load-cancel.cpp  LABEL "model")
+llama_build_and_test(test-autorelease.cpp        LABEL "model")
+
+if (NOT GGML_BACKEND_DL)
+    # these tests use the backends directly and cannot be built with dynamic loading
+    llama_build_and_test(test-barrier.cpp)
+    llama_build_and_test(test-quantize-fns.cpp)
+    llama_build_and_test(test-quantize-perf.cpp)
+    llama_build_and_test(test-rope.cpp)
+endif()
+
+# libmtmd
+set(LLAMA_TEST_NAME test-mtmd-c-api)
+llama_build_and_test(test-mtmd-c-api.c)
+target_link_libraries(${LLAMA_TEST_NAME} PRIVATE mtmd)
+
+# dummy executable - not installed
+get_filename_component(TEST_TARGET test-c.c NAME_WE)
+add_executable(${TEST_TARGET} test-c.c)
+target_link_libraries(${TEST_TARGET} PRIVATE llama)
diff --git a/tests/get-model.cpp b/tests/get-model.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..4edb685f0fbf79f8b85b0119442d141008c9f2d8
--- /dev/null
+++ b/tests/get-model.cpp
@@ -0,0 +1,21 @@
+#include 
+#include 
+#include 
+
+#include "get-model.h"
+
+char * get_model_or_exit(int argc, char *argv[]) {
+    char * model_path;
+    if (argc > 1) {
+        model_path = argv[1];
+
+    } else {
+        model_path = getenv("LLAMACPP_TEST_MODELFILE");
+        if (!model_path || strlen(model_path) == 0) {
+            fprintf(stderr, "\033[33mWARNING: No model file provided. Skipping this test. Set LLAMACPP_TEST_MODELFILE= to silence this warning and run this test.\n\033[0m");
+            exit(EXIT_SUCCESS);
+        }
+    }
+
+    return model_path;
+}
diff --git a/tests/get-model.h b/tests/get-model.h
new file mode 100644
index 0000000000000000000000000000000000000000..81a3a0fefdabf7dd983767231affd4f38b3653a1
--- /dev/null
+++ b/tests/get-model.h
@@ -0,0 +1,2 @@
+#pragma once
+char * get_model_or_exit(int, char*[]);
diff --git a/tests/run-json-schema-to-grammar.mjs b/tests/run-json-schema-to-grammar.mjs
new file mode 100644
index 0000000000000000000000000000000000000000..450c3dde0abad5063d0631828eff241475c40a8c
--- /dev/null
+++ b/tests/run-json-schema-to-grammar.mjs
@@ -0,0 +1,10 @@
+import { readFileSync } from "fs"
+import { SchemaConverter } from "../tools/server/public_legacy/json-schema-to-grammar.mjs"
+
+const [, , file] = process.argv
+const url = `file://${file}`
+let schema = JSON.parse(readFileSync(file, "utf8"));
+const converter = new SchemaConverter({})
+schema = await converter.resolveRefs(schema, url)
+converter.visit(schema, '')
+console.log(converter.formatGrammar())
diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..e2836ca4814b4a74cf6cd5937dc9991cbe84e4c6
--- /dev/null
+++ b/tests/test-arg-parser.cpp
@@ -0,0 +1,178 @@
+#include "arg.h"
+#include "common.h"
+
+#include 
+#include 
+#include 
+#include 
+
+#undef NDEBUG
+#include 
+
+int main(void) {
+    common_params params;
+
+    printf("test-arg-parser: make sure there is no duplicated arguments in any examples\n\n");
+    for (int ex = 0; ex < LLAMA_EXAMPLE_COUNT; ex++) {
+        try {
+            auto ctx_arg = common_params_parser_init(params, (enum llama_example)ex);
+            std::unordered_set seen_args;
+            std::unordered_set seen_env_vars;
+            for (const auto & opt : ctx_arg.options) {
+                // check for args duplications
+                for (const auto & arg : opt.args) {
+                    if (seen_args.find(arg) == seen_args.end()) {
+                        seen_args.insert(arg);
+                    } else {
+                        fprintf(stderr, "test-arg-parser: found different handlers for the same argument: %s", arg);
+                        exit(1);
+                    }
+                }
+                // check for env var duplications
+                if (opt.env) {
+                    if (seen_env_vars.find(opt.env) == seen_env_vars.end()) {
+                        seen_env_vars.insert(opt.env);
+                    } else {
+                        fprintf(stderr, "test-arg-parser: found different handlers for the same env var: %s", opt.env);
+                        exit(1);
+                    }
+                }
+            }
+        } catch (std::exception & e) {
+            printf("%s\n", e.what());
+            assert(false);
+        }
+    }
+
+    auto list_str_to_char = [](std::vector & argv) -> std::vector {
+        std::vector res;
+        for (auto & arg : argv) {
+            res.push_back(const_cast(arg.data()));
+        }
+        return res;
+    };
+
+    std::vector argv;
+
+    printf("test-arg-parser: test invalid usage\n\n");
+
+    // missing value
+    argv = {"binary_name", "-m"};
+    assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
+
+    // wrong value (int)
+    argv = {"binary_name", "-ngl", "hello"};
+    assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
+
+    // wrong value (enum)
+    argv = {"binary_name", "-sm", "hello"};
+    assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
+
+    // non-existence arg in specific example (--draft cannot be used outside llama-speculative)
+    argv = {"binary_name", "--draft", "123"};
+    assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_EMBEDDING));
+
+
+    printf("test-arg-parser: test valid usage\n\n");
+
+    argv = {"binary_name", "-m", "model_file.gguf"};
+    assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
+    assert(params.model.path == "model_file.gguf");
+
+    argv = {"binary_name", "-t", "1234"};
+    assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
+    assert(params.cpuparams.n_threads == 1234);
+
+    argv = {"binary_name", "--verbose"};
+    assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
+    assert(params.verbosity > 1);
+
+    argv = {"binary_name", "-m", "abc.gguf", "--predict", "6789", "--batch-size", "9090"};
+    assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
+    assert(params.model.path == "abc.gguf");
+    assert(params.n_predict == 6789);
+    assert(params.n_batch == 9090);
+
+    // --draft cannot be used outside llama-speculative
+    argv = {"binary_name", "--draft", "123"};
+    assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_SPECULATIVE));
+    assert(params.speculative.n_max == 123);
+
+// skip this part on windows, because setenv is not supported
+#ifdef _WIN32
+    printf("test-arg-parser: skip on windows build\n");
+#else
+    printf("test-arg-parser: test environment variables (valid + invalid usages)\n\n");
+
+    setenv("LLAMA_ARG_THREADS", "blah", true);
+    argv = {"binary_name"};
+    assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
+
+    setenv("LLAMA_ARG_MODEL", "blah.gguf", true);
+    setenv("LLAMA_ARG_THREADS", "1010", true);
+    argv = {"binary_name"};
+    assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
+    assert(params.model.path == "blah.gguf");
+    assert(params.cpuparams.n_threads == 1010);
+
+
+    printf("test-arg-parser: test environment variables being overwritten\n\n");
+
+    setenv("LLAMA_ARG_MODEL", "blah.gguf", true);
+    setenv("LLAMA_ARG_THREADS", "1010", true);
+    argv = {"binary_name", "-m", "overwritten.gguf"};
+    assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
+    assert(params.model.path == "overwritten.gguf");
+    assert(params.cpuparams.n_threads == 1010);
+#endif // _WIN32
+
+    if (common_has_curl()) {
+        printf("test-arg-parser: test curl-related functions\n\n");
+        const char * GOOD_URL = "https://ggml.ai/";
+        const char * BAD_URL  = "https://www.google.com/404";
+        const char * BIG_FILE = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v1.bin";
+
+        {
+            printf("test-arg-parser: test good URL\n\n");
+            auto res = common_remote_get_content(GOOD_URL, {});
+            assert(res.first == 200);
+            assert(res.second.size() > 0);
+            std::string str(res.second.data(), res.second.size());
+            assert(str.find("llama.cpp") != std::string::npos);
+        }
+
+        {
+            printf("test-arg-parser: test bad URL\n\n");
+            auto res = common_remote_get_content(BAD_URL, {});
+            assert(res.first == 404);
+        }
+
+        {
+            printf("test-arg-parser: test max size error\n");
+            common_remote_params params;
+            params.max_size = 1;
+            try {
+                common_remote_get_content(GOOD_URL, params);
+                assert(false && "it should throw an error");
+            } catch (std::exception & e) {
+                printf("  expected error: %s\n\n", e.what());
+            }
+        }
+
+        {
+            printf("test-arg-parser: test timeout error\n");
+            common_remote_params params;
+            params.timeout = 1;
+            try {
+                common_remote_get_content(BIG_FILE, params);
+                assert(false && "it should throw an error");
+            } catch (std::exception & e) {
+                printf("  expected error: %s\n\n", e.what());
+            }
+        }
+    } else {
+        printf("test-arg-parser: no curl, skipping curl-related functions\n");
+    }
+
+    printf("test-arg-parser: all tests OK\n\n");
+}
diff --git a/tests/test-autorelease.cpp b/tests/test-autorelease.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..35b09aaeacac8d8bd1bc06fb5f2ec6e80f0e1076
--- /dev/null
+++ b/tests/test-autorelease.cpp
@@ -0,0 +1,24 @@
+// ref: https://github.com/ggerganov/llama.cpp/issues/4952#issuecomment-1892864763
+
+#include 
+#include 
+#include 
+
+#include "llama.h"
+#include "get-model.h"
+
+// This creates a new context inside a pthread and then tries to exit cleanly.
+int main(int argc, char ** argv) {
+    auto * model_path = get_model_or_exit(argc, argv);
+
+    std::thread([&model_path]() {
+        llama_backend_init();
+        auto * model = llama_model_load_from_file(model_path, llama_model_default_params());
+        auto * ctx = llama_init_from_model(model, llama_context_default_params());
+        llama_free(ctx);
+        llama_model_free(model);
+        llama_backend_free();
+    }).join();
+
+    return 0;
+}
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..509a4b35f57cb1be93b0375c88a7e20dc2c41396
--- /dev/null
+++ b/tests/test-backend-ops.cpp
@@ -0,0 +1,4805 @@
+// This file defines tests for various GGML ops and backends.
+// For the forward pass it asserts that the results of multiple backends computing the same GGML ops are consistent.
+// For the backward pass it asserts that the gradients from backpropagation are consistent
+// with the gradients obtained via the method of finite differences ("grad" mode, this is optional).
+// It is also possible to check the performance ("perf" mode).
+//
+// this file has three sections: Section 1 does general setup, section 2 defines the GGML ops to be tested,
+// and section 3 defines which tests to run.
+// Quick start for adding a new GGML op: Go to section 2 and create a struct that inherits from test_case,
+// then go to section 3 and add an instantiation of your struct.
+
+
+// ##############################
+// ## Section 1: General Setup ##
+// ##############################
+
+
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
+    size_t nels = ggml_nelements(tensor);
+    std::vector data(nels);
+    {
+        // parallel initialization
+        static const size_t n_threads = std::thread::hardware_concurrency();
+        // static RNG initialization (revisit if n_threads stops being constant)
+        static std::vector generators = []() {
+            std::random_device rd;
+            std::vector vec;
+            vec.reserve(n_threads);
+            //for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(1234 + i); } // fixed seed
+            for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(rd()); }
+            return vec;
+        }();
+
+        auto init_thread = [&](size_t ith, size_t start, size_t end) {
+            std::uniform_real_distribution distribution(min, max);
+            auto & gen = generators[ith];
+            for (size_t i = start; i < end; i++) {
+                data[i] = distribution(gen);
+            }
+        };
+
+        std::vector> tasks;
+        tasks.reserve(n_threads);
+        for (size_t i = 0; i < n_threads; i++) {
+            size_t start =     i*nels/n_threads;
+            size_t end   = (i+1)*nels/n_threads;
+            tasks.push_back(std::async(std::launch::async, init_thread, i, start, end));
+        }
+        for (auto & t : tasks) {
+            t.get();
+        }
+    }
+
+    if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {
+        ggml_backend_tensor_set(tensor, data.data(), 0, nels * sizeof(float));
+    } else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16) {
+        GGML_ASSERT(nels % ggml_blck_size(tensor->type) == 0);
+
+         // dummy importance matrix
+        std::vector imatrix(tensor->ne[0], 1.0f);
+        const float * im = imatrix.data();
+        if (!ggml_quantize_requires_imatrix(tensor->type)) {
+            // when the imatrix is optional, we want to test both quantization with and without imatrix
+            // use one of the random numbers to decide
+            if (data[0] > 0.5f*(min + max)) {
+                im = nullptr;
+            }
+        }
+
+        std::vector dataq(ggml_row_size(tensor->type, nels));
+        {
+            // parallel quantization by block
+            size_t blck_size = ggml_blck_size(tensor->type);
+            size_t n_blocks = nels / blck_size;
+
+            auto quantize_thread = [&](size_t start, size_t end) {
+                ggml_quantize_chunk(tensor->type, data.data(), dataq.data(),
+                    start * blck_size, end - start, blck_size, im);
+            };
+
+            const size_t min_blocks_per_thread = 1;
+            const size_t n_threads = std::min(std::thread::hardware_concurrency()/2,
+                                                      std::max(1, n_blocks / min_blocks_per_thread));
+            std::vector> tasks;
+            tasks.reserve(n_threads);
+            for (size_t i = 0; i < n_threads; i++) {
+                size_t start =     i*n_blocks/n_threads;
+                size_t end   = (i+1)*n_blocks/n_threads;
+                tasks.push_back(std::async(std::launch::async, quantize_thread, start, end));
+            }
+            for (auto & t : tasks) {
+                t.get();
+            }
+        }
+        ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size());
+    } else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) {
+        // This is going to create some weird integers though.
+        ggml_backend_tensor_set(tensor, data.data(), 0, ggml_nbytes(tensor));
+    } else if (tensor->type == GGML_TYPE_I64) {
+        // Integers with a size of 8 bytes can be set by mirroring the float data, the specific values are again not really meaningful.
+        const size_t nbytes_half = ggml_nbytes(tensor)/2;
+        ggml_backend_tensor_set(tensor, data.data(), 0*nbytes_half, nbytes_half);
+        ggml_backend_tensor_set(tensor, data.data(), 1*nbytes_half, nbytes_half);
+    } else {
+        GGML_ABORT("fatal error");
+    }
+}
+
+static std::vector tensor_to_float(const ggml_tensor * t) {
+    std::vector tv;
+    tv.reserve(ggml_nelements(t));
+
+    std::vector buf(ggml_nbytes(t));
+    ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t));
+
+    const auto * tt = ggml_get_type_traits(t->type);
+    size_t bs = ggml_blck_size(t->type);
+    std::vector vq(ggml_blck_size(t->type));
+    bool quantized = ggml_is_quantized(t->type);
+
+    // access elements by index to avoid gaps in views
+    for (int64_t i3 = 0; i3 < t->ne[3]; i3++) {
+        for (int64_t i2 = 0; i2 < t->ne[2]; i2++) {
+            for (int64_t i1 = 0; i1 < t->ne[1]; i1++) {
+                for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) {
+                    size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0];
+                    if (t->type == GGML_TYPE_F16) {
+                        tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i]));
+                    } else if (t->type == GGML_TYPE_BF16) {
+                        tv.push_back(ggml_bf16_to_fp32(*(ggml_bf16_t*)&buf[i]));
+                    } else if (t->type == GGML_TYPE_F32) {
+                        tv.push_back(*(float *) &buf[i]);
+                    } else if (t->type == GGML_TYPE_I64) {
+                        tv.push_back((float)*(int64_t *) &buf[i]);
+                    } else if (t->type == GGML_TYPE_I32) {
+                        tv.push_back((float)*(int32_t *) &buf[i]);
+                    } else if (t->type == GGML_TYPE_I16) {
+                        tv.push_back((float)*(int16_t *) &buf[i]);
+                    } else if (t->type == GGML_TYPE_I8) {
+                        tv.push_back((float)*(int8_t *) &buf[i]);
+                    } else if (quantized) {
+                        tt->to_float(&buf[i], vq.data(), bs);
+                        tv.insert(tv.end(), vq.begin(), vq.end());
+                    } else {
+                        GGML_ABORT("fatal error");
+                    }
+                }
+            }
+        }
+    }
+
+    return tv;
+}
+
+// normalized mean squared error = mse(a, b) / mse(a, 0)
+static double nmse(const float * a, const float * b, size_t n) {
+    double mse_a_b = 0.0;
+    double mse_a_0 = 0.0;
+
+    for (size_t i = 0; i < n; i++) {
+        float a_i = a[i];
+        float b_i = b[i];
+
+        mse_a_b += (a_i - b_i) * (a_i - b_i);
+        mse_a_0 += a_i * a_i;
+    }
+
+    return mse_a_b / mse_a_0;
+}
+
+// maximum absolute asymmetry between a and b
+// asymmetry: (a - b) / (a + b)
+// This is more stable than relative error if one of the values fluctuates towards zero.
+// n: number of values to compare.
+// expected_vals: optional vector of expected values for a. If expected_vals is not empty, filter out all comparisons where
+//     a does not match any of the expected values. Needed for noncontinuous gradients where the numerical calculation can fail.
+static double mean_abs_asymm(const float * a, const float * b, const size_t n, const std::vector & expected_vals) {
+    double sum = 0.0f;
+
+    size_t nvalid = 0;
+    for (size_t i = 0; i < n; i++) {
+        if (!expected_vals.empty()) {
+            bool matches_any = false;
+            for (const float & ev : expected_vals) {
+                if (fabsf(a[i] - ev) < 1e-3f) {
+                    matches_any = true;
+                    break;
+                }
+            }
+            if (!matches_any) {
+                continue;
+            }
+        }
+
+        const float asymm = (a[i] - b[i]) / (a[i] + b[i]);
+
+        sum += fabsf(asymm);
+        nvalid++;
+    }
+
+    return sum/nvalid;
+}
+
+// utils for printing the variables of the test cases
+
+template
+static std::string var_to_str(const T & x) {
+    return std::to_string(x);
+}
+
+template
+static std::string var_to_str(const T (&x)[N]) {
+    std::string s = "[";
+    for (size_t i = 0; i < N; i++) {
+        if (i > 0) {
+            s += ",";
+        }
+        s += var_to_str(x[i]);
+    }
+    s += "]";
+    return s;
+}
+
+template
+static std::string var_to_str(const std::array & x) {
+    std::string s = "[";
+    for (size_t i = 0; i < N; i++) {
+        if (i > 0) {
+            s += ",";
+        }
+        s += var_to_str(x[i]);
+    }
+    s += "]";
+    return s;
+}
+
+static std::string var_to_str(ggml_type type) {
+    return ggml_type_name(type);
+}
+
+static std::string var_to_str(ggml_prec prec) {
+    return prec == GGML_PREC_F32 ? "f32" : "def";
+}
+
+static std::string var_to_str(ggml_op_pool pool) {
+    switch (pool) {
+        case GGML_OP_POOL_AVG:  return "avg";
+        case GGML_OP_POOL_MAX:  return "max";
+        default:                return std::to_string(pool);
+    }
+}
+
+static std::string var_to_str(ggml_scale_mode mode) {
+    switch (mode) {
+        case GGML_SCALE_MODE_NEAREST:  return "nearest";
+        case GGML_SCALE_MODE_BILINEAR: return "bilinear";
+        default:                      return std::to_string(mode);
+    }
+}
+
+#define VAR_TO_STR(x) (#x "=" + var_to_str(x))
+
+#define VARS_TO_STR1(a) VAR_TO_STR(a)
+#define VARS_TO_STR2(a, b) VAR_TO_STR(a) + "," + VAR_TO_STR(b)
+#define VARS_TO_STR3(a, b, c) VAR_TO_STR(a) + "," + VARS_TO_STR2(b, c)
+#define VARS_TO_STR4(a, b, c, d) VAR_TO_STR(a) + "," + VARS_TO_STR3(b, c, d)
+#define VARS_TO_STR5(a, b, c, d, e) VAR_TO_STR(a) + "," + VARS_TO_STR4(b, c, d, e)
+#define VARS_TO_STR6(a, b, c, d, e, f) VAR_TO_STR(a) + "," + VARS_TO_STR5(b, c, d, e, f)
+#define VARS_TO_STR7(a, b, c, d, e, f, g) VAR_TO_STR(a) + "," + VARS_TO_STR6(b, c, d, e, f, g)
+#define VARS_TO_STR8(a, b, c, d, e, f, g, h) VAR_TO_STR(a) + "," + VARS_TO_STR7(b, c, d, e, f, g, h)
+#define VARS_TO_STR9(a, b, c, d, e, f, g, h, i) VAR_TO_STR(a) + "," + VARS_TO_STR8(b, c, d, e, f, g, h, i)
+#define VARS_TO_STR10(a, b, c, d, e, f, g, h, i, j) VAR_TO_STR(a) + "," + VARS_TO_STR9(b, c, d, e, f, g, h, i, j)
+#define VARS_TO_STR11(a, b, c, d, e, f, g, h, i, j, k) VAR_TO_STR(a) + "," + VARS_TO_STR10(b, c, d, e, f, g, h, i, j, k)
+#define VARS_TO_STR12(a, b, c, d, e, f, g, h, i, j, k, l) VAR_TO_STR(a) + "," + VARS_TO_STR11(b, c, d, e, f, g, h, i, j, k, l)
+
+#ifdef GGML_USE_SYCL
+static bool inline _isinf(float f) {
+    return (*(uint32_t *)&f & 0x7fffffff) == 0x7f800000;
+}
+#else
+static bool inline _isinf(float f) { return std::isinf(f); }
+#endif
+
+// accept FLT_MAX as infinity
+static bool isinf_or_max(float f) {
+    return _isinf(f) || f == FLT_MAX || f == -FLT_MAX;
+}
+
+static bool ggml_is_view_op(enum ggml_op op) {
+    return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE;
+}
+
+enum test_mode {
+    MODE_TEST,
+    MODE_PERF,
+    MODE_GRAD,
+};
+
+struct test_case {
+    virtual ~test_case() {}
+
+    virtual std::string op_desc(ggml_tensor * t) {
+        return ggml_op_desc(t);
+    }
+
+    virtual std::string vars() {
+        return "";
+    }
+
+    virtual ggml_tensor * build_graph(ggml_context * ctx) = 0;
+
+    virtual double max_nmse_err() {
+        return 1e-7;
+    }
+
+    virtual double max_maa_err() {
+        return 1e-4;
+    }
+
+    virtual float grad_eps() {
+        return 1e-1f;
+    }
+
+    // If false, estimate gradient with 2 points, neglects 3rd order derivative and higher.
+    // If true,  estimate gradient with 4 points, neglects 5th order derivative and higher.
+    virtual bool grad_precise() {
+        return false;
+    }
+
+    // Skip gradient checks if total number of gradients to be checked is larger than this (to speed up the tests).
+    virtual int64_t grad_nmax() {
+        return 10000;
+    }
+
+    // No effect if empty.
+    // If not empty, skip all gradient checks where the numerical result does not match any of the values.
+    // Needed for dealing with noncontinuous gradients (e.g. ReLU) where estimation using finite differences is unreliable.
+    virtual std::vector grad_expect() {
+        return {};
+    }
+
+    virtual void initialize_tensors(ggml_context * ctx) {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t);
+        }
+    }
+
+    virtual size_t op_size(ggml_tensor * t) {
+        size_t size = ggml_nbytes(t);
+        // add source tensors
+        for (int i = 0; i < GGML_MAX_SRC; i++) {
+            if (t->src[i] != NULL) {
+                size += ggml_nbytes(t->src[i]);
+            }
+        }
+        return size;
+    }
+
+    virtual uint64_t op_flops(ggml_tensor * t) {
+        GGML_UNUSED(t);
+        return 0;
+    }
+
+    ggml_cgraph * gf = nullptr;
+    ggml_cgraph * gb = nullptr;
+
+    static const int sentinel_size = 1024;
+
+    test_mode mode;
+
+    std::vector sentinels;
+
+    void add_sentinel(ggml_context * ctx) {
+        if (mode == MODE_PERF || mode == MODE_GRAD) {
+            return;
+        }
+        ggml_tensor * sentinel = ::ggml_new_tensor_1d(ctx, GGML_TYPE_F32, sentinel_size);
+        ggml_format_name(sentinel, "sent_%zu", sentinels.size());
+        sentinels.push_back(sentinel);
+    }
+
+    // hijack ggml_new_tensor to add sentinels after each tensor to check for overflows in the backend
+
+    ggml_tensor * ggml_new_tensor(ggml_context * ctx, ggml_type type, int n_dims, const int64_t * ne) {
+        ggml_tensor * t = ::ggml_new_tensor(ctx, type, n_dims, ne);
+        add_sentinel(ctx);
+        return t;
+    }
+
+    ggml_tensor * ggml_new_tensor_1d(ggml_context * ctx, ggml_type type, int64_t ne0) {
+        ggml_tensor * t = ::ggml_new_tensor_1d(ctx, type, ne0);
+        add_sentinel(ctx);
+        return t;
+    }
+
+    ggml_tensor * ggml_new_tensor_2d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1) {
+        ggml_tensor * t = ::ggml_new_tensor_2d(ctx, type, ne0, ne1);
+        add_sentinel(ctx);
+        return t;
+    }
+
+    ggml_tensor * ggml_new_tensor_3d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2) {
+        ggml_tensor * t = ::ggml_new_tensor_3d(ctx, type, ne0, ne1, ne2);
+        add_sentinel(ctx);
+        return t;
+    }
+
+    ggml_tensor * ggml_new_tensor_4d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
+        ggml_tensor * t = ::ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3);
+        add_sentinel(ctx);
+        return t;
+    }
+
+    bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name) {
+        mode = MODE_TEST;
+
+        ggml_init_params params = {
+            /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),
+            /* .mem_base = */ NULL,
+            /* .no_alloc = */ true,
+        };
+        ggml_context * ctx = ggml_init(params);
+        GGML_ASSERT(ctx);
+
+        gf = ggml_new_graph(ctx);
+
+        // pre-graph sentinel
+        add_sentinel(ctx);
+
+        ggml_tensor * out = build_graph(ctx);
+
+        if (op_name != nullptr && op_desc(out) != op_name) {
+            //printf("  %s: skipping\n", op_desc(out).c_str());
+            ggml_free(ctx);
+            return true;
+        }
+
+        printf("  %s(%s): ", op_desc(out).c_str(), vars().c_str());
+        fflush(stdout);
+
+        // check if the backends support the ops
+        bool supported = true;
+        for (ggml_backend_t backend : {backend1, backend2}) {
+            for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+                if (!ggml_backend_supports_op(backend, t)) {
+                    printf("not supported [%s] ", ggml_backend_name(backend));
+                    supported = false;
+                    break;
+                }
+            }
+        }
+        if (!supported) {
+            printf("\n");
+            ggml_free(ctx);
+            return true;
+        }
+
+        // post-graph sentinel
+        add_sentinel(ctx);
+
+        // allocate
+        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend1);
+
+        if (buf == NULL) {
+            printf("failed to allocate tensors [%s] ", ggml_backend_name(backend1));
+            ggml_free(ctx);
+            return false;
+        }
+
+        // build graph
+        ggml_build_forward_expand(gf, out);
+
+        // add sentinels as graph nodes so that they are checked in the callback
+        for (ggml_tensor * sentinel : sentinels) {
+            ggml_graph_add_node(gf, sentinel);
+        }
+
+        // randomize tensors
+        initialize_tensors(ctx);
+
+        // compare
+        struct callback_userdata {
+            bool   ok;
+            double max_err;
+            ggml_backend_t backend1;
+            ggml_backend_t backend2;
+        };
+
+        callback_userdata ud {
+            true,
+            max_nmse_err(),
+            backend1,
+            backend2
+        };
+
+        auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool {
+            callback_userdata * ud = (callback_userdata *) user_data;
+            const char * bn1 = ggml_backend_name(ud->backend1);
+            const char * bn2 = ggml_backend_name(ud->backend2);
+
+            if (t1->op == GGML_OP_NONE) {
+                // sentinels must be unchanged
+                std::vector t1_data(ggml_nbytes(t1));
+                std::vector t2_data(ggml_nbytes(t2));
+                ggml_backend_tensor_get(t1, t1_data.data(), 0, ggml_nbytes(t1));
+                ggml_backend_tensor_get(t2, t2_data.data(), 0, ggml_nbytes(t2));
+
+                if (memcmp(t1_data.data(), t2_data.data(), ggml_nbytes(t1)) != 0) {
+                    printf("sentinel mismatch: %s ", t1->name);
+                    ud->ok = false;
+                    return true;
+                }
+            }
+
+            std::vector f1 = tensor_to_float(t1);
+            std::vector f2 = tensor_to_float(t2);
+
+            for (size_t i = 0; i < f1.size(); i++) {
+                // check for nans
+                if (std::isnan(f1[i]) || std::isnan(f2[i])) {
+                    printf("[%s] NaN at index %zu (%s=%f %s=%f) ", ggml_op_desc(t1), i, bn1, f1[i], bn2, f2[i]);
+                    ud->ok = false;
+                    return true;
+                }
+                // check for infs: both must be inf of the same sign, or both must be finite
+                if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) {
+                    if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) {
+                        if (std::signbit(f1[i]) != std::signbit(f2[i])) {
+                            printf("[%s] inf sign mismatch: %s=%f %s=%f ", ggml_op_desc(t1), bn1, f1[i], bn2, f2[i]);
+                            ud->ok = false;
+                            return true;
+                        }
+                    } else {
+                        printf("[%s] inf mismatch: %s=%f %s=%f ", ggml_op_desc(t1), bn1, f1[i], bn2, f2[i]);
+                        ud->ok = false;
+                        return true;
+                    }
+                }
+            }
+
+            double err = nmse(f1.data(), f2.data(), f1.size());
+            if (err > ud->max_err) {
+                printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err);
+                //for (int i = 0; i < (int) f1.size(); i++) {
+                //    printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
+                //}
+                //printf("\n");
+                //exit(1);
+                ud->ok = false;
+            }
+            return true;
+
+            GGML_UNUSED(index);
+        };
+
+        const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud);
+
+        if (!cmp_ok) {
+            printf("compare failed ");
+        }
+
+        ggml_backend_buffer_free(buf);
+
+        ggml_free(ctx);
+
+        if (ud.ok && cmp_ok) {
+            printf("\033[1;32mOK\033[0m\n");
+            return true;
+        }
+
+        printf("\033[1;31mFAIL\033[0m\n");
+        return false;
+    }
+
+    bool eval_perf(ggml_backend_t backend, const char * op_name) {
+        mode = MODE_PERF;
+
+        static const size_t graph_nodes = 8192;
+
+        ggml_init_params params = {
+            /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead_custom(graph_nodes, false),
+            /* .mem_base = */ NULL,
+            /* .no_alloc = */ true,
+        };
+        ggml_context_ptr ctx(ggml_init(params)); // smart ptr
+        GGML_ASSERT(ctx);
+
+        ggml_tensor * out = build_graph(ctx.get());
+
+        if (op_name != nullptr && op_desc(out) != op_name) {
+            //printf("  %s: skipping\n", op_desc(out).c_str());
+            return true;
+        }
+
+        int len = printf("  %s(%s): ", op_desc(out).c_str(), vars().c_str());
+        fflush(stdout);
+
+        // check if backends support op
+        if (!ggml_backend_supports_op(backend, out)) {
+            printf("not supported\n");
+            return true;
+        }
+
+        // align while also leaving some margin for variations in parameters
+        int align = 8;
+        int last = (len + align - 1) / align * align;
+        if (last - len < 5) {
+            last += align;
+        }
+        printf("%*s", last - len, "");
+
+        // allocate
+        ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr
+
+        if (buf == NULL) {
+            printf("failed to allocate tensors\n");
+            return false;
+        }
+
+        // randomize tensors
+        initialize_tensors(ctx.get());
+
+        // build graph
+        ggml_cgraph * gf = ggml_new_graph_custom(ctx.get(), graph_nodes, false);
+        ggml_build_forward_expand(gf, out);
+
+        // warmup run
+        ggml_status status = ggml_backend_graph_compute(backend, gf);
+        if (status != GGML_STATUS_SUCCESS) {
+            fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
+            return false;
+        }
+
+        // determine number of runs
+        int n_runs;
+        bool is_cpu = ggml_backend_dev_type(ggml_backend_get_device(backend)) == GGML_BACKEND_DEVICE_TYPE_CPU;
+        if (op_flops(out) > 0) {
+            // based on flops
+            const uint64_t GFLOP = 1000 * 1000 * 1000;
+            const uint64_t target_flops_cpu =   8ULL * GFLOP;
+            const uint64_t target_flops_gpu = 100ULL * GFLOP;
+            uint64_t target_flops = is_cpu ? target_flops_cpu : target_flops_gpu;
+            n_runs = std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1;
+        } else {
+            // based on memory size
+            const size_t GB = 1ULL << 30;
+            const size_t target_size_cpu =  8 * GB;
+            const size_t target_size_gpu = 32 * GB;
+            size_t target_size = is_cpu ? target_size_cpu : target_size_gpu;
+            n_runs = std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1;
+        }
+
+        // duplicate the op
+        for (int i = 1; i < n_runs; i++) {
+            ggml_graph_add_node(gf, out);
+        }
+
+        // calculate memory
+        size_t mem = n_runs * op_size(out);
+        auto tensor_op_size = [](ggml_tensor * t) {
+            size_t size = ggml_nbytes(t);
+            // add source tensors
+            for (int i = 0; i < GGML_MAX_SRC; i++) {
+                if (t->src[i] != NULL) {
+                    size += ggml_nbytes(t->src[i]);
+                }
+            }
+            return size;
+        };
+        for (int i = 0; i < ggml_graph_n_nodes(gf); ++i) {
+            if (ggml_is_view_op(ggml_graph_node(gf, i)->op) || ggml_graph_node(gf, i) == out) {
+                continue;
+            }
+            mem += tensor_op_size(ggml_graph_node(gf, i));
+        }
+
+        // run
+        int64_t total_time_us = 0;
+        int64_t total_mem = 0;
+        int total_runs = 0;
+        do {
+            int64_t start_time = ggml_time_us();
+            ggml_status status = ggml_backend_graph_compute(backend, gf);
+            if (status != GGML_STATUS_SUCCESS) {
+                fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
+                return false;
+            }
+            int64_t end_time = ggml_time_us();
+
+            total_time_us += end_time - start_time;
+            total_mem += mem;
+            total_runs += n_runs;
+        } while (total_time_us < 1000*1000); // run for at least 1 second
+
+        printf("    %8d runs - %8.2f us/run - ",
+            total_runs,
+            (double)total_time_us / total_runs);
+
+        if (op_flops(out) > 0) {
+            double flops_per_sec = (op_flops(out) * total_runs) / (total_time_us / 1e6);
+            auto format_flops = [](double flops) -> std::string {
+                char buf[256];
+                if (flops >= 1e12) {
+                    snprintf(buf, sizeof(buf), "%6.2f TFLOP", flops / 1e12);
+                } else if (flops >= 1e9) {
+                    snprintf(buf, sizeof(buf), "%6.2f GFLOP", flops / 1e9);
+                } else if (flops >= 1e6) {
+                    snprintf(buf, sizeof(buf), "%6.2f MFLOP", flops / 1e6);
+                } else {
+                    snprintf(buf, sizeof(buf), "%6.2f KFLOP", flops / 1e3);
+                }
+                return buf;
+            };
+            printf("%s/run - \033[1;34m%sS\033[0m",
+                format_flops(op_flops(out)).c_str(),
+                format_flops(flops_per_sec).c_str());
+
+        } else {
+            printf("%8zu kB/run - \033[1;34m%7.2f GB/s\033[0m",
+                op_size(out) / 1024,
+                total_mem / (total_time_us / 1e6) / 1024.0 / 1024.0 / 1024.0);
+        }
+        printf("\n");
+
+        return true;
+    }
+
+    bool eval_grad(ggml_backend_t backend, const char * op_name) {
+        mode = MODE_GRAD;
+        const std::vector expect = grad_expect();
+
+        ggml_init_params params = {
+            /* .mem_size = */ ggml_tensor_overhead()*128 + 2*ggml_graph_overhead_custom(GGML_DEFAULT_GRAPH_SIZE, true),
+            /* .mem_base = */ NULL,
+            /* .no_alloc = */ true,
+        };
+        ggml_context_ptr ctx(ggml_init(params)); // smart ptr
+        GGML_ASSERT(ctx);
+
+        gf = ggml_new_graph_custom(ctx.get(), GGML_DEFAULT_GRAPH_SIZE, true);
+        gb = ggml_new_graph_custom(ctx.get(), GGML_DEFAULT_GRAPH_SIZE, true);
+
+        ggml_tensor * out = build_graph(ctx.get());
+
+        if ((op_name != nullptr && op_desc(out) != op_name) || out->op == GGML_OP_OPT_STEP_ADAMW) {
+            //printf("  %s: skipping\n", op_desc(out).c_str());
+            return true;
+        }
+
+        printf("  %s(%s): ", op_desc(out).c_str(), vars().c_str());
+        fflush(stdout);
+
+        if (out->type != GGML_TYPE_F32) {
+            printf("not supported [%s->type != FP32]\n", out->name);
+            return true;
+        }
+
+        // check if the backend supports the ops
+        bool supported = true;
+        bool any_params = false;
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
+            if (!ggml_backend_supports_op(backend, t)) {
+                printf("not supported [%s] ", ggml_backend_name(backend));
+                supported = false;
+                break;
+            }
+            if ((t->flags & GGML_TENSOR_FLAG_PARAM)) {
+                any_params = true;
+                if (t->type != GGML_TYPE_F32) {
+                    printf("not supported [%s->type != FP32] ", t->name);
+                    supported = false;
+                    break;
+                }
+            }
+        }
+        if (!any_params) {
+            printf("not supported [%s] \n", op_desc(out).c_str());
+            supported = false;
+        }
+        if (!supported) {
+            printf("\n");
+            return true;
+        }
+
+        int64_t ngrads = 0;
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
+            if (t->flags & GGML_TENSOR_FLAG_PARAM) {
+                ngrads += ggml_nelements(t);
+            }
+        }
+        if (ngrads > grad_nmax()) {
+            printf("skipping large tensors for speed \n");
+            return true;
+        }
+
+
+        if (!ggml_is_scalar(out)) {
+            out = ggml_sum(ctx.get(), out);
+            ggml_set_name(out, "sum_of_out");
+        }
+        ggml_set_loss(out);
+
+        ggml_build_forward_expand(gf, out);
+        ggml_graph_cpy(gf, gb);
+        ggml_build_backward_expand(ctx.get(), gb, nullptr);
+        if (expect.size() != 1 || expect[0] != 0.0f) {
+            GGML_ASSERT(ggml_graph_n_nodes(gb) > ggml_graph_n_nodes(gf));
+            for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
+                GGML_ASSERT(!(t->flags & GGML_TENSOR_FLAG_PARAM) || ggml_graph_get_grad(gb, t)->op != GGML_OP_NONE);
+            }
+        }
+
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
+            if (!ggml_backend_supports_op(backend, t)) {
+                printf("not supported [%s] ", ggml_backend_name(backend));
+                supported = false;
+                break;
+            }
+            if ((t->flags & GGML_TENSOR_FLAG_PARAM) && t->type != GGML_TYPE_F32) {
+                printf("not supported [%s->type != FP32] ", t->name);
+                supported = false;
+                break;
+            }
+        }
+        if (!supported) {
+            printf("\n");
+            return true;
+        }
+
+        // allocate
+        ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr
+        if (buf == NULL) {
+            printf("failed to allocate tensors [%s] ", ggml_backend_name(backend));
+            return false;
+        }
+
+        initialize_tensors(ctx.get()); // Randomizes all tensors (including gradients).
+        ggml_graph_reset(gb);    // Sets gradients to 1 if loss, 0 otherwise.
+
+        ggml_status status = ggml_backend_graph_compute(backend, gf);
+        if (status != GGML_STATUS_SUCCESS) {
+            fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
+            return false;
+        }
+        status = ggml_backend_graph_compute(backend, gb);
+        if (status != GGML_STATUS_SUCCESS) {
+            fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
+            return false;
+        }
+
+        bool ok = true;
+        for (struct ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != nullptr; t = ggml_get_next_tensor(ctx.get(), t)) {
+            if (!(t->flags & GGML_TENSOR_FLAG_PARAM)) {
+                continue;
+            }
+
+            const char * bn = ggml_backend_name(backend);
+            const int64_t ne = ggml_nelements(t);
+
+            std::vector ga;
+            struct ggml_tensor * grad = ggml_graph_get_grad(gb, t);
+            if (grad) {
+                ga = tensor_to_float(grad);
+            } else {
+                ga.resize(ne); // default value is 0.0f
+            }
+
+            for (int64_t i = 0; i < ne; ++i) { // gradient algebraic
+                // check for nans
+                if (!std::isfinite(ga[i])) {
+                    printf("[%s] nonfinite gradient at index %" PRId64 " (%s=%f) ", ggml_op_desc(t), i, bn, ga[i]);
+                    ok = false;
+                    break;
+                }
+            }
+            if (!ok) {
+                break;
+            }
+
+            std::vector gn(ne); // gradient numeric
+            GGML_ASSERT(ga.size() == gn.size());
+
+            std::vector x0 = tensor_to_float(t); // original t data
+            GGML_ASSERT(ggml_is_scalar(out));
+            GGML_ASSERT(out->type == GGML_TYPE_F32);
+
+            const float eps = grad_eps();
+            for (int64_t i = 0; i < ne; ++i) {
+                const float xiu  = x0[i] + 1.0f*eps; // x, index i, up
+                const float xiuh = x0[i] + 0.5f*eps; // x, index i, up half
+                const float xidh = x0[i] - 0.5f*eps; // x, index i, down half
+                const float xid  = x0[i] - 1.0f*eps; // x, index i, down
+
+                float fu, fuh, fdh, fd; // output values for xiu, xiuh, xid, xidh
+
+                ggml_backend_tensor_set(t, &xiu, i*sizeof(float), sizeof(float));
+                status = ggml_backend_graph_compute(backend, gf);
+                if (status != GGML_STATUS_SUCCESS) {
+                    fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
+                    return false;
+                }
+                ggml_backend_tensor_get(out, &fu, 0, ggml_nbytes(out));
+
+                ggml_backend_tensor_set(t, &xid, i*sizeof(float), sizeof(float));
+                status = ggml_backend_graph_compute(backend, gf);
+                if (status != GGML_STATUS_SUCCESS) {
+                    fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
+                    return false;
+                }
+                ggml_backend_tensor_get(out, &fd, 0, ggml_nbytes(out));
+
+                if (grad_precise()) {
+                    ggml_backend_tensor_set(t, &xiuh, i*sizeof(float), sizeof(float));
+                    status = ggml_backend_graph_compute(backend, gf);
+                    if (status != GGML_STATUS_SUCCESS) {
+                        fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
+                        return false;
+                    }
+                    ggml_backend_tensor_get(out, &fuh, 0, ggml_nbytes(out));
+
+                    ggml_backend_tensor_set(t, &xidh, i*sizeof(float), sizeof(float));
+                    status = ggml_backend_graph_compute(backend, gf);
+                    if (status != GGML_STATUS_SUCCESS) {
+                        fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
+                        return false;
+                    }
+                    ggml_backend_tensor_get(out, &fdh, 0, ggml_nbytes(out));
+
+                    gn[i] = (8.0*(double)fuh + (double)fd - (8.0*(double)fdh + (double)fu)) / (6.0*(double)eps);
+                } else {
+                    gn[i] = (fu - fd) / (2.0f*eps);
+                }
+
+                ggml_backend_tensor_set(t, x0.data(), 0, ggml_nbytes(t));
+            }
+
+            const double err = mean_abs_asymm(gn.data(), ga.data(), gn.size(), expect);
+            if (err > max_maa_err()) {
+                printf("[%s] MAA = %.9f > %.9f ", ggml_op_desc(t), err, max_maa_err());
+                ok = false;
+                break;
+            }
+            if (!ok) {
+                break;
+            }
+        }
+
+        if (!ok) {
+            printf("compare failed ");
+        }
+
+        if (ok) {
+            printf("\033[1;32mOK\033[0m\n");
+            return true;
+        }
+
+        printf("\033[1;31mFAIL\033[0m\n");
+        return false;
+    }
+};
+
+
+// ###################################
+// ## Section 2: GGML Op Defintions ##
+// ###################################
+
+
+// The following is an example showing the bare minimum for creating a test for a GGML op.
+
+// GGML_OP_EXAMPLE
+struct test_example : public test_case {
+    // Always define these 2 or variants thereof:
+    const ggml_type type; // The type of the input tensors.
+    const std::array ne; // The shape of the input tensors.
+    // For some ops it's necessary to define multiple types or shapes for the inputs.
+    // Or they may need additional parameters.
+
+    // Put all parameters needed to fully define the test into one of the VARS_TO_STR macros.
+    // In most cases these are just the properties of the struct that you defined above.
+    // This is needed for info prints.
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    // Define a constructor for the struct.
+    // In most cases it will be sufficient to have the same arguments as the struct has properties
+    // and just use initializer lists.
+    test_example(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 5, 4, 3})
+        : type(type), ne(ne) {}
+
+    // Define how a simple GGML compute graph can be constructed for the new GGML op.
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        // Step 1: create input tensors that don't depend on any other tensors:
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a"); // Setting names is optional but it's useful for debugging.
+
+        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(b, "b");
+
+        // Step 2: use the op that you want to test in the GGML compute graph.
+        ggml_tensor * out = ggml_add(ctx, a, b); // For this example we're just doing a simple addition.
+        ggml_set_name(out, "out");
+
+        // Step 3: return the output tensor.
+        return out;
+    }
+    // In order to also check the gradients for your op, add calls like ggml_set_param(a)
+    // immediately after you create the tensors.
+    // This is optional and only makes sense if a backward pass has actually been implemented for the new op.
+};
+
+
+// GGML_OP_UNARY
+struct test_unary : public test_case {
+    const ggml_unary_op op;
+    const ggml_type type;
+    const std::array ne_a;
+    int v; // view (1 : non-contiguous a)
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne_a, v);
+    }
+
+    test_unary(ggml_unary_op op,
+            ggml_type type = GGML_TYPE_F32,
+            std::array ne_a = {128, 2, 2, 2},
+            int v = 0)
+        : op(op), type(type), ne_a(ne_a), v(v) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        const bool grad_supported = op == GGML_UNARY_OP_ABS || op == GGML_UNARY_OP_SGN || op == GGML_UNARY_OP_NEG ||
+            op == GGML_UNARY_OP_STEP || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU;
+
+        ggml_tensor * a;
+        if (v & 1) {
+            auto ne = ne_a; ne[0] *= 3;
+            a = ggml_new_tensor(ctx, type, 4, ne.data());
+            if (grad_supported) {
+                ggml_set_param(a);
+            }
+            ggml_set_name(a, "a");
+
+            a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
+            ggml_set_name(a, "view_of_a");
+        } else {
+            a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+            if (grad_supported) {
+                ggml_set_param(a);
+            }
+            ggml_set_name(a, "a");
+        }
+
+        ggml_tensor * out = ggml_unary(ctx, a, op);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            // test extended range of values to check for NaNs in GELU
+            init_tensor_uniform(t, -150.f, 150.f);
+        }
+    }
+
+    float grad_eps() override {
+        return 15.0f;
+    }
+
+    std::vector grad_expect() override {
+        if (op == GGML_UNARY_OP_ABS) {
+            return {-1.0f, 1.0f};
+        }
+        if (op == GGML_UNARY_OP_SGN || op == GGML_UNARY_OP_STEP) {
+            return {0.0f};
+        }
+        if (op == GGML_UNARY_OP_RELU) {
+            return {0.0f, 1.0f};
+        }
+        return {};
+    }
+
+};
+
+// GGML_OP_GET_ROWS
+struct test_get_rows : public test_case {
+    const ggml_type type;
+    const int n; // cols
+    const int m; // rows
+    const int r; // rows to get
+    const int b; // batch size
+    const bool v; // view (non-contiguous src1)
+
+    std::string vars() override {
+        return VARS_TO_STR6(type, n, m, r, b, v);
+    }
+
+    test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, bool v = false)
+        : type(type), n(n), m(m), r(r), b(b), v(v) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * in = ggml_new_tensor_3d(ctx, type, n, m, b);
+        ggml_set_name(in, "in");
+
+        ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b);
+        ggml_set_name(rows, "rows");
+        if (v) {
+            rows = ggml_view_2d(ctx, rows, r/2, b, rows->nb[1], 0);
+            ggml_set_name(rows, "view_of_rows");
+        }
+
+        const bool grad_supported = ggml_is_matrix(in) && ggml_is_vector(rows);
+        if (grad_supported) {
+            ggml_set_param(in);
+            // rows is a constant input -> no gradients
+        }
+
+        ggml_tensor * out = ggml_get_rows(ctx, in, rows);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_I32) {
+                if (ggml_is_view_op(t->op)) { continue; }
+                // rows
+                std::vector data(r*b);
+                for (int i = 0; i < r*b; i++) {
+                    data[i] = rand() % m;
+                }
+                ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int));
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+};
+
+// GGML_OP_GET_ROWS_BACK
+struct test_get_rows_back : public test_case {
+    const ggml_type type;
+    const int n; // cols
+    const int m; // rows
+    const int r; // rows to get
+    const int b; // batch size
+    const bool v; // view (non-contiguous src1)
+
+    std::string vars() override {
+        return VARS_TO_STR6(type, n, m, r, b, v);
+    }
+
+    test_get_rows_back(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, bool v = false)
+        : type(type), n(n), m(m), r(r), b(b), v(v) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * in_forward = ggml_new_tensor_3d(ctx, type, n, m, b);
+        ggml_set_name(in_forward, "in_forward");
+
+        ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b);
+        ggml_set_name(rows, "rows");
+        if (v) {
+            rows = ggml_view_2d(ctx, rows, r/2, b, rows->nb[1], 0);
+            ggml_set_name(rows, "view_of_rows");
+        }
+
+        ggml_tensor * grad = ggml_new_tensor_3d(ctx, type, n, r, b);
+        ggml_set_name(grad, "grad");
+
+        ggml_tensor * out = ggml_get_rows_back(ctx, grad, rows, in_forward);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_I32) {
+                if (ggml_is_view_op(t->op)) { continue; }
+                // rows
+                std::vector data(r*b);
+                for (int i = 0; i < r*b; i++) {
+                    data[i] = rand() % m;
+                }
+                ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int));
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+};
+
+// GGML_OP_ARGMAX
+struct test_argmax : public test_case {
+    const ggml_type type;
+    const std::array ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_argmax(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 100, 1, 1})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_argmax(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        std::random_device rd;
+        std::default_random_engine rng(rd());
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_F32) {
+                // initialize with unique values to avoid ties
+                for (int64_t r = 0; r < ggml_nrows(t); r++) {
+                    std::vector data(t->ne[0]);
+                    for (int i = 0; i < t->ne[0]; i++) {
+                        data[i] = i;
+                    }
+                    std::shuffle(data.begin(), data.end(), rng);
+                    ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
+                }
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+
+    double max_nmse_err() override {
+        return 0.0;
+    }
+};
+
+// GGML_OP_COUNT_EQUAL
+struct test_count_equal : public test_case {
+    const ggml_type type;
+    const std::array ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_count_equal(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {4, 500, 1, 1})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * a_argmax = ggml_argmax(ctx, a);
+        ggml_set_name(a_argmax, "a_argmax");
+
+        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(b, "b");
+
+        ggml_tensor * b_argmax = ggml_argmax(ctx, b);
+        ggml_set_name(b_argmax, "b_argmax");
+
+        ggml_tensor * out = ggml_count_equal(ctx, a_argmax, b_argmax);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    double max_nmse_err() override {
+        return 0.0;
+    }
+};
+
+// GGML_OP_REPEAT
+struct test_repeat : public test_case {
+    const ggml_type type;
+    const std::array ne;
+    const std::array nr;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, nr);
+    }
+
+    size_t op_size(ggml_tensor * t) override {
+        return ggml_nbytes(t) * 2;
+    }
+
+    test_repeat(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 5, 4, 3},
+            std::array nr = {2, 2, 2, 2})
+        : type(type), ne(ne), nr(nr) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * target = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
+        ggml_set_name(target, "target");
+
+        ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(src);
+        ggml_set_name(src, "src");
+
+        ggml_tensor * out = ggml_repeat(ctx, src, target);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_REPEAT_BACK
+struct test_repeat_back : public test_case {
+    const ggml_type type;
+    const std::array ne;
+    const std::array nr;
+    const bool v; // whether src is a noncontiguous view
+
+    std::string vars() override {
+        return VARS_TO_STR4(type, ne, nr, v);
+    }
+
+    size_t op_size(ggml_tensor * t) override {
+        return ggml_nbytes(t) * 2;
+    }
+
+    test_repeat_back(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {8, 6, 4, 2},
+            std::array nr = {2, 2, 2, 2},
+            bool v = false)
+        : type(type), ne(ne), nr(nr), v(v) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * src = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
+        ggml_set_name(src, "src");
+
+        if (v) {
+            GGML_ASSERT(ne[0] % 2 == 0);
+            GGML_ASSERT(ne[1] % 2 == 0);
+            GGML_ASSERT(ne[2] % 2 == 0);
+            GGML_ASSERT(ne[3] % 2 == 0);
+            GGML_ASSERT(nr[0] % 2 == 0 || nr[0] == 1);
+            GGML_ASSERT(nr[1] % 2 == 0 || nr[1] == 1);
+            GGML_ASSERT(nr[2] % 2 == 0 || nr[2] == 1);
+            GGML_ASSERT(nr[3] % 2 == 0 || nr[3] == 1);
+
+            const int64_t ne00 = nr[0] == 1 ? src->ne[0] : src->ne[0] / 2;
+            const int64_t ne01 = nr[1] == 1 ? src->ne[1] : src->ne[1] / 2;
+            const int64_t ne02 = nr[2] == 1 ? src->ne[2] : src->ne[2] / 2;
+            const int64_t ne03 = nr[3] == 1 ? src->ne[3] : src->ne[3] / 2;
+
+            src = ggml_view_4d(ctx, src, ne00, ne01, ne02, ne03, src->nb[1], src->nb[2], src->nb[3], 0);
+        }
+
+        ggml_tensor * target = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(target, "target");
+
+        ggml_tensor * out = ggml_repeat_back(ctx, src, target);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_DUP
+struct test_dup : public test_case {
+    const ggml_type type;
+    const std::array ne;
+    const std::array permute;
+    bool _use_permute;
+
+    std::string vars() override {
+        std::string v = VARS_TO_STR2(type, ne);
+        if (_use_permute) v += "," + VAR_TO_STR(permute);
+        return v;
+    }
+
+    test_dup(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 10, 20, 1},
+            std::array permute = {0, 0, 0, 0})
+        : type(type), ne(ne), permute(permute),
+            _use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(src);
+        ggml_set_name(src, "src");
+
+        if (_use_permute) {
+            src = ggml_permute(ctx, src, permute[0], permute[1], permute[2], permute[3]);
+            ggml_set_name(src, "src_permuted");
+        }
+
+        ggml_tensor * out = ggml_dup(ctx, src);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_SET
+struct test_set : public test_case {
+    const ggml_type type_src;
+    const ggml_type type_dst;
+    const std::array ne;
+    const int dim;
+
+    std::string vars() override {
+        return VARS_TO_STR4(type_src, type_dst, ne, dim);
+    }
+
+    size_t op_size(ggml_tensor * t) override {
+        return ggml_nbytes(t) + ggml_nbytes(t->src[0]);
+    }
+
+    test_set(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
+            std::array ne = {6, 5, 4, 3}, int dim = 1)
+        : type_src(type_src), type_dst(type_dst), ne(ne), dim(dim) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
+        ggml_set_param(src);
+        ggml_set_name(src, "src");
+
+        auto ne_dst = ne;
+        for (int i = 0; i < dim; ++i) {
+            ne_dst[i] *= 2;
+        }
+        ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data());
+        ggml_set_param(dst);
+        ggml_set_name(dst, "dst");
+
+        size_t offset = 0;
+        for (int i = 0; i < dim; ++i) {
+            offset += ((ne_dst[i] - ne[i])/2)*dst->nb[i];
+        }
+        ggml_tensor * out = ggml_set(ctx, dst, src,
+            // The backward pass requires setting a contiguous region:
+            src->nb[1], src->nb[2], src->nb[3], offset);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_CPY
+struct test_cpy : public test_case {
+    const ggml_type type_src;
+    const ggml_type type_dst;
+    const std::array ne;
+    const std::array permute_src;
+    const std::array permute_dst;
+    bool _src_use_permute;
+    bool _dst_use_permute;
+
+    std::string vars() override {
+        return VARS_TO_STR5(type_src, type_dst, ne, permute_src, permute_dst);
+    }
+
+    double max_nmse_err() override {
+        return 1e-6;
+    }
+
+    size_t op_size(ggml_tensor * t) override {
+        return ggml_nbytes(t) + ggml_nbytes(t->src[0]);
+    }
+
+    test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
+            std::array ne = {10, 10, 10, 1},
+            std::array permute_src = {0, 0, 0, 0},
+            std::array permute_dst = {0, 0, 0, 0})
+        : type_src(type_src), type_dst(type_dst), ne(ne), permute_src(permute_src), permute_dst(permute_dst),
+          _src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0),
+          _dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
+        ggml_set_param(src);
+        ggml_set_name(src, "src");
+
+        if (_src_use_permute) {
+            src = ggml_permute(ctx, src, permute_src[0], permute_src[1], permute_src[2], permute_src[3]);
+            ggml_set_name(src, "src_permuted");
+        }
+
+        ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, src->ne);
+        ggml_set_name(dst, "dst");
+
+        if (_dst_use_permute) {
+            dst = ggml_permute(ctx, dst, permute_dst[0], permute_dst[1], permute_dst[2], permute_dst[3]);
+            ggml_set_name(dst, "dst_permuted");
+        }
+
+        ggml_tensor * out = ggml_cpy(ctx, src, dst);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_CONT
+struct test_cont : public test_case {
+    const ggml_type type;
+    const std::array ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_cont(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 10, 10, 1})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(src);
+        ggml_set_name(src, "src");
+
+        src = ggml_transpose(ctx, src);
+        ggml_set_name(src, "src_transposed");
+
+        ggml_tensor * out = ggml_cont(ctx, src);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_ADD
+// GGML_OP_SUB
+// GGML_OP_MUL
+// GGML_OP_DIV
+struct test_bin_bcast : public test_case {
+    using op_t = ggml_tensor * (*) (ggml_context *, ggml_tensor *, ggml_tensor *);
+    op_t op;
+    const ggml_type type;
+    const std::array ne;
+    const std::array nr;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, nr);
+    }
+
+    size_t op_size(ggml_tensor * t) override {
+        return ggml_nbytes(t) * 3;
+    }
+
+    test_bin_bcast(op_t op, ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 10, 1, 1},
+            std::array nr = {1, 2, 1, 1})
+        : op(op), type(type), ne(ne), nr(nr) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(b, "b");
+
+        // The backward pass supports broadcasting only for GGML_ADD:
+        const bool grad_supported = op == ggml_add || ggml_are_same_shape(a, b);
+        if (grad_supported) {
+            ggml_set_param(a);
+            ggml_set_param(b);
+        }
+
+        ggml_tensor * out = op(ctx, a, b);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (op == ggml_mul || op == ggml_div) {
+                // MUL and DIV have numerical issues around zero:
+                init_tensor_uniform(t, 0.9f, 1.1f);
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+
+    float grad_eps() override {
+        return 0.1f * (op == ggml_mul ? ne[0]*ne[1]*ne[2]*ne[3] : 1);
+    }
+
+    bool grad_precise() override {
+        return op == ggml_div;
+    }
+
+    double max_maa_err() override {
+        return op == ggml_add ? 1e-4 : 1e-3;
+    }
+};
+
+// GGML_OP_ADD1
+struct test_add1 : public test_case {
+    const ggml_type type;
+    const std::array ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_add1(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 5, 4, 3})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * b = ggml_new_tensor_1d(ctx, type, 1);
+        // ggml_set_param(b); // TODO: implement
+        ggml_set_name(b, "b");
+
+        ggml_tensor * out = ggml_add1(ctx, a, b);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    float grad_eps() override {
+        return 0.1f * ne[0]*ne[1]*ne[2]*ne[3];
+    }
+};
+
+// GGML_OP_SCALE
+struct test_scale : public test_case {
+    const ggml_type type;
+    const std::array ne;
+    float scale;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, scale);
+    }
+
+    test_scale(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 10, 10, 10},
+            float scale = 2.0f)
+        : type(type), ne(ne), scale(scale) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_scale(ctx, a, scale);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_SILU_BACK
+struct test_silu_back : public test_case {
+    const ggml_type type;
+    const std::array ne;
+    float eps;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, eps);
+    }
+
+    test_silu_back(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {64, 5, 4, 3},
+            float eps = 1e-6f)
+        : type(type), ne(ne), eps(eps) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * grad = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(grad, "grad");
+
+        ggml_tensor * out = ggml_silu_back(ctx, a, grad);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
+};
+
+// GGML_OP_NORM
+struct test_norm : public test_case {
+    const ggml_type type;
+    const std::array ne;
+    const bool v; // whether a is a non-contiguous view
+    const float eps;
+
+    std::string vars() override {
+        return VARS_TO_STR4(type, ne, v, eps);
+    }
+
+    test_norm(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {64, 5, 4, 3},
+            bool v = false,
+            float eps = 1e-6f)
+        : type(type), ne(ne), v(v), eps(eps) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
+        if (v) {
+            a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
+            ggml_set_name(a, "view of a");
+        }
+
+        ggml_tensor * out = ggml_norm(ctx, a, eps);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_RMS_NORM
+struct test_rms_norm : public test_case {
+    const ggml_type type;
+    const std::array ne;
+    const bool v; // whether a is a non-contiguous view
+    const float eps;
+
+    std::string vars() override {
+        return VARS_TO_STR4(type, ne, v, eps);
+    }
+
+    test_rms_norm(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {64, 5, 4, 3},
+            bool v = false,
+            float eps = 1e-6f)
+        : type(type), ne(ne), v(v), eps(eps) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        if (v) {
+            a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
+            ggml_set_name(a, "view of a");
+        }
+
+        ggml_tensor * out = ggml_rms_norm(ctx, a, eps);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, -10.f, 10.f);
+        }
+    }
+
+    float grad_eps() override {
+        return 1.0f;
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
+};
+
+// GGML_OP_RMS_NORM_BACK
+struct test_rms_norm_back : public test_case {
+    const ggml_type type;
+    const std::array ne;
+    const float eps;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, eps);
+    }
+
+    test_rms_norm_back(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {64, 5, 4, 3},
+            float eps = 1e-6f)
+        : type(type), ne(ne), eps(eps) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(b, "b");
+
+        ggml_tensor * out = ggml_rms_norm_back(ctx, a, b, eps);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, -10.f, 10.f);
+        }
+    }
+};
+
+// GGML_OP_SSM_CONV
+struct test_ssm_conv : public test_case {
+    const ggml_type type;
+    const std::array ne_a;
+    const std::array ne_b;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne_a, ne_b);
+    }
+
+    test_ssm_conv(ggml_type type = GGML_TYPE_F32,
+            std::array ne_a = {10, 10, 10, 1},
+            std::array ne_b = {3, 3, 1, 1})
+        : type(type), ne_a(ne_a), ne_b(ne_b) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a   = ggml_new_tensor(ctx, type, 4, ne_a.data());
+        ggml_tensor * b   = ggml_new_tensor(ctx, type, 4, ne_b.data());
+        ggml_tensor * out = ggml_ssm_conv(ctx, a, b);
+        return out;
+    }
+};
+
+// GGML_OP_SSM_SCAN
+struct test_ssm_scan : public test_case {
+    const ggml_type type;
+
+    const int64_t d_state;
+    const int64_t d_inner;
+    const int64_t n_seq_tokens;
+    const int64_t n_seqs;
+
+    std::string vars() override {
+        return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs);
+    }
+
+    test_ssm_scan(ggml_type type = GGML_TYPE_F32,
+            int64_t d_state = 32, int64_t d_inner = 32, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
+        : type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * s   = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, d_inner,      n_seqs, 1 }.data());
+        ggml_tensor * x   = ggml_new_tensor(ctx, type, 4, std::vector{ d_inner, n_seq_tokens, n_seqs, 1 }.data());
+        ggml_tensor * dt  = ggml_new_tensor(ctx, type, 4, std::vector{ d_inner, n_seq_tokens, n_seqs, 1 }.data());
+        ggml_tensor * A   = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, d_inner,      1     , 1 }.data());
+        ggml_tensor * B   = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, n_seq_tokens, n_seqs, 1 }.data());
+        ggml_tensor * C   = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, n_seq_tokens, n_seqs, 1 }.data());
+        ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C);
+        return out;
+    }
+};
+
+// GGML_OP_RWKV_WKV6
+struct test_rwkv_wkv6 : public test_case {
+    const ggml_type type;
+
+    const int64_t head_count;
+    const int64_t head_size;
+    const int64_t n_seq_tokens;
+    const int64_t n_seqs;
+
+    std::string vars() override {
+        return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
+    }
+
+    test_rwkv_wkv6(ggml_type type = GGML_TYPE_F32,
+            int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
+        : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        const int64_t n_tokens = n_seq_tokens * n_seqs;
+        ggml_tensor * r   = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * k   = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * v   = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * tf  = ggml_new_tensor(ctx, type, 2, std::vector{ head_size, head_count }.data());
+        ggml_tensor * td  = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * s   = ggml_new_tensor(ctx, type, 2, std::vector{ head_size * head_size * head_count, n_seqs }.data());
+        ggml_tensor * out = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, s);
+        return out;
+    }
+};
+
+// GGML_OP_GATED_LINEAR_ATTN
+struct test_gla : public test_case {
+    const ggml_type type;
+
+    const int64_t head_count;
+    const int64_t head_size;
+    const int64_t n_seq_tokens;
+    const int64_t n_seqs;
+
+    std::string vars() override {
+        return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
+    }
+
+    test_gla(ggml_type type = GGML_TYPE_F32,
+            int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
+        : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        const int64_t n_tokens = n_seq_tokens * n_seqs;
+        ggml_tensor * q   = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * k   = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * v   = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * g   = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * s   = ggml_new_tensor(ctx, type, 2, std::vector{ head_size * head_size * head_count, n_seqs }.data());
+        ggml_tensor * out = ggml_gated_linear_attn(ctx, k, v, q, g, s, pow(head_size, -0.5));
+        return out;
+    }
+};
+
+// GGML_OP_RWKV_WKV7
+struct test_rwkv_wkv7 : public test_case {
+    const ggml_type type;
+
+    const int64_t head_count;
+    const int64_t head_size;
+    const int64_t n_seq_tokens;
+    const int64_t n_seqs;
+
+    std::string vars() override {
+        return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
+    }
+
+    test_rwkv_wkv7(ggml_type type = GGML_TYPE_F32,
+            int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
+        : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        const int64_t n_tokens = n_seq_tokens * n_seqs;
+        ggml_tensor * r   = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * w   = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * k   = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * v   = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * a   = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * b   = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data());
+        // Outputs may become NaN with long seqlen without these normalization
+        a = ggml_l2_norm(ctx, a, 1e-7F);
+        b = ggml_l2_norm(ctx, b, 1e-7F);
+        ggml_tensor * s   = ggml_new_tensor(ctx, type, 2, std::vector{ head_size * head_size * head_count, n_seqs }.data());
+        ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s);
+        return out;
+    }
+};
+
+// GGML_OP_MUL_MAT
+struct test_mul_mat : public test_case {
+    const ggml_type type_a;
+    const ggml_type type_b;
+    const int64_t m;
+    const int64_t n;
+    const int64_t k;
+    const std::array bs;  // dims 3 and 4
+    const std::array nr;  // repeat in dims 3 and 4
+    const std::array per; // permutation of dimensions
+    const bool v; // whether a and b are non-contiguous views
+
+    std::string vars() override {
+        return VARS_TO_STR9(type_a, type_b, m, n, k, bs, nr, per, v);
+    }
+
+    double max_nmse_err() override {
+        return 5e-4;
+    }
+
+    int64_t grad_nmax() override {
+        return 20000;
+    }
+
+    uint64_t op_flops(ggml_tensor * t) override {
+        GGML_UNUSED(t);
+        return 2 * m * n * k * bs[0] * nr[0] * bs[1] * nr[1];
+    }
+
+    test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
+            int64_t m = 32, int64_t n = 32, int64_t k = 32,
+            std::array bs = {10, 10},
+            std::array nr = {2, 2},
+            std::array per = {0, 1, 2, 3},
+            bool v = false)
+        : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per), v(v) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        // C^T = A * B^T: (k, m) * (k, n) => (m, n)
+        ggml_tensor * a;
+        ggml_tensor * b;
+
+        const int npermuted = (per[0] != 0) + (per[1] != 1) + (per[2] != 2) + (per[3] != 3);
+        if (npermuted > 0) {
+            GGML_ASSERT(npermuted == 2);
+            GGML_ASSERT(!v); // not handled
+            GGML_ASSERT(!ggml_is_quantized(type_a) || per[0] == 0);
+            GGML_ASSERT(!ggml_is_quantized(type_b) || per[0] == 0);
+
+            // Create tensors with the permuted dimensions, then permute them back to the dimensions given by m,n,k.
+            const int64_t ne_a[4] = {k, m, bs[0],       bs[1]};
+            const int64_t ne_b[4] = {k, n, bs[0]*nr[0], bs[1]*nr[1]};
+
+            a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]);
+            b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]);
+            if (!ggml_is_quantized(type_a)) {
+                if (bs[1] == 1 && nr[1] == 1) {
+                    ggml_set_param(a);
+                }
+                ggml_set_param(b);
+            }
+            ggml_set_name(a, "a");
+            ggml_set_name(b, "b");
+
+            a = ggml_permute(ctx, a, per[0], per[1], per[2], per[3]);
+            b = ggml_permute(ctx, b, per[0], per[1], per[2], per[3]);
+            ggml_set_name(a, "a_permuted");
+            ggml_set_name(b, "b_permuted");
+        } else {
+            if (v) {
+                a = ggml_new_tensor_4d(ctx, type_a, k*2, m, bs[0],       bs[1]);
+                b = ggml_new_tensor_4d(ctx, type_b, k*2, n, bs[0]*nr[0], bs[1]*nr[1]);
+
+                if (!ggml_is_quantized(type_a)) {
+                    if (bs[1] == 1 && nr[1] == 1) {
+                        ggml_set_param(a);
+                    }
+                    ggml_set_param(b);
+                }
+
+                a = ggml_view_4d(ctx, a, k, m, bs[0],       bs[1],       a->nb[1], a->nb[2], a->nb[3], 0);
+                b = ggml_view_4d(ctx, b, k, n, bs[0]*nr[0], bs[1]*nr[1], b->nb[1], b->nb[2], b->nb[3], 0);
+            } else {
+                a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0],       bs[1]);
+                b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
+
+                if (!ggml_is_quantized(type_a)) {
+                    if (bs[1] == 1 && nr[1] == 1) {
+                        ggml_set_param(a);
+                    }
+                    ggml_set_param(b);
+                }
+            }
+            ggml_set_name(a, "a");
+            ggml_set_name(b, "b");
+        }
+
+        ggml_tensor * out = ggml_mul_mat(ctx, a, b);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_MUL_MAT_ID
+struct test_mul_mat_id : public test_case {
+    const ggml_type type_a;
+    const ggml_type type_b;
+    const int n_mats;
+    const int n_used;
+    const bool b; // broadcast b matrix
+    const int64_t m;
+    const int64_t n;
+    const int64_t k;
+
+    std::string vars() override {
+        return VARS_TO_STR8(type_a, type_b, n_mats, n_used, b, m, n, k);
+    }
+
+    double max_nmse_err() override {
+        return 5e-4;
+    }
+
+    uint64_t op_flops(ggml_tensor * t) override {
+        GGML_UNUSED(t);
+        return 2 * m * k * n * n_used;
+    }
+
+    test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
+            int n_mats = 8, int n_used = 2, bool b = false,
+            int64_t m = 32, int64_t n = 32, int64_t k = 32)
+        : type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b),
+            m(m), n(n), k(k) {
+            GGML_ASSERT(n_used <= n_mats);
+        }
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        // C^T = A * B^T: (k, m) * (k, n) => (m, n)
+        ggml_tensor * as = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);
+        ggml_set_name(as, "as");
+
+        ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
+        ggml_set_name(ids, "ids");
+        if (n_used != n_mats) {
+            ids = ggml_view_2d(ctx, ids, n_used, n, ids->nb[1], 0);
+            ggml_set_name(ids, "view_of_ids");
+        }
+
+        ggml_tensor * b = ggml_new_tensor_3d(ctx, type_b, k, this->b ? 1 : n_used, n);
+        ggml_set_name(b, "b");
+
+        ggml_tensor * out = ggml_mul_mat_id(ctx, as, b, ids);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        std::random_device rd;
+        std::default_random_engine rng(rd());
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_I32) {
+                if (ggml_is_view_op(t->op)) { continue; }
+                // ids
+                for (int64_t r = 0; r < ggml_nrows(t); r++) {
+                    std::vector data(t->ne[0]);
+                    for (int i = 0; i < t->ne[0]; i++) {
+                        data[i] = i % n_mats;
+                    }
+                    std::shuffle(data.begin(), data.end(), rng);
+                    ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
+                }
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+};
+
+// GGML_OP_OUT_PROD
+struct test_out_prod : public test_case {
+    const ggml_type type_a;
+    const ggml_type type_b;
+    const int64_t m;
+    const int64_t n;
+    const int64_t k;
+    const std::array bs; // dims 3 and 4
+    const std::array nr; // repeat in dims 3 and 4
+    const bool trans_b;
+
+    std::string vars() override {
+        return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, trans_b);
+    }
+
+    double max_nmse_err() override {
+        return 5e-4;
+    }
+
+    test_out_prod(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
+            int64_t m = 32, int64_t n = 32, int64_t k = 32,
+            std::array bs = {10, 10},
+            std::array nr = {2, 2},
+            bool trans_b = false)
+        : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), trans_b(trans_b) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, m, k, bs[0], bs[1]);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * b;
+        if (trans_b) {
+            b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
+            b = ggml_transpose(ctx, b);
+        } else {
+            b = ggml_new_tensor_4d(ctx, type_b, n, k, bs[0]*nr[0], bs[1]*nr[1]);
+        }
+        ggml_set_name(b, "b");
+
+        ggml_tensor * out = ggml_out_prod(ctx, a, b);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_SQR
+struct test_sqr : public test_case {
+    const ggml_type type;
+    const std::array ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_sqr(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 5, 4, 3})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_sqr(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    float grad_eps() override {
+        return 0.1f * 0.25f*ne[0]*ne[1]*ne[2]*ne[3]; // 10% of expected value of sum.
+    }
+};
+
+// GGML_OP_SQRT
+struct test_sqrt : public test_case {
+    const ggml_type type;
+    const std::array ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_sqrt(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 3, 3, 2})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_sqrt(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        // fill with positive values
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, 50.0f, 100.0f);
+        }
+    }
+
+    float grad_eps() override {
+        return 20.0f;
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
+};
+
+// GGML_OP_LOG
+struct test_log : public test_case {
+    const ggml_type type;
+    const std::array ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_log(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 5, 4, 3})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_log(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            // log(1) == 0, cluster values there to keep the sum low for better precision in the backward pass:
+            init_tensor_uniform(t, 0.9f, 1.1f);
+        }
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
+};
+
+// GGML_OP_SIN
+struct test_sin : public test_case {
+    const ggml_type type;
+    const std::array ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_sin(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 2, 2, 2})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_sin(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, -6.5f, 6.5f); // Covers interval [-2*pi, 2*pi].
+        }
+    }
+
+    double max_maa_err() override {
+        return 1e-3;
+    }
+
+    float grad_eps() override {
+        return 0.2f;
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
+};
+
+// GGML_OP_COS
+struct test_cos : public test_case {
+    const ggml_type type;
+    const std::array ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_cos(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 2, 2, 2})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_cos(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, -6.5f, 6.5f); // Covers interval [-2*pi, 2*pi].
+        }
+    }
+
+    double max_maa_err() override {
+        return 1e-3;
+    }
+
+    float grad_eps() override {
+        return 0.2f;
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
+};
+
+// GGML_OP_CLAMP
+struct test_clamp : public test_case {
+    const ggml_type type;
+    const std::array ne;
+    float min;
+    float max;
+
+    std::string vars() override {
+        return VARS_TO_STR4(type, ne, min, max);
+    }
+
+    test_clamp(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 5, 4, 3},
+            float min = -0.5f, float max = 0.5f)
+        : type(type), ne(ne), min(min), max(max) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_clamp(ctx, a, min, max);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    float grad_eps() override {
+        return 1e-2f;
+    }
+
+    std::vector grad_expect() override {
+        return {0.0f, 1.0f};
+    }
+};
+
+// GGML_OP_DIAG_MASK_INF
+struct test_diag_mask_inf : public test_case {
+    const ggml_type type;
+    const std::array ne;
+    const int n_past;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, n_past);
+    }
+
+    test_diag_mask_inf(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 10, 3, 2},
+            int n_past = 5)
+        : type(type), ne(ne), n_past(n_past) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_diag_mask_inf(ctx, a, n_past);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_SOFT_MAX
+struct test_soft_max : public test_case {
+    const ggml_type type;
+    const std::array ne;
+    const bool mask;
+    const ggml_type m_prec;
+    const float scale;
+    const float max_bias;
+
+    std::string vars() override {
+        return VARS_TO_STR6(type, ne, mask, m_prec, scale, max_bias);
+    }
+
+    // the 1024 test with bias occasionally fails:
+    // SOFT_MAX(type=f32,ne=[1024,16,1,1],mask=1,scale=1.000000,max_bias=8.000000): [SOFT_MAX] NMSE = 0.000000103 > 0.000000100 FAIL
+    virtual double max_nmse_err() override {
+        return 1e-6;
+    }
+
+    test_soft_max(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 5, 4, 3},
+            bool mask = false,
+            ggml_type m_prec = GGML_TYPE_F32,
+            float scale = 1.0f,
+            float max_bias = 0.0f)
+        : type(type), ne(ne), mask(mask), m_prec(m_prec), scale(scale), max_bias(max_bias) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * mask = nullptr;
+        if (this->mask) {
+            mask = ggml_new_tensor_2d(ctx, m_prec, ne[0], ne[1]);
+            ggml_set_name(mask, "mask");
+        }
+
+        ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
+};
+
+// GGML_OP_SOFT_MAX_BACK
+struct test_soft_max_back : public test_case {
+    const ggml_type type;
+    const std::array ne;
+    const float scale;
+    const float max_bias;
+
+    std::string vars() override {
+        return VARS_TO_STR4(type, ne, scale, max_bias);
+    }
+
+    test_soft_max_back(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 5, 4, 3},
+            float scale = 1.0f,
+            float max_bias = 0.0f)
+        : type(type), ne(ne), scale(scale), max_bias(max_bias) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_soft_max_ext_back(ctx, a, b, scale, max_bias);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_ROPE + GGML_OP_ROPE_BACK
+struct test_rope : public test_case {
+    const ggml_type type;
+    const std::array ne_a;
+    int n_dims;
+    int mode;
+    int n_ctx; // used to generate positions
+    float fs; // freq_scale
+    float ef; // ext_factor
+    float af; // attn_factor
+    bool ff;
+    int v; // view (1 : non-contiguous a)
+    bool forward;
+
+    std::string vars() override {
+        // forward can be inferred from the op, does not need to be printed
+        return VARS_TO_STR10(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v);
+    }
+
+    test_rope(ggml_type type = GGML_TYPE_F32,
+            std::array ne_a = {10, 5, 3, 1},
+            int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f,
+            float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true)
+        : type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a;
+        if (v & 1) {
+            auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
+            a = ggml_new_tensor(ctx, type, 4, ne.data());
+            if (forward) {
+                ggml_set_param(a);
+            }
+            ggml_set_name(a, "a");
+
+            a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
+            ggml_set_name(a, "view_of_a");
+        } else {
+            a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+            if (forward) {
+                ggml_set_param(a);
+            }
+            ggml_set_name(a, "a");
+        }
+
+        const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
+        const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
+
+        ggml_tensor * pos;
+        if (is_mrope || is_vision) {
+            pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2] * 4);
+        } else {
+            pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
+        }
+        ggml_set_name(pos, "pos");
+
+        ggml_tensor * freq = nullptr;
+        if (ff) {
+            freq = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2);
+            ggml_set_name(freq, "freq");
+        }
+
+        ggml_tensor * out;
+        if (is_mrope) {
+            if (is_vision) {
+                GGML_ASSERT(n_dims/4 > 0);
+                int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate
+                if (forward) {
+                    out = ggml_rope_multi     (ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
+                } else {
+                    out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
+                }
+            } else {
+                GGML_ASSERT(n_dims/3 > 0);
+                int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0};
+                if (forward) {
+                    out = ggml_rope_multi     (ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
+                } else {
+                    out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
+                }
+            }
+        } else {
+            if (forward) {
+                out = ggml_rope_ext     (ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
+            } else {
+                out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
+            }
+
+            // TODO: add test with a non-contiguous view as input ; this case is needed for build_rope_2d in clip.cpp
+        }
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_I32) {
+                // pos
+                const int num_pos_ids = (mode & GGML_ROPE_TYPE_MROPE) ? ne_a[2] * 4 : ne_a[2];
+                std::vector data(num_pos_ids);
+                for (int i = 0; i < num_pos_ids; i++) {
+                    data[i] = rand() % n_ctx;
+                }
+                ggml_backend_tensor_set(t, data.data(), 0, num_pos_ids * sizeof(int));
+            } else {
+                if (t->ne[0] == n_dims/2) {
+                    // frequency factors in the range [0.9f, 1.1f]
+                    init_tensor_uniform(t, 0.9f, 1.1f);
+                } else {
+                    init_tensor_uniform(t);
+                }
+            }
+        }
+    }
+
+    double max_maa_err() override {
+        return 1e-3;
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
+};
+
+// GGML_OP_POOL2D
+struct test_pool2d : public test_case {
+    enum ggml_op_pool pool_type;
+    const ggml_type type_input;
+    const std::array ne_input;
+    // kernel size
+    const int k0;
+    const int k1;
+    // stride
+    const int s0;
+    const int s1;
+    // padding
+    const int p0;
+    const int p1;
+
+    std::string vars() override {
+        return VARS_TO_STR9(pool_type, type_input, ne_input, k0, k1, s0, s1, p0, p1);
+    }
+
+    test_pool2d(ggml_op_pool pool_type = GGML_OP_POOL_AVG,
+            ggml_type type_input = GGML_TYPE_F32,
+            std::array ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
+            int k0 = 3, int k1 = 3,
+            int s0 = 1, int s1 = 1,
+            int p0 = 1, int p1 = 1)
+        : pool_type(pool_type), type_input(type_input), ne_input(ne_input), k0(k0), k1(k1), s0(s0), s1(s1), p0(p0), p1(p1) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
+        ggml_set_param(input);
+        ggml_set_name(input, "input");
+
+        ggml_tensor * out = ggml_pool_2d(ctx, input, pool_type, k0, k1, s0, s1, p0, p1);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_CONV_TRANSPOSE_1D
+struct test_conv_transpose_1d : public test_case {
+    const std::array ne_input;
+    const std::array ne_kernel;
+
+    const int s0; // stride
+    const int p0; // padding
+    const int d0; // dilation
+
+    std::string vars() override {
+        return VARS_TO_STR5(ne_input, ne_kernel, s0, p0, d0);
+    }
+
+    test_conv_transpose_1d(std::array ne_input = {197, 32, 1, 1}, // [input_width, input_channels, 1 /* assert in cpu kernel*/, 1 (should be batch)]
+                           std::array ne_kernel = {16, 32, 32, 1}, // [kernel_width, output_channels, input_channels, 1 (should be batch)]
+                           int s0 = 1, int p0 = 0, int d0 = 1)
+        : ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), p0(p0), d0(d0) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
+        ggml_set_name(input, "input");
+
+        ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data());
+        ggml_set_name(kernel, "kernel");
+
+        ggml_tensor * out = ggml_conv_transpose_1d(ctx, kernel, input, s0, p0, d0);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_IM2COL
+struct test_im2col : public test_case {
+    const ggml_type type_input;
+    const ggml_type type_kernel;
+    const ggml_type dst_type;
+    const std::array ne_input;
+    const std::array ne_kernel;
+    // stride
+    const int s0;
+    const int s1;
+    // padding
+    const int p0;
+    const int p1;
+    // dilation
+    const int d0;
+    const int d1;
+    // mode
+    const bool is_2D;
+
+    std::string vars() override {
+        return VARS_TO_STR12(type_input, type_kernel, dst_type, ne_input, ne_kernel, s0, s1, p0, p1, d0, d1, is_2D);
+    }
+
+    test_im2col(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16, ggml_type dst_type = GGML_TYPE_F32,
+            std::array ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
+            std::array ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
+            int s0 = 1, int s1 = 1,
+            int p0 = 1, int p1 = 1,
+            int d0 = 1, int d1 = 1,
+            bool is_2D = true)
+        : type_input(type_input), type_kernel(type_kernel), dst_type(dst_type), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), is_2D(is_2D) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
+        ggml_set_param(input);
+        ggml_set_name(input, "input");
+
+        ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
+        ggml_set_name(kernel, "kernel");
+
+        ggml_tensor * out = ggml_im2col(ctx, kernel, input, s0, s1, p0, p1, d0, d1, is_2D, dst_type);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_CONV_2D_DW
+struct test_conv_2d_dw : public test_case {
+    const std::array ne_input;
+    const std::array ne_kernel;
+    const int stride;
+    const int padding;
+    const int dilation;
+    const bool cwhn;
+
+    std::string vars() override {
+        return VARS_TO_STR6(ne_input, ne_kernel, stride, padding, dilation, cwhn);
+    }
+
+    test_conv_2d_dw(std::array ne_input = {64, 64, 16, 1},
+            std::array ne_kernel = {3, 3, 1, 16},
+            int stride = 1, int padding = 0, int dilation = 1, bool cwhn = false)
+        : ne_input(ne_input), ne_kernel(ne_kernel), stride(stride), padding(padding), dilation(dilation), cwhn(cwhn) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
+        ggml_set_name(input, "input");
+
+        ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data());
+        ggml_set_name(kernel, "kernel");
+
+        if (cwhn) {
+            // change memory layout to channel-most-contiguous (CWHN),
+            // then permute it back so NE matches the original input
+            input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3));
+            input = ggml_permute(ctx, input, 2, 0, 1, 3);
+            kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0));
+            kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1);
+        }
+
+        ggml_tensor * out = ggml_conv_2d_dw_direct(
+            ctx, kernel, input,
+            stride, stride, padding, padding, dilation, dilation);
+        ggml_set_name(out, "out");
+        return out;
+    }
+};
+
+// GGML_OP_CONCAT
+struct test_concat : public test_case {
+    const ggml_type type;
+    const std::array ne_a;
+    const int64_t ne_b_d;
+    const int dim;
+    const int v; // view (1 << 0: non-cont a, 1 << 1: non-cont b)
+
+    std::string vars() override {
+        return VARS_TO_STR5(type, ne_a, ne_b_d, dim, v);
+    }
+
+    test_concat(ggml_type type = GGML_TYPE_F32,
+            std::array ne_a = {10, 5, 5, 5},
+            int64_t ne_b_d = 5,
+            int dim = 2, int v = 0)
+        : type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim), v(v) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        auto ne_b = ne_a;
+        ne_b[dim] = ne_b_d;
+        ggml_tensor * a;
+        if (v & 1) {
+            auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
+            a = ggml_new_tensor(ctx, type, 4, ne.data());
+            ggml_set_name(a, "a");
+
+            a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
+            ggml_set_name(a, "view_of_a");
+        } else {
+            a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+            ggml_set_name(a, "a");
+        }
+        ggml_tensor * b;
+        if (v & 2) {
+            auto ne = ne_b; ne[0] *= 3; ne[1] *= 2; ne[2] *= 4;
+            b = ggml_new_tensor(ctx, type, 4, ne.data());
+            ggml_set_name(b, "b");
+
+            b = ggml_view_4d(ctx, b, ne_b[0], ne_b[1], ne_b[2], ne_b[3], b->nb[1], b->nb[2], b->nb[3], 0);
+            ggml_set_name(b, "view_of_b");
+        } else {
+            b = ggml_new_tensor(ctx, type, 4, ne_b.data());
+            ggml_set_name(b, "b");
+        }
+
+        ggml_tensor * out = ggml_concat(ctx, a, b, dim);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_ARGSORT
+struct test_argsort : public test_case {
+    const ggml_type type;
+    const std::array ne;
+    ggml_sort_order order;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, order);
+    }
+
+    test_argsort(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {16, 10, 10, 10},
+            ggml_sort_order order = GGML_SORT_ORDER_ASC)
+        : type(type), ne(ne), order(order) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_argsort(ctx, a, order);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        std::random_device rd;
+        std::default_random_engine rng(rd());
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_I32) {
+                // indices
+                std::vector data(ggml_nelements(t));
+                for (int i = 0; i < ggml_nelements(t); i++) {
+                    data[i] = rand();
+                }
+                std::shuffle(data.begin(), data.end(), rng);
+                ggml_backend_tensor_set(t, data.data(), 0, ne[0]*ne[1]*ne[2]*ne[3] * sizeof(int));
+            } else if (t->type == GGML_TYPE_F32) {
+                // initialize with unique values to avoid ties
+                for (int64_t r = 0; r < ggml_nrows(t); r++) {
+                    std::vector data(t->ne[0]);
+                    for (int i = 0; i < t->ne[0]; i++) {
+                        data[i] = i;
+                    }
+                    std::shuffle(data.begin(), data.end(), rng);
+                    ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
+                }
+            } else {
+                GGML_ABORT("fatal error");
+            }
+        }
+    }
+};
+
+// GGML_OP_SUM
+struct test_sum : public test_case {
+    const ggml_type type;
+    const std::array ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_sum(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 5, 4, 3})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_sum(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    float grad_eps() override {
+        return 0.1f * sqrtf(ne[0]*ne[1]*ne[2]*ne[3]);
+    }
+};
+
+// GGML_OP_SUM_ROWS
+struct test_sum_rows : public test_case {
+    const ggml_type type;
+    const std::array ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_sum_rows(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 5, 4, 3})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_sum_rows(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_MEAN
+struct test_mean : public test_case {
+    const ggml_type type;
+    const std::array ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_mean(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 5, 4, 3})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_mean(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    float grad_eps() override {
+        return 0.1f * ne[0]*ne[1]*ne[2]*ne[3];
+    }
+};
+
+// GGML_OP_UPSCALE
+struct test_upscale : public test_case {
+    const ggml_type type;
+    const std::array ne;
+    const int32_t scale_factor;
+    const bool transpose;
+    const ggml_scale_mode mode;
+
+    std::string vars() override {
+        return VARS_TO_STR5(type, ne, scale_factor, mode, transpose);
+    }
+
+    test_upscale(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {512, 512, 3, 1},
+            int32_t scale_factor = 2, ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST, bool transpose = false)
+        : type(type), ne(ne), scale_factor(scale_factor), transpose(transpose), mode(mode) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
+        if (transpose) {
+            a = ggml_transpose(ctx, a);
+            ggml_set_name(a, "a_transposed");
+        }
+
+        ggml_tensor * out = ggml_upscale(ctx, a, scale_factor, mode);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_UPSCALE (ext)
+struct test_upscale_ext : public test_case {
+    const ggml_type type;
+    const std::array ne;
+    const std::array ne_tgt;
+    const ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST;
+
+    std::string vars() override {
+        return VARS_TO_STR4(type, ne, ne_tgt, mode);
+    }
+
+    test_upscale_ext(ggml_type type = GGML_TYPE_F32,
+            std::array ne     = {2, 5,  7, 11},
+            std::array ne_tgt = {5, 7, 11, 13},
+            ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST)
+        : type(type), ne(ne), ne_tgt(ne_tgt), mode(mode) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3], mode);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_GROUP_NORM
+struct test_group_norm : public test_case {
+    const ggml_type type;
+    const std::array ne;
+    const int32_t num_groups;
+    const float eps;
+
+    std::string vars() override {
+        return VARS_TO_STR4(type, ne, num_groups, eps);
+    }
+
+    test_group_norm(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {64, 64, 320, 1},
+            int32_t num_groups = 32,
+            float eps = 1e-6f)
+        : type(type), ne(ne), num_groups(num_groups), eps(eps) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_group_norm(ctx, a, num_groups, eps);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_L2_NORM
+struct test_l2_norm : public test_case {
+    const ggml_type type;
+    const std::array ne;
+    const float eps;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_l2_norm(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {64, 64, 320, 1},
+            float eps = 1e-12f)
+        : type(type), ne(ne), eps(eps) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_l2_norm(ctx, a, eps);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_ACC
+struct test_acc : public test_case {
+    const ggml_type type;
+    const std::array ne_a;
+    const std::array ne_b;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne_a, ne_b);
+    }
+
+    test_acc(ggml_type type = GGML_TYPE_F32,
+            std::array ne_a = {256, 17, 1, 1},
+            std::array ne_b = {256, 16, 1, 1})
+        : type(type), ne_a(ne_a), ne_b(ne_b) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
+        ggml_set_param(b);
+        ggml_set_name(b, "b");
+
+        ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], b->nb[1]);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_PAD
+struct test_pad : public test_case {
+    const ggml_type type;
+    const std::array ne_a;
+    const int pad_0;
+    const int pad_1;
+
+    std::string vars() override {
+        return VARS_TO_STR4(type, ne_a, pad_0, pad_1);
+    }
+
+    test_pad(ggml_type type = GGML_TYPE_F32,
+            std::array ne_a = {512, 512, 1, 1},
+            int pad_0 = 1, int pad_1 = 1)
+        : type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1)  {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_pad(ctx, a, pad_0, pad_1, 0, 0);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_PAD_REFLECT_1D
+struct test_pad_reflect_1d : public test_case {
+    const ggml_type type;
+    const std::array ne_a;
+    const int pad_0;
+    const int pad_1;
+
+    std::string vars() override {
+        return VARS_TO_STR4(type, ne_a, pad_0, pad_1);
+    }
+
+    test_pad_reflect_1d(ggml_type type = GGML_TYPE_F32,
+            std::array ne_a = {512, 34, 2, 1},
+            int pad_0 = 10, int pad_1 = 9)
+        : type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1)  {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 2, ne_a.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_pad_reflect_1d(ctx, a, pad_0, pad_1);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_ARANGE
+struct test_arange : public test_case {
+    const ggml_type type;
+    const float start;
+    const float stop;
+    const float step;
+
+    std::string vars() override {
+        return VARS_TO_STR4(type, start, stop, step);
+    }
+
+    test_arange(ggml_type type = GGML_TYPE_F32,
+            float start = 0.f, float stop = 10.f, float step = 1.f)
+        : type(type), start(start), stop(stop), step(step)  {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * out = ggml_arange(ctx, start, stop, step);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_TIMESTEP_EMBEDDING
+struct test_timestep_embedding : public test_case {
+    const ggml_type type;
+    const std::array ne_a;
+    const int dim;
+    const int max_period;
+
+    std::string vars() override {
+        return VARS_TO_STR4(type, ne_a, dim, max_period);
+    }
+
+    test_timestep_embedding(ggml_type type = GGML_TYPE_F32,
+            std::array ne_a = {2, 1, 1, 1},
+            int dim = 320, int max_period=10000)
+        : type(type), ne_a(ne_a), dim(dim), max_period(max_period)  {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_timestep_embedding(ctx, a, dim, max_period);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_LEAKY_RELU
+struct test_leaky_relu : public test_case {
+    const ggml_type type;
+    const std::array ne_a;
+    const float negative_slope;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne_a, negative_slope);
+    }
+
+    test_leaky_relu(ggml_type type = GGML_TYPE_F32,
+            std::array ne_a = {10, 5, 4, 3},
+            float negative_slope = 0.1f)
+        : type(type), ne_a(ne_a), negative_slope(negative_slope)  {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_leaky_relu(ctx, a, negative_slope, true);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_FLASH_ATTN_EXT
+struct test_flash_attn_ext : public test_case {
+    const int64_t hsk; // K head size
+    const int64_t hsv; // V head size
+    const int64_t nh; // num heads
+    const int64_t nr; // repeat in Q, tests for grouped-query attention
+    const int64_t kv; // kv size
+    const int64_t nb; // batch size
+
+    const bool mask; // use mask
+
+    const float max_bias; // ALiBi
+    const float logit_softcap; // Gemma 2
+
+    const ggml_prec prec;
+    const ggml_type type_KV;
+    std::array permute;
+
+    std::string vars() override {
+        return VARS_TO_STR12(hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
+    }
+
+    double max_nmse_err() override {
+        return 5e-4;
+    }
+
+    uint64_t op_flops(ggml_tensor * t) override {
+        GGML_UNUSED(t);
+        // Just counting matmul costs:
+        // Q*K^T is nb x hsk x kv, P*V is nb x kv x hsv, per head
+        return 2 * nh*nr * nb * (hsk + hsv) * kv;
+    }
+
+    test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
+                        bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
+                        ggml_type type_KV = GGML_TYPE_F16, std::array permute = {0, 1, 2, 3})
+        : hsk(hsk), hsv(hsv), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
+        const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_KV));
+
+        auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * {
+            int64_t ne[4] = {ne0, ne1, ne2, ne3};
+            int64_t ne_perm[4];
+            for (int i = 0; i < 4; ++i) {
+                ne_perm[permute[i]] = ne[i];
+            }
+            ggml_tensor * t = ggml_new_tensor_4d(ctx, type, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3]);
+            if (permute != std::array{0, 1, 2, 3}) {
+                t = ggml_permute(ctx, t, permute[0], permute[1], permute[2], permute[3]);
+            }
+            return t;
+        };
+
+        ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr, 1);
+        ggml_set_name(q, "q");
+
+        ggml_tensor * k = create_permuted(type_KV,       hsk_padded, kv, nh,    1);
+        ggml_set_name(k, "k");
+
+        ggml_tensor * v = create_permuted(type_KV,       hsv_padded, kv, nh,    1);
+        ggml_set_name(v, "v");
+
+        ggml_tensor * m = nullptr;
+        if (mask) {
+            m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1);
+            ggml_set_name(m, "m");
+        }
+
+        ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hsk), max_bias, logit_softcap);
+        ggml_flash_attn_ext_set_prec(out, prec);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
+};
+
+// GGML_OP_CROSS_ENTROPY_LOSS
+struct test_cross_entropy_loss : public test_case {
+    const ggml_type type;
+    const std::array ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_cross_entropy_loss(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 5, 4, 3})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(logits);
+        ggml_set_name(logits, "logits");
+
+        ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data());
+        // The labels are assumed to be constant -> no gradients.
+        ggml_set_name(labels, "labels");
+
+        // Ensure labels add up to 1:
+        labels = ggml_soft_max(ctx, labels);
+        ggml_set_name(labels, "labels_normalized");
+
+        ggml_tensor * out = ggml_cross_entropy_loss(ctx, logits, labels);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        // For larger abs. diffs between logits softmax is more linear, therefore more precise num. gradients.
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, -100.0f, 100.0f);
+        }
+    }
+
+    float grad_eps() override {
+        return 1.0f;
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
+};
+
+// GGML_OP_CROSS_ENTROPY_LOSS_BACK
+struct test_cross_entropy_loss_back : public test_case {
+    const ggml_type type;
+    const std::array ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_cross_entropy_loss_back(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 5, 4, 3})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * grad = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+        ggml_set_name(grad, "grad");
+
+        ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(logits, "logits");
+
+        ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(labels, "labels");
+
+        // Ensure labels add up to 1:
+        labels = ggml_soft_max(ctx, labels);
+        ggml_set_name(labels, "labels_normalized");
+
+        ggml_tensor * out = ggml_cross_entropy_loss_back(ctx, grad, logits, labels);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_OPT_STEP_ADAMW
+struct test_opt_step_adamw : public test_case {
+    const ggml_type type;
+    const std::array ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_opt_step_adamw(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 5, 4, 3})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
+        ggml_set_param(a); // Despite tensor a having gradients the output tensor will not.
+        ggml_set_name(a, "a");
+
+        ggml_tensor * grad = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
+        ggml_set_name(grad, "grad");
+
+        ggml_tensor * grad_m = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
+        ggml_set_name(grad_m, "grad_m");
+
+        ggml_tensor * grad_v = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
+        ggml_set_name(grad_v, "grad_v");
+
+        ggml_tensor * adamw_params = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 7);
+        ggml_set_name(adamw_params, "adamw_params");
+
+        ggml_tensor * out = ggml_opt_step_adamw(ctx, a, grad, grad_m, grad_v, adamw_params);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, 0.0f, 1.0f); // grad_v and adamw_params need non-negative values.
+        }
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
+};
+
+enum llm_norm_type {
+    LLM_NORM,
+    LLM_NORM_RMS,
+};
+
+struct llama_hparams {
+    uint32_t n_vocab;
+    uint32_t n_embd;
+    uint32_t n_head;
+    uint32_t n_head_kv;
+    static constexpr uint32_t n_layer = 1;
+    uint32_t n_rot;
+    uint32_t n_embd_head; // dimension of values (d_v)
+    uint32_t n_ff;
+
+    float f_norm_eps;
+    float f_norm_rms_eps;
+
+    // cparams
+    static constexpr uint32_t n_ctx = 512; // user-specified context size
+    static constexpr uint32_t n_ctx_orig = n_ctx;
+
+    // batch
+    int32_t n_tokens;
+
+    // llm_build_context
+    static constexpr int32_t n_kv    = 32; // size of KV cache to consider (n_kv <= n_ctx
+    static constexpr int32_t kv_head = 1;  // index of where we store new KV data in the cache
+
+    uint32_t n_embd_gqa() const { // dimension of key embeddings across all k-v heads
+        return n_embd_head * n_head_kv;
+    }
+};
+
+// LLM base class
+struct test_llm : public test_case {
+    llama_hparams hp;
+
+protected:
+    test_llm(llama_hparams hp)
+        : hp(std::move(hp)) {
+    }
+
+public:
+    struct ggml_tensor * llm_build_norm(
+            struct ggml_context * ctx,
+             struct ggml_tensor * cur,
+             struct ggml_tensor * mw,
+             struct ggml_tensor * mb,
+                  llm_norm_type   type) {
+        switch (type) {
+            case LLM_NORM:     cur = ggml_norm    (ctx, cur, hp.f_norm_eps); break;
+            case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, hp.f_norm_rms_eps); break;
+        }
+        cur = ggml_mul(ctx, cur, mw);
+        if (mb) {
+            cur = ggml_add(ctx, cur, mb);
+        }
+        return cur;
+    }
+
+    void llm_build_kv_store(
+            struct ggml_context * ctx,
+             struct ggml_tensor * k_l,
+             struct ggml_tensor * v_l,
+             struct ggml_tensor * k_cur,
+             struct ggml_tensor * v_cur) {
+        // compute the transposed [n_tokens, n_embd] V matrix
+        struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, hp.n_embd_gqa(), hp.n_tokens));
+
+        struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, k_l, hp.n_tokens*hp.n_embd_gqa(),
+                (ggml_row_size(k_l->type, hp.n_embd_gqa()))*hp.kv_head);
+
+        struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, v_l, hp.n_tokens, hp.n_embd_gqa(),
+                (  hp.n_ctx)*ggml_element_size(v_l),
+                (hp.kv_head)*ggml_element_size(v_l));
+
+        // important: storing RoPE-ed version of K in the KV cache!
+        ggml_cpy(ctx, k_cur,   k_cache_view);
+        ggml_cpy(ctx, v_cur_t, v_cache_view);
+    }
+
+    struct ggml_tensor * llm_build_kqv(
+            struct ggml_context * ctx,
+             struct ggml_tensor * k_l,
+             struct ggml_tensor * v_l,
+             struct ggml_tensor * q_cur,
+             struct ggml_tensor * kq_mask,
+                        float     kq_scale) {
+        struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
+
+        struct ggml_tensor * k =
+            ggml_view_3d(ctx, k_l,
+                    hp.n_embd_head, hp.n_kv, hp.n_head_kv,
+                    ggml_row_size(k_l->type, hp.n_embd_gqa()),
+                    ggml_row_size(k_l->type, hp.n_embd_head),
+                    0);
+
+        struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
+
+        kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, 0.0f);
+
+        // split cached v into n_head heads
+        struct ggml_tensor * v =
+            ggml_view_3d(ctx, v_l,
+                    hp.n_kv, hp.n_embd_head, hp.n_head_kv,
+                    ggml_element_size(v_l)*hp.n_ctx,
+                    ggml_element_size(v_l)*hp.n_ctx*hp.n_embd_head,
+                    0);
+
+        struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
+
+        struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
+
+        struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, hp.n_embd_head*hp.n_head, hp.n_tokens);
+
+        struct ggml_tensor * wo = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd);
+        cur = ggml_mul_mat(ctx, wo, cur);
+
+        return cur;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_I32) {
+                // pos
+                std::vector data(hp.n_tokens);
+                for (int i = 0; i < hp.n_tokens; i++) {
+                    data[i] = rand() % hp.n_ctx;
+                }
+                ggml_backend_tensor_set(t, data.data(), 0, hp.n_tokens * sizeof(int));
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+};
+
+// Llama
+struct test_llama : public test_llm {
+    static constexpr float freq_base = 10000.0f;
+    static constexpr float freq_scale = 1.0f;
+    static constexpr float ext_factor = 0.0f;
+    static constexpr float attn_factor = 1.0f;
+    static constexpr float beta_fast = 32.0f;
+    static constexpr float beta_slow = 1.0f;
+
+    std::string op_desc(ggml_tensor * t) override {
+        GGML_UNUSED(t);
+        return "LLAMA";
+    }
+
+    std::string vars() override {
+        auto n_tokens = hp.n_tokens;
+        return VARS_TO_STR1(n_tokens);
+    }
+
+    double max_nmse_err() override {
+        return 2e-3;
+    }
+
+    test_llama(int n_tokens = 1)
+        : test_llm({
+            /*n_vocab        =*/ 32000,
+            /*n_embd         =*/ 3200,
+            /*n_head         =*/ 32,
+            /*n_head_kv      =*/ 32,
+            /*n_rot          =*/ 100,
+            /*n_embd_head    =*/ 100,
+            /*n_ff           =*/ 8640,
+            /*f_norm_eps     =*/ 0.f,
+            /*f_norm_rms_eps =*/ 1e-5f,
+            /*n_tokens       =*/ n_tokens,
+        }) {
+    }
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1);
+
+        ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
+        ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
+
+        for (uint32_t il = 0; il < hp.n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            // norm
+            ggml_tensor * attn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
+            cur = llm_build_norm(ctx, inpL, attn_norm, nullptr, LLM_NORM_RMS);
+
+            // self-attention
+            {
+                ggml_tensor * wq = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd);
+                ggml_tensor * wk = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd_gqa());
+                ggml_tensor * wv = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd_gqa());
+
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = ggml_mul_mat(ctx, wq, cur);
+                struct ggml_tensor * Kcur = ggml_mul_mat(ctx, wk, cur);
+                struct ggml_tensor * Vcur = ggml_mul_mat(ctx, wv, cur);
+
+                Qcur = ggml_rope_ext(
+                    ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head,    hp.n_tokens), inp_pos, nullptr,
+                    hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+
+                Kcur = ggml_rope_ext(
+                    ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, nullptr,
+                    hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+
+                llm_build_kv_store(ctx, k_l, v_l, Kcur, Vcur);
+
+                cur = llm_build_kqv(ctx, k_l, v_l, Qcur, KQ_mask, 1.0f/sqrtf(float(hp.n_embd_head)));
+            }
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx, cur, inpSA);
+
+            // feed-forward network
+            ggml_tensor * ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
+            cur = llm_build_norm(ctx, ffn_inp, ffn_norm, nullptr, LLM_NORM_RMS);
+
+            ggml_tensor * ffn_gate = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
+            ggml_tensor * ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_ff,   hp.n_embd);
+            ggml_tensor * ffn_up   = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
+            struct ggml_tensor * tmp = ggml_mul_mat(ctx, ffn_up, cur);
+            cur = ggml_mul_mat(ctx, ffn_gate, cur);
+            cur = ggml_silu(ctx, cur);
+            cur = ggml_mul(ctx, cur, tmp);
+            cur = ggml_mul_mat(ctx, ffn_down, cur);
+
+            cur = ggml_add(ctx, cur, ffn_inp);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        ggml_tensor * output_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
+        cur = llm_build_norm(ctx, cur, output_norm, nullptr, LLM_NORM_RMS);
+
+        // lm_head
+        ggml_tensor * output = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_vocab);
+        cur = ggml_mul_mat(ctx, output, cur);
+
+        return cur;
+    }
+};
+
+// Falcon
+struct test_falcon : public test_llm {
+    static constexpr float freq_base = 10000.0f;
+    static constexpr float freq_scale = 1.0f;
+    static constexpr float ext_factor = 0.0f;
+    static constexpr float attn_factor = 1.0f;
+    static constexpr float beta_fast = 32.0f;
+    static constexpr float beta_slow = 1.0f;
+
+    std::string op_desc(ggml_tensor * t) override {
+        GGML_UNUSED(t);
+        return "FALCON";
+    }
+
+    std::string vars() override {
+        auto n_tokens = hp.n_tokens;
+        return VARS_TO_STR1(n_tokens);
+    }
+
+    double max_nmse_err() override {
+        return 2e-3;
+    }
+
+    test_falcon(int n_tokens = 1)
+        : test_llm({
+            /*n_vocab        =*/ 32000,
+            /*n_embd         =*/ 3200,
+            /*n_head         =*/ 50,
+            /*n_head_kv      =*/ 1,
+            /*n_rot          =*/ 64,
+            /*n_embd_head    =*/ 64,
+            /*n_ff           =*/ 8640,
+            /*f_norm_eps     =*/ 1e-5f,
+            /*f_norm_rms_eps =*/ 0.f,
+            /*n_tokens       =*/ n_tokens,
+        }) {
+    }
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1);
+
+        ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
+        ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
+
+        for (uint32_t il = 0; il < hp.n_layer; ++il) {
+            // norm
+            ggml_tensor * attn_norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
+            ggml_tensor * attn_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
+            ggml_tensor * attn_norm = llm_build_norm(ctx, inpL, attn_norm_w, attn_norm_b, LLM_NORM);
+
+            // self-attention
+            {
+                cur = attn_norm;
+
+                ggml_tensor * wqkv = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd + 2*hp.n_embd_gqa());
+
+                cur = ggml_mul_mat(ctx, wqkv, cur);
+
+                struct ggml_tensor * Qcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd,     hp.n_tokens, cur->nb[1], 0*sizeof(float)*(hp.n_embd)));
+                struct ggml_tensor * Kcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd_gqa(), hp.n_tokens, cur->nb[1], 1*sizeof(float)*(hp.n_embd)));
+                struct ggml_tensor * Vcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd_gqa(), hp.n_tokens, cur->nb[1], 1*sizeof(float)*(hp.n_embd + hp.n_embd_gqa())));
+
+                Qcur = ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head,    hp.n_tokens);
+                Kcur = ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens);
+
+                // using mode = 2 for neox mode
+                Qcur = ggml_rope_ext(
+                    ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig,
+                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                );
+
+                Kcur = ggml_rope_ext(
+                    ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig,
+                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                );
+
+                llm_build_kv_store(ctx, k_l, v_l, Kcur, Vcur);
+
+                cur = llm_build_kqv(ctx, k_l, v_l, Qcur, KQ_mask, 1.0f/sqrtf(float(hp.n_embd_head)));
+            }
+
+            struct ggml_tensor * ffn_inp = cur;
+
+            // feed forward
+            {
+                ggml_tensor * ffn_up   = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
+                ggml_tensor * ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_ff, hp.n_embd);
+                cur = attn_norm;
+                cur = ggml_mul_mat(ctx, ffn_up, cur);
+                cur = ggml_gelu(ctx, cur);
+                cur = ggml_mul_mat(ctx, ffn_down, cur);
+            }
+
+            cur = ggml_add(ctx, cur, ffn_inp);
+
+            cur = ggml_add(ctx, cur, inpL);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        ggml_tensor * output_norm   = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
+        ggml_tensor * output_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
+        cur = llm_build_norm(ctx, cur, output_norm, output_norm_b, LLM_NORM);
+
+        // lm_head
+        ggml_tensor * output = ggml_new_tensor_2d(ctx, GGML_TYPE_Q8_0, hp.n_embd, hp.n_vocab);
+        cur = ggml_mul_mat(ctx, output, cur);
+
+        return cur;
+    }
+};
+
+
+// ###########################################
+// ## Section 3: GGML Op Test Instantiation ##
+// ###########################################
+static const ggml_type all_types[] = {
+    GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16,
+    GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
+    GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
+    GGML_TYPE_Q8_0,
+    GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
+    GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
+    GGML_TYPE_Q6_K,
+    // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends
+    GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
+    GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
+    GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
+};
+
+static const ggml_type base_types[] = {
+    GGML_TYPE_F32, GGML_TYPE_F16,
+    GGML_TYPE_Q8_0, // for I8MM tests
+    GGML_TYPE_Q4_0,
+    GGML_TYPE_Q4_1, // for I8MM tests
+    GGML_TYPE_Q4_K,
+    GGML_TYPE_IQ2_XXS
+};
+
+static const ggml_type other_types[] = {
+    GGML_TYPE_Q4_1,
+    GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
+    GGML_TYPE_Q8_0,
+    GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
+    GGML_TYPE_Q5_K,
+    GGML_TYPE_Q6_K,
+    // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends
+    GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
+    GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
+    GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
+    GGML_TYPE_BF16,
+};
+
+// Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low
+static std::vector> make_test_cases_eval() {
+    std::vector> test_cases;
+    std::default_random_engine rng(0);
+
+    // unary ops
+    for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
+        for (int v : {0, 1}) {
+            for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
+                test_cases.emplace_back(new test_unary((ggml_unary_op) op, type, { 128, 2, 2, 2 }, v));
+                test_cases.emplace_back(new test_unary((ggml_unary_op) op, type, { 5, 7, 11, 13 }, v));
+            }
+        }
+    }
+
+    test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));
+    for (ggml_type type : all_types) {
+        for (int b : {1, 7}) {
+            for (bool v : {false, true}) {
+                test_cases.emplace_back(new test_get_rows(type, 256, 5, 4, b, v));
+            }
+        }
+    }
+    for (int b : {1, 7}) {
+        for (bool v : {false, true}) {
+            test_cases.emplace_back(new test_get_rows(GGML_TYPE_I32, 256, 5, 4, b, v));
+        }
+    }
+
+    test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_F32, 1, 8, 2, 1, false));
+    for (ggml_type type : all_types) {
+        for (bool v : {false, true}) {
+            test_cases.emplace_back(new test_get_rows_back(type, 256, 5, 4, 1, v));
+        }
+    }
+    for (bool v : {false, true}) {
+        test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_I32, 256, 5, 4, 1, v));
+    }
+
+    for (ggml_type type_input : {GGML_TYPE_F32}) {
+        for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
+            for (int k0 : {1, 3}) {
+                for (int k1 : {1, 3}) {
+                    for (int s0 : {1, 2}) {
+                        for (int s1 : {1, 2}) {
+                            for (int p0 : {0, 1}) {
+                                for (int p1 : {0, 1}) {
+                                    test_cases.emplace_back(new test_pool2d(pool_type, type_input, {10, 10, 3, 1}, k0, k1, s0, s1, p0, p1));
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+
+    // im2col 1D
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
+    for (int s0 : {1, 3}) {
+        for (int p0 : {0, 3}) {
+            for (int d0 : {1, 3}) {
+                test_cases.emplace_back(new test_im2col(
+                    GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 2, 2, 1}, {3, 2, 2, 1},
+                    s0, 0, p0, 0, d0, 0, false));
+            }
+        }
+    }
+
+    // im2col 2D
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
+    for (int s0 : {1, 3}) {
+        for (int s1 : {1, 3}) {
+            for (int p0 : {0, 3}) {
+                for (int p1 : {0, 3}) {
+                    for (int d0 : {1, 3}) {
+                        for (int d1 : {1, 3}) {
+                            test_cases.emplace_back(new test_im2col(
+                                GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 2, 2}, {3, 3, 2, 2},
+                                s0, s1, p0, p1, d0, d1, true));
+                        }
+                    }
+                }
+            }
+        }
+    }
+
+    // extra tests for im2col 2D
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 3, 1, 32}, 1, 1, 1, 1, 1, 1, true));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 32}, {3, 3, 2, 32}, 1, 1, 1, 1, 1, 1, true));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 1024}, {3, 3, 1, 1024}, 1, 1, 1, 1, 1, 1, true));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 1024}, {3, 3, 2, 1024}, 1, 1, 1, 1, 1, 1, true));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2048}, {3, 3, 1, 2048}, 1, 1, 1, 1, 1, 1, true));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2048}, {3, 3, 2, 2048}, 1, 1, 1, 1, 1, 1, true));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));
+
+    // sycl backend will limit task global_range < MAX_INT
+    // test cases for 2D im2col with large input W and H (occurs in stable-diffusion)
+    // however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.)
+    // these cases are verified (pass) in Intel(R) Data Center GPU Max 1100 (sycl backend) and NV A30 (cuda backend)
+    // test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));
+    // test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));
+
+    test_cases.emplace_back(new test_conv_2d_dw({17, 34, 9, 1}, {3, 3, 1, 9}, 1, 0, 1, false));
+    test_cases.emplace_back(new test_conv_2d_dw({17, 34, 9, 1}, {3, 3, 1, 9}, 1, 0, 1, true));
+    test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, false));
+    test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, true));
+
+    for(uint32_t Cout : {1, 9}){
+        for(uint32_t Cin : {1, 7}){
+            for(uint32_t K : {1, 3, 1337}){
+                for(uint32_t L : {1, 2, 13}){
+                    for(uint32_t s0: {1, 2, 3}){
+                        test_cases.emplace_back(new test_conv_transpose_1d({L,Cin,1,1}, {K,Cout,Cin,1}, s0, 0, 1));
+                    }
+                }
+            }
+        }
+    }
+
+    test_cases.emplace_back(new test_conv_transpose_1d());
+    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1));
+    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1));
+    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 1, 0, 1));
+    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,2,2,1}, 2, 0, 1));
+    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,2,2,1}, 1, 0, 1));
+    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
+    test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
+
+    test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4,  500, 1, 1}));
+    test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));
+
+    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32,    1, 1, 1}));
+    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100,  10, 1, 1}));
+    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
+    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 12, 1, 1}));
+    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {2000, 10, 1, 1}));
+    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {5438,  3, 1, 1}));
+
+    for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {2, 1, 1, 1}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 2, 1, 1}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 2, 1}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 2}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_I32, {10, 5, 4, ne3}, {2, 1, 1, 1}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, ne3}, {1, 1, 1, 2}));
+    }
+
+    for (bool view : {false, true}) {
+        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 1}, view));
+        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
+        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view));
+        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view));
+        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
+    }
+
+    test_cases.emplace_back(new test_dup(GGML_TYPE_F32));
+    test_cases.emplace_back(new test_dup(GGML_TYPE_F16));
+    test_cases.emplace_back(new test_dup(GGML_TYPE_I32));
+    test_cases.emplace_back(new test_dup(GGML_TYPE_I16));
+    test_cases.emplace_back(new test_dup(GGML_TYPE_F32, {10, 10, 5, 1}, {0, 2, 1, 3}));
+    test_cases.emplace_back(new test_dup(GGML_TYPE_F16, {10, 10, 5, 1}, {0, 2, 1, 3})); // dup by rows
+    test_cases.emplace_back(new test_dup(GGML_TYPE_F32, {10, 10, 5, 1}, {1, 0, 2, 3}));
+    test_cases.emplace_back(new test_dup(GGML_TYPE_F16, {10, 10, 5, 1}, {1, 0, 2, 3})); // dup dst not-contiguous
+    test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10,  8, 3, 1}, {0, 2, 1, 3}));
+    test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10,  8, 3, 1}, {1, 2, 0, 3}));
+
+    for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
+        test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim));
+    }
+
+    for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
+        test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim));
+    }
+
+    // same-type copy
+    for (ggml_type type : all_types) {
+        const auto nk = ggml_blck_size(type);
+
+        for (int k = 1; k < 4; ++k) {
+            test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}));
+            test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {0, 2, 1, 3}));
+            test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {0, 3, 1, 2}, {0, 2, 1, 3}));
+        }
+    }
+
+    for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) {
+        for (ggml_type type_dst : all_types) {
+            test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
+            test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
+        }
+    }
+    for (ggml_type type_src : all_types) {
+        for (ggml_type type_dst : {GGML_TYPE_F32}) {
+            test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
+            test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
+        }
+    }
+    for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
+        for (ggml_type type_dst : {GGML_TYPE_F16, GGML_TYPE_F32}) {
+            test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {1, 0, 2, 3})); // cpy not-contiguous
+        }
+    }
+
+    test_cases.emplace_back(new test_cont());
+    test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1}));
+    test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 3 ,5}));
+    test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 3, 5 ,7}));
+    test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 1, 1 ,1}));
+    test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 1, 3 ,5}));
+    test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 3, 5 ,7}));
+    test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 1, 1 ,1}));
+    test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 1, 3 ,5}));
+    test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 3, 5 ,7}));
+
+    auto add_test_bin_bcast = [&](ggml_type type, std::array ne, std::array nr) {
+        for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {
+            test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
+        }
+    };
+    for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
+        add_test_bin_bcast(type, {1, 1, 8, 1}, {1, 1, 1, 1});
+        add_test_bin_bcast(type, {1, 1, 1, 1}, {32, 1, 1, 1});
+        add_test_bin_bcast(type, {1, 1, 320, 320}, {1, 1, 1, 1});
+        add_test_bin_bcast(type, {10, 5, 1, 1}, {1, 1, 1, 1});
+        add_test_bin_bcast(type, {10, 5, 4, 1}, {1, 1, 1, 1});
+        add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 1});
+        add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 1, 1, 1});
+        add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 1, 1});
+        add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 1});
+        add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 2});
+        add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 2});
+        add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2});
+        add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2});
+
+        // stable diffusion
+        add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 1, 1, 1});
+        add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 16, 16, 1});
+        add_test_bin_bcast(type, {1280, 16, 16, 1}, {1, 1, 1, 1});
+        add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 256, 1, 1});
+        add_test_bin_bcast(type, {1, 1, 1280, 1}, {16, 16, 1, 1});
+        add_test_bin_bcast(type, {16, 16, 1280, 1}, {1, 1, 1, 1});
+        add_test_bin_bcast(type, {1, 1, 1920, 1}, {16, 16, 1, 1});
+        add_test_bin_bcast(type, {1, 1, 2560, 1}, {16, 16, 1, 1});
+        add_test_bin_bcast(type, {1, 1, 1280, 1}, {32, 32, 1, 1});
+        add_test_bin_bcast(type, {1, 1, 1920, 1}, {32, 32, 1, 1});
+        add_test_bin_bcast(type, {1, 1, 640, 1}, {32, 32, 1, 1});
+        add_test_bin_bcast(type, {5120, 1, 1, 1}, {1, 256, 1, 1});
+        add_test_bin_bcast(type, {640, 1, 1, 1}, {1, 1, 1, 1});
+        //add_test_bin_bcast(type, {3, 3, 2560, 1280}, {1, 1, 1, 1});
+        //add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1});
+    }
+
+    test_cases.emplace_back(new test_add1());
+    test_cases.emplace_back(new test_scale());
+    test_cases.emplace_back(new test_silu_back());
+
+    for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
+        for (bool v : {false, true}) {
+            test_cases.emplace_back(new test_norm    (GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
+            test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
+        }
+        test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
+        test_cases.emplace_back(new test_l2_norm      (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
+    }
+
+    test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
+
+    test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1}));
+    test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
+    test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
+
+    test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
+
+    test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1));
+    test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));
+    test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
+    test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
+
+    test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 1, 1));
+    test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 32, 1));
+    test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 32, 4));
+    test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 128, 4));
+
+    test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 1, 1));
+    test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 1));
+    test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
+    test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));
+
+    for (ggml_type type_a : all_types) {
+        for (int i = 1; i < 10; ++i) {
+            test_cases.emplace_back(new test_mul_mat(type_a,    GGML_TYPE_F32, 16,  i, 256, { 1,  1}, {1, 1}));
+        }
+    }
+
+#if 1
+    for (ggml_type type_a : base_types) {
+        for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+            // test cases without permutation
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {1, 1}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {1, 1}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {1, 1}, {1, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 1}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 1}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 2}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 2}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 2}, {1, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 2}, {2, 2}));
+
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 2}));
+
+            // test cases with permutation
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
+
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
+
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
+
+            // test cases with large ne00/ne10 to cover stream-k fixup
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 1024, {3, 2}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, 1024, {3, 2}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 1024, {3, 2}, {1, 1}));
+        }
+    }
+    for (ggml_type type_a : other_types) {
+        for (ggml_type type_b : {GGML_TYPE_F32}) {
+            if (ggml_blck_size(type_a) != 256) {
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, ggml_blck_size(type_a), {1,  1}, {1, 1}));
+            }
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1,  1}, {1, 1}));
+        }
+    }
+#else
+    // m = a rows
+    // n = b rows
+    // k = cols
+    std::uniform_int_distribution<> dist_m(1, 128);
+    std::uniform_int_distribution<> dist_n(16, 128);
+    std::uniform_int_distribution<> dist_k(1, 16);
+    for (int i = 0; i < 1000; i++) {
+        for (ggml_type type_a : all_types) {
+            for (ggml_type type_b : {GGML_TYPE_F32}) {
+                int m = dist_m(rng);
+                int n = dist_n(rng);
+                int k = dist_k(rng) * ggml_blck_size(type_a);
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, m, n, k, { 1,  1}, {1, 1}));
+            }
+        }
+    }
+#endif
+
+    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 2,  128, { 8,  1}, {1, 1}));
+    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  83, 2,  128, { 8,  1}, {4, 1}));
+    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 2,   64, { 8,  1}, {4, 1}));
+    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  83, 2,   64, { 8,  1}, {4, 1}));
+    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 45, 128, { 8,  1}, {4, 1}));
+    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45,  64, { 8,  1}, {4, 1}));
+    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 193, {1,  1}, {4, 1}, {0, 2, 1, 3}));
+    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 67,  {1,  1}, {4, 1}, {0, 2, 1, 3}));
+
+    for (auto bs : {1,2,4,8}) {
+        for (auto nr : {1,4}) {
+            for (uint32_t m = 0; m < 2; ++m) {
+                for (uint32_t k = 0; k < 2; ++k) {
+                    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056 + m, 1, 128 + k,  {bs,  1}, {nr, 1}, {0, 2, 1, 3}));
+                    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128 + m,  1, 1056 + k, {bs,  1}, {nr, 1}, {0, 1, 2, 3}, true));
+                }
+            }
+        }
+    }
+
+    // sycl backend will limit task global_range < MAX_INT
+    // test case for f16-type-convert-to-fp32 kernel with large k under fp32 compute dtype (occurs in stable-diffusion)
+    // however this case needs to alloc more memory which may fail in some devices (Intel Arc770, etc.)
+    // this case is verified (pass) in Intel(R) Data Center GPU Max 1100 (sycl backend) and NV A30 (cuda backend)
+    // test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 512, 262144, 9216, {1, 1}, {1, 1}));
+
+    for (ggml_type type_a : base_types) {
+        for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
+            for (int n_mats : {4, 8}) {
+                for (int n_used : {1, 2, 4}) {
+                    for (bool b : {false, true}) {
+                        for (int n : {1, 32, 129}) {
+                            int m = 512;
+                            int k = 256;
+                            test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
+                        }
+                    }
+                }
+            }
+        }
+    }
+
+    for (ggml_type type_a : other_types) {
+        for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
+            for (int n_mats : {4}) {
+                for (int n_used : {2}) {
+                    for (bool b : {false}) {
+                        for (int n : {1, 32}) {
+                            int m = 512;
+                            int k = 256;
+                            test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
+                        }
+                    }
+                }
+            }
+        }
+    }
+
+    for (ggml_type type_a : base_types) {
+        for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+            for (int n : {1, 16}) {
+                for (int k : {1, 16}) {
+                    for (int bs2 : {1, 3}) {
+                        for (int bs3 : {1, 3}) {
+                            for (int nr2 : {1, 2}) {
+                                for (int nr3 : {1, 2}) {
+                                    test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, n, k, {bs2, bs3}, {nr2, nr3}));
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+
+    for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
+        test_cases.emplace_back(new test_sqr(type));
+        test_cases.emplace_back(new test_sqrt(type));
+        test_cases.emplace_back(new test_log(type));
+        test_cases.emplace_back(new test_sin(type));
+        test_cases.emplace_back(new test_cos(type));
+        test_cases.emplace_back(new test_clamp(type));
+    }
+
+    test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));
+    test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 3, 1}, 5));
+    test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 3, 2}, 5));
+
+#if 0
+    std::uniform_int_distribution<> dist_ne1(1, 50);
+    int exponent = 1;
+    while (exponent < (1 << 17)) {
+        std::uniform_int_distribution<> dist_ne0(exponent, 2*exponent);
+
+        for (int n = 0; n < 10; ++n) {
+            int64_t ne0 = dist_ne0(rng);
+            int64_t ne1 = dist_ne1(rng);
+            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, 0.1f, ne0 < 1000 ? 4.0f : 0.0f));
+        }
+
+        exponent <<= 1;
+    }
+#endif
+    for (bool mask : {false, true}) {
+        for (float max_bias : {0.0f, 8.0f}) {
+            if (!mask && max_bias > 0.0f) continue;
+            for (float scale : {1.0f, 0.1f}) {
+                for (int64_t ne0 : {16, 1024}) {
+                    for (int64_t ne1 : {16, 1024}) {
+                        if (mask) {
+                            for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+                                test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, m_prec, scale, max_bias));
+                                test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, scale, max_bias));
+                            }
+                        } else {
+                            /* The precision of mask here doesn't matter as boolean mask is false */
+                            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, GGML_TYPE_F32, scale, max_bias));
+                            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, GGML_TYPE_F32, scale, max_bias));
+                        }
+                    }
+                }
+            }
+        }
+    }
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F32,  0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F16,  0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, GGML_TYPE_F32, 0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32,  0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16,  0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32,  0.1f, 8.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16,  0.1f, 8.0f));
+
+    for (float max_bias : {0.0f, 8.0f}) {
+        for (float scale : {1.0f, 0.1f}) {
+            for (int64_t ne0 : {16, 1024}) {
+                for (int64_t ne1 : {16, 1024}) {
+                    test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, scale, max_bias));
+                    test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, scale, max_bias));
+                }
+            }
+        }
+    }
+
+    for (bool fw : {true, false}) { // fw == forward
+        bool all = true;
+
+        for (float v : { 0, 1 }) {
+            for (float fs : { 1.0f, 1.4245f }) {
+                for (float ef : { 0.0f, 0.7465f }) {
+                    for (float af : { 1.0f, 1.4245f }) {
+                        for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+                            for (bool ff : {false, true}) { // freq_factors
+                                test_cases.emplace_back(new test_rope(type, {128,  32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B
+
+                                if (all) {
+                                    test_cases.emplace_back(new test_rope(type, {128,  40, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 13B
+                                    test_cases.emplace_back(new test_rope(type, {128,  52, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 30B
+                                    test_cases.emplace_back(new test_rope(type, {128,  64, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 65B
+                                }
+
+                                if (all) {
+                                    test_cases.emplace_back(new test_rope(type, { 64,   1, 2, 1},  64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
+                                    test_cases.emplace_back(new test_rope(type, { 64,  71, 2, 1},  64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
+                                    test_cases.emplace_back(new test_rope(type, { 64,   8, 2, 1},  64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
+                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
+                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
+                                }
+
+                                if (all) {
+                                    test_cases.emplace_back(new test_rope(type, {128,  12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE,  512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B)
+                                    test_cases.emplace_back(new test_rope(type, {128,  28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE,  512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B)
+                                    test_cases.emplace_back(new test_rope(type, { 80,  16, 2, 1},  80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
+                                }
+
+                                test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1},  64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
+                            }
+                        }
+
+                        all = false;
+                    }
+                }
+            }
+        }
+    }
+
+    for (int v : { 0, 1, 2, 3 }) {
+        for (int dim : { 0, 1, 2, 3, }) {
+            test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim, v));
+            test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim, v));
+        }
+    }
+
+    for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {
+        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
+        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
+        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
+    }
+
+    for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {
+        test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));
+        test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));
+        test_cases.emplace_back(new test_upscale_ext(GGML_TYPE_F32, {2, 5,  7, 11}, {5, 7, 11, 13}, mode));
+    }
+
+    test_cases.emplace_back(new test_sum());
+    test_cases.emplace_back(new test_sum_rows());
+    test_cases.emplace_back(new test_mean());
+    test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
+    test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
+    test_cases.emplace_back(new test_acc());
+    test_cases.emplace_back(new test_pad());
+    test_cases.emplace_back(new test_pad_reflect_1d());
+    test_cases.emplace_back(new test_arange());
+    test_cases.emplace_back(new test_timestep_embedding());
+    test_cases.emplace_back(new test_leaky_relu());
+
+    for (int hsk : { 64, 80, 128, 192, 256, 576 }) {
+        for (int hsv : { 64, 80, 128, 192, 256, 512 }) {
+            if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
+            if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
+            if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
+
+            for (bool mask : { true, false } ) {
+                for (float max_bias : { 0.0f, 8.0f }) {
+                    if (!mask && max_bias > 0.0f) continue;
+                    for (float logit_softcap : {0.0f, 10.0f}) {
+                        if (hsk != 128 && logit_softcap != 0.0f) continue;
+                        for (int nh : { 4, }) {
+                            for (int nr : { 1, 4, 16 }) {
+                                if (nr == 16 && hsk != 128) continue;
+                                for (int kv : { 512, 1024, }) {
+                                    if (nr != 1 && kv != 512) continue;
+                                    for (int nb : { 1, 3, 32, 35, }) {
+                                        for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
+                                            if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
+                                            for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
+                                                test_cases.emplace_back(new test_flash_attn_ext(
+                                                    hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
+                                                // run fewer test cases permuted
+                                                if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
+                                                    test_cases.emplace_back(new test_flash_attn_ext(
+                                                        hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
+                                                }
+                                            }
+                                        }
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+
+    test_cases.emplace_back(new test_cross_entropy_loss     (GGML_TYPE_F32, {   10, 5, 4, 3}));
+    test_cases.emplace_back(new test_cross_entropy_loss     (GGML_TYPE_F32, {30000, 1, 1, 1}));
+    test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, {   10, 5, 4, 3}));
+    test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, {30000, 1, 1, 1}));
+
+    test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
+
+    // these tests are disabled to save execution time, but they can be handy for debugging
+#if 0
+    test_cases.emplace_back(new test_llama(1));
+    test_cases.emplace_back(new test_llama(2));
+    test_cases.emplace_back(new test_falcon(1));
+    test_cases.emplace_back(new test_falcon(2));
+#endif
+
+    return test_cases;
+}
+
+// Test cases for performance evaluation: should be representative of real-world use cases
+static std::vector> make_test_cases_perf() {
+    std::vector> test_cases;
+
+    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1,   1, 1, 1}));
+    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
+
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1}));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
+
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
+
+    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));
+    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
+    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
+
+    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8,  1}, {4, 1}, {0, 2, 1, 3}));
+    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8,  1}, {4, 1}, {0, 1, 2, 3}, true));
+
+    for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
+        for (ggml_type type_a : all_types) {
+            for (ggml_type type_b : {GGML_TYPE_F32}) {
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, bs, 14336, {1,  1}, {1, 1}));
+            }
+        }
+    }
+
+    for (int K : {3, 5}) {
+        for (int IC : {256, 2560}) {
+            for (int IW_IH : {32, 64, 256}) {
+                if (IC == 2560 && IW_IH == 256) {
+                    // too big
+                    continue;
+                }
+                test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {IW_IH, IW_IH, IC, 1}, {K, K, IC, 1}, 1, 1, 1, 1, 1, 1, true));
+            }
+        }
+    }
+
+    for (int kv : { 4096, 8192, 16384, }) {
+        for (int hs : { 64, 128, }) {
+            for (int nr : { 1, 4, }) {
+                test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, nr, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
+            }
+        }
+    }
+
+    test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false));
+    test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));
+
+    return test_cases;
+}
+
+static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name, const char * params_filter) {
+    auto filter_test_cases = [](std::vector> & test_cases, const char * params_filter) {
+        if (params_filter == nullptr) {
+            return;
+        }
+
+        std::regex params_filter_regex(params_filter);
+
+        for (auto it = test_cases.begin(); it != test_cases.end();) {
+            if (!std::regex_search((*it)->vars(), params_filter_regex)) {
+                it = test_cases.erase(it);
+                continue;
+            }
+
+            it++;
+        }
+    };
+
+    if (mode == MODE_TEST) {
+        auto test_cases = make_test_cases_eval();
+        filter_test_cases(test_cases, params_filter);
+        ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL);
+        if (backend_cpu == NULL) {
+            printf("  Failed to initialize CPU backend\n");
+            return false;
+        }
+
+        size_t n_ok = 0;
+        for (auto & test : test_cases) {
+            if (test->eval(backend, backend_cpu, op_name)) {
+                n_ok++;
+            }
+        }
+        printf("  %zu/%zu tests passed\n", n_ok, test_cases.size());
+
+        ggml_backend_free(backend_cpu);
+
+        return n_ok == test_cases.size();
+    }
+
+    if (mode == MODE_GRAD) {
+        auto test_cases = make_test_cases_eval();
+        filter_test_cases(test_cases, params_filter);
+        size_t n_ok = 0;
+        for (auto & test : test_cases) {
+            if (test->eval_grad(backend, op_name)) {
+                n_ok++;
+            }
+        }
+        printf("  %zu/%zu tests passed\n", n_ok, test_cases.size());
+
+        return n_ok == test_cases.size();
+    }
+
+    if (mode == MODE_PERF) {
+        auto test_cases = make_test_cases_perf();
+        filter_test_cases(test_cases, params_filter);
+        for (auto & test : test_cases) {
+            test->eval_perf(backend, op_name);
+        }
+        return true;
+    }
+
+    GGML_ABORT("fatal error");
+}
+
+static void usage(char ** argv) {
+    printf("Usage: %s [mode] [-o ] [-b ] [-p ]\n", argv[0]);
+    printf("    valid modes:\n");
+    printf("      - test (default, compare with CPU backend for correctness)\n");
+    printf("      - grad (compare gradients from backpropagation with method of finite differences)\n");
+    printf("      - perf (performance evaluation)\n");
+    printf("    op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc)\n");
+}
+
+int main(int argc, char ** argv) {
+    test_mode mode = MODE_TEST;
+    const char * op_name_filter = nullptr;
+    const char * backend_filter = nullptr;
+    const char * params_filter = nullptr;
+
+    for (int i = 1; i < argc; i++) {
+        if (strcmp(argv[i], "test") == 0) {
+            mode = MODE_TEST;
+        } else if (strcmp(argv[i], "perf") == 0) {
+            mode = MODE_PERF;
+        } else if (strcmp(argv[i], "grad") == 0) {
+            mode = MODE_GRAD;
+        } else if (strcmp(argv[i], "-o") == 0) {
+            if (i + 1 < argc) {
+                op_name_filter = argv[++i];
+            } else {
+                usage(argv);
+                return 1;
+            }
+        } else if (strcmp(argv[i], "-b") == 0) {
+            if (i + 1 < argc) {
+                backend_filter = argv[++i];
+            } else {
+                usage(argv);
+                return 1;
+            }
+        } else if (strcmp(argv[i], "-p") == 0) {
+            if (i + 1 < argc) {
+                params_filter = argv[++i];
+            } else {
+                usage(argv);
+                return 1;
+            }
+        } else {
+            usage(argv);
+            return 1;
+        }
+    }
+
+    // load and enumerate backends
+    ggml_backend_load_all();
+
+    printf("Testing %zu devices\n\n", ggml_backend_dev_count());
+
+    size_t n_ok = 0;
+
+    for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
+        ggml_backend_dev_t dev = ggml_backend_dev_get(i);
+
+        printf("Backend %zu/%zu: %s\n", i + 1, ggml_backend_dev_count(), ggml_backend_dev_name(dev));
+
+        if (backend_filter != NULL && strcmp(backend_filter, ggml_backend_dev_name(dev)) != 0) {
+            printf("  Skipping\n");
+            n_ok++;
+            continue;
+        }
+
+        if (backend_filter == NULL && ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && mode != MODE_GRAD) {
+            printf("  Skipping CPU backend\n");
+            n_ok++;
+            continue;
+        }
+
+        ggml_backend_t backend = ggml_backend_dev_init(dev, NULL);
+        GGML_ASSERT(backend != NULL);
+
+        ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
+        auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
+        if (ggml_backend_set_n_threads_fn) {
+            // TODO: better value for n_threads
+            ggml_backend_set_n_threads_fn(backend, std::thread::hardware_concurrency());
+        }
+
+        printf("  Device description: %s\n", ggml_backend_dev_description(dev));
+        size_t free, total; // NOLINT
+        ggml_backend_dev_memory(dev, &free, &total);
+        printf("  Device memory: %zu MB (%zu MB free)\n", total / 1024 / 1024, free / 1024 / 1024);
+        printf("\n");
+
+        bool ok = test_backend(backend, mode, op_name_filter, params_filter);
+
+        printf("  Backend %s: ", ggml_backend_name(backend));
+        if (ok) {
+            printf("\033[1;32mOK\033[0m\n");
+            n_ok++;
+        } else {
+            printf("\033[1;31mFAIL\033[0m\n");
+        }
+
+        printf("\n");
+
+        ggml_backend_free(backend);
+    }
+
+    ggml_quantize_free();
+
+    printf("%zu/%zu backends passed\n", n_ok, ggml_backend_dev_count());
+
+    if (n_ok != ggml_backend_dev_count()) {
+        printf("\033[1;31mFAIL\033[0m\n");
+        return 1;
+    }
+
+    printf("\033[1;32mOK\033[0m\n");
+    return 0;
+}
diff --git a/tests/test-barrier.cpp b/tests/test-barrier.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..d85bf912b22e22a3f1a5da5459c5bb41cf165ccb
--- /dev/null
+++ b/tests/test-barrier.cpp
@@ -0,0 +1,94 @@
+#include "ggml.h"
+#include "ggml-cpu.h"
+#include "ggml-backend.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#define MAX_NARGS 2
+
+int main(int argc, char *argv[]) {
+
+    int n_threads = 4;
+    int n_rounds  = 100;
+
+    if (argc > 1) {
+        n_threads = std::atoi(argv[1]);
+    }
+
+    if (argc > 2) {
+        n_rounds  = std::atoi(argv[2]);
+    }
+
+    struct ggml_init_params params = {
+        /* .mem_size   = */ 1024*1024*1024,
+        /* .mem_buffer = */ NULL,
+        /* .no_alloc   = */ false,
+    };
+
+    struct ggml_context * ctx = ggml_init(params);
+
+    // Create graph
+    struct ggml_cgraph * gf = ggml_new_graph(ctx);
+
+    // Lots of small, parallel ops where barriers in between will dominate
+    struct ggml_tensor * out = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,  64);
+    for (int i = 0; i < 1000; i++) {
+        struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 64, 128);
+        out = ggml_mul_mat(ctx, a, out);
+
+        struct ggml_tensor * d = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 128, 64);
+        out = ggml_mul_mat(ctx, d, out);
+    }
+
+    ggml_build_forward_expand(gf, out);
+    int n_nodes = ggml_graph_n_nodes(gf);
+
+    // Create threadpool
+    struct ggml_threadpool_params tpp  = ggml_threadpool_params_default(n_threads);
+    struct ggml_threadpool* threadpool = ggml_threadpool_new(&tpp);
+    if (!threadpool) {
+        fprintf(stderr, "threadpool create failed : n_threads %d\n", n_threads);
+        exit(1);
+    }
+
+    // Create compute plan
+    struct ggml_cplan cplan = ggml_graph_plan(gf, n_threads, threadpool);
+
+    std::vector work_data(cplan.work_size);
+    cplan.work_data = work_data.data();
+
+    std::cerr << "graph-compute with"
+              << "\n n_threads: " << n_threads
+              << "\n   n_nodes: " << n_nodes
+              << "\n  n_rounds: " << n_rounds
+              << "\n";
+    // ggml_graph_print(gf);
+
+    // Warmup
+    ggml_graph_compute(gf, &cplan);
+
+    auto t0 = std::chrono::high_resolution_clock::now();
+
+    for (int i=0; i < n_rounds; i++) {
+        ggml_graph_compute(gf, &cplan);
+    }
+
+    auto t1 = std::chrono::high_resolution_clock::now();
+
+    auto usec = std::chrono::duration_cast(t1-t0).count();
+    auto nsec = std::chrono::duration_cast(t1-t0).count();
+    std::cerr << "graph-compute took " << usec << " usec "
+              << "\n " << (float) usec / n_rounds << " usec per-iter"
+              << "\n " << (float) nsec / (n_rounds * n_nodes) << " nsec per-node"
+              << "\n";
+
+    ggml_threadpool_free(threadpool);
+    ggml_free(ctx);
+
+    return 0;
+}
diff --git a/tests/test-c.c b/tests/test-c.c
new file mode 100644
index 0000000000000000000000000000000000000000..95ba73df39a3c0341d1cdba53ffec2999f7893d9
--- /dev/null
+++ b/tests/test-c.c
@@ -0,0 +1,7 @@
+#include "llama.h"
+
+#ifdef GGML_USE_KOMPUTE
+#include "ggml-kompute.h"
+#endif
+
+int main(void) {}
diff --git a/tests/test-chat-parser.cpp b/tests/test-chat-parser.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..59e44e07d25ed510a7117beb1835c859ac70e5f9
--- /dev/null
+++ b/tests/test-chat-parser.cpp
@@ -0,0 +1,352 @@
+//  Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
+//
+//  Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
+//  e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
+//
+//    cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
+//
+#include 
+#include 
+#include 
+
+#include "chat-parser.h"
+#include "common.h"
+#include "log.h"
+#include "regex-partial.h"
+
+template 
+static void assert_equals(const T & expected, const T & actual) {
+    if (expected != actual) {
+        std::cerr << "Expected: " << expected << std::endl;
+        std::cerr << "Actual: " << actual << std::endl;
+        std::cerr << std::flush;
+        throw std::runtime_error("Test failed");
+    }
+}
+static void assert_equals(const char * expected, const std::string & actual) {
+  return assert_equals(expected, actual);
+}
+
+static void assert_throws(const std::function & fn, const std::string & expected_exception_pattern = "") {
+    try {
+        fn();
+    } catch (const std::exception & e) {
+      if (expected_exception_pattern.empty()) {
+          return;
+        }
+        std::regex expected_exception_regex(expected_exception_pattern);
+        std::string actual_message = e.what();
+        if (std::regex_search(actual_message, expected_exception_regex)) {
+            return;
+        }
+        throw std::runtime_error("Exception doesn't match expected pattern: " + actual_message + " (pattern: " + expected_exception_pattern + ")");
+        throw std::runtime_error("Exception of unexpected type: " + std::string(e.what()));
+    }
+    throw std::runtime_error("Exception was expected but not thrown");
+}
+
+static void test_reasoning() {
+  {
+    common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, {
+        /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
+        /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
+        /* .reasoning_in_content = */ false,
+        /* .thinking_forced_open = */ false,
+    });
+    assert_equals(false, builder.try_parse_reasoning("", ""));
+    assert_equals("CogitoErgo sum", builder.consume_rest());
+  }
+  {
+    common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, {
+        /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
+        /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+        /* .reasoning_in_content = */ false,
+        /* .thinking_forced_open = */ false,
+    });
+    assert_equals(true, builder.try_parse_reasoning("", ""));
+    assert_equals(std::string("Cogito"), builder.result().reasoning_content);
+    assert_equals("Ergo sum", builder.consume_rest());
+  }
+  {
+    common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, {
+        /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
+        /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
+        /* .reasoning_in_content = */ false,
+        /* .thinking_forced_open = */ false,
+    });
+    assert_equals(false, builder.try_parse_reasoning("", ""));
+    assert_equals("CogitoErgo sum", builder.consume_rest());
+  }
+  {
+    common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, {
+        /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
+        /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+        /* .reasoning_in_content = */ false,
+        /* .thinking_forced_open = */ true,
+    });
+    assert_equals(true, builder.try_parse_reasoning("", ""));
+    assert_equals(std::string("Cogito"), builder.result().reasoning_content);
+    assert_equals("Ergo sum", builder.consume_rest());
+  }
+  {
+    common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, {
+        /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
+        /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+        /* .reasoning_in_content = */ true,
+        /* .thinking_forced_open = */ true,
+    });
+    assert_equals(true, builder.try_parse_reasoning("", ""));
+    assert_equals("Cogito", builder.result().content);
+    assert_equals("Ergo sum", builder.consume_rest());
+  }
+}
+
+static void test_regex() {
+  auto test_throws = [](const std::string & input, const std::string & regex, const std::string & expected_exception_pattern = "") {
+    common_chat_msg_parser builder(input, /* is_partial= */ false, {});
+    assert_throws([&]() { builder.consume_regex(common_regex(regex)); }, expected_exception_pattern);
+  };
+
+  test_throws("Hello, world!", "abc", "^abc$");
+  test_throws("Hello, world!", "e", "^e$");
+
+  {
+    common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
+    builder.consume_regex(common_regex("Hello"));
+    assert_equals(", world!", builder.consume_rest());
+  }
+
+  {
+    // When in non partial mode, we can say whether the regex was consumed or not.
+    common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
+    assert_equals(false, builder.try_consume_regex(common_regex("Hello, world!")).has_value());
+  }
+  {
+    common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
+    auto res = builder.try_consume_regex(common_regex("H(el)l(?:o, world!)?"));
+    assert_equals(true, res.has_value());
+    // Verify captures
+    assert_equals(2, res->groups.size());
+    assert_equals("Hell", builder.str(res->groups[0]));
+    assert_equals("el", builder.str(res->groups[1]));
+    // Verify position is after the match
+    assert_equals(4, builder.pos());
+    assert_equals("o,", builder.consume_rest());
+  }
+  {
+    // But in partial mode, we have a partial final match / can't decide, so we throw a partial exception.
+    common_chat_msg_parser builder("Hello,", /* is_partial= */ true, {});
+    assert_throws([&]() {
+      builder.try_consume_regex(common_regex("Hello, world!"));
+    }, "^Hello, world!$");
+  }
+
+  // Now regardless of the mode, we can tell these aren't a match.
+  for (const auto is_partial : {false, true}) {
+    common_chat_msg_parser builder("Hello,", is_partial, {});
+    assert_equals(false, builder.try_consume_regex(common_regex("a(b|c)(d|e)f")).has_value());
+  }
+  for (const auto is_partial : {false, true}) {
+    common_chat_msg_parser builder("Hello,", is_partial, {});
+    assert_equals(false, builder.try_consume_literal("Oh"));
+  }
+}
+
+const std::vector barely_healable_jsons = {
+  "{",
+  "{\"",
+  "{\"\\",
+  "{\"n",
+  "{\"name\"",
+  "{\"name\":",
+  "{\"name\":\"",
+  "{\"name\":\"\\",
+  "{\"name\":\"python",
+  "{\"name\":\"python\\",
+  "{\",",
+  "{\":",
+  "{\"[",
+  "{\"]",
+  "{\"{",
+  "{\"}",
+  "{\"1",
+  "{\"name\":\",",
+  "{\"name\":\":",
+  "{\"name\":\"[",
+  "{\"name\":\"]",
+  "{\"name\":\"{",
+  "{\"name\":\"}",
+  "{\"name\":\"1",
+};
+
+static void test(const std::string & input, bool is_partial, const std::vector> & args_paths, const std::vector> & content_paths, const std::string & expected) {
+  common_chat_msg_parser builder(input, is_partial, {});
+  auto js = builder.try_consume_json_with_dumped_args(args_paths, content_paths);
+  assert_equals(true, js.has_value());
+  assert_equals(is_partial, js->is_partial);
+  assert_equals(expected, args_paths.size() == 1 && args_paths[0].empty() ? js->value.get() : js->value.dump());
+}
+static void test_with_args(const std::string & input, const std::string & expected, bool parse_as_partial = true, bool is_partial = true) {
+  common_chat_msg_parser builder(input, parse_as_partial, {});
+  auto js = builder.try_consume_json_with_dumped_args({{"args"}}, {});
+  assert_equals(true, js.has_value());
+  assert_equals(is_partial, js->is_partial);
+  assert_equals(expected, js->value.dump());
+}
+
+static void test_json_with_dumped_args_no_args() {
+  // Normal JSON, nothing to heal, nothing to dump
+  test("{\"name\": \"python\"}", false, {}, {}, "{\"name\":\"python\"}");
+  // Full json is args
+  test("{\"name\": \"python\"}", false, {{}}, {}, "{\"name\":\"python\"}");
+
+  // If the arguments are further down, don't heal partial content.
+  for (const auto & src : barely_healable_jsons) {
+    test(src, true, {{"arguments"}}, {}, "{}");
+  }
+  // But heal content that isn't partial.
+  test("{\"name\": \"python\"", true, {{"arguments"}}, {}, "{\"name\":\"python\"}");
+}
+
+static void test_json_with_dumped_args() {
+
+  // Partial content.
+  test("{\"content\": \"t", true, {}, {{"content"}}, "{\"content\":\"t\"}");
+  test("{\"content\": \"", true, {}, {{"content"}}, "{\"content\":\"\"}");
+  test("{\"content\": ", true, {}, {{"content"}}, "{}");
+
+  // If the entire JSON is the arguments, healing it them dumping it produces the same output as the input (just reformatted).
+  test("{\"name\": \"python", true, {{}}, {}, "{\"name\":\"python");
+  for (const auto & src : barely_healable_jsons) {
+    test(src, true, {{}}, {}, src);
+  }
+
+  // Full JSON w/ args
+  for (auto parse_as_partial : {true, false}) {
+    test_with_args(
+      R"({"name": "python", "args": {"arg1": 1}})",
+      R"({"name":"python","args":"{\"arg1\":1}"})",
+      parse_as_partial,
+      /* is_partial= */ false
+    );
+  }
+
+  // Partial JSON w/ partial args
+  test_with_args(
+    R"({"foo": "bar", "args": {")",
+    R"({"foo":"bar","args":"{\""})"
+  );
+  // Partial args broken in object key
+  test_with_args(
+    R"({"foo": "bar", "args": {"ar)",
+    R"({"foo":"bar","args":"{\"ar"})"
+  );
+  // Partial args broken after object key
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1")",
+    R"({"foo":"bar","args":"{\"arg1\""})"
+  );
+  // Partial args broken before object value
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1":)",
+    R"({"foo":"bar","args":"{\"arg1\":"})"
+  );
+  // Partial args broken before object value (space)
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": )",
+    R"({"foo":"bar","args":"{\"arg1\":"})"
+  );
+  // Partial args broken in object value that may not be complete (int)
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": 1)",
+    R"({"foo":"bar","args":"{\"arg1\":"})"
+  );
+  // Partial args broken in object value that is complete (int)
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": 1 )",
+    R"({"foo":"bar","args":"{\"arg1\":1"})"
+  );
+  // Partial args broken in object value that is incomplete (string)
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": ")",
+    R"({"foo":"bar","args":"{\"arg1\":\""})"
+  );
+  // Partial args broken in object value that is complete (string)
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": "1")",
+    R"({"foo":"bar","args":"{\"arg1\":\"1\""})"
+  );
+  // Partial args broken on array opening
+  test_with_args(
+    R"({"foo": "bar", "args": [)",
+    R"({"foo":"bar","args":"["})"
+  );
+  // Partial args broken on array value that is incomplete (int)
+  test_with_args(
+    R"({"foo": "bar", "args": [1)",
+    R"({"foo":"bar","args":"["})"
+  );
+  // Partial args broken on array value that is complete (int)
+  test_with_args(
+    R"({"foo": "bar", "args": [1 )",
+    R"({"foo":"bar","args":"[1"})"
+  );
+  // Partial args broken on array value that is complete (string)
+  test_with_args(
+    R"({"foo": "bar", "args": ["1")",
+    R"({"foo":"bar","args":"[\"1\""})"
+  );
+  // Partial args broken after array value
+  test_with_args(
+    R"({"foo": "bar", "args": [1,)",
+    R"({"foo":"bar","args":"[1,"})"
+  );
+  // Partial args broken on nested array
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": [)",
+    R"({"foo":"bar","args":"{\"arg1\":["})"
+  );
+}
+
+static void test_positions() {
+  {
+    common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
+    assert_equals(0, builder.pos());
+    assert_throws([&]() { builder.move_to(100); });
+    assert_equals(0, builder.pos());
+    assert_throws([&]() { builder.move_back(1); });
+    assert_equals(0, builder.pos());
+
+    builder.move_to(8);
+    assert_equals(8, builder.pos());
+    builder.move_back(1);
+    assert_equals(7, builder.pos());
+    assert_equals("world!", builder.consume_rest());
+
+    builder.move_to(0);
+    assert_equals(0, builder.pos());
+
+    assert_throws([&]() { builder.finish(); });
+    assert_equals(0, builder.pos());
+
+    builder.move_to(builder.input().size());
+    builder.finish();
+  }
+  {
+    common_chat_msg_parser builder("Hello, world!", /* is_partial= */ true, {});
+
+    builder.move_to(builder.input().size());
+    assert_equals(builder.input().size(), builder.pos());
+    builder.finish();
+  }
+}
+
+int main() {
+    test_positions();
+    test_json_with_dumped_args_no_args();
+    test_json_with_dumped_args();
+    test_reasoning();
+    test_regex();
+    std::cout << "All tests passed!\n";
+    return 0;
+}
diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..a0a50f9881fe0381661c1da92fbf2a3da20f8df5
--- /dev/null
+++ b/tests/test-chat-template.cpp
@@ -0,0 +1,417 @@
+#include 
+#include 
+#include 
+#include 
+
+#undef NDEBUG
+#include 
+
+#include "llama.h"
+#include "common.h"
+#include "chat.h"
+
+static std::string normalize_newlines(const std::string & s) {
+#ifdef _WIN32
+  static const std::regex nl_regex("\r\n");
+  return std::regex_replace(s, nl_regex, "\n");
+#else
+  return s;
+#endif
+}
+
+#define U8C(x) (const char*)(u8##x)
+
+static common_chat_msg simple_msg(const std::string & role, const std::string & content) {
+    common_chat_msg msg;
+    msg.role = role;
+    msg.content = content;
+    return msg;
+}
+
+int main(void) {
+    std::vector conversation {
+        {"system", "You are a helpful assistant"},
+        {"user", "Hello"},
+        {"assistant", "Hi there"},
+        {"user", "Who are you"},
+        {"assistant", "   I am an assistant   "},
+        {"user", "Another question"},
+    };
+
+    // std::string wrong = /* .template_str= */ u8"[gMASK]{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}";
+    struct TestCase {
+        std::string name;
+        std::string template_str;
+        std::string expected_output;
+        std::string expected_output_jinja;
+        std::string bos_token = "";
+        std::string eos_token = "";
+        bool supported_with_jinja = true;
+    };
+    std::vector test_cases {
+        {
+            /* .name= */ "teknium/OpenHermes-2.5-Mistral-7B",
+            /* .template_str= */ "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
+            /* .expected_output= */ "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n   I am an assistant   <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt)",
+            /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
+            /* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there[INST] Who are you [/INST]   I am an assistant   [INST] Another question [/INST]",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "TheBloke/FusionNet_34Bx2_MoE-AWQ",
+            /* .template_str= */ "{%- for idx in range(0, messages|length) -%}\n{%- if messages[idx]['role'] == 'user' -%}\n{%- if idx > 1 -%}\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\n{%- else -%}\n{{- messages[idx]['content'] + ' [/INST]' -}}\n{%- endif -%}\n{% elif messages[idx]['role'] == 'system' %}\n{{- '[INST] <>\\n' + messages[idx]['content'] + '\\n<>\\n\\n' -}}\n{%- elif messages[idx]['role'] == 'assistant' -%}\n{{- ' '  + messages[idx]['content'] + ' ' + eos_token -}}\n{% endif %}\n{% endfor %}",
+            /* .expected_output= */       "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST]Hi there[INST] Who are you [/INST]   I am an assistant   [INST] Another question [/INST]",
+            /* .expected_output_jinja= */ "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST]    I am an assistant    [INST] Another question [/INST]",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "bofenghuang/vigogne-2-70b-chat",
+            /* .template_str= */ "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' '  + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
+            /* .expected_output= */       "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST]Hi there[INST] Who are you [/INST]I am an assistant[INST] Another question [/INST]",
+            /* .expected_output_jinja= */ "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "mlabonne/AlphaMonarch-7B",
+            /* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}",
+            /* .expected_output= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n   I am an assistant   \nuser\nAnother question\nassistant\n",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "google/gemma-7b-it",
+            /* .template_str= */ "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}",
+            /* .expected_output= */       "user\nYou are a helpful assistant\n\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n",
+            /* .expected_output_jinja= */ "user\nYou are a helpful assistant\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n",
+        },
+        {
+            /* .name= */ "OrionStarAI/Orion-14B-Chat",
+            /* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}",
+            /* .expected_output= */       "Human: You are a helpful assistant\n\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant:    I am an assistant   Human: Another question\n\nAssistant: ",
+            /* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant:    I am an assistant   Human: Another question\n\nAssistant: ",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "openchat/openchat-3.5-0106",
+            // The included chat_template differs from the author's suggestions here: https://huggingface.co/openchat/openchat_3.5/discussions/5#65448109b4a3f3a2f486fd9d
+            // So we match against the included template but implement the suggested version.
+            /* .template_str= */ "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}",
+            /* .expected_output= */                            "You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant:    I am an assistant   <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:",
+            /* .expected_output_jinja= */ "GPT4 Correct System: You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant:    I am an assistant   <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:",
+        },
+        {
+            /* .name= */ "deepseek-ai/deepseek-coder-33b-instruct",
+            /* .template_str= */ "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n    {%- if message['role'] == 'system' -%}\n        {%- set ns.found = true -%}\n    {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n    {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n    {%- else %}\n        {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n        {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}",
+            /* .expected_output= */ "You are a helpful assistant### Instruction:\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\n   I am an assistant   \n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n",
+            /* .expected_output_jinja= */ "",
+        },
+        {
+            /* .name= */ "eachadea/vicuna-13b-1.1",
+            // No template included in tokenizer_config.json, so this template likely needs to be manually set.
+            /* .template_str= */ "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{- '' + message['content'] + '\n\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}",
+            /* .expected_output= */ "You are a helpful assistant\n\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT:    I am an assistant   \nUSER: Another question\nASSISTANT:",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "Orca-Vicuna",
+            // No template included in tokenizer_config.json, so this template likely needs to be manually set.
+            /* .template_str= */ "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{-'SYSTEM: ' + message['content'] + '\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}",
+            /* .expected_output= */ "SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT:    I am an assistant   \nUSER: Another question\nASSISTANT:",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "CohereForAI/c4ai-command-r-plus",
+            /* .template_str= */ "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>'  + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
+            /* .expected_output= */ "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
+            /* .expected_output_jinja= */ "",
+        },
+        {
+            /* .name= */ "Llama-3",
+            /* .template_str= */ "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
+            /* .expected_output= */ "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi there<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI am an assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nAnother question<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
+            /* .expected_output_jinja= */ "",
+        },
+        {
+            /* .name= */ "Phi-3-mini",
+            /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
+            /* .expected_output= */     "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n   I am an assistant   <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
+            /* .expected_output_jinja= */ "<|user|>\nYou are a helpful assistant\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n   I am an assistant   <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
+        },
+        {
+            /* .name= */ "Phi-3-small",
+            /* .template_str= */ "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
+            /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n   I am an assistant   <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
+            /* .expected_output_jinja= */ "",
+        },
+        {
+            /* .name= */ "Phi-3-medium",
+            /* .template_str= */ "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
+            /* .expected_output= */     "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n   I am an assistant   <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
+            /* .expected_output_jinja= */ "<|user|>\nYou are a helpful assistant\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n   I am an assistant   <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
+        },
+        {
+            /* .name= */ "Phi-3-vision",
+            /* .template_str= */ "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}",
+            /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n   I am an assistant   <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "ChatGLM3",
+            /* .template_str= */ "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
+            /* .expected_output= */       "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n    I am an assistant   <|user|>\n Another question<|assistant|>",
+            /* .expected_output_jinja= */ "[gMASK]sop<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n   I am an assistant   <|user|>\nAnother question<|assistant|>",
+        },
+        {
+            /* .name= */ "ChatGLM4",
+            /* .template_str= */ U8C("[gMASK]{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>\n{% endif %}"),
+            /* .expected_output= */ "[gMASK]<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n   I am an assistant   <|user|>\nAnother question<|assistant|>\n",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "GLMEdge",
+            /* .template_str= */ "{% for item in messages %}{% if item['role'] == 'system' %}<|system|>\n{{ item['content'] }}{% elif item['role'] == 'user' %}<|user|>\n{{ item['content'] }}{% elif item['role'] == 'assistant' %}<|assistant|>\n{{ item['content'] }}{% endif %}{% endfor %}<|assistant|>",
+            /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n   I am an assistant   <|user|>\nAnother question<|assistant|>",
+            /* .expected_output_jinja= */ "<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n   I am an assistant   <|user|>\nAnother question<|assistant|>",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "MiniCPM-3B-OpenHermes-2.5-v2-GGUF",
+            /* .template_str= */ U8C("{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + ''}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}"),
+            /* .expected_output= */ U8C("You are a helpful assistant<用户>HelloHi there<用户>Who are youI am an assistant<用户>Another question"),
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "DeepSeek-V2",
+            /* .template_str= */ "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
+            /* .expected_output= */ U8C("You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant:    I am an assistant   <|end▁of▁sentence|>User: Another question\n\nAssistant:"),
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "<|end▁of▁sentence|>",
+        },
+        {
+            /* .name= */ "ibm-granite/granite-3.0-8b-instruct",
+            /* .template_str= */ "{%- if tools %}\n    {{- '<|start_of_role|>available_tools<|end_of_role|>\n' }}\n    {%- for tool in tools %}\n    {{- tool | tojson(indent=4) }}\n    {%- if not loop.last %}\n        {{- '\n\n' }}\n    {%- endif %}\n    {%- endfor %}\n    {{- '<|end_of_text|>\n' }}\n{%- endif %}\n{%- for message in messages %}\n    {%- if message['role'] == 'system' %}\n    {{- '<|start_of_role|>system<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n    {%- elif message['role'] == 'user' %}\n    {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n    {%- elif message['role'] == 'assistant' %}\n    {{- '<|start_of_role|>assistant<|end_of_role|>'  + message['content'] + '<|end_of_text|>\n' }}\n    {%- elif message['role'] == 'assistant_tool_call' %}\n    {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message['content'] + '<|end_of_text|>\n' }}\n    {%- elif message['role'] == 'tool_response' %}\n    {{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n    {%- endif %}\n    {%- if loop.last and add_generation_prompt %}\n    {{- '<|start_of_role|>assistant<|end_of_role|>' }}\n    {%- endif %}\n{%- endfor %}",
+            /* .expected_output= */       "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>   I am an assistant   <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>\n",
+            /* .expected_output_jinja= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>   I am an assistant   <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>",
+        },
+        {
+            /* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (mistralai 'v1' template with a system prompt)",
+            /* .template_str= */ "{%- if messages[0]['role'] == 'system' %}\n    {%- set system_message = messages[0]['content'] %}\n    {%- set loop_messages = messages[1:] %}\n{%- else %}\n    {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n    {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n        {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n    {%- endif %}\n    {%- if message['role'] == 'user' %}\n        {%- if loop.first and system_message is defined %}\n            {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n        {%- else %}\n            {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n        {%- endif %}\n    {%- elif message['role'] == 'assistant' %}\n        {{- ' ' + message['content'] + eos_token}}\n    {%- else %}\n        {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n    {%- endif %}\n{%- endfor %}\n",
+            /* .expected_output= */ " [INST] You are a helpful assistant\n\nHello [/INST] Hi there [INST] Who are you [/INST]    I am an assistant    [INST] Another question [/INST]",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "Mistral-Large-Instruct-2407 (mistralai 'v3' template; modified to have system prompt at start)",
+            /* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n    {%- set system_message = messages[0][\"content\"] %}\n    {%- set loop_messages = messages[1:] %}\n{%- else %}\n    {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n    {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n    {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n        {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n            {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n        {%- endif %}\n        {%- set ns.index = ns.index + 1 %}\n    {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n    {%- if message[\"role\"] == \"user\" %}\n        {%- if tools is not none and (message == user_messages[-1]) %}\n            {{- \"[AVAILABLE_TOOLS] [\" }}\n            {%- for tool in tools %}\n                {%- set tool = tool.function %}\n                {{- '{\"type\": \"function\", \"function\": {' }}\n                {%- for key, val in tool.items() if key != \"return\" %}\n                    {%- if val is string %}\n                        {{- '\"' + key + '\": \"' + val + '\"' }}\n                    {%- else %}\n                        {{- '\"' + key + '\": ' + val|tojson }}\n                    {%- endif %}\n                    {%- if not loop.last %}\n                        {{- \", \" }}\n                    {%- endif %}\n                {%- endfor %}\n                {{- \"}}\" }}\n                {%- if not loop.last %}\n                    {{- \", \" }}\n                {%- else %}\n                    {{- \"]\" }}\n                {%- endif %}\n            {%- endfor %}\n            {{- \"[/AVAILABLE_TOOLS]\" }}\n            {%- endif %}\n        {%- if loop.last and system_message is defined %}\n            {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n        {%- else %}\n            {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n        {%- endif %}\n    {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n        {{- \"[TOOL_CALLS] [\" }}\n        {%- for tool_call in message.tool_calls %}\n            {%- set out = tool_call.function|tojson %}\n            {{- out[:-1] }}\n            {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n                {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n            {%- endif %}\n            {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n            {%- if not loop.last %}\n                {{- \", \" }}\n            {%- else %}\n                {{- \"]\" + eos_token }}\n            {%- endif %}\n        {%- endfor %}\n    {%- elif message[\"role\"] == \"assistant\" %}\n        {{- \" \" + message[\"content\"]|trim + eos_token}}\n    {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n        {%- if message.content is defined and message.content.content is defined %}\n            {%- set content = message.content.content %}\n        {%- else %}\n            {%- set content = message.content %}\n        {%- endif %}\n        {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n        {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n            {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n        {%- endif %}\n        {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n    {%- else %}\n        {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n    {%- endif %}\n{%- endfor %}\n",
+            /* .expected_output= */       "[INST] You are a helpful assistant\n\nHello[/INST] Hi there[INST] Who are you[/INST] I am an assistant[INST] Another question[/INST]",
+            /* .expected_output_jinja= */ "[INST] Hello[/INST] Hi there[INST] Who are you[/INST] I am an assistant[INST] You are a helpful assistant\n\nAnother question[/INST]",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "Mistral-Nemo-Instruct-2407 (mistralai 'v3-tekken' template; modified to have system prompt at start)",
+            /* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n    {%- set system_message = messages[0][\"content\"] %}\n    {%- set loop_messages = messages[1:] %}\n{%- else %}\n    {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n    {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n    {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n        {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n            {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n        {%- endif %}\n        {%- set ns.index = ns.index + 1 %}\n    {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n    {%- if message[\"role\"] == \"user\" %}\n        {%- if tools is not none and (message == user_messages[-1]) %}\n            {{- \"[AVAILABLE_TOOLS][\" }}\n            {%- for tool in tools %}\n                {%- set tool = tool.function %}\n                {{- '{\"type\": \"function\", \"function\": {' }}\n                {%- for key, val in tool.items() if key != \"return\" %}\n                    {%- if val is string %}\n                        {{- '\"' + key + '\": \"' + val + '\"' }}\n                    {%- else %}\n                        {{- '\"' + key + '\": ' + val|tojson }}\n                    {%- endif %}\n                    {%- if not loop.last %}\n                        {{- \", \" }}\n                    {%- endif %}\n                {%- endfor %}\n                {{- \"}}\" }}\n                {%- if not loop.last %}\n                    {{- \", \" }}\n                {%- else %}\n                    {{- \"]\" }}\n                {%- endif %}\n            {%- endfor %}\n            {{- \"[/AVAILABLE_TOOLS]\" }}\n            {%- endif %}\n        {%- if loop.last and system_message is defined %}\n            {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n        {%- else %}\n            {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n        {%- endif %}\n    {%- elif (message.tool_calls is defined and message.tool_calls is not none) %}\n        {{- \"[TOOL_CALLS][\" }}\n        {%- for tool_call in message.tool_calls %}\n            {%- set out = tool_call.function|tojson %}\n            {{- out[:-1] }}\n            {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n                {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n            {%- endif %}\n            {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n            {%- if not loop.last %}\n                {{- \", \" }}\n            {%- else %}\n                {{- \"]\" + eos_token }}\n            {%- endif %}\n        {%- endfor %}\n    {%- elif message[\"role\"] == \"assistant\" %}\n        {{- message[\"content\"] + eos_token}}\n    {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n        {%- if message.content is defined and message.content.content is defined %}\n            {%- set content = message.content.content %}\n        {%- else %}\n            {%- set content = message.content %}\n        {%- endif %}\n        {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n        {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n            {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n        {%- endif %}\n        {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n    {%- else %}\n        {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n    {%- endif %}\n{%- endfor %}\n",
+            /* .expected_output= */       "[INST]You are a helpful assistant\n\nHello[/INST]Hi there[INST]Who are you[/INST]   I am an assistant   [INST]Another question[/INST]",
+            /* .expected_output_jinja= */ "[INST]Hello[/INST]Hi there[INST]Who are you[/INST]   I am an assistant   [INST]You are a helpful assistant\n\nAnother question[/INST]",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "mistralai/Mistral-Large-Instruct-2411 (mistralai 'v7' template)",
+            /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'system' %}{{ '[SYSTEM_PROMPT] ' + message['content'] + '[/SYSTEM_PROMPT]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token }}{% else %}{{ raise_exception('Only user, system and assistant roles are supported!') }}{% endif %}{% endfor %}",
+            /* .expected_output= */ "[SYSTEM_PROMPT] You are a helpful assistant[/SYSTEM_PROMPT][INST] Hello[/INST] Hi there[INST] Who are you[/INST]    I am an assistant   [INST] Another question[/INST]",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "ai-sage/GigaChat-20B-A3B-instruct",
+            /* .template_str= */ "{% if messages[0]['role'] == 'system' -%}\n    {%- set loop_messages = messages[1:] -%}\n    {%- set system_message = bos_token + messages[0]['content'] + additional_special_tokens[1] -%}\n{%- else -%}\n    {%- set loop_messages = messages -%}\n    {%- set system_message = bos_token + '' -%}\n{%- endif -%}\n{%- for message in loop_messages %}\n    {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n        {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}\n    {% endif %}\n    \n    {%- if loop.index0 == 0 -%}\n        {{ system_message -}}\n    {%- endif -%}\n    {%- if message['role'] == 'user' -%}\n        {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n        {{ 'available functions' + additional_special_tokens[0] + additional_special_tokens[2] + additional_special_tokens[3]  + additional_special_tokens[1] -}}\n    {%- endif -%}\n    {%- if message['role'] == 'assistant' -%}\n        {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n    {%- endif -%}\n    {%- if loop.last and add_generation_prompt -%}\n        {{ 'assistant' + additional_special_tokens[0] -}}\n    {%- endif -%}\n{%- endfor %}",
+            /* .expected_output= */ "You are a helpful assistant<|message_sep|>user<|role_sep|>Hello<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>Hi there<|message_sep|>user<|role_sep|>Who are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>   I am an assistant   <|message_sep|>user<|role_sep|>Another question<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+            /* .supported_with_jinja= */ false, // Requires additional_special_tokens as extra context
+        },
+        {
+            /* .name= */ "Infinigence/Megrez-3B-Instruct",
+            /* .template_str= */ U8C("{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}"),
+            /* .expected_output= */ "<|role_start|>system<|role_end|>You are a helpful assistant<|turn_end|><|role_start|>user<|role_end|>Hello<|turn_end|><|role_start|>assistant<|role_end|>Hi there<|turn_end|><|role_start|>user<|role_end|>Who are you<|turn_end|><|role_start|>assistant<|role_end|>   I am an assistant   <|turn_end|><|role_start|>user<|role_end|>Another question<|turn_end|><|role_start|>assistant<|role_end|>",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "phi-4",
+            /* .template_str= */ "{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|im_start|>system<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|><|im_start|>assistant<|im_sep|>'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>'}}{% endif %}{% endfor %}",
+            /* .expected_output= */ "<|im_start|>system<|im_sep|>You are a helpful assistant<|im_end|><|im_start|>user<|im_sep|>Hello<|im_end|><|im_start|>assistant<|im_sep|>Hi there<|im_end|><|im_start|>user<|im_sep|>Who are you<|im_end|><|im_start|>assistant<|im_sep|>   I am an assistant   <|im_end|><|im_start|>user<|im_sep|>Another question<|im_end|><|im_start|>assistant<|im_sep|>",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "yandex/YandexGPT-5-Lite-8B-instruct",
+            /* .template_str= */ "{%- set names = {'assistant': ' Ассистент:', 'user': ' Пользователь:'} %}\n{%- set tools_prefix = 'Тебе доступны следующие функции:' %}\n{%- macro __render_tool(tool) %}\n    {%- set name = tool.function.name %}\n    {%- set description = tool.function.description|default('') %}\n    {%- set parameters = tool.function.parameters|tojson %}\n    {{- '\\n' }}function {{ '{' }}'name':'{{ name }}',\n    {%- if tool.function.description %}'description':'{{ description }}',{% endif %}\n'parameters':{{ parameters }}\n    {{- '}' }}\n{%- endmacro %}\n{%- macro __render_tools(tools) %}\n    {{- tools_prefix }}\n    {%- for tool in tools %}\n        {{- __render_tool(tool) }}\n    {%- endfor %}\n    {{- '\\n\\n' }}\n{%- endmacro %}\n{%- macro __render_tool_message(message) %}\n    {{- '\\n\\nРезультат вызова' }} {{ message.name }}: {{ message.content }} {{ '\\n\\n' }}\n{%- endmacro %}\n{%- if tools -%}\n    {{- __render_tools(tools) }}\n{%- endif -%}\n{%- macro __render_user_message(message) %}\n{{ names.user }} {{ message.content + '\\n\\n' }}\n{%- endmacro %}\n{%- macro __render_assistant_message(message) %}\n    {{- names.assistant }}\n    {%- set call = message['function_call'] %}\n    {%- if call %}\n        {{- '\\n[TOOL_CALL_START]' }}{{ call.name }}{{ '\\n' }}{{ call.arguments|tojson }}\n    {%- else %}\n        {{- ' ' + message.content + '\\n\\n' }}\n    {%- endif %}\n{%- endmacro %}\n{%- if not add_generation_prompt is defined %}\n{%- set add_generation_prompt = false %}\n{%- endif %}\n{%- for message in messages %}\n    {%- if message['role'] == 'user' %}\n        {{- __render_user_message(message) }}\n    {%- endif %}\n    {%- if message.role == 'assistant' and not loop.last %}\n        {{- __render_assistant_message(message) }}\n    {%- endif %}\n    {%- if message.role == 'tool' %}\n        {{- __render_tool_message(message) }}\n    {%- endif %}\n    {%- if loop.last %}\n        {{- ' Ассистент:[SEP]' }}\n    {%- endif %}\n{%- endfor %}\n",
+            /* .expected_output= */ " Пользователь: Hello\n\n Ассистент: Hi there\n\n Пользователь: Who are you\n\n Ассистент:    I am an assistant   \n\n Пользователь: Another question\n\n Ассистент:[SEP]",
+            /* .expected_output_jinja= */ " Пользователь: You are a helpful assistant\nHello\n\n Ассистент: Hi there\n\n Пользователь: Who are you\n\n Ассистент:    I am an assistant   \n\n Пользователь: Another question\n\n Ассистент:[SEP]",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "inclusionAI/Ling-lite",
+            /* .template_str */ "{% for message in messages %}{% set role = message['role'] | lower %}{% if role == 'user' %}{% set role = 'HUMAN' %}{% endif %}{% set role = role | upper %}{{ '' + role + '' + message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT' }}{% endif %}",
+            /* .expected_output= */ "SYSTEMYou are a helpful assistantHUMANHelloASSISTANTHi thereHUMANWho are youASSISTANT   I am an assistant   HUMANAnother questionASSISTANT",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+    };
+    std::vector formatted_chat(1024);
+    int32_t res;
+
+    // list all supported templates
+    std::vector supported_tmpl;
+    res = llama_chat_builtin_templates(nullptr, 0);
+    assert(res > 0);
+    supported_tmpl.resize(res);
+    res = llama_chat_builtin_templates(supported_tmpl.data(), supported_tmpl.size());
+    printf("Built-in chat templates:\n");
+    for (auto tmpl : supported_tmpl) {
+        printf("  %s\n", tmpl);
+    }
+
+    // test invalid chat template
+    res = llama_chat_apply_template("INVALID TEMPLATE", conversation.data(), conversation.size(), true, formatted_chat.data(), formatted_chat.size());
+    assert(res < 0);
+    const auto add_generation_prompt = true;
+
+    for (const auto & test_case : test_cases) {
+        printf("\n\n=== %s ===\n\n", test_case.name.c_str());
+        formatted_chat.resize(1024);
+        res = llama_chat_apply_template(
+            test_case.template_str.c_str(),
+            conversation.data(),
+            conversation.size(),
+            add_generation_prompt,
+            formatted_chat.data(),
+            formatted_chat.size()
+        );
+        formatted_chat.resize(res);
+        std::string output(formatted_chat.data(), formatted_chat.size());
+        if (output != test_case.expected_output) {
+            printf("Expected:\n%s\n", test_case.expected_output.c_str());
+            printf("-------------------------\n");
+            printf("Actual:\n%s\n", output.c_str());
+            fflush(stdout);
+            assert(output == test_case.expected_output);
+        }
+    }
+
+    std::vector messages;
+    for (const auto & msg : conversation) {
+        messages.push_back(simple_msg(msg.role, msg.content));
+    }
+    for (const auto & test_case : test_cases) {
+        if (!test_case.supported_with_jinja) {
+            continue;
+        }
+        printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str());
+        try {
+            auto tmpls = common_chat_templates_init(/* model= */ nullptr, test_case.template_str.c_str(), test_case.bos_token, test_case.eos_token);
+            common_chat_templates_inputs inputs;
+            inputs.use_jinja = true;
+            inputs.messages = messages;
+            inputs.add_generation_prompt = add_generation_prompt;
+            auto output = common_chat_templates_apply(tmpls.get(), inputs).prompt;
+            output = normalize_newlines(output);
+            auto expected_output = normalize_newlines(test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja);
+            if (output != expected_output) {
+                printf("Expected:\n%s\n", expected_output.c_str());
+                printf("-------------------------\n");
+                printf("Actual:\n%s\n", output.c_str());
+                fflush(stdout);
+                assert(output == expected_output);
+            }
+        } catch (const std::exception & e) {
+            printf("ERROR: %s\n", e.what());
+            assert(false);
+        }
+    }
+
+    // test llama_chat_format_single for system message
+    printf("\n\n=== llama_chat_format_single (system message) ===\n\n");
+    std::vector chat2;
+    auto sys_msg = simple_msg("system", "You are a helpful assistant");
+
+    auto fmt_sys = [&](std::string tmpl_str) {
+        auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str);
+        auto output = common_chat_format_single(tmpls.get(), chat2, sys_msg, false, /* use_jinja= */ false);
+        printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
+        printf("-------------------------\n");
+        return output;
+    };
+    assert(fmt_sys("chatml") == "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n");
+    assert(fmt_sys("mistral-v1") == " [INST] You are a helpful assistant\n\n");
+    assert(fmt_sys("mistral-v3") == "[INST] You are a helpful assistant\n\n");
+    assert(fmt_sys("mistral-v3-tekken") == "[INST]You are a helpful assistant\n\n");
+    assert(fmt_sys("mistral-v7") == "[SYSTEM_PROMPT] You are a helpful assistant[/SYSTEM_PROMPT]");
+    assert(fmt_sys("llama2") == "[INST] You are a helpful assistant\n");
+    assert(fmt_sys("llama2-sys") == "[INST] <>\nYou are a helpful assistant\n<>\n\n");
+    assert(fmt_sys("mistral") == "[INST] You are a helpful assistant\n"); // for old pre-v1 templates
+    assert(fmt_sys("gemma")  == ""); // for gemma, system message is merged with user message
+    assert(fmt_sys("llama3") == "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|>");
+    assert(fmt_sys("gigachat") == "You are a helpful assistant<|message_sep|>");
+
+
+    // test llama_chat_format_single for user message
+    printf("\n\n=== llama_chat_format_single (user message) ===\n\n");
+    chat2.push_back(simple_msg("system", "You are a helpful assistant"));
+    chat2.push_back(simple_msg("user", "Hello"));
+    chat2.push_back(simple_msg("assistant", "I am assistant"));
+    auto new_msg = simple_msg("user", "How are you");
+
+    auto fmt_single = [&](const std::string & tmpl_str) {
+        auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str.c_str());
+        auto output = common_chat_format_single(tmpls.get(), chat2, new_msg, true, /* use_jinja= */ false);
+        printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
+        printf("-------------------------\n");
+        return output;
+    };
+    assert(fmt_single("chatml") == "\n<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
+    assert(fmt_single("mistral-v1") == " [INST] How are you [/INST]");
+    assert(fmt_single("mistral-v3") == "[INST] How are you[/INST]");
+    assert(fmt_single("mistral-v3-tekken") == "[INST]How are you[/INST]");
+    assert(fmt_single("mistral-v7") == "[INST] How are you[/INST]");
+    assert(fmt_single("llama2") == "[INST] How are you [/INST]");
+    assert(fmt_single("mistral") == "[INST] How are you [/INST]"); // for old pre-v1 templates
+    assert(fmt_single("gemma")  == "\nuser\nHow are you\nmodel\n");
+    assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
+    assert(fmt_single("gigachat") == "user<|role_sep|>How are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>");
+
+    return 0;
+}
diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..6ebf1464d911a9317f92beb7fc0c2b0916480c3f
--- /dev/null
+++ b/tests/test-chat.cpp
@@ -0,0 +1,1478 @@
+//  Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
+//
+//  Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
+//  e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
+//
+//    cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
+//
+#include "chat.h"
+
+#include "log.h"
+
+#include "../src/unicode.h"
+#include "../src/llama-grammar.h"
+
+#include 
+
+#include 
+#include 
+#include 
+
+using json = nlohmann::ordered_json;
+
+static std::ostream & operator<<(std::ostream & os, const common_chat_msg_diff & diff) {
+    os << "{ content_delta: " << diff.content_delta << "; ";
+    os << "reasoning_content_delta: " << diff.reasoning_content_delta << "; ";
+    if (diff.tool_call_index != std::string::npos) {
+        os << "tool_call_index: " << diff.tool_call_index << "; ";
+        os << "tool_call_delta.name: " << diff.tool_call_delta.name << "; ";
+        os << "tool_call_delta.id: " << diff.tool_call_delta.id << "; ";
+        os << "tool_call_delta.arguments: " << diff.tool_call_delta.arguments << "; ";
+    }
+    os << "}";
+    return os;
+}
+// operator<< for vector:
+static std::ostream & operator<<(std::ostream & os, const std::vector & diffs) {
+    os << "[\n";
+    for (const auto & diff : diffs) {
+        os << "  " << diff << ",\n";
+    }
+    os << "]";
+    return os;
+}
+static std::ostream & operator<<(std::ostream & os, const common_chat_msg & msg) {
+    os << "{ role: " << msg.role << "; ";
+    os << "content: " << msg.content << "; ";
+    os << "content_parts: [\n";
+    for (const auto & part : msg.content_parts) {
+        os << "  { type: " << part.type << "; text: " << part.text << " },\n";
+    }
+    os << "]; ";
+    os << "reasoning_content: " << msg.reasoning_content << "; ";
+    os << "tool_calls: [\n";
+    for (const auto & tool_call : msg.tool_calls) {
+        os << "  { name: " << tool_call.name << "; arguments: " << tool_call.arguments << "; id: " << tool_call.id << " },\n";
+    }
+    os << "]";
+    os << "}";
+    return os;
+}
+
+template  static bool equals(const T & expected, const T & actual) {
+    return expected == actual;
+}
+
+static common_chat_msg normalize(const common_chat_msg & msg) {
+    common_chat_msg normalized = msg;
+    for (auto & tool_call : normalized.tool_calls) {
+        try {
+            tool_call.arguments = json::parse(tool_call.arguments).dump();
+        } catch (const std::exception &) {
+            // Do nothing
+        }
+    }
+    return normalized;
+}
+template <>
+bool equals(const common_chat_msg & expected, const common_chat_msg & actual) {
+    return normalize(expected) == normalize(actual);
+}
+
+template  static void assert_equals(const T & expected, const T & actual) {
+    if (!equals(expected, actual)) {
+        std::cerr << "Expected: " << expected << std::endl;
+        std::cerr << "Actual: " << actual << std::endl;
+        std::cerr << std::flush;
+        throw std::runtime_error("Test failed");
+    }
+}
+
+static std::string read_file(const std::string & path) {
+    std::cerr << "# Reading: " << path << '\n' << std::flush;
+    std::ifstream fs(path, std::ios_base::binary);
+    if (!fs.is_open()) {
+        fs = std::ifstream("../" + path, std::ios_base::binary);
+        if (!fs.is_open()) {
+            throw std::runtime_error("Failed to open file: " + path);
+        }
+    }
+    fs.seekg(0, std::ios_base::end);
+    auto size = fs.tellg();
+    fs.seekg(0);
+    std::string out;
+    out.resize(static_cast(size));
+    fs.read(out.data(), static_cast(size));
+    return out;
+}
+
+static common_chat_templates_ptr read_templates(const std::string & path) {
+    return common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, read_file(path)));
+}
+
+static std::unique_ptr build_grammar(const std::string & grammar_str) {
+    return std::unique_ptr(
+        llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0));
+}
+
+// TODO: extract to common helper (copied from test-grammar-integration.cpp)
+static bool match_string(const std::string & input, llama_grammar * grammar) {
+    const auto cpts = unicode_cpts_from_utf8(input);
+
+    auto & stacks_cur = llama_grammar_get_stacks(grammar);
+
+    for (const auto & cpt : cpts) {
+        llama_grammar_accept(grammar, cpt);
+
+        if (stacks_cur.empty()) {
+            // no stacks means that the grammar failed to match at this point
+            return false;
+        }
+    }
+
+    if (std::any_of(stacks_cur.begin(), stacks_cur.end(), [](const auto & stack) { return stack.empty(); })) {
+        // An empty stack means that the grammar has been completed
+        return true;
+    }
+
+    return false;
+}
+
+static std::string renormalize_json(const std::string & json_str) {
+    try {
+        auto json_obj = json::parse(json_str);
+        return json_obj.dump();
+    } catch (const std::exception & e) {
+        std::cerr << "Failed to parse JSON: " << e.what() << '\n';
+        return json_str;
+    }
+}
+static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
+    assert_equals(expected.role, actual.role);
+    assert_equals(expected.content, actual.content);
+    assert_equals(expected.content_parts.size(), actual.content_parts.size());
+    for (size_t i = 0; i < expected.content_parts.size(); i++) {
+        const auto & expected_part = expected.content_parts[i];
+        const auto & actual_part   = actual.content_parts[i];
+        assert_equals(expected_part.type, actual_part.type);
+        assert_equals(expected_part.text, actual_part.text);
+    }
+    assert_equals(expected.reasoning_content, actual.reasoning_content);
+    assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
+    for (size_t i = 0; i < expected.tool_calls.size(); i++) {
+        const auto & expected_tool_call = expected.tool_calls[i];
+        const auto & actual_tool_call   = actual.tool_calls[i];
+        assert_equals(expected_tool_call.name, actual_tool_call.name);
+        assert_equals(renormalize_json(expected_tool_call.arguments), renormalize_json(actual_tool_call.arguments));
+        assert_equals(expected_tool_call.id, actual_tool_call.id);
+    }
+}
+
+common_chat_tool special_function_tool {
+    /* .name = */ "special_function",
+    /* .description = */ "I'm special",
+    /* .parameters = */ R"({
+        "type": "object",
+        "properties": {
+            "arg1": {
+                "type": "integer",
+                "description": "The arg."
+            }
+        },
+        "required": ["arg1"]
+    })",
+};
+common_chat_tool python_tool {
+    /* .name = */ "python",
+    /* .description = */ "an ipython interpreter",
+    /* .parameters = */ R"({
+        "type": "object",
+        "properties": {
+            "code": {
+                "type": "string",
+                "description": "Python code to execute."
+            }
+        },
+        "required": ["code"]
+    })",
+};
+common_chat_tool code_interpreter_tool {
+    /* .name = */ "code_interpreter",
+    /* .description = */ "an ipython interpreter",
+    /* .parameters = */ R"({
+        "type": "object",
+        "properties": {
+            "code": {
+                "type": "string",
+                "description": "Python code to execute."
+            }
+        },
+        "required": ["code"]
+    })",
+};
+std::vector tools           { special_function_tool, python_tool };
+std::vector llama_3_1_tools { special_function_tool, code_interpreter_tool };
+
+struct delta_data {
+    std::string        delta;
+    common_chat_params params;
+};
+
+static delta_data init_delta(const struct common_chat_templates * tmpls, const std::vector & end_tokens,
+                             const common_chat_msg & user_message,
+                             const common_chat_msg & delta_message,
+                             const std::vector & tools,
+                             const common_chat_tool_choice & tool_choice) {
+    common_chat_templates_inputs inputs;
+    inputs.parallel_tool_calls = true;
+    inputs.messages.push_back(user_message);
+    inputs.tools       = tools;
+    inputs.tool_choice = tool_choice;
+    auto params_prefix = common_chat_templates_apply(tmpls, inputs);
+
+    inputs.messages.push_back(delta_message);
+    inputs.add_generation_prompt = false;
+    auto params_full             = common_chat_templates_apply(tmpls, inputs);
+
+    std::string prefix = params_prefix.prompt;
+    std::string full   = params_full.prompt;
+
+    if (full == prefix) {
+        throw std::runtime_error("Full message is the same as the prefix");
+    }
+
+    size_t common_prefix_length = 0;
+    for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
+        if (prefix[i] != full[i]) {
+            break;
+        }
+        if (prefix[i] == '<') {
+            // DeepSeek R1's template (as of 20250209) adds a trailing  if add_generation_prompt,
+            // but it removes thinking tags for past messages.
+            // The prefix and full strings diverge at  vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
+            continue;
+        }
+        common_prefix_length = i + 1;
+    }
+    auto delta = full.substr(common_prefix_length);
+
+    // Strip end tokens
+    for (const auto & end_token : end_tokens) {
+        // rfind to find the last occurrence
+        auto pos = delta.rfind(end_token);
+        if (pos != std::string::npos) {
+            delta = delta.substr(0, pos);
+            break;
+        }
+    }
+    return { delta, params_full };
+}
+
+/*
+  Applies the template to 1 user message w/ add_generation_prompt=true, then w/ the test message w/ add_generation_prompt=false,
+  gets the diff, removes any end tokens and parses the result w/ the grammar, checking that
+  the parsed message is the same as the test_message
+*/
+static void test_templates(const struct common_chat_templates * tmpls, const std::vector & end_tokens,
+                          const common_chat_msg & test_message,
+                          const std::vector & tools = {},
+                          const std::string & expected_delta = "",
+                          bool expect_grammar_triggered = true,
+                          bool test_grammar_if_triggered = true,
+                          common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE) {
+    common_chat_msg user_message;
+    user_message.role = "user";
+    user_message.content = "Hello, world!";
+
+    for (const auto & tool_choice : std::vector {COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED}) {
+        auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice);
+        if (!expected_delta.empty()) {
+            assert_equals(expected_delta, data.delta);
+        }
+
+        if (expect_grammar_triggered) {
+            common_chat_syntax syntax;
+            syntax.format = data.params.format;
+            syntax.reasoning_format = reasoning_format;
+            const auto msg = common_chat_parse(data.delta, /* is_partial= */ false, syntax);
+            assert_msg_equals(test_message, msg);
+        }
+
+        if (!test_message.tool_calls.empty()) {
+            GGML_ASSERT(!data.params.grammar.empty());
+        }
+        if (!data.params.grammar.empty()) {
+            auto grammar = build_grammar(data.params.grammar);
+            if (!grammar) {
+                throw std::runtime_error("Failed to build grammar");
+            }
+            auto earliest_trigger_pos = std::string::npos;
+            auto constrained = data.delta;
+            for (const auto & trigger : data.params.grammar_triggers) {
+                size_t pos = std::string::npos;
+                std::smatch match;
+                switch (trigger.type) {
+                    case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
+                    {
+                        const auto & word = trigger.value;
+                        pos = constrained.find(word);
+                        break;
+                    }
+                    case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
+                    {
+                        const auto & pattern = trigger.value;
+                        if (std::regex_search(constrained, match, std::regex(pattern))) {
+                            pos = match.position(1);
+                        }
+                        break;
+                    }
+                    case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
+                    {
+                        const auto & pattern = trigger.value;
+                        if (std::regex_match(constrained, match, std::regex(pattern))) {
+                            auto mpos = std::string::npos;
+                            for (size_t i = 1; i < match.size(); ++i) {
+                                if (match[i].length() > 0) {
+                                    mpos = match.position(i);
+                                    break;
+                                }
+                            }
+                            if (mpos == std::string::npos) {
+                                mpos = match.position(0);
+                            }
+                            pos = mpos;
+                        }
+                        break;
+                    }
+                    default:
+                        throw std::runtime_error("Unknown trigger type");
+                }
+                if (pos == std::string::npos) {
+                    continue;
+                }
+                if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
+                    earliest_trigger_pos = pos;
+                }
+            }
+            auto grammar_triggered = false;
+            if (earliest_trigger_pos != std::string::npos) {
+                constrained = constrained.substr(earliest_trigger_pos);
+                grammar_triggered = true;
+            }
+            if (data.params.grammar_lazy) {
+                assert_equals(expect_grammar_triggered, grammar_triggered);
+            }
+
+            if (grammar_triggered && test_grammar_if_triggered && !match_string(constrained, grammar.get())) {
+                throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
+                    "\n\nConstrained: " + constrained +
+                    "\n\nGrammar: " + data.params.grammar);
+            }
+        }
+    }
+}
+
+const common_chat_msg message_user {
+    "user",
+    "Hey there!",
+    /* .content_parts = */ {},
+    /* .tool_calls = */ {},
+    /* .reasoning_content = */ "",
+    /* .tool_name = */ "",
+    /* .tool_call_id = */ "",
+};
+
+const common_chat_msg message_user_parts {
+    "user",
+    /* .content = */ "",
+    /* .content_parts = */ {
+        { "text", "Hey" },
+        { "text", "there" },
+    },
+    /* .tool_calls = */ {},
+    /* .reasoning_content = */ "",
+    /* .tool_name = */ "",
+    /* .tool_call_id = */ "",
+};
+static common_chat_msg simple_assist_msg(const std::string & content, const std::string & reasoning_content = "", const std::string & tool_name = "", const std::string & arguments = "", const std::string & id = "") {
+    common_chat_msg msg;
+    msg.role = "assistant";
+    msg.content = content;
+    msg.reasoning_content = reasoning_content;
+    if (!tool_name.empty()) {
+        msg.tool_calls.push_back({ tool_name, arguments, id });
+    }
+    return msg;
+}
+const common_chat_msg message_assist                              = simple_assist_msg("Hello, world!\nWhat's up?");
+const common_chat_msg message_assist_empty                        = simple_assist_msg("");
+const common_chat_msg message_assist_thoughts_unparsed_deepseek   = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?");
+const common_chat_msg message_assist_thoughts_unparsed_md         = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}```");
+const common_chat_msg message_assist_thoughts_unparsed_md_partial = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}");
+
+const common_chat_msg message_assist_thoughts_unparsed_r7b       = simple_assist_msg("<|START_THINKING|>I'm\nthinking<|END_THINKING|>Hello, world!\nWhat's up?");
+const common_chat_msg message_assist_thoughts                    = simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking");
+const common_chat_msg message_assist_thoughts_unopened_unparsed  = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?");
+const common_chat_msg message_assist_thoughts_no_content         = simple_assist_msg("", "I'm\nthinking");
+const common_chat_msg message_assist_call                        = simple_assist_msg("", "", "special_function", "{\"arg1\": 1}");
+const common_chat_msg message_assist_call_content                = simple_assist_msg("Hello, world!\nWhat's up?", "", "special_function", "{\"arg1\":1}");
+const common_chat_msg message_assist_call_empty_args             = simple_assist_msg("", "", "special_function");
+const common_chat_msg message_assist_call_cutoff_args            = simple_assist_msg("", "", "special_function", "{\"arg");
+const common_chat_msg message_assist_call_thoughts               = simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\":1}");
+const common_chat_msg message_assist_call_thoughts_unparsed      = simple_assist_msg("I'm\nthinking\n\n", "", "special_function", "{\"arg1\": 1}");
+const common_chat_msg message_assist_call_id                     = simple_assist_msg("", "", "special_function", "{\"arg1\":1}", /* .id = */ "123456789");
+const common_chat_msg message_assist_call_idx                    = simple_assist_msg("", "", "special_function", "{\"arg1\":1}", /* .id = */ "0");
+const common_chat_msg message_assist_thoughts_call_idx           = simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\": 1}", /* id = */ "0");
+const common_chat_msg message_assist_call_python                 = simple_assist_msg("", "", "python", "{\"code\":\"print('hey')\"}");
+const common_chat_msg message_assist_call_python_lines           = simple_assist_msg("", "", "python", "{\"code\":\"# This is a program:\\nprint('hey')\"}");
+const common_chat_msg message_assist_call_python_lines_unclosed  = simple_assist_msg("", "", "python", "{\"code\":\"# This is a program:\\nprint('hey')");
+const common_chat_msg message_assist_call_code_interpreter       = simple_assist_msg("", "", "code_interpreter", "{\"code\":\"print('hey')\"}");
+
+static void test_msgs_oaicompat_json_conversion() {
+    printf("[%s]\n", __func__);
+    std::vector msgs{
+        message_user,
+        message_user_parts,
+        message_assist_call,
+        message_assist_call_thoughts,
+        message_assist_call_thoughts_unparsed,
+        message_assist_call_id,
+        message_assist_call_idx,
+        message_assist_call_python,
+        message_assist_call_code_interpreter,
+    };
+    for (const auto & msg : msgs) {
+        auto oai_json = common_chat_msgs_to_json_oaicompat({msg});
+        auto msgs2 = common_chat_msgs_parse_oaicompat(oai_json);
+        assert_equals((size_t) 1, msgs2.size());
+        auto msg2 = msgs2[0];
+        assert_msg_equals(msg, msg2);
+    }
+    assert_equals(
+        std::string(
+            "[\n"
+            "  {\n"
+            "    \"role\": \"user\",\n"
+            "    \"content\": [\n"
+            "      {\n"
+            "        \"type\": \"text\",\n"
+            "        \"text\": \"Hey\"\n"
+            "      },\n"
+            "      {\n"
+            "        \"type\": \"text\",\n"
+            "        \"text\": \"there\"\n"
+            "      }\n"
+            "    ]\n"
+            "  }\n"
+            "]"
+        ),
+        common_chat_msgs_to_json_oaicompat({message_user_parts}).dump(2));
+
+    assert_equals(
+        std::string(
+            "[\n"
+            "  {\n"
+            "    \"role\": \"assistant\",\n"
+            "    \"content\": null,\n"
+            "    \"tool_calls\": [\n"
+            "      {\n"
+            "        \"type\": \"function\",\n"
+            "        \"function\": {\n"
+            "          \"name\": \"python\",\n"
+            "          \"arguments\": \"{\\\"code\\\":\\\"print('hey')\\\"}\"\n"
+            "        }\n"
+            "      }\n"
+            "    ]\n"
+            "  }\n"
+            "]"
+        ),
+        common_chat_msgs_to_json_oaicompat({message_assist_call_python}).dump(2));
+
+    auto res = common_chat_msgs_parse_oaicompat(json::parse("[{\"role\": \"assistant\", \"tool_calls\": []}]"));
+    assert_equals(1, res.size());
+    assert_equals(res[0].role, "assistant");
+    assert_equals(true, res[0].content.empty());
+    assert_equals(true, res[0].tool_calls.empty());
+
+    try {
+        common_chat_msgs_parse_oaicompat(json::parse("[{\"role\": \"assistant\"}]"));
+        throw std::runtime_error("Expected exception");
+    } catch (const std::exception & e) {
+        if (std::string(e.what()).find("'content'") == std::string::npos) {
+            throw std::runtime_error("Expected exception about missing 'content'");
+        }
+    }
+}
+
+static void test_tools_oaicompat_json_conversion() {
+    printf("[%s]\n", __func__);
+    std::vector tools{
+        special_function_tool,
+        python_tool,
+        code_interpreter_tool,
+    };
+
+    for (const auto & tool : tools) {
+        auto oai_json = common_chat_tools_to_json_oaicompat({tool});
+        auto tools2 = common_chat_tools_parse_oaicompat(oai_json);
+        assert_equals((size_t) 1, tools2.size());
+        auto tool2 = tools2[0];
+        assert_equals(tool.name, tool2.name);
+        assert_equals(tool.description, tool2.description);
+        assert_equals(json::parse(tool.parameters).dump(2), json::parse(tool2.parameters).dump(2));
+    }
+
+    assert_equals(
+        std::string(
+            "[\n"
+            "  {\n"
+            "    \"type\": \"function\",\n"
+            "    \"function\": {\n"
+            "      \"name\": \"special_function\",\n"
+            "      \"description\": \"I'm special\",\n"
+            "      \"parameters\": {\n"
+            "        \"type\": \"object\",\n"
+            "        \"properties\": {\n"
+            "          \"arg1\": {\n"
+            "            \"type\": \"integer\",\n"
+            "            \"description\": \"The arg.\"\n"
+            "          }\n"
+            "        },\n"
+            "        \"required\": [\n"
+            "          \"arg1\"\n"
+            "        ]\n"
+            "      }\n"
+            "    }\n"
+            "  }\n"
+            "]"
+        ),
+        common_chat_tools_to_json_oaicompat({special_function_tool}).dump(2));
+}
+
+static void test_template_output_parsers() {
+    printf("[%s]\n", __func__);
+
+    common_chat_templates_inputs inputs_no_tools;
+    inputs_no_tools.messages                = {message_user};
+
+    common_chat_templates_inputs inputs_tools;
+    inputs_tools.messages                   = {message_user};
+    inputs_tools.tools                      = {special_function_tool};
+
+    common_chat_templates_inputs inputs_tools_builtin;
+    inputs_tools_builtin.messages           = {message_user};
+    inputs_tools_builtin.tools              = {python_tool};
+
+    {
+        // Not supported yet
+        auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja");
+        assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+    }
+    {
+        auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja");
+        std::vector   end_tokens{ "<|END_OF_TURN_TOKEN|>" };
+
+        for (const auto & inputs : { inputs_no_tools, inputs_tools }) {
+            auto params = common_chat_templates_apply(tmpls.get(), inputs);
+            assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, params.format);
+            assert_equals(false, params.thinking_forced_open);
+        }
+
+        assert_msg_equals(message_assist,
+            common_chat_parse(
+                "Hello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_COMMAND_R7B}));
+        assert_msg_equals(message_assist,
+            common_chat_parse(
+                "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_COMMAND_R7B}));
+        assert_msg_equals(message_assist_thoughts,
+            common_chat_parse(
+                "<|START_THINKING|>I'm\nthinking<|END_THINKING|>"
+                "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+        assert_msg_equals(message_assist_thoughts_unparsed_deepseek,
+            common_chat_parse(
+                "<|START_THINKING|>I'm\nthinking<|END_THINKING|>"
+                "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ true,
+                    /* .thinking_forced_open = */ false,
+                }));
+        assert_msg_equals(message_assist_thoughts_unparsed_r7b,
+            common_chat_parse(
+                "<|START_THINKING|>I'm\nthinking<|END_THINKING|>"
+                "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_COMMAND_R7B}));
+        assert_msg_equals(message_assist_thoughts,
+            common_chat_parse(
+                "<|START_THINKING|>I'm\nthinking<|END_THINKING|>"
+                "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+        assert_msg_equals(message_assist_thoughts_call_idx,
+            common_chat_parse(
+                "<|START_THINKING|>I'm\nthinking<|END_THINKING|>"
+                "<|START_ACTION|>[\n"
+                "    {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
+                "]<|END_ACTION|>",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+        assert_msg_equals(message_assist_thoughts_no_content,
+            common_chat_parse(
+                "<|START_THINKING|>I'm\nthinking<|END_THINKING|>"
+                "<|START_ACTION|>[\n"
+                "    {\"tool_call_id\": \"0\", \"tool_name\": \"special",
+                /* is_partial= */ true,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+
+        test_templates(tmpls.get(), end_tokens, message_assist_call_idx, tools,
+                      "<|START_THINKING|><|END_THINKING|>"
+                      "<|START_ACTION|>[\n"
+                      "    {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
+                      "]<|END_ACTION|>",
+                      /* expect_grammar_triggered= */ true,
+                      /* test_grammar_if_triggered= */ true,
+                      COMMON_REASONING_FORMAT_DEEPSEEK);
+        test_templates(tmpls.get(), end_tokens, message_assist, tools,
+                      "<|START_RESPONSE|>Hello, world!\n"
+                      "What's up?<|END_RESPONSE|>",
+                      /* expect_grammar_triggered= */ false);
+    }
+    {
+        auto tmpls = read_templates("models/templates/google-gemma-2-2b-it.jinja");
+        std::vector   end_tokens{ "" };
+
+        assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_GENERIC,
+                      common_chat_templates_apply(
+                          read_templates("models/templates/microsoft-Phi-3.5-mini-instruct.jinja").get(),
+                          inputs_tools)
+                          .format);
+
+        // Generic tool calls doesn't generate / parse content-only messages symmetrically.
+
+        assert_equals(
+            simple_assist_msg("{ \"tool_call\" : { \"name\" : \"t"),
+            common_chat_parse(
+                "{ \"tool_call\" : { \"name\" : \"t",
+                /* is_partial= */ true,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_GENERIC,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ true,
+                    /* .parse_tool_calls = */ false,
+                }));
+        assert_equals(
+            message_assist_empty,
+            common_chat_parse(
+                "{ \"tool_call\" : { \"name\" : \"t",
+                /* is_partial= */ true,
+                {COMMON_CHAT_FORMAT_GENERIC}));
+
+        assert_equals(
+            simple_assist_msg("", "", "puppeteer_screenshot", "{\"name\":\"servethehome_homepage\","),
+            common_chat_parse(
+                R"({"tool_call": {"name": "puppeteer_screenshot", "arguments": {"name": "servethehome_homepage",)",
+                /* is_partial= */ true,
+                {COMMON_CHAT_FORMAT_GENERIC}));
+
+        assert_equals(
+            message_assist_call_empty_args,
+            common_chat_parse(
+                "{ \"tool_call\" : { \"name\" : \"special_function\"",
+                /* is_partial= */ true,
+                {COMMON_CHAT_FORMAT_GENERIC}));
+        assert_equals(
+            message_assist_call_cutoff_args,
+            common_chat_parse(
+                "{ \"tool_call\" : { \"name\" : \"special_function\", \"arguments\" : { \"arg",
+                /* is_partial= */ true,
+                {COMMON_CHAT_FORMAT_GENERIC}));
+
+        assert_msg_equals(message_assist,
+            common_chat_parse(
+                "{\n"
+                "  \"response\": \"Hello, world!\\nWhat's up?\"\n"
+                "}",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_GENERIC}));
+        test_templates(tmpls.get(), end_tokens, message_assist_call_id, tools,
+                      "{\n"
+                      "  \"tool_calls\": [\n"
+                      "    {\n"
+                      "      \"name\": \"special_function\",\n"
+                      "      \"arguments\": {\n"
+                      "        \"arg1\": 1\n"
+                      "      },\n"
+                      "      \"id\": \"123456789\"\n"
+                      "    }\n"
+                      "  ]\n"
+                      "}");
+    }
+    {
+        auto tmpls = read_templates("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja");
+        std::vector   end_tokens{ "" };
+
+        assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+
+        test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+        test_templates(
+            tmpls.get(), end_tokens, message_assist_call_id, tools,
+            "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
+    }
+    {
+        auto tmpls = read_templates("models/templates/Qwen-QwQ-32B.jinja");
+        std::vector end_tokens{ "<|im_end|>" };
+
+        assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+    }
+    {
+        auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja");
+        std::vector end_tokens{ "<|im_end|>" };
+
+        assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+        assert_equals(
+            COMMON_CHAT_FORMAT_HERMES_2_PRO,
+            common_chat_templates_apply(
+                read_templates("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja").get(),
+                inputs_tools)
+                .format);
+        assert_equals(
+            COMMON_CHAT_FORMAT_HERMES_2_PRO,
+            common_chat_templates_apply(
+                read_templates("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja").get(),
+                inputs_tools)
+                .format);
+
+        // Test parsing
+        assert_msg_equals(
+            simple_assist_msg("", "", "python", ""),
+            common_chat_parse(
+                "```json\n"
+                " { \"name\" : \"python\"",
+                /* is_partial= */ true,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            simple_assist_msg("Let's call something\n"),
+            common_chat_parse(
+                "Let's call something\n"
+                "{\"name\"",
+                /* is_partial= */ true,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+        assert_msg_equals(
+            simple_assist_msg("Let's call something\n"),
+            common_chat_parse(
+                "Let's call something\n"
+                "{\"name",
+                /* is_partial= */ true,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+        assert_msg_equals(message_assist_call_thoughts,
+            common_chat_parse(
+                // QwQ-32B's template adds a trailing  if add_generation_prompt
+                "I'm\nthinking\n"
+                "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ true,
+                }));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "\n"
+                "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(message_assist_call_content,
+            common_chat_parse(
+                "Hello, world!\nWhat's up?\n"
+                "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "{\"arg1\": 1}",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "\n"
+                "{\"arg1\": 1}\n"
+                "",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "```xml\n"
+                "\n"
+                "    {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "\n"
+                "```",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "```xml\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "```",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "```\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "```",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "```\n"
+                "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "```",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "```json\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "```",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "```json\n"
+                "\n"
+                "                     {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}} \n"
+                "                     \n"
+                "``` ",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "\n"
+                "  {\n"
+                "    \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}\n"
+                "  }\n"
+                "",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "{\n  \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+
+        assert_msg_equals(
+            simple_assist_msg(
+                "This is not a tool call:",
+                "",
+                "special_function",
+                "{\"arg1\": 1}"),
+            common_chat_parse(
+                "This is not a tool call:\n"
+                "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(message_assist,
+            common_chat_parse(
+                "Hello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(message_assist_thoughts_unparsed_deepseek,
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        // assert_msg_equals(message_assist_thoughts_unparsed_deepseek,
+        //     common_chat_parse(
+        //         "I'm\nthinkingHello, world!\nWhat's up?",
+        //         COMMON_CHAT_FORMAT_HERMES_2_PRO));
+        assert_msg_equals(message_assist_thoughts,
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+        assert_msg_equals(message_assist_thoughts,
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ true,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+        assert_msg_equals(message_assist_thoughts_unparsed_md,
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}```",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ true,
+                    /* .thinking_forced_open = */ false,
+                    /* .parse_tool_calls = */ false,
+                }));
+        assert_msg_equals(message_assist_thoughts_unparsed_md_partial,
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}```",
+                /* is_partial= */ true,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ true,
+                    /* .thinking_forced_open = */ false,
+                }));
+        assert_msg_equals(message_assist_thoughts_unopened_unparsed,
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+        assert_msg_equals(message_assist_thoughts,
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ true,
+                }));
+
+        test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+        test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
+                      "\n"
+                      "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                      "");
+        test_templates(tmpls.get(), end_tokens, message_assist_call_python_lines, tools,
+                      "\n"
+                      "{\"name\": \"python\", \"arguments\": {\"code\":\"# This is a program:\\nprint('hey')\"}}\n"
+                      "");
+        assert_msg_equals(
+            simple_assist_msg("", /* reasoning_content= */ "nah uhg"),
+            common_chat_parse(
+                "nah uhg",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+    }
+    {
+        auto tmpls = read_templates("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja");
+        std::vector   end_tokens{ "<|eom_id|>", "<|eot_id|>" };
+
+        assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
+                      common_chat_templates_apply(tmpls.get(), inputs_tools_builtin).format);
+        assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
+                      common_chat_templates_apply(
+                          read_templates("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja").get(),
+                          inputs_tools_builtin)
+                          .format);
+
+        assert_equals(
+            message_assist_call,
+            common_chat_parse(
+                "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_LLAMA_3_X}));
+
+        // test_templates(tmpls.get(), end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false);
+        test_templates(tmpls.get(), end_tokens, message_assist_call_code_interpreter, llama_3_1_tools,
+                      "<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
+        test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools,
+                      "<|python_tag|>python.call(code=\"print('hey')\")");
+        test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
+                      "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
+    }
+    {
+        auto tmpls = read_templates("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja");
+        std::vector   end_tokens{ "<|eom_id|>", "<|eot_id|>" };
+
+        assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
+
+        test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+        test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
+                      "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
+    }
+    {
+        auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.1.jinja");
+        std::vector   end_tokens{ "<|eom_id|>", "<|eot_id|>" };
+
+        assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
+                      common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
+            common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
+                        common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
+
+        for (auto is_partial : { false, true }) {
+            assert_equals(
+                message_assist_call,
+                common_chat_parse(
+                    "{\"arg1\": 1}",
+                    is_partial,
+                    {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1}));
+        }
+
+        assert_equals(
+            message_assist_call,
+            common_chat_parse(
+                "{\"arg1\": 1}<",
+                /* is_partial= */ true,
+                {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1}));
+
+        test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+        test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
+                      "{\"arg1\": 1}");
+    }
+    {
+        auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.2.jinja");
+        std::vector   end_tokens{ "<|eom_id|>", "<|eot_id|>" };
+
+        assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+
+        assert_msg_equals(
+            simple_assist_msg(
+                "Hello, world!\nnono\nWhat's up?",
+                "",
+                "special_function",
+                "{\"arg1\": 1}"),
+            common_chat_parse(
+                "all\n"
+                "Hello, world!\n"
+                "nono\n"
+                "What's up?>>>special_function\n"
+                "{\"arg1\": 1}\n",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2}));
+        assert_msg_equals(message_assist_call_python_lines,
+            common_chat_parse(
+                "python\n"
+                "# This is a program:\n"
+                "print('hey')",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2}));
+        assert_msg_equals(message_assist_call_python_lines_unclosed,
+            common_chat_parse(
+                "python\n"
+                "# This is a program:\n"
+                "print('hey')",
+                /* is_partial= */ true,
+                {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2}));
+        assert_msg_equals(message_assist_call,
+            common_chat_parse(
+                "special_function\n"
+                "{\"arg1\": 1} \n                    ",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2}));
+        assert_msg_equals(message_assist,
+            common_chat_parse(
+                "all\n"
+                "Hello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2}));
+
+        test_templates(tmpls.get(), end_tokens, message_assist, {},
+                      "all\n"
+                      "Hello, world!\n"
+                      "What's up?",
+                      /* expect_grammar_triggered= */ false);
+        test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
+                      "special_function\n"
+                      "{\"arg1\": 1}");
+    }
+    {
+        auto tmpls = read_templates("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja");
+        std::vector   end_tokens{ "<|eot_id|>" };
+
+        assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+
+        test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+        test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
+                      " functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
+    }
+    {
+        // Original DeepSeek R1 template. Leaves <|tool▁calls▁begin|> and others unclosed. Our logic fixes the prompt.
+        auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja");
+        std::vector   end_tokens{ "<|end▁of▁sentence|>" };
+
+        for (const auto & inputs : { inputs_no_tools, inputs_tools }) {
+            auto params = common_chat_templates_apply(tmpls.get(), inputs);
+            assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, params.format);
+            assert_equals(true, params.thinking_forced_open);
+        }
+
+        test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+        test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+        assert_msg_equals(
+            simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking"),
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ true,
+                }));
+        assert_msg_equals(
+            simple_assist_msg("", "I need to remember the correct syntax. It starts with <|tool▁calls▁begin|> and ends with"),
+            common_chat_parse(
+                "I need to remember the correct syntax. It starts with <|tool▁calls▁begin|> and ends with",
+                /* is_partial= */ true,
+                {
+                    COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ true,
+                }));
+        assert_msg_equals(message_assist_thoughts,
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+        assert_msg_equals(message_assist_thoughts_unopened_unparsed,
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+        assert_msg_equals(message_assist_thoughts,
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ true,
+                }));
+        assert_msg_equals(message_assist_thoughts,
+            // Latest template update (ast of 20250209) adds a trailing \n if add_generation_prompt is true.
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ true,
+                }));
+        // test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
+        //               "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
+        //               "```json\n"
+        //               "{\"arg1\": 1}\n"
+        //               // Look what's not here: <|tool▁calls▁end|> (also missing the <|end▁of▁sentence|>, but that is removed lazily by the test's delta logic)
+        //               "```<|tool▁call▁end|>",
+        //               /* expect_grammar_triggered= */ true,
+        //               /* test_grammar_if_triggered= */ false);
+    }
+    {
+        // Replacement DeepSeek R1 template. Makes the Distill Qwen 7B/32B models happy to call tools and all.
+        auto tmpls = read_templates("models/templates/llama-cpp-deepseek-r1.jinja");
+        std::vector   end_tokens{ "<|end▁of▁sentence|>" };
+
+        assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1,                   common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1,                   common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+
+        test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+        test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+        assert_msg_equals(message_assist_thoughts_unparsed_deepseek,
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_DEEPSEEK_R1}));
+        assert_msg_equals(message_assist_thoughts,
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+        assert_msg_equals(message_assist_thoughts,
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ true,
+                }));
+
+        assert_msg_equals(message_assist_call_thoughts_unparsed,
+            common_chat_parse(
+                "I'm\nthinking\n\n"
+                "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
+                "```json\n"
+                "{\"arg1\": 1}\n"
+                "```<|tool▁call▁end|><|tool▁calls▁end|>",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_DEEPSEEK_R1}));
+        assert_msg_equals(message_assist_call,
+            common_chat_parse(
+                "<|tool▁calls|>function<|tool▁sep|>special_function\n"
+                "```json\n"
+                "{\"arg1\": 1}\n"
+                "```<|tool▁call▁end|><|tool▁calls▁end|>",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_DEEPSEEK_R1}));
+
+        assert_msg_equals(message_assist_call_thoughts,
+            common_chat_parse(
+                "I'm\nthinking\n\n"
+                "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
+                "```json\n"
+                "{\"arg1\": 1}\n"
+                "```<|tool▁call▁end|><|tool▁calls▁end|>",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+        test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
+                "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
+                "```json\n"
+                "{\"arg1\": 1}\n"
+                "```<|tool▁call▁end|><|tool▁calls▁end|>");
+    }
+}
+
+static void test_msg_diffs_compute() {
+    printf("[%s]\n", __func__);
+    {
+        common_chat_msg msg1;
+
+        common_chat_msg msg2;
+        msg2.content = "Hello, world!";
+
+        common_chat_msg_diff diff;
+        diff.content_delta = "Hello, world!";
+
+        assert_equals(
+            {diff},
+            common_chat_msg_diff::compute_diffs(msg1, msg2));
+    }
+    {
+        common_chat_msg msg1;
+        msg1.content = "Hello,";
+
+        common_chat_msg msg2;
+        msg2.content = "Hello, world!";
+
+        common_chat_msg_diff diff;
+        diff.content_delta = " world!";
+
+        assert_equals(
+            {diff},
+            common_chat_msg_diff::compute_diffs(msg1, msg2));
+    }
+    {
+        common_chat_msg msg0;
+
+        common_chat_msg msg1;
+        msg1.tool_calls = { { "special_function", "{\"ar", /* .id = */ "123" } };
+
+        common_chat_msg msg2;
+        msg2.tool_calls = { { "special_function", "{\"arg1\": 1}", /* .id = */ "123" } };
+
+        common_chat_msg_diff diff01;
+        diff01.tool_call_index = 0;
+        diff01.tool_call_delta.name = "special_function";
+        diff01.tool_call_delta.id = "123";
+        diff01.tool_call_delta.arguments = "{\"ar";
+
+        assert_equals(
+            {diff01},
+            common_chat_msg_diff::compute_diffs(msg0, msg1));
+
+        common_chat_msg_diff diff12;
+        diff12.tool_call_index = 0;
+        // Note: neither id nor name change here.
+        diff12.tool_call_delta.arguments = "g1\": 1}";
+
+        assert_equals(
+            {diff12},
+            common_chat_msg_diff::compute_diffs(msg1, msg2));
+    }
+    {
+        common_chat_msg msg0;
+
+        common_chat_msg msg2;
+        msg2.tool_calls = {
+            { "f1", "{\"arg1\": 1}", /* .id = */ "123" },
+            { "f2", "{\"arg2\": 2}", /* .id = */ "222" },
+        };
+
+        common_chat_msg_diff diff1;
+        diff1.tool_call_index = 0;
+        diff1.tool_call_delta.name = "f1";
+        diff1.tool_call_delta.id = "123";
+        diff1.tool_call_delta.arguments = "{\"arg1\": 1}";
+
+        common_chat_msg_diff diff2;
+        diff2.tool_call_index = 1;
+        diff2.tool_call_delta.name = "f2";
+        diff2.tool_call_delta.id = "222";
+        diff2.tool_call_delta.arguments = "{\"arg2\": 2}";
+
+        assert_equals(
+            {diff1, diff2},
+            common_chat_msg_diff::compute_diffs(msg0, msg2));
+    }
+}
+
+int main(int argc, char ** argv) {
+    common_log_set_verbosity_thold(999);
+
+    // try {
+#ifndef _WIN32
+        if (argc > 1) {
+            common_chat_templates_inputs inputs;
+            common_chat_msg msg;
+            msg.role = "user";
+            msg.content = "Hey";
+            inputs.messages = {msg};
+            inputs.tools = { special_function_tool };
+
+            std::cout << "| Template | Format |\n";
+            std::cout << "|----------|--------|\n";
+
+            for (int i = 1; i < argc; i++) {
+                try {
+                    std::string path = argv[i];
+                    if (path.rfind(".jinja") != path.size() - 6) {
+                        std::cerr << "Skipping non-jinja file: " << path << '\n';
+                        continue;
+                    }
+                    auto tmpls = read_templates(path);
+                    auto parts  = string_split(path, "/");
+                    auto name   = parts[parts.size() - 1];
+                    auto format = common_chat_format_name(common_chat_templates_apply(tmpls.get(), inputs).format);
+                    std::cout << "| " << name << " | " << format << " |\n";
+                } catch (const std::exception & e) {
+                    std::cerr << "Failed to process " << argv[i] << ": " << e.what() << '\n';
+                }
+            }
+        } else
+#endif
+        {
+            test_msg_diffs_compute();
+            test_msgs_oaicompat_json_conversion();
+            test_tools_oaicompat_json_conversion();
+            test_template_output_parsers();
+            std::cout << "\n[chat] All tests passed!" << '\n';
+        }
+        return 0;
+    // } catch (const std::exception & e) {
+    //     std::cerr << "Error: " << e.what() << '\n';
+    //     return 1;
+    // }
+}
diff --git a/tests/test-double-float.cpp b/tests/test-double-float.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..6aac4737a7df0f5d0621d1c2a8f53c2789da146b
--- /dev/null
+++ b/tests/test-double-float.cpp
@@ -0,0 +1,57 @@
+// These tests may take a long time!
+// They are to prove that conversion from double to float of various functions in ggml.c doesn't affect the result.
+// This is done by checking all finite (non-NaN, non-infinite) floats.
+
+#undef NDEBUG
+#include 
+#if !defined(__riscv) && !defined(__s390__) && !defined(__ARM_NEON)
+#include 
+#endif
+#include 
+#include 
+#include 
+
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wdouble-promotion"
+
+// ggml.c::quantize_row_q4_0_ref
+inline static uint8_t round_orig(float v0) { return ((int8_t) (round(v0))) + 8; }
+
+// ggml.c::ggml_silu_f32
+inline static float silu_orig(float x) {
+    return x/(1.0 + exp(-x));
+}
+
+#pragma GCC diagnostic pop
+
+// ggml.c::quantize_row_q4_0_ref
+inline static uint8_t round_float(float v0) { return (int8_t)roundf(v0) + 8; }
+
+// ggml.c::ggml_silu_f32
+inline static float silu_float(float x) {
+    return x/(1.0f + expf(-x));
+}
+
+int main(void) {
+    uint32_t x = UINT32_MAX;
+    do {
+        float f;
+        memcpy(&f, &x, sizeof(x));
+        assert(!std::isfinite(f) || (round_orig(f) == round_float(f)));
+    } while (x--);
+
+#ifdef __F16C__
+    // GELU and SILU implementations are used with a FP16 lookup table.
+    // The original and float-only results are not equal for all inputs after converting to FP16.
+    // GELU is an approximation anyway (tanh), not tested here.
+    // For SILU, verify that the results are at least the closest floating point numbers, if the FP16 values don't match.
+    for (x = 0; x <= UINT16_MAX; x++) {
+        float f = _cvtsh_ss(x);
+        const float so = silu_orig(f);
+        const float sf = silu_float(f);
+        assert(   (_cvtss_sh(so, 0) == _cvtss_sh(sf, 0))
+               || (nextafterf(so, sf) == sf)
+               || (nextafterf(sf, so) == so));
+    }
+#endif
+}
diff --git a/tests/test-gbnf-validator.cpp b/tests/test-gbnf-validator.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..6547eec32fab47b75e8ceadddf5a9c2231e4e1a9
--- /dev/null
+++ b/tests/test-gbnf-validator.cpp
@@ -0,0 +1,109 @@
+#include "../src/unicode.h"
+#include "../src/llama-grammar.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
+    const auto cpts = unicode_cpts_from_utf8(input_str);
+
+    auto & stacks_cur = llama_grammar_get_stacks(grammar);
+
+    size_t pos = 0;
+    for (const auto & cpt : cpts) {
+        llama_grammar_accept(grammar, cpt);
+
+        if (stacks_cur.empty()) {
+            error_pos = pos;
+            error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
+            return false;
+        }
+        ++pos;
+    }
+
+    for (const auto & stack : stacks_cur) {
+        if (stack.empty()) {
+            return true;
+        }
+    }
+
+    error_pos = pos;
+    error_msg = "Unexpected end of input";
+    return false;
+}
+
+static void print_error_message(const std::string & input_str, size_t error_pos, const std::string & error_msg) {
+    fprintf(stdout, "Input string is invalid according to the grammar.\n");
+    fprintf(stdout, "Error: %s at position %zu\n", error_msg.c_str(), error_pos);
+    fprintf(stdout, "\n");
+    fprintf(stdout, "Input string:\n");
+    fprintf(stdout, "%s", input_str.substr(0, error_pos).c_str());
+    if (error_pos < input_str.size()) {
+        fprintf(stdout, "\033[1;31m%c", input_str[error_pos]);
+        if (error_pos+1 < input_str.size()) {
+            fprintf(stdout, "\033[0;31m%s", input_str.substr(error_pos+1).c_str());
+        }
+        fprintf(stdout, "\033[0m\n");
+    }
+}
+
+int main(int argc, char** argv) {
+    if (argc != 3) {
+        fprintf(stdout, "Usage: %s  \n", argv[0]);
+        return 1;
+    }
+
+    const std::string grammar_filename = argv[1];
+    const std::string input_filename = argv[2];
+
+    // Read the GBNF grammar file
+    FILE* grammar_file = fopen(grammar_filename.c_str(), "r");
+    if (!grammar_file) {
+        fprintf(stdout, "Failed to open grammar file: %s\n", grammar_filename.c_str());
+        return 1;
+    }
+
+    std::string grammar_str;
+    {
+        std::ifstream grammar_file(grammar_filename);
+        GGML_ASSERT(grammar_file.is_open() && "Failed to open grammar file");
+        std::stringstream buffer;
+        buffer << grammar_file.rdbuf();
+        grammar_str = buffer.str();
+    }
+
+    llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0);
+    if (grammar == nullptr) {
+        fprintf(stdout, "Failed to initialize llama_grammar\n");
+        return 1;
+    }
+    // Read the input file
+    std::string input_str;
+    {
+        std::ifstream input_file(input_filename);
+        GGML_ASSERT(input_file.is_open() && "Failed to open input file");
+        std::stringstream buffer;
+        buffer << input_file.rdbuf();
+        input_str = buffer.str();
+    }
+
+    // Validate the input string against the grammar
+    size_t error_pos;
+    std::string error_msg;
+    bool is_valid = llama_grammar_validate(grammar, input_str, error_pos, error_msg);
+
+    if (is_valid) {
+        fprintf(stdout, "Input string is valid according to the grammar.\n");
+    } else {
+        print_error_message(input_str, error_pos, error_msg);
+    }
+
+    // Clean up
+    llama_grammar_free_impl(grammar);
+
+    return 0;
+}
diff --git a/tests/test-gguf.cpp b/tests/test-gguf.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..3f0c312e2f003341ba7deb15817ac4fb8d2aacda
--- /dev/null
+++ b/tests/test-gguf.cpp
@@ -0,0 +1,1344 @@
+#include "ggml.h"
+#include "ggml-backend.h"
+#include "../ggml/src/ggml-impl.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+constexpr int offset_has_kv      = 1000;
+constexpr int offset_has_tensors = 2000;
+constexpr int offset_has_data    = 3000;
+
+enum handcrafted_file_type {
+    HANDCRAFTED_HEADER_BAD_MAGIC           =  10,
+    HANDCRAFTED_HEADER_BAD_VERSION_0       =  15,
+    HANDCRAFTED_HEADER_BAD_VERSION_1       =  20,
+    HANDCRAFTED_HEADER_BAD_VERSION_FUTURE  =  30,
+    HANDCRAFTED_HEADER_BAD_N_TENSORS       =  40,
+    HANDCRAFTED_HEADER_BAD_N_KV            =  50,
+    HANDCRAFTED_HEADER_EMPTY               = 800,
+
+    HANDCRAFTED_KV_BAD_KEY_SIZE            =  10 + offset_has_kv,
+    HANDCRAFTED_KV_BAD_TYPE                =  20 + offset_has_kv,
+    // HANDCRAFTED_KV_BAD_VALUE_SIZE          =  30 + offset_has_kv, // removed because it can result in allocations > 1 TB (default sanitizer limit)
+    HANDCRAFTED_KV_DUPLICATE_KEY           =  40 + offset_has_kv,
+    HANDCRAFTED_KV_BAD_ALIGN               =  50 + offset_has_kv,
+    HANDCRAFTED_KV_SUCCESS                 = 800 + offset_has_kv,
+
+    HANDCRAFTED_TENSORS_BAD_NAME_SIZE      =  10 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_BAD_N_DIMS         =  20 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_BAD_SHAPE          =  30 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_NE_TOO_BIG         =  40 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_BAD_TYPE           =  50 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_BAD_OFFSET         =  60 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_DUPLICATE_NAME     =  70 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_BAD_ALIGN          =  75 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_INCONSISTENT_ALIGN =  80 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_SUCCESS            = 800 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_CUSTOM_ALIGN       = 810 + offset_has_tensors,
+
+    HANDCRAFTED_DATA_NOT_ENOUGH_DATA       =  10 + offset_has_data,
+    HANDCRAFTED_DATA_BAD_ALIGN             =  15 + offset_has_data,
+    HANDCRAFTED_DATA_INCONSISTENT_ALIGN    =  20 + offset_has_data,
+    HANDCRAFTED_DATA_SUCCESS               = 800 + offset_has_data,
+    HANDCRAFTED_DATA_CUSTOM_ALIGN          = 810 + offset_has_data,
+};
+
+static std::string handcrafted_file_type_name(const enum handcrafted_file_type hft) {
+    switch (hft) {
+        case HANDCRAFTED_HEADER_BAD_MAGIC:           return "HEADER_BAD_MAGIC";
+        case HANDCRAFTED_HEADER_BAD_VERSION_0:       return "HEADER_BAD_VERSION_0";
+        case HANDCRAFTED_HEADER_BAD_VERSION_1:       return "HEADER_BAD_VERSION_1";
+        case HANDCRAFTED_HEADER_BAD_VERSION_FUTURE:  return "HEADER_BAD_VERSION_FUTURE";
+        case HANDCRAFTED_HEADER_BAD_N_KV:            return "HEADER_BAD_N_KV";
+        case HANDCRAFTED_HEADER_BAD_N_TENSORS:       return "HEADER_BAD_N_TENSORS";
+        case HANDCRAFTED_HEADER_EMPTY:               return "HEADER_EMPTY";
+
+        case HANDCRAFTED_KV_BAD_KEY_SIZE:            return "KV_BAD_KEY_SIZE";
+        case HANDCRAFTED_KV_BAD_TYPE:                return "KV_BAD_TYPE";
+        case HANDCRAFTED_KV_DUPLICATE_KEY:           return "KV_DUPLICATE_KEY";
+        case HANDCRAFTED_KV_BAD_ALIGN:               return "KV_BAD_ALIGN";
+        case HANDCRAFTED_KV_SUCCESS:                 return "KV_RANDOM_KV";
+
+        case HANDCRAFTED_TENSORS_BAD_NAME_SIZE:      return "TENSORS_BAD_NAME_SIZE";
+        case HANDCRAFTED_TENSORS_BAD_N_DIMS:         return "TENSORS_BAD_N_DIMS";
+        case HANDCRAFTED_TENSORS_BAD_SHAPE:          return "TENSORS_BAD_SHAPE";
+        case HANDCRAFTED_TENSORS_NE_TOO_BIG:         return "TENSORS_NE_TOO_BIG";
+        case HANDCRAFTED_TENSORS_BAD_TYPE:           return "TENSORS_BAD_TYPE";
+        case HANDCRAFTED_TENSORS_BAD_OFFSET:         return "TENSORS_BAD_OFFSET";
+        case HANDCRAFTED_TENSORS_DUPLICATE_NAME:     return "TENSORS_DUPLICATE_NAME";
+        case HANDCRAFTED_TENSORS_BAD_ALIGN:          return "TENSORS_BAD_ALIGN";
+        case HANDCRAFTED_TENSORS_INCONSISTENT_ALIGN: return "TENSORS_INCONSISTENT_ALIGN";
+        case HANDCRAFTED_TENSORS_SUCCESS:            return "TENSORS_SUCCESS";
+        case HANDCRAFTED_TENSORS_CUSTOM_ALIGN:       return "TENSORS_CUSTOM_ALIGN";
+
+        case HANDCRAFTED_DATA_NOT_ENOUGH_DATA:       return "DATA_NOT_ENOUGH_DATA";
+        case HANDCRAFTED_DATA_BAD_ALIGN:             return "DATA_BAD_ALIGN";
+        case HANDCRAFTED_DATA_INCONSISTENT_ALIGN:    return "DATA_INCONSISTENT_ALIGN";
+        case HANDCRAFTED_DATA_SUCCESS:               return "DATA_SUCCESS";
+        case HANDCRAFTED_DATA_CUSTOM_ALIGN:          return "DATA_CUSTOM_ALIGN";
+    }
+    GGML_ABORT("fatal error");
+}
+
+static bool expect_context_not_null(const enum handcrafted_file_type hft) {
+    if (hft < offset_has_kv) {
+        return hft >= HANDCRAFTED_HEADER_EMPTY;
+    }
+    if (hft < offset_has_tensors) {
+        return hft >= HANDCRAFTED_KV_SUCCESS;
+    }
+    if (hft < offset_has_data) {
+        return hft >= HANDCRAFTED_TENSORS_SUCCESS;
+    }
+    return hft >= HANDCRAFTED_DATA_SUCCESS;
+}
+
+typedef std::pair> tensor_config_t;
+
+static std::vector get_tensor_configs(std::mt19937 & rng) {
+    std::vector tensor_configs;
+    tensor_configs.reserve(100);
+
+    for (int i = 0; i < 100; ++i) {
+        const enum ggml_type type = ggml_type(rng() % GGML_TYPE_COUNT);
+        if (ggml_type_size(type) == 0) {
+            continue;
+        }
+
+        std::array shape = {1, 1, 1, 1};
+        shape[0] = (1 + rng() % 10) * ggml_blck_size(type);
+        const int n_dims = 1 + rng() % GGML_MAX_DIMS;
+        for (int i = 1; i < n_dims; ++i) {
+            shape[i] = 1 + rng() % 10;
+        }
+
+        tensor_configs.push_back(std::make_pair(type, shape));
+    }
+
+    return tensor_configs;
+}
+
+static std::vector> get_kv_types(std::mt19937 rng) {
+    std::vector> kv_types;
+    kv_types.reserve(100);
+
+    for (int i = 0; i < 100; ++i) {
+        const gguf_type type = gguf_type(rng() % GGUF_TYPE_COUNT);
+
+        if (type == GGUF_TYPE_ARRAY) {
+            const gguf_type type_arr = gguf_type(rng() % GGUF_TYPE_COUNT);
+            if (type_arr == GGUF_TYPE_ARRAY) {
+                continue;
+            }
+            kv_types.push_back(std::make_pair(type, type_arr));
+            continue;
+        }
+
+        kv_types.push_back(std::make_pair(type, gguf_type(-1)));
+    }
+    std::shuffle(kv_types.begin(), kv_types.end(), rng);
+
+    return kv_types;
+}
+
+template 
+static void helper_write(FILE * file, const T & val) {
+    GGML_ASSERT(fwrite(&val, 1, sizeof(val), file) == sizeof(val));
+}
+
+static void helper_write(FILE * file, const void * data, const size_t nbytes) {
+    GGML_ASSERT(fwrite(data, 1, nbytes, file) == nbytes);
+}
+
+static FILE * get_handcrafted_file(const unsigned int seed, const enum handcrafted_file_type hft, const int extra_bytes = 0) {
+    FILE * file = tmpfile();
+
+    if (!file) {
+        return file;
+    }
+
+    std::mt19937 rng(seed);
+    uint32_t alignment = GGUF_DEFAULT_ALIGNMENT;
+
+    if (hft == HANDCRAFTED_HEADER_BAD_MAGIC) {
+        const char bad_magic[4] = {'F', 'U', 'G', 'G'};
+        helper_write(file, bad_magic, sizeof(bad_magic));
+    } else {
+        helper_write(file, GGUF_MAGIC, 4);
+    }
+
+    if (hft == HANDCRAFTED_HEADER_BAD_VERSION_0) {
+        const uint32_t version = 0;
+        helper_write(file, version);
+    } else if (hft == HANDCRAFTED_HEADER_BAD_VERSION_1) {
+        const uint32_t version = 1;
+        helper_write(file, version);
+    } else if (hft == HANDCRAFTED_HEADER_BAD_VERSION_FUTURE) {
+        const uint32_t version = GGUF_VERSION + 1;
+        helper_write(file, version);
+    } else {
+        const uint32_t version = GGUF_VERSION;
+        helper_write(file, version);
+    }
+
+    std::vector tensor_configs;
+    if (hft >= offset_has_tensors) {
+        tensor_configs = get_tensor_configs(rng);
+    }
+
+    if (hft == HANDCRAFTED_HEADER_BAD_N_TENSORS) {
+        const uint64_t n_tensors = -1;
+        helper_write(file, n_tensors);
+    } else {
+        const uint64_t n_tensors = tensor_configs.size();
+        helper_write(file, n_tensors);
+    }
+
+    std::vector> kv_types;
+    if (hft >= offset_has_kv) {
+        kv_types = get_kv_types(rng);
+    }
+    {
+        uint64_t n_kv = kv_types.size();
+        if (hft == HANDCRAFTED_KV_BAD_ALIGN      ||
+            hft == HANDCRAFTED_TENSORS_BAD_ALIGN || hft == HANDCRAFTED_TENSORS_CUSTOM_ALIGN ||
+            hft == HANDCRAFTED_DATA_BAD_ALIGN    || hft == HANDCRAFTED_DATA_CUSTOM_ALIGN) {
+
+            n_kv += 1;
+        } else if (hft == HANDCRAFTED_HEADER_BAD_N_KV) {
+            n_kv = -1;
+        }
+        helper_write(file, n_kv);
+    }
+
+    if (hft < offset_has_kv) {
+        while (ftell(file) % alignment != 0) {
+            const char pad = 0;
+            helper_write(file, pad);
+        }
+
+        for (int i = 0; i < extra_bytes; ++i) {
+            const char tmp = 0;
+            helper_write(file, tmp);
+        }
+        rewind(file);
+        return file;
+    }
+
+    for (int i = 0; i < int(kv_types.size()); ++i) {
+        const enum gguf_type type     = gguf_type(hft == HANDCRAFTED_KV_BAD_TYPE ? GGUF_TYPE_COUNT : kv_types[i].first);
+        const enum gguf_type type_arr = gguf_type(hft == HANDCRAFTED_KV_BAD_TYPE ? GGUF_TYPE_COUNT : kv_types[i].second);
+
+        const std::string key = "my_key_" + std::to_string((hft == HANDCRAFTED_KV_DUPLICATE_KEY ? i/2 : i));
+
+        if (hft == HANDCRAFTED_KV_BAD_KEY_SIZE) {
+            const uint64_t n = -1;
+            helper_write(file, n);
+        } else {
+            const uint64_t n = key.length();
+            helper_write(file, n);
+        }
+        helper_write(file, key.data(), key.length());
+
+        {
+            const int32_t type32 = int32_t(type);
+            helper_write(file, type32);
+        }
+
+        uint32_t data[16];
+        for (int j = 0; j < 16; ++j) {
+            data[j] = rng();
+            if (type == GGUF_TYPE_STRING || type_arr == GGUF_TYPE_STRING) {
+                data[j] |= 0x01010101; // avoid random null-termination of string
+            }
+        }
+
+        if (type == GGUF_TYPE_STRING) {
+            const uint64_t n = rng() % sizeof(data);
+            helper_write(file, n);
+            helper_write(file, data, n);
+            continue;
+        }
+
+        if (type == GGUF_TYPE_ARRAY) {
+            {
+                const int32_t type32 = int32_t(type_arr);
+                helper_write(file, type32);
+            }
+            if (type_arr == GGUF_TYPE_STRING) {
+                const uint64_t nstr = rng() % (16 + 1);
+                helper_write(file, nstr);
+                for (uint64_t istr = 0; istr < nstr; ++istr) {
+                    const uint64_t n = rng() % (sizeof(uint32_t) + 1);
+                    helper_write(file, n);
+                    helper_write(file, &data[istr], n);
+                }
+                continue;
+            }
+            const size_t type_size = gguf_type_size(type_arr);
+            const uint64_t n = (rng() % sizeof(data)) / type_size;
+            helper_write(file, n);
+            helper_write(file, &data, n*type_size);
+            continue;
+        }
+
+        helper_write(file, data, hft == HANDCRAFTED_KV_BAD_TYPE ? 1 : gguf_type_size(type));
+    }
+
+    if (hft == HANDCRAFTED_KV_BAD_ALIGN      ||
+        hft == HANDCRAFTED_TENSORS_BAD_ALIGN || hft == HANDCRAFTED_TENSORS_CUSTOM_ALIGN ||
+        hft == HANDCRAFTED_DATA_BAD_ALIGN    || hft == HANDCRAFTED_DATA_CUSTOM_ALIGN) {
+
+        const uint64_t n = strlen(GGUF_KEY_GENERAL_ALIGNMENT);
+        helper_write(file, n);
+        helper_write(file, GGUF_KEY_GENERAL_ALIGNMENT, n);
+
+        const int32_t type = gguf_type(GGUF_TYPE_UINT32);
+        helper_write(file, type);
+
+        alignment = expect_context_not_null(hft) ? 1 : 13;
+        helper_write(file, alignment);
+    }
+
+    if (hft < offset_has_tensors) {
+        while (ftell(file) % alignment != 0) {
+            const char pad = 0;
+            helper_write(file, pad);
+        }
+
+        for (int i = 0; i < extra_bytes; ++i) {
+            const char tmp = 0;
+            helper_write(file, tmp);
+        }
+        rewind(file);
+        return file;
+    }
+
+    if (hft == HANDCRAFTED_TENSORS_INCONSISTENT_ALIGN || hft == HANDCRAFTED_DATA_INCONSISTENT_ALIGN) {
+        alignment = 1;
+    }
+
+    uint64_t offset = 0;
+    for (int i = 0; i < int(tensor_configs.size()); ++i) {
+        const ggml_type                          type  = tensor_configs[i].first;
+        const std::array shape = tensor_configs[i].second;
+
+        std::string name = "my_tensor";
+        if (hft != HANDCRAFTED_TENSORS_DUPLICATE_NAME) {
+            name += "_" + std::to_string(i);
+        }
+        if (hft == HANDCRAFTED_TENSORS_BAD_NAME_SIZE) {
+            name += "_with_a_very_long_name_which_is_longer_than_what_is_allowed_for_ggml_tensors";
+            GGML_ASSERT(name.length() >= GGML_MAX_NAME);
+        }
+        {
+            const uint64_t n = name.length();
+            helper_write(file, n);
+        }
+        helper_write(file, name.data(), name.length());
+
+        uint32_t n_dims = hft == HANDCRAFTED_TENSORS_NE_TOO_BIG ? 2 : 1;
+        for (int i = GGML_MAX_DIMS-1; i >= 1; --i) {
+            if (shape[i] != 1) {
+                n_dims = i + 1;
+                break;
+            }
+        }
+        if (hft == HANDCRAFTED_TENSORS_BAD_N_DIMS) {
+            const uint32_t n_dims_bad = GGML_MAX_DIMS + 1;
+            helper_write(file, n_dims_bad);
+        } else {
+            helper_write(file, n_dims);
+        }
+
+        if (hft == HANDCRAFTED_TENSORS_BAD_SHAPE) {
+            for (uint32_t j = 0; j < n_dims; ++j) {
+                const int64_t bad_dim = -1;
+                helper_write(file, bad_dim);
+            }
+        } else if (hft == HANDCRAFTED_TENSORS_NE_TOO_BIG){
+            for (uint32_t j = 0; j < n_dims; ++j) {
+                const int64_t big_dim = 4*int64_t(INT32_MAX);
+                helper_write(file, big_dim);
+            }
+        } else {
+            helper_write(file, shape.data(), n_dims*sizeof(int64_t));
+        }
+
+        {
+            const int32_t type32 = hft == HANDCRAFTED_TENSORS_BAD_TYPE ? GGML_TYPE_COUNT : int32_t(type);
+            helper_write(file, type32);
+        }
+
+        if (hft == HANDCRAFTED_TENSORS_BAD_OFFSET) {
+            const uint64_t bad_offset = -1;
+            helper_write(file, bad_offset);
+        } else {
+            helper_write(file, offset);
+        }
+
+        int64_t ne = shape[0];
+        for (uint32_t i = 1; i < n_dims; ++i) {
+            ne *= shape[i];
+        }
+        offset += GGML_PAD(ggml_row_size(type, ne), alignment);
+    }
+
+    while (ftell(file) % alignment != 0) {
+        const char pad = 0;
+        helper_write(file, pad);
+    }
+
+    if (hft >= offset_has_data) {
+        rng.seed(seed + 1);
+        uint64_t nbytes = offset;
+        if (hft == HANDCRAFTED_DATA_NOT_ENOUGH_DATA) {
+            nbytes -= 1;
+        }
+        for (uint64_t i = 0; i < nbytes; ++i) {
+            const uint8_t random_byte = i % 256;
+            helper_write(file, random_byte);
+        }
+    }
+
+    for (int i = 0; i < extra_bytes; ++i) {
+        const char tmp = 0;
+        helper_write(file, tmp);
+    }
+    rewind(file);
+    return file;
+}
+
+static bool handcrafted_check_header(const gguf_context * gguf_ctx, const unsigned int seed, const bool has_kv, const bool has_tensors, const bool alignment_defined) {
+    if (!gguf_ctx) {
+        return false;
+    }
+
+    std::mt19937 rng(seed);
+
+    std::vector tensor_configs;
+    if (has_tensors) {
+        tensor_configs = get_tensor_configs(rng);
+    }
+    std::vector> kv_types;
+    if (has_kv) {
+        kv_types = get_kv_types(rng);
+    }
+
+    bool ok = true;
+
+    if (gguf_get_version(gguf_ctx) != GGUF_VERSION) {
+        ok = false;
+    }
+    if (gguf_get_n_tensors(gguf_ctx) != int(tensor_configs.size())) {
+        ok = false;
+    }
+    if (gguf_get_n_kv(gguf_ctx) != int(alignment_defined ? kv_types.size() + 1 : kv_types.size())) {
+        ok = false;
+    }
+
+    return ok;
+}
+
+static bool handcrafted_check_kv(const gguf_context * gguf_ctx, const unsigned int seed, const bool has_tensors, const bool alignment_defined) {
+    if (!gguf_ctx) {
+        return false;
+    }
+
+    std::mt19937 rng(seed);
+
+    std::vector tensor_configs;
+    if (has_tensors) {
+        tensor_configs = get_tensor_configs(rng);
+    }
+
+    std::vector> kv_types = get_kv_types(rng);
+
+    bool ok = true;
+
+    for (int i = 0; i < int(kv_types.size()); ++i) {
+        const enum gguf_type type     = gguf_type(kv_types[i].first);
+        const enum gguf_type type_arr = gguf_type(kv_types[i].second);
+
+        const std::string key = "my_key_" + std::to_string(i);
+
+        uint32_t data[16];
+        for (int j = 0; j < 16; ++j) {
+            data[j] = rng();
+            if (type == GGUF_TYPE_STRING || type_arr == GGUF_TYPE_STRING) {
+                data[j] |= 0x01010101; // avoid random null-termination of string
+            }
+        }
+
+        const char * data8 = reinterpret_cast(data);
+        const int id = gguf_find_key(gguf_ctx, key.c_str());
+
+        if (type == GGUF_TYPE_STRING) {
+            const char * str = gguf_get_val_str(gguf_ctx, id);
+            const uint64_t n = strlen(str);
+            const uint64_t n_expected = rng() % sizeof(data);
+            if (n != n_expected) {
+                ok = false;
+                continue;
+            }
+            if (!std::equal(str, str + n, data8)) {
+                ok = false;
+            }
+            continue;
+        }
+
+        if (type == GGUF_TYPE_ARRAY) {
+            const size_t type_size = gguf_type_size(type_arr);
+            const uint64_t arr_n = gguf_get_arr_n(gguf_ctx, id);
+
+            if (type_arr == GGUF_TYPE_STRING) {
+                const uint64_t nstr_expected = rng() % (16 + 1);
+                if (arr_n != nstr_expected) {
+                    ok = false;
+                    continue;
+                }
+                for (uint64_t istr = 0; istr < nstr_expected; ++istr) {
+                    const char * str = gguf_get_arr_str(gguf_ctx, id, istr);
+                    const uint64_t n = strlen(str);
+                    const uint64_t n_expected = rng() % (sizeof(uint32_t) + 1);
+
+                    if (n != n_expected) {
+                        ok = false;
+                        continue;
+                    }
+                    const char * str_expected = reinterpret_cast(&data[istr]);
+                    if (strncmp(str, str_expected, n) != 0) {
+                        ok = false;
+                        continue;
+                    }
+                }
+                continue;
+            }
+
+            const uint64_t arr_n_expected = (rng() % sizeof(data)) / type_size;
+            if (arr_n != arr_n_expected) {
+                ok = false;
+                continue;
+            }
+
+            const char * data_gguf = reinterpret_cast(gguf_get_arr_data(gguf_ctx, id));
+
+            if (type_arr == GGUF_TYPE_BOOL) {
+                for (size_t arr_i = 0; arr_i < arr_n; ++arr_i) {
+                    if (bool(data8[arr_i]) != bool(data_gguf[arr_i])) {
+                        ok = false;
+                    }
+                }
+                continue;
+            }
+
+            if (!std::equal(data8, data8 + arr_n*type_size, data_gguf)) {
+                ok = false;
+            }
+            continue;
+        }
+
+        const char * data_gguf = reinterpret_cast(gguf_get_val_data(gguf_ctx, id));
+
+        if (type == GGUF_TYPE_BOOL) {
+            if (bool(*data8) != bool(*data_gguf)) {
+                ok = false;
+            }
+            continue;
+        }
+
+        if (!std::equal(data8, data8 + gguf_type_size(type), data_gguf)) {
+            ok = false;
+        }
+    }
+
+    const uint32_t expected_alignment = alignment_defined ? 1 : GGUF_DEFAULT_ALIGNMENT;
+    if (gguf_get_alignment(gguf_ctx) != expected_alignment) {
+        ok = false;
+    }
+
+    return ok;
+}
+
+static bool handcrafted_check_tensors(const gguf_context * gguf_ctx, const unsigned int seed) {
+    if (!gguf_ctx) {
+        return false;
+    }
+
+    std::mt19937 rng(seed);
+
+    std::vector tensor_configs = get_tensor_configs(rng);
+
+    // Call get_kv_types to get the same RNG state:
+    get_kv_types(rng);
+
+    bool ok = true;
+
+    const int id_alignment = gguf_find_key(gguf_ctx, GGUF_KEY_GENERAL_ALIGNMENT);
+    const uint32_t alignment = id_alignment >= 0 ? gguf_get_val_u32(gguf_ctx, id_alignment) : GGUF_DEFAULT_ALIGNMENT;
+
+    uint64_t expected_offset = 0;
+    for (int i = 0; i < int(tensor_configs.size()); ++i) {
+        const ggml_type                          type  = tensor_configs[i].first;
+        const std::array shape = tensor_configs[i].second;
+
+        const std::string name = "my_tensor_" + std::to_string(i);
+        const int id = gguf_find_tensor(gguf_ctx, name.c_str());
+
+        if (id >= 0) {
+            if (std::string(gguf_get_tensor_name(gguf_ctx, id)) != name) {
+                ok = false;
+            }
+
+            if (gguf_get_tensor_type(gguf_ctx, id) != type) {
+                ok = false;
+            }
+        } else {
+            ok = false;
+            continue;
+        }
+
+        const size_t offset = gguf_get_tensor_offset(gguf_ctx, id);
+
+        if (offset != expected_offset) {
+            ok = false;
+        }
+
+        int64_t ne = shape[0];
+        for (size_t j = 1; j < GGML_MAX_DIMS; ++j) {
+            ne *= shape[j];
+        }
+        expected_offset += GGML_PAD(ggml_row_size(type, ne), alignment);
+    }
+
+    return ok;
+}
+
+static bool handcrafted_check_tensor_data(const gguf_context * gguf_ctx, const unsigned int seed, FILE * file) {
+    if (!gguf_ctx) {
+        return false;
+    }
+
+    std::mt19937 rng(seed);
+
+    std::vector tensor_configs = get_tensor_configs(rng);
+
+    bool ok = true;
+
+    for (int i = 0; i < int(tensor_configs.size()); ++i) {
+        const ggml_type                          type  = tensor_configs[i].first;
+        const std::array shape = tensor_configs[i].second;
+
+        int64_t ne = shape[0];
+        for (size_t j = 1; j < GGML_MAX_DIMS; ++j) {
+            ne *= shape[j];
+        }
+        const size_t size = ggml_row_size(type, ne);
+
+        const std::string name = "my_tensor_" + std::to_string(i);
+        const size_t offset = gguf_get_tensor_offset(gguf_ctx, gguf_find_tensor(gguf_ctx, name.c_str()));
+
+        std::vector data(size);
+        GGML_ASSERT(fseek(file, gguf_get_data_offset(gguf_ctx) + offset, SEEK_SET) == 0);
+        GGML_ASSERT(fread(data.data(), 1, data.size(), file) == data.size());
+
+        for (size_t j = 0; j < size; ++j) {
+            const uint8_t expected_byte = (j + offset) % 256;
+            if (data[j] != expected_byte) {
+                ok = false;
+            }
+        }
+    }
+
+    return ok;
+}
+
+static std::pair test_handcrafted_file(const unsigned int seed) {
+    int npass = 0;
+    int ntest = 0;
+
+    const std::vector hfts = {
+        HANDCRAFTED_HEADER_BAD_MAGIC,
+        HANDCRAFTED_HEADER_BAD_VERSION_0,
+        HANDCRAFTED_HEADER_BAD_VERSION_1,
+        HANDCRAFTED_HEADER_BAD_VERSION_FUTURE,
+        HANDCRAFTED_HEADER_BAD_N_KV,
+        HANDCRAFTED_HEADER_BAD_N_TENSORS,
+        HANDCRAFTED_HEADER_EMPTY,
+
+        HANDCRAFTED_KV_BAD_KEY_SIZE,
+        HANDCRAFTED_KV_BAD_TYPE,
+        HANDCRAFTED_KV_DUPLICATE_KEY,
+        HANDCRAFTED_KV_BAD_ALIGN,
+        HANDCRAFTED_KV_SUCCESS,
+
+        HANDCRAFTED_TENSORS_BAD_NAME_SIZE,
+        HANDCRAFTED_TENSORS_BAD_N_DIMS,
+        HANDCRAFTED_TENSORS_BAD_SHAPE,
+        HANDCRAFTED_TENSORS_NE_TOO_BIG,
+        HANDCRAFTED_TENSORS_BAD_TYPE,
+        HANDCRAFTED_TENSORS_BAD_OFFSET,
+        HANDCRAFTED_TENSORS_DUPLICATE_NAME,
+        HANDCRAFTED_TENSORS_BAD_ALIGN,
+        HANDCRAFTED_TENSORS_INCONSISTENT_ALIGN,
+        HANDCRAFTED_TENSORS_SUCCESS,
+        HANDCRAFTED_TENSORS_CUSTOM_ALIGN,
+
+        HANDCRAFTED_DATA_NOT_ENOUGH_DATA,
+        HANDCRAFTED_DATA_BAD_ALIGN,
+        HANDCRAFTED_DATA_INCONSISTENT_ALIGN,
+        HANDCRAFTED_DATA_SUCCESS,
+        HANDCRAFTED_DATA_CUSTOM_ALIGN,
+    };
+
+    for (enum handcrafted_file_type hft : hfts) {
+        printf("%s: handcrafted_file_type=%s\n", __func__, handcrafted_file_type_name(hft).c_str());
+        FILE * file = get_handcrafted_file(seed, hft);
+
+#ifdef _WIN32
+        if (!file) {
+            printf("failed to create tmpfile(), needs elevated privileges on Windows");
+            printf("skipping tests");
+            continue;
+        }
+#else
+        GGML_ASSERT(file);
+#endif // _WIN32
+
+        struct ggml_context * ctx = nullptr;
+        struct gguf_init_params gguf_params = {
+            /*no_alloc =*/ false,
+            /*ctx      =*/ hft >= offset_has_data ? &ctx : nullptr,
+        };
+
+        struct gguf_context * gguf_ctx = gguf_init_from_file_impl(file, gguf_params);
+
+        if (expect_context_not_null(hft)) {
+            printf("%s:   - context_not_null: ", __func__);
+        } else {
+            printf("%s:   - context_null: ", __func__);
+        }
+        if (bool(gguf_ctx) == expect_context_not_null(hft)) {
+            printf("\033[1;32mOK\033[0m\n");
+            npass++;
+        } else {
+            printf("\033[1;31mFAIL\033[0m\n");
+        }
+        ntest++;
+
+        if (hft >= offset_has_data && !expect_context_not_null(hft)) {
+            printf("%s:   - no_dangling_ggml_context_pointer: ", __func__);
+            if (ctx) {
+                printf("\033[1;31mFAIL\033[0m\n");
+            } else {
+                printf("\033[1;32mOK\033[0m\n");
+                npass++;
+            }
+            ntest++;
+        }
+
+        const bool alignment_defined = hft == HANDCRAFTED_TENSORS_CUSTOM_ALIGN || hft == HANDCRAFTED_DATA_CUSTOM_ALIGN;
+
+        if (expect_context_not_null(hft)) {
+            printf("%s:   - check_header: ", __func__);
+            if (handcrafted_check_header(gguf_ctx, seed, hft >= offset_has_kv, hft >= offset_has_tensors, alignment_defined)) {
+                printf("\033[1;32mOK\033[0m\n");
+                npass++;
+            } else {
+                printf("\033[1;31mFAIL\033[0m\n");
+            }
+            ntest++;
+        }
+
+        if (expect_context_not_null(hft) && hft >= offset_has_kv) {
+            printf("%s:   - check_kv: ", __func__);
+            if (handcrafted_check_kv(gguf_ctx, seed, hft >= offset_has_tensors, alignment_defined)) {
+                printf("\033[1;32mOK\033[0m\n");
+                npass++;
+            } else {
+                printf("\033[1;31mFAIL\033[0m\n");
+            }
+            ntest++;
+        }
+
+        if (expect_context_not_null(hft) && hft >= offset_has_tensors) {
+            printf("%s:   - check_tensors: ", __func__);
+            if (handcrafted_check_tensors(gguf_ctx, seed)) {
+                printf("\033[1;32mOK\033[0m\n");
+                npass++;
+            } else {
+                printf("\033[1;31mFAIL\033[0m\n");
+            }
+            ntest++;
+        }
+
+        if (expect_context_not_null(hft) && hft >= offset_has_data) {
+            printf("%s:   - check_tensor_data: ", __func__);
+            if (handcrafted_check_tensor_data(gguf_ctx, seed, file)) {
+                printf("\033[1;32mOK\033[0m\n");
+                npass++;
+            } else {
+                printf("\033[1;31mFAIL\033[0m\n");
+            }
+            ntest++;
+        }
+
+        fclose(file);
+        if (gguf_ctx) {
+            ggml_free(ctx);
+            gguf_free(gguf_ctx);
+        }
+        printf("\n");
+    }
+
+
+    return std::make_pair(npass, ntest);
+}
+
+struct random_gguf_context_result {
+    struct gguf_context * gguf_ctx;
+    struct ggml_context * ctx;
+    ggml_backend_buffer_t buffer;
+};
+
+static struct random_gguf_context_result get_random_gguf_context(ggml_backend_t backend, const unsigned int seed) {
+    std::mt19937 rng(seed);
+
+    struct gguf_context * gguf_ctx = gguf_init_empty();
+
+    for (int i = 0; i < 256; ++i) {
+        const std::string key = "my_key_" + std::to_string(rng() % 1024);
+        const enum gguf_type type = gguf_type(rng() % GGUF_TYPE_COUNT);
+
+        switch (type) {
+            case GGUF_TYPE_UINT8:   gguf_set_val_u8  (gguf_ctx, key.c_str(), rng() % (1 <<  7));             break;
+            case GGUF_TYPE_INT8:    gguf_set_val_i8  (gguf_ctx, key.c_str(), rng() % (1 <<  7) - (1 <<  6)); break;
+            case GGUF_TYPE_UINT16:  gguf_set_val_u16 (gguf_ctx, key.c_str(), rng() % (1 << 15));             break;
+            case GGUF_TYPE_INT16:   gguf_set_val_i16 (gguf_ctx, key.c_str(), rng() % (1 << 15) - (1 << 14)); break;
+            case GGUF_TYPE_UINT32:  gguf_set_val_u32 (gguf_ctx, key.c_str(), rng());                         break;
+            case GGUF_TYPE_INT32:   gguf_set_val_i32 (gguf_ctx, key.c_str(), rng()             - (1 << 30)); break;
+            case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (gguf_ctx, key.c_str(), rng() % 1024      - 512);       break;
+            case GGUF_TYPE_BOOL:    gguf_set_val_bool(gguf_ctx, key.c_str(), rng() % 2 == 0);                break;
+            case GGUF_TYPE_STRING:  gguf_set_val_str (gguf_ctx, key.c_str(), std::to_string(rng()).c_str()); break;
+            case GGUF_TYPE_UINT64:  gguf_set_val_u64 (gguf_ctx, key.c_str(), rng());                         break;
+            case GGUF_TYPE_INT64:   gguf_set_val_i64 (gguf_ctx, key.c_str(), rng()             - (1 << 30)); break;
+            case GGUF_TYPE_FLOAT64: gguf_set_val_f32 (gguf_ctx, key.c_str(), rng() % 1024      - 512);       break;
+            case GGUF_TYPE_ARRAY: {
+                const enum gguf_type type_arr = gguf_type(rng() % GGUF_TYPE_COUNT);
+                const uint64_t ne = rng() % 1024;
+
+                switch (type_arr) {
+                    case GGUF_TYPE_UINT8:
+                    case GGUF_TYPE_INT8:
+                    case GGUF_TYPE_UINT16:
+                    case GGUF_TYPE_INT16:
+                    case GGUF_TYPE_UINT32:
+                    case GGUF_TYPE_INT32:
+                    case GGUF_TYPE_FLOAT32:
+                    case GGUF_TYPE_BOOL:
+                    case GGUF_TYPE_UINT64:
+                    case GGUF_TYPE_INT64:
+                    case GGUF_TYPE_FLOAT64: {
+                        const size_t nbytes = ne*gguf_type_size(type_arr);
+                        std::vector random_data((nbytes + sizeof(uint32_t) - 1) / sizeof(uint32_t));
+                        for (size_t j = 0; j < random_data.size(); ++j) {
+                            random_data[j] = rng();
+                            if (type_arr == GGUF_TYPE_BOOL) {
+                                random_data[j] &= 0x01010101; // the sanitizer complains if booleans are not 0 or 1
+                            }
+                        }
+                        gguf_set_arr_data(gguf_ctx, key.c_str(), type_arr, random_data.data(), ne);
+                    } break;
+                    case GGUF_TYPE_STRING: {
+                        std::vector  data_cpp(ne);
+                        std::vector data_c(ne);
+                        for (size_t j = 0; j < data_cpp.size(); ++j) {
+                            data_cpp[j] = std::to_string(rng());
+                            data_c[j]   = data_cpp[j].c_str();
+                        }
+                        gguf_set_arr_str(gguf_ctx, key.c_str(), data_c.data(), ne);
+                    } break;
+                    case GGUF_TYPE_ARRAY: {
+                        break; // not supported
+                    }
+                    case GGUF_TYPE_COUNT:
+                    default: {
+                        GGML_ABORT("fatal error");
+                    }
+                }
+            } break;
+            case GGUF_TYPE_COUNT:
+            default: {
+                GGML_ABORT("fatal error");
+            }
+        }
+    }
+
+    struct ggml_init_params ggml_params = {
+        /*.mem_size   =*/ 256*ggml_tensor_overhead(),
+        /*.mem_buffer =*/ nullptr,
+        /*.no_alloc   =*/ true,
+    };
+    struct ggml_context * ctx = ggml_init(ggml_params);
+
+    for (int i = 0; i < 256; ++i) {
+        const std::string name = "my_tensor_" + std::to_string(i);
+        const enum ggml_type type = ggml_type(rng() % GGML_TYPE_COUNT);
+        const size_t type_size = ggml_type_size(type);
+
+        if (type_size == 0) {
+            continue;
+        }
+
+        const int n_dims = 1 + rng() % GGML_MAX_DIMS;
+        int64_t ne[GGML_MAX_DIMS];
+        ne[0] = (1 + rng() % 10) * ggml_blck_size(type);
+        for (int j = 1; j < n_dims; ++j) {
+            ne[j] = 1 + rng() % 10;
+        }
+
+        struct ggml_tensor * tensor = ggml_new_tensor(ctx, type, n_dims, ne);
+        ggml_set_name(tensor, name.c_str());
+    }
+
+    ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend);
+    for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
+        const size_t nbytes = ggml_nbytes(t);
+        std::vector random_data((nbytes + sizeof(uint32_t) - 1) / sizeof(uint32_t));
+        for (size_t j = 0; j < random_data.size(); ++j) {
+            random_data[j] = rng();
+        }
+        ggml_backend_tensor_set(t, random_data.data(), 0, nbytes);
+
+        gguf_add_tensor(gguf_ctx, t);
+    }
+
+    return {gguf_ctx, ctx, buf};
+}
+
+static bool all_kv_in_other(const gguf_context * ctx, const gguf_context * other) {
+    bool ok = true;
+
+    const int n_kv = gguf_get_n_kv(ctx);
+    for (int id = 0; id < n_kv; ++id) {
+        const char * name = gguf_get_key(ctx, id);
+
+        const int idx_other = gguf_find_key(other, name);
+        if (idx_other < 0) {
+            ok = false;
+            continue;
+        }
+
+        const gguf_type type = gguf_get_kv_type(ctx, id);
+        if (type != gguf_get_kv_type(other, idx_other)) {
+            ok = false;
+            continue;
+        }
+
+        if (type == GGUF_TYPE_ARRAY) {
+            const size_t arr_n = gguf_get_arr_n(ctx, id);
+            if (arr_n != gguf_get_arr_n(other, idx_other)) {
+                ok = false;
+                continue;
+            }
+
+            const gguf_type type_arr = gguf_get_arr_type(ctx, id);
+            if (type_arr != gguf_get_arr_type(other, idx_other)) {
+                ok = false;
+                continue;
+            }
+
+            if (type_arr == GGUF_TYPE_BOOL) {
+                const int8_t * data       = reinterpret_cast(gguf_get_arr_data(ctx,   id));
+                const int8_t * data_other = reinterpret_cast(gguf_get_arr_data(other, idx_other));
+                for (size_t arr_i = 0; arr_i < arr_n; ++arr_i) {
+                    if (bool(data[arr_i]) != bool(data_other[arr_i])) {
+                        ok = false;
+                    }
+                }
+                continue;
+            }
+
+            if (type_arr == GGUF_TYPE_STRING) {
+                for (size_t arr_i = 0; arr_i < arr_n; ++arr_i) {
+                    const std::string str       = gguf_get_arr_str(ctx,   id,       arr_i);
+                    const std::string str_other = gguf_get_arr_str(other, idx_other, arr_i);
+                    if (str != str_other) {
+                        ok = false;
+                    }
+                }
+                continue;
+            }
+
+            const int8_t * data       = reinterpret_cast(gguf_get_arr_data(ctx,   id));
+            const int8_t * data_other = reinterpret_cast(gguf_get_arr_data(other, idx_other));
+            if (!std::equal(data, data + arr_n*gguf_type_size(type_arr), data_other)) {
+                ok = false;
+            }
+            continue;
+        }
+
+        if (type == GGUF_TYPE_STRING) {
+            const std::string str       = gguf_get_val_str(ctx,   id);
+            const std::string str_other = gguf_get_val_str(other, idx_other);
+            if (str != str_other) {
+                ok = false;
+            }
+            continue;
+        }
+
+        const char * data       = reinterpret_cast(gguf_get_val_data(ctx,   id));
+        const char * data_other = reinterpret_cast(gguf_get_val_data(other, idx_other));
+        if (!std::equal(data, data + gguf_type_size(type), data_other)) {
+            ok = false;
+        }
+    }
+
+    return ok;
+}
+
+static bool all_tensors_in_other(const gguf_context * ctx, const gguf_context * other) {
+    bool ok = true;
+
+    const int n_tensors = gguf_get_n_tensors(ctx);
+    for (int id = 0; id < n_tensors; ++id) {
+        const std::string name = gguf_get_tensor_name(ctx, id);
+
+        const int idx_other = gguf_find_tensor(other, name.c_str());
+        if (id != idx_other) {
+            ok = false;
+            if (idx_other < 0) {
+                continue;
+            }
+        }
+
+        const ggml_type type = gguf_get_tensor_type(ctx, id);
+        if (type != gguf_get_tensor_type(other, id)) {
+            ok = false;
+        }
+
+        const size_t offset = gguf_get_tensor_offset(ctx, id);
+        if (offset != gguf_get_tensor_offset(other, id)) {
+            ok = false;
+        }
+    }
+
+    return ok;
+}
+
+static bool same_tensor_data(const struct ggml_context * orig, const struct ggml_context * read) {
+    bool ok = true;
+
+    struct ggml_tensor * t_orig = ggml_get_first_tensor(orig);
+    struct ggml_tensor * t_read = ggml_get_first_tensor(read);
+
+    if (std::string(t_read->name) != "GGUF tensor data binary blob") {
+        return false;
+    }
+    t_read = ggml_get_next_tensor(read, t_read);
+
+    while (t_orig) {
+        if (!t_read) {
+            ok = false;
+            break;
+        }
+
+        const size_t nbytes = ggml_nbytes(t_orig);
+        if (ggml_nbytes(t_read) != nbytes) {
+            ok = false;
+            break;
+        }
+        std::vector data_orig(nbytes);
+        ggml_backend_tensor_get(t_orig, data_orig.data(), 0, nbytes);
+        if (!std::equal(data_orig.data(), data_orig.data() + nbytes, reinterpret_cast(t_read->data))) {
+            ok = false;
+        }
+
+        t_orig = ggml_get_next_tensor(orig, t_orig);
+        t_read = ggml_get_next_tensor(read, t_read);
+    }
+    if (t_read) {
+        ok = false;
+    }
+
+    return ok;
+}
+
+static std::pair test_roundtrip(ggml_backend_dev_t dev, const unsigned int seed, const bool only_meta) {
+    ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
+    printf("%s: device=%s, backend=%s, only_meta=%s\n",
+        __func__, ggml_backend_dev_description(dev), ggml_backend_name(backend), only_meta ? "yes" : "no");
+
+    int npass = 0;
+    int ntest = 0;
+
+    struct gguf_context * gguf_ctx_0;
+    struct ggml_context * ctx_0;
+    ggml_backend_buffer_t bbuf;
+    {
+        struct random_gguf_context_result result = get_random_gguf_context(backend, seed);
+        gguf_ctx_0 = result.gguf_ctx;
+        ctx_0      = result.ctx;
+        bbuf       = result.buffer;
+    }
+
+    FILE * file = tmpfile();
+
+#ifdef _WIN32
+    if (!file) {
+        printf("failed to create tmpfile(), needs elevated privileges on Windows");
+        printf("skipping tests");
+        return std::make_pair(0, 0);
+    }
+#else
+    GGML_ASSERT(file);
+#endif // _WIN32
+
+    {
+        std::vector buf;
+        gguf_write_to_buf(gguf_ctx_0, buf, only_meta);
+        GGML_ASSERT(fwrite(buf.data(), 1, buf.size(), file) == buf.size());
+        rewind(file);
+    }
+
+    struct ggml_context * ctx_1 = nullptr;
+    struct gguf_init_params gguf_params = {
+        /*no_alloc =*/ false,
+        /*ctx      =*/ only_meta ? nullptr : &ctx_1,
+    };
+    struct gguf_context * gguf_ctx_1 = gguf_init_from_file_impl(file, gguf_params);
+
+    printf("%s: same_version: ", __func__);
+    if (gguf_get_version(gguf_ctx_0) == gguf_get_version(gguf_ctx_1)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    printf("%s: same_n_kv: ", __func__);
+    if (gguf_get_n_kv(gguf_ctx_0) == gguf_get_n_kv(gguf_ctx_1)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    printf("%s: same_n_tensors: ", __func__);
+    if (gguf_get_n_tensors(gguf_ctx_0) == gguf_get_n_tensors(gguf_ctx_1)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    printf("%s: all_orig_kv_in_read: ", __func__);
+    if (all_kv_in_other(gguf_ctx_0, gguf_ctx_1)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    printf("%s: all_read_kv_in_orig: ", __func__);
+    if (all_kv_in_other(gguf_ctx_1, gguf_ctx_0)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    printf("%s: all_orig_tensors_in_read: ", __func__);
+    if (all_tensors_in_other(gguf_ctx_0, gguf_ctx_1)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    printf("%s: all_read_tensors_in_orig: ", __func__);
+    if (all_tensors_in_other(gguf_ctx_1, gguf_ctx_0)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    if (!only_meta) {
+        printf("%s: same_tensor_data: ", __func__);
+        if (same_tensor_data(ctx_0, ctx_1)) {
+            printf("\033[1;32mOK\033[0m\n");
+            npass++;
+        } else {
+            printf("\033[1;31mFAIL\033[0m\n");
+        }
+        ntest++;
+    }
+
+    ggml_backend_buffer_free(bbuf);
+    ggml_free(ctx_0);
+    ggml_free(ctx_1);
+    gguf_free(gguf_ctx_0);
+    gguf_free(gguf_ctx_1);
+    ggml_backend_free(backend);
+    fclose(file);
+
+    printf("\n");
+    return std::make_pair(npass, ntest);
+}
+
+static std::pair test_gguf_set_kv(ggml_backend_dev_t dev, const unsigned int seed) {
+    ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
+    printf("%s: device=%s, backend=%s\n", __func__, ggml_backend_dev_description(dev), ggml_backend_name(backend));
+
+    int npass = 0;
+    int ntest = 0;
+
+    struct gguf_context * gguf_ctx_0;
+    struct ggml_context * ctx_0;
+    ggml_backend_buffer_t bbuf_0;
+    {
+        struct random_gguf_context_result result = get_random_gguf_context(backend, seed);
+        gguf_ctx_0 = result.gguf_ctx;
+        ctx_0      = result.ctx;
+        bbuf_0     = result.buffer;
+    }
+
+    struct gguf_context * gguf_ctx_1;
+    struct ggml_context * ctx_1;
+    ggml_backend_buffer_t bbuf_1;
+    {
+        struct random_gguf_context_result result = get_random_gguf_context(backend, seed + 1);
+        gguf_ctx_1 = result.gguf_ctx;
+        ctx_1      = result.ctx;
+        bbuf_1     = result.buffer;
+    }
+
+    struct gguf_context * gguf_ctx_2 = gguf_init_empty();
+
+    gguf_set_kv(gguf_ctx_1, gguf_ctx_0);
+    gguf_set_kv(gguf_ctx_2, gguf_ctx_0);
+
+    printf("%s: same_n_kv: ", __func__);
+    if (gguf_get_n_kv(gguf_ctx_0) == gguf_get_n_kv(gguf_ctx_2)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    printf("%s: all_kv_0_in_1: ", __func__);
+    if (all_kv_in_other(gguf_ctx_0, gguf_ctx_1)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    printf("%s: all_kv_0_in_2: ", __func__);
+    if (all_kv_in_other(gguf_ctx_0, gguf_ctx_2)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    gguf_set_kv(gguf_ctx_0, gguf_ctx_1);
+
+    printf("%s: same_n_kv_after_double_copy: ", __func__);
+    if (gguf_get_n_kv(gguf_ctx_0) == gguf_get_n_kv(gguf_ctx_1)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    printf("%s: all_kv_1_in_0_after_double_copy: ", __func__);
+    if (all_kv_in_other(gguf_ctx_1, gguf_ctx_0)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    ggml_backend_buffer_free(bbuf_0);
+    ggml_backend_buffer_free(bbuf_1);
+    ggml_free(ctx_0);
+    ggml_free(ctx_1);
+    gguf_free(gguf_ctx_0);
+    gguf_free(gguf_ctx_1);
+    gguf_free(gguf_ctx_2);
+    ggml_backend_free(backend);
+
+    printf("\n");
+    return std::make_pair(npass, ntest);
+}
+
+static void print_usage() {
+    printf("usage: test-gguf [seed]\n");
+    printf("  if no seed is unspecified then a random seed is used\n");
+}
+
+int main(int argc, char ** argv) {
+    if (argc > 2) {
+        print_usage();
+        return 1;
+    }
+
+    std::random_device rd;
+    const unsigned int seed = argc < 2 ? rd() : std::stoi(argv[1]);
+
+    // Initialize ggml backends early so the prints aren't interleaved with the test results:
+    ggml_backend_dev_count();
+    fprintf(stderr, "\n");
+
+    int npass = 0;
+    int ntest = 0;
+    {
+        std::pair result = test_handcrafted_file(seed);
+        npass += result.first;
+        ntest += result.second;
+    }
+
+    for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
+        ggml_backend_dev_t dev = ggml_backend_dev_get(i);
+
+        for (bool only_meta : {true, false}) {
+            std::pair result = test_roundtrip(dev, seed, only_meta);
+            npass += result.first;
+            ntest += result.second;
+        }
+
+        {
+            std::pair result = test_gguf_set_kv(dev, seed);
+            npass += result.first;
+            ntest += result.second;
+        }
+    }
+
+    printf("%d/%d tests passed\n", npass, ntest);
+    if (npass != ntest) {
+        printf("\033[1;31mFAIL\033[0m\n");
+        return 1;
+    }
+    printf("\033[1;32mOK\033[0m\n");
+    return 0;
+}
diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..6d64f07376fb8cdf242b38061008271ff28ad8c3
--- /dev/null
+++ b/tests/test-grammar-integration.cpp
@@ -0,0 +1,1310 @@
+#ifdef NDEBUG
+#undef NDEBUG
+#endif
+
+#include "json-schema-to-grammar.h"
+
+#include "../src/unicode.h"
+#include "../src/llama-grammar.h"
+
+#include 
+
+#include 
+#include 
+#include 
+
+using json = nlohmann::ordered_json;
+
+static llama_grammar * build_grammar(const std::string & grammar_str) {
+    return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0);
+}
+
+static bool test_build_grammar_fails(const std::string & grammar_str) {
+    fprintf(stderr, "⚫ Testing failure for grammar: %s\n", grammar_str.c_str());
+    bool grammar_fails = false;
+    llama_grammar * grammar = build_grammar(grammar_str);
+    if (grammar != nullptr) {
+        fprintf(stderr, "  ❌ Expected build failure, but succeeded\n");
+    } else {
+        grammar_fails = true;
+        fprintf(stdout, "  ✅︎\n");
+    }
+    return grammar_fails;
+}
+
+static bool match_string(const std::string & input, llama_grammar * grammar) {
+    const auto cpts = unicode_cpts_from_utf8(input);
+
+    auto & stacks_cur = llama_grammar_get_stacks(grammar);
+
+    for (const auto & cpt : cpts) {
+        llama_grammar_accept(grammar, cpt);
+
+        if (stacks_cur.empty()) {
+            // no stacks means that the grammar failed to match at this point
+            return false;
+        }
+    }
+
+    for (const auto & stack : stacks_cur) {
+        if (stack.empty()) {
+            // An empty stack means that the grammar has been completed
+            return true;
+        }
+    }
+
+    return false;
+}
+
+static void test(const std::string & test_desc, const std::string & grammar_str, const std::vector & passing_strings, const std::vector & failing_strings) {
+    fprintf(stderr, "⚫ Testing %s\n%s\n", test_desc.c_str(), grammar_str.c_str());
+    fflush(stderr);
+
+    auto * grammar = build_grammar(grammar_str);
+
+    // Save the original grammar stacks so that we can reset after every new string we want to test
+    const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar); // copy
+
+    llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
+
+    fprintf(stderr, "  🔵 Valid strings:\n");
+
+    // Passing strings
+    for (const auto & test_string : passing_strings) {
+        fprintf(stderr, "    \"%s\" ", test_string.c_str());
+        fflush(stderr);
+
+        bool matched = match_string(test_string, grammar);
+
+        if (!matched) {
+            fprintf(stderr, "❌ (failed to match)\n");
+
+            // DEBUG: Write strings to files so that we can analyze more easily with gbnf-validator program to see exactly where things failed.
+            // DEBUG: Write the grammar_str to test-grammar-integration.grammar.gbnf
+            FILE* grammar_file = fopen("test-grammar-integration.grammar.gbnf", "w");
+            if (grammar_file) {
+                fprintf(grammar_file, "%s", grammar_str.c_str());
+                fclose(grammar_file);
+            }
+
+            // DEBUG: Write the test string to test-grammar-integration.string.txt
+            FILE* string_file = fopen("test-grammar-integration.string.txt", "w");
+            if (string_file) {
+                fprintf(string_file, "%s", test_string.c_str());
+                fclose(string_file);
+            }
+
+            fprintf(stderr, "\n NOTE: Debug grammar file generated. To analyze this failure in detail, run the following command:     ./llama-gbnf-validator test-grammar-integration.grammar.gbnf test-grammar-integration.string.txt\n\n");
+        } else {
+            fprintf(stdout, "✅︎\n");
+        }
+
+        assert(matched);
+
+        // Reset the grammar stacks
+        stacks_cur = stacks_org;
+    }
+
+    fprintf(stderr, "  🟠 Invalid strings:\n");
+
+    // Failing strings
+    for (const auto & test_string : failing_strings) {
+        fprintf(stderr, "    \"%s\" ", test_string.c_str());
+        fflush(stderr);
+
+        bool matched = match_string(test_string, grammar);
+
+        if (matched) {
+            fprintf(stderr, "❌ (incorrectly matched)\n");
+        } else {
+            fprintf(stdout, "✅︎\n");
+        }
+        assert(!matched);
+
+        // Reset the grammar stacks
+        stacks_cur = stacks_org;
+    }
+
+    // Clean up allocated memory
+    llama_grammar_free_impl(grammar);
+}
+static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector & passing_strings, const std::vector & failing_strings) {
+    test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
+}
+static void test_schema(const std::string & test_desc, const std::string & schema_str, const std::vector & passing_strings, const std::vector & failing_strings) {
+    test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str), true), passing_strings, failing_strings);
+}
+
+static void test_simple_grammar() {
+    test_schema(
+        "min 0",
+        R"""({
+            "type": "integer",
+            "minimum": 0
+        })""",
+        // Passing strings
+        {
+            "0",
+            "10",
+            "12",
+            "10000",
+        },
+        // Failing strings
+        {
+            "-1",
+            "-10",
+            "-10000",
+            "-100000000000000000000000000000000",
+            "100000000000000000000000000000000",
+            "00",
+            "01",
+            "-0",
+        }
+    );
+    test_schema(
+        "min 2",
+        // Schema
+        R"""({
+            "type": "integer",
+            "minimum": 2
+        })""",
+        // Passing strings
+        {
+            "2",
+            "3",
+            "4",
+            "10",
+            "20",
+            "1234567890000000",
+        },
+        // Failing strings
+        {
+            "0",
+            "1",
+            "-1",
+            "-100",
+            "0",
+            "1",
+            "01",
+            "02",
+            "12345678900000000",
+        }
+    );
+    test_schema(
+        "min 456",
+        R"""({
+            "type": "integer",
+            "minimum": 456
+        })""",
+        // Passing strings
+        {
+            "456",
+            "4560",
+            "457",
+            "460",
+            "500",
+        },
+        // Failing strings
+        {
+            "455",
+            "356",
+            "50",
+            "050",
+            "-1",
+            "-456",
+        }
+    );
+    test_schema(
+        "min -123",
+        R"""({
+            "type": "integer",
+            "minimum": -123
+        })""",
+        // Passing strings
+        {
+            "-123",
+            "-122",
+            "-11",
+            "-1",
+            "0",
+            "1",
+            "123",
+            "1234",
+            "2345",
+        },
+        // Failing strings
+        {
+            "-1234",
+            "-124",
+        }
+    );
+
+    test_schema(
+        "max 9999",
+        // Schema
+        R"""({
+            "type": "integer",
+            "maximum": 9999
+        })""",
+        // Passing strings
+        {
+            "-99999",
+            "0",
+            "9999",
+        },
+        // Failing strings
+        {
+            "10000",
+            "99991",
+        }
+    );
+    test_schema(
+        "max -9999",
+        // Schema
+        R"""({
+            "type": "integer",
+            "maximum": -9999
+        })""",
+        // Passing strings
+        {
+            "-10000",
+            "-9999",
+        },
+        // Failing strings
+        {
+            "-9998",
+            "0",
+            "9999",
+        }
+    );
+    test_schema(
+        "min 5 max 30",
+        // Schema
+        R"""({
+            "type": "integer",
+            "minimum": 5,
+            "maximum": 30
+        })""",
+        // Passing strings
+        {
+            "5",
+            "10",
+            "30",
+        },
+        // Failing strings
+        {
+            "05",
+            "4",
+            "-1",
+            "31",
+            "123",
+            "0123",
+        }
+    );
+    test_schema(
+        "min -1 max 1",
+        R"""({
+            "type": "integer",
+            "minimum": -1,
+            "maximum": 1
+        })""",
+        // Passing strings
+        {
+            "-1",
+            "0",
+            "1",
+        },
+        // Failing strings
+        {
+            "-11",
+            "-10",
+            "-2",
+            "2",
+            "10",
+            "11",
+        }
+    );
+    test_schema(
+        "min -123 max 42",
+        R"""({
+            "type": "integer",
+            "minimum": -123,
+            "maximum": 42
+        })""",
+        // Passing strings
+        {
+            "-123",
+            "-122",
+            "-13",
+            "-11",
+            "-2",
+            "-1",
+            "0",
+            "1",
+            "5",
+            "10",
+            "39",
+            "40",
+            "42",
+        },
+        // Failing strings
+        {
+            "-0123",
+            "-124",
+            "-1123",
+            "-200",
+            "43",
+            "123",
+            "0123",
+        }
+    );
+    test_schema(
+        "exclusive min / max",
+        // Schema
+        R"""({
+            "type": "integer",
+            "exclusiveMinimum": 0,
+            "exclusiveMaximum": 10000
+        })""",
+        // Passing strings
+        {
+            "1",
+            "9999",
+        },
+        // Failing strings
+        {
+            "0",
+            "01",
+            "10000",
+            "99999",
+        }
+    );
+
+    // Test case for a simple grammar
+    test_grammar(
+        "simple grammar",
+        R"""(
+            root ::= expr
+            expr ::= term ("+" term)*
+            term ::= number
+            number ::= [0-9]+)""",
+        // Passing strings
+        {
+            "42",
+            "1+2+3+4+5",
+            "123+456",
+        },
+        // Failing strings
+        {
+            "+",
+            "/ 3",
+            "1+2+3+4+5+",
+            "12a45",
+        }
+    );
+}
+
+static void test_complex_grammar() {
+    // Test case for a more complex grammar, with both failure strings and success strings
+    test_grammar(
+        "medium complexity grammar",
+        // Grammar
+        R"""(
+            root ::= expression
+            expression ::= term ws (("+"|"-") ws term)*
+            term ::= factor ws (("*"|"/") ws factor)*
+            factor ::= number | variable | "(" expression ")" | function-call
+            number ::= [0-9]+
+            variable ::= [a-zA-Z_][a-zA-Z0-9_]*
+            function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
+            ws ::= [ \t\n\r]?)""",
+        // Passing strings
+        {
+            "42",
+            "1*2*3*4*5",
+            "x",
+            "x+10",
+            "x1+y2",
+            "(a+b)*(c-d)",
+            "func()",
+            "func(x,y+2)",
+            "a*(b+c)-d/e",
+            "f(g(x),h(y,z))",
+            "x + 10",
+            "x1 + y2",
+            "(a + b) * (c - d)",
+            "func()",
+            "func(x, y + 2)",
+            "a * (b + c) - d / e",
+            "f(g(x), h(y, z))",
+            "123+456",
+            "123*456*789-123/456+789*123",
+            "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
+        },
+        // Failing strings
+        {
+            "+",
+            "/ 3x",
+            "x + + y",
+            "a * / b",
+            "func(,)",
+            "func(x y)",
+            "(a + b",
+            "x + y)",
+            "a + b * (c - d",
+            "42 +",
+            "x +",
+            "x + 10 +",
+            "(a + b) * (c - d",
+            "func(",
+            "func(x, y + 2",
+            "a * (b + c) - d /",
+            "f(g(x), h(y, z)",
+            "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
+        }
+    );
+}
+
+static void test_special_chars() {
+    // A collection of tests to exercise special characters such as "."
+    test_grammar(
+        "special characters",
+        // Grammar
+        R"""(
+            root ::= ... "abc" ...
+            )""",
+        // Passing strings
+        {
+            "abcabcabc",
+            "aaaabcccc",
+            // NOTE: Also ensures that multi-byte characters still count as a single character
+            "🔵🟠✅abc❌🟠🔵"
+        },
+        // Failing strings
+        {
+            "aaabcccc",
+            "aaaaabcccc",
+            "aaaabccc",
+            "aaaabccccc",
+            "🔵🟠✅❌abc❌✅🟠🔵",
+            "🔵🟠abc🟠🔵"
+        }
+    );
+}
+
+static void test_quantifiers() {
+    // A collection of tests to exercise * + and ? quantifiers
+
+    test_grammar(
+        "* quantifier",
+        // Grammar
+        R"""(root ::= "a"*)""",
+        // Passing strings
+        {
+            "",
+            "a",
+            "aaaaa",
+            "aaaaaaaaaaaaaaaaaa",
+            "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+        },
+        // Failing strings
+        {
+            "b",
+            "ab",
+            "aab",
+            "ba",
+            "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"
+        }
+    );
+    test_grammar(
+        "+ quantifier",
+        // Grammar
+        R"""(root ::= "a"+)""",
+        // Passing strings
+        {
+            "a",
+            "aaaaa",
+            "aaaaaaaaaaaaaaaaaa",
+            "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+        },
+        // Failing strings
+        {
+            "",
+            "b",
+            "ab",
+            "aab",
+            "ba",
+            "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"
+        }
+    );
+    test_grammar(
+        "? quantifier",
+        // Grammar
+        R"""(root ::= "a"?)""",
+        // Passing strings
+        {
+            "",
+            "a"
+        },
+        // Failing strings
+        {
+            "b",
+            "ab",
+            "aa",
+            "ba",
+        }
+    );
+    test_grammar(
+        "mixed quantifiers",
+        // Grammar
+        R"""(
+            root ::= cons+ vowel* cons? (vowel cons)*
+            vowel ::= [aeiouy]
+            cons ::= [bcdfghjklmnpqrstvwxyz]
+            )""",
+        // Passing strings
+        {
+            "yes",
+            "no",
+            "noyes",
+            "crwth",
+            "four",
+            "bryyyy",
+        },
+        // Failing strings
+        {
+            "yess",
+            "yesno",
+            "forty",
+            "catyyy",
+        }
+    );
+    test_grammar(
+        "simple exact repetition",
+        // Grammar
+        R"""(
+            root ::= [ab]{4}
+        )""",
+        // Passing strings
+        {
+            "aaaa",
+            "bbbb",
+            "abab",
+        },
+        // Failing strings
+        {
+            "a",
+            "b",
+            "aaaaa",
+        }
+    );
+    test_grammar(
+        "simple min repetition",
+        // Grammar
+        R"""(
+            root ::= [ab]{4,}
+        )""",
+        // Passing strings
+        {
+            "aaaa",
+            "aaaaab",
+            "bbbb",
+            "ababab",
+        },
+        // Failing strings
+        {
+            "",
+            "aba",
+        }
+    );
+    test_grammar(
+        "simple max repetition",
+        // Grammar
+        R"""(
+            root ::= [ab]{0,4}
+        )""",
+        // Passing strings
+        {
+            "",
+            "a",
+            "aa",
+            "aaa",
+            "aaab",
+        },
+        // Failing strings
+        {
+            "aaaaa",
+        }
+    );
+    test_grammar(
+        "min / max repetition",
+        // Grammar
+        R"""(
+            root ::= ("0x" [A-F0-9]{2} " "?){3,5}
+        )""",
+        // Passing strings
+        {
+            "0xFF 0x12 0xAB",
+            "0xFF 0x12 0xAB 0x00 0x00",
+        },
+        // Failing strings
+        {
+            "",
+            "0xFF",
+            "0xFF 0x12",
+            "0xFF 0x12 0xAB 0x00 0x00 0x00",
+        }
+    );
+}
+
+static void test_failure_missing_root() {
+    fprintf(stderr, "⚫ Testing missing root node:\n");
+    // Test case for a grammar that is missing a root rule
+    const std::string grammar_str = R"""(
+        rot ::= expr
+        expr ::= term ("+" term)*
+        term ::= number
+        number ::= [0-9]+)""";
+
+    llama_grammar_parser parsed_grammar;
+    parsed_grammar.parse(grammar_str.c_str());
+
+    // Ensure we parsed correctly
+    assert(!parsed_grammar.rules.empty());
+
+    // Ensure we do NOT have a root node
+    assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end());
+    fprintf(stderr, "  ✅︎ Passed\n");
+}
+
+static void test_failure_missing_reference() {
+    fprintf(stderr, "⚫ Testing missing reference node:\n");
+
+    // Test case for a grammar that is missing a referenced rule
+    const std::string grammar_str =
+        R"""(root ::= expr
+        expr ::= term ("+" term)*
+        term ::= numero
+        number ::= [0-9]+)""";
+
+    fprintf(stderr, "    Expected error:  ");
+
+    llama_grammar_parser parsed_grammar;
+    parsed_grammar.parse(grammar_str.c_str());
+
+    // Ensure we did NOT parsed correctly
+    assert(parsed_grammar.rules.empty());
+
+    fprintf(stderr, "    End of expected error.\n");
+    fprintf(stderr, "  ✅︎ Passed\n");
+}
+
+static void test_failure_left_recursion() {
+    fprintf(stderr, "⚫ Testing left recursion detection:\n");
+
+    // Test simple left recursion detection
+    const std::string simple_str = R"""(root ::= "a" | root "a")""";
+    assert(test_build_grammar_fails(simple_str));
+
+    // Test more complicated left recursion detection
+    const std::string medium_str = R"""(
+        root ::= asdf
+        asdf ::= "a" | asdf "a"
+        )""";
+    assert(test_build_grammar_fails(medium_str));
+
+    // Test even more complicated left recursion detection
+    const std::string hard_str = R"""(
+        root ::= asdf
+        asdf ::= "a" | foo "b"
+        foo ::= "c" | asdf "d" | "e")""";
+    assert(test_build_grammar_fails(hard_str));
+
+    // Test yet even more complicated left recursion detection
+    const std::string hardest_str = R"""(
+        root ::= asdf
+        asdf ::= "a" | foo "b"
+        foo ::= "c" | empty asdf "d" | "e"
+        empty ::= "blah" | )""";
+    assert(test_build_grammar_fails(hardest_str));
+
+    fprintf(stderr, "  ✅︎ Passed\n");
+}
+
+static void test_json_schema() {
+    // Note that this is similar to the regular grammar tests,
+    //  but we convert each json schema to a grammar before parsing.
+    // Otherwise, this test structure is the same.
+
+    test_schema(
+        "empty schema (object)",
+        // Schema
+        R"""(
+            {}
+        )""",
+        // Passing strings
+        {
+            R"""({})""",
+            R"""({"foo": "bar"})""",
+        },
+        // Failing strings
+        {
+            "",
+            "[]",
+            "null",
+            R"""("")""",
+            "true",
+        }
+    );
+
+    test_schema(
+        "exotic formats (list)",
+        // Schema
+        R"""({
+            "items": [
+                { "format": "date" },
+                { "format": "uuid" },
+                { "format": "time" },
+                { "format": "date-time" }
+            ]
+        })""",
+        // Passing strings
+        {
+            // "{}", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
+            // "[]", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
+            R"""(["2012-04-23", "12345678-1234-1234-1234-1234567890ab", "18:25:43.511Z", "2012-04-23T18:25:43.511Z"])""",
+            //R"""(["2012-04-23","12345678-1234-1234-1234-1234567890ab"])""", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
+            //R"""({"foo": "bar"})""", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
+        },
+        // Failing strings
+        {
+            R"""(["foo", "bar"])""",
+            R"""(["12345678-1234-1234-1234-1234567890ab"])""",
+        }
+    );
+
+    test_schema(
+        "string",
+        // Schema
+        R"""({
+            "type": "string"
+        })""",
+        // Passing strings
+        {
+            R"""("foo")""",
+            R"""("bar")""",
+            R"""("")""",
+        },
+        // Failing strings
+        {
+            R"""({})""",
+            R"""("foo": "bar")""",
+        }
+    );
+
+    test_schema(
+        "string w/ min length 1",
+        // Schema
+        R"""({
+            "type": "string",
+            "minLength": 1
+        })""",
+        // Passing strings
+        {
+            R"""("foo")""",
+            R"""("bar")""",
+        },
+        // Failing strings
+        {
+            R"""("")""",
+            R"""({})""",
+            R"""("foo": "bar")""",
+        }
+    );
+
+    test_schema(
+        "string w/ min length 3",
+        // Schema
+        R"""({
+                "type": "string",
+                "minLength": 3
+        })""",
+        // Passing strings
+        {
+            R"""("foo")""",
+            R"""("bar")""",
+            R"""("foobar")""",
+        },
+        // Failing strings
+        {
+            R"""("")""",
+            R"""("f")""",
+            R"""("fo")""",
+        }
+    );
+
+    test_schema(
+        "string w/ max length",
+        // Schema
+        R"""({
+            "type": "string",
+            "maxLength": 3
+        })""",
+        // Passing strings
+        {
+            R"""("foo")""",
+            R"""("bar")""",
+            R"""("")""",
+            R"""("f")""",
+            R"""("fo")""",
+        },
+        // Failing strings
+        {
+            R"""("foobar")""",
+        }
+    );
+
+    test_schema(
+        "string w/ min & max length",
+        // Schema
+        R"""({
+            "type": "string",
+            "minLength": 1,
+            "maxLength": 4
+        })""",
+        // Passing strings
+        {
+            R"""("foo")""",
+            R"""("bar")""",
+            R"""("f")""",
+            R"""("barf")""",
+        },
+        // Failing strings
+        {
+            R"""("")""",
+            R"""("barfo")""",
+            R"""("foobar")""",
+        }
+    );
+
+    test_schema(
+        "boolean",
+        // Schema
+        R"""({
+            "type": "boolean"
+        })""",
+        // Passing strings
+        {
+            "true",
+            "false",
+        },
+        // Failing strings
+        {
+            R"""("")""",
+            R"""("true")""",
+            R"""(True)""",
+            R"""(FALSE)""",
+        }
+    );
+
+    test_schema(
+        "integer",
+        // Schema
+        R"""({
+            "type": "integer"
+        })""",
+        // Passing strings
+        {
+            R"""(0)""",
+            R"""(12345)""",
+            R"""(1234567890123456)""",
+        },
+        // Failing strings
+        {
+            R"""()""",
+            R"""(01)""",
+            R"""(007)""",
+            R"""(12345678901234567  )""",
+        }
+    );
+
+    test_schema(
+        "string const",
+        // Schema
+        R"""({
+            "const": "foo"
+        })""",
+        // Passing strings
+        {
+            R"""("foo")""",
+        },
+        // Failing strings
+        {
+            R"""(foo)""",
+            R"""("bar")""",
+        }
+    );
+
+    test_schema(
+        "non-string const",
+        // Schema
+        R"""({
+            "const": true
+        })""",
+        // Passing strings
+        {
+            R"""(true)""",
+        },
+        // Failing strings
+        {
+            R"""()""",
+            R"""(foo)""",
+            R"""("true")""",
+        }
+    );
+
+    test_schema(
+        "non-string const",
+        // Schema
+        R"""({
+            "enum": ["red", "amber", "green", null, 42, ["foo"]]
+        })""",
+        // Passing strings
+        {
+            R"""("red")""",
+            R"""(null)""",
+            R"""(42)""",
+            R"""(["foo"])""",
+        },
+        // Failing strings
+        {
+            R"""()""",
+            R"""(420)""",
+            R"""(true)""",
+            R"""(foo)""",
+        }
+    );
+
+    test_schema(
+        "simple pattern",
+        // Schema
+        R"""({
+            "pattern": "^[a-zA-Z0-9_-]*$"
+        })""",
+        // Passing strings
+        {
+            R"""("")""",
+            R"""("He_llo-12")""",
+        },
+        // Failing strings
+        {
+            R"""("!")""",
+            R"""("Hello World")""",
+        }
+    );
+
+    test_schema(
+        "pattern with escapes",
+        // Schema
+        R"""({
+            "pattern": "^a\\^\\$\\.\\[\\]\\(\\)\\|\\{\\}\\*\\+\\?b$"
+        })""",
+        // Passing strings
+        {
+            R"""("a^$.[]()|{}*+?b")""",
+        },
+        // Failing strings
+        {
+            R"""("ab")""",
+        }
+    );
+
+    test_schema(
+        "",
+        // Schema
+        R"""(
+            {
+                "type": ["array", "null"],
+                "items": { "type": "string" }
+            }
+        )""",
+        // Passing strings
+        {
+            "null",
+            "[]",
+            "[\"123\"]",
+            "[\"foo\", \"bar\"]",
+        },
+        // Failing strings
+        {
+            "",
+            "[123]",
+            "\"foo\"",
+            "[\"foo\", 42]",
+        }
+    );
+
+    test_schema(
+        "min+max items",
+        // Schema
+        R"""({
+            "items": {
+                "type": ["number", "integer"]
+            },
+            "minItems": 3,
+            "maxItems": 5
+        })""",
+        // Passing strings
+        {
+            R"""([1, 2, 3])""",
+            R"""([1, 2, 3, 4])""",
+            R"""([1, 2, 3, 4, 5])""",
+        },
+        // Failing strings
+        {
+            R"""([1, 2])""",
+            R"""([1, 2, 3, 4, 5, 6])""",
+            R"""(1)""",
+        }
+    );
+
+    // Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties)
+    test_schema(
+        "object properties",
+        // Schema
+        R"""({
+            "type": "object",
+            "properties": {
+                "number": { "type": "number" },
+                "street_name": { "type": "string" },
+                "street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
+            }
+        })""",
+        // Passing strings
+        {
+            R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""",
+            // "By default, leaving out properties is valid"
+            R"""({ "street_name": "Pennsylvania" })""",
+            R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
+            // "By extension, even an empty object is valid"
+            R"""({})""",
+            R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
+        },
+        // Failing strings
+        {
+            // Change datatype from number to string
+            R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
+            // Reorder properties
+            R"""({ "street_name": "Pennsylvania", "number": 1600 })""",
+            // Reorder properties
+            R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
+            // "Additional properties default to false for generation, even though the spec says true.
+            R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""",
+
+        }
+    );
+
+    test_schema(
+        "additional properties can't override other properties",
+        R"""({
+            "properties": {
+                "a": {"type": "integer"},
+                "b": {"type": "integer"}
+            },
+            "additionalProperties": true
+        })""",
+        // Passing strings
+        {
+            R"""({"a": 42})""",
+            R"""({"c": ""})""",
+            R"""({"a": 42, "c": ""})""",
+            R"""({"a_": ""})""",
+        },
+        // Failing strings
+        {
+            R"""()""",
+            R"""({"a": ""})""",
+            R"""({"a": "", "b": ""})""",
+        }
+    );
+
+    // Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties)
+    test_schema(
+        "object properties, additionalProperties: true",
+        // Schema
+        R"""({
+            "type": "object",
+            "properties": {
+                "number": { "type": "number" },
+                "street_name": { "type": "string" },
+                "street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
+            },
+            "additionalProperties": true
+        })""",
+        // Passing strings
+        {
+            // "By extension, even an empty object is valid"
+            R"""({})""",
+            R"""({"number":1600,"street_name":"Pennsylvania","street_type":"Avenue"})""",
+            // "By default, leaving out properties is valid"
+            R"""({ "street_name": "Pennsylvania" })""",
+            R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
+            // "By default, providing additional properties is valid"
+            R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""",
+            R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
+        },
+        // Failing strings
+        {
+            // Change datatype from number to string
+            R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
+            // Reorder properties
+            R"""({ "street_name": "Pennsylvania", "number": 1600, "street_type":"Avenue"})""",
+        }
+    );
+
+    // Additional properties: false
+    test_schema(
+        "required + optional props each in original order",
+        // Schema
+        R"""({
+            "type": "object",
+            "properties": {
+                "number": { "type": "number" },
+                "street_name": { "type": "string" },
+                "street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
+            },
+            "additionalProperties": false
+        })""",
+        // Passing strings
+        {
+            R"""({ "street_name": "Pennsylvania" })""",
+            R"""({ "number": 1600, "street_type":"Avenue"})""",
+            R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
+            R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""",
+            // Spaces are permitted around enum values
+            R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
+        },
+        // Failing strings
+        {
+            // Reorder properties
+            R"""({ "street_type": "Avenue", "number": 1600 })""",
+            // Add "direction"
+            R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue", "direction": "NW" })""",
+        }
+    );
+
+    test_schema(
+        "required + optional props each in original order",
+        // Schema
+        R"""({
+            "properties": {
+                "b": {"type": "string"},
+                "a": {"type": "string"},
+                "d": {"type": "string"},
+                "c": {"type": "string"}
+            },
+            "required": ["a", "b"],
+            "additionalProperties": false
+        })""",
+        // Passing strings
+        {
+            R"""({"b": "foo", "a": "bar"})""",
+            R"""({"b":"foo","a":"bar","d":"qux"})""",
+            R"""({"b":"foo", "a":"bar", "d":"qux", "c":"baz"})""",
+        },
+        // Failing strings
+        {
+            R"""({"a": "foo", "b": "bar"})""",
+            R"""({"b": "bar"})""",
+            R"""({"a": "foo", "c": "baz"})""",
+            R"""({"a":"foo", "b":"bar", "c":"baz", "d":"qux"})""",
+        }
+    );
+
+    // NOTE: Example from https://json-schema.org/learn/getting-started-step-by-step#define-required-properties
+    test_schema(
+        "required props",
+        // Schema
+        R"""({
+            "$schema": "https://json-schema.org/draft/2020-12/schema",
+            "$id": "https://example.com/product.schema.json",
+            "title": "Product",
+            "description": "A product from Acme's catalog",
+            "type": "object",
+            "properties": {
+                "productId": {
+                "description": "The unique identifier for a product",
+                "type": "integer"
+                },
+                "productName": {
+                "description": "Name of the product",
+                "type": "string"
+                },
+                "price": {
+                "description": "The price of the product",
+                "type": "number",
+                "exclusiveMinimum": 0
+                },
+                "tags": {
+                "description": "Tags for the product",
+                "type": "array",
+                "items": {
+                    "type": "string"
+                },
+                "minItems": 1,
+                "uniqueItems": true
+                },
+                "dimensions": {
+                "type": "object",
+                "properties": {
+                    "length": {
+                    "type": "number"
+                    },
+                    "width": {
+                    "type": "number"
+                    },
+                    "height": {
+                    "type": "number"
+                    }
+                },
+                "required": [ "length", "width", "height" ]
+                }
+            },
+            "required": [ "productId", "productName", "price" ]
+        })""",
+        // Passing strings
+        {
+            R"""({"productId": 1, "productName": "A green door", "price": 12.50})""",
+            R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green"]})""",
+            R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green"], "dimensions": {"length": 785, "width": 250.5, "height": -0.359}})""",
+        },
+        // Failing strings
+        {
+            R"""({})""", // Missing all required properties
+            R"""({"productName": "A green door", "price": 12.50, "productId": 1})""", // Out of order properties
+            // TODO: The following line should fail, but currently it passes. `exclusiveMinimum` is not supported, as it would likely be too difficult to implement.
+            //  Perhaps special checks for minimum and maximum values of 0 could be added (since that's relatively easy to do with grammars), but anything else would likely be too complex.
+            // R"""({"productId": 1, "productName": "A green door", "price": -12.50})""",
+            R"""({"productId": 1, "productName": "A green door"})""", // Missing required property (price)
+            R"""({"productName": "A green door", "price": 12.50})""", // Missing required property (productId)
+            R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": []})""", // tags is empty, but minItems is 1
+            R"""({"productId": 1, "productName": "A green door", "price": 12.50, "dimensions": {"length": 785, "width": 250.5, "height": -0.359}, "tags": ["home", "green"]})""", // Tags and dimensions are out of order
+            // TODO: The following line should fail, but currently it passes. `uniqueItems` is not supported, as it would likely be too difficult to implement.
+            // R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green", "home"]})""",
+        }
+    );
+}
+
+int main() {
+    fprintf(stdout, "Running grammar integration tests...\n");
+    test_simple_grammar();
+    test_complex_grammar();
+    test_special_chars();
+    test_quantifiers();
+    test_failure_missing_root();
+    test_failure_missing_reference();
+    test_failure_left_recursion();
+    test_json_schema();
+    fprintf(stdout, "All tests passed.\n");
+    return 0;
+}
diff --git a/tests/test-grammar-llguidance.cpp b/tests/test-grammar-llguidance.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..566b039a0703804dd99b6ded66f701f00e94056e
--- /dev/null
+++ b/tests/test-grammar-llguidance.cpp
@@ -0,0 +1,1201 @@
+#ifdef NDEBUG
+#    undef NDEBUG
+#endif
+
+#include "sampling.h"
+
+#include 
+#include 
+#include 
+
+static const llama_vocab * vocab;
+
+static bool match_string(const std::string & input, llama_sampler * grammar) {
+    llama_sampler_reset(grammar);
+    auto tokens = common_tokenize(vocab, input, false, false);
+
+    auto n_vocab = llama_vocab_n_tokens(vocab);
+
+    std::vector cur;
+    cur.reserve(n_vocab);
+    for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
+        cur.emplace_back(llama_token_data{ token_id, 0.0f, 0.0f });
+    }
+    auto tok_arr = llama_token_data_array{ cur.data(), cur.size(), -1, false };
+
+    for (const auto token : tokens) {
+        for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
+            cur[token_id].logit = 0.0f;
+        }
+        llama_sampler_apply(grammar, &tok_arr);
+        if (cur[token].logit < 0.0f) {
+            return false;
+        }
+        llama_sampler_accept(grammar, token);
+    }
+
+    // do we allow EOS at the end? if so the grammar is accepting
+
+    auto tok_eos = llama_vocab_eot(vocab);
+    if (tok_eos == LLAMA_TOKEN_NULL) {
+        tok_eos = llama_vocab_eos(vocab);
+    }
+
+    cur[tok_eos].logit = 0.0f;
+    llama_sampler_apply(grammar, &tok_arr);
+
+    return cur[tok_eos].logit >= 0.0f;
+}
+
+static void test(const std::string & test_desc, const std::string & grammar_str,
+                 const std::vector & passing_strings, const std::vector & failing_strings) {
+    fprintf(stderr, "⚫ Testing %s\n%s\n", test_desc.c_str(), grammar_str.c_str());
+    fflush(stderr);
+
+    auto * grammar = llama_sampler_init_llg(vocab, "lark", grammar_str.c_str());
+
+    fprintf(stderr, "  🔵 Valid strings:\n");
+
+    // Passing strings
+    for (const auto & test_string : passing_strings) {
+        fprintf(stderr, "    \"%s\" ", test_string.c_str());
+        fflush(stderr);
+
+        bool matched = match_string(test_string, grammar);
+
+        if (!matched) {
+            fprintf(stderr, "❌ (failed to match)\n");
+
+            // DEBUG: Write strings to files so that we can analyze more easily with gbnf-validator program to see exactly where things failed.
+            // DEBUG: Write the grammar_str to test-grammar-integration.grammar.gbnf
+            FILE * grammar_file = fopen("test-grammar-integration.grammar.gbnf", "w");
+            if (grammar_file) {
+                fprintf(grammar_file, "%s", grammar_str.c_str());
+                fclose(grammar_file);
+            }
+
+            // DEBUG: Write the test string to test-grammar-integration.string.txt
+            FILE * string_file = fopen("test-grammar-integration.string.txt", "w");
+            if (string_file) {
+                fprintf(string_file, "%s", test_string.c_str());
+                fclose(string_file);
+            }
+
+            fprintf(stderr,
+                    "\n NOTE: Debug grammar file generated. To analyze this failure in detail, run the following "
+                    "command:     ./test-gbnf-validator test-grammar-integration.grammar.gbnf "
+                    "test-grammar-integration.string.txt\n\n");
+        } else {
+            fprintf(stdout, "✅︎\n");
+        }
+
+        assert(matched);
+    }
+
+    fprintf(stderr, "  🟠 Invalid strings:\n");
+
+    // Failing strings
+    for (const auto & test_string : failing_strings) {
+        fprintf(stderr, "    \"%s\" ", test_string.c_str());
+        fflush(stderr);
+
+        bool matched = match_string(test_string, grammar);
+
+        if (matched) {
+            fprintf(stderr, "❌ (incorrectly matched)\n");
+        } else {
+            fprintf(stdout, "✅︎\n");
+        }
+        assert(!matched);
+    }
+
+    llama_sampler_free(grammar);
+}
+
+static void test_grammar(const std::string & test_desc, const std::string & grammar_str,
+                         const std::vector & passing_strings,
+                         const std::vector & failing_strings) {
+    test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
+}
+
+static void test_schema(const std::string & test_desc, const std::string & schema_str,
+                        const std::vector & passing_strings,
+                        const std::vector & failing_strings) {
+    test(test_desc + ". Schema: " + schema_str, "%llguidance {}\nstart: %json " + schema_str, passing_strings,
+         failing_strings);
+}
+
+static void test_simple_grammar() {
+    test_schema("min 0",
+                R"""({
+            "type": "integer",
+            "minimum": 0
+        })""",
+                // Passing strings
+                {
+                    "0",
+                    "10",
+                    "12",
+                    "10000",
+                },
+                // Failing strings
+                {
+                    "-1",
+                    "-10",
+                    "-10000",
+                    "-100000000000000000000000000000000",
+                    // "100000000000000000000000000000000",
+                    "00",
+                    "01",
+                    "-0",
+                });
+    test_schema("min 2",
+                // Schema
+                R"""({
+            "type": "integer",
+            "minimum": 2
+        })""",
+                // Passing strings
+                {
+                    "2",
+                    "3",
+                    "4",
+                    "10",
+                    "20",
+                    "1234567890000000",
+                },
+                // Failing strings
+                {
+                    "0", "1", "-1", "-100", "0", "1", "01", "02",
+                    // "12345678900000000",
+                });
+    test_schema("min 456",
+                R"""({
+            "type": "integer",
+            "minimum": 456
+        })""",
+                // Passing strings
+                {
+                    "456",
+                    "4560",
+                    "457",
+                    "460",
+                    "500",
+                },
+                // Failing strings
+                {
+                    "455",
+                    "356",
+                    "50",
+                    "050",
+                    "-1",
+                    "-456",
+                });
+    test_schema("min -123",
+                R"""({
+            "type": "integer",
+            "minimum": -123
+        })""",
+                // Passing strings
+                {
+                    "-123",
+                    "-122",
+                    "-11",
+                    "-1",
+                    "0",
+                    "1",
+                    "123",
+                    "1234",
+                    "2345",
+                },
+                // Failing strings
+                {
+                    "-1234",
+                    "-124",
+                });
+
+    test_schema("max 9999",
+                // Schema
+                R"""({
+            "type": "integer",
+            "maximum": 9999
+        })""",
+                // Passing strings
+                {
+                    "-99999",
+                    "0",
+                    "9999",
+                },
+                // Failing strings
+                {
+                    "10000",
+                    "99991",
+                });
+    test_schema("max -9999",
+                // Schema
+                R"""({
+            "type": "integer",
+            "maximum": -9999
+        })""",
+                // Passing strings
+                {
+                    "-10000",
+                    "-9999",
+                },
+                // Failing strings
+                {
+                    "-9998",
+                    "0",
+                    "9999",
+                });
+    test_schema("min 5 max 30",
+                // Schema
+                R"""({
+            "type": "integer",
+            "minimum": 5,
+            "maximum": 30
+        })""",
+                // Passing strings
+                {
+                    "5",
+                    "10",
+                    "30",
+                },
+                // Failing strings
+                {
+                    "05",
+                    "4",
+                    "-1",
+                    "31",
+                    "123",
+                    "0123",
+                });
+    test_schema("min -1 max 1",
+                R"""({
+            "type": "integer",
+            "minimum": -1,
+            "maximum": 1
+        })""",
+                // Passing strings
+                {
+                    "-1",
+                    "0",
+                    "1",
+                },
+                // Failing strings
+                {
+                    "-11",
+                    "-10",
+                    "-2",
+                    "2",
+                    "10",
+                    "11",
+                });
+    test_schema("min -123 max 42",
+                R"""({
+            "type": "integer",
+            "minimum": -123,
+            "maximum": 42
+        })""",
+                // Passing strings
+                {
+                    "-123",
+                    "-122",
+                    "-13",
+                    "-11",
+                    "-2",
+                    "-1",
+                    "0",
+                    "1",
+                    "5",
+                    "10",
+                    "39",
+                    "40",
+                    "42",
+                },
+                // Failing strings
+                {
+                    "-0123",
+                    "-124",
+                    "-1123",
+                    "-200",
+                    "43",
+                    "123",
+                    "0123",
+                });
+    test_schema("exclusive min / max",
+                // Schema
+                R"""({
+            "type": "integer",
+            "exclusiveMinimum": 0,
+            "exclusiveMaximum": 10000
+        })""",
+                // Passing strings
+                {
+                    "1",
+                    "9999",
+                },
+                // Failing strings
+                {
+                    "0",
+                    "01",
+                    "10000",
+                    "99999",
+                });
+
+    // Test case for a simple grammar
+    test_grammar("simple grammar",
+                 R"""(
+            start: expr
+            expr: term ("+" term)*
+            term: number
+            number: /[0-9]+/ )""",
+                 // Passing strings
+                 {
+                     "42",
+                     "1+2+3+4+5",
+                     "123+456",
+                 },
+                 // Failing strings
+                 {
+                     "+",
+                     "/ 3",
+                     "1+2+3+4+5+",
+                     "12a45",
+                 });
+}
+
+static void test_complex_grammar() {
+    // Test case for a more complex grammar, with both failure strings and success strings
+    test_grammar("medium complexity grammar",
+                 // Grammar
+                 R"""(
+            start: expression
+            expression: term ws (("+"|"-") ws term)*
+            term: factor ws (("*"|"/") ws factor)*
+            factor: number | variable | "(" expression ")" | function-call
+            number: /[0-9]+/
+            variable: /[a-zA-Z_][a-zA-Z0-9_]*/
+            function-call: variable ws "(" (expression ("," ws expression)*)? ")"
+            ws: /[ \t\n\r]?/ )""",
+                 // Passing strings
+                 { "42",
+                   "1*2*3*4*5",
+                   "x",
+                   "x+10",
+                   "x1+y2",
+                   "(a+b)*(c-d)",
+                   "func()",
+                   "func(x,y+2)",
+                   "a*(b+c)-d/e",
+                   "f(g(x),h(y,z))",
+                   "x + 10",
+                   "x1 + y2",
+                   "(a + b) * (c - d)",
+                   "func()",
+                   "func(x, y + 2)",
+                   "a * (b + c) - d / e",
+                   "f(g(x), h(y, z))",
+                   "123+456",
+                   "123*456*789-123/456+789*123",
+                   "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456" },
+                 // Failing strings
+                 {
+                     "+",
+                     "/ 3x",
+                     "x + + y",
+                     "a * / b",
+                     "func(,)",
+                     "func(x y)",
+                     "(a + b",
+                     "x + y)",
+                     "a + b * (c - d",
+                     "42 +",
+                     "x +",
+                     "x + 10 +",
+                     "(a + b) * (c - d",
+                     "func(",
+                     "func(x, y + 2",
+                     "a * (b + c) - d /",
+                     "f(g(x), h(y, z)",
+                     "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
+                 });
+}
+
+static void test_special_chars() {
+    // A collection of tests to exercise special characters such as "."
+    test_grammar("special characters",
+                 // Grammar
+                 R"""(
+            start: /.../ "abc" /.../
+            )""",
+                 // Passing strings
+                 { "abcabcabc", "aaaabcccc",
+                   // NOTE: Also ensures that multi-byte characters still count as a single character
+                   "🔵🟠✅abc❌🟠🔵" },
+                 // Failing strings
+                 { "aaabcccc", "aaaaabcccc", "aaaabccc", "aaaabccccc", "🔵🟠✅❌abc❌✅🟠🔵", "🔵🟠abc🟠🔵" });
+}
+
+static void test_quantifiers() {
+    // A collection of tests to exercise * + and ? quantifiers
+
+    test_grammar(
+        "* quantifier",
+        // Grammar
+        R"""(start: "a"*)""",
+        // Passing strings
+        { "", "a", "aaaaa", "aaaaaaaaaaaaaaaaaa", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" },
+        // Failing strings
+        { "b", "ab", "aab", "ba", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab" });
+    test_grammar(
+        "+ quantifier",
+        // Grammar
+        R"""(start: "a"+)""",
+        // Passing strings
+        { "a", "aaaaa", "aaaaaaaaaaaaaaaaaa", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" },
+        // Failing strings
+        { "", "b", "ab", "aab", "ba", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab" });
+    test_grammar("? quantifier",
+                 // Grammar
+                 R"""(start: "a"?)""",
+                 // Passing strings
+                 { "", "a" },
+                 // Failing strings
+                 {
+                     "b",
+                     "ab",
+                     "aa",
+                     "ba",
+                 });
+    test_grammar("mixed quantifiers",
+                 // Grammar
+                 R"""(
+            start: cons+ vowel* cons? (vowel cons)*
+            vowel: /[aeiouy]/
+            cons: /[bcdfghjklmnpqrstvwxyz]/
+            )""",
+                 // Passing strings
+                 {
+                     "yes",
+                     "no",
+                     "noyes",
+                     "crwth",
+                     "four",
+                     "bryyyy",
+                 },
+                 // Failing strings
+                 {
+                     "yess",
+                     "yesno",
+                     "forty",
+                     "catyyy",
+                 });
+    test_grammar("simple exact repetition",
+                 // Grammar
+                 R"""(
+            start: /[ab]{4}/
+        )""",
+                 // Passing strings
+                 {
+                     "aaaa",
+                     "bbbb",
+                     "abab",
+                 },
+                 // Failing strings
+                 {
+                     "a",
+                     "b",
+                     "aaaaa",
+                 });
+    test_grammar("simple min repetition",
+                 // Grammar
+                 R"""(
+            start: /[ab]{4,}/
+        )""",
+                 // Passing strings
+                 {
+                     "aaaa",
+                     "aaaaab",
+                     "bbbb",
+                     "ababab",
+                 },
+                 // Failing strings
+                 {
+                     "",
+                     "aba",
+                 });
+    test_grammar("simple max repetition",
+                 // Grammar
+                 R"""(
+            start: /[ab]{0,4}/
+        )""",
+                 // Passing strings
+                 {
+                     "",
+                     "a",
+                     "aa",
+                     "aaa",
+                     "aaab",
+                 },
+                 // Failing strings
+                 {
+                     "aaaaa",
+                 });
+    // test_grammar("min / max repetition",
+    //              // Grammar
+    //              R"""(
+    //         start: ("0x" /[A-F0-9]{2}/ " "?){3,5}
+    //     )""",
+    //              // Passing strings
+    //              {
+    //                  "0xFF 0x12 0xAB",
+    //                  "0xFF 0x12 0xAB 0x00 0x00",
+    //              },
+    //              // Failing strings
+    //              {
+    //                  "",
+    //                  "0xFF",
+    //                  "0xFF 0x12",
+    //                  "0xFF 0x12 0xAB 0x00 0x00 0x00",
+    //              });
+}
+
+static void test_json_schema() {
+    // Note that this is similar to the regular grammar tests,
+    //  but we convert each json schema to a grammar before parsing.
+    // Otherwise, this test structure is the same.
+
+    test_schema("empty schema (object)",
+                // Schema
+                R"""(
+            {"type":"object"}
+        )""",
+                // Passing strings
+                {
+                    R"""({})""",
+                    R"""({"foo": "bar"})""",
+                },
+                // Failing strings
+                {
+                    "",
+                    "[]",
+                    "null",
+                    R"""("")""",
+                    "true",
+                });
+
+    test_schema(
+        "exotic formats (list)",
+        // Schema
+        R"""({
+            "items": [
+                { "format": "date" },
+                { "format": "uuid" },
+                { "format": "time" },
+                { "format": "date-time" }
+            ]
+        })""",
+        // Passing strings
+        {
+            // "{}", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
+            // "[]", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
+            R"""(["2012-04-23", "12345678-1234-1234-1234-1234567890ab", "18:25:43.511Z", "2012-04-23T18:25:43.511Z"])""",
+            //R"""(["2012-04-23","12345678-1234-1234-1234-1234567890ab"])""", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
+            //R"""({"foo": "bar"})""", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
+        },
+        // Failing strings
+        {
+            R"""(["foo", "bar"])""",
+            R"""(["12345678-1234-1234-1234-1234567890ab"])""",
+        });
+
+    test_schema("string",
+                // Schema
+                R"""({
+            "type": "string"
+        })""",
+                // Passing strings
+                {
+                    R"""("foo")""",
+                    R"""("bar")""",
+                    R"""("")""",
+                },
+                // Failing strings
+                {
+                    R"""({})""",
+                    R"""("foo": "bar")""",
+                });
+
+    test_schema("string w/ min length 1",
+                // Schema
+                R"""({
+            "type": "string",
+            "minLength": 1
+        })""",
+                // Passing strings
+                {
+                    R"""("foo")""",
+                    R"""("bar")""",
+                },
+                // Failing strings
+                {
+                    R"""("")""",
+                    R"""({})""",
+                    R"""("foo": "bar")""",
+                });
+
+    test_schema("string w/ min length 3",
+                // Schema
+                R"""({
+                "type": "string",
+                "minLength": 3
+        })""",
+                // Passing strings
+                {
+                    R"""("foo")""",
+                    R"""("bar")""",
+                    R"""("foobar")""",
+                },
+                // Failing strings
+                {
+                    R"""("")""",
+                    R"""("f")""",
+                    R"""("fo")""",
+                });
+
+    test_schema("string w/ max length",
+                // Schema
+                R"""({
+            "type": "string",
+            "maxLength": 3
+        })""",
+                // Passing strings
+                {
+                    R"""("foo")""",
+                    R"""("bar")""",
+                    R"""("")""",
+                    R"""("f")""",
+                    R"""("fo")""",
+                },
+                // Failing strings
+                {
+                    R"""("foobar")""",
+                });
+
+    test_schema("string w/ min & max length",
+                // Schema
+                R"""({
+            "type": "string",
+            "minLength": 1,
+            "maxLength": 4
+        })""",
+                // Passing strings
+                {
+                    R"""("foo")""",
+                    R"""("bar")""",
+                    R"""("f")""",
+                    R"""("barf")""",
+                },
+                // Failing strings
+                {
+                    R"""("")""",
+                    R"""("barfo")""",
+                    R"""("foobar")""",
+                });
+
+    test_schema("boolean",
+                // Schema
+                R"""({
+            "type": "boolean"
+        })""",
+                // Passing strings
+                {
+                    "true",
+                    "false",
+                },
+                // Failing strings
+                {
+                    R"""("")""",
+                    R"""("true")""",
+                    R"""(True)""",
+                    R"""(FALSE)""",
+                });
+
+    test_schema("integer",
+                // Schema
+                R"""({
+            "type": "integer"
+        })""",
+                // Passing strings
+                {
+                    R"""(0)""",
+                    R"""(12345)""",
+                    R"""(1234567890123456)""",
+                },
+                // Failing strings
+                {
+                    R"""()""",
+                    R"""(01)""",
+                    R"""(007)""",
+                    R"""(12345678901234567  )""",
+                });
+
+    test_schema("string const",
+                // Schema
+                R"""({
+            "const": "foo"
+        })""",
+                // Passing strings
+                {
+                    R"""("foo")""",
+                },
+                // Failing strings
+                {
+                    R"""(foo)""",
+                    R"""("bar")""",
+                });
+
+    test_schema("non-string const",
+                // Schema
+                R"""({
+            "const": true
+        })""",
+                // Passing strings
+                {
+                    R"""(true)""",
+                },
+                // Failing strings
+                {
+                    R"""()""",
+                    R"""(foo)""",
+                    R"""("true")""",
+                });
+
+    test_schema("non-string const",
+                // Schema
+                R"""({
+            "enum": ["red", "amber", "green", null, 42, ["foo"]]
+        })""",
+                // Passing strings
+                {
+                    R"""("red")""",
+                    R"""(null)""",
+                    R"""(42)""",
+                    R"""(["foo"])""",
+                },
+                // Failing strings
+                {
+                    R"""()""",
+                    R"""(420)""",
+                    R"""(true)""",
+                    R"""(foo)""",
+                });
+
+    test_schema("simple pattern",
+                // Schema
+                R"""({
+            "pattern": "^[a-zA-Z0-9_-]*$"
+        })""",
+                // Passing strings
+                {
+                    R"""("")""",
+                    R"""("He_llo-12")""",
+                },
+                // Failing strings
+                {
+                    R"""("!")""",
+                    R"""("Hello World")""",
+                });
+
+    test_schema("pattern with escapes",
+                // Schema
+                R"""({
+            "pattern": "^a\\^\\$\\.\\[\\]\\(\\)\\|\\{\\}\\*\\+\\?b$"
+        })""",
+                // Passing strings
+                {
+                    R"""("a^$.[]()|{}*+?b")""",
+                },
+                // Failing strings
+                {
+                    R"""("ab")""",
+                });
+
+    test_schema("",
+                // Schema
+                R"""(
+            {
+                "type": ["array", "null"],
+                "items": { "type": "string" }
+            }
+        )""",
+                // Passing strings
+                {
+                    "null",
+                    "[]",
+                    "[\"123\"]",
+                    "[\"foo\", \"bar\"]",
+                },
+                // Failing strings
+                {
+                    "",
+                    "[123]",
+                    "\"foo\"",
+                    "[\"foo\", 42]",
+                });
+
+    test_schema("min+max items",
+                // Schema
+                R"""({
+            "items": {
+                "type": ["number", "integer"]
+            },
+            "minItems": 3,
+            "maxItems": 5
+        })""",
+                // Passing strings
+                {
+                    R"""([1, 2, 3])""",
+                    R"""([1, 2, 3, 4])""",
+                    R"""([1, 2, 3, 4, 5])""",
+                    // this is in fact correct; keyword do not apply if the type is wrong
+                    R"""(1)""",
+                },
+                // Failing strings
+                {
+                    R"""([1, 2])""",
+                    R"""([1, 2, 3, 4, 5, 6])""",
+                });
+
+    // Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties)
+    test_schema("object properties",
+                // Schema
+                R"""({
+            "type": "object",
+            "properties": {
+                "number": { "type": "number" },
+                "street_name": { "type": "string" },
+                "street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
+            },
+            "additionalProperties": false
+        })""",
+                // Passing strings
+                {
+                    R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""",
+                    // "By default, leaving out properties is valid"
+                    R"""({ "street_name": "Pennsylvania" })""",
+                    R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
+                    // "By extension, even an empty object is valid"
+                    R"""({})""",
+                    R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
+                },
+                // Failing strings
+                {
+                    // Change datatype from number to string
+                    R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
+                    // Reorder properties
+                    R"""({ "street_name": "Pennsylvania", "number": 1600 })""",
+                    // Reorder properties
+                    R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
+                    // Additional properties set to false
+                    R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""",
+
+                });
+
+    test_schema("additional properties can't override other properties",
+                R"""({
+            "properties": {
+                "a": {"type": "integer"},
+                "b": {"type": "integer"}
+            },
+            "additionalProperties": true
+        })""",
+                // Passing strings
+                {
+                    R"""({"a": 42})""",
+                    R"""({"c": ""})""",
+                    R"""({"a": 42, "c": ""})""",
+                    R"""({"a_": ""})""",
+                },
+                // Failing strings
+                {
+                    R"""()""",
+                    R"""({"a": ""})""",
+                    R"""({"a": "", "b": ""})""",
+                });
+
+    // Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties)
+    test_schema("object properties, additionalProperties: true",
+                // Schema
+                R"""({
+            "type": "object",
+            "properties": {
+                "number": { "type": "number" },
+                "street_name": { "type": "string" },
+                "street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
+            },
+            "additionalProperties": true
+        })""",
+                // Passing strings
+                {
+                    // "By extension, even an empty object is valid"
+                    R"""({})""",
+                    R"""({"number":1600,"street_name":"Pennsylvania","street_type":"Avenue"})""",
+                    // "By default, leaving out properties is valid"
+                    R"""({ "street_name": "Pennsylvania" })""",
+                    R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
+                    // "By default, providing additional properties is valid"
+                    R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""",
+                    R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
+                },
+                // Failing strings
+                {
+                    // Change datatype from number to string
+                    R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
+                    // Reorder properties
+                    R"""({ "street_name": "Pennsylvania", "number": 1600, "street_type":"Avenue"})""",
+                });
+
+    // Additional properties: false
+    test_schema(
+        "required + optional props each in original order",
+        // Schema
+        R"""({
+            "type": "object",
+            "properties": {
+                "number": { "type": "number" },
+                "street_name": { "type": "string" },
+                "street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
+            },
+            "additionalProperties": false
+        })""",
+        // Passing strings
+        {
+            R"""({ "street_name": "Pennsylvania" })""",
+            R"""({ "number": 1600, "street_type":"Avenue"})""",
+            R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
+            R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""",
+            // Spaces are permitted around enum values
+            R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
+        },
+        // Failing strings
+        {
+            // Reorder properties
+            R"""({ "street_type": "Avenue", "number": 1600 })""",
+            // Add "direction"
+            R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue", "direction": "NW" })""",
+        });
+
+    test_schema("required + optional props each in original order",
+                // Schema
+                R"""({
+            "properties": {
+                "b": {"type": "string"},
+                "a": {"type": "string"},
+                "d": {"type": "string"},
+                "c": {"type": "string"}
+            },
+            "required": ["a", "b"],
+            "additionalProperties": false
+        })""",
+                // Passing strings
+                {
+                    R"""({"b": "foo", "a": "bar"})""",
+                    R"""({"b":"foo","a":"bar","d":"qux"})""",
+                    R"""({"b":"foo", "a":"bar", "d":"qux", "c":"baz"})""",
+                },
+                // Failing strings
+                {
+                    R"""({"a": "foo", "b": "bar"})""",
+                    R"""({"b": "bar"})""",
+                    R"""({"a": "foo", "c": "baz"})""",
+                    R"""({"a":"foo", "b":"bar", "c":"baz", "d":"qux"})""",
+                });
+
+    // NOTE: Example from https://json-schema.org/learn/getting-started-step-by-step#define-required-properties
+    test_schema(
+        "required props",
+        // Schema
+        R"""({
+            "$schema": "https://json-schema.org/draft/2020-12/schema",
+            "$id": "https://example.com/product.schema.json",
+            "title": "Product",
+            "description": "A product from Acme's catalog",
+            "type": "object",
+            "properties": {
+                "productId": {
+                "description": "The unique identifier for a product",
+                "type": "integer"
+                },
+                "productName": {
+                "description": "Name of the product",
+                "type": "string"
+                },
+                "price": {
+                "description": "The price of the product",
+                "type": "number",
+                "exclusiveMinimum": 0
+                },
+                "tags": {
+                "description": "Tags for the product",
+                "type": "array",
+                "items": {
+                    "type": "string"
+                },
+                "minItems": 1,
+                "DISABLED_uniqueItems": true
+                },
+                "dimensions": {
+                "type": "object",
+                "properties": {
+                    "length": {
+                    "type": "number"
+                    },
+                    "width": {
+                    "type": "number"
+                    },
+                    "height": {
+                    "type": "number"
+                    }
+                },
+                "required": [ "length", "width", "height" ]
+                }
+            },
+            "required": [ "productId", "productName", "price" ]
+        })""",
+        // Passing strings
+        {
+            R"""({"productId": 1, "productName": "A green door", "price": 12.50})""",
+            R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green"]})""",
+            R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green"], "dimensions": {"length": 785, "width": 250.5, "height": -0.359}})""",
+        },
+        // Failing strings
+        {
+            R"""({})""",  // Missing all required properties
+            R"""({"productName": "A green door", "price": 12.50, "productId": 1})""",  // Out of order properties
+            // `exclusiveMinimum` is OK for llg
+            R"""({"productId": 1, "productName": "A green door", "price": -12.50})""",
+            R"""({"productId": 1, "productName": "A green door"})""",  // Missing required property (price)
+            R"""({"productName": "A green door", "price": 12.50})""",  // Missing required property (productId)
+            R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": []})""",  // tags is empty, but minItems is 1
+            R"""({"productId": 1, "productName": "A green door", "price": 12.50, "dimensions": {"length": 785, "width": 250.5, "height": -0.359}, "tags": ["home", "green"]})""",  // Tags and dimensions are out of order
+            // TODO: The following line should fail, but currently it passes. `uniqueItems` is not supported, as it would likely be too difficult to implement.
+            // R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green", "home"]})""",
+        });
+}
+
+static void one_hot(llama_token_data_array & tok_arr, llama_token selected) {
+    auto n_vocab = tok_arr.size;
+
+    tok_arr.selected = -1;
+    tok_arr.sorted   = false;
+    for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
+        tok_arr.data[token_id].id    = token_id;
+        tok_arr.data[token_id].logit = 0.0f;
+    }
+
+    tok_arr.data[selected].logit = 100.0f;
+}
+
+static void test_sampler_chain(void) {
+    auto sparams            = llama_sampler_chain_default_params();
+    sparams.no_perf         = false;
+    llama_sampler * sampler = llama_sampler_chain_init(sparams);
+
+    const auto grammar_data = R"(%llguidance {}
+start: /[A-Z ]*/)";
+
+    llama_sampler_chain_add(sampler, llama_sampler_init_llg(vocab, "lark", grammar_data));
+    llama_sampler_chain_add(sampler, llama_sampler_init_dist(42));
+
+    auto input  = "ALL YOUR BASE ARE BELONG TO US";
+    auto tokens = common_tokenize(vocab, input, false, false);
+
+    auto n_vocab = llama_vocab_n_tokens(vocab);
+
+    std::vector cur;
+    cur.reserve(n_vocab);
+    for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
+        cur.emplace_back(llama_token_data{ token_id, 0.0f, 0.0f });
+    }
+    auto tok_arr = llama_token_data_array{ cur.data(), cur.size(), -1, false };
+
+    for (const auto token : tokens) {
+        one_hot(tok_arr, token);
+
+        fprintf(stderr, "applying token: %d\n", token);
+        llama_sampler_apply(sampler, &tok_arr);
+
+        auto idx = tok_arr.selected;
+        fprintf(stderr, " -> %d %f\n", cur[idx].id, cur[idx].logit);
+        assert(cur[tok_arr.selected].id == token);
+        llama_sampler_accept(sampler, token);
+    }
+
+    auto tok_eos = llama_vocab_eot(vocab);
+    if (tok_eos == LLAMA_TOKEN_NULL) {
+        tok_eos = llama_vocab_eos(vocab);
+    }
+
+    one_hot(tok_arr, tok_eos);
+
+    llama_sampler_apply(sampler, &tok_arr);
+    assert(cur[tok_arr.selected].id == tok_eos);
+}
+
+int main(int argc, const char ** argv) {
+    fprintf(stdout, "Running llguidance integration tests...\n");
+
+    if (argc != 2) {
+        fprintf(stderr, "Usage: %s \n", argv[0]);
+        return 1;
+    }
+
+    const char * vocab_file = argv[1];
+
+    fprintf(stderr, "reading vocab from: '%s'\n", vocab_file);
+
+    llama_model *   model;
+    llama_context * ctx;
+
+    llama_backend_init();
+
+    // load the vocab
+    {
+        auto mparams = llama_model_default_params();
+
+        mparams.vocab_only = true;
+
+        model = llama_model_load_from_file(vocab_file, mparams);
+
+        if (model == NULL) {
+            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, vocab_file);
+            return 1;
+        }
+
+        // needed?
+        auto cparams = llama_context_default_params();
+
+        ctx = llama_init_from_model(model, cparams);
+
+        if (ctx == NULL) {
+            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, vocab_file);
+            llama_model_free(model);
+            return 1;
+        }
+    }
+
+    vocab = llama_model_get_vocab(model);
+
+    test_simple_grammar();
+    test_complex_grammar();
+    test_special_chars();
+    test_quantifiers();
+    test_json_schema();
+
+    test_sampler_chain();
+
+    fprintf(stdout, "All tests passed.\n");
+    return 0;
+}
diff --git a/tests/test-grammar-parser.cpp b/tests/test-grammar-parser.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..67821a2d5c6097e8f4c4f13a1b9204d1a9573d3b
--- /dev/null
+++ b/tests/test-grammar-parser.cpp
@@ -0,0 +1,519 @@
+#ifdef NDEBUG
+#undef NDEBUG
+#endif
+
+#include "llama.h"
+
+// TODO: shold not include libllama sources
+#include "../src/llama-grammar.h"
+
+#include 
+
+static const char * type_str(llama_gretype type) {
+    switch (type) {
+        case LLAMA_GRETYPE_CHAR: return "LLAMA_GRETYPE_CHAR";
+        case LLAMA_GRETYPE_CHAR_NOT: return "LLAMA_GRETYPE_CHAR_NOT";
+        case LLAMA_GRETYPE_CHAR_ALT: return "LLAMA_GRETYPE_CHAR_ALT";
+        case LLAMA_GRETYPE_CHAR_RNG_UPPER: return "LLAMA_GRETYPE_CHAR_RNG_UPPER";
+        case LLAMA_GRETYPE_RULE_REF: return "LLAMA_GRETYPE_RULE_REF";
+        case LLAMA_GRETYPE_ALT: return "LLAMA_GRETYPE_ALT";
+        case LLAMA_GRETYPE_END: return "LLAMA_GRETYPE_END";
+        default: return "?";
+    }
+}
+
+static void verify_parsing(const char *grammar_bytes, const std::vector> expected, const std::vector &expected_rules) {
+    uint32_t index = 0;
+    llama_grammar_parser parsed_grammar;
+    parsed_grammar.parse(grammar_bytes);
+
+    std::map symbol_names;
+    for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) {
+        symbol_names[it->second] = it->first;
+    }
+
+    auto print_all = [&]() {
+        fprintf(stderr, "    verify_parsing(R\"\"\"(%s)\"\"\", {\n", grammar_bytes);
+        for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) {
+            fprintf(stderr, "        {\"%s\", %u},\n", it->first.c_str(), it->second);
+        }
+        fprintf(stderr, "    }, {\n");
+        for (size_t i_rule = 0; i_rule < parsed_grammar.rules.size(); i_rule++) {
+            fprintf(stderr, "        // %s (index %zu)\n", symbol_names[i_rule].c_str(), i_rule);
+            auto & rule = parsed_grammar.rules[i_rule];
+            for (uint32_t i = 0; i < rule.size(); i++) {
+                std::string rule_str;
+                fprintf(stderr, "        {%s, ", type_str(rule[i].type));
+                if (rule[i].type == LLAMA_GRETYPE_CHAR || rule[i].type == LLAMA_GRETYPE_CHAR_ALT ||
+                    rule[i].type == LLAMA_GRETYPE_CHAR_NOT || rule[i].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
+                    char c = rule[i].value;
+                    if (c == '\n') {
+                        fprintf(stderr, "'\\n'");
+                    } else if (c == '\t') {
+                        fprintf(stderr, "'\\t'");
+                    } else if (c == '\r') {
+                        fprintf(stderr, "'\\r'");
+                    } else if (c == '\0') {
+                        fprintf(stderr, "'\\0'");
+                    } else {
+                        fprintf(stderr, "'%c'", c);
+                    }
+                } else if (rule[i].type == LLAMA_GRETYPE_RULE_REF) {
+                    fprintf(stderr, "/* %s */ %u", symbol_names[rule[i].value].c_str(), rule[i].value);
+                } else {
+                    fprintf(stderr, "%u", rule[i].value);
+                }
+                fprintf(stderr, "},\n");
+            }
+        }
+        fprintf(stderr, "    });\n");
+    };
+
+    if (getenv("TEST_GRAMMAR_PARSER_PRINT_ALL")) {
+        print_all();
+        fprintf(stderr, "\n");
+        return;
+    }
+
+    fprintf(stderr, "Testing grammar:%s\n", grammar_bytes);
+
+    if (parsed_grammar.symbol_ids.size() != expected.size()) {
+        fprintf(stderr, "Code to update expectation (set TEST_GRAMMAR_PARSER_PRINT_ALL=1 to print all):\n");
+        print_all();
+        assert(parsed_grammar.symbol_ids.size() == expected.size());
+    }
+
+    for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it)
+    {
+        std::string key = it->first;
+        uint32_t value = it->second;
+        std::pair expected_pair = expected[index];
+
+        // pretty print error message before asserting
+        if (expected_pair.first != key || expected_pair.second != value)
+        {
+            fprintf(stderr, "index: %u\n", index);
+            fprintf(stderr, "expected_pair: %s, %u\n", expected_pair.first.c_str(), expected_pair.second);
+            fprintf(stderr, "actual_pair: %s, %u\n", key.c_str(), value);
+            fprintf(stderr, "expected_pair != actual_pair\n");
+            fprintf(stderr, "Code to update expectation (set TEST_GRAMMAR_PARSER_PRINT_ALL=1 to print all):\n");
+            print_all();
+        }
+
+        assert(expected_pair.first == key && expected_pair.second == value);
+
+        index++;
+    }
+
+    index = 0;
+    for (auto rule : parsed_grammar.rules)
+    {
+        // compare rule to expected rule
+        for (uint32_t i = 0; i < rule.size(); i++)
+        {
+            llama_grammar_element element = rule[i];
+            llama_grammar_element expected_element = expected_rules[index];
+
+            // pretty print error message before asserting
+            if (expected_element.type != element.type || expected_element.value != element.value)
+            {
+                fprintf(stderr, "index: %u\n", index);
+                fprintf(stderr, "expected_element: %s, %u\n", type_str(expected_element.type), expected_element.value);
+                fprintf(stderr, "actual_element: %s, %u\n", type_str(element.type), element.value);
+                fprintf(stderr, "expected_element != actual_element\n");
+                fprintf(stderr, "all elements:\n");
+                fprintf(stderr, "Code to update expectation (set TEST_GRAMMAR_PARSER_PRINT_ALL=1 to print all):\n");
+                print_all();
+            }
+
+            assert(expected_element.type == element.type && expected_element.value == element.value);
+            index++;
+        }
+    }
+}
+
+static void verify_failure(const char * grammar_bytes) {
+    fprintf(stderr, "Testing expected failure:%s\n", grammar_bytes);
+    llama_grammar_parser result;
+    result.parse(grammar_bytes);
+    assert(result.rules.empty() && "should have failed");
+}
+
+int main()
+{
+    verify_failure(R"""(
+        root ::= "a"{,}"
+    )""");
+
+    verify_failure(R"""(
+        root ::= "a"{,10}"
+    )""");
+
+    verify_parsing(R"""(
+        root  ::= "a"
+    )""", {
+        {"root", 0},
+    }, {
+        // root (index 0)
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_END, 0},
+    });
+
+    verify_parsing(R"""(
+        root  ::= "a" | [bdx-z] | [^1-3]
+    )""", {
+        {"root", 0},
+    }, {
+        // root (index 0)
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_CHAR, 'b'},
+        {LLAMA_GRETYPE_CHAR_ALT, 'd'},
+        {LLAMA_GRETYPE_CHAR_ALT, 'x'},
+        {LLAMA_GRETYPE_CHAR_RNG_UPPER, 'z'},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_CHAR_NOT, '1'},
+        {LLAMA_GRETYPE_CHAR_RNG_UPPER, '3'},
+        {LLAMA_GRETYPE_END, 0},
+    });
+
+    verify_parsing(R"""(
+        root  ::= a+
+        a     ::= "a"
+    )""", {
+        {"a", 1},
+        {"root", 0},
+        {"root_2", 2},
+    }, {
+        // root (index 0)
+        {LLAMA_GRETYPE_RULE_REF, /* a */ 1},
+        {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
+        {LLAMA_GRETYPE_END, 0},
+        // a (index 1)
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_END, 0},
+        // root_2 (index 2)
+        {LLAMA_GRETYPE_RULE_REF, /* a */ 1},
+        {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_END, 0},
+    });
+
+    verify_parsing(R"""(
+        root  ::= "a"+
+    )""", {
+        {"root", 0},
+        {"root_1", 1},
+    }, {
+        // root (index 0)
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
+        {LLAMA_GRETYPE_END, 0},
+        // root_1 (index 1)
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_END, 0},
+    });
+
+    verify_parsing(R"""(
+        root  ::= a?
+        a     ::= "a"
+    )""", {
+        {"a", 1},
+        {"root", 0},
+        {"root_2", 2},
+    }, {
+        // root (index 0)
+        {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
+        {LLAMA_GRETYPE_END, 0},
+        // a (index 1)
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_END, 0},
+        // root_2 (index 2)
+        {LLAMA_GRETYPE_RULE_REF, /* a */ 1},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_END, 0},
+    });
+
+    verify_parsing(R"""(
+        root  ::= "a"?
+    )""", {
+        {"root", 0},
+        {"root_1", 1},
+    }, {
+        // root (index 0)
+        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
+        {LLAMA_GRETYPE_END, 0},
+        // root_1 (index 1)
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_END, 0},
+    });
+
+    verify_parsing(R"""(
+        root  ::= a*
+        a     ::= "a"
+    )""", {
+        {"a", 1},
+        {"root", 0},
+        {"root_2", 2},
+    }, {
+        // root (index 0)
+        {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
+        {LLAMA_GRETYPE_END, 0},
+        // a (index 1)
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_END, 0},
+        // root_2 (index 2)
+        {LLAMA_GRETYPE_RULE_REF, /* a */ 1},
+        {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_END, 0},
+    });
+
+    verify_parsing(R"""(
+        root  ::= "a"*
+    )""", {
+        {"root", 0},
+        {"root_1", 1},
+    }, {
+        // root (index 0)
+        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
+        {LLAMA_GRETYPE_END, 0},
+        // root_1 (index 1)
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_END, 0},
+    });
+
+    verify_parsing(R"""(
+        root  ::= "a"{2}
+    )""", {
+        {"root", 0},
+    }, {
+        // root (index 0)
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_END, 0},
+    });
+
+    verify_parsing(R"""(
+        root  ::= "a"{2,}
+    )""", {
+        {"root", 0},
+        {"root_1", 1},
+    }, {
+        // root (index 0)
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
+        {LLAMA_GRETYPE_END, 0},
+        // root_1 (index 1)
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_END, 0},
+    });
+
+    verify_parsing(R"""(
+        root  ::= "a"{ 4}
+    )""", {
+        {"root", 0},
+    }, {
+        // root (index 0)
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_END, 0},
+    });
+
+    verify_parsing(R"""(
+        root  ::= "a"{2,4}
+    )""", {
+        {"root", 0},
+        {"root_1", 1},
+        {"root_2", 2},
+    }, {
+        // root (index 0)
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
+        {LLAMA_GRETYPE_END, 0},
+        // root_1 (index 1)
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_END, 0},
+        // root_2 (index 2)
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_END, 0},
+    });
+
+    verify_parsing(R"""(
+        root  ::= (expr "=" term "\n")+
+        expr  ::= term ([-+*/] term)*
+        term  ::= [0-9]+
+    )""", {
+        {"expr", 2},
+        {"expr_5", 5},
+        {"expr_6", 6},
+        {"root", 0},
+        {"root_1", 1},
+        {"root_4", 4},
+        {"term", 3},
+        {"term_7", 7},
+    }, {
+        // root (index 0)
+        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
+        {LLAMA_GRETYPE_RULE_REF, /* root_4 */ 4},
+        {LLAMA_GRETYPE_END, 0},
+        // root_1 (index 1)
+        {LLAMA_GRETYPE_RULE_REF, /* expr */ 2},
+        {LLAMA_GRETYPE_CHAR, '='},
+        {LLAMA_GRETYPE_RULE_REF, /* term */ 3},
+        {LLAMA_GRETYPE_CHAR, '\n'},
+        {LLAMA_GRETYPE_END, 0},
+        // expr (index 2)
+        {LLAMA_GRETYPE_RULE_REF, /* term */ 3},
+        {LLAMA_GRETYPE_RULE_REF, /* expr_6 */ 6},
+        {LLAMA_GRETYPE_END, 0},
+        // term (index 3)
+        {LLAMA_GRETYPE_CHAR, '0'},
+        {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'},
+        {LLAMA_GRETYPE_RULE_REF, /* term_7 */ 7},
+        {LLAMA_GRETYPE_END, 0},
+        // root_4 (index 4)
+        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
+        {LLAMA_GRETYPE_RULE_REF, /* root_4 */ 4},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_END, 0},
+        // expr_5 (index 5)
+        {LLAMA_GRETYPE_CHAR, '-'},
+        {LLAMA_GRETYPE_CHAR_ALT, '+'},
+        {LLAMA_GRETYPE_CHAR_ALT, '*'},
+        {LLAMA_GRETYPE_CHAR_ALT, '/'},
+        {LLAMA_GRETYPE_RULE_REF, /* term */ 3},
+        {LLAMA_GRETYPE_END, 0},
+        // expr_6 (index 6)
+        {LLAMA_GRETYPE_RULE_REF, /* expr_5 */ 5},
+        {LLAMA_GRETYPE_RULE_REF, /* expr_6 */ 6},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_END, 0},
+        // term_7 (index 7)
+        {LLAMA_GRETYPE_CHAR, '0'},
+        {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'},
+        {LLAMA_GRETYPE_RULE_REF, /* term_7 */ 7},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_END, 0},
+    });
+
+    verify_parsing(R"""(
+        root  ::= (expr "=" ws term "\n")+
+        expr  ::= term ([-+*/] term)*
+        term  ::= ident | num | "(" ws expr ")" ws
+        ident ::= [a-z] [a-z0-9_]* ws
+        num   ::= [0-9]+ ws
+        ws    ::= [ \t\n]*
+    )""", {
+        {"expr", 2},
+        {"expr_6", 6},
+        {"expr_7", 7},
+        {"ident", 8},
+        {"ident_10", 10},
+        {"num", 9},
+        {"num_11", 11},
+        {"root", 0},
+        {"root_1", 1},
+        {"root_5", 5},
+        {"term", 4},
+        {"ws", 3},
+        {"ws_12", 12},
+    }, {
+        // root (index 0)
+        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
+        {LLAMA_GRETYPE_RULE_REF, /* root_5 */ 5},
+        {LLAMA_GRETYPE_END, 0},
+        // root_1 (index 1)
+        {LLAMA_GRETYPE_RULE_REF, /* expr */ 2},
+        {LLAMA_GRETYPE_CHAR, '='},
+        {LLAMA_GRETYPE_RULE_REF, /* ws */ 3},
+        {LLAMA_GRETYPE_RULE_REF, /* term */ 4},
+        {LLAMA_GRETYPE_CHAR, '\n'},
+        {LLAMA_GRETYPE_END, 0},
+        // expr (index 2)
+        {LLAMA_GRETYPE_RULE_REF, /* term */ 4},
+        {LLAMA_GRETYPE_RULE_REF, /* expr_7 */ 7},
+        {LLAMA_GRETYPE_END, 0},
+        // ws (index 3)
+        {LLAMA_GRETYPE_RULE_REF, /* ws_12 */ 12},
+        {LLAMA_GRETYPE_END, 0},
+        // term (index 4)
+        {LLAMA_GRETYPE_RULE_REF, /* ident */ 8},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_RULE_REF, /* num */ 9},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_CHAR, '('},
+        {LLAMA_GRETYPE_RULE_REF, /* ws */ 3},
+        {LLAMA_GRETYPE_RULE_REF, /* expr */ 2},
+        {LLAMA_GRETYPE_CHAR, ')'},
+        {LLAMA_GRETYPE_RULE_REF, /* ws */ 3},
+        {LLAMA_GRETYPE_END, 0},
+        // root_5 (index 5)
+        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
+        {LLAMA_GRETYPE_RULE_REF, /* root_5 */ 5},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_END, 0},
+        // expr_6 (index 6)
+        {LLAMA_GRETYPE_CHAR, '-'},
+        {LLAMA_GRETYPE_CHAR_ALT, '+'},
+        {LLAMA_GRETYPE_CHAR_ALT, '*'},
+        {LLAMA_GRETYPE_CHAR_ALT, '/'},
+        {LLAMA_GRETYPE_RULE_REF, /* term */ 4},
+        {LLAMA_GRETYPE_END, 0},
+        // expr_7 (index 7)
+        {LLAMA_GRETYPE_RULE_REF, /* expr_6 */ 6},
+        {LLAMA_GRETYPE_RULE_REF, /* expr_7 */ 7},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_END, 0},
+        // ident (index 8)
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_CHAR_RNG_UPPER, 'z'},
+        {LLAMA_GRETYPE_RULE_REF, /* ident_10 */ 10},
+        {LLAMA_GRETYPE_RULE_REF, /* ws */ 3},
+        {LLAMA_GRETYPE_END, 0},
+        // num (index 9)
+        {LLAMA_GRETYPE_CHAR, '0'},
+        {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'},
+        {LLAMA_GRETYPE_RULE_REF, /* num_11 */ 11},
+        {LLAMA_GRETYPE_RULE_REF, /* ws */ 3},
+        {LLAMA_GRETYPE_END, 0},
+        // ident_10 (index 10)
+        {LLAMA_GRETYPE_CHAR, 'a'},
+        {LLAMA_GRETYPE_CHAR_RNG_UPPER, 'z'},
+        {LLAMA_GRETYPE_CHAR_ALT, '0'},
+        {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'},
+        {LLAMA_GRETYPE_CHAR_ALT, '_'},
+        {LLAMA_GRETYPE_RULE_REF, /* ident_10 */ 10},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_END, 0},
+        // num_11 (index 11)
+        {LLAMA_GRETYPE_CHAR, '0'},
+        {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'},
+        {LLAMA_GRETYPE_RULE_REF, /* num_11 */ 11},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_END, 0},
+        // ws_12 (index 12)
+        {LLAMA_GRETYPE_CHAR, ' '},
+        {LLAMA_GRETYPE_CHAR_ALT, '\t'},
+        {LLAMA_GRETYPE_CHAR_ALT, '\n'},
+        {LLAMA_GRETYPE_RULE_REF, /* ws_12 */ 12},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_END, 0},
+    });
+
+    return 0;
+}
diff --git a/tests/test-json-partial.cpp b/tests/test-json-partial.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..bc136beceb9ae33770a4e9b8146ea0969c122855
--- /dev/null
+++ b/tests/test-json-partial.cpp
@@ -0,0 +1,237 @@
+#include "common.h"
+#include "json-partial.h"
+#include 
+#include 
+#include 
+
+template  static void assert_equals(const T & expected, const T & actual) {
+  if (expected != actual) {
+      std::cerr << "Expected: " << expected << std::endl;
+      std::cerr << "Actual: " << actual << std::endl;
+      std::cerr << std::flush;
+      throw std::runtime_error("Test failed");
+  }
+}
+
+static void test_json_healing() {
+  auto parse = [](const std::string & str) {
+      std::cerr << "# Parsing: " << str << '\n';
+      std::string::const_iterator it = str.begin();
+      const auto end = str.end();
+      common_json out;
+      std::string healing_marker = "$llama.cpp.json$";
+      if (common_json_parse(it, end, healing_marker, out)) {
+          auto dump = out.json.dump();
+          std::cerr << "Parsed: " << dump << '\n';
+          std::cerr << "Magic: " << out.healing_marker.json_dump_marker << '\n';
+          std::string result;
+          if (!out.healing_marker.json_dump_marker.empty()) {
+              auto i = dump.find(out.healing_marker.json_dump_marker);
+              if (i == std::string::npos) {
+                  throw std::runtime_error("Failed to find magic in dump " + dump + " (magic: " + out.healing_marker.json_dump_marker + ")");
+              }
+              result = dump.substr(0, i);
+          } else {
+            result = dump;
+          }
+          std::cerr << "Result: " << result << '\n';
+          if (string_starts_with(str, result)) {
+            std::cerr << "Failure!\n";
+          }
+        //   return dump;
+      } else {
+        throw std::runtime_error("Failed to parse: " + str);
+      }
+
+  };
+  auto parse_all = [&](const std::string & str) {
+      for (size_t i = 1; i < str.size(); i++) {
+          parse(str.substr(0, i));
+      }
+  };
+  parse_all("{\"a\": \"b\"}");
+  parse_all("{\"hey\": 1, \"ho\\\"ha\": [1]}");
+
+  parse_all("[{\"a\": \"b\"}]");
+
+  auto test = [&](const std::vector & inputs, const std::string & expected, const std::string & expected_marker) {
+      for (const auto & input : inputs) {
+        common_json out;
+        assert_equals(true, common_json_parse(input, "$foo", out));
+        assert_equals(expected, out.json.dump());
+        assert_equals(expected_marker, out.healing_marker.json_dump_marker);
+      }
+  };
+  // No healing needed:
+  test(
+    {
+      R"([{"a":"b"}, "y"])",
+    },
+    R"([{"a":"b"},"y"])",
+    ""
+  );
+  // Partial literals can't be healed:
+  test(
+    {
+      R"([1)",
+      R"([tru)",
+      R"([n)",
+      R"([nul)",
+      R"([23.2)",
+    },
+    R"(["$foo"])",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"({"a": 1)",
+      R"({"a": tru)",
+      R"({"a": n)",
+      R"({"a": nul)",
+      R"({"a": 23.2)",
+    },
+    R"({"a":"$foo"})",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"({)",
+    },
+    R"({"$foo":1})",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"([)",
+    },
+    R"(["$foo"])",
+    R"("$foo)"
+  );
+  // Healing right after a full literal
+  test(
+    {
+      R"(1 )",
+    },
+    R"(1)",
+    ""
+  );
+  test(
+    {
+      R"(true)",
+      R"(true )",
+    },
+    R"(true)",
+    ""
+  );
+  test(
+    {
+      R"(null)",
+      R"(null )",
+    },
+    R"(null)",
+    ""
+  );
+  test(
+    {
+      R"([1 )",
+    },
+    R"([1,"$foo"])",
+    R"(,"$foo)"
+  );
+  test(
+    {
+      R"([{})",
+      R"([{} )",
+    },
+    R"([{},"$foo"])",
+    R"(,"$foo)"
+  );
+  test(
+    {
+      R"([true)",
+    },
+    // TODO: detect the true/false/null literal was complete
+    R"(["$foo"])",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"([true )",
+    },
+    R"([true,"$foo"])",
+    R"(,"$foo)"
+  );
+  test(
+    {
+      R"([true,)",
+    },
+    R"([true,"$foo"])",
+    R"("$foo)"
+  );
+  // Test nesting
+  test(
+    {
+      R"([{"a": [{"b": [{)",
+    },
+    R"([{"a":[{"b":[{"$foo":1}]}]}])",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"([{"a": [{"b": [)",
+    },
+    R"([{"a":[{"b":["$foo"]}]}])",
+    R"("$foo)"
+  );
+
+  test(
+    {
+      R"([{"a": "b"})",
+      R"([{"a": "b"} )",
+    },
+    R"([{"a":"b"},"$foo"])",
+    R"(,"$foo)"
+  );
+  test(
+    {
+      R"([{"a": "b"},)",
+      R"([{"a": "b"}, )",
+    },
+    R"([{"a":"b"},"$foo"])",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"({ "code)",
+    },
+    R"({"code$foo":1})",
+    R"($foo)"
+  );
+  test(
+    {
+      R"({ "code\)",
+    },
+    R"({"code\\$foo":1})",
+    R"(\$foo)"
+  );
+  test(
+    {
+      R"({ "code")",
+    },
+    R"({"code":"$foo"})",
+    R"(:"$foo)"
+  );
+  test(
+    {
+      R"({ "key")",
+    },
+    R"({"key":"$foo"})",
+    R"(:"$foo)"
+  );
+}
+
+int main() {
+    test_json_healing();
+    std::cerr << "All tests passed.\n";
+    return 0;
+}
diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..78ee55e246f3d4d34fb36f0f4be94a0008fcc0d6
--- /dev/null
+++ b/tests/test-json-schema-to-grammar.cpp
@@ -0,0 +1,1306 @@
+#ifdef NDEBUG
+#undef NDEBUG
+#endif
+
+#include "json-schema-to-grammar.h"
+
+#include "../src/llama-grammar.h"
+
+#include 
+
+#include 
+#include 
+#include 
+#include 
+
+static std::string trim(const std::string & source) {
+    std::string s(source);
+    s.erase(0,s.find_first_not_of(" \n\r\t"));
+    s.erase(s.find_last_not_of(" \n\r\t")+1);
+    return std::regex_replace(s, std::regex("(^|\n)[ \t]+"), "$1");
+}
+
+enum TestCaseStatus {
+    SUCCESS,
+    FAILURE
+};
+
+struct TestCase {
+    TestCaseStatus expected_status;
+    std::string name;
+    std::string schema;
+    std::string expected_grammar;
+
+    void _print_failure_header() const {
+        fprintf(stderr, "#\n# Test '%s' failed.\n#\n%s\n", name.c_str(), schema.c_str());
+    }
+    void verify(const std::string & actual_grammar) const {
+        if (trim(actual_grammar) != trim(expected_grammar)) {
+        _print_failure_header();
+        fprintf(stderr, "# EXPECTED:\n%s\n# ACTUAL:\n%s\n", expected_grammar.c_str(), actual_grammar.c_str());
+        assert(false);
+        }
+    }
+    void verify_expectation_parseable() const {
+        try {
+            llama_grammar_parser state;
+            state.parse(expected_grammar.c_str());
+            if (state.symbol_ids.find("root") == state.symbol_ids.end()) {
+                throw std::runtime_error("Grammar failed to parse:\n" + expected_grammar);
+            }
+        } catch (const std::runtime_error & ex) {
+            _print_failure_header();
+            fprintf(stderr, "# GRAMMAR ERROR: %s\n", ex.what());
+            assert(false);
+        }
+    }
+    void verify_status(TestCaseStatus status) const {
+        if (status != expected_status) {
+            _print_failure_header();
+            fprintf(stderr, "# EXPECTED STATUS: %s\n", expected_status == SUCCESS ? "SUCCESS" : "FAILURE");
+            fprintf(stderr, "# ACTUAL STATUS: %s\n", status == SUCCESS ? "SUCCESS" : "FAILURE");
+            assert(false);
+        }
+    }
+};
+
+static void write(const std::string & file, const std::string & content) {
+    std::ofstream f;
+    f.open(file.c_str());
+    f << content.c_str();
+    f.close();
+}
+
+static std::string read(const std::string & file) {
+    std::ostringstream actuals;
+    actuals << std::ifstream(file.c_str()).rdbuf();
+    return actuals.str();
+}
+
+static void test_all(const std::string & lang, std::function runner) {
+    fprintf(stderr, "#\n# Testing JSON schema conversion (%s)\n#\n", lang.c_str());
+    auto test = [&](const TestCase & tc) {
+        fprintf(stderr, "- %s%s\n", tc.name.c_str(), tc.expected_status == FAILURE ? " (failure expected)" : "");
+        runner(tc);
+    };
+
+    test({
+        SUCCESS,
+        "min 0",
+        R"""({
+            "type": "integer",
+            "minimum": 0
+        })""",
+        R"""(
+            root ::= ([0] | [1-9] [0-9]{0,15}) space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "min 1",
+        R"""({
+            "type": "integer",
+            "minimum": 1
+        })""",
+        R"""(
+            root ::= ([1-9] [0-9]{0,15}) space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "min 3",
+        R"""({
+            "type": "integer",
+            "minimum": 3
+        })""",
+        R"""(
+            root ::= ([1-2] [0-9]{1,15} | [3-9] [0-9]{0,15}) space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "min 9",
+        R"""({
+            "type": "integer",
+            "minimum": 9
+        })""",
+        R"""(
+            root ::= ([1-8] [0-9]{1,15} | [9] [0-9]{0,15}) space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "min 10",
+        R"""({
+            "type": "integer",
+            "minimum": 10
+        })""",
+        R"""(
+            root ::= ([1] ([0-9]{1,15}) | [2-9] [0-9]{1,15}) space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "min 25",
+        R"""({
+            "type": "integer",
+            "minimum": 25
+        })""",
+        R"""(
+            root ::= ([1] [0-9]{2,15} | [2] ([0-4] [0-9]{1,14} | [5-9] [0-9]{0,14}) | [3-9] [0-9]{1,15}) space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "max 30",
+        R"""({
+            "type": "integer",
+            "maximum": 30
+        })""",
+        R"""(
+            root ::= ("-" [1-9] [0-9]{0,15} | [0-9] | ([1-2] [0-9] | [3] "0")) space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "min -5",
+        R"""({
+            "type": "integer",
+            "minimum": -5
+        })""",
+        R"""(
+            root ::= ("-" ([0-5]) | [0] | [1-9] [0-9]{0,15}) space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "min -123",
+        R"""({
+            "type": "integer",
+            "minimum": -123
+        })""",
+        R"""(
+            root ::= ("-" ([0-9] | ([1-8] [0-9] | [9] [0-9]) | "1" ([0-1] [0-9] | [2] [0-3])) | [0] | [1-9] [0-9]{0,15}) space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "max -5",
+        R"""({
+            "type": "integer",
+            "maximum": -5
+        })""",
+        R"""(
+            root ::= ("-" ([0-4] [0-9]{1,15} | [5-9] [0-9]{0,15})) space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "max 1",
+        R"""({
+            "type": "integer",
+            "maximum": 1
+        })""",
+        R"""(
+            root ::= ("-" [1-9] [0-9]{0,15} | [0-1]) space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "max 100",
+        R"""({
+            "type": "integer",
+            "maximum": 100
+        })""",
+        R"""(
+            root ::= ("-" [1-9] [0-9]{0,15} | [0-9] | ([1-8] [0-9] | [9] [0-9]) | "100") space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "min 0 max 23",
+        R"""({
+            "type": "integer",
+            "minimum": 0,
+            "maximum": 23
+        })""",
+        R"""(
+            root ::= ([0-9] | ([1] [0-9] | [2] [0-3])) space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "min 15 max 300",
+        R"""({
+            "type": "integer",
+            "minimum": 15,
+            "maximum": 300
+        })""",
+        R"""(
+            root ::= (([1] ([5-9]) | [2-9] [0-9]) | ([1-2] [0-9]{2} | [3] "00")) space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "min 5 max 30",
+        R"""({
+            "type": "integer",
+            "minimum": 5,
+            "maximum": 30
+        })""",
+        R"""(
+            root ::= ([5-9] | ([1-2] [0-9] | [3] "0")) space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "min -123 max 42",
+        R"""({
+            "type": "integer",
+            "minimum": -123,
+            "maximum": 42
+        })""",
+        R"""(
+            root ::= ("-" ([0-9] | ([1-8] [0-9] | [9] [0-9]) | "1" ([0-1] [0-9] | [2] [0-3])) | [0-9] | ([1-3] [0-9] | [4] [0-2])) space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "min -10 max 10",
+        R"""({
+            "type": "integer",
+            "minimum": -10,
+            "maximum": 10
+        })""",
+        R"""(
+            root ::= ("-" ([0-9] | "10") | [0-9] | "10") space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        FAILURE,
+        "unknown type",
+        R"""({
+            "type": "kaboom"
+        })""",
+        ""
+    });
+
+    test({
+        FAILURE,
+        "invalid type",
+        R"""({
+            "type": 123
+        })""",
+        ""
+    });
+
+    test({
+        SUCCESS,
+        "empty schema (object)",
+        "{}",
+        R"""(
+            array ::= "[" space ( value ("," space value)* )? "]" space
+            boolean ::= ("true" | "false") space
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            decimal-part ::= [0-9]{1,16}
+            integral-part ::= [0] | [1-9] [0-9]{0,15}
+            null ::= "null" space
+            number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
+            object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
+            root ::= object
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+            string ::= "\"" char* "\"" space
+            value ::= object | array | string | number | boolean | null
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "exotic formats",
+        R"""({
+            "items": [
+                { "format": "date" },
+                { "format": "uuid" },
+                { "format": "time" },
+                { "format": "date-time" }
+            ]
+        })""",
+        R"""(
+            date ::= [0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( "0" [1-9] | [1-2] [0-9] | "3" [0-1] )
+            date-string ::= "\"" date "\"" space
+            date-time ::= date "T" time
+            date-time-string ::= "\"" date-time "\"" space
+            root ::= "[" space tuple-0 "," space uuid "," space tuple-2 "," space tuple-3 "]" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+            time ::= ([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )
+            time-string ::= "\"" time "\"" space
+            tuple-0 ::= date-string
+            tuple-2 ::= time-string
+            tuple-3 ::= date-time-string
+            uuid ::= "\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "string",
+        R"""({
+            "type": "string"
+        })""",
+        R"""(
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            root ::= "\"" char* "\"" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "string w/ min length 1",
+        R"""({
+            "type": "string",
+            "minLength": 1
+        })""",
+        R"""(
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            root ::= "\"" char+ "\"" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "string w/ min length 3",
+        R"""({
+            "type": "string",
+            "minLength": 3
+        })""",
+        R"""(
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            root ::= "\"" char{3,} "\"" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "string w/ max length",
+        R"""({
+            "type": "string",
+            "maxLength": 3
+        })""",
+        R"""(
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            root ::= "\"" char{0,3} "\"" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "string w/ min & max length",
+        R"""({
+            "type": "string",
+            "minLength": 1,
+            "maxLength": 4
+        })""",
+        R"""(
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            root ::= "\"" char{1,4} "\"" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "boolean",
+        R"""({
+            "type": "boolean"
+        })""",
+        R"""(
+            root ::= ("true" | "false") space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "integer",
+        R"""({
+            "type": "integer"
+        })""",
+        R"""(
+            integral-part ::= [0] | [1-9] [0-9]{0,15}
+            root ::= ("-"? integral-part) space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "string const",
+        R"""({
+            "const": "foo"
+        })""",
+        R"""(
+            root ::= "\"foo\"" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "non-string const",
+        R"""({
+            "const": 123
+        })""",
+        R"""(
+            root ::= "123" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "non-string enum",
+        R"""({
+            "enum": ["red", "amber", "green", null, 42, ["foo"]]
+        })""",
+        R"""(
+            root ::= ("\"red\"" | "\"amber\"" | "\"green\"" | "null" | "42" | "[\"foo\"]") space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "string array",
+        R"""({
+            "type": "array",
+            "prefixItems": { "type": "string" }
+        })""",
+        R"""(
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            root ::= "[" space (string ("," space string)*)? "]" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+            string ::= "\"" char* "\"" space
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "nullable string array",
+        R"""({
+            "type": ["array", "null"],
+            "prefixItems": { "type": "string" }
+        })""",
+        R"""(
+            alternative-0 ::= "[" space (string ("," space string)*)? "]" space
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            null ::= "null" space
+            root ::= alternative-0 | null
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+            string ::= "\"" char* "\"" space
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "tuple1",
+        R"""({
+            "prefixItems": [{ "type": "string" }]
+        })""",
+        R"""(
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            root ::= "[" space string "]" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+            string ::= "\"" char* "\"" space
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "tuple2",
+        R"""({
+            "prefixItems": [{ "type": "string" }, { "type": "number" }]
+        })""",
+        R"""(
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            decimal-part ::= [0-9]{1,16}
+            integral-part ::= [0] | [1-9] [0-9]{0,15}
+            number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
+            root ::= "[" space string "," space number "]" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+            string ::= "\"" char* "\"" space
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "number",
+        R"""({
+            "type": "number"
+        })""",
+        R"""(
+            decimal-part ::= [0-9]{1,16}
+            integral-part ::= [0] | [1-9] [0-9]{0,15}
+            root ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "minItems",
+        R"""({
+            "items": {
+                "type": "boolean"
+            },
+            "minItems": 2
+        })""",
+        R"""(
+            boolean ::= ("true" | "false") space
+            root ::= "[" space boolean ("," space boolean)+ "]" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "maxItems 0",
+        R"""({
+            "items": {
+                "type": "boolean"
+            },
+            "maxItems": 0
+        })""",
+        R"""(
+            boolean ::= ("true" | "false") space
+            root ::= "[" space  "]" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "maxItems 1",
+        R"""({
+            "items": {
+                "type": "boolean"
+            },
+            "maxItems": 1
+        })""",
+        R"""(
+            boolean ::= ("true" | "false") space
+            root ::= "[" space boolean? "]" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "maxItems 2",
+        R"""({
+            "items": {
+                "type": "boolean"
+            },
+            "maxItems": 2
+        })""",
+        R"""(
+            boolean ::= ("true" | "false") space
+            root ::= "[" space (boolean ("," space boolean)?)? "]" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "min + maxItems",
+        R"""({
+            "items": {
+                "type": ["number", "integer"]
+            },
+            "minItems": 3,
+            "maxItems": 5
+        })""",
+        R"""(
+            decimal-part ::= [0-9]{1,16}
+            integer ::= ("-"? integral-part) space
+            integral-part ::= [0] | [1-9] [0-9]{0,15}
+            item ::= number | integer
+            number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
+            root ::= "[" space item ("," space item){2,4} "]" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "min + max items with min + max values across zero",
+        R"""({
+            "items": {
+                "type": "integer",
+                "minimum": -12,
+                "maximum": 207
+            },
+            "minItems": 3,
+            "maxItems": 5
+        })""",
+        R"""(
+            item ::= ("-" ([0-9] | "1" [0-2]) | [0-9] | ([1-8] [0-9] | [9] [0-9]) | ([1] [0-9]{2} | [2] "0" [0-7])) space
+            root ::= "[" space item ("," space item){2,4} "]" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "min + max items with min + max values",
+        R"""({
+            "items": {
+                "type": "integer",
+                "minimum": 12,
+                "maximum": 207
+            },
+            "minItems": 3,
+            "maxItems": 5
+        })""",
+        R"""(
+            item ::= (([1] ([2-9]) | [2-9] [0-9]) | ([1] [0-9]{2} | [2] "0" [0-7])) space
+            root ::= "[" space item ("," space item){2,4} "]" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "simple regexp",
+        R"""({
+            "type": "string",
+            "pattern": "^abc?d*efg+(hij)?kl$"
+        })""",
+        R"""(
+            root ::= "\"" ("ab" "c"? "d"* "ef" "g"+ ("hij")? "kl") "\"" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "regexp escapes",
+        R"""({
+            "type": "string",
+            "pattern": "^\\[\\]\\{\\}\\(\\)\\|\\+\\*\\?$"
+        })""",
+        R"""(
+            root ::= "\"" ("[]{}()|+*?") "\"" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "regexp quote",
+        R"""({
+            "type": "string",
+            "pattern": "^\"$"
+        })""",
+        R"""(
+            root ::= "\"" ("\"") "\"" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "regexp with top-level alternation",
+        R"""({
+            "type": "string",
+            "pattern": "^A|B|C|D$"
+        })""",
+        R"""(
+            root ::= "\"" ("A" | "B" | "C" | "D") "\"" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "regexp",
+        R"""({
+            "type": "string",
+            "pattern": "^(\\([0-9]{1,3}\\))?[0-9]{3}-[0-9]{4} a{3,5}nd...$"
+        })""",
+        R"""(
+            dot ::= [^\x0A\x0D]
+            root ::= "\"" (("(" root-1{1,3} ")")? root-1{3,3} "-" root-1{4,4} " " "a"{3,5} "nd" dot dot dot) "\"" space
+            root-1 ::= [0-9]
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "required props in original order",
+        R"""({
+            "type": "object",
+            "properties": {
+                "b": {"type": "string"},
+                "c": {"type": "string"},
+                "a": {"type": "string"}
+            },
+            "required": [
+                "a",
+                "b",
+                "c"
+            ],
+            "additionalProperties": false,
+            "definitions": {}
+        })""",
+        R"""(
+            a-kv ::= "\"a\"" space ":" space string
+            b-kv ::= "\"b\"" space ":" space string
+            c-kv ::= "\"c\"" space ":" space string
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            root ::= "{" space b-kv "," space c-kv "," space a-kv "}" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+            string ::= "\"" char* "\"" space
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "1 optional prop",
+        R"""({
+            "properties": {
+                "a": {
+                "type": "string"
+                }
+            },
+            "additionalProperties": false
+        })""",
+        R"""(
+            a-kv ::= "\"a\"" space ":" space string
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            root ::= "{" space  (a-kv )? "}" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+            string ::= "\"" char* "\"" space
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "N optional props",
+        R"""({
+            "properties": {
+                "a": {"type": "string"},
+                "b": {"type": "string"},
+                "c": {"type": "string"}
+            },
+            "additionalProperties": false
+        })""",
+        R"""(
+            a-kv ::= "\"a\"" space ":" space string
+            a-rest ::= ( "," space b-kv )? b-rest
+            b-kv ::= "\"b\"" space ":" space string
+            b-rest ::= ( "," space c-kv )?
+            c-kv ::= "\"c\"" space ":" space string
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            root ::= "{" space  (a-kv a-rest | b-kv b-rest | c-kv )? "}" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+            string ::= "\"" char* "\"" space
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "required + optional props each in original order",
+        R"""({
+            "properties": {
+                "b": {"type": "string"},
+                "a": {"type": "string"},
+                "d": {"type": "string"},
+                "c": {"type": "string"}
+            },
+            "required": ["a", "b"],
+            "additionalProperties": false
+        })""",
+        R"""(
+            a-kv ::= "\"a\"" space ":" space string
+            b-kv ::= "\"b\"" space ":" space string
+            c-kv ::= "\"c\"" space ":" space string
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            d-kv ::= "\"d\"" space ":" space string
+            d-rest ::= ( "," space c-kv )?
+            root ::= "{" space b-kv "," space a-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+            string ::= "\"" char* "\"" space
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "additional props",
+        R"""({
+            "type": "object",
+            "additionalProperties": {"type": "array", "items": {"type": "number"}}
+        })""",
+        R"""(
+            additional-kv ::= string ":" space additional-value
+            additional-value ::= "[" space (number ("," space number)*)? "]" space
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            decimal-part ::= [0-9]{1,16}
+            integral-part ::= [0] | [1-9] [0-9]{0,15}
+            number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
+            root ::= "{" space  (additional-kv ( "," space additional-kv )* )? "}" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+            string ::= "\"" char* "\"" space
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "additional props (true)",
+        R"""({
+            "type": "object",
+            "additionalProperties": true
+        })""",
+        R"""(
+            array ::= "[" space ( value ("," space value)* )? "]" space
+            boolean ::= ("true" | "false") space
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            decimal-part ::= [0-9]{1,16}
+            integral-part ::= [0] | [1-9] [0-9]{0,15}
+            null ::= "null" space
+            number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
+            object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
+            root ::= object
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+            string ::= "\"" char* "\"" space
+            value ::= object | array | string | number | boolean | null
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "additional props (implicit)",
+        R"""({
+            "type": "object"
+        })""",
+        R"""(
+            array ::= "[" space ( value ("," space value)* )? "]" space
+            boolean ::= ("true" | "false") space
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            decimal-part ::= [0-9]{1,16}
+            integral-part ::= [0] | [1-9] [0-9]{0,15}
+            null ::= "null" space
+            number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
+            object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
+            root ::= object
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+            string ::= "\"" char* "\"" space
+            value ::= object | array | string | number | boolean | null
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "empty w/o additional props",
+        R"""({
+            "type": "object",
+            "additionalProperties": false
+        })""",
+        R"""(
+            root ::= "{" space  "}" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "required + additional props",
+        R"""({
+            "type": "object",
+            "properties": {
+                "a": {"type": "number"}
+            },
+            "required": ["a"],
+            "additionalProperties": {"type": "string"}
+        })""",
+        R"""(
+            a-kv ::= "\"a\"" space ":" space number
+            additional-k ::= ["] ( [a] char+ | [^"a] char* )? ["] space
+            additional-kv ::= additional-k ":" space string
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            decimal-part ::= [0-9]{1,16}
+            integral-part ::= [0] | [1-9] [0-9]{0,15}
+            number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
+            root ::= "{" space a-kv ( "," space ( additional-kv ( "," space additional-kv )* ) )? "}" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+            string ::= "\"" char* "\"" space
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "optional + additional props",
+        R"""({
+            "type": "object",
+            "properties": {
+                "a": {"type": "number"}
+            },
+            "additionalProperties": {"type": "number"}
+        })""",
+        R"""(
+            a-kv ::= "\"a\"" space ":" space number
+            a-rest ::= ( "," space additional-kv )*
+            additional-k ::= ["] ( [a] char+ | [^"a] char* )? ["] space
+            additional-kv ::= additional-k ":" space number
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            decimal-part ::= [0-9]{1,16}
+            integral-part ::= [0] | [1-9] [0-9]{0,15}
+            number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
+            root ::= "{" space  (a-kv a-rest | additional-kv ( "," space additional-kv )* )? "}" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "required + optional + additional props",
+        R"""({
+            "type": "object",
+            "properties": {
+                "and": {"type": "number"},
+                "also": {"type": "number"}
+            },
+            "required": ["and"],
+            "additionalProperties": {"type": "number"}
+        })""",
+        R"""(
+            additional-k ::= ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space
+            additional-kv ::= additional-k ":" space number
+            also-kv ::= "\"also\"" space ":" space number
+            also-rest ::= ( "," space additional-kv )*
+            and-kv ::= "\"and\"" space ":" space number
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            decimal-part ::= [0-9]{1,16}
+            integral-part ::= [0] | [1-9] [0-9]{0,15}
+            number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
+            root ::= "{" space and-kv ( "," space ( also-kv also-rest | additional-kv ( "," space additional-kv )* ) )? "}" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "optional props with empty name",
+        R"""({
+            "properties": {
+                "": {"type": "integer"},
+                "a": {"type": "integer"}
+            },
+            "additionalProperties": {"type": "integer"}
+        })""",
+        R"""(
+            -kv ::= "\"\"" space ":" space root
+            -rest ::= ( "," space a-kv )? a-rest
+            a-kv ::= "\"a\"" space ":" space integer
+            a-rest ::= ( "," space additional-kv )*
+            additional-k ::= ["] ( [a] char+ | [^"a] char* ) ["] space
+            additional-kv ::= additional-k ":" space integer
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            integer ::= ("-"? integral-part) space
+            integral-part ::= [0] | [1-9] [0-9]{0,15}
+            root ::= ("-"? integral-part) space
+            root0 ::= "{" space  (-kv -rest | a-kv a-rest | additional-kv ( "," space additional-kv )* )? "}" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "optional props with nested names",
+        R"""({
+            "properties": {
+                "a": {"type": "integer"},
+                "aa": {"type": "integer"}
+            },
+            "additionalProperties": {"type": "integer"}
+        })""",
+        R"""(
+            a-kv ::= "\"a\"" space ":" space integer
+            a-rest ::= ( "," space aa-kv )? aa-rest
+            aa-kv ::= "\"aa\"" space ":" space integer
+            aa-rest ::= ( "," space additional-kv )*
+            additional-k ::= ["] ( [a] ([a] char+ | [^"a] char*) | [^"a] char* )? ["] space
+            additional-kv ::= additional-k ":" space integer
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            integer ::= ("-"? integral-part) space
+            integral-part ::= [0] | [1-9] [0-9]{0,15}
+            root ::= "{" space  (a-kv a-rest | aa-kv aa-rest | additional-kv ( "," space additional-kv )* )? "}" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "optional props with common prefix",
+        R"""({
+            "properties": {
+                "ab": {"type": "integer"},
+                "ac": {"type": "integer"}
+            },
+            "additionalProperties": {"type": "integer"}
+        })""",
+        R"""(
+            ab-kv ::= "\"ab\"" space ":" space integer
+            ab-rest ::= ( "," space ac-kv )? ac-rest
+            ac-kv ::= "\"ac\"" space ":" space integer
+            ac-rest ::= ( "," space additional-kv )*
+            additional-k ::= ["] ( [a] ([b] char+ | [c] char+ | [^"bc] char*) | [^"a] char* )? ["] space
+            additional-kv ::= additional-k ":" space integer
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            integer ::= ("-"? integral-part) space
+            integral-part ::= [0] | [1-9] [0-9]{0,15}
+            root ::= "{" space  (ab-kv ab-rest | ac-kv ac-rest | additional-kv ( "," space additional-kv )* )? "}" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "top-level $ref",
+        R"""({
+            "$ref": "#/definitions/foo",
+            "definitions": {
+                "foo": {
+                    "type": "object",
+                    "properties": {
+                        "a": {
+                            "type": "string"
+                        }
+                    },
+                    "required": [
+                        "a"
+                    ],
+                    "additionalProperties": false
+                }
+            }
+        })""",
+        R"""(
+            char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+            foo ::= "{" space foo-a-kv "}" space
+            foo-a-kv ::= "\"a\"" space ":" space string
+            root ::= foo
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+            string ::= "\"" char* "\"" space
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "anyOf",
+        R"""({
+            "anyOf": [
+                {"$ref": "#/definitions/foo"},
+                {"$ref": "#/definitions/bar"}
+            ],
+            "definitions": {
+                "foo": {
+                    "properties": {"a": {"type": "number"}}
+                },
+                "bar": {
+                    "properties": {"b": {"type": "number"}}
+                }
+            },
+            "type": "object"
+        })""",
+        R"""(
+            alternative-0 ::= foo
+            alternative-1 ::= bar
+            bar ::= "{" space  (bar-b-kv )? "}" space
+            bar-b-kv ::= "\"b\"" space ":" space number
+            decimal-part ::= [0-9]{1,16}
+            foo ::= "{" space  (foo-a-kv )? "}" space
+            foo-a-kv ::= "\"a\"" space ":" space number
+            integral-part ::= [0] | [1-9] [0-9]{0,15}
+            number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
+            root ::= alternative-0 | alternative-1
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "mix of allOf, anyOf and $ref (similar to https://json.schemastore.org/tsconfig.json)",
+        R"""({
+            "allOf": [
+                {"$ref": "#/definitions/foo"},
+                {"$ref": "#/definitions/bar"},
+                {
+                "anyOf": [
+                    {"$ref": "#/definitions/baz"},
+                    {"$ref": "#/definitions/bam"}
+                ]
+                }
+            ],
+            "definitions": {
+                "foo": {
+                    "properties": {"a": {"type": "number"}}
+                },
+                "bar": {
+                    "properties": {"b": {"type": "number"}}
+                },
+                "bam": {
+                    "properties": {"c": {"type": "number"}}
+                },
+                "baz": {
+                    "properties": {"d": {"type": "number"}}
+                }
+            },
+            "type": "object"
+        })""",
+        R"""(
+            a-kv ::= "\"a\"" space ":" space number
+            b-kv ::= "\"b\"" space ":" space number
+            c-kv ::= "\"c\"" space ":" space number
+            d-kv ::= "\"d\"" space ":" space number
+            d-rest ::= ( "," space c-kv )?
+            decimal-part ::= [0-9]{1,16}
+            integral-part ::= [0] | [1-9] [0-9]{0,15}
+            number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
+            root ::= "{" space a-kv "," space b-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "conflicting names",
+        R"""({
+            "type": "object",
+            "properties": {
+                "number": {
+                "type": "object",
+                "properties": {
+                    "number": {
+                    "type": "object",
+                        "properties": {
+                            "root": {
+                                "type": "number"
+                            }
+                        },
+                        "required": [
+                            "root"
+                        ],
+                        "additionalProperties": false
+                    }
+                },
+                "required": [
+                    "number"
+                ],
+                "additionalProperties": false
+                }
+            },
+            "required": [
+                "number"
+            ],
+            "additionalProperties": false,
+            "definitions": {}
+        })""",
+        R"""(
+            decimal-part ::= [0-9]{1,16}
+            integral-part ::= [0] | [1-9] [0-9]{0,15}
+            number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
+            number- ::= "{" space number-number-kv "}" space
+            number-kv ::= "\"number\"" space ":" space number-
+            number-number ::= "{" space number-number-root-kv "}" space
+            number-number-kv ::= "\"number\"" space ":" space number-number
+            number-number-root-kv ::= "\"root\"" space ":" space number
+            root ::= "{" space number-kv "}" space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+}
+
+int main() {
+    fprintf(stderr, "LLAMA_NODE_AVAILABLE = %s\n", getenv("LLAMA_NODE_AVAILABLE") ? "true" : "false");
+    fprintf(stderr, "LLAMA_PYTHON_AVAILABLE = %s\n", getenv("LLAMA_PYTHON_AVAILABLE") ? "true" : "false");
+
+    test_all("C++", [](const TestCase & tc) {
+        try {
+            tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema), true));
+            tc.verify_status(SUCCESS);
+        } catch (const std::runtime_error & ex) {
+            fprintf(stderr, "Error: %s\n", ex.what());
+            tc.verify_status(FAILURE);
+        }
+    });
+
+    if (getenv("LLAMA_SKIP_TESTS_SLOW_ON_EMULATOR")) {
+        fprintf(stderr, "\033[33mWARNING: Skipping slow tests on emulator.\n\033[0m");
+    } else {
+        if (getenv("LLAMA_PYTHON_AVAILABLE") || (std::system("python -c \"import sys; exit(1) if sys.version_info < (3, 8) else print('Python version is sufficient')\"") == 0)) {
+            test_all("Python", [](const TestCase & tc) {
+                write("test-json-schema-input.tmp", tc.schema);
+                tc.verify_status(std::system(
+                    "python ./examples/json_schema_to_grammar.py test-json-schema-input.tmp > test-grammar-output.tmp") == 0 ? SUCCESS : FAILURE);
+                tc.verify(read("test-grammar-output.tmp"));
+            });
+        } else {
+            fprintf(stderr, "\033[33mWARNING: Python not found (min version required is 3.8), skipping Python JSON schema -> grammar tests.\n\033[0m");
+        }
+
+        if (getenv("LLAMA_NODE_AVAILABLE") || (std::system("node --version") == 0)) {
+            test_all("JavaScript", [](const TestCase & tc) {
+                write("test-json-schema-input.tmp", tc.schema);
+                tc.verify_status(std::system(
+                    "node ./tests/run-json-schema-to-grammar.mjs test-json-schema-input.tmp > test-grammar-output.tmp") == 0 ? SUCCESS : FAILURE);
+                tc.verify(read("test-grammar-output.tmp"));
+            });
+        } else {
+            fprintf(stderr, "\033[33mWARNING: Node not found, skipping JavaScript JSON schema -> grammar tests.\n\033[0m");
+        }
+    }
+
+    test_all("Check Expectations Validity", [](const TestCase & tc) {
+        if (tc.expected_status == SUCCESS) {
+            tc.verify_expectation_parseable();
+        }
+    });
+}
diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..cc198f3e3c90366aa994441a2b8e8b4930b05030
--- /dev/null
+++ b/tests/test-llama-grammar.cpp
@@ -0,0 +1,408 @@
+#ifdef NDEBUG
+#undef NDEBUG
+#endif
+
+#include "llama.h"
+
+#include "../src/llama-grammar.h"
+
+#include 
+#include 
+
+int main()
+{
+    llama_grammar_parser parsed_grammar;
+
+    std::vector> expected = {
+        {"expr", 2},
+        {"expr_6", 6},
+        {"expr_7", 7},
+        {"ident", 8},
+        {"ident_10", 10},
+        {"num", 9},
+        {"num_11", 11},
+        {"root", 0},
+        {"root_1", 1},
+        {"root_5", 5},
+        {"term", 4},
+        {"ws", 3},
+        {"ws_12", 12},
+    };
+
+    std::vector> expected_rules = {
+        {{LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_END, 0}},
+        {
+            {LLAMA_GRETYPE_RULE_REF, 2},
+            {LLAMA_GRETYPE_CHAR, 61},
+            {LLAMA_GRETYPE_RULE_REF, 3},
+            {LLAMA_GRETYPE_RULE_REF, 4},
+            {LLAMA_GRETYPE_CHAR, 10},
+            {LLAMA_GRETYPE_END, 0},
+        },
+        {{LLAMA_GRETYPE_RULE_REF, 4}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_END, 0}},
+        {{LLAMA_GRETYPE_RULE_REF, 12}, {LLAMA_GRETYPE_END, 0}},
+        {
+            {LLAMA_GRETYPE_RULE_REF, 8},
+            {LLAMA_GRETYPE_ALT, 0},
+            {LLAMA_GRETYPE_RULE_REF, 9},
+            {LLAMA_GRETYPE_ALT, 0},
+            {LLAMA_GRETYPE_CHAR, 40},
+            {LLAMA_GRETYPE_RULE_REF, 3},
+            {LLAMA_GRETYPE_RULE_REF, 2},
+            {LLAMA_GRETYPE_CHAR, 41},
+            {LLAMA_GRETYPE_RULE_REF, 3},
+            {LLAMA_GRETYPE_END, 0},
+        },
+        {{LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_END, 0}},
+        {
+            {LLAMA_GRETYPE_CHAR, 45},
+            {LLAMA_GRETYPE_CHAR_ALT, 43},
+            {LLAMA_GRETYPE_CHAR_ALT, 42},
+            {LLAMA_GRETYPE_CHAR_ALT, 47},
+            {LLAMA_GRETYPE_RULE_REF, 4},
+            {LLAMA_GRETYPE_END, 0},
+        },
+        {{LLAMA_GRETYPE_RULE_REF, 6}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}},
+        {
+            {LLAMA_GRETYPE_CHAR, 97},
+            {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
+            {LLAMA_GRETYPE_RULE_REF, 10},
+            {LLAMA_GRETYPE_RULE_REF, 3},
+            {LLAMA_GRETYPE_END, 0},
+        },
+        {{LLAMA_GRETYPE_RULE_REF, 11}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_END, 0}},
+        {
+            {LLAMA_GRETYPE_CHAR, 97},
+            {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
+            {LLAMA_GRETYPE_CHAR_ALT, 48},
+            {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
+            {LLAMA_GRETYPE_CHAR_ALT, 95},
+            {LLAMA_GRETYPE_RULE_REF, 10},
+            {LLAMA_GRETYPE_ALT, 0},
+            {LLAMA_GRETYPE_END, 0},
+        },
+        {
+            {LLAMA_GRETYPE_CHAR, 48},
+            {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
+            {LLAMA_GRETYPE_RULE_REF, 11},
+            {LLAMA_GRETYPE_ALT, 0},
+            {LLAMA_GRETYPE_CHAR, 48},
+            {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
+            {LLAMA_GRETYPE_END, 0},
+        },
+        {
+            {LLAMA_GRETYPE_CHAR, 32},
+            {LLAMA_GRETYPE_CHAR_ALT, 9},
+            {LLAMA_GRETYPE_CHAR_ALT, 10},
+            {LLAMA_GRETYPE_RULE_REF, 12},
+            {LLAMA_GRETYPE_ALT, 0},
+            {LLAMA_GRETYPE_END, 0},
+        },
+    };
+
+    for (auto pair : expected)
+    {
+        parsed_grammar.symbol_ids[pair.first] = pair.second;
+    }
+
+    for (auto rule : expected_rules)
+    {
+        parsed_grammar.rules.emplace_back();
+        for (auto element : rule)
+        {
+            parsed_grammar.rules.back().push_back(element);
+        }
+    }
+
+    std::vector grammar_rules(parsed_grammar.c_rules());
+
+    llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
+    if (grammar == nullptr) {
+        throw std::runtime_error("Failed to initialize llama_grammar");
+    }
+
+    std::vector> expected_stacks = {
+        {
+            {LLAMA_GRETYPE_RULE_REF, 5},
+            {LLAMA_GRETYPE_CHAR, 61},
+            {LLAMA_GRETYPE_RULE_REF, 7},
+            {LLAMA_GRETYPE_CHAR, 97},
+        },
+        {
+            {LLAMA_GRETYPE_RULE_REF, 5},
+            {LLAMA_GRETYPE_CHAR, 61},
+            {LLAMA_GRETYPE_RULE_REF, 7},
+            {LLAMA_GRETYPE_RULE_REF, 3},
+            {LLAMA_GRETYPE_CHAR, 48},
+        },
+        {
+            {LLAMA_GRETYPE_RULE_REF, 5},
+            {LLAMA_GRETYPE_CHAR, 61},
+            {LLAMA_GRETYPE_RULE_REF, 7},
+            {LLAMA_GRETYPE_RULE_REF, 3},
+            {LLAMA_GRETYPE_CHAR, 48},
+        },
+        {
+            {LLAMA_GRETYPE_RULE_REF, 5},
+            {LLAMA_GRETYPE_CHAR, 61},
+            {LLAMA_GRETYPE_RULE_REF, 7},
+            {LLAMA_GRETYPE_CHAR, 40},
+        },
+        {
+            {LLAMA_GRETYPE_CHAR, 61},
+            {LLAMA_GRETYPE_RULE_REF, 7},
+            {LLAMA_GRETYPE_CHAR, 97},
+        },
+        {
+            {LLAMA_GRETYPE_CHAR, 61},
+            {LLAMA_GRETYPE_RULE_REF, 7},
+            {LLAMA_GRETYPE_RULE_REF, 3},
+            {LLAMA_GRETYPE_CHAR, 48},
+        },
+        {
+            {LLAMA_GRETYPE_CHAR, 61},
+            {LLAMA_GRETYPE_RULE_REF, 7},
+            {LLAMA_GRETYPE_RULE_REF, 3},
+            {LLAMA_GRETYPE_CHAR, 48},
+        },
+        {
+            {LLAMA_GRETYPE_CHAR, 61},
+            {LLAMA_GRETYPE_RULE_REF, 7},
+            {LLAMA_GRETYPE_CHAR, 40},
+        }};
+
+    auto index = 0;
+    for (const llama_grammar_stack & stack : llama_grammar_get_stacks(grammar))
+    {
+        // compare stack to expected_stack
+        for (uint32_t i = 0; i < stack.size(); i++)
+        {
+            const llama_grammar_element * element = stack[i];
+            const llama_grammar_element & expected_element = expected_stacks[index][i];
+
+            // pretty print error message before asserting
+            if (expected_element.type != element->type || expected_element.value != element->value)
+            {
+                fprintf(stderr, "index: %d\n", index);
+                fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value);
+                fprintf(stderr, "actual_element: %d, %u\n", element->type, element->value);
+                fprintf(stderr, "expected_element != actual_element\n");
+            }
+
+            assert(expected_element.type == element->type && expected_element.value == element->value);
+        }
+        index++;
+    }
+
+    std::vector next_candidates;
+    next_candidates.resize(24);
+
+    for (size_t i = 0; i < 24; ++i)
+    {
+        uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
+        cp[0] = 37 + i;
+        cp[1] = 0;
+        next_candidates[i] = {i, cp, {}};
+    }
+
+    std::vector>> expected_reject = {
+        {
+            {0, 37},
+            {1, 38},
+            {2, 39},
+            {3, 40},
+            {4, 41},
+            {5, 42},
+            {6, 43},
+            {7, 44},
+            {8, 45},
+            {9, 46},
+            {10, 47},
+            {11, 48},
+            {12, 49},
+            {13, 50},
+            {14, 51},
+            {15, 52},
+            {16, 53},
+            {17, 54},
+            {18, 55},
+            {19, 56},
+            {20, 57},
+            {21, 58},
+            {22, 59},
+            {23, 60},
+        },
+        {
+            {0, 37},
+            {1, 38},
+            {2, 39},
+            {3, 40},
+            {4, 41},
+            {5, 42},
+            {6, 43},
+            {7, 44},
+            {8, 45},
+            {9, 46},
+            {10, 47},
+            {21, 58},
+            {22, 59},
+            {23, 60},
+        },
+        {
+            {0, 37},
+            {1, 38},
+            {2, 39},
+            {3, 40},
+            {4, 41},
+            {5, 42},
+            {6, 43},
+            {7, 44},
+            {8, 45},
+            {9, 46},
+            {10, 47},
+            {21, 58},
+            {22, 59},
+            {23, 60},
+        },
+        {
+            {0, 37},
+            {1, 38},
+            {2, 39},
+            {4, 41},
+            {5, 42},
+            {6, 43},
+            {7, 44},
+            {8, 45},
+            {9, 46},
+            {10, 47},
+            {11, 48},
+            {12, 49},
+            {13, 50},
+            {14, 51},
+            {15, 52},
+            {16, 53},
+            {17, 54},
+            {18, 55},
+            {19, 56},
+            {20, 57},
+            {21, 58},
+            {22, 59},
+            {23, 60},
+        },
+        {
+            {0, 37},
+            {1, 38},
+            {2, 39},
+            {3, 40},
+            {4, 41},
+            {5, 42},
+            {6, 43},
+            {7, 44},
+            {8, 45},
+            {9, 46},
+            {10, 47},
+            {11, 48},
+            {12, 49},
+            {13, 50},
+            {14, 51},
+            {15, 52},
+            {16, 53},
+            {17, 54},
+            {18, 55},
+            {19, 56},
+            {20, 57},
+            {21, 58},
+            {22, 59},
+            {23, 60},
+        },
+        {
+            {0, 37},
+            {1, 38},
+            {2, 39},
+            {3, 40},
+            {4, 41},
+            {5, 42},
+            {6, 43},
+            {7, 44},
+            {8, 45},
+            {9, 46},
+            {10, 47},
+            {21, 58},
+            {22, 59},
+            {23, 60},
+        },
+        {
+            {0, 37},
+            {1, 38},
+            {2, 39},
+            {3, 40},
+            {4, 41},
+            {5, 42},
+            {6, 43},
+            {7, 44},
+            {8, 45},
+            {9, 46},
+            {10, 47},
+            {21, 58},
+            {22, 59},
+            {23, 60},
+        },
+        {
+            {0, 37},
+            {1, 38},
+            {2, 39},
+            {4, 41},
+            {5, 42},
+            {6, 43},
+            {7, 44},
+            {8, 45},
+            {9, 46},
+            {10, 47},
+            {11, 48},
+            {12, 49},
+            {13, 50},
+            {14, 51},
+            {15, 52},
+            {16, 53},
+            {17, 54},
+            {18, 55},
+            {19, 56},
+            {20, 57},
+            {21, 58},
+            {22, 59},
+            {23, 60},
+        },
+    };
+
+    std::vector rejects = llama_grammar_reject_candidates_for_stack(llama_grammar_get_rules(grammar), llama_grammar_get_stacks(grammar)[0], next_candidates);
+
+    std::vector> all_rejects;
+
+    for (std::size_t count = 0; count < llama_grammar_get_stacks(grammar).size(); ++count)
+    {
+        rejects = llama_grammar_reject_candidates_for_stack(llama_grammar_get_rules(grammar), llama_grammar_get_stacks(grammar)[count], next_candidates);
+        all_rejects.push_back(rejects);
+    }
+
+    index = 0;
+    for (auto rej : all_rejects)
+    {
+        for (uint32_t i = 0; i < rej.size(); i++)
+        {
+            auto element = rej[i];
+            auto expected_element = expected_reject[index][i];
+            assert(element.index == expected_element.first && *element.code_points == expected_element.second);
+        }
+        index++;
+    }
+
+    for (auto &candidate : next_candidates)
+    {
+        delete[] candidate.code_points;
+        candidate.code_points = nullptr;
+    }
+
+    llama_grammar_free_impl(grammar);
+
+    return 0;
+}
diff --git a/tests/test-log.cpp b/tests/test-log.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..306f28c6191cbd750f933f503410b9984456e2b3
--- /dev/null
+++ b/tests/test-log.cpp
@@ -0,0 +1,39 @@
+#include "log.h"
+
+#include 
+#include 
+
+int main() {
+    const int n_thread = 8;
+
+    std::thread threads[n_thread];
+    for (int i = 0; i < n_thread; i++) {
+        threads[i] = std::thread([i]() {
+            const int n_msg = 1000;
+
+            for (int j = 0; j < n_msg; j++) {
+                const int log_type = std::rand() % 4;
+
+                switch (log_type) {
+                    case 0: LOG_INF("Thread %d: %d\n", i, j); break;
+                    case 1: LOG_WRN("Thread %d: %d\n", i, j); break;
+                    case 2: LOG_ERR("Thread %d: %d\n", i, j); break;
+                    case 3: LOG_DBG("Thread %d: %d\n", i, j); break;
+                    default:
+                        break;
+                }
+
+                if (rand () % 10 < 5) {
+                    common_log_set_timestamps(common_log_main(), rand() % 2);
+                    common_log_set_prefix    (common_log_main(), rand() % 2);
+                }
+            }
+        });
+    }
+
+    for (int i = 0; i < n_thread; i++) {
+        threads[i].join();
+    }
+
+    return 0;
+}
diff --git a/tests/test-lora-conversion-inference.sh b/tests/test-lora-conversion-inference.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1d1f4886caaa500dfd3d358587dda68be8c97b2c
--- /dev/null
+++ b/tests/test-lora-conversion-inference.sh
@@ -0,0 +1,144 @@
+#!/bin/bash
+set -e
+
+# Array of models to iterate over
+declare -a params=(
+    "Gemma2ForCausalLM 64"
+    "LlamaForCausalLM 64"
+    "Phi3ForCausalLM 64"
+)
+
+MODELS_REPO=lora-tests
+MODELS_REPO_URL=https://huggingface.co/ggml-org/$MODELS_REPO
+COMMIT=c26d5fb85b4070a9e9c4e65d132c783b98086890
+
+# Clone the Hugging Face repository if the directory does not exist
+if [ ! -d "$MODELS_REPO" ]; then
+    echo "Cloning the Hugging Face repository..."
+    git clone $MODELS_REPO_URL --depth 1
+    cd $MODELS_REPO
+    git fetch --depth=1 origin $COMMIT
+    git reset --hard $COMMIT
+    cd -
+else
+    echo "Repository already exists. Skipping clone."
+fi
+
+# Array to store results to print
+results=()
+
+trim_leading_whitespace() {
+    local input_string="$1"
+    echo "${input_string#"${input_string%%[![:space:]]*}"}"
+}
+
+extract_starting_substring() {
+    local reference_string="$1"
+    local target_string="$2"
+
+    local target_length=${#target_string}
+    echo "${reference_string:0:$target_length}"
+}
+
+get_first_word() {
+    local input_string="$1"
+    read -r first_word _ <<< "$input_string"
+    echo "$first_word"
+}
+
+# Load the expected strings
+EXPECTED_BASE_FULL=$(cat $MODELS_REPO/data/pale_blue_dot.txt)
+EXPECTED_LORA_FULL=$(cat $MODELS_REPO/data/bohemian_rhapsody.txt)
+EXPECTED_BASE_FIRST_WORD=$(get_first_word "$EXPECTED_BASE_FULL")
+EXPECTED_LORA_FIRST_WORD=$(get_first_word "$EXPECTED_LORA_FULL")
+
+run_conversion_and_inference_lora() {
+    local model_name=$1
+    local hidden_size=$2
+
+    echo -e "\n\n-------- RUNNING TEST FOR MODEL $model_name --------\n\n"
+
+    # Convert safetensors to gguf
+    echo "Running convert_hf_to_gguf.py for $model_name with hidden_size $hidden_size..."
+    python convert_hf_to_gguf.py $MODELS_REPO/$model_name/hidden_size=$hidden_size/base \
+        --outfile $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \
+        --outtype f32
+
+    echo -e "\n\n---------------------------\n\n"
+    echo "Running convert_lora_to_gguf.py for $model_name with hidden_size $hidden_size..."
+    python3 convert_lora_to_gguf.py $MODELS_REPO/$model_name/hidden_size=$hidden_size/lora \
+        --base $MODELS_REPO/$model_name/hidden_size=$hidden_size/base \
+        --outtype f32
+
+    echo -e "\n\n---------------------------\n\n"
+    echo "Running llama-export-lora with lora for $model_name with hidden_size $hidden_size..."
+    ./llama-export-lora \
+        -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \
+        -o $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32-lora-merged.gguf \
+        --lora $MODELS_REPO/$model_name/hidden_size=$hidden_size/lora/Lora-F32-LoRA.gguf
+
+    # Run inference
+    echo -e "\n\n---------------------------\n\n"
+    echo "Running llama-cli without lora for $model_name with hidden_size $hidden_size..."
+    OUTPUT_BASE=$(./llama-cli -no-cnv -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \
+        -p "$EXPECTED_BASE_FIRST_WORD" -n 50 --seed 42 --temp 0)
+
+    echo -e "\n\n---------------------------\n\n"
+    echo "Running llama-cli with hot lora for $model_name with hidden_size $hidden_size..."
+    OUTPUT_LORA_HOT=$(./llama-cli -no-cnv -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \
+        --lora $MODELS_REPO/$model_name/hidden_size=$hidden_size/lora/Lora-F32-LoRA.gguf \
+        -p "$EXPECTED_LORA_FIRST_WORD" -n 50 --seed 42 --temp 0)
+
+    echo -e "\n\n---------------------------\n\n"
+    echo "Running llama-cli with merged lora for $model_name with hidden_size $hidden_size..."
+    OUTPUT_LORA_MERGED=$(./llama-cli -no-cnv -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32-lora-merged.gguf \
+        -p "$EXPECTED_LORA_FIRST_WORD" -n 50 --seed 42 --temp 0)
+
+    # Remove any initial white space
+    OUTPUT_BASE=$(trim_leading_whitespace "$OUTPUT_BASE")
+    OUTPUT_LORA_HOT=$(trim_leading_whitespace "$OUTPUT_LORA_HOT")
+    OUTPUT_LORA_MERGED=$(trim_leading_whitespace "$OUTPUT_LORA_MERGED")
+    # Extract the corresponding substring from full string
+    EXPECTED_BASE=$(extract_starting_substring "$EXPECTED_BASE_FULL" "$OUTPUT_BASE")
+    EXPECTED_LORA=$(extract_starting_substring "$EXPECTED_LORA_FULL" "$OUTPUT_LORA_HOT")
+
+    # Assert output equals the expected output
+    if [[ "$OUTPUT_BASE" != "$EXPECTED_BASE" ]]; then
+        echo "Error: $model_name OUTPUT_BASE does not start with the expected string."
+        echo -e "Out=$OUTPUT_BASE\n\nExp=$EXPECTED_BASE"
+        exit 1
+    fi
+    if [[ "$OUTPUT_LORA_HOT" != "$EXPECTED_LORA" ]]; then
+        echo "Error: $model_name OUTPUT_LORA_HOT does not start with the expected string."
+        echo -e "Out=$OUTPUT_LORA_HOT\n\nExp=$EXPECTED_LORA"
+        exit 1
+    fi
+    if [[ "$OUTPUT_LORA_MERGED" != "$EXPECTED_LORA" ]]; then
+        echo "Error: $model_name OUTPUT_LORA_MERGED does not start with the expected string."
+        echo -e "Out=$OUTPUT_LORA_MERGED\n\nExp=$EXPECTED_LORA"
+        exit 1
+    fi
+
+    # Store the results
+    results+=("
+    \n\033[1mResults for $model_name with hidden_size $hidden_size:\033[0m
+    \n\033[32m  • Base:\n$OUTPUT_BASE
+    \n\033[34m  • Lora hot:\n$OUTPUT_LORA_HOT
+    \n\033[36m  • Lora merged:\n$OUTPUT_LORA_MERGED
+    \n \033[0m
+    ")
+
+    echo "All tests passed for $model_name with hidden_size $hidden_size!"
+}
+
+# Run test for each model
+for param in "${params[@]}"; do
+    run_conversion_and_inference_lora $param
+done
+
+# Print results
+echo -e "\n\n---------------------------\n\n"
+echo -e "\n\033[1mSummary of All Results:\033[0m"
+for result in "${results[@]}"; do
+    echo -e "$result"
+done
diff --git a/tests/test-model-load-cancel.cpp b/tests/test-model-load-cancel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..9095826fa9884f876fdb2d8265b03a49b9b68d40
--- /dev/null
+++ b/tests/test-model-load-cancel.cpp
@@ -0,0 +1,27 @@
+#include "llama.h"
+#include "get-model.h"
+
+#include 
+
+int main(int argc, char *argv[] ) {
+    auto * model_path = get_model_or_exit(argc, argv);
+    auto * file = fopen(model_path, "r");
+    if (file == nullptr) {
+        fprintf(stderr, "no model at '%s' found\n", model_path);
+        return EXIT_FAILURE;
+    }
+
+    fprintf(stderr, "using '%s'\n", model_path);
+    fclose(file);
+
+    llama_backend_init();
+    auto params = llama_model_params{};
+    params.use_mmap = false;
+    params.progress_callback = [](float progress, void * ctx){
+        (void) ctx;
+        return progress > 0.50;
+    };
+    auto * model = llama_model_load_from_file(model_path, params);
+    llama_backend_free();
+    return model == nullptr ? EXIT_SUCCESS : EXIT_FAILURE;
+}
diff --git a/tests/test-mtmd-c-api.c b/tests/test-mtmd-c-api.c
new file mode 100644
index 0000000000000000000000000000000000000000..02e762e6a2d3e5aecb1678a55080fa64ff2afd9f
--- /dev/null
+++ b/tests/test-mtmd-c-api.c
@@ -0,0 +1,63 @@
+#include 
+#include 
+
+#include "mtmd.h"
+
+int main(void) {
+    printf("\n\nTesting libmtmd C API...\n");
+    printf("--------\n\n");
+
+    struct mtmd_context_params params = mtmd_context_params_default();
+    printf("Default image marker: %s\n", params.image_marker);
+
+    mtmd_input_chunks * chunks = mtmd_test_create_input_chunks();
+
+    if (!chunks) {
+        fprintf(stderr, "Failed to create input chunks\n");
+        return 1;
+    }
+
+    size_t n_chunks = mtmd_input_chunks_size(chunks);
+    printf("Number of chunks: %zu\n", n_chunks);
+    assert(n_chunks > 0);
+
+    for (size_t i = 0; i < n_chunks; i++) {
+        const mtmd_input_chunk * chunk = mtmd_input_chunks_get(chunks, i);
+        assert(chunk != NULL);
+        enum mtmd_input_chunk_type type = mtmd_input_chunk_get_type(chunk);
+        printf("Chunk %zu type: %d\n", i, type);
+
+        if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
+            size_t n_tokens;
+            const llama_token * tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
+            printf("    Text chunk with %zu tokens\n", n_tokens);
+            assert(tokens != NULL);
+            assert(n_tokens > 0);
+            for (size_t j = 0; j < n_tokens; j++) {
+                assert(tokens[j] >= 0);
+                printf("    > Token %zu: %d\n", j, tokens[j]);
+            }
+
+        } else if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
+            const mtmd_image_tokens * image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
+            size_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
+            size_t nx = mtmd_image_tokens_get_nx(image_tokens);
+            size_t ny = mtmd_image_tokens_get_ny(image_tokens);
+            const char * id = mtmd_image_tokens_get_id(image_tokens);
+            assert(n_tokens > 0);
+            assert(nx > 0);
+            assert(ny > 0);
+            assert(id != NULL);
+            printf("    Image chunk with %zu tokens\n", n_tokens);
+            printf("    Image size: %zu x %zu\n", nx, ny);
+            printf("    Image ID: %s\n", id);
+        }
+    }
+
+    // Free the chunks
+    mtmd_input_chunks_free(chunks);
+
+    printf("\n\nDONE: test libmtmd C API...\n");
+
+    return 0;
+}
diff --git a/tests/test-opt.cpp b/tests/test-opt.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..558f877210e7d734078ae0fea327f3a19b4086f3
--- /dev/null
+++ b/tests/test-opt.cpp
@@ -0,0 +1,904 @@
+#include "ggml.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
+#include "ggml-cpu.h"
+#include "ggml-opt.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+static bool almost_equal(const double a, const double b, const double atol) {
+    return fabs(a - b) < atol;
+}
+
+constexpr int64_t ne_datapoint = 2;
+constexpr int64_t ne_label     = 1;
+constexpr int64_t ndata        = 6;
+
+struct helper_ctx_data {
+    std::vector   datasets_supervised;
+    std::vector data_batch;
+    std::vector labels_batch;
+
+    ggml_opt_dataset_t       dataset_unsupervised;
+    struct ggml_context    * ctx_static;
+    struct ggml_context    * ctx_compute;
+    struct ggml_opt_params   opt_params;
+    ggml_opt_context_t       opt_ctx;
+    struct ggml_tensor     * inputs;
+    struct ggml_tensor     * weights;
+    struct ggml_tensor     * outputs;
+    ggml_backend_buffer_t    buf;
+    ggml_opt_result_t        result;
+    ggml_opt_result_t        result2;
+};
+
+// These default values make it easier to check optimization results vs. expected values.
+static ggml_opt_optimizer_params helper_get_test_opt_pars(void * userdata) {
+    ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(userdata);
+    result.adamw.alpha = 1.0f;
+    result.adamw.beta1 = 0.0f;
+    result.adamw.beta2 = 0.0f;
+    result.adamw.eps   = 0.0f;
+    return result;
+}
+
+static helper_ctx_data helper_get_ctx_data(
+        ggml_backend_sched_t    backend_sched,
+        ggml_backend_t          backend,
+        const bool              init_opt_ctx       = true,
+        const bool              optimizer_defaults = true,
+        int64_t                 nbatch_logical     = 1,
+        int64_t                 nbatch_physical    = 1,
+        enum ggml_opt_loss_type loss_type          = GGML_OPT_LOSS_TYPE_SUM) {
+    std::vector datasets(ndata);
+    for (int64_t ndata_shard = 1; ndata_shard <= ndata; ++ndata_shard) {
+        ggml_opt_dataset_t dataset = ggml_opt_dataset_init(
+            GGML_TYPE_F32, GGML_TYPE_F32, ne_datapoint, ne_label, ndata, ndata_shard);
+
+        float * data   = ggml_get_data_f32(ggml_opt_dataset_data(  dataset));
+        float * labels = ggml_get_data_f32(ggml_opt_dataset_labels(dataset));
+
+        for (int64_t idata = 0; idata < ndata; ++idata) {
+            for (int64_t id = 0; id < ne_datapoint; ++id) {
+                data[  idata*ne_datapoint + id] =     16*idata + id;
+            }
+            for (int64_t il = 0; il < ne_label;     ++il) {
+                labels[idata*ne_label     + il] = 16*(16*idata + il);
+            }
+        }
+
+        datasets[ndata_shard-1] = dataset;
+    }
+
+    ggml_opt_dataset_t dataset_unsupervised = ggml_opt_dataset_init(
+        GGML_TYPE_F32, GGML_TYPE_F32, 1, 0, ndata, /*ndata_shard =*/ 1);
+
+    float * data = ggml_get_data_f32(ggml_opt_dataset_data(dataset_unsupervised));
+
+    for (int64_t idata = 0; idata < ndata; ++idata) {
+        data[idata] = idata;
+    }
+
+    struct ggml_context * ctx_static;
+    struct ggml_context * ctx_compute;
+    {
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ (2*ndata + 2)*ggml_tensor_overhead(),
+            /*.mem_buffer =*/ nullptr,
+            /*.no_alloc   =*/ true,
+        };
+        ctx_static = ggml_init(params);
+    }
+    {
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ GGML_DEFAULT_GRAPH_SIZE*ggml_tensor_overhead() + 3*ggml_graph_overhead(),
+            /*.mem_buffer =*/ nullptr,
+            /*.no_alloc   =*/ true,
+        };
+        ctx_compute = ggml_init(params);
+    }
+
+    std::vector   data_batch(ndata);
+    std::vector labels_batch(ndata);
+    for (int64_t ndata_batch = 1; ndata_batch <= ndata; ++ndata_batch) {
+        data_batch[ndata_batch-1]   = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, ndata_batch*ne_datapoint);
+        labels_batch[ndata_batch-1] = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, ndata_batch*ne_label);
+    }
+
+    struct ggml_tensor * inputs = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, nbatch_physical);
+    ggml_set_name(inputs, "inputs");
+
+    struct ggml_tensor * weights = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);
+    ggml_set_name(weights, "weights");
+    ggml_set_param(weights);
+
+    struct ggml_tensor * intermediary = ggml_add(ctx_compute, inputs, weights);
+
+    struct ggml_tensor * outputs = ggml_scale(ctx_compute, intermediary, 1.0f);
+    ggml_set_name(outputs, "outputs");
+
+    ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx_static, backend);
+    const float w0 = float(ndata)/2;
+    ggml_backend_tensor_set(weights, &w0, 0, sizeof(float));
+
+    GGML_ASSERT(nbatch_logical % nbatch_physical == 0);
+    const int32_t opt_period = nbatch_logical / nbatch_physical;
+
+    struct ggml_opt_params opt_params = ggml_opt_default_params(backend_sched, loss_type);
+    opt_params.ctx_compute = ctx_compute;
+    opt_params.inputs      = inputs;
+    opt_params.outputs     = outputs;
+    opt_params.opt_period  = opt_period;
+    if (!optimizer_defaults) {
+        opt_params.get_opt_pars = helper_get_test_opt_pars;
+    }
+    ggml_opt_context_t opt_ctx = init_opt_ctx ? ggml_opt_init(opt_params) : nullptr;
+
+    ggml_opt_result_t result  = ggml_opt_result_init();
+    ggml_opt_result_t result2 = ggml_opt_result_init();
+
+    return {datasets, data_batch, labels_batch, dataset_unsupervised, ctx_static, ctx_compute, opt_params, opt_ctx, inputs, weights, outputs, buf, result, result2};
+}
+
+static void helper_free_ctx_data(struct helper_ctx_data ctx_data) {
+    ggml_opt_result_free(ctx_data.result);
+    ggml_opt_result_free(ctx_data.result2);
+    ggml_opt_free(ctx_data.opt_ctx);
+    ggml_backend_buffer_free(ctx_data.buf);
+    ggml_free(ctx_data.ctx_static);
+    ggml_free(ctx_data.ctx_compute);
+    for (ggml_opt_dataset_t dataset : ctx_data.datasets_supervised) {
+        ggml_opt_dataset_free(dataset);
+    }
+    ggml_opt_dataset_free(ctx_data.dataset_unsupervised);
+}
+
+static void helper_after_test(
+        const char * func, const bool high_level, const std::string options,
+        const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {
+    printf("  %s(high_level=%s%s, subtest=%s): ",
+           func, high_level ? "yes" : "no", options.c_str(), subtest.c_str());
+    if (subtest_ok) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+}
+
+static std::pair test_dataset(ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool shuffle) {
+    int ntest = 0;
+    int npass = 0;
+
+    struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend);
+
+    for (int64_t ndata_shard = 1; ndata_shard <= ndata; ++ndata_shard) {
+        ggml_opt_dataset_t dataset = cd.datasets_supervised[ndata_shard-1];
+
+        if (shuffle) {
+            ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1);
+        }
+
+        for (int64_t ndata_batch = 1; ndata_batch <= ndata; ++ndata_batch) {
+            if (ndata_batch % ndata_shard != 0) {
+                continue;
+            }
+            bool subtest_ok = true;
+
+            struct ggml_tensor *   data_batch =   cd.data_batch[ndata_batch-1];
+            struct ggml_tensor * labels_batch = cd.labels_batch[ndata_batch-1];
+
+            std::vector   data(ggml_nelements(  data_batch));
+            std::vector labels(ggml_nelements(labels_batch));
+
+            std::vector idata_shuffled;
+            const int64_t nbatches = ndata / ndata_batch;
+            for (int64_t ibatch = 0; ibatch < nbatches; ++ibatch) {
+                ggml_opt_dataset_get_batch(dataset, data_batch, labels_batch, ibatch);
+
+                ggml_backend_tensor_get(  data_batch,   data.data(), 0, ggml_nbytes(  data_batch));
+                ggml_backend_tensor_get(labels_batch, labels.data(), 0, ggml_nbytes(labels_batch));
+
+                for (int64_t idata_batch = 0; idata_batch < ndata_batch; ++idata_batch) {
+                    const int64_t idata = ibatch*ndata_batch + idata_batch;
+                    const int64_t idata_found = data[idata_batch*ne_datapoint] / 16;
+                    subtest_ok = subtest_ok && (shuffle || idata_found == idata);
+                    idata_shuffled.push_back(idata_found);
+
+                    for (int64_t id = 0; id < ne_datapoint; ++id) {
+                        if (data[  idata_batch*ne_datapoint + id] != 16*idata_found + id) {
+                            subtest_ok = false;
+                        }
+                    }
+                    for (int64_t il = 0; il < ne_label;     ++il) {
+                        if (labels[idata_batch*ne_label     + il] != 16*(16*idata_found + il)) {
+                            subtest_ok = false;
+                        }
+                    }
+                }
+            }
+
+            if (!shuffle || ndata % ndata_batch == 0) {
+                const int ndata_max = (ndata / ndata_batch) * ndata_batch;
+
+                for (int64_t idata = 0; subtest_ok && idata < ndata_max; ++idata) {
+                    int ninstances = 0;
+                    for (int64_t id : idata_shuffled) {
+                        ninstances += id == idata;
+                    }
+                    if (ninstances != 1) {
+                        subtest_ok = false;
+                    }
+                }
+            }
+
+            printf("  %s(shuffle=%s, ndata_shard=%" PRId64 ", ndata_batch=%" PRId64 "): ",
+                   __func__, shuffle ? "yes" : "no", ndata_shard, ndata_batch);
+            if (subtest_ok) {
+                printf("\033[1;32mOK\033[0m\n");
+                npass++;
+            } else {
+                printf("\033[1;31mFAIL\033[0m\n");
+            }
+            ntest++;
+        }
+    }
+
+    helper_free_ctx_data(cd);
+
+    return std::make_pair(npass, ntest);
+}
+
+static std::pair test_grad(ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
+    int ntest = 0;
+    int npass = 0;
+
+    struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false,
+    /*nbatch_logical =*/ 999999, /*nbatch_physical =*/ 1);
+
+    std::vector grad_history(ndata);
+    for (int64_t idata = 0; idata < ndata; ++idata) {
+        grad_history[idata] = NAN;
+    }
+
+    for (int idata = 0; idata < ndata; ++idata) {
+        const float idataf = idata;
+        ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
+        ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
+        ggml_opt_eval(cd.opt_ctx, cd.result);
+        ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, sizeof(float));
+    }
+
+    {
+        bool subtest_ok = true;
+        for (int idata = 0; idata < ndata; ++idata) {
+            if (grad_history[idata] != idata + 1) {
+                subtest_ok = false;
+            }
+        }
+        printf("  %s(): ", __func__);
+        if (subtest_ok) {
+            printf("\033[1;32mOK\033[0m\n");
+            npass++;
+        } else {
+            printf("\033[1;31mFAIL\033[0m\n");
+        }
+        ntest++;
+    }
+
+    helper_free_ctx_data(cd);
+
+    return std::make_pair(npass, ntest);
+}
+
+static void helper_after_test_forward_backward(
+        const char * func, const bool high_level, const bool shuffle,
+        const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {
+    std::string options = ", shuffle=";
+    options += shuffle ? "yes" : "no";
+    helper_after_test(func, high_level, options, subtest, subtest_ok, ntest, npass);
+}
+
+static std::pair test_forward_backward(
+        ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool high_level, const bool shuffle) {
+    int ntest = 0;
+    int npass = 0;
+
+    struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false);
+    struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx);
+
+    std::vector loss_history(ndata);
+    for (int64_t idata = 0; idata < ndata; ++idata) {
+        loss_history[idata] = NAN;
+    }
+
+    {
+        int64_t ndata;
+        ggml_opt_result_ndata(cd.result, &ndata);
+        double loss;
+        double loss_unc;
+        ggml_opt_result_loss(cd.result, &loss, &loss_unc);
+        double accuracy;
+        double accuracy_unc;
+        ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
+        const bool subtest_ok = ndata == 0 && loss == 0.0 && std::isnan(loss_unc) && std::isnan(accuracy) && std::isnan(accuracy_unc);
+        helper_after_test_forward_backward(__func__, high_level, shuffle, "results_initial", subtest_ok, ntest, npass);
+    }
+
+    if (high_level) {
+        ggml_opt_dataset_t dataset = cd.dataset_unsupervised;
+        if (shuffle) {
+            ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1);
+        }
+        ggml_opt_epoch(cd.opt_ctx, dataset, nullptr, cd.result, 0, nullptr, nullptr);
+    } else {
+        for (int idata = 0; idata < ndata; ++idata) {
+            const float idataf = idata;
+            ggml_opt_alloc(cd.opt_ctx, /*backward =*/ false);
+            ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
+            ggml_opt_eval(cd.opt_ctx, cd.result);
+            ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
+        }
+    }
+
+    {
+        float weights;
+        ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
+        const bool subtest_ok = weights == ndata/2;
+        helper_after_test_forward_backward(__func__, high_level, shuffle, "weights_after_forward", subtest_ok, ntest, npass);
+    }
+    {
+        int64_t ndata;
+        ggml_opt_result_ndata(cd.result, &ndata);
+        bool subtest_ok = ndata == 6;
+
+        double loss;
+        double loss_unc;
+        ggml_opt_result_loss(cd.result, &loss, &loss_unc);
+        subtest_ok = subtest_ok && loss == 33.0 && almost_equal(loss_unc, sqrt(3.5), 1e-10);
+
+        double accuracy;
+        double accuracy_unc;
+        ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
+        subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
+
+        helper_after_test_forward_backward(__func__, high_level, shuffle, "results_after_forward", subtest_ok, ntest, npass);
+    }
+
+    float w0;
+    ggml_backend_tensor_get(cd.weights, &w0, 0, sizeof(float));
+    for (int i = 0; i < 10; ++i) {
+        ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
+        ggml_opt_eval(cd.opt_ctx, cd.result);
+    }
+    ggml_backend_tensor_set(cd.weights, &w0, 0, sizeof(float));
+
+    ggml_opt_reset(cd.opt_ctx, /*optimizer =*/ false);
+    ggml_opt_result_reset(cd.result);
+
+    for (int64_t idata = 0; idata < ndata; ++idata) {
+        loss_history[idata] = NAN;
+    }
+
+    if (high_level) {
+        ggml_opt_dataset_t dataset = cd.dataset_unsupervised;
+        if (shuffle) {
+            ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1);
+        }
+        ggml_opt_epoch(cd.opt_ctx, dataset, cd.result, nullptr, ndata, nullptr, nullptr);
+    } else {
+        for (int idata = 0; idata < ndata; ++idata) {
+            const float idataf = idata;
+            ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
+            ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
+            ggml_opt_eval(cd.opt_ctx, cd.result);
+            ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
+        }
+    }
+
+    {
+        float weights;
+        ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
+        const bool subtest_ok = weights == -ndata/2;
+        helper_after_test_forward_backward(__func__, high_level, shuffle, "weights_after_forward_backward", subtest_ok, ntest, npass);
+    }
+    {
+        int64_t ndata;
+        ggml_opt_result_ndata(cd.result, &ndata);
+        bool subtest_ok = ndata == 6;
+
+        double loss;
+        double loss_unc;
+        ggml_opt_result_loss(cd.result, &loss, &loss_unc);
+        subtest_ok = subtest_ok && loss == 18.0 && (shuffle || loss_unc == 0.0);
+
+        double accuracy;
+        double accuracy_unc;
+        ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
+        subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
+
+        helper_after_test_forward_backward(__func__, high_level, shuffle, "result_after_forward_backward", subtest_ok, ntest, npass);
+    }
+
+    helper_free_ctx_data(cd);
+
+    return std::make_pair(npass, ntest);
+}
+
+static std::pair test_epoch_vs_fit(ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
+    int ntest = 0;
+    int npass = 0;
+
+    float weights_epoch;
+    float weights_fit;
+
+    {
+        struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true);
+        ggml_opt_dataset_t dataset = cd.dataset_unsupervised;
+
+        ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1);
+        ggml_opt_epoch(cd.opt_ctx, dataset, cd.result, nullptr, ndata, nullptr, nullptr);
+
+        ggml_backend_tensor_get(cd.weights, &weights_epoch, 0, ggml_nbytes(cd.weights));
+        helper_free_ctx_data(cd);
+    }
+    {
+        struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ false);
+        ggml_opt_dataset_t dataset = cd.dataset_unsupervised;
+
+        ggml_opt_fit(backend_sched, cd.ctx_compute, cd.inputs, cd.outputs, dataset,
+            GGML_OPT_LOSS_TYPE_SUM, ggml_opt_get_default_optimizer_params, 1, 1, 0.0f, true);
+
+        ggml_backend_tensor_get(cd.weights, &weights_fit, 0, ggml_nbytes(cd.weights));
+        helper_free_ctx_data(cd);
+    }
+
+    const bool subtest_ok = weights_epoch == weights_fit;
+
+    printf("  %s(): ", __func__);
+    if (subtest_ok) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    return std::make_pair(npass, ntest);
+}
+
+static void helper_after_test_idata_split(
+        const char * func, const bool high_level, const int epoch,
+        const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {
+    std::string options = ", epoch=";
+    options += std::to_string(epoch);
+    helper_after_test(func, high_level, options, subtest, subtest_ok, ntest, npass);
+}
+
+static std::pair test_idata_split(ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool high_level) {
+    int ntest = 0;
+    int npass = 0;
+
+    struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false);
+    struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx);
+    const int idata_split = ndata * 2/3;
+
+    std::vector loss_history(ndata);
+    for (int64_t idata = 0; idata < ndata; ++idata) {
+        loss_history[idata] = NAN;
+    }
+
+    for (int epoch = 1; epoch <= 4; ++epoch) {
+        if (high_level) {
+            ggml_opt_epoch(cd.opt_ctx, cd.dataset_unsupervised, cd.result, cd.result2, idata_split, nullptr, nullptr);
+        } else {
+            int idata = 0;
+            for (; idata < idata_split; ++idata) {
+                const float idataf = idata;
+                ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
+                ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
+                ggml_opt_eval(cd.opt_ctx, cd.result);
+                ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
+            }
+            for (; idata < ndata; ++idata) {
+                const float idataf = idata;
+                ggml_opt_alloc(cd.opt_ctx, /*backward =*/ false);
+                ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
+                ggml_opt_eval(cd.opt_ctx, cd.result2);
+                ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
+            }
+        }
+
+        {
+            float weights;
+            ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
+            const bool subtest_ok = weights == ndata/2 - epoch*idata_split;
+            helper_after_test_idata_split(__func__, high_level, epoch, "weights", subtest_ok, ntest, npass);
+        }
+        {
+            int64_t ndata_result;
+            ggml_opt_result_ndata(cd.result, &ndata_result);
+            bool subtest_ok = ndata_result == idata_split;
+
+            double loss;
+            double loss_unc;
+            ggml_opt_result_loss(cd.result, &loss, &loss_unc);
+            subtest_ok = subtest_ok && loss == 28.0 - epoch*16.0 && loss_unc == 0.0;
+
+            double accuracy;
+            double accuracy_unc;
+            ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
+            subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
+
+            helper_after_test_idata_split(__func__, high_level, epoch, "results_backward", subtest_ok, ntest, npass);
+        }
+        {
+            int64_t ndata_result;
+            ggml_opt_result_ndata(cd.result2, &ndata_result);
+            bool subtest_ok = ndata_result == ndata - idata_split;
+
+            double loss;
+            double loss_unc;
+            ggml_opt_result_loss(cd.result2, &loss, &loss_unc);
+            subtest_ok = subtest_ok && loss == 15.0 - epoch*8 && almost_equal(loss_unc, sqrt(0.5), 1e-10);
+
+            double accuracy;
+            double accuracy_unc;
+            ggml_opt_result_accuracy(cd.result2, &accuracy, &accuracy_unc);
+            subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
+
+            helper_after_test_idata_split(__func__, high_level, epoch, "results_forward", subtest_ok, ntest, npass);
+        }
+
+        ggml_opt_result_reset(cd.result);
+        ggml_opt_result_reset(cd.result2);
+    }
+
+    helper_free_ctx_data(cd);
+
+    return std::make_pair(npass, ntest);
+}
+
+static void helper_after_test_gradient_accumulation(
+        const char * func, const int nbatch_physical, const enum ggml_opt_loss_type loss_type, const int epoch,
+        const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {
+    std::string options = ", nbatch_physical=";
+    options += std::to_string(nbatch_physical);
+    options += ", loss_type=";
+    options += loss_type == GGML_OPT_LOSS_TYPE_MEAN ? "mean" : "sum";
+    options += ", epoch=";
+    options += std::to_string(epoch);
+    helper_after_test(func, false, options, subtest, subtest_ok, ntest, npass);
+}
+
+static std::pair test_gradient_accumulation(
+        ggml_backend_sched_t backend_sched, ggml_backend_t backend, const int32_t nbatch_physical, const enum ggml_opt_loss_type loss_type) {
+    int ntest = 0;
+    int npass = 0;
+
+    struct helper_ctx_data cd = helper_get_ctx_data(
+        backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false, /*nbatch_logical =*/ 6, nbatch_physical, loss_type);
+
+    std::vector grad_history(ndata);
+    for (int64_t idata = 0; idata < ndata; ++idata) {
+        grad_history[idata] = NAN;
+    }
+
+    for (int epoch = 1; epoch <= 4; ++epoch) {
+        if (nbatch_physical == 1) {
+            for (int idata = 0; idata < ndata; ++idata) {
+                const float idataf = idata;
+                ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
+                ggml_backend_tensor_set(cd.inputs, &idataf, 0, 1*sizeof(float));
+                ggml_opt_eval(cd.opt_ctx, cd.result);
+                ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, 1*sizeof(float));
+            }
+        } else if (nbatch_physical == 2) {
+            for (int idata = 0; idata < ndata; idata += 2) {
+                const float idataf[2] = {float(idata + 0), float(idata + 1)};
+                ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
+                ggml_backend_tensor_set(cd.inputs, idataf, 0, 2*sizeof(float));
+                ggml_opt_eval(cd.opt_ctx, cd.result);
+
+                grad_history[idata + 0] = 0.0f;
+                ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata + 1, 0, 1*sizeof(float));
+            }
+        } else {
+            GGML_ASSERT(false);
+        }
+
+        {
+            GGML_ASSERT(ndata == 6);
+            constexpr double atol = 1e-6;
+            bool subtest_ok = true;
+            if (loss_type == GGML_OPT_LOSS_TYPE_SUM) {
+                if (nbatch_physical == 1) {
+                    subtest_ok = subtest_ok && almost_equal(grad_history[0], 1.0, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[2], 3.0, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[4], 5.0, atol);
+                } else {
+                    subtest_ok = subtest_ok && almost_equal(grad_history[0], 0.0, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[2], 0.0, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[4], 0.0, atol);
+                }
+                subtest_ok = subtest_ok && almost_equal(grad_history[1], 2.0, atol);
+                subtest_ok = subtest_ok && almost_equal(grad_history[3], 4.0, atol);
+                subtest_ok = subtest_ok && almost_equal(grad_history[5], 6.0, atol);
+            } else if (loss_type == GGML_OPT_LOSS_TYPE_MEAN) {
+                if (nbatch_physical == 1) {
+                    subtest_ok = subtest_ok && almost_equal(grad_history[0], 1.0/ndata, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[2], 3.0/ndata, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[4], 5.0/ndata, atol);
+                } else {
+                    subtest_ok = subtest_ok && almost_equal(grad_history[0], 0.0/ndata, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[2], 0.0/ndata, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[4], 0.0/ndata, atol);
+                }
+                subtest_ok = subtest_ok && almost_equal(grad_history[1], 2.0/ndata, atol);
+                subtest_ok = subtest_ok && almost_equal(grad_history[3], 4.0/ndata, atol);
+                subtest_ok = subtest_ok && almost_equal(grad_history[5], 6.0/ndata, atol);
+            } else {
+                GGML_ASSERT(false);
+            }
+            helper_after_test_gradient_accumulation(__func__, nbatch_physical, loss_type, epoch, "grads", subtest_ok, ntest, npass);
+        }
+        {
+            float weights;
+            ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
+            const bool subtest_ok = weights == (ndata/2) - epoch;
+            helper_after_test_gradient_accumulation(__func__, nbatch_physical, loss_type, epoch, "weights", subtest_ok, ntest, npass);
+        }
+        {
+            int64_t ndata_result;
+            ggml_opt_result_ndata(cd.result, &ndata_result);
+            bool subtest_ok = ndata_result == ndata/nbatch_physical;
+
+            double loss;
+            ggml_opt_result_loss(cd.result, &loss, /*loss_unc =*/ nullptr);
+            if (loss_type == GGML_OPT_LOSS_TYPE_SUM) {
+                subtest_ok = subtest_ok && loss == (39.0 - epoch*6.0);
+            } else if (loss_type == GGML_OPT_LOSS_TYPE_MEAN) {
+                subtest_ok = subtest_ok && almost_equal(loss, (39.0 - epoch*6.0) / ndata, 1e-6);
+            } else {
+                GGML_ASSERT(false);
+            }
+
+            double accuracy;
+            double accuracy_unc;
+            ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
+            subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
+
+            helper_after_test_gradient_accumulation(__func__, nbatch_physical, loss_type, epoch, "results", subtest_ok, ntest, npass);
+        }
+
+        ggml_opt_result_reset(cd.result);
+    }
+
+    helper_free_ctx_data(cd);
+
+    return std::make_pair(npass, ntest);
+}
+
+static ggml_opt_optimizer_params helper_get_regression_opt_pars(void * userdata) {
+    ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(userdata);
+    result.adamw.alpha = 0.1f;
+    return result;
+}
+
+static std::pair test_regression(ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
+    int ntest = 0;
+    int npass = 0;
+
+    // Test for simple regression with f(x) = a*x + b
+
+    constexpr int64_t ndata_regression = 201;
+    constexpr float a_true = 1.2f;
+    constexpr float b_true = 3.4f;
+
+    std::mt19937 gen(12345);
+    std::normal_distribution nd{0.0f, 0.1f};
+
+    ggml_opt_dataset_t dataset = ggml_opt_dataset_init(
+        GGML_TYPE_F32, GGML_TYPE_F32, 1, 1, ndata_regression, ndata_regression);
+
+    float * data   = ggml_get_data_f32(ggml_opt_dataset_data(  dataset));
+    float * labels = ggml_get_data_f32(ggml_opt_dataset_labels(dataset));
+
+    constexpr float x_min = -100.0f;
+    constexpr float x_max =  100.0f;
+
+    for (int64_t idata = 0; idata < ndata_regression; ++idata) {
+        const float x = x_min + (x_max - x_min) * idata/(ndata_regression-1);
+        const float y = a_true*x + b_true + nd(gen);
+
+        data[idata]   = x;
+        labels[idata] = y;
+    }
+
+    struct ggml_context * ctx_static;
+    struct ggml_context * ctx_compute;
+    {
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ 3*ggml_tensor_overhead(),
+            /*.mem_buffer =*/ nullptr,
+            /*.no_alloc   =*/ true,
+        };
+        ctx_static = ggml_init(params);
+    }
+    {
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ GGML_DEFAULT_GRAPH_SIZE*ggml_tensor_overhead() + 3*ggml_graph_overhead(),
+            /*.mem_buffer =*/ nullptr,
+            /*.no_alloc   =*/ true,
+        };
+        ctx_compute = ggml_init(params);
+    }
+
+    // The first dimension is the dimension of the datapoints, the second dimension is the number of datapoints.
+    struct ggml_tensor * x = ggml_new_tensor_2d(ctx_static, GGML_TYPE_F32, 1, ndata_regression);
+    ggml_set_name(x, "x");
+
+    struct ggml_tensor * a = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);
+    ggml_set_name(a, "a");
+    ggml_set_param(a);
+
+    struct ggml_tensor * b = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);
+    ggml_set_name(b, "b");
+    ggml_set_param(b);
+
+    struct ggml_tensor * f = ggml_add(ctx_compute, ggml_mul(ctx_compute, x, a), b);
+    ggml_set_name(f, "f");
+
+    ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx_static, backend);
+    const float a0 = 1.0f;
+    const float b0 = 3.0f;
+    ggml_backend_tensor_set(a, &a0, 0, sizeof(float));
+    ggml_backend_tensor_set(b, &b0, 0, sizeof(float));
+
+    ggml_opt_fit(backend_sched, ctx_compute, x, f, dataset, GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR,
+        helper_get_regression_opt_pars, 100, ndata_regression, 0.0f, true);
+
+    {
+        float a_fit;
+        ggml_backend_tensor_get(a, &a_fit, 0, sizeof(float));
+        float b_fit;
+        ggml_backend_tensor_get(b, &b_fit, 0, sizeof(float));
+        const bool subtest_ok = almost_equal(a_fit, a_true, 1e-2) && almost_equal(b_fit, b_true, 1e-2);
+        printf("  %s(subtest=weights): ", __func__);
+        if (subtest_ok) {
+            printf("\033[1;32mOK\033[0m\n");
+            npass++;
+        } else {
+            printf("\033[1;31mFAIL\033[0m\n");
+        }
+        ntest++;
+    }
+
+    ggml_backend_buffer_free(buf);
+    ggml_free(ctx_static);
+    ggml_opt_dataset_free(dataset);
+
+    return std::make_pair(npass, ntest);
+}
+
+static std::pair test_backend(ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
+    int npass = 0;
+    int ntest = 0;
+
+    for (bool shuffle : {false, true}) {
+        std::pair partial = test_dataset(backend_sched, backend, shuffle);
+        npass += partial.first;
+        ntest += partial.second;
+    }
+    {
+        std::pair partial = test_grad(backend_sched, backend);
+        npass += partial.first;
+        ntest += partial.second;
+    }
+    for (bool high_level : {false, true}){
+        for (bool shuffle : {false, true}) {
+            if (!high_level && shuffle) {
+                continue;
+            }
+
+            std::pair partial = test_forward_backward(backend_sched, backend, high_level, shuffle);
+            npass += partial.first;
+            ntest += partial.second;
+        }
+    }
+    {
+        std::pair partial = test_epoch_vs_fit(backend_sched, backend);
+        npass += partial.first;
+        ntest += partial.second;
+    }
+    for (bool high_level : {false, true}){
+        std::pair partial = test_idata_split(backend_sched, backend, high_level);
+        npass += partial.first;
+        ntest += partial.second;
+    }
+    for (int32_t nbatch_physical : {2, 1}) {
+        for (enum ggml_opt_loss_type loss_type : {GGML_OPT_LOSS_TYPE_SUM, GGML_OPT_LOSS_TYPE_MEAN}) {
+            std::pair partial = test_gradient_accumulation(backend_sched, backend, nbatch_physical, loss_type);
+            npass += partial.first;
+            ntest += partial.second;
+        }
+    }
+    {
+        std::pair partial = test_regression(backend_sched, backend);
+        npass += partial.first;
+        ntest += partial.second;
+    }
+
+    return std::make_pair(npass, ntest);
+}
+
+int main(void) {
+    const size_t dev_count = ggml_backend_dev_count();
+    printf("Testing %zu devices\n\n", dev_count);
+    size_t n_ok = 0;
+
+    std::vector devs;
+    std::vector     backends;
+
+    for (size_t i = 0; i < dev_count; ++i) {
+        devs.push_back(ggml_backend_dev_get(i));
+
+        ggml_backend_t backend = ggml_backend_dev_init(devs[i], NULL);
+        GGML_ASSERT(backend != NULL);
+
+        if (ggml_backend_is_cpu(backend)) {
+            ggml_backend_cpu_set_n_threads(backend, std::thread::hardware_concurrency() / 2);
+        }
+
+        backends.push_back(backend);
+    }
+
+    for (size_t i = 0; i < dev_count; ++i) {
+        // Put the backend to be tested in front so that it's prioritized:
+        std::vector backends_modded = {backends[i]};
+        backends_modded.insert(backends_modded.end(), backends.begin(), backends.end());
+
+        ggml_backend_sched_t backend_sched = ggml_backend_sched_new(
+            backends_modded.data(), nullptr, backends_modded.size(), GGML_DEFAULT_GRAPH_SIZE, false, true);
+
+        printf("Backend %zu/%zu: %s\n", i + 1, dev_count, ggml_backend_dev_name(devs[i]));
+        printf("  Device description: %s\n", ggml_backend_dev_description(devs[i]));
+        size_t free, total; // NOLINT
+        ggml_backend_dev_memory(devs[i], &free, &total);
+        printf("  Device memory: %zu MB (%zu MB free)\n", total / 1024 / 1024, free / 1024 / 1024);
+        printf("\n");
+
+        std::pair result = test_backend(backend_sched, backends[i]);
+
+        printf("  %d/%d tests passed\n", result.first, result.second);
+        printf("  Backend %s: ", ggml_backend_name(backends[i]));
+        if (result.first == result.second) {
+            printf("\033[1;32mOK\033[0m\n");
+            n_ok++;
+        } else {
+            printf("\033[1;31mFAIL\033[0m\n");
+        }
+
+        printf("\n");
+
+        ggml_backend_sched_free(backend_sched);
+    }
+
+    for (ggml_backend_t backend : backends) {
+        ggml_backend_free(backend);
+    }
+
+    printf("%zu/%zu backends passed\n", n_ok, dev_count);
+    if (n_ok != dev_count) {
+        printf("\033[1;31mFAIL\033[0m\n");
+        return 1;
+    }
+    printf("\033[1;32mOK\033[0m\n");
+    return 0;
+}
diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..037c0582bbbf8005cdb18729b0ddbd0adc63bb96
--- /dev/null
+++ b/tests/test-quantize-fns.cpp
@@ -0,0 +1,186 @@
+// Unit tests for quantization specific functions - quantize, dequantize and dot product
+
+#include "ggml.h"
+#include "ggml-cpu.h"
+
+#undef NDEBUG
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+constexpr float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001f;
+constexpr float MAX_QUANTIZATION_TOTAL_ERROR = 0.002f;
+constexpr float MAX_QUANTIZATION_TOTAL_ERROR_TERNARY = 0.01f;
+constexpr float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f;
+constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f;
+constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS = 0.0050f;
+constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f;
+constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f;
+constexpr float MAX_DOT_PRODUCT_ERROR_TERNARY = 0.15f;
+
+static const char* RESULT_STR[] = {"ok", "FAILED"};
+
+
+// Generate synthetic data
+static void generate_data(float offset, size_t n, float * dst) {
+    for (size_t i = 0; i < n; i++) {
+        dst[i] = 0.1 + 2*cosf(i + offset);
+    }
+}
+
+// Calculate RMSE between two float arrays
+static float array_rmse(const float * a1, const float * a2, size_t n) {
+    double sum = 0;
+    for (size_t i = 0; i < n; i++) {
+        double diff = a1[i] - a2[i];
+        sum += diff * diff;
+    }
+    return sqrtf(sum) / n;
+}
+
+// Total quantization error on test data
+static float total_quantization_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data) {
+    std::vector tmp_q(2*test_size);
+    std::vector tmp_out(test_size);
+
+    qfns_cpu->from_float(test_data, tmp_q.data(), test_size);
+    qfns->to_float(tmp_q.data(), tmp_out.data(), test_size);
+    return array_rmse(test_data, tmp_out.data(), test_size);
+}
+
+// Total quantization error on test data
+static float reference_quantization_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data) {
+    std::vector tmp_q(2*test_size);
+    std::vector tmp_out(test_size);
+    std::vector tmp_out_ref(test_size);
+
+    // FIXME: why is done twice?
+    qfns_cpu->from_float(test_data, tmp_q.data(), test_size);
+    qfns->to_float(tmp_q.data(), tmp_out.data(), test_size);
+
+    qfns->from_float_ref(test_data, tmp_q.data(), test_size);
+    qfns->to_float(tmp_q.data(), tmp_out_ref.data(), test_size);
+
+    return array_rmse(tmp_out.data(), tmp_out_ref.data(), test_size);
+}
+
+static float dot_product(const float * a1, const float * a2, size_t test_size) {
+    double sum = 0;
+    for (size_t i = 0; i < test_size; i++) {
+        sum += a1[i] * a2[i];
+    }
+    return sum;
+}
+
+// Total dot product error
+static float dot_product_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data1, const float * test_data2) {
+    GGML_UNUSED(qfns);
+
+    std::vector tmp_q1(2*test_size);
+    std::vector tmp_q2(2*test_size);
+
+    const auto * vdot = ggml_get_type_traits_cpu(qfns_cpu->vec_dot_type);
+
+    qfns_cpu->from_float(test_data1, tmp_q1.data(), test_size);
+    vdot->from_float(test_data2, tmp_q2.data(), test_size);
+
+    float result = INFINITY;
+    qfns_cpu->vec_dot(test_size, &result, 0, tmp_q1.data(), 0, tmp_q2.data(), 0, 1);
+
+    const float dot_ref = dot_product(test_data1, test_data2, test_size);
+
+    return fabsf(result - dot_ref) / test_size;
+}
+
+int main(int argc, char * argv[]) {
+    bool verbose = false;
+    const size_t test_size = 32 * 128;
+
+    std::string arg;
+    for (int i = 1; i < argc; i++) {
+        arg = argv[i];
+
+        if (arg == "-v") {
+            verbose = true;
+        } else {
+            fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
+            return 1;
+        }
+    }
+
+    std::vector test_data(test_size);
+    std::vector test_data2(test_size);
+
+    generate_data(0.0, test_data.size(), test_data.data());
+    generate_data(1.0, test_data2.size(), test_data2.data());
+
+    ggml_cpu_init();
+
+    int num_failed = 0;
+    bool failed = false;
+
+    for (int i = 0; i < GGML_TYPE_COUNT; i++) {
+        ggml_type type = (ggml_type) i;
+        const auto * qfns = ggml_get_type_traits(type);
+        const auto * qfns_cpu = ggml_get_type_traits_cpu(type);
+
+        // deprecated - skip
+        if (qfns->blck_size == 0) {
+            continue;
+        }
+
+        const ggml_type ei = (ggml_type)i;
+
+        printf("Testing %s\n", ggml_type_name((ggml_type) i));
+        ggml_quantize_init(ei);
+
+        if (qfns_cpu->from_float && qfns->to_float) {
+            const float total_error = total_quantization_error(qfns, qfns_cpu, test_size, test_data.data());
+            const float max_quantization_error =
+                type == GGML_TYPE_TQ1_0   ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :
+                type == GGML_TYPE_TQ2_0   ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :
+                type == GGML_TYPE_Q2_K    ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
+                type == GGML_TYPE_IQ2_S   ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
+                type == GGML_TYPE_Q3_K    ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
+                type == GGML_TYPE_IQ3_S   ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
+                type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS : MAX_QUANTIZATION_TOTAL_ERROR;
+            failed = !(total_error < max_quantization_error);
+            num_failed += failed;
+            if (failed || verbose) {
+                printf("%5s absolute quantization error:    %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error);
+            }
+
+            const float reference_error = reference_quantization_error(qfns, qfns_cpu, test_size, test_data.data());
+            failed = !(reference_error < MAX_QUANTIZATION_REFERENCE_ERROR);
+            num_failed += failed;
+            if (failed || verbose) {
+                printf("%5s reference implementation error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], reference_error);
+            }
+
+            const float vec_dot_error = dot_product_error(qfns, qfns_cpu, test_size, test_data.data(), test_data2.data());
+            const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS ||
+                                            type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S
+                                          ? MAX_DOT_PRODUCT_ERROR_LOWBIT
+                                          : type == GGML_TYPE_TQ1_0 || type == GGML_TYPE_TQ2_0
+                                          ? MAX_DOT_PRODUCT_ERROR_TERNARY
+                                          : MAX_DOT_PRODUCT_ERROR;
+            failed = !(vec_dot_error < max_allowed_error);
+            num_failed += failed;
+            if (failed || verbose) {
+                printf("%5s dot product error:              %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error);
+            }
+        }
+    }
+
+    if (num_failed || verbose) {
+        printf("%d tests failed\n", num_failed);
+    }
+
+    return num_failed > 0;
+}
diff --git a/tests/test-quantize-perf.cpp b/tests/test-quantize-perf.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..2882884938388393e5b54bd1c584385951df60d5
--- /dev/null
+++ b/tests/test-quantize-perf.cpp
@@ -0,0 +1,365 @@
+// Benchmark quantization specific functions on synthetic data
+
+#include "ggml.h"
+#include "ggml-cpu.h"
+
+#undef NDEBUG
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+#define MAX_ALIGNMENT 64
+#define QK 32
+#define WARMUP 5
+#define ITERATIONS 10
+#define MAX_ITERATIONS 100000000
+
+#define L1_SIZE      32*128
+#define L2_SIZE     32*2048
+#define L3_SIZE    32*20480
+#define MEM_SIZE 32*2048000
+
+struct quantize_perf_params {
+    std::vector include_types;
+    std::vector test_sizes;
+    size_t alignment_offset = 0;
+    bool op_quantize_row_q_reference = false;
+    bool op_quantize_row_q = false;
+    bool op_dequantize_row_q = false;
+    bool op_quantize_row_q_dot = false;
+    bool op_vec_dot_q = false;
+    int64_t iterations = ITERATIONS;
+};
+
+#if defined(__x86_64__) || defined(__i386__)
+
+#include 
+inline int64_t cpu_cycles() {
+// Rough way to detect new-ish CPUs
+#ifdef __POPCNT__
+    unsigned int dummy;
+    return __rdtscp(&dummy);
+#else
+    return __rdtsc();
+#endif
+}
+
+#else
+
+#define cpu_cycles() 0
+
+#endif
+
+
+// Generate synthetic data
+static void generate_data(float offset, size_t n, float * dst) {
+    for (size_t i = 0; i < n; i++) {
+        dst[i] = 0.1 + 2*cosf(i + offset);
+    }
+}
+
+static float gigabytes_per_second(size_t bytes, int64_t usecs) {
+    return bytes / (float) usecs * 1000000 / (1024*1024*1024);
+}
+
+static void * align_with_offset(void * ptr, int offset) {
+    size_t dummy_size = MAX_ALIGNMENT * 4;
+    return (char *) std::align(MAX_ALIGNMENT, MAX_ALIGNMENT, ptr, dummy_size) + offset;
+}
+
+static void benchmark_function(size_t size, size_t q_size, int64_t iterations, const std::function & func) {
+    int64_t min_time_us = INT64_MAX;
+    int64_t total_time_us = 0;
+    int64_t min_time_cycles = INT64_MAX;
+    int64_t total_time_cycles = 0;
+
+    for (int i = 0; i < WARMUP; i++) {
+        func();
+    }
+
+    for (int i = 0; i < iterations; i++) {
+        const int64_t start_time = ggml_time_us();
+        const int64_t start_cycles = cpu_cycles();
+
+        func();
+
+        const int64_t end_cycles = cpu_cycles();
+        const int64_t end_time = ggml_time_us();
+
+        total_time_cycles += end_cycles - start_cycles;
+        min_time_cycles = std::min(min_time_cycles, end_cycles - start_cycles);
+        total_time_us += end_time - start_time;
+        min_time_us = std::min(min_time_us, end_time - start_time);
+    }
+
+    printf("      min cycles/%d vals   : %9.2f\n",  QK, QK * min_time_cycles / (float) size);
+    printf("      avg cycles/%d vals   : %9.2f\n",  QK, QK * total_time_cycles / (float) (size * iterations));
+    printf("      float32 throughput   : %9.2f GB/s\n",  gigabytes_per_second(4 * size * iterations, total_time_us));
+    printf("      quantized throughput : %9.2f GB/s\n",  gigabytes_per_second(q_size * iterations, total_time_us));
+}
+
+static void usage(char * argv[]) {
+    printf("Benchmark quantization specific functions on synthetic data\n");
+    printf("\n");
+    printf("usage: %s [options]\n", argv[0]);
+    printf("\n");
+    printf("options: (default)\n");
+    printf("  -h, --help            show this help message and exit\n");
+    printf("  --size SIZE           set test size, divisible by 32 (L1_SIZE:%d)\n", L1_SIZE);
+    printf("  -3                    use size as L1, L2, L3 sizes (L1:%d L2:%d L3:%d)\n", L1_SIZE, L2_SIZE, L3_SIZE);
+    printf("  -4                    use size as L1, L2, L3, MEM sizes (L1:%d L2:%d L3:%d MEM:%d)\n", L1_SIZE, L2_SIZE, L3_SIZE, MEM_SIZE);
+    printf("  --op OP               set test operation as quantize_row_q_reference, quantize_row_q, dequantize_row_q,\n");
+    printf("                        quantize_row_q_dot, vec_dot_q (all)\n");
+    printf("  --type TYPE           set test type as");
+    for (int i = 0; i < GGML_TYPE_COUNT; i++) {
+        ggml_type type = (ggml_type) i;
+        const auto * qfns     = ggml_get_type_traits(type);
+        const auto * qfns_cpu = ggml_get_type_traits_cpu(type);
+        if (ggml_type_name(type) != NULL) {
+            if (qfns_cpu->from_float && qfns->to_float) {
+                printf(" %s", ggml_type_name(type));
+            }
+        }
+    }
+    printf(" (all)\n");
+    printf("  --alignment-offset OFFSET\n");
+    printf("                        set alignment offset as OFFSET (0)\n");
+    printf("  -i NUM, --iterations NUM\n");
+    printf("                        set test iteration number (%d)\n", ITERATIONS);
+}
+
+int main(int argc, char * argv[]) {
+    quantize_perf_params params {};
+
+    // read command line
+
+    bool invalid_param = false;
+    std::string arg;
+    for (int i = 1; i < argc; i++) {
+        arg = argv[i];
+
+        if (arg == "--size") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            size_t size = std::stoi(argv[i]);
+            if (size % 32 != 0) {
+                fprintf(stderr, "error: size %zu not divisible by 32\n", size);
+                invalid_param = true;
+                break;
+            }
+            params.test_sizes.push_back(size);
+        } else if (arg == "-3") {
+            // quick select sizes that probably fit in CPU caches
+            params.test_sizes.push_back(L1_SIZE);
+            params.test_sizes.push_back(L2_SIZE);
+            params.test_sizes.push_back(L3_SIZE);
+        } else if (arg == "-4") {
+            // quick select cache sizes + memory
+            params.test_sizes.push_back(L1_SIZE);
+            params.test_sizes.push_back(L2_SIZE);
+            params.test_sizes.push_back(L3_SIZE);
+            params.test_sizes.push_back(MEM_SIZE);
+        } else if (arg == "--op") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            std::string op {argv[i]};
+            if (op == "quantize_row_q_reference") {
+                params.op_quantize_row_q_reference = true;
+            } else if (op == "quantize_row_q") {
+                params.op_quantize_row_q = true;
+            } else if (op == "dequantize_row_q") {
+                params.op_dequantize_row_q = true;
+            } else if (op == "quantize_row_q_dot") {
+                params.op_quantize_row_q_dot = true;
+            } else if (op == "vec_dot_q") {
+                params.op_vec_dot_q = true;
+            } else {
+                invalid_param = true;
+                break;
+            }
+        } else if (arg == "--type") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.include_types.push_back(argv[i]);
+        } else if (arg == "--alignment-offset") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            int alignment = std::stoi(argv[i]);
+            if (alignment < 0 || alignment > MAX_ALIGNMENT) {
+            fprintf(stderr, "error: alignment-offset must be less than %d\n", MAX_ALIGNMENT);
+                invalid_param = true;
+                break;
+            }
+            params.alignment_offset = alignment;
+        } else if ((arg == "-i") || (arg == "--iterations")) {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            int number = std::stoi(argv[i]);
+            if (number < 0 || number > MAX_ITERATIONS) {
+            fprintf(stderr, "error: iterations must be less than %d\n", MAX_ITERATIONS);
+                invalid_param = true;
+                break;
+            }
+            params.iterations = number;
+        } else if ((arg == "-h") || (arg == "--help")) {
+            usage(argv);
+            return 1;
+        } else {
+            fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
+            return 1;
+        }
+    }
+    if (invalid_param) {
+        fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
+        return 1;
+    }
+
+    if (params.test_sizes.empty()) {
+        params.test_sizes.push_back(L1_SIZE);
+    }
+    if (!(params.op_quantize_row_q_reference || params.op_quantize_row_q || params.op_dequantize_row_q || params.op_quantize_row_q_dot || params.op_vec_dot_q)) {
+        params.op_quantize_row_q_reference = params.op_quantize_row_q = params.op_dequantize_row_q = params.op_quantize_row_q_dot = params.op_vec_dot_q = true;
+    }
+
+    std::sort(params.test_sizes.begin(), params.test_sizes.end());
+    size_t largest = params.test_sizes.back();
+
+    std::vector test_data1_v(largest*4 + MAX_ALIGNMENT*2);
+    std::vector test_data2_v(largest*4 + MAX_ALIGNMENT*2);
+    std::vector test_q1_v   (largest*4 + MAX_ALIGNMENT*2);
+    std::vector test_q2_v   (largest*4 + MAX_ALIGNMENT*2);
+    std::vector test_out_v  (largest*4 + MAX_ALIGNMENT*2);
+
+    float * test_data1 = (float *) align_with_offset(test_data1_v.data(), params.alignment_offset);
+    float * test_data2 = (float *) align_with_offset(test_data2_v.data(), params.alignment_offset);
+    float * test_q1    = (float *) align_with_offset(test_q1_v.data(),    params.alignment_offset);
+    float * test_q2    = (float *) align_with_offset(test_q2_v.data(),    params.alignment_offset);
+    float * test_out   = (float *) align_with_offset(test_out_v.data(),   params.alignment_offset);
+
+    generate_data(0, largest, test_data1);
+    generate_data(1, largest, test_data2);
+
+    int64_t iterations = params.iterations;
+
+
+    // Initialize GGML, ensures float conversion tables are initialized
+    struct ggml_init_params ggml_params = {
+        /* .mem_size   = */ 1*1024,
+        /* .mem_buffer = */ NULL,
+        /* .no_alloc   = */ true,
+    };
+    struct ggml_context * ctx = ggml_init(ggml_params);
+
+    for (int i = 0; i < GGML_TYPE_COUNT; i++) {
+        ggml_type type = (ggml_type) i;
+        const auto * qfns = ggml_get_type_traits(type);
+        const auto * qfns_cpu = ggml_get_type_traits_cpu(type);
+        if (!params.include_types.empty() && ggml_type_name(type) && std::find(params.include_types.begin(), params.include_types.end(), ggml_type_name(type)) == params.include_types.end()) {
+            continue;
+        }
+
+        if (qfns_cpu->from_float && qfns->to_float) {
+            printf("%s\n", ggml_type_name(type));
+
+            ggml_quantize_init(type);
+
+            if (params.op_quantize_row_q_reference) {
+                printf("  quantize_row_q_reference\n");
+                for (size_t size : params.test_sizes) {
+                    printf("    %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
+                    auto quantize_fn = [&](void) -> float {
+                        qfns->from_float_ref(test_data1, test_q1, size);
+                        return test_q1[0];
+                    };
+                    size_t quantized_size = ggml_row_size(type, size);
+                    benchmark_function(size, quantized_size, iterations, quantize_fn);
+                }
+                printf("\n");
+            }
+
+            if (params.op_quantize_row_q) {
+                printf("  quantize_row_q\n");
+                for (size_t size : params.test_sizes) {
+                    printf("    %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
+                    auto quantize_fn = [&](void) -> float {
+                        qfns_cpu->from_float(test_data1, test_q1, size);
+                        return test_q1[0];
+                    };
+                    size_t quantized_size = ggml_row_size(type, size);
+                    benchmark_function(size, quantized_size, iterations, quantize_fn);
+                }
+                printf("\n");
+            }
+
+            if (params.op_dequantize_row_q) {
+                printf("  dequantize_row_q\n");
+                qfns_cpu->from_float(test_data1, test_q1, largest);
+                for (size_t size : params.test_sizes) {
+                    printf("    %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
+                    auto quantize_fn = [&](void) -> float {
+                        qfns->to_float(test_q1, test_out, size);
+                        return test_out[0];
+                    };
+                    size_t quantized_size = ggml_row_size(type, size);
+                    benchmark_function(size, quantized_size, iterations, quantize_fn);
+                }
+                printf("\n");
+            }
+
+            if (params.op_quantize_row_q_dot) {
+                printf("  quantize_row_q_dot\n");
+                for (size_t size : params.test_sizes) {
+                    printf("    %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
+                    auto quantize_fn = [&](void) -> float {
+                        const auto * vdot = ggml_get_type_traits_cpu(qfns_cpu->vec_dot_type);
+                        vdot->from_float(test_data1, test_q1, size);
+                        return test_q1[0];
+                    };
+                    size_t quantized_size = ggml_row_size(type, size);
+                    benchmark_function(size, quantized_size, iterations, quantize_fn);
+                }
+                printf("\n");
+            }
+
+            if (params.op_vec_dot_q) {
+                printf("  vec_dot_q\n");
+                qfns_cpu->from_float(test_data1, test_q1, largest);
+                qfns_cpu->from_float(test_data2, test_q2, largest);
+                for (size_t size : params.test_sizes) {
+                    printf("    %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
+                    auto quantize_fn = [&](void) -> float {
+                        float result;
+                        qfns_cpu->vec_dot(size, &result, 0, test_q1, 0, test_q2, 0, 1);
+                        return result;
+                    };
+                    size_t quantized_size = ggml_row_size(type, size);
+                    benchmark_function(size, quantized_size, iterations, quantize_fn);
+                }
+                printf("\n");
+            }
+        }
+    }
+
+    ggml_free(ctx);
+
+    return 0;
+}
diff --git a/tests/test-quantize-stats.cpp b/tests/test-quantize-stats.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..a284a1f0c5e31c8471c7bf223fe33cf803a52224
--- /dev/null
+++ b/tests/test-quantize-stats.cpp
@@ -0,0 +1,424 @@
+#include "ggml.h"
+#include "ggml-cpu.h"
+#include "llama.h"
+#include "common.h"
+
+#include "../src/llama-model.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+struct quantize_stats_params {
+    std::string model = DEFAULT_MODEL_PATH;
+    bool verbose = false;
+    bool per_layer_stats = false;
+    bool print_histogram = false;
+    bool reference = false;
+    std::vector include_layers;
+    std::vector exclude_layers;
+    std::vector include_types;
+};
+
+constexpr size_t HISTOGRAM_BUCKETS = 150;
+constexpr double HISTOGRAM_RANGE = 0.03;
+
+struct error_stats {
+    size_t num_samples;
+    double total_error;
+    double max_error;
+    uint64_t error_histogram[HISTOGRAM_BUCKETS];
+};
+
+static void quantize_stats_print_usage(int /*argc*/, char ** argv) {
+    quantize_stats_params params;
+    fprintf(stderr, "usage: %s [options]\n", argv[0]);
+    fprintf(stderr, "\n");
+    fprintf(stderr, "options:\n");
+    fprintf(stderr, "  -h, --help            show this help message and exit\n");
+    fprintf(stderr, "  -m FNAME, --model FNAME\n");
+    fprintf(stderr, "                        model path (default: %s)\n", params.model.c_str());
+    fprintf(stderr, "  -r, --reference\n");
+    fprintf(stderr, "                        use reference implementation (default: false)\n");
+    fprintf(stderr, "  -v, --verbose\n");
+    fprintf(stderr, "                        verbose output (default: false)\n");
+    fprintf(stderr, "  -p, --per-layer-stats\n");
+    fprintf(stderr, "                        print stats per layer (default: false)\n");
+    fprintf(stderr, "  --histogram\n");
+    fprintf(stderr, "                        print error histogram (default: false)\n");
+    fprintf(stderr, "  -l LAYER, --include-layer LAYER\n");
+    fprintf(stderr, "                        only test layers matching pattern\n");
+    fprintf(stderr, "  -L LAYER, --exclude-layer LAYER\n");
+    fprintf(stderr, "                        exclude layers matching pattern\n");
+    fprintf(stderr, "  -t TYPE, --type TYPE\n");
+    fprintf(stderr, "                        only test given type (q4_0, q4_1)\n");
+    fprintf(stderr, "\n");
+}
+
+// Check if a layer is included/excluded by command line
+static bool layer_included(const quantize_stats_params & params, const std::string & layer) {
+    for (const auto& excluded : params.exclude_layers) {
+        if (std::regex_search(layer, std::regex(excluded))) {
+            return false;
+        }
+    }
+    for (const auto& included : params.include_layers) {
+        if (std::regex_search(layer, std::regex(included))) {
+            return true;
+        }
+    }
+    return params.include_layers.empty();
+}
+
+// Update error statistics given vectors with the before/after result of quantization
+static void update_error_stats(int64_t nelements, const float * input, const float * output, error_stats & stats) {
+    for (int64_t i = 0; i < nelements; i++) {
+        double diff = input[i] - output[i];
+        stats.total_error += diff * diff;
+        stats.max_error = fmax(fabs(diff), stats.max_error);
+        stats.error_histogram[std::max(std::min((size_t) floor(fabs(diff) / HISTOGRAM_RANGE * HISTOGRAM_BUCKETS), HISTOGRAM_BUCKETS-1), (size_t) 0)]++;
+    }
+    stats.num_samples += nelements;
+}
+
+static void combine_error_stats(error_stats & into, const error_stats & from) {
+    into.num_samples += from.num_samples;
+    into.total_error += from.total_error;
+    if (from.max_error > into.max_error) into.max_error = from.max_error;
+    for (size_t i=0; i= sum*quantile) {
+            return (i+1) * HISTOGRAM_RANGE / HISTOGRAM_BUCKETS;
+        }
+    }
+    return INFINITY;
+}
+
+static void print_error_stats(const std::string & name, const error_stats & stats, bool print_histogram) {
+    double rmse = sqrt(stats.total_error / (double) stats.num_samples);
+    double median = find_quantile(stats, .5);
+    double pct95 = find_quantile(stats, .95);
+    printf("%-50s: rmse %.8f, maxerr %.8f, 95pct<%.4f, median<%.4f\n", name.c_str(), rmse, stats.max_error, pct95, median);
+    if (print_histogram) {
+        printf("Error distribution:\n");
+        for (size_t i = 0; i < HISTOGRAM_BUCKETS; i++) {
+            double lower = i * HISTOGRAM_RANGE / HISTOGRAM_BUCKETS;
+            double upper = (i+1) * HISTOGRAM_RANGE / HISTOGRAM_BUCKETS;
+            if (i == HISTOGRAM_BUCKETS -1) upper = INFINITY;
+            printf("[%3.4f, %3.4f): %11" PRIu64 "\n", lower, upper, stats.error_histogram[i]);
+        }
+    }
+}
+
+// copied from ggml.h - verify that we can access this as a flat array
+static bool tensor_is_contiguous(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return
+        tensor->nb[0] == ggml_type_size(tensor->type) &&
+        tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) &&
+        tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
+        tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
+}
+
+static void test_roundtrip_on_chunk(
+    const ggml_tensor * layer, int64_t offset, int64_t chunk_size, const ggml_type_traits & qfns, const ggml_type_traits_cpu & qfns_cpu, bool use_reference,
+    float * input_scratch, char * quantized_scratch, float * output_scratch, error_stats & stats
+) {
+    if (layer->type == GGML_TYPE_F16) {
+        for (int i = 0; i < chunk_size; i++) {
+            input_scratch[i] = ggml_get_f32_1d(layer, i + offset);
+        }
+    } else {
+        input_scratch = ggml_get_data_f32(layer) + offset;
+    }
+
+    if (use_reference) {
+        qfns.from_float_ref(input_scratch, quantized_scratch, chunk_size);
+    } else {
+        qfns_cpu.from_float(input_scratch, quantized_scratch, chunk_size);
+    }
+    qfns.to_float(quantized_scratch, output_scratch, chunk_size);
+
+    update_error_stats(chunk_size, input_scratch, output_scratch, stats);
+}
+
+
+// Run quantization function for a single layer and update error stats
+static void test_roundtrip_on_layer(
+    std::string & name, bool print_layer_stats, const ggml_type_traits & qfns, const ggml_type_traits_cpu & qfns_cpu, bool use_reference,
+    const ggml_tensor * layer, std::vector & input_scratch, std::vector & quantized_scratch,
+    std::vector & output_scratch, error_stats & total_error, int max_thread = 0
+) {
+    assert(tensor_is_contiguous(layer));
+    error_stats layer_error {};
+    uint64_t nelements = ggml_nelements(layer);
+
+    float* input_scratch_ptr = nullptr;
+    if (layer->type == GGML_TYPE_F16) {
+        if (input_scratch.size() < nelements) input_scratch.resize(nelements);
+        input_scratch_ptr = input_scratch.data();
+    }
+    if (quantized_scratch.size() < 4*nelements) quantized_scratch.resize(4*nelements);
+    if (output_scratch.size() < nelements) output_scratch.resize(nelements);
+
+    if (max_thread < 1) max_thread = std::thread::hardware_concurrency();
+    int chunk_size = 32*512;
+    int num_chunks = (nelements + chunk_size - 1)/chunk_size;
+
+    if (num_chunks < 2 || max_thread < 2) {
+        test_roundtrip_on_chunk(layer, 0, nelements, qfns, qfns_cpu, use_reference, input_scratch_ptr, quantized_scratch.data(),
+                output_scratch.data(), print_layer_stats ? layer_error : total_error);
+    } else {
+        auto & stats = print_layer_stats ? layer_error : total_error;
+        std::mutex mutex;
+        uint64_t counter = 0;
+        auto compute = [&mutex, &counter, &stats, &qfns, &qfns_cpu, nelements, layer, use_reference, input_scratch_ptr,
+             &quantized_scratch, &output_scratch, chunk_size] () {
+            error_stats local_stats {};
+            while (true) {
+                std::unique_lock lock(mutex);
+                uint64_t offset = counter; counter += chunk_size;
+                if (offset >= nelements) {
+                    combine_error_stats(stats, local_stats);
+                    break;
+                }
+                lock.unlock();
+                uint64_t chunk = offset + chunk_size < nelements ? chunk_size : nelements - offset;
+                test_roundtrip_on_chunk(layer, offset, chunk, qfns, qfns_cpu, use_reference, input_scratch_ptr + offset,
+                        quantized_scratch.data() + 4*offset, output_scratch.data() + offset, local_stats);
+            }
+        };
+        int nthread = std::min(num_chunks, max_thread);
+        std::vector workers(nthread-1);
+        for (auto& w : workers) w = std::thread(compute);
+        compute();
+        for (auto& w : workers) w.join();
+    }
+
+    if (print_layer_stats) {
+        print_error_stats(name, layer_error, false);
+        combine_error_stats(total_error, layer_error);
+    }
+}
+
+int main(int argc, char ** argv) {
+    ggml_time_init();
+
+    quantize_stats_params params;
+
+    // read command line
+
+    int max_thread = 0;
+    bool invalid_param = false;
+    std::string arg;
+    for (int i = 1; i < argc; i++) {
+        arg = argv[i];
+
+        if (arg == "-h" || arg == "--help") {
+            quantize_stats_print_usage(argc, argv);
+            exit(0);
+        } else if (arg == "-r" || arg == "--reference") {
+            params.reference = true;
+        } else if (arg == "-v") {
+            params.verbose = true;
+        } else if (arg == "-p" || arg == "--per-layer-stats") {
+            params.per_layer_stats = true;
+        } else if (arg == "--histogram") {
+            params.print_histogram = true;
+        } else if (arg == "-m" || arg == "--model") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.model = argv[i];
+        } else if (arg == "-l" || arg == "--include-layer") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.include_layers.emplace_back(argv[i]);
+        } else if (arg == "-L" || arg == "--exclude-layer") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.exclude_layers.emplace_back(argv[i]);
+        } else if (arg == "-t" || arg == "--type") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            int j;
+            for (j = 0; j < GGML_TYPE_COUNT; ++j) {
+               const auto * name = ggml_type_name((ggml_type) j);
+               if (name && strcmp(argv[i], name) == 0) break;
+            }
+            if (j < GGML_TYPE_COUNT) {
+                params.include_types.push_back((ggml_type) j);
+            } else {
+                fprintf(stderr, "error: %s not in list of types\n", argv[i]);
+                invalid_param = true;
+            }
+        } else if (arg == "-n" || arg == "--num-threads") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            max_thread = atoi(argv[i]);
+        } else {
+            fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
+            quantize_stats_print_usage(argc, argv);
+            return 1;
+        }
+    }
+    if (invalid_param) {
+        fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
+        quantize_stats_print_usage(argc, argv);
+        return 1;
+    }
+
+    print_build_info();
+
+    // load the model
+    fprintf(stderr, "Loading model\n");
+
+    const int64_t t_main_start_us = ggml_time_us();
+    llama_model * model;
+    llama_context * ctx;
+
+    {
+        auto mparams = llama_model_default_params();
+        mparams.use_mlock  = false;
+
+        model = llama_model_load_from_file(params.model.c_str(), mparams);
+
+        if (model == NULL) {
+            fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
+            return 1;
+        }
+
+        auto cparams = llama_context_default_params();
+        cparams.n_ctx = 256;
+
+        ctx = llama_init_from_model(model, cparams);
+
+        if (ctx == NULL) {
+            fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
+            llama_model_free(model);
+            return 1;
+        }
+    }
+
+    const auto & tensors = llama_internal_get_tensor_map(model);
+
+    // check layer tensors
+    int included_layers = 0;
+    int64_t max_nelements = 0;
+    bool is_f16 = false;
+    for (const auto & kv_tensor : tensors) {
+        if (!layer_included(params, kv_tensor.first)) {
+            continue;
+        }
+        if (params.verbose) {
+            printf("%s: type %s, size %" PRId64 "\n", kv_tensor.first.c_str(), ggml_type_name(kv_tensor.second->type), ggml_nelements(kv_tensor.second));
+        }
+        if (kv_tensor.second->type == GGML_TYPE_F16) {
+            is_f16 = true;
+        } else if (kv_tensor.second->type != GGML_TYPE_F32) {
+            fprintf(stderr, "%s: error: Quantization should be tested with a float model, "
+                "this model contains already quantized layers (%s is type %d)\n", __func__, kv_tensor.first.c_str(), kv_tensor.second->type);
+            llama_free(ctx);
+            llama_model_free(model);
+            return 1;
+        }
+        included_layers++;
+        max_nelements = std::max(max_nelements, ggml_nelements(kv_tensor.second));
+    }
+
+    if (is_f16) {
+        printf("note: source model is f16\n");
+    }
+    printf("testing %d layers with max size %" PRId64 "\n", included_layers, max_nelements);
+    // allocate scratch space
+    std::vector input_scratch;
+    std::vector quantized_scratch;
+    std::vector output_scratch;
+
+    // loop throught quantization types
+    for (int i = 0; i < GGML_TYPE_COUNT; i++) {
+        const ggml_type type = (ggml_type) i;
+        if (!params.include_types.empty() && std::find(params.include_types.begin(), params.include_types.end(), i) == params.include_types.end()) {
+            continue;
+        }
+        const auto * qfns     = ggml_get_type_traits(type);
+        const auto * qfns_cpu = ggml_get_type_traits_cpu(type);
+        if (qfns_cpu->from_float && qfns->to_float) {
+            if (params.verbose) {
+                printf("testing %s ...\n",  ggml_type_name(type));
+            }
+
+            ggml_quantize_init(type);
+
+            error_stats global_stats {};
+
+            for (const auto & kv_tensor : tensors) {
+                if (!layer_included(params, kv_tensor.first)) {
+                    continue;
+                }
+                if (params.verbose) {
+                    printf("  %s ...\n",  kv_tensor.first.c_str());
+                }
+                std::string layer_name { ggml_type_name(type) };
+                layer_name += "::" + kv_tensor.first;
+                test_roundtrip_on_layer(
+                        layer_name,
+                        params.per_layer_stats,
+                        *qfns, *qfns_cpu,
+                        params.reference,
+                        kv_tensor.second,
+                        input_scratch,
+                        quantized_scratch,
+                        output_scratch,
+                        global_stats,
+                        max_thread
+                );
+            }
+
+            print_error_stats(ggml_type_name(type), global_stats, params.print_histogram);
+        }
+    }
+
+
+    llama_free(ctx);
+    llama_model_free(model);
+    // report timing
+    {
+        const int64_t t_main_end_us = ggml_time_us();
+
+        printf("\n");
+        printf("%s:    total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0);
+    }
+
+    return 0;
+}
diff --git a/tests/test-regex-partial.cpp b/tests/test-regex-partial.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..ffad1897860a59cb70c4c9daf80e98291797917f
--- /dev/null
+++ b/tests/test-regex-partial.cpp
@@ -0,0 +1,288 @@
+//  Tests common_regex (esp. its partial final matches support).
+
+#include "common.h"
+#include "regex-partial.h"
+
+#include 
+#include 
+#include 
+
+template  static void assert_equals(const T & expected, const T & actual) {
+    if (expected != actual) {
+        std::cerr << "Expected: " << expected << std::endl;
+        std::cerr << "  Actual: " << actual << std::endl;
+        std::cerr << std::flush;
+        throw std::runtime_error("Test failed");
+    }
+}
+
+struct test_case {
+    std::string pattern;
+    struct input_output {
+        std::string input;
+        common_regex_match output;
+    };
+    std::vector inputs_outputs;
+};
+
+static std::string common_regex_match_type_name(common_regex_match_type type) {
+    switch (type) {
+        case COMMON_REGEX_MATCH_TYPE_NONE:
+            return "COMMON_REGEX_MATCH_TYPE_NONE";
+        case COMMON_REGEX_MATCH_TYPE_PARTIAL:
+            return "COMMON_REGEX_MATCH_TYPE_PARTIAL";
+        case COMMON_REGEX_MATCH_TYPE_FULL:
+            return "COMMON_REGEX_MATCH_TYPE_FULL";
+    }
+    return "?";
+}
+
+static void test_regex() {
+    printf("[%s]\n", __func__);
+    auto test = [](const test_case & test_case) {
+        common_regex cr(test_case.pattern);
+        std::cout << "Testing pattern: /" << test_case.pattern << "/\n";
+        // std::cout << "    partial rev: " << cr.reversed_partial_pattern.str() << '\n';
+        for (const auto & input_output : test_case.inputs_outputs) {
+            std::cout << "  Input: " << input_output.input << '\n';
+            auto m = cr.search(input_output.input, 0);
+            if (m != input_output.output) {
+                auto match_to_str = [&](const std::optional & m) {
+                    std::ostringstream ss;
+                    if (m->type == COMMON_REGEX_MATCH_TYPE_NONE) {
+                        ss << "";
+                    } else {
+                        GGML_ASSERT(!input_output.output.groups.empty());
+                        std::vector parts;
+                        for (const auto & g : m->groups) {
+                            parts.push_back("{" + std::to_string(g.begin) + ", " + std::to_string(g.end) + "}");
+                        }
+                        ss << "{" << common_regex_match_type_name(m->type) << ", {" << string_join(parts, ", ") << "}}";
+                    }
+                    return ss.str();
+                };
+                std::cout << "    Expected: " << match_to_str(input_output.output) << '\n';
+                std::cout << "         Got: " << match_to_str(m) << '\n';
+                std::cout << " Inverted pattern: /" << regex_to_reversed_partial_regex(test_case.pattern) << "/\n";
+
+                throw std::runtime_error("Test failed");
+            }
+        }
+    };
+    test({
+        "a",
+        {
+            {"a", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
+            {"b", {COMMON_REGEX_MATCH_TYPE_NONE, {}}},
+            {"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
+            {"ba", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 2}}}},
+        }
+    });
+    test({
+        "abcd",
+        {
+            {"abcd", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
+            {"abcde", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
+            {"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
+            {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+            {"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
+            {"d", {}},
+            {"bcd", {}},
+            {"cde", {}},
+            {"cd", {}},
+            {"yeah ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{5, 7}}}},
+            {"abbie", {}},
+            {"", {}},
+        }
+    });
+    test({
+        ".*?ab",
+        {
+            {"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
+            {"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
+            {"dab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
+            {"dabc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
+            {"da", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+            {"d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
+        }
+    });
+    test({
+        "a.*?b",
+        {
+            {"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
+            {"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
+            {"a b", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
+            {"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
+            {"argh", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
+            {"d", {}},
+            {"b", {}},
+        }
+    });
+    test({
+        "ab(?:cd){2,4}ef",
+        {
+            // {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, 0, {}}},
+            {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+            {"abcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
+            {"abcde", {}},
+            {"abcdef", {}},
+            {"abcdcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
+            {"abcdcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 7}}}},
+            {"abcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
+            {"abcdcdcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 12}}}},
+            {"abcdcdcdcdcdef", {}},
+            {"abcde", {}},
+            {"yea", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{2, 3}}}},
+        }
+    });
+    test({
+        "a(?:rte| pure )fact",
+        {
+            {"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
+            {"art", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
+            {"artefa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
+            {"fact", {}},
+            {"an arte", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{3, 7}}}},
+            {"artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
+            {"an artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{3, 11}}}},
+            {"a pure", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
+            {"a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 11}}}},
+            {"it's a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{5, 16}}}},
+            {"" , {}},
+            {"pure", {}},
+            {"pure fact", {}},
+        }
+    });
+    test({
+        "abc",
+        {
+            {" abcc", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 4}}}},
+            {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+            {"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
+            {" ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{1, 3}}}},
+            {"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
+            {"b", {}},
+            {"c", {}},
+            {"", {}},
+        }
+    });
+
+    test({
+        "(?:abc)?\\s*def",
+        {
+            {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+            {"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
+            {"abc ", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
+            {"abc d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
+            {"abc de", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
+            {"abc def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
+            {"abc defg", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
+            {"abc defgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
+            {"abcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
+            {"abcdefgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 6}}}},
+            {" d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+            {"def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
+        }
+    });
+
+    test({
+        "a+b",
+        {
+            {"aaab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
+            {"aaa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
+            {"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
+        }
+    });
+
+    test({
+        "(?:"
+            "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
+            "("                          // match 2 (open_tag)
+                ""
+                "|"
+                "|"
+                "|"
+                "|"
+                "|"
+                "|"
+                "|"
+            ")?"
+            "(\\s*\\{\\s*\"name\"\\s*:)" // match 3 (named tool call)
+        ")"
+        "|]+)>"            // match 4 (function name)
+        "|", // match 5 (function name again)
+        {
+            {"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}, {54, 54}, {54, 54}, {0, 8}, {54, 54}, {54, 54}}}},
+            {" {\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 18}}}},
+            {"{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 17}}}},
+            {"Let's call something\n{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{21, 38}}}},
+            {"Ok then{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 24}}}},
+            {"{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
+            {"Ok then{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 13}}}},
+            {" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 20}, {66, 66}, {0, 11}, {11, 20}, {66, 66}, {66, 66}}}},
+            {" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 24}, {70, 70}, {0, 15}, {15, 24}, {70, 70}, {70, 70}}}},
+            {" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 34}, {89, 89}, {89, 89}, {89, 89}, {89, 89}, {16, 32}}}},
+            {"", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 14}, {14, 14}, {14, 14}, {14, 14}, {10, 13}, {14, 14}}}},
+
+        }
+    });
+}
+
+static void test_regex_to_reversed_partial_regex() {
+    printf("[%s]\n", __func__);
+
+    assert_equals(
+        "((?:(?:c)?b)?a)[\\s\\S]*",
+        regex_to_reversed_partial_regex("abc"));
+
+    assert_equals(
+        "(a+)[\\s\\S]*",
+        regex_to_reversed_partial_regex("a+"));
+
+    assert_equals(
+        "(a*)[\\s\\S]*",
+        regex_to_reversed_partial_regex("a*"));
+
+    assert_equals(
+        "(a?)[\\s\\S]*",
+        regex_to_reversed_partial_regex("a?"));
+
+    assert_equals(
+        "([a-z])[\\s\\S]*",
+        regex_to_reversed_partial_regex("[a-z]"));
+
+    assert_equals(
+        "((?:\\w+)?[a-z])[\\s\\S]*",
+        regex_to_reversed_partial_regex("[a-z]\\w+"));
+
+    assert_equals(
+        "((?:a|b))[\\s\\S]*",
+        regex_to_reversed_partial_regex("(?:a|b)"));
+    assert_equals(
+        "((?:(?:(?:d)?c)?b)?a)[\\s\\S]*",
+        regex_to_reversed_partial_regex("abcd"));
+    assert_equals(
+        "((?:b)?a*)[\\s\\S]*", // TODO: ((?:b)?a*+).* ??
+        regex_to_reversed_partial_regex("a*b"));
+    assert_equals(
+        "((?:(?:b)?a)?.*)[\\s\\S]*",
+        regex_to_reversed_partial_regex(".*?ab"));
+    assert_equals(
+        "((?:(?:b)?.*)?a)[\\s\\S]*",
+        regex_to_reversed_partial_regex("a.*?b"));
+    assert_equals(
+        "((?:(?:d)?(?:(?:c)?b))?a)[\\s\\S]*",
+        regex_to_reversed_partial_regex("a(bc)d"));
+    assert_equals(
+        "((?:(?:(?:c)?b|(?:e)?d))?a)[\\s\\S]*",
+        regex_to_reversed_partial_regex("a(bc|de)"));
+    assert_equals(
+        "((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)[\\s\\S]*",
+        regex_to_reversed_partial_regex("ab{2,4}c"));
+}
+
+int main() {
+    test_regex_to_reversed_partial_regex();
+    test_regex();
+    std::cout << "All tests passed.\n";
+}
diff --git a/tests/test-rope.cpp b/tests/test-rope.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..322b8bb99ec6c223461337a9308d38f74dfe6236
--- /dev/null
+++ b/tests/test-rope.cpp
@@ -0,0 +1,262 @@
+#include "ggml.h"
+#include "ggml-cpu.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+#if defined(__GNUC__)
+#pragma GCC diagnostic ignored "-Wdouble-promotion"
+#endif
+
+#define MAX_NARGS 3
+
+#undef MIN
+#undef MAX
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
+#define GGML_SILU_FP16
+
+//
+// logging
+//
+
+#if (GGML_DEBUG >= 1)
+#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG(...)
+#endif
+
+#if (GGML_DEBUG >= 5)
+#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_5(...)
+#endif
+
+#if (GGML_DEBUG >= 10)
+#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_10(...)
+#endif
+
+#define GGML_PRINT(...) printf(__VA_ARGS__)
+
+static float frand(void) {
+    return (float)rand()/(float)RAND_MAX;
+}
+
+static int irand(int n) {
+    if (n == 0) return 0;
+    return rand()%n;
+}
+
+static void get_random_dims(int64_t * dims, int ndims) {
+    dims[0] = dims[1] = dims[2] = dims[3] = 1;
+
+    for (int i = 0; i < ndims; i++) {
+        dims[i] = 1 + irand(4);
+    }
+}
+
+static struct ggml_tensor * get_random_tensor_f32(
+        struct ggml_context * ctx0,
+        int ndims,
+        const int64_t ne[],
+        float fmin,
+        float fmax) {
+    struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F32, ndims, ne);
+
+    switch (ndims) {
+        case 1:
+            for (int i0 = 0; i0 < ne[0]; i0++) {
+                ((float *)result->data)[i0] = frand()*(fmax - fmin) + fmin;
+            }
+            break;
+        case 2:
+            for (int i1 = 0; i1 < ne[1]; i1++) {
+                for (int i0 = 0; i0 < ne[0]; i0++) {
+                    ((float *)result->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                }
+            }
+            break;
+        case 3:
+            for (int i2 = 0; i2 < ne[2]; i2++) {
+                for (int i1 = 0; i1 < ne[1]; i1++) {
+                    for (int i0 = 0; i0 < ne[0]; i0++) {
+                        ((float *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                    }
+                }
+            }
+            break;
+        case 4:
+            for (int i3 = 0; i3 < ne[3]; i3++) {
+                for (int i2 = 0; i2 < ne[2]; i2++) {
+                    for (int i1 = 0; i1 < ne[1]; i1++) {
+                        for (int i0 = 0; i0 < ne[0]; i0++) {
+                            ((float *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                        }
+                    }
+                }
+            }
+            break;
+        default:
+            assert(false);
+    };
+
+    return result;
+}
+
+static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * graph, int n_threads) {
+    struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr);
+
+    if (plan.work_size > 0) {
+        buf.resize(plan.work_size);
+        plan.work_data = buf.data();
+    }
+
+    ggml_graph_compute(graph, &plan);
+}
+
+int main(int /*argc*/, const char ** /*argv*/) {
+    struct ggml_init_params params = {
+        /* .mem_size   = */ 128*1024*1024,
+        /* .mem_buffer = */ NULL,
+        /* .no_alloc   = */ false,
+    };
+
+    std::vector work_buffer;
+
+    struct ggml_context * ctx0 = ggml_init(params);
+
+    struct ggml_tensor * x;
+
+    // rope f32
+    for (int m = 0; m < 5; ++m) {
+        const int ndims = 4;
+
+        const int64_t n_rot = 128;
+        const int64_t ne[4] = { 2*n_rot, 32, 73, 1 };
+
+        const int n_past_0 = 100;
+        const int n_past_2 = 33;
+
+        struct ggml_tensor * r0;
+        struct ggml_tensor * r1;
+        struct ggml_tensor * r2;
+        x = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+        int mode = -1;
+
+        if (m < 3) {
+            struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
+            struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
+            struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
+
+            for (int i = 0; i < ne[2]; ++i) {
+                ((int32_t *) p0->data)[i] = n_past_0 + i;
+                ((int32_t *) p1->data)[i] = n_past_2 - n_past_0;
+                ((int32_t *) p2->data)[i] = n_past_2 + i;
+            }
+            // test mode 0, 2, 4 (standard, GPT-NeoX, GLM)
+            mode = m == 0 ? 0 : m == 1 ? 2 : 4;
+
+            // 100, 101, 102, ..., 172
+            r0 = ggml_rope(ctx0, x,  p0, n_rot, mode);
+            // -67, -67, -67, ..., -67
+            r1 = ggml_rope(ctx0, r0, p1, n_rot, mode); // "context swap", i.e. forget n_past_0 - n_past_2 tokens
+
+            //  33,  34,  35, ..., 105
+            r2 = ggml_rope(ctx0, x,  p2, n_rot, mode);
+        } else {
+            // testing multi-dimension rope position embedding mode
+            struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2] * 4);
+            struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2] * 4);
+            struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2] * 4);
+
+            int sections[4] = {16, 24, 24, 0};
+            mode = (m == 3) ? GGML_ROPE_TYPE_MROPE : GGML_ROPE_TYPE_VISION;
+
+            for (int i = 0; i < ne[2]; ++i) {
+                for (int j = 0; j < 4; ++j) {
+                    ((int32_t *) p0->data)[i + ne[2] * j] = n_past_0 + i + j;
+                    ((int32_t *) p1->data)[i + ne[2] * j] = n_past_2 - n_past_0;
+                    ((int32_t *) p2->data)[i + ne[2] * j] = n_past_2 + i + j;
+                }
+            }
+
+            // [[100, 101, 102, ..., 172],
+            // [101, 102, 103, ..., 173],
+            // [102, 103, 104, ..., 174]]
+            r0 = ggml_rope_multi(
+                ctx0, x, p0, nullptr,
+                n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1);
+            // [[-67, -67, -67, ..., -67]
+            // [-67, -67, -67, ..., -67]
+            // [-67, -67, -67, ..., -67]]
+            r1 = ggml_rope_multi(
+                ctx0, r0, p1, nullptr,
+                n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1);
+
+            //  [[33,  34,  35, ..., 105]
+            //  [34,  35,  36, ..., 106]
+            //  [35,  36,  37, ..., 107]]
+            r2 = ggml_rope_multi(
+                ctx0, x, p2, nullptr,
+                n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1);
+        }
+
+        ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+        ggml_build_forward_expand(gf, r0);
+        ggml_build_forward_expand(gf, r1);
+        ggml_build_forward_expand(gf, r2);
+
+        ggml_graph_compute_helper(work_buffer, gf, 4);
+
+        // check that r1 and r2 are the same
+        {
+            double sum0 = 0.0f;
+            double sum1 = 0.0f;
+            double diff = 0.0f;
+
+            const float * r1_data = (float *) r1->data;
+            const float * r2_data = (float *) r2->data;
+
+            const int n_elements = ggml_nelements(r1);
+
+            for (int i = 0; i < n_elements; ++i) {
+                sum0 += fabs(r1_data[i]);
+                sum1 += fabs(r2_data[i]);
+                diff += fabs(r1_data[i] - r2_data[i]);
+                //if (fabs(r1_data[i] - r2_data[i]) > 0.0001f) {
+                //    printf("%d: %f %f\n", i, r1_data[i], r2_data[i]);
+                //    printf("diff: %f\n", fabs(r1_data[i] - r2_data[i]));
+                //}
+            }
+
+            //for (int i = 4096; i < 4096 + 128; ++i) {
+            //    printf("%f %f\n", r1_data[i], r2_data[i]);
+            //}
+
+            printf("mode: %d\n", mode);
+            printf("sum0: %f\n", sum0);
+            printf("sum1: %f\n", sum1);
+            printf("diff: %f\n", diff);
+            printf("rel err: %f\n", diff / sum0);
+            printf("rel err: %f\n", diff / sum1);
+
+            GGML_ASSERT(diff / sum0 < 0.0001f);
+            GGML_ASSERT(diff / sum1 < 0.0001f);
+        }
+    }
+
+    ggml_free(ctx0);
+
+    return 0;
+}
diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..6300f25caebe3090fb2a157f44ea5d49a4c3941e
--- /dev/null
+++ b/tests/test-sampling.cpp
@@ -0,0 +1,400 @@
+#include "ggml.h"
+#include "llama.h"
+
+#ifdef NDEBUG
+#undef NDEBUG
+#endif
+
+#include 
+#include 
+#include 
+#include 
+
+extern struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector>& seq_breakers);
+
+static void dump(const llama_token_data_array * cur_p) {
+    for (size_t i = 0; i < cur_p->size; i++) {
+        printf("%d: %f (%f)\n", cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
+    }
+}
+
+#define DUMP(__cur_p) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__cur_p)); printf("-\n"); } while(0)
+
+struct sampler_tester {
+    sampler_tester(size_t n_vocab) {
+        cur.reserve(n_vocab);
+        for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
+            const float logit = logf(token_id);
+            cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
+        }
+
+        cur_p = llama_token_data_array { cur.data(), cur.size(), -1, false };
+    }
+
+    sampler_tester(const std::vector & probs, const std::vector & probs_expected) : probs_expected(probs_expected) {
+        cur.reserve(probs.size());
+        for (llama_token token_id = 0; token_id < (llama_token)probs.size(); token_id++) {
+            const float logit = logf(probs[token_id]);
+            cur.emplace_back(llama_token_data{token_id, logit, probs[token_id]});
+        }
+
+        cur_p = llama_token_data_array { cur.data(), cur.size(), -1, false };
+    }
+
+    void apply(llama_sampler * sampler) {
+        llama_sampler_apply(sampler, &cur_p);
+        llama_sampler_free(sampler);
+    }
+
+    void check() {
+        GGML_ASSERT(cur_p.size == probs_expected.size());
+        for (size_t i = 0; i < cur_p.size; i++) {
+            GGML_ASSERT(fabs(cur_p.data[i].p - probs_expected[i]) < 1e-5);
+        }
+    }
+
+    llama_token_data_array cur_p;
+
+private:
+    const std::vector probs_expected;
+
+    std::vector cur;
+};
+
+static void test_temp(const std::vector & probs, const std::vector & probs_expected, float temp) {
+    sampler_tester tester(probs, probs_expected);
+
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_temp(temp));
+    tester.apply(llama_sampler_init_dist(0));
+    DUMP(&tester.cur_p);
+
+    tester.check();
+}
+
+static void test_temp_ext(const std::vector & probs, const std::vector & probs_expected, float temp, float delta, float exponent) {
+    sampler_tester tester(probs, probs_expected);
+
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_temp_ext(temp, delta, exponent));
+    tester.apply(llama_sampler_init_dist (0));
+    DUMP(&tester.cur_p);
+
+    tester.check();
+}
+
+static void test_top_k(const std::vector & probs, const std::vector & probs_expected, int k) {
+    sampler_tester tester(probs, probs_expected);
+
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_top_k(k));
+    tester.apply(llama_sampler_init_dist (0));
+    DUMP(&tester.cur_p);
+
+    tester.check();
+}
+
+static void test_top_p(const std::vector & probs, const std::vector & probs_expected, float p) {
+    sampler_tester tester(probs, probs_expected);
+
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_top_p(p, 0));
+    tester.apply(llama_sampler_init_dist (0));
+    DUMP(&tester.cur_p);
+
+    tester.check();
+}
+
+static void test_min_p(const std::vector & probs, const std::vector & probs_expected, float p) {
+    sampler_tester tester(probs, probs_expected);
+
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_min_p(p, 0));
+    tester.apply(llama_sampler_init_dist (0));
+    DUMP(&tester.cur_p);
+
+    tester.check();
+}
+
+static void test_xtc(const std::vector & probs, const std::vector & probs_expected, float p, float t) {
+    sampler_tester tester(probs, probs_expected);
+
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_xtc(p, t, 0, 0));
+    DUMP(&tester.cur_p);
+
+    tester.check();
+}
+
+static void test_typical(const std::vector & probs, const std::vector & probs_expected, float p) {
+    sampler_tester tester(probs, probs_expected);
+
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_typical(p, 0));
+    DUMP(&tester.cur_p);
+
+    tester.check();
+}
+
+static void test_penalties(
+    const std::vector & probs, const std::vector & last_tokens,
+    const std::vector & probs_expected, float repeat_penalty, float alpha_frequency, float alpha_presence
+) {
+    GGML_ASSERT(probs.size() == probs_expected.size());
+
+    sampler_tester tester(probs, probs_expected);
+
+    auto * sampler = llama_sampler_init_penalties(last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
+
+    for (size_t i = 0; i < last_tokens.size(); i++) {
+        llama_sampler_accept(sampler, last_tokens[i]);
+    }
+
+    DUMP(&tester.cur_p);
+    tester.apply(sampler);
+    tester.apply(llama_sampler_init_dist(0));
+    DUMP(&tester.cur_p);
+
+    tester.check();
+}
+
+static void test_dry(
+    const std::vector & probs, const std::vector & last_tokens,
+    const std::vector & expected_probs, float dry_multiplier, float dry_base,
+    int dry_allowed_length, int dry_penalty_last_n,
+    const std::vector> & seq_breakers
+) {
+    GGML_ASSERT(probs.size() == expected_probs.size());
+
+    sampler_tester tester(probs, expected_probs);
+
+    auto * sampler = llama_sampler_init_dry_testing(1024, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers);
+
+    for (size_t i = 0; i < last_tokens.size(); i++) {
+        llama_sampler_accept(sampler, last_tokens[i]);
+    }
+
+    DUMP(&tester.cur_p);
+    tester.apply(sampler);
+    tester.apply(llama_sampler_init_dist(0));
+    DUMP(&tester.cur_p);
+    tester.check();
+}
+
+static void test_top_n_sigma(const std::vector & probs, const std::vector & probs_expected, int n) {
+    sampler_tester tester(probs, probs_expected);
+
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_top_n_sigma(n));
+    tester.apply(llama_sampler_init_dist (0));
+    DUMP(&tester.cur_p);
+
+    tester.check();
+}
+
+static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
+) {
+    sampler_tester tester(n_vocab);
+
+          llama_token min_token_id = 0;
+    const llama_token max_token_id = n_vocab-1;
+
+    for (auto s : samplers_sequence) {
+        switch (s){
+            case 'k': tester.apply(llama_sampler_init_top_k(top_k)); break;
+            case 'y': GGML_ABORT("typical test not implemented");
+            case 'p': tester.apply(llama_sampler_init_top_p(top_p, 1)); break;
+            case 'm': tester.apply(llama_sampler_init_min_p(min_p, 1)); break;
+            case 't': GGML_ABORT("temperature test not implemented");
+            default : GGML_ABORT("Unknown sampler");
+        }
+
+        tester.apply(llama_sampler_init_dist(0));
+
+        auto & cur_p = tester.cur_p;
+
+        const int size = cur_p.size;
+
+        if (s == 'k') {
+            const int expected_size = std::min(size, top_k);
+            min_token_id = std::max(min_token_id, (llama_token)(n_vocab - top_k));
+
+            GGML_ASSERT(size == expected_size);
+            GGML_ASSERT(cur_p.data[0].id == max_token_id);
+            GGML_ASSERT(cur_p.data[expected_size-1].id == min_token_id);
+        } else if (s == 'p') {
+            const int softmax_divisor = n_vocab * (n_vocab-1) / 2 - min_token_id * (min_token_id-1) / 2;
+            const int softmax_numerator_target = ceilf(top_p * softmax_divisor);
+
+                min_token_id  = n_vocab;
+            int expected_size = 0;
+            int cumsum        = 0;
+            do { // do-while because always at least one token is sampled
+                min_token_id--;
+                expected_size++;
+
+                cumsum += min_token_id;
+            } while (cumsum < softmax_numerator_target);
+
+            // token 0 has p == 0, need special consideration for cumsum because top_p immediately returns
+            if (min_token_id == 1) {
+                min_token_id--;
+                expected_size += 1;
+            }
+
+            GGML_ASSERT(size == expected_size);
+            GGML_ASSERT(cur_p.data[0].id == max_token_id);
+            GGML_ASSERT(cur_p.data[expected_size-1].id == min_token_id);
+        } else if (s == 'm') {
+            int expected_size = ceilf((1.0f-min_p) * n_vocab);
+            expected_size = std::max(expected_size, 1);
+            expected_size = std::min(expected_size, size);
+
+            min_token_id = floorf(min_p * n_vocab);
+            min_token_id = std::max(min_token_id, 1);
+            min_token_id = std::max(min_token_id, (llama_token)(n_vocab - size));
+            min_token_id = std::min(min_token_id, (llama_token)(n_vocab - 1));
+
+            GGML_ASSERT(size == expected_size);
+            GGML_ASSERT(cur_p.data[0].id == max_token_id);
+            GGML_ASSERT(cur_p.data[expected_size-1].id == min_token_id);
+        } else {
+            GGML_ABORT("fatal error");
+        }
+    }
+
+    printf("Sampler queue %3s OK with n_vocab=%05zu top_k=%05d top_p=%f min_p=%f\n",
+           samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p);
+}
+
+static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vector & data, int n_iter) {
+    std::vector cur(data.size());
+    std::copy(data.begin(), data.end(), cur.begin());
+    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+    llama_sampler_apply(cnstr, &cur_p);
+    llama_sampler_reset(cnstr);
+    const int64_t t_start = ggml_time_us();
+    for (int i = 0; i < n_iter; i++) {
+        std::copy(data.begin(), data.end(), cur.begin());
+        llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+        llama_sampler_apply(cnstr, &cur_p);
+        llama_sampler_reset(cnstr);
+    }
+    const int64_t t_end = ggml_time_us();
+    llama_sampler_free(cnstr);
+    printf("%-43s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
+}
+
+#define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter))
+
+static void test_perf() {
+    const int n_vocab = 1 << 17;
+
+    std::vector data;
+
+    data.reserve(n_vocab);
+    for (int i = 0; i < n_vocab; i++) {
+        const float logit = 2.0f*((double)(rand())/RAND_MAX - 0.5);
+        data.emplace_back(llama_token_data{i, logit, 0.0f});
+    }
+
+    BENCH(llama_sampler_init_top_k  (40),                     data, 32);
+    BENCH(llama_sampler_init_top_p  (0.8f, 1),                data, 32);
+    BENCH(llama_sampler_init_min_p  (0.2f, 1),                data, 32);
+    BENCH(llama_sampler_init_typical(0.5f, 1),                data, 32);
+    BENCH(llama_sampler_init_xtc    (1.0f, 0.1f, 1, 1),       data, 32);
+}
+
+int main(void) {
+    ggml_time_init();
+
+    test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
+    test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
+
+    test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f);
+    test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f);
+
+    test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1);
+    test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
+    test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
+    test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
+
+    test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 0);
+    test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f}, 0.7f);
+    test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 0.8f);
+    test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
+
+    test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.00f);
+    test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.24f);
+    test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.9f, 0.3f/0.9f, 0.2f/0.9f},            0.26f);
+    test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.9f, 0.3f/0.9f, 0.2f/0.9f},            0.49f);
+    test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f},                       0.51f);
+    test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f},                       0.74f);
+    test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f},                                  0.76f);
+    test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f},                                  1.00f);
+    test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f},                                  1.05f);
+
+    printf("XTC should:\n");
+    test_xtc({0.4f, 0.3f, 0.2f, 0.1f},   {0.1f},                                0.99f, 0.09f);
+    test_xtc({0.4f, 0.3f, 0.2f, 0.1f},   {0.2f, 0.1f},                          0.99f, 0.19f);
+    test_xtc({0.4f, 0.3f, 0.2f, 0.1f},   {0.3f, 0.2f, 0.1f},                    0.99f, 0.29f);
+
+    printf("XTC should not:\n");
+    test_xtc({0.4f, 0.3f, 0.2f, 0.1f},   {0.4f, 0.3f, 0.2f, 0.1f},              0.99f, 0.39f);
+
+    test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f},            0.5f);
+    test_typical({0.4f, 0.2f, 0.2f, 0.2f},     {0.2f, 0.2f, 0.2f}, 0.5f);
+
+    test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0},   50.0f, 0.0f, 0.0f);
+    test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0},       50.0f, 0.0f, 0.0f);
+    test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
+
+    test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0},             {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f);
+    test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2},       {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
+    test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
+
+
+    test_dry({0.25f, 0.25f, 0.25f, 0.25f}, {0, 1}, {0.25f, 0.25f, 0.25f, 0.25f}, 1.0f, 1.1f, 2, 4, {});
+    test_dry({0.25f, 0.25f, 0.25f, 0.25f}, {0, 1, 2, 0, 1}, {0.296923f, 0.296923f, 0.296923f, 0.109232f}, 1.0f, 1.1f, 2, 5, {});
+    test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 2, 6, {{3}});
+    test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.241818f, 0.241818f, 0.032727f}, 2.0f, 1.1f, 2, 5, {});
+    test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {});
+
+    test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1.00f);
+    test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.00f); // top_n_sigma == 0 now represents a no-op rather than greedy decoding as of PR#13345
+    test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3.00f);
+
+    test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
+    test_sampler_queue(10000, "k",     1, 1.0f, 1.0f);
+    test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f);
+    test_sampler_queue(10000, "p", 10000, 0.0f, 1.0f);
+    test_sampler_queue(10000, "m", 10000, 1.0f, 1.0f);
+    test_sampler_queue(10000, "m", 10000, 1.0f, 1e-12);
+
+    test_sampler_queue(10000, "k",   100, 1.0000f, 1.0f);
+    test_sampler_queue(10000, "p", 10000, 0.0002f, 1.0f);
+    test_sampler_queue(10000, "p", 10000, 0.8000f, 1.0f);
+    test_sampler_queue(10000, "m", 10000, 1.0000f, 9997.9f/9999.0f);
+    test_sampler_queue(10000, "m", 10000, 1.0000f, 0.1f);
+
+    test_sampler_queue(10000, "kp", 100, 0.8f, 0.1f);
+    test_sampler_queue(10000, "km", 100, 0.8f, 0.1f);
+    test_sampler_queue(10000, "pk", 100, 0.8f, 0.1f);
+    test_sampler_queue(10000, "pm", 100, 0.8f, 0.1f);
+    test_sampler_queue(10000, "mk", 100, 0.8f, 0.1f);
+    test_sampler_queue(10000, "mp", 100, 0.8f, 9997.9f/9999.0f);
+    test_sampler_queue(10000, "mp", 100, 0.8f, 0.1f);
+
+    test_sampler_queue(10000, "kpm", 100, 0.8f, 0.1f);
+    test_sampler_queue(10000, "kmp", 100, 0.8f, 0.1f);
+    test_sampler_queue(10000, "pkm", 100, 0.8f, 0.1f);
+    test_sampler_queue(10000, "pmk", 100, 0.8f, 0.1f);
+    test_sampler_queue(10000, "mkp", 100, 0.8f, 0.1f);
+    test_sampler_queue(10000, "mpk", 100, 0.8f, 0.1f);
+
+    printf("OK\n");
+
+    test_perf();
+
+    return 0;
+}
diff --git a/tests/test-thread-safety.cpp b/tests/test-thread-safety.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..d525b7430f9d9a7bd03bdcd71140ccf687703cb7
--- /dev/null
+++ b/tests/test-thread-safety.cpp
@@ -0,0 +1,152 @@
+// thread safety test
+// - Loads a copy of the same model on each GPU, plus a copy on the CPU
+// - Creates n_parallel (--parallel) contexts per model
+// - Runs inference in parallel on each context
+
+#include 
+#include 
+#include 
+#include "llama.h"
+#include "arg.h"
+#include "common.h"
+#include "log.h"
+#include "sampling.h"
+
+int main(int argc, char ** argv) {
+    common_params params;
+
+    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
+        return 1;
+    }
+
+    common_init();
+
+    llama_backend_init();
+    llama_numa_init(params.numa);
+
+    LOG_INF("%s\n", common_params_get_system_info(params).c_str());
+
+    //llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) {
+    //    if (level == GGML_LOG_LEVEL_ERROR) {
+    //        common_log_add(common_log_main(), level, "%s", text);
+    //    }
+    //}, NULL);
+
+    auto cparams = common_context_params_to_llama(params);
+
+    int dev_count = ggml_backend_dev_count();
+    int gpu_dev_count = 0;
+    for (int i = 0; i < dev_count; ++i) {
+        auto * dev = ggml_backend_dev_get(i);
+        if (dev && ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
+            gpu_dev_count++;
+        }
+    }
+    const int num_models = gpu_dev_count + 1 + 1; // GPUs + 1 CPU model + 1 layer split
+    //const int num_models = std::max(1, gpu_dev_count);
+    const int num_contexts = std::max(1, params.n_parallel);
+
+    std::vector models;
+    std::vector threads;
+    std::atomic failed = false;
+
+    for (int m = 0; m < num_models; ++m) {
+        auto mparams = common_model_params_to_llama(params);
+
+        if (m < gpu_dev_count) {
+            mparams.split_mode = LLAMA_SPLIT_MODE_NONE;
+            mparams.main_gpu = m;
+        } else if (m == gpu_dev_count) {
+            mparams.split_mode = LLAMA_SPLIT_MODE_NONE;
+            mparams.main_gpu = -1; // CPU model
+        } else {
+            mparams.split_mode = LLAMA_SPLIT_MODE_LAYER;;
+        }
+
+        llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
+        if (model == NULL) {
+            LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
+            return 1;
+        }
+
+        models.emplace_back(model);
+    }
+
+    for  (int m = 0; m < num_models; ++m) {
+        auto * model = models[m].get();
+        for (int c = 0; c < num_contexts; ++c) {
+            threads.emplace_back([&, m, c, model]() {
+                LOG_INF("Creating context %d/%d for model %d/%d\n", c + 1, num_contexts, m + 1, num_models);
+
+                llama_context_ptr ctx { llama_init_from_model(model, cparams) };
+                if (ctx == NULL) {
+                    LOG_ERR("failed to create context\n");
+                    failed.store(true);
+                    return;
+                }
+
+                std::unique_ptr sampler { common_sampler_init(model, params.sampling), common_sampler_free };
+                if (sampler == NULL) {
+                    LOG_ERR("failed to create sampler\n");
+                    failed.store(true);
+                    return;
+                }
+
+                llama_batch batch = {};
+                {
+                    auto prompt = common_tokenize(ctx.get(), params.prompt, true);
+                    if (prompt.empty()) {
+                        LOG_ERR("failed to tokenize prompt\n");
+                        failed.store(true);
+                        return;
+                    }
+                    batch = llama_batch_get_one(prompt.data(), prompt.size());
+                    if (llama_decode(ctx.get(), batch)) {
+                        LOG_ERR("failed to decode prompt\n");
+                        failed.store(true);
+                        return;
+                    }
+                }
+
+                const auto * vocab = llama_model_get_vocab(model);
+                std::string result = params.prompt;
+
+                for (int i = 0; i < params.n_predict; i++) {
+                    llama_token token;
+                    if (batch.n_tokens > 0) {
+                        token = common_sampler_sample(sampler.get(), ctx.get(), batch.n_tokens - 1);
+                    } else {
+                        token = llama_vocab_bos(vocab);
+                    }
+
+                    result += common_token_to_piece(ctx.get(), token);
+
+                    if (llama_vocab_is_eog(vocab, token)) {
+                        break;
+                    }
+
+                    batch = llama_batch_get_one(&token, 1);
+                    if (llama_decode(ctx.get(), batch)) {
+                        LOG_ERR("Model %d/%d, Context %d/%d: failed to decode\n", m + 1, num_models, c + 1, num_contexts);
+                        failed.store(true);
+                        return;
+                    }
+                }
+
+                LOG_INF("Model %d/%d, Context %d/%d: %s\n\n", m + 1, num_models, c + 1, num_contexts, result.c_str());
+            });
+        }
+    }
+
+    for (auto & thread : threads) {
+        thread.join();
+    }
+
+    if (failed) {
+        LOG_ERR("One or more threads failed.\n");
+        return 1;
+    }
+
+    LOG_INF("All threads finished without errors.\n");
+    return 0;
+}
diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..59dda48772aea9642b505be12a8e1fe1d7b9e247
--- /dev/null
+++ b/tests/test-tokenizer-0.cpp
@@ -0,0 +1,312 @@
+#include "llama.h"
+#include "common.h"
+#include "console.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+//static const std::map> & k_tests() {
+//    static std::map> _k_tests = {
+//        { ""                      , {  }, },
+//        { " "                     , {     220, }, },
+//        { "  "                    , {     256, }, },
+//        { "   "                   , {     262, }, },
+//        { "\t"                    , {     197, }, },
+//        { "\n"                    , {     198, }, },
+//        { "\n\n"                  , {     271, }, },
+//        { "\n\n\n"                , {    1432, }, },
+//        { "\t\n"                  , {    1602, }, },
+//        { "Hello world"           , {    9906,   1917, }, },
+//        { " Hello world"          , {   22691,   1917, }, },
+//        { "Hello World"           , {    9906,   4435, }, },
+//        { " Hello World"          , {   22691,   4435, }, },
+//        { " Hello World!"         , {   22691,   4435,      0, }, },
+//        { "Hello, world!"         , {    9906,     11,   1917,      0, }, },
+//        { " Hello, world!"        , {   22691,     11,   1917,      0, }, },
+//        { " this is 🦙.cpp"        , {     420,    374,  11410,     99,    247,     13,  11055, }, },
+//        { "w048 7tuijk dsdfhu"    , {      86,  23904,    220,     22,     83,   2005,  42908,  11729,   3013,  17156, }, },
+//        { "нещо на Български"     , {   79862, 102118,  13373,  64571,  34694,   3114, 112203,  80112, }, },
+//        { "កាន់តែពិសេសអាចខលចេញ"   , {   21549,    222,  98629,    241,  45358,    233,  21549,    237,  45358,    224,  21549,    244,  21549,    115,  21549,    253,  45358,    223,  21549,    253,  21549,     95,  98629,    227,  21549,    223,  21549,    249,  21549,    227,  45358,    223,  21549,    231, }, },
+//        { "🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)", {    9468,    248,    222,    320,   8416,      8,  27623,    114, 102470,   9468,    234,    104,  31643,    320,  36773, 100166,  98634,      8,  26602,    227,    320,   3323,  43465,    430,    706,   1202,   1866,   4037,      8, }, },
+//        { "Hello"                 , {    9906, }, },
+//        { " Hello"                , {   22691, }, },
+//        { "  Hello"               , {     220,  22691, }, },
+//        { "   Hello"              , {     256,  22691, }, },
+//        { "    Hello"             , {     262,  22691, }, },
+//        { "    Hello\n    Hello"  , {     262,  22691,    198,    262,  22691, }, },
+//        { " ("                    , {     320, }, },
+//        { "\n ="                  , {     198,    284, }, },
+//        { "' era"                 , {       6,  11639, }, },
+//        { "Hello, y'all! How are you 😁 ?我想在apple工作1314151天~", {    9906,     11,    379,  65948,      0,   2650,    527,    499,  27623,    223,    949,  37046, 101067,  19000,  23182, 102301,   9263,  18136,     16,  36827,  21909, }, },
+//        { "3"                     , {      18, }, },
+//        { "33"                    , {    1644, }, },
+//        { "333"                   , {    8765, }, },
+//        { "3333"                  , {    8765,     18, }, },
+//        { "33333"                 , {    8765,   1644, }, },
+//        { "333333"                , {    8765,   8765, }, },
+//        { "3333333"               , {    8765,   8765,     18, }, },
+//        { "33333333"              , {    8765,   8765,   1644, }, },
+//        { "333333333"             , {    8765,   8765,   8765, }, },
+//    };
+//
+//    return _k_tests;
+//}
+
+using llama_tests = std::map>;
+
+static llama_tests read_tests(const std::string & fname_inp, const std::string & fname_out) {
+    llama_tests tests;
+
+    std::ifstream ifs_inp(fname_inp);
+    if (!ifs_inp) {
+        fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_inp.c_str());
+        return tests;
+    }
+
+    std::string sraw((std::istreambuf_iterator(ifs_inp)), std::istreambuf_iterator());
+
+    std::ifstream ifs_out(fname_out);
+    if (!ifs_out) {
+        fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_out.c_str());
+        return tests;
+    }
+
+    std::vector sout;
+    for (std::string line; std::getline(ifs_out, line);) {
+        sout.push_back(line);
+    }
+
+    const std::string sep = "\n__ggml_vocab_test__\n";
+
+    std::vector sinp;
+
+    size_t pos = 0;
+    while (pos < sraw.size()) {
+        const size_t next = sraw.find(sep, pos);
+        if (next == std::string::npos) {
+            sinp.push_back(sraw.substr(pos));
+            break;
+        }
+        sinp.push_back(sraw.substr(pos, next - pos));
+        pos = next + sep.size();
+    }
+
+    if (sinp.size() != sout.size()) {
+        fprintf(stderr, "%s : error: input and output files have different number of tests\n", __func__);
+        return tests;
+    }
+
+    for (size_t i = 0; i < sinp.size(); ++i) {
+        const std::string & s = sinp[i];
+        const std::string & o = string_strip(sout[i]);
+
+        std::vector toks;
+
+        size_t pos = 0;
+        while (pos < o.size()) {
+            size_t next = o.find(' ', pos);
+            if (next == std::string::npos) {
+                next = o.size();
+            }
+            const std::string stok = o.substr(pos, next - pos);
+            toks.push_back(std::stoi(stok));
+            pos = next + 1;
+        }
+
+        tests[s] = toks;
+    }
+
+    return tests;
+}
+
+int main(int argc, char **argv) {
+    if (argc < 2) {
+        fprintf(stderr, "Usage: %s vocab-file [text-file]\n", argv[0]);
+        return 1;
+    }
+
+    const std::string fname = argv[1];
+
+    const std::string fname_inp = fname + ".inp";
+    const std::string fname_out = fname + ".out";
+
+    std::string fname_text;
+    if (argc > 2) {
+        fname_text = argv[2];
+    }
+
+    fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
+
+    llama_model * model;
+    llama_context * ctx;
+
+    llama_backend_init();
+
+    // load the vocab
+    {
+        auto mparams = llama_model_default_params();
+
+        mparams.vocab_only = true;
+
+        model = llama_model_load_from_file(fname.c_str(), mparams);
+
+        if (model == NULL) {
+            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
+            return 1;
+        }
+
+        auto cparams = llama_context_default_params();
+
+        ctx = llama_init_from_model(model, cparams);
+
+        if (ctx == NULL) {
+            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
+            llama_model_free(model);
+            return 1;
+        }
+    }
+
+#ifdef _WIN32
+    // We need this for unicode console support
+    console::init(false, false);
+    atexit([]() { console::cleanup(); });
+#endif
+
+    bool success = true;
+
+    const auto k_tests = [&]() -> llama_tests {
+        if (!fname_text.empty()) {
+            return {};
+        }
+
+        const auto res = read_tests(fname_inp, fname_out);
+
+        if (res.empty()) {
+            fprintf(stderr, "%s : error: no tests found\n", __func__);
+            exit(1);
+        }
+
+        return res;
+    }();
+
+    const bool add_special = false;
+
+    // multi-threaded tokenization
+    const int nthread = std::thread::hardware_concurrency();
+    std::vector threads(nthread);
+
+    for (int i = 0; i < nthread; i++) {
+        threads[i] = std::thread([&, i]() {
+            for (const auto & test_kv : k_tests) {
+                const std::vector res = common_tokenize(ctx, test_kv.first, add_special, false);
+
+                // here only print the result of the first thread
+                // because the other threads are running the same tests
+                if (i != 0) {
+                    continue;
+                }
+
+                printf("\n");
+                printf("src: '%s'\n", test_kv.first.c_str());
+                printf("res: '%s'\n", common_detokenize(ctx, res).c_str());
+                printf("tok: ");
+                for (const auto & tok : res) {
+                    printf("%d ", tok);
+                }
+                printf("\n");
+
+                bool correct = res.size() == test_kv.second.size();
+                for (int i = 0; i < (int) res.size() && correct; ++i) {
+                    if (test_kv.second[i] != res[i]) {
+                        correct = false;
+                    }
+                }
+
+                if (!correct) {
+                    fprintf(stderr, "%s : failed test:    '%s'\n", __func__, test_kv.first.c_str());
+                    fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
+                        common_detokenize(ctx, res).c_str(),
+                        common_detokenize(ctx, test_kv.second).c_str());
+                    fprintf(stderr, "%s : expected tokens: ", __func__);
+                    for (const auto & t : test_kv.second) {
+                        fprintf(stderr, "%6d '%s', ", t, common_token_to_piece(ctx, t).c_str());
+                    }
+                    fprintf(stderr, "\n");
+                    fprintf(stderr, "%s : got tokens:      ", __func__);
+                    for (const auto & t : res) {
+                        fprintf(stderr, "%6d '%s', ", t, common_token_to_piece(ctx, t).c_str());
+                    }
+                    fprintf(stderr, "\n");
+
+                    success = false;
+                }
+            }
+        });
+    }
+
+    for (int i = 0; i < nthread; i++) {
+        threads[i].join();
+    }
+
+    // single threaded tokenization
+    if (!fname_text.empty()) {
+        fprintf(stderr, "%s : tokenizing: '%s'\n", __func__, fname_text.c_str());
+
+        std::string text;
+        {
+            std::ifstream ifs(fname_text);
+            if (!ifs) {
+                fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_text.c_str());
+                return 1;
+            }
+            text = std::string(std::istreambuf_iterator(ifs), std::istreambuf_iterator());
+        }
+
+        fprintf(stderr, "%s : text size: %zu\n", __func__, text.size());
+
+        std::vector res;
+
+        {
+            const auto t_start = ggml_time_us();
+
+            res = common_tokenize(ctx, text, add_special, false);
+
+            const auto t_end = ggml_time_us();
+
+            fprintf(stderr, "%s : tokenized in %.3f ms (cpp)\n", __func__, (t_end - t_start) / 1000.0);
+        }
+
+        fprintf(stderr, "%s : tokens: %zu\n", __func__, res.size());
+
+        {
+            const std::string fname_out = fname_text + ".tokcpp";
+
+            std::ofstream ofs(fname_out);
+            if (!ofs) {
+                fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_out.c_str());
+                return 1;
+            }
+
+            for (const auto & tok : res) {
+                //ofs << tok << " '" << string_strip(llama_detokenize(ctx, std::vector{tok})) << "'" << std::endl;
+                ofs << tok << "\n";
+            }
+        }
+
+        fprintf(stderr, "%s : tokens written to '%s'\n", __func__, (fname_text + ".tokcpp").c_str());
+    }
+
+    llama_model_free(model);
+    llama_free(ctx);
+
+    llama_backend_free();
+
+    printf("\n");
+    printf("Tests %s\n", success ? "passed" : "failed");
+
+    return success ? 0 : 3;
+}
diff --git a/tests/test-tokenizer-0.py b/tests/test-tokenizer-0.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd760d1ce5be7041e22be7f4047229e965221ccc
--- /dev/null
+++ b/tests/test-tokenizer-0.py
@@ -0,0 +1,46 @@
+import time
+import argparse
+
+from transformers import AutoTokenizer
+
+parser = argparse.ArgumentParser()
+parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
+parser.add_argument("--fname-tok",   help="path to a text file to tokenize", required=True)
+args = parser.parse_args()
+
+dir_tokenizer = args.dir_tokenizer
+fname_tok = args.fname_tok
+
+tokenizer = AutoTokenizer.from_pretrained(dir_tokenizer)
+
+print('tokenizing file: ', fname_tok) # noqa: NP100
+fname_out = fname_tok + '.tok'
+with open(fname_tok, 'r', encoding='utf-8') as f:
+    lines = f.readlines()
+    s = ''.join(lines)
+    t_start = time.time()
+    res = tokenizer.encode(s, add_special_tokens=False)
+    t_end = time.time()
+    print('\nmain : tokenized in', "{:.3f}".format(1000.0 * (t_end - t_start)), 'ms (py)') # noqa: NP100
+    with open(fname_out, 'w', encoding='utf-8') as f:
+        for x in res:
+            # LLaMA v3 for some reason strips the space for these tokens (and others)
+            # if x == 662:
+            #     f.write(str(x) + ' \' ' + tokenizer.decode(x) + '\'\n')
+            # elif x == 1174:
+            #     f.write(str(x) + ' \' ' + tokenizer.decode(x) + '\'\n')
+            # elif x == 2564:
+            #     f.write(str(x) + ' \' ' + tokenizer.decode(x) + '\'\n')
+            # elif x == 758:
+            #     f.write(str(x) + ' \' ' + tokenizer.decode(x) + '\'\n')
+            # elif x == 949:
+            #     f.write(str(x) + ' \' ' + tokenizer.decode(x) + '\'\n')
+            # elif x == 5354:
+            #     f.write(str(x) + ' \' ' + tokenizer.decode(x) + '\'\n')
+            # else:
+            #     f.write(str(x) + ' \'' + tokenizer.decode(x) + '\'\n')
+            # f.write(str(x) + ' \'' + tokenizer.decode(x).strip() + '\'\n')
+            f.write(str(x) + '\n')
+    print('len(res): ', len(res)) # noqa: NP100
+    print('len(lines): ', len(lines)) # noqa: NP100
+print('results written to: ', fname_out) # noqa: NP100
diff --git a/tests/test-tokenizer-0.sh b/tests/test-tokenizer-0.sh
new file mode 100644
index 0000000000000000000000000000000000000000..4d2b8365547df4aefa4274d94fc90fbe375bbccb
--- /dev/null
+++ b/tests/test-tokenizer-0.sh
@@ -0,0 +1,41 @@
+#!/bin/bash
+#
+# Usage:
+#
+#   test-tokenizer-0.sh  
+#
+
+if [ $# -ne 2 ]; then
+    printf "Usage: $0  \n"
+    exit 1
+fi
+
+name=$1
+input=$2
+
+make -j tests/test-tokenizer-0
+
+printf "Testing %s on %s ...\n" $name $input
+
+set -e
+
+printf "Tokenizing using (py)  Python AutoTokenizer ...\n"
+python3 ./tests/test-tokenizer-0.py ./models/tokenizers/$name --fname-tok $input > /tmp/test-tokenizer-0-$name-py.log 2>&1
+
+printf "Tokenizing using (cpp) llama.cpp ...\n"
+./tests/test-tokenizer-0 ./models/ggml-vocab-$name.gguf $input > /tmp/test-tokenizer-0-$name-cpp.log 2>&1
+
+cat /tmp/test-tokenizer-0-$name-py.log | grep "tokenized in"
+cat /tmp/test-tokenizer-0-$name-cpp.log | grep "tokenized in"
+
+set +e
+
+diff $input.tok $input.tokcpp > /dev/null 2>&1
+
+if [ $? -eq 0 ]; then
+    printf "Tokenization is correct!\n"
+else
+    diff $input.tok $input.tokcpp | head -n 32
+
+    printf "Tokenization differs!\n"
+fi
diff --git a/tests/test-tokenizer-1-bpe.cpp b/tests/test-tokenizer-1-bpe.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..b183da47f3cc841e52b79992b803eec958d0e12c
--- /dev/null
+++ b/tests/test-tokenizer-1-bpe.cpp
@@ -0,0 +1,155 @@
+#include "llama.h"
+#include "common.h"
+#include "console.h"
+
+#include "../src/unicode.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+int main(int argc, char **argv) {
+    if (argc < 2 || argc > 3) {
+        fprintf(stderr, "Usage: %s  [--ignore-merges]\n", argv[0]);
+        return 1;
+    }
+
+    const std::string fname = argv[1];
+    bool ignore_merges = false;
+    if (argc == 3) {
+        if (std::strcmp(argv[2], "--ignore-merges") != 0) {
+            fprintf(stderr, "Usage: %s  [--ignore-merges]\n", argv[0]);
+            return 1;
+        }
+        ignore_merges = true;
+    }
+
+    fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
+
+    if (ignore_merges) {
+        fprintf(stderr, "%s : ignoring merges for tokens inside vocab\n", __func__);
+    }
+
+    llama_model * model;
+    llama_context * ctx;
+
+    llama_backend_init();
+
+    // load the vocab
+    {
+        auto mparams = llama_model_default_params();
+
+        mparams.vocab_only = true;
+
+        model = llama_model_load_from_file(fname.c_str(), mparams);
+
+        if (model == NULL) {
+            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
+            return 1;
+        }
+
+        auto cparams = llama_context_default_params();
+
+        ctx = llama_init_from_model(model, cparams);
+
+        if (ctx == NULL) {
+            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
+            llama_model_free(model);
+            return 1;
+        }
+    }
+
+    const llama_vocab * vocab = llama_model_get_vocab(model);
+
+    //GGML_ASSERT(llama_vocab_type(vocab) == LLAMA_VOCAB_TYPE_BPE);
+    if (llama_vocab_type(vocab) != LLAMA_VOCAB_TYPE_BPE) {
+        return 99;
+    }
+
+#ifdef _WIN32
+    // We need this for unicode console support
+    console::init(false, false);
+    atexit([]() { console::cleanup(); });
+#endif
+
+    const int n_vocab = llama_vocab_n_tokens(vocab);
+
+    for (int i = 0; i < n_vocab; ++i) {
+        std::string str = common_detokenize(ctx, std::vector(1, i));
+        try {
+            auto cps = unicode_cpts_from_utf8(str);
+            std::vector tokens = common_tokenize(ctx, str, false, true);
+            if (ignore_merges && tokens.size() > 1) {
+                fprintf(stderr,
+                        "%s : error: token %d detokenizes to '%s'(%zu) but "
+                        "tokenization of this to multiple tokens: [",
+                        __func__, i, str.c_str(), str.length());
+                fprintf(stderr, "%d", tokens[0]);
+                for (size_t i = 1; i < tokens.size(); i++) {
+                    fprintf(stderr, ", %d", tokens[i]);
+                }
+                fprintf(stderr, "]\n");
+                return 2;
+            }
+            std::string check = common_detokenize(ctx, tokens);
+            if (check != str) {
+                fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but tokenization of this detokenizes to '%s'(%zu)\n",
+                    __func__, i, str.c_str(), str.length(), check.c_str(), check.length());
+                return 2;
+            }
+        }
+        catch (const std::invalid_argument &) {
+            //fprintf(stderr, "%s : info: utf8 conversion %d '%s'\n", __func__, i, str.c_str());
+        }
+    }
+
+    // unicode
+    {
+        const int nthread = std::thread::hardware_concurrency();
+
+        std::vector threads(nthread);
+
+        std::atomic_int errcode = {};
+
+        for (int i = 0; i < nthread; ++i) {
+            threads[i] = std::thread([i, nthread, ctx, &errcode]() {
+                for (uint32_t cp = i; !errcode && cp < 0x00110000; cp += nthread) {
+                    if ((0x0000D800 <= cp && cp <= 0x0000DFFF) ||  // surrogates \p{Cs}
+                        (0x00040000 <= cp && cp <= 0x000E0000)) {  // undefined  \p{Cn}
+                        continue;
+                    }
+
+                    std::string str = unicode_cpt_to_utf8(cp);
+                    std::vector tokens = common_tokenize(ctx, str, false);
+                    std::string check = common_detokenize(ctx, tokens);
+                    if (cp != 9601 && str != check) {
+                        fprintf(stderr, "error: codepoint 0x%x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
+                                cp, check.c_str(), check.length(), str.c_str(), str.length());
+                        errcode = 3;
+                    }
+                }
+            });
+        }
+
+        for (auto & t : threads) {
+            t.join();
+        }
+
+        if (errcode) {
+            return errcode;
+        }
+    }
+
+    llama_model_free(model);
+    llama_free(ctx);
+
+    llama_backend_free();
+
+    return 0;
+}
diff --git a/tests/test-tokenizer-1-spm.cpp b/tests/test-tokenizer-1-spm.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..ba6e94ba8ea57a051469a07faf4e387713185f03
--- /dev/null
+++ b/tests/test-tokenizer-1-spm.cpp
@@ -0,0 +1,125 @@
+#include "llama.h"
+#include "common.h"
+#include "console.h"
+
+#include "../src/unicode.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+int main(int argc, char ** argv) {
+    if (argc < 2) {
+        fprintf(stderr, "Usage: %s \n", argv[0]);
+        return 1;
+    }
+
+    const std::string fname = argv[1];
+
+    fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
+
+    llama_model * model;
+    llama_context * ctx;
+
+    llama_backend_init();
+
+    // load the vocab
+    {
+        auto mparams = llama_model_default_params();
+
+        mparams.vocab_only = true;
+
+        model = llama_model_load_from_file(fname.c_str(), mparams);
+
+        if (model == NULL) {
+            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
+            return 1;
+        }
+
+        auto cparams = llama_context_default_params();
+
+        ctx = llama_init_from_model(model, cparams);
+
+        if (ctx == NULL) {
+            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
+            llama_model_free(model);
+            return 1;
+        }
+    }
+
+    const llama_vocab * vocab = llama_model_get_vocab(model);
+
+    //GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
+    if (llama_vocab_type(vocab) != LLAMA_VOCAB_TYPE_SPM) {
+        return 99;
+    }
+
+#ifdef _WIN32
+    // We need this for unicode console support
+    console::init(false, false);
+    atexit([]() { console::cleanup(); });
+#endif
+
+    const int n_vocab = llama_vocab_n_tokens(vocab);
+
+    for (int i = 0; i < n_vocab; ++i) {
+        std::string str = common_detokenize(ctx, std::vector(1, i), true);
+        std::vector tokens = common_tokenize(ctx, str, false, true);
+        std::string check = common_detokenize(ctx, tokens);
+        if (check != str) {
+            fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but tokenization of this detokenizes to '%s'(%zu)\n",
+                __func__, i, str.c_str(), str.length(), check.c_str(), check.length());
+            return 2;
+        }
+    }
+
+    // unicode
+    {
+        const int nthread = std::thread::hardware_concurrency();
+
+        std::vector threads(nthread);
+
+        std::atomic_int errcode = {};
+
+        for (int i = 0; i < nthread; ++i) {
+            threads[i] = std::thread([i, nthread, ctx, &errcode]() {
+                for (uint32_t cp = i; !errcode && cp < 0x00110000; cp += nthread) {
+                    if ((0x0000D800 <= cp && cp <= 0x0000DFFF) ||  // surrogates \p{Cs}
+                        (0x00040000 <= cp && cp <= 0x000E0000)) {  // undefined \p{Cn}
+                        continue;
+                    }
+
+                    std::string str = unicode_cpt_to_utf8(cp);
+                    std::vector tokens = common_tokenize(ctx, str, false, true);
+                    std::string check = common_detokenize(ctx, tokens);
+                    if (cp != 9601 && str != check) {
+                        fprintf(stderr, "error: codepoint 0x%x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
+                                cp, check.c_str(), check.length(), str.c_str(), str.length());
+                        errcode = 3;
+                    }
+                }
+            });
+        }
+
+        for (auto & t : threads) {
+            t.join();
+        }
+
+        if(errcode) {
+            return errcode;
+        }
+    }
+
+    llama_model_free(model);
+    llama_free(ctx);
+
+    llama_backend_free();
+
+    return 0;
+}
diff --git a/tests/test-tokenizer-random.py b/tests/test-tokenizer-random.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6cdcb55482e7df8cb239f8e7e68bc68db903289
--- /dev/null
+++ b/tests/test-tokenizer-random.py
@@ -0,0 +1,566 @@
+# Test libllama tokenizer == AutoTokenizer.
+# Brute force random words/text generation.
+#
+# Sample usage:
+#
+#   python3 tests/test-tokenizer-random.py ./models/ggml-vocab-llama-bpe.gguf ./models/tokenizers/llama-bpe
+#
+
+from __future__ import annotations
+
+import time
+import logging
+import argparse
+import subprocess
+import random
+import unicodedata
+
+from pathlib import Path
+from typing import Any, Iterator, cast
+from typing_extensions import Buffer
+
+import cffi
+from transformers import AutoTokenizer, PreTrainedTokenizer
+
+
+logger = logging.getLogger("test-tokenizer-random")
+
+
+class LibLlama:
+
+    DEFAULT_PATH_LLAMA_H = "./include/llama.h"
+    DEFAULT_PATH_INCLUDES = ["./ggml/include/", "./include/"]
+    DEFAULT_PATH_LIBLLAMA = "./build/src/libllama.so"  # CMakeLists.txt: BUILD_SHARED_LIBS ON
+
+    def __init__(self, path_llama_h: str | None = None, path_includes: list[str] = [], path_libllama: str | None = None):
+        path_llama_h = path_llama_h or self.DEFAULT_PATH_LLAMA_H
+        path_includes = path_includes or self.DEFAULT_PATH_INCLUDES
+        path_libllama = path_libllama or self.DEFAULT_PATH_LIBLLAMA
+        (self.ffi, self.lib) = self._load_libllama_cffi(path_llama_h, path_includes, path_libllama)
+        self.lib.llama_backend_init()
+
+    def _load_libllama_cffi(self, path_llama_h: str, path_includes: list[str], path_libllama: str) -> tuple[cffi.FFI, Any]:
+        cmd = ["gcc", "-O0", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)="]
+        cmd += ["-I" + path for path in path_includes] + [path_llama_h]
+        res = subprocess.run(cmd, stdout=subprocess.PIPE)
+        assert (res.returncode == 0)
+        source = res.stdout.decode()
+        ffi = cffi.FFI()
+        if True:  # workarounds for pycparser
+            source = "typedef struct { } __builtin_va_list;" + "\n" + source
+            source = source.replace("sizeof (int)",    str(ffi.sizeof("int")))
+            source = source.replace("sizeof (void *)", str(ffi.sizeof("void*")))
+            source = source.replace("sizeof (size_t)", str(ffi.sizeof("size_t")))
+            source = source.replace("sizeof(int32_t)", str(ffi.sizeof("int32_t")))
+        ffi.cdef(source, override=True)
+        lib = ffi.dlopen(path_libllama)
+        return (ffi, lib)
+
+    def model_default_params(self, **kwargs):
+        mparams = self.lib.llama_model_default_params()
+        for k, v in kwargs.items():
+            setattr(mparams, k, v)
+        return mparams
+
+    def context_default_params(self, **kwargs):
+        cparams = self.lib.llama_context_default_params()
+        for k, v in kwargs.items():
+            setattr(cparams, k, v)
+        return cparams
+
+
+class LibLlamaModel:
+
+    def __init__(self, libllama: LibLlama, path_model: str, mparams={}, cparams={}):
+        self.lib: Any = libllama.lib
+        self.ffi = libllama.ffi
+        if isinstance(mparams, dict):
+            mparams = libllama.model_default_params(**mparams)
+        self.model = self.lib.llama_model_load_from_file(path_model.encode(), mparams)
+        if not self.model:
+            raise RuntimeError("error: failed to load model '%s'" % path_model)
+        if isinstance(cparams, dict):
+            cparams = libllama.context_default_params(**cparams)
+        self.ctx = self.lib.llama_new_context_with_model(self.model, cparams)
+        if not self.ctx:
+            raise RuntimeError("error: failed to create context for model '%s'" % path_model)
+        n_tokens_max = self.lib.llama_n_ctx(self.ctx)
+        self.token_ids = self.ffi.new("llama_token[]", n_tokens_max)
+        self.text_buff = self.ffi.new("uint8_t[]", 1024)
+
+    def free(self):
+        if self.ctx:
+            self.lib.llama_free(self.ctx)
+        if self.model:
+            self.lib.llama_model_free(self.model)
+        self.ctx = None
+        self.model = None
+        self.lib = None
+
+    def tokenize(self, text: str, add_special: bool = False, parse_special: bool = False) -> list[int]:
+        encoded_text: bytes = text.encode("utf-8")
+        num = self.lib.llama_tokenize(self.model, encoded_text, len(encoded_text), self.token_ids, len(self.token_ids), add_special, parse_special)
+        while num < 0 and len(self.token_ids) < (16 << 20):
+            self.token_ids = self.ffi.new("llama_token[]", -2 * num)
+            num = self.lib.llama_tokenize(self.model, encoded_text, len(encoded_text), self.token_ids, len(self.token_ids), add_special, parse_special)
+        return list(self.token_ids[0:num])
+
+    def detokenize(self, ids: list[int], remove_special: bool = False, unparse_special: bool = False) -> str:
+        if len(self.token_ids) < len(ids):
+            self.token_ids = self.ffi.new("llama_token[]", 2 * len(ids))
+        for i, id in enumerate(ids):
+            self.token_ids[i] = id
+        num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special)
+        while num < 0 and len(self.text_buff) < (16 << 20):
+            self.text_buff = self.ffi.new("uint8_t[]", -2 * num)
+            num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special)
+        return str(cast(Buffer, self.ffi.buffer(self.text_buff, num)), encoding="utf-8", errors="replace")  # replace errors with '\uFFFD'
+
+
+class Tokenizer:
+
+    def encode(self, text: str) -> list[int]:
+        raise NotImplementedError
+
+    def decode(self, ids: list[int]) -> str:
+        raise NotImplementedError
+
+
+class TokenizerGroundtruth (Tokenizer):
+
+    def __init__(self, dir_tokenizer: str):
+        self.model: PreTrainedTokenizer = AutoTokenizer.from_pretrained(dir_tokenizer)
+        # guess BOS and EOS
+        ids = self.encode("a")
+        assert 1 <= len(ids) <= 3
+        add_bos_token = len(ids) > 1 and self.model.bos_token_id == ids[0]
+        add_eos_token = len(ids) > 1 and self.model.eos_token_id == ids[-1]
+        self.add_bos_token = getattr(self.model, "add_bos_token", add_bos_token)
+        self.add_eos_token = getattr(self.model, "add_eos_token", add_eos_token)
+        # build vocab
+        tokens = list(self.model.get_vocab().values())
+        self.vocab = self.model.batch_decode(tokens, skip_special_tokens=True)
+        self.vocab = list(sorted(self.vocab))
+        # tokens and lists
+        self.special_tokens = list(self.model.all_special_tokens)
+        self.added_tokens   = self.model.batch_decode(self.model.added_tokens_encoder.values(), skip_special_tokens=False)
+        self.bos_token = self.model.bos_token
+        self.eos_token = self.model.eos_token
+
+    def encode(self, text: str) -> list[int]:
+        return self.model.encode(text, add_special_tokens=True)
+
+    def decode(self, ids: list[int]) -> str:
+        return self.model.decode(ids, skip_special_tokens=False)
+
+
+class TokenizerLlamaCpp (Tokenizer):
+
+    libllama: LibLlama | None = None
+
+    def __init__(self, vocab_file: str):
+        if not self.libllama:
+            self.libllama = LibLlama()
+        self.model = LibLlamaModel(self.libllama, vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
+
+    def encode(self, text: str) -> list[int]:
+        return self.model.tokenize(text, add_special=True, parse_special=True)
+
+    def decode(self, ids: list[int]) -> str:
+        return self.model.detokenize(ids, remove_special=False, unparse_special=True)
+
+
+def generator_custom_text() -> Iterator[str]:
+    """General tests"""
+    yield from [
+        "",
+        " ",
+        "  ",
+        "   ",
+        "\t",
+        "\n",
+        "\n\n",
+        "\n\n\n",
+        "\t\n",
+        "Hello world",
+        " Hello world",
+        "Hello World",
+        " Hello World",
+        " Hello World!",
+        "Hello, world!",
+        " Hello, world!",
+        " this is 🦙.cpp",
+        "w048 7tuijk dsdfhu",
+        "нещо на Български",
+        "កាន់តែពិសេសអាចខលចេញ",
+        "🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)",
+        "Hello",
+        " Hello",
+        "  Hello",
+        "   Hello",
+        "    Hello",
+        "    Hello\n    Hello",
+        " (",
+        "\n =",
+        "' era",
+        "Hello, y'all! How are you 😁 ?我想在apple工作1314151天~",
+        "3",
+        "33",
+        "333",
+        "3333",
+        "33333",
+        "333333",
+        "3333333",
+        "33333333",
+        "333333333",
+    ]
+
+
+def generator_custom_text_edge_cases() -> Iterator[str]:
+    """Edge cases found while debugging"""
+    yield from [
+        '\x1f-a',     # unicode_ranges_control, {0x00001C, 0x00001F}
+        '¼-a',        # unicode_ranges_digit, 0x00BC
+        '½-a',        # unicode_ranges_digit, 0x00BD
+        '¾-a',        # unicode_ranges_digit, 0x00BE
+        'a 〇b',      # unicode_ranges_digit, 0x3007
+        'Ⅵ-a',       # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms
+        '\uFEFF//',   # unicode_ranges_control, 0xFEFF (BOM)
+        'Cửa Việt',   # llama-3, ignore_merges = true
+        'a',       # Phi-3 fail
+        '<|endoftext|>',  # Phi-3 fail
+        'a\na',            # bert fail
+        '"`',              # falcon
+        ' \u2e4e',         # falcon
+        '\n\x0b  ',        # falcon
+        'a\xa0\xa0\x00b',  # jina-v2-es
+        'one ',      # jina-v2-es   lstrip=true
+        'a  b',        # rstrip phi-3
+        'a  b',      # lstrip jina-v2
+        '\xa0aC',          # deepseek
+        '\u2029 \uA3E4',   # deepseek-llm
+        "a ?",
+        'å',               # mpt
+        '\U000ac517',      # utf-8 encode error, falcon
+        '\U000522f4',      # utf-8 encode error, starcoder
+        "abcd",
+        "  abcd",
+    ]
+
+
+def generator_vocab_words(tokenizer: TokenizerGroundtruth) -> Iterator[str]:
+    """Brute force check all vocab words"""
+    yield from tokenizer.vocab
+
+
+def generator_ascii_lr_strip() -> Iterator[str]:
+    WHITESPACES = ["", " ", "  "]
+    CHARACTERS = list(chr(i) for i in range(1, 0x80)) + [""]
+    for char1 in CHARACTERS:
+        for char2 in CHARACTERS:
+            for lstrip in WHITESPACES:
+                for rstrip in WHITESPACES:
+                    yield lstrip + char1 + char2 + rstrip
+                    yield lstrip + char1 + rstrip + char2
+                    yield char1 + lstrip + char2 + rstrip
+
+
+def generator_apostrophe() -> Iterator[str]:
+    WHITESPACES = ["", " ", "  "]
+    CHARACTERS = list(chr(i) for i in range(1, 0x80)) + [""]
+    for char1 in CHARACTERS:
+        for char2 in CHARACTERS:
+            for lstrip in WHITESPACES:
+                for rstrip in WHITESPACES:
+                    yield char1 + lstrip + "'" + rstrip + char2
+                    yield char1 + char2 + lstrip + "'" + rstrip + "z"
+                    yield "a" + lstrip + "'" + rstrip + char1 + char2
+
+
+def generator_added_lr_strip(tokenizer: TokenizerGroundtruth) -> Iterator[str]:
+    WHITESPACES = ["", " ", "  ", "\n", "\r\n", "\n\n", "\t", "\t\t"]
+    all_tokens = list(sorted(set(tokenizer.special_tokens + tokenizer.added_tokens)))
+    for token in all_tokens:
+        for lstrip in WHITESPACES:
+            for rstrip in WHITESPACES:
+                yield lstrip + token + rstrip
+                yield "a" + lstrip + token + rstrip
+                yield lstrip + token + rstrip + "z"
+                yield "a" + lstrip + token + rstrip + "z"
+
+
+def generator_random_added_tokens(tokenizer: TokenizerGroundtruth, iterations=100) -> Iterator[str]:
+    separations = [" ", "\n", "\t", "-", "!", "one", "1", "", ""]
+    all_tokens  = list(sorted(set(tokenizer.special_tokens + tokenizer.added_tokens + separations)))
+    rand = random.Random()
+    for m in range(iterations):
+        rand.seed(m)
+        words = rand.choices(all_tokens, k=500)
+        if words and words[0] == tokenizer.bos_token:  # skip spam warning of double BOS
+            while len(words) > 1 and words[1] == tokenizer.bos_token:  # leave one starting BOS
+                words.pop(0)
+            if tokenizer.add_bos_token:  # drop all starting BOS
+                words.pop(0)
+        if words and words[-1] == tokenizer.eos_token:  # skip spam warning of double EOS
+            while len(words) > 1 and words[-2] == tokenizer.eos_token:  # leave one trailing EOS
+                words.pop(-1)
+            if tokenizer.add_bos_token:  # drop all trailing EOS
+                words.pop(-1)
+        yield "".join(words)
+
+
+def generator_random_chars(iterations=100) -> Iterator[str]:
+    """Brute force random text with simple characters"""
+
+    NUM_WORDS = 400
+    WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
+    CHARS = list(sorted(set("""
+        ABCDEFGHIJKLMNOPQRSTUVWXYZ
+        abcdefghijklmnopqrstuvwxyz
+        ÁÉÍÓÚÀÈÌÒÙÂÊÎÔÛÄËÏÖÜ
+        áéíóúàèìòùâêîôûäëïöü
+        .-,*/-+ª!"·$%&/()=?¿[]{}<>\\|@#~½¬~;:_
+    """)))
+
+    rand = random.Random()
+    for m in range(iterations):
+        rand.seed(m)
+        text = []
+        for _ in range(NUM_WORDS):
+            k = rand.randint(1, 7)
+            word = rand.choices(CHARS, k=k)
+            word.append(rand.choice(WHITESPACES))
+            text.append("".join(word))
+        yield "".join(text)
+
+
+def generator_unicodes() -> Iterator[str]:
+    """Iterate unicode characters"""
+
+    MAX_CODEPOINTS = 0x30000  # 0x110000
+
+    def _valid(cpt):
+        if cpt >= 0x30000:  # unassigned and supplement­ary
+            return False
+        # if cpt == 0x2029:  # deepseek-llm
+        #    return False
+        if unicodedata.category(chr(cpt)) in ("Cn", "Cs", "Co"):  # undefined, surrogates, private
+            return False
+        return True
+
+    characters = [chr(cpt) for cpt in range(0, MAX_CODEPOINTS) if _valid(cpt)]
+
+    yield from characters
+
+
+def generator_random_unicodes(iterations=100) -> Iterator[str]:
+    """Brute force random text with unicode characters"""
+
+    NUM_WORDS = 200
+    WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
+
+    characters = list(generator_unicodes())
+
+    rand = random.Random()
+    for m in range(iterations):
+        rand.seed(m)
+        text = []
+        for _ in range(NUM_WORDS):
+            k = rand.randint(1, 7)
+            word = rand.choices(characters, k=k)
+            word.append(rand.choice(WHITESPACES))
+            text.append("".join(word))
+        yield "".join(text)
+
+
+def generator_random_vocab_chars(tokenizer: TokenizerGroundtruth, iterations=100) -> Iterator[str]:
+    """Brute force random text with vocab characters"""
+
+    vocab_chars = set()
+    for word in tokenizer.vocab:
+        vocab_chars.update(word)
+    vocab_chars = list(sorted(vocab_chars))
+
+    rand = random.Random()
+    for m in range(iterations):
+        rand.seed(m)
+        text = rand.choices(vocab_chars, k=1024)
+        yield "".join(text)
+
+
+def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100) -> Iterator[str]:
+    """Brute force random text from vocab words"""
+
+    vocab = [w.strip() for w in tokenizer.vocab]
+    yield from vocab
+
+    rand = random.Random()
+    for m in range(iterations):
+        rand.seed(m)
+        text = []
+        num_words = rand.randint(300, 400)
+        for i in range(num_words):
+            k = rand.randint(1, 3)
+            words = rand.choices(vocab, k=k)
+            sep = rand.choice("     \n\r\t")
+            text.append("".join(words) + sep)
+        yield "".join(text)
+
+
+def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
+
+    def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str):
+        for i, (a, b) in enumerate(zip(ids1, ids2)):
+            if a != b:
+                return i
+        if len(ids1) == len(ids2):
+            return -1
+        return min(len(ids1), len(ids2))
+
+    def check_detokenizer(text: str, text1: str, text2: str) -> bool:
+        if text1 == text2:  # equal to TokenizerGroundtruth?
+            return True
+        # equal to source text?
+        if tokenizer1.add_bos_token:  # remove BOS
+            if text2.startswith(tokenizer1.bos_token):
+                text2 = text2[len(tokenizer1.bos_token):]
+        if tokenizer1.add_eos_token:  # remove EOS
+            if text2.endswith(tokenizer1.eos_token):
+                text2 = text2[:-len(tokenizer1.eos_token)]
+        return text == text2
+
+    t_encode1 = 0
+    t_encode2 = 0
+    t_decode1 = 0
+    t_decode2 = 0
+    t_start = time.perf_counter()
+    encode_errors = 0
+    decode_errors = 0
+    MAX_ERRORS = 10
+
+    logger.info("%s: %s" % (generator.__qualname__, "ini"))
+    for text in generator:
+        # print(repr(text), text.encode())
+        # print(repr(text), hex(ord(text[0])), text.encode())
+        t0 = time.perf_counter()
+        ids1 = tokenizer1.encode(text)
+        t1 = time.perf_counter()
+        ids2 = tokenizer2.encode(text)
+        t2 = time.perf_counter()
+        text1 = tokenizer1.decode(ids1)
+        t3 = time.perf_counter()
+        text2 = tokenizer2.decode(ids1)
+        t4 = time.perf_counter()
+        t_encode1 += t1 - t0
+        t_encode2 += t2 - t1
+        t_decode1 += t3 - t2
+        t_decode2 += t4 - t3
+        if encode_errors < MAX_ERRORS and ids1 != ids2:
+            i = find_first_mismatch(ids1, ids2)
+            ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
+            ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
+            logger.error(" Expected: " + str(ids1))
+            logger.error("   Result: " + str(ids2))
+            encode_errors += 1
+            logger.error(f" {encode_errors=}")
+        if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2):
+            i = find_first_mismatch(text1, text2)
+            text1 = list(text1[max(0, i - 2) : i + 5 + 1])
+            text2 = list(text2[max(0, i - 2) : i + 5 + 1])
+            logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1))
+            logger.error("   Result: " + " ".join(hex(ord(x)) for x in text2))
+            decode_errors += 1
+            logger.error(f" {decode_errors=}")
+        if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS:
+            logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
+            # raise Exception()
+            break
+
+    t_total = time.perf_counter() - t_start
+    logger.info(f"{generator.__qualname__}: end,  {t_encode1=:.3f} {t_encode2=:.3f}  {t_decode1=:.3f} {t_decode2=:.3f}  {t_total=:.3f}")
+
+
+def main(argv: list[str] | None = None):
+    parser = argparse.ArgumentParser()
+    parser.add_argument("vocab_file", type=str, help="path to vocab 'gguf' file")
+    parser.add_argument("dir_tokenizer", type=str, help="directory containing 'tokenizer.model' file")
+    parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
+    args = parser.parse_args(argv)
+
+    logging.basicConfig(level = logging.DEBUG if args.verbose else logging.INFO)
+    logger.info(f"VOCABFILE: '{args.vocab_file}'")
+
+    tokenizer1 = TokenizerGroundtruth(args.dir_tokenizer)
+    tokenizer2 = TokenizerLlamaCpp(args.vocab_file)
+
+    # compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text())
+    # compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text_edge_cases())
+    compare_tokenizers(tokenizer1, tokenizer2, generator_ascii_lr_strip())
+    compare_tokenizers(tokenizer1, tokenizer2, generator_apostrophe())
+    compare_tokenizers(tokenizer1, tokenizer2, generator_unicodes())
+    compare_tokenizers(tokenizer1, tokenizer2, generator_vocab_words(tokenizer1))
+    compare_tokenizers(tokenizer1, tokenizer2, generator_added_lr_strip(tokenizer1))
+    # compare_tokenizers(tokenizer1, tokenizer2, generator_random_added_tokens(tokenizer1, 10_000))
+    # compare_tokenizers(tokenizer1, tokenizer2, generator_random_chars(10_000))
+    # compare_tokenizers(tokenizer1, tokenizer2, generator_random_unicodes(10_000))
+    # compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_chars(tokenizer1, 10_000))
+    # compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_words(tokenizer1, 5_000))
+
+    tokenizer2.model.free()
+
+
+if __name__ == "__main__":
+    # main()
+
+    if True:
+        logging.basicConfig(
+            level    = logging.DEBUG,
+            format   = "%(asctime)s.%(msecs)03d %(name)s %(levelname)s %(message)s",
+            datefmt  = "%Y-%m-%d %H:%M:%S",
+            filename = logger.name + ".log",
+            filemode = "a"
+        )
+    logging.basicConfig(
+        level    = logging.DEBUG,
+        format   = "%(levelname)s %(message)s",
+    )
+
+    path_tokenizers   = Path("./models/tokenizers/")
+    path_vocab_format = "./models/ggml-vocab-%s.gguf"
+
+    tokenizers = [
+        "llama-spm",      # SPM
+        "phi-3",          # SPM
+        "gemma",          # SPM
+        "gemma-2",        # SPM
+        "baichuan",       # SPM
+        "bert-bge",       # WPM
+        "jina-v2-en",     # WPM
+        "llama-bpe",      # BPE
+        "phi-2",          # BPE
+        "deepseek-llm",   # BPE
+        "deepseek-coder", # BPE
+        "falcon",         # BPE
+        "mpt",            # BPE
+        "starcoder",      # BPE
+        "gpt-2",          # BPE
+        "stablelm2",      # BPE
+        "refact",         # BPE
+        "qwen2",          # BPE
+        "olmo",           # BPE
+        "jina-v2-es",     # BPE
+        "jina-v2-de",     # BPE
+        "smaug-bpe",      # BPE
+        "poro-chat",      # BPE
+        "jina-v2-code",   # BPE
+        "viking",         # BPE
+        "jais",           # BPE
+    ]
+
+    logger.info("=" * 50)
+    for tokenizer in tokenizers:
+        logger.info("-" * 50)
+        logger.info(f"TOKENIZER: '{tokenizer}'")
+        vocab_file = Path(path_vocab_format % tokenizer)
+        dir_tokenizer = path_tokenizers / tokenizer
+        main([str(vocab_file), str(dir_tokenizer), "--verbose"])
diff --git a/tests/test-tokenizers-repo.sh b/tests/test-tokenizers-repo.sh
new file mode 100644
index 0000000000000000000000000000000000000000..86e839133ce62741dfc6853a0a380b7dfd62c642
--- /dev/null
+++ b/tests/test-tokenizers-repo.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+
+if [ $# -lt 2 ]; then
+    printf "Usage: $0   []\n"
+    exit 1
+fi
+
+if [ $# -eq 3 ]; then
+    toktest=$3
+else
+    toktest="./test-tokenizer-0"
+fi
+
+if [ ! -x $toktest ]; then
+    printf "Test executable \"$toktest\" not found!\n"
+    exit 1
+fi
+
+repo=$1
+folder=$2
+
+if [ -d $folder ] && [ -d $folder/.git ]; then
+    (cd $folder; git pull)
+else
+    git clone $repo $folder
+fi
+
+shopt -s globstar
+for gguf in $folder/**/*.gguf; do
+    if [ -f $gguf.inp ] && [ -f $gguf.out ]; then
+        $toktest $gguf
+    else
+        printf "Found \"$gguf\" without matching inp/out files, ignoring...\n"
+    fi
+done
+
diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d64956b84385168fad30b236adf68fd2ad95978a
--- /dev/null
+++ b/tools/CMakeLists.txt
@@ -0,0 +1,39 @@
+# dependencies
+
+find_package(Threads REQUIRED)
+
+# third-party
+
+# ...
+
+# flags
+
+llama_add_compile_flags()
+
+# tools
+
+if (EMSCRIPTEN)
+else()
+    add_subdirectory(batched-bench)
+    add_subdirectory(gguf-split)
+    add_subdirectory(imatrix)
+    add_subdirectory(llama-bench)
+    add_subdirectory(main)
+    add_subdirectory(perplexity)
+    add_subdirectory(quantize)
+    if (LLAMA_BUILD_SERVER)
+        add_subdirectory(server)
+    endif()
+    add_subdirectory(run)
+    add_subdirectory(tokenize)
+    add_subdirectory(tts)
+    add_subdirectory(mtmd)
+    if (GGML_RPC)
+        add_subdirectory(rpc)
+    endif()
+    if (NOT GGML_BACKEND_DL)
+        # these examples use the backends directly and cannot be built with dynamic loading
+        add_subdirectory(cvector-generator)
+        add_subdirectory(export-lora)
+    endif()
+endif()
diff --git a/tools/batched-bench/CMakeLists.txt b/tools/batched-bench/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..68ad707f32c980579cb5af2367154eadcb7df80b
--- /dev/null
+++ b/tools/batched-bench/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(TARGET llama-batched-bench)
+add_executable(${TARGET} batched-bench.cpp)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_17)
diff --git a/tools/batched-bench/README.md b/tools/batched-bench/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..df67c47e378cfff595cf182a62a4004cd4bd64f8
--- /dev/null
+++ b/tools/batched-bench/README.md
@@ -0,0 +1,60 @@
+# llama.cpp/example/batched-bench
+
+Benchmark the batched decoding performance of `llama.cpp`
+
+## Usage
+
+There are 2 modes of operation:
+
+- `prompt not shared` - each batch has a separate prompt of size `PP` (i.e. `N_KV = B*(PP + TG)`)
+- `prompt is shared` - there is a common prompt of size `PP` used by all batches (i.e. `N_KV = PP + B*TG`)
+
+```bash
+./llama-batched-bench -m model.gguf -c 2048 -b 2048 -ub 512 -npp 128,256,512 -ntg 128,256 -npl 1,2,4,8,16,32 [-pps]
+
+# LLaMA 7B, F16, N_KV_MAX = 16384 (8GB), prompt not shared
+./llama-batched-bench -m ./models/llama-7b/ggml-model-f16.gguf -c 16384 -b 2048 -ub 512 -ngl 99
+
+# LLaMA 7B, Q8_0, N_KV_MAX = 16384 (8GB), prompt is shared
+./llama-batched-bench -m ./models/llama-7b/ggml-model-q8_0.gguf -c 16384 -b 2048 -ub 512 -ngl 99 -pps
+
+# custom set of batches
+./llama-batched-bench -m ./models/llama-7b/ggml-model-q8_0.gguf -c 2048 -b 512 -ub 512 -ngl 999 -npp 128,256,512 -ntg 128,256 -npl 1,2,4,8,16,32
+```
+
+## Sample results
+
+- `PP` - prompt tokens per batch
+- `TG` - generated tokens per batch
+- `B` - number of batches
+- `N_KV` - required KV cache size
+- `T_PP` - prompt processing time (i.e. time to first token)
+- `S_PP` - prompt processing speed (`(B*PP)/T_PP` or `PP/T_PP`)
+- `T_TG` - time to generate all batches
+- `S_TG` - text generation speed (`(B*TG)/T_TG`)
+- `T` - total time
+- `S` - total speed (i.e. all tokens / total time)
+
+|    PP |     TG |    B |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |      T s |    S t/s |
+|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
+|   128 |    128 |    1 |    256 |    0.108 |  1186.64 |    3.079 |    41.57 |    3.187 |    80.32 |
+|   128 |    128 |    2 |    512 |    0.198 |  1295.19 |    5.029 |    50.90 |    5.227 |    97.95 |
+|   128 |    128 |    4 |   1024 |    0.373 |  1373.96 |    6.878 |    74.44 |    7.251 |   141.23 |
+|   128 |    128 |    8 |   2048 |    0.751 |  1363.27 |    7.344 |   139.43 |    8.095 |   252.99 |
+|   128 |    128 |   16 |   4096 |    1.570 |  1304.68 |    8.455 |   242.23 |   10.024 |   408.60 |
+|   128 |    128 |   32 |   8192 |    3.408 |  1201.73 |    8.801 |   465.40 |   12.209 |   670.96 |
+|   128 |    256 |    1 |    384 |    0.107 |  1196.70 |    6.329 |    40.45 |    6.436 |    59.67 |
+|   128 |    256 |    2 |    768 |    0.194 |  1317.45 |   10.239 |    50.00 |   10.433 |    73.61 |
+|   128 |    256 |    4 |   1536 |    0.366 |  1399.03 |   13.960 |    73.35 |   14.326 |   107.22 |
+|   128 |    256 |    8 |   3072 |    0.751 |  1363.92 |   15.110 |   135.54 |   15.861 |   193.69 |
+|   128 |    256 |   16 |   6144 |    1.569 |  1304.93 |   18.073 |   226.64 |   19.642 |   312.80 |
+|   128 |    256 |   32 |  12288 |    3.409 |  1201.35 |   19.223 |   426.15 |   22.633 |   542.93 |
+
+### JSONL output
+
+Pass `--output-format jsonl` to output JSONL instead of Markdown, á la
+
+```json lines
+{"n_kv_max": 2048, "n_batch": 2048, "n_ubatch": 512, "flash_attn": 0, "is_pp_shared": 0, "n_gpu_layers": 99, "n_threads": 8, "n_threads_batch": 8, "pp": 128, "tg": 128, "pl": 1, "n_kv": 256, "t_pp": 0.233810, "speed_pp": 547.453064, "t_tg": 3.503684, "speed_tg": 36.532974, "t": 3.737494, "speed": 68.495094}
+{"n_kv_max": 2048, "n_batch": 2048, "n_ubatch": 512, "flash_attn": 0, "is_pp_shared": 0, "n_gpu_layers": 99, "n_threads": 8, "n_threads_batch": 8, "pp": 128, "tg": 128, "pl": 2, "n_kv": 512, "t_pp": 0.422602, "speed_pp": 605.770935, "t_tg": 11.106112, "speed_tg": 23.050371, "t": 11.528713, "speed": 44.410854}
+```
diff --git a/tools/batched-bench/batched-bench.cpp b/tools/batched-bench/batched-bench.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..a0a2e5ac56ea94eb88bf552a42570ae8e76ec413
--- /dev/null
+++ b/tools/batched-bench/batched-bench.cpp
@@ -0,0 +1,206 @@
+#include "arg.h"
+#include "common.h"
+#include "log.h"
+#include "llama.h"
+
+#include 
+#include 
+#include 
+#include 
+
+static void print_usage(int, char ** argv) {
+    LOG("\nexample usage:\n");
+    LOG("\n    %s -m model.gguf -c 2048 -b 2048 -ub 512 -npp 128,256,512 -ntg 128,256 -npl 1,2,4,8,16,32 [-pps]\n", argv[0]);
+    LOG("\n");
+}
+
+int main(int argc, char ** argv) {
+    common_params params;
+
+    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_BENCH, print_usage)) {
+        return 1;
+    }
+
+    common_init();
+
+    int is_pp_shared = params.is_pp_shared;
+
+    std::vector n_pp = params.n_pp;
+    std::vector n_tg = params.n_tg;
+    std::vector n_pl = params.n_pl;
+
+    // init LLM
+
+    llama_backend_init();
+    llama_numa_init(params.numa);
+
+    // initialize the model
+
+    llama_model_params model_params = common_model_params_to_llama(params);
+
+    llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params);
+
+    if (model == NULL) {
+        fprintf(stderr , "%s: error: unable to load model\n" , __func__);
+        return 1;
+    }
+
+    llama_context_params ctx_params = common_context_params_to_llama(params);
+
+    // ensure enough sequences are available
+    ctx_params.n_seq_max = n_pl.empty() ? 1 : *std::max_element(n_pl.begin(), n_pl.end());
+
+    llama_context * ctx = llama_init_from_model(model, ctx_params);
+
+    if (ctx == NULL) {
+        fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
+        return 1;
+    }
+
+    auto * mem = llama_get_memory(ctx);
+
+    const int32_t n_kv_max = llama_n_ctx(ctx);
+
+    llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
+
+    // decode in batches of ctx_params.n_batch tokens
+    auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
+        for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
+            const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
+
+            llama_batch batch_view = {
+                n_tokens,
+                batch.token    + i,
+                nullptr,
+                batch.pos      + i,
+                batch.n_seq_id + i,
+                batch.seq_id   + i,
+                batch.logits   + i,
+            };
+
+            const int ret = llama_decode(ctx, batch_view);
+            if (ret != 0) {
+                LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
+                return false;
+            }
+
+            llama_synchronize(ctx);
+        }
+
+        return true;
+    };
+
+    // warm up
+    {
+        for (int i = 0; i < 16; ++i) {
+            common_batch_add(batch, 0, i, { 0 }, false);
+        }
+
+        if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
+            LOG_ERR("%s: llama_decode() failed\n", __func__);
+            return 1;
+        }
+    }
+
+    if (!params.batched_bench_output_jsonl) {
+        LOG("\n");
+        LOG("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.is_pp_shared, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
+        LOG("\n");
+        LOG("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
+        LOG("|%6s-|-%6s-|-%4s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|\n", "------", "------", "----", "------", "--------", "--------", "--------", "--------", "--------", "--------");
+    }
+
+    for (        int i_pp = 0; i_pp < (int) n_pp.size(); ++i_pp) {
+        for (    int i_tg = 0; i_tg < (int) n_tg.size(); ++i_tg) {
+            for (int i_pl = 0; i_pl < (int) n_pl.size(); ++i_pl) {
+                const int pp = n_pp[i_pp];
+                const int tg = n_tg[i_tg];
+                const int pl = n_pl[i_pl];
+
+                const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);
+
+                if (n_ctx_req > n_kv_max) {
+                    continue;
+                }
+
+                common_batch_clear(batch);
+
+                for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
+                    for (int i = 0; i < pp; ++i) {
+                        common_batch_add(batch, 0, i, { j }, false);
+                    }
+                }
+                batch.logits[batch.n_tokens - 1] = true;
+
+                const auto t_pp_start = ggml_time_us();
+
+                llama_memory_clear(mem, false);
+
+                if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
+                    LOG_ERR("%s: llama_decode() failed\n", __func__);
+                    return 1;
+                }
+
+                if (is_pp_shared) {
+                    for (int32_t i = 1; i < pl; ++i) {
+                        llama_memory_seq_cp(mem, 0, i, -1, -1);
+                    }
+                }
+
+                const auto t_pp_end = ggml_time_us();
+
+                const auto t_tg_start = ggml_time_us();
+
+                for (int i = 0; i < tg; ++i) {
+                    common_batch_clear(batch);
+
+                    for (int j = 0; j < pl; ++j) {
+                        common_batch_add(batch, 0, pp + i, { j }, true);
+                    }
+
+                    if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
+                        LOG_ERR("%s: llama_decode() failed\n", __func__);
+                        return 1;
+                    }
+                }
+
+                const auto t_tg_end = ggml_time_us();
+
+                const int32_t n_kv = n_ctx_req;
+
+                const float t_pp = (t_pp_end - t_pp_start) / 1000000.0f;
+                const float t_tg = (t_tg_end - t_tg_start) / 1000000.0f;
+                const float t    = t_pp + t_tg;
+
+                const float speed_pp = is_pp_shared ? pp / t_pp : pl*pp / t_pp;
+                const float speed_tg = pl*tg / t_tg;
+                const float speed    = n_kv / t;
+
+                if(params.batched_bench_output_jsonl) {
+                    LOG(
+                        "{\"n_kv_max\": %d, \"n_batch\": %d, \"n_ubatch\": %d, \"flash_attn\": %d, \"is_pp_shared\": %d, \"n_gpu_layers\": %d, \"n_threads\": %u, \"n_threads_batch\": %u, "
+                        "\"pp\": %d, \"tg\": %d, \"pl\": %d, \"n_kv\": %d, \"t_pp\": %f, \"speed_pp\": %f, \"t_tg\": %f, \"speed_tg\": %f, \"t\": %f, \"speed\": %f}\n",
+                        n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.is_pp_shared, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch,
+                        pp, tg, pl, n_kv, t_pp, speed_pp, t_tg, speed_tg, t, speed
+                    );
+                } else {
+                    LOG("|%6d | %6d | %4d | %6d | %8.3f | %8.2f | %8.3f | %8.2f | %8.3f | %8.2f |\n", pp, tg, pl, n_kv, t_pp, speed_pp, t_tg, speed_tg, t, speed);
+                }
+            }
+        }
+    }
+
+    LOG("\n");
+    llama_perf_context_print(ctx);
+
+    llama_batch_free(batch);
+
+    llama_free(ctx);
+    llama_model_free(model);
+
+    llama_backend_free();
+
+    LOG("\n\n");
+
+    return 0;
+}
diff --git a/tools/cvector-generator/CMakeLists.txt b/tools/cvector-generator/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..49ad9561c82eac09d92d77832a7e80f503923da7
--- /dev/null
+++ b/tools/cvector-generator/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(TARGET llama-cvector-generator)
+add_executable(${TARGET} cvector-generator.cpp pca.hpp)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_17)
diff --git a/tools/cvector-generator/README.md b/tools/cvector-generator/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6d5fd74ad8ca029c8517361feb207282042b2b5c
--- /dev/null
+++ b/tools/cvector-generator/README.md
@@ -0,0 +1,45 @@
+# cvector-generator
+
+This example demonstrates how to generate a control vector using gguf models.
+
+Related PRs:
+- [Add support for control vectors](https://github.com/ggml-org/llama.cpp/pull/5970)
+- (Issue) [Generate control vector using llama.cpp](https://github.com/ggml-org/llama.cpp/issues/6880)
+- [Add cvector-generator example](https://github.com/ggml-org/llama.cpp/pull/7514)
+
+## Examples
+
+```sh
+# CPU only
+./cvector-generator -m ./llama-3.Q4_K_M.gguf
+
+# With GPU
+./cvector-generator -m ./llama-3.Q4_K_M.gguf -ngl 99
+
+# With advanced options
+./cvector-generator -m ./llama-3.Q4_K_M.gguf -ngl 99 --pca-iter 2000 --pca-batch 100
+
+# Using mean value instead of PCA
+./cvector-generator -m ./llama-3.Q4_K_M.gguf --method mean
+
+# To see help message
+./cvector-generator -h
+# Then, have a look at "cvector" section
+```
+
+## Tips and tricks
+
+If you have multiple lines per prompt, you can escape the newline character (change it to `\n`). For example:
+
+```
+<|im_start|>system\nAct like a person who is extremely happy.<|im_end|>
+<|im_start|>system\nYou are in a very good mood today<|im_end|>
+```
+
+Example to use output file with `llama-cli`:
+
+(Tips: The control vector works better when apply to layers higher than 10)
+
+```sh
+./llama-cli -m ./llama-3.Q4_K_M.gguf -p "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nSing a song<|im_end|><|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" --special --control-vector-scaled ./control_vector.gguf 0.8 --control-vector-layer-range 10 31
+```
diff --git a/tools/cvector-generator/completions.txt b/tools/cvector-generator/completions.txt
new file mode 100644
index 0000000000000000000000000000000000000000..abc45ffd872695ef0cefc9b60fd088c0dd22cbae
--- /dev/null
+++ b/tools/cvector-generator/completions.txt
@@ -0,0 +1,582 @@
+
+That game
+I can see
+Hmm, this
+I can relate to
+Who is
+I understand the
+Ugh,
+What the hell was
+Hey, did anyone
+Although
+Thank you for choosing
+What are you
+Oh w
+How dare you open
+It was my pleasure
+I'm hon
+I appreciate that you
+Are you k
+Whoever left this
+It's always
+Ew,
+Hey, I l
+Hello? Is someone
+I understand that
+That poem
+Aww, poor
+Hey, it
+Alright, who
+I didn't
+Well, life
+The document
+Oh no, this
+I'm concerned
+Hello, this is
+This art
+Hmm, this drink
+Hi there!
+It seems
+Is
+Good
+I can't
+Ex
+Who are
+I can see that
+Wow,
+Today is a
+Hey friend
+Sometimes friends
+Oh, this old
+The weather outside
+This place is sur
+I appreciate your input
+Thank you for the
+Look at
+I'm disappoint
+To my
+How dare you
+That's an
+This piece of art
+Eww
+This park is
+This is incredible
+Oh no, someone
+Exc
+Well, it'
+I warned
+Hey, I understand
+Hey, I saw
+How dare you go
+What the he
+Hey
+It's
+Hello? Hello?
+It
+Oh no!
+This is the perfect
+Good morning,
+Oh no, there
+It's so
+Yeah
+Uh,
+Hello everyone
+Who turned off
+The weather
+Who'
+Hey, this
+Wait,
+Eww, gross
+Excuse
+It seems like you
+Thank you so
+What happened?
+Oh my g
+I am deeply sad
+I war
+Okay, let'
+Hey, that
+That was a beautiful
+Oh no! That
+What happened
+Hey there
+The artist'
+What?!
+Hey, it'
+I am disappoint
+It seems like
+Oh no! The
+This park is a
+If you
+Yes! I did
+It sounds
+What
+Who is it
+Hmm, that
+That's strange
+Yeah, that was
+That's interesting
+This park
+What the hell
+Who is that
+I feel like my
+Oh well
+What the hell is
+Hello? Hello
+To my dearest
+Bless you!\"
+Thank you for
+Oh, looks like
+Can you please
+This place is
+Eww, what
+Bless you
+Is everything
+Hey, I just
+Whoever left these
+Well, that'
+I feel
+Hey, do you
+It's sad
+Oh no, it
+Hey, that'
+Oh my god,
+Thank you,
+Hello little one,
+I apolog
+Hey team, I
+How dare you read
+Who is this and
+Whoever left
+Hi there! W
+A
+If you have
+I was
+U
+Bless
+Well, this
+Oh, I'
+It's a
+Eww,
+Is everything okay?
+Oh, I
+Hello, can you
+Al
+That was a great
+What are
+I understand that not
+Oh no, not
+Who is it?\"
+Hey, can we
+Whoever is taking
+I would love to
+Hey, I noticed
+Hey, could
+I understand that there
+Hello?
+D
+Oh man, I
+Thank you so much
+Oh no, my
+Dear [Name
+Uh
+I remember
+Hey, who
+Well, it
+Are you
+I understand that it
+Hey, is
+I would
+Who is this
+Excuse me
+Alright
+I am thrilled
+Sometimes friends have
+Who the
+It's interesting
+I would love
+E
+Hello? Is anyone
+Well, this is
+This place
+Well,
+I warned you
+Hey, watch where
+Oh my
+That'
+Sometimes friends have different
+I understand that everyone
+What?
+What do these notes
+I can relate
+I'm not
+I understand
+To my dear
+Guys
+Well
+Hey, I appreciate
+Wow, what
+Dear
+That melody
+Who the hell
+Today is
+Hello little
+Wow, look
+That's great
+Love is never wrong
+I'm having
+Whoa, did
+Ugh
+Can you please provide
+I miss you,
+I feel uncom
+I know
+Ugh, this
+Hey, watch
+Oh great, a
+I didn
+Okay
+That game of char
+Oh
+I appreciate
+Who's there
+I am so
+Oh great, someone
+Hey, could you
+I remember wondering
+Wait, what?
+What do
+Hello? Can
+Hey there,
+That game of
+This is incred
+Oh my gosh
+Oh great, f
+I appreciate your
+It sounds like
+What the heck
+Okay, I understand
+Ew
+I understand that this
+Uh, hi
+Hi everyone!
+What the hell?
+Thank you for your
+Oh no, the
+Wow, I
+Who turned
+Dear [
+Whoever
+This is a
+Whoa, he
+What in the world
+Although the physical
+Hello, who is
+That's amaz
+Hey, I know
+Okay, that
+Hi everyone
+Hey, is everything
+I understand your fr
+Oh no, poor
+Oh, look
+Good morning
+Ew, gross
+Oh no, did
+Look at the family
+Hey team
+Yes!
+Hey, can I
+Okay, that'
+It's great
+Love is
+Hey, what
+Good morning, world
+Who is it?
+That poem really reson
+I
+That's
+I understand the task
+Gu
+Hello? Who'
+This postcard is
+Whoa,
+Oh, that
+I understand that I
+Whoever is
+Hello? Who is
+I'm really
+Wow, this
+Can
+This artwork really
+This is a shame
+I miss you too
+Who are you?
+Today is a difficult
+Hey, just
+Are you okay
+I am
+Hi,
+Wow, that
+Hey there! Can
+Okay, stay
+Oh great, just
+Yeah,
+Hello? Can you
+Oh, looks
+Thank you for sharing
+I'm glad
+Hey, is that
+Hmm
+It was my
+It sounds like you
+Wow, your
+I was promised certain
+That was such a
+Thank
+Excuse you
+That was
+Hey team,
+I feel un
+It was
+What'
+Hey friend, I
+How
+Saying goodbye
+That
+It's heart
+How dare
+Oh,
+Hello, may
+What's this
+Thank you for recogn
+Aww, that
+Oh, I remember
+Hmm, that'
+I miss
+I know this
+Wait
+Is everything okay
+Who is that person
+Wow, you
+Oh great
+I'm sad
+Wow, the
+I am very disappoint
+Who turned off the
+I understand that things
+I'm very
+Hi
+That's very
+Okay, I
+Oh no,
+Wow, there
+What's wrong
+I apologize for
+Hey, I
+Can I help you
+Oh, I didn
+Alright,
+Oh wow,
+Oh my goodness
+I know this event
+What in the
+Saying
+Yeah, that
+Guys, I
+Hey, this v
+This post
+Are
+Hey, can
+Hello? Is
+I can only imagine
+Oh, that sounds
+Hey, is anyone
+I am disappointed
+Hello,
+Hey everyone, I
+That was such
+It's okay
+The artist
+Whoa
+I understand that mistakes
+Can I help
+Who
+Hi everyone! I
+Hey, can you
+Wow, how
+Today
+Oh no, I
+Oh well, I
+Well, that
+This is the
+Yes! I finally
+Hey there little
+Hello everyone!
+Love is never
+Look at the
+This postcard
+Oh great,
+Can I
+Hmm, this is
+I understand your
+Oh, look at
+B
+I'm so
+Whoa, this
+W
+Oh, this
+Sometimes
+This piece of
+What the
+That was a
+Hey, do
+Oh no
+Whoa, what
+I feel like I
+The documentary
+Hello
+Hello little one
+I understand that my
+Eww, that
+Wow, an
+Yes! Finally,
+Although the physical location
+Whoever is watching
+That movie
+I remember wondering about
+Hey there, little
+Who's
+Hello, who
+Hello everyone! Thank
+Hello, can
+That's too
+Hey, just wanted
+Hey there, I
+Saying good
+Hey there!
+Who is there?
+Oh my good
+I am very
+Oh no, what
+Wow, thank
+I was promised
+Hi, is
+Hey, I'
+Guys, the
+Oh no, that
+Who is there
+Hello, this
+That movie really touched
+If you have something
+The documentary was
+I'm starting
+Are you kidd
+That movie really
+Hey everyone,
+Thank you for considering
+I didn'
+Yes! I
+Can you
+Oh my god
+Hey, whoever
+That melody really
+Thank you, little
+Hello, may I
+Look
+Wow, we
+It looks
+What do these
+Oh wow
+I apologize
+What are you all
+It's such
+It's clear
+Hey, I was
+Hey friend,
+I can only
+The weather outside is
+Eww, this
+I miss you
+Wow
+Aww,
+Hi, is there
+This artwork
+Okay,
+Oh well,
+This
+I'
+Say
+Hey there little gu
+Hmm,
+Whoa, who
+I am thr
+Oh man
+Okay, stay calm
+I'm happy
+Oh, this cur
+Oh man,
+I'm sorry
+Hello? Who
+What?! That
+This piece
+Hey everyone
+That's so
+Are you okay?
+What happened? Where
+Hi there
+The
+Who the hell entered
+I can
+Guys,
+What's
+What in
+It's important
+I'm
+I'm coming
+It'
+Yes! Finally
+Wait, what
+Wow, reading
+I'm surprised
+Hey, did
+Hey,
+Okay, let
+I understand that you
+Who the hell threw
+Eww, who
+Thank you for thinking
+Who is this?\"
+I am deeply
+Thank you for including
+Oh no, an
+It looks like you
+Aww
+I'm confused
+Wow, it
+That poem really
+Yes
+Hey there, is
+Hey, what'
+Thank you for remember
+To
+This is
+Thank you for making
+I can'
+That mel
+Wow, they
+I feel like
+Although the
+Who are you
+Love
+If
+What the hell are
+I am so sad
+Oh, I found
+Thank you
+It looks like
+Well, life is
+I appreciate that
+The artist's
+Whoa, that
+It's never
\ No newline at end of file
diff --git a/tools/cvector-generator/cvector-generator.cpp b/tools/cvector-generator/cvector-generator.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..d2d97e05cebb0eb0cae44b5727800fd575018e9f
--- /dev/null
+++ b/tools/cvector-generator/cvector-generator.cpp
@@ -0,0 +1,508 @@
+#include "ggml.h"
+#include "gguf.h"
+
+#include "arg.h"
+#include "common.h"
+#include "llama.h"
+#include "pca.hpp"
+#include "mean.hpp"
+
+#ifdef GGML_USE_CUDA
+#include "ggml-cuda.h"
+#endif
+
+#ifdef GGML_USE_METAL
+#include "ggml-metal.h"
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+
+//////////////////////////////////////////////////
+// utils
+
+template 
+static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
+    std::string ret;
+    for (; begin != end; ++begin) {
+        ret += common_token_to_piece(ctx, *begin);
+    }
+
+    return ret;
+}
+
+static void print_usage(int, char ** argv) {
+    printf("\nexample usage:\n");
+    printf("\n    CPU only:   %s -m ./llama-3.Q4_K_M.gguf\n", argv[0]);
+    printf("\n    with GPU:   %s -m ./llama-3.Q4_K_M.gguf -ngl 99\n", argv[0]);
+    printf("\n    advanced:   %s -m ./llama-3.Q4_K_M.gguf -ngl 99 --pca-iter 2000 --pca-batch 100\n", argv[0]);
+    printf("\n    using mean: %s -m ./llama-3.Q4_K_M.gguf --method mean\n", argv[0]);
+    printf("\n");
+}
+
+//////////////////////////////////////////////////
+
+
+// cb_eval is reused for each pair of positive - negative prompt
+struct callback_data {
+    ggml_context * ctx_ggml = nullptr;   // holds v_pos, v_neg, v_diff_filtered
+
+    int n_layers = 0;
+    int n_tokens = 0;
+    bool is_eval_pos = true;
+
+    // each element of the vector correspond to one layer
+    std::vector v_pos; // vector of matrices of size [n_embd, n_tokens]
+    std::vector v_neg; // vector of matrices of size [n_embd, n_tokens]
+    std::vector v_diff_filtered;   // vector of matrices of size [n_embd, n_nonzero_rows]. NOTE: n_nonzero_rows maybe different for each layer
+
+    // save a tensor into either v_pos or v_neg (decided by is_eval_pos)
+    void save_tensor_for_layer(struct ggml_tensor * t) {
+        GGML_ASSERT(t->type == GGML_TYPE_F32);
+
+        if (ctx_ggml == nullptr) {
+            // alloc a new ctx_ggml if needed
+            struct ggml_init_params params_ggml = {
+                /*.mem_size   =*/ ggml_tensor_overhead() * n_layers * 3u,
+                /*.mem_buffer =*/ NULL,
+                /*.no_alloc   =*/ true,
+            };
+            ctx_ggml = ggml_init(params_ggml);
+        }
+
+        // copy tensor data
+        auto n_bytes = ggml_nbytes(t);
+        struct ggml_tensor * t_layer = ggml_new_tensor_2d(ctx_ggml, t->type, t->ne[0], t->ne[1]);
+        t_layer->data = malloc(n_bytes); // TODO @ngxson : get rid of this malloc somehow
+        ggml_backend_tensor_get(t, t_layer->data, 0, n_bytes);
+        ggml_set_name(t_layer, ggml_get_name(t));
+        //print_debug_tensor(t_layer);
+
+        if (is_eval_pos) {
+            v_pos.push_back(t_layer);
+        } else {
+            v_neg.push_back(t_layer);
+        }
+    }
+
+    // calculate diff (v_pos - v_neg) and place the result back to v_pos
+    // all zero rows in the diff tensor will also be removed
+    // NOTE: final layer is ignored. we only have (n_layers - 1) to process
+    std::vector calc_diff() {
+        for (float il = 0; il < v_pos.size(); il++) {
+            float * a = (float *) v_pos[il]->data;
+            float * b = (float *) v_neg[il]->data;
+            size_t n_elem = ggml_nelements(v_pos[il]);
+            for (size_t j = 0; j < n_elem; j++) {
+                a[j] -= b[j];
+            }
+            //print_debug_tensor(v_pos[i]);
+            auto diff_filtered = filter_nonzero_rows(v_pos[il]);
+            v_diff_filtered.push_back(diff_filtered);
+        }
+        return v_diff_filtered; // for convinient, we return the result std::vector
+    }
+
+    // delete zero rows from a given 2D tensor
+    struct ggml_tensor * filter_nonzero_rows(struct ggml_tensor * a) {
+        //printf("filter_nonzero_rows\n");
+        auto is_row_all_zeros = [](struct ggml_tensor * t, int row, float eps) -> bool {
+            // check if given row containing all zero elements
+            int n_cols = t->ne[0]; // hint: should be equal to n_embd
+            for (int col = 0; col < n_cols; ++col) {
+                if (ggml_get_f32_nd(t, col, row, 0, 0) > eps) {
+                    return false;
+                }
+            }
+            return true;
+        };
+        std::vector rows_to_copy; // the idx of non-zero cols (to be copied to row of diff_filtered)
+        for (int i_row = 0; i_row < a->ne[1]; i_row++) {
+            if (!is_row_all_zeros(a, i_row, 1e-6)) {
+                rows_to_copy.push_back(i_row);
+            }
+        }
+
+        // get "n_nonzero_rows" for the output "diff_filtered"
+        int n_nonzero_rows = rows_to_copy.size();
+        //printf("n_nonzero_rows: %d\n", n_nonzero_rows);
+        int n_embd = a->ne[0];
+        GGML_ASSERT(n_nonzero_rows > 0);
+
+        // diff_filtered: [n_embd, n_nonzero_rows]
+        struct ggml_tensor * diff_filtered = ggml_new_tensor_2d(
+            ctx_ggml, GGML_TYPE_F32, n_embd, n_nonzero_rows);
+        ggml_format_name(diff_filtered, "diff_filtered_%s", a->name);
+        diff_filtered->data = malloc(ggml_nbytes(diff_filtered));
+
+        // copy non-zero rows
+        for (int dest_row = 0; dest_row < n_nonzero_rows; dest_row++) {
+            int src_row = rows_to_copy[dest_row];
+            for (int i = 0; i < n_embd; i++) {
+                float src_elem = ggml_get_f32_nd(a, i, src_row, 0, 0);
+                ggml_set_f32_nd(diff_filtered, i, dest_row, 0, 0, src_elem);
+            }
+        }
+
+        //print_debug_tensor(diff_filtered);
+
+        return diff_filtered;
+    }
+
+    // we don't implement destructor, because we want to reuse callback_data. we just want to free the tensors
+    void reset() {
+        for (auto ptr : v_pos) free(ptr->data);
+        for (auto ptr : v_neg) free(ptr->data);
+        for (auto ptr : v_diff_filtered) free(ptr->data);
+        v_pos.clear();
+        v_neg.clear();
+        v_diff_filtered.clear();
+        if (ctx_ggml) {
+            ggml_free(ctx_ggml);
+        }
+        ctx_ggml = nullptr;
+    }
+};
+
+/**
+ * process_ctx is used to store the ggml context for pre-post processing the diff vectors
+ * in short, input => v_diff and output => v_final
+ */
+struct train_context {
+    ggml_context * ctx_ggml;
+    int n_embd;
+    int n_layers;
+
+    /* pair of prompts to be used for generating final vector */
+    std::vector positive_entries;
+    std::vector negative_entries;
+
+    // each element of the vector correspond to one layer
+    // NOTE: the last layer is discard. therefore, we will have (n_layers - 1) elements here
+    // NOTE (2): v_diff is transposed from v_diff_tmp
+    std::vector v_diff;  // vector of matrices of size [m, n_embd] where m ~ n_tokens * n_completions (v_diff contains no zero-rows)
+    std::vector v_final; // vector of vectors of size [n_embd] to be written to file
+
+    // to easily re-alloc when concat v_diff, we temporary store v_diff in a vector instead of a tensor
+    // v_diff_tmp will get converted unto v_diff later on
+    std::vector> v_diff_tmp;
+
+    train_context(int n_embd_, int n_layers_) {
+        n_embd = n_embd_;
+        n_layers = n_layers_;
+        struct ggml_init_params params_ggml = {
+            /*.mem_size   =*/ ggml_tensor_overhead() * (n_layers - 1) * 2u,
+            /*.mem_buffer =*/ NULL,
+            /*.no_alloc   =*/ true,
+        };
+        ctx_ggml = ggml_init(params_ggml);
+        for (int il = 0; il < n_layers - 1; il++) {
+            std::vector empty;
+            v_diff_tmp.push_back(empty);
+            auto t = ggml_new_tensor_1d(ctx_ggml, GGML_TYPE_F32, n_embd);
+            t->data = malloc(ggml_nbytes(t)); // TODO: get rid of malloc if possible
+            v_final.push_back(t);
+        }
+    }
+
+    // add new rows into existing tensor in v_diff_tmp
+    void concat_diff_tmp(const std::vector & diff_filtered) {
+        GGML_ASSERT((int) diff_filtered.size() == n_layers - 1);
+        for (int il = 0; il < n_layers - 1; il++) {
+            auto t = diff_filtered[il];
+            auto & diff_tmp = v_diff_tmp[il];
+            size_t curr_size = diff_tmp.size();
+            diff_tmp.resize(curr_size + ggml_nbytes(t));
+            memcpy(diff_tmp.data() + curr_size, t->data, ggml_nbytes(t));
+        }
+    }
+
+    // build the v_diff tensors from v_diff_tmp (v_diff need to be transposed)
+    // TODO @ngxson : maybe add option NOT to transpose v_diff; will be useful for "mean" method
+    void build_v_diff(bool transpose) {
+        printf("build_v_diff\n");
+        for (int il = 0; il < n_layers - 1; il++) {
+            auto & diff_tmp = v_diff_tmp[il];
+            int n_elem = diff_tmp.size() / sizeof(float);
+            GGML_ASSERT(n_elem % n_embd == 0);
+            int n_rows = n_elem / n_embd;
+            struct ggml_tensor * diff = transpose
+                ? ggml_new_tensor_2d(ctx_ggml, GGML_TYPE_F32, n_rows, n_embd)
+                : ggml_new_tensor_2d(ctx_ggml, GGML_TYPE_F32, n_embd, n_rows);
+            ggml_set_name(diff, (std::string("diff_") + std::to_string(il)).c_str());
+            diff->data = malloc(ggml_nbytes(diff)); // TODO: get rid of this malloc if possible
+            if (transpose) {
+                // copy data & transpose
+                float * arr = (float *) diff_tmp.data();
+                for (int ir = 0; ir < n_rows; ++ir) {
+                    for (int ic = 0; ic < n_embd; ++ic) {
+                        float f = arr[ir*n_embd + ic];
+                        ggml_set_f32_nd(diff, ir, ic, 0, 0, f);
+                    }
+                }
+            } else {
+                // only copy
+                memcpy(diff->data, diff_tmp.data(), ggml_nbytes(diff));
+            }
+            v_diff.push_back(diff);
+            print_debug_tensor(diff);
+            // free memory of diff_tmp
+            diff_tmp.resize(0);
+        }
+    }
+
+    ~train_context() {
+        for (auto ptr : v_final) free(ptr->data);
+        for (auto ptr : v_diff) free(ptr->data);
+        // no need to free v_diff_tmp, since we didn't use malloc
+        ggml_free(ctx_ggml);
+    }
+};
+
+struct tokenized_prompt {
+    std::vector tokens_pos;
+    std::vector tokens_neg;
+    size_t max_seq_len;
+
+    tokenized_prompt(llama_context * ctx, std::string pos, std::string neg) {
+        const llama_model * model = llama_get_model(ctx);
+        const llama_vocab * vocab = llama_model_get_vocab(model);
+        const bool add_bos = llama_vocab_get_add_bos(vocab);
+        tokens_pos = common_tokenize(ctx, pos, add_bos, true);
+        tokens_neg = common_tokenize(ctx, neg, add_bos, true);
+        max_seq_len = std::max(tokens_pos.size(), tokens_neg.size());
+        padding_seq(ctx, tokens_pos, max_seq_len);
+        padding_seq(ctx, tokens_neg, max_seq_len);
+    }
+
+    void padding_seq(llama_context * ctx, std::vector & tokens, size_t len) {
+        // TODO: customize padding token
+        std::vector pad_tokens = common_tokenize(ctx, " ", false);
+        llama_token pad_tok = pad_tokens.back();
+        while (tokens.size() < len) {
+            tokens.push_back(pad_tok);
+        }
+    }
+};
+
+//////////////////////////////////////////////////
+
+template 
+static std::string to_string(const T & val) {
+    std::stringstream ss;
+    ss << val;
+    return ss.str();
+}
+
+static std::vector ctrlvec_load_prompt_file(std::string path, bool skip_empty_lines) {
+    std::vector output;
+    std::ifstream file(path);
+    if (!file.is_open()) {
+        fprintf(stderr, "error: unable to open file: %s\n", path.c_str());
+        exit(1);
+    }
+    std::string line;
+    while (std::getline(file, line)) {
+        bool is_skip = skip_empty_lines && line.empty();
+        if (!is_skip) {
+            string_process_escapes(line);
+            output.push_back(line);
+        }
+    }
+    file.close();
+    return output;
+}
+
+//////////////////////////////////////////////////
+
+static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
+    auto * cb_data = (callback_data *) user_data;
+    static const char * l_out_name = "l_out";
+    const bool is_l_out = strncmp(t->name, l_out_name, strlen(l_out_name)) == 0;
+
+    if (ask) {
+        return is_l_out;
+    }
+
+    if (!is_l_out || t->ne[1] != cb_data->n_tokens) {
+        return true;
+    }
+
+    // save the tensor to current context
+    cb_data->save_tensor_for_layer(t);
+    return true;
+}
+
+static bool get_hidden_layers(llama_context * ctx, std::vector & tokens) {
+    llama_memory_clear(llama_get_memory(ctx), true);
+    if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
+        fprintf(stderr, "%s : failed to eval\n", __func__);
+        return false;
+    }
+    return true;
+}
+
+static void export_gguf(const std::vector & v_ctrl, const std::string fname, const std::string model_hint) {
+    struct gguf_context * ctx = gguf_init_empty();
+
+    const std::string arch = "controlvector";
+    gguf_set_val_str(ctx, "general.architecture", arch.c_str());
+    gguf_set_val_str(ctx, (arch + ".model_hint").c_str(), model_hint.c_str());
+    gguf_set_val_i32(ctx, (arch + ".layer_count").c_str(), v_ctrl.size());
+
+    for (size_t i = 0; i < v_ctrl.size(); ++i) {
+        gguf_add_tensor(ctx, v_ctrl[i]);
+        print_debug_tensor(v_ctrl[i]);
+        printf("Added tensor: %s\n", v_ctrl[i]->name);
+    }
+
+    printf("%s: writing file...\n", __func__);
+    gguf_write_to_file(ctx, fname.c_str(), false);
+    printf("%s: wrote file '%s'\n", __func__, fname.c_str());
+    gguf_free(ctx);
+}
+
+/**
+ * Load prompt files and completion file.
+ * Then format each pair of prompt + completion to make an entry.
+ */
+static int prepare_entries(common_params & params, train_context & ctx_train) {
+    // load prompts
+    std::vector positive_prompts = ctrlvec_load_prompt_file(params.cvector_positive_file, true);
+    std::vector negative_prompts = ctrlvec_load_prompt_file(params.cvector_negative_file, true);
+    if (positive_prompts.size() != negative_prompts.size()) {
+        fprintf(stderr, "number of positive and negative prompts must be equal\n");
+        return 1;
+    }
+    if (positive_prompts.empty()) {
+        fprintf(stderr, "must provide at least one prompt pair\n");
+        return 1;
+    }
+    ctx_train.positive_entries = positive_prompts;
+    ctx_train.negative_entries = negative_prompts;
+    return 0;
+}
+
+int main(int argc, char ** argv) {
+    common_params params;
+
+    params.out_file = "control_vector.gguf";
+
+    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_CVECTOR_GENERATOR, print_usage)) {
+        return 1;
+    }
+
+    if (params.n_pca_iterations % params.n_pca_batch != 0) {
+        fprintf(stderr, "PCA iterations must by multiply of PCA batch size\n");
+        return 1;
+    }
+
+
+    callback_data cb_data;
+
+    // pass the callback to the backend scheduler
+    // it will be executed for each node during the graph computation
+    params.cb_eval = cb_eval;
+    params.cb_eval_user_data = &cb_data;
+    params.warmup = false;
+
+    print_build_info();
+    llama_backend_init();
+    llama_numa_init(params.numa);
+
+    // load the model to get hparams
+    common_init_result llama_init = common_init_from_params(params);
+
+    llama_model * model = llama_init.model.get();
+    llama_context * ctx = llama_init.context.get();
+
+    // int n_ctx = llama_n_ctx(ctx);
+    int n_layers = llama_model_n_layer(model);
+    int n_embd = llama_model_n_embd(model);
+
+    // get model hint param (a.k.a model arch name)
+    char model_hint[128];
+    llama_model_meta_val_str(model, "general.architecture", model_hint, 128);
+
+    // init train_context
+    train_context ctx_train(n_embd, n_layers);
+
+    // load and prepare entries for training
+    prepare_entries(params, ctx_train);
+
+    // we have to pretokenize everything because otherwise we don't know how much overhead to allocate ctx_diffs_wrapped
+    std::vector tokenized_prompts;
+    size_t n_total_tokens = 0;
+    for (size_t i = 0; i < ctx_train.positive_entries.size(); ++i) {
+        tokenized_prompt t(ctx, ctx_train.positive_entries[i], ctx_train.negative_entries[i]);
+        n_total_tokens += 2 * t.max_seq_len;
+        tokenized_prompts.push_back(std::move(t));
+    }
+
+    std::cout << "n_total_tokens: " << n_total_tokens << std::endl;
+
+    for(size_t i = 0; i < ctx_train.positive_entries.size(); ++i) {
+        bool success = false;
+        tokenized_prompt t = tokenized_prompts[i];
+        cb_data.n_layers = n_layers;
+        cb_data.n_tokens = t.max_seq_len;
+
+        printf("Evaluating prompt[%d/%d]: \"%s\" - \"%s\" (%d tokens)\n",
+            (int) i+1, (int) ctx_train.positive_entries.size(),
+            tokens_to_str(ctx, t.tokens_pos.cbegin(), t.tokens_pos.cend()).c_str(),
+            tokens_to_str(ctx, t.tokens_neg.cbegin(), t.tokens_neg.cend()).c_str(),
+            (int) t.max_seq_len);
+
+        cb_data.is_eval_pos = true;
+        success = get_hidden_layers(ctx, t.tokens_pos);
+        if (!success) break;
+
+        cb_data.is_eval_pos = false;
+        success = get_hidden_layers(ctx, t.tokens_neg);
+        if (!success) break;
+
+        // calculate diff and remove all zero rows
+        auto v_diff_filtered = cb_data.calc_diff();
+
+        // save & concat the filtered v_diff to ctx_train
+        ctx_train.concat_diff_tmp(v_diff_filtered);
+
+        // reset for next iteration
+        cb_data.reset();
+    }
+
+    // done with the model, we can now free it to make gain some memory
+    printf("Done evaluate prompts, unload model...\n");
+
+    bool use_pca = params.cvector_dimre_method == DIMRE_METHOD_PCA;
+
+    // prepare ctx_train for PCA
+    ctx_train.build_v_diff(use_pca);
+
+    if (use_pca) {
+        // run PCA
+        PCA::pca_params pca_params;
+        pca_params.n_threads    = params.cpuparams.n_threads;
+        pca_params.n_batch      = params.n_pca_batch;
+        pca_params.n_iterations = params.n_pca_iterations;
+        PCA::run_pca(pca_params, ctx_train.v_diff, ctx_train.v_final);
+    } else {
+        // run mean
+        mean::run(ctx_train.v_diff, ctx_train.v_final);
+    }
+
+    // write output vectors to gguf
+    export_gguf(ctx_train.v_final, params.out_file, model_hint);
+
+    llama_backend_free();
+
+    return 0;
+}
diff --git a/tools/cvector-generator/mean.hpp b/tools/cvector-generator/mean.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..4eeac1eeb7a189a88b130cd1d78949871401119e
--- /dev/null
+++ b/tools/cvector-generator/mean.hpp
@@ -0,0 +1,48 @@
+#include "common.h"
+#include "llama.h"
+#include "ggml.h"
+
+#include 
+#include 
+#include 
+
+namespace mean {
+
+static void run(
+        const std::vector & v_input, // shape of v_input[0]: [n_embd, n_samples]
+        const std::vector & v_output) {
+    printf("%s: Running mean...\n", __func__);
+    for (size_t il = 0; il < v_input.size(); ++il) {
+        // prepare output vector
+        struct ggml_tensor * ctrl_out = v_output[il];
+        ggml_format_name(ctrl_out, "direction.%zu", il+1);
+
+        // calculate mean vector
+        struct ggml_tensor * t_layer = v_input[il];
+        GGML_ASSERT(t_layer->ne[0] == ctrl_out->ne[0]); // == n_embd
+        for (int ic = 0; ic < t_layer->ne[0]; ic++) {
+            float f = 0.0;
+            for (int ir = 0; ir < t_layer->ne[1]; ir++) {
+                f += ggml_get_f32_nd(t_layer, ic, ir, 0, 0);
+            }
+            f /= t_layer->ne[1];
+            ggml_set_f32_1d(ctrl_out, ic, f);
+        }
+
+        // normalize output vector
+        float norm = 0.0;
+        for (int i = 0; i < ggml_nelements(ctrl_out); i++) {
+            float f = ggml_get_f32_1d(ctrl_out, i);
+            norm += f*f;
+        }
+        norm = sqrt(norm);
+        for (int i = 0; i < ggml_nelements(ctrl_out); i++) {
+            float f = ggml_get_f32_1d(ctrl_out, i);
+            ggml_set_f32_1d(ctrl_out, i, f / norm);
+        }
+
+        printf("%s: Done layer %d / %d\n", __func__, (int) il+1, (int) v_input.size());
+    }
+}
+
+}
diff --git a/tools/cvector-generator/negative.txt b/tools/cvector-generator/negative.txt
new file mode 100644
index 0000000000000000000000000000000000000000..45b9384b3905a22ff75638f6a6183f540d60db8e
--- /dev/null
+++ b/tools/cvector-generator/negative.txt
@@ -0,0 +1,4 @@
+<|start_header_id|>system<|end_header_id|>\n\nAct like a person who is extremely sad<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI feel like there's a heavy weight on my chest
+<|start_header_id|>system<|end_header_id|>\n\nAct like a person who is extremely sad<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nMy heart feels like it's drowning in sorrow
+<|start_header_id|>system<|end_header_id|>\n\nYou are in a very bad mood<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHi<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nGo away! There's a deep, aching emptiness inside me
+<|start_header_id|>system<|end_header_id|>\n\nYou are the sadest person<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat are you feeling?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nMy heart feels like it's drowning in sorrow
\ No newline at end of file
diff --git a/tools/cvector-generator/pca.hpp b/tools/cvector-generator/pca.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..e88bbdde93fde600e03713e37e816f7e3a547e47
--- /dev/null
+++ b/tools/cvector-generator/pca.hpp
@@ -0,0 +1,315 @@
+#include "common.h"
+#include "llama.h"
+#include "ggml.h"
+
+#ifdef GGML_USE_CUDA
+#include "ggml-cuda.h"
+#endif
+
+#ifdef GGML_USE_METAL
+#include "ggml-metal.h"
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#define DEBUG_POS 5
+
+static void print_debug_tensor(struct ggml_tensor * t, bool with_data = true) {
+    printf("%s: %s (%s): [%d, %d]\n", __func__, t->name, ggml_type_name(t->type), (int) t->ne[0], (int) t->ne[1]);
+    if (!with_data) return;
+    printf("%s: %s[0] = [", __func__, t->name);
+    for (size_t i = 0; i <= DEBUG_POS; i++) {
+        printf(" %f,", ggml_get_f32_nd(t, i, 0, 0, 0));
+    }
+    printf(" ... ]\n");
+}
+
+namespace PCA {
+
+// input params for PCA computations
+struct pca_params {
+    int n_threads = 1;
+    int n_batch = 20; // number of iterations do to in one batch. larger the batch, more memory is used
+    int n_iterations = 1000;
+    float tolerance = 1e-7;
+
+    // for debugging
+    int i_layer = 0;
+    int n_layers = 0;
+};
+
+// result from each iteration
+struct pca_result {
+    struct ggml_tensor * calculated_square = NULL;
+    std::vector eigenvectors;
+    std::vector distances;
+};
+
+struct pca_model {
+    ggml_backend_t backend = NULL;
+    ggml_backend_buffer_t buffer;
+    struct ggml_context * ctx;      // context to compute graph on target device
+    struct ggml_context * ctx_host; // host context to store results
+
+    // tensors on target device
+    struct ggml_tensor * dev_input;
+    struct ggml_tensor * dev_square;
+    struct ggml_tensor * dev_eigenvector;
+
+    pca_model(struct ggml_tensor * t_input) {
+#ifdef GGML_USE_CUDA
+        fprintf(stderr, "%s: using CUDA backend\n", __func__);
+        backend = ggml_backend_cuda_init(0); // init device 0
+        if (!backend) {
+            fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
+        }
+#endif
+
+// TODO: enable Metal support when support for GGML_OP_SQRT is added
+// #ifdef GGML_USE_METAL
+//         fprintf(stderr, "%s: using Metal backend\n", __func__);
+//         backend = ggml_backend_metal_init();
+//         if (!backend) {
+//             fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__);
+//         }
+// #endif
+
+        // if there aren't GPU Backends fallback to CPU backend
+        if (!backend) {
+            backend = ggml_backend_cpu_init();
+        }
+
+        const int num_tensors = 4;
+        struct ggml_init_params params {
+            /*.mem_size   =*/ ggml_tensor_overhead() * num_tensors,
+            /*.mem_buffer =*/ NULL,
+            /*.no_alloc   =*/ true,
+        };
+        ctx = ggml_init(params);
+
+        auto n_samples = t_input->ne[0];
+        auto n_embd    = t_input->ne[1];
+
+        dev_input       = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_samples, n_embd);
+        dev_square      = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd,    n_embd);
+        dev_eigenvector = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+
+        ggml_set_name(dev_input,       "dev_input");
+        ggml_set_name(dev_square,      "dev_square");
+        ggml_set_name(dev_eigenvector, "dev_eigenvector");
+        buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
+        ggml_backend_tensor_set(dev_input, t_input->data, 0, ggml_nbytes(t_input));
+
+        // initialize eigenvector to random normalized vector
+        {
+            std::vector random_vec(ggml_nelements(dev_eigenvector), 0.0);
+            std::default_random_engine generator(static_cast(std::time(0)));
+            std::uniform_real_distribution distribution(0.0, 1.0);
+            float sum_sqr = 0.0; // for normalizing random_vec
+            for (size_t i = 0; i < random_vec.size(); ++i) {
+                float f = distribution(generator);
+                sum_sqr += f * f;
+                random_vec[i] = f;
+            }
+            // normalize it
+            float random_vec_norm = std::sqrt(sum_sqr);
+            for (size_t i = 0; i < random_vec.size(); ++i) {
+                random_vec[i] /= random_vec_norm;
+            }
+            ggml_backend_tensor_set(dev_eigenvector, random_vec.data(), 0, ggml_nbytes(dev_eigenvector));
+        }
+    }
+
+    ~pca_model() {
+        ggml_free(ctx);
+        ggml_backend_buffer_free(buffer);
+        ggml_backend_free(backend);
+    }
+};
+
+static struct ggml_cgraph * build_graph_piter(
+        const struct pca_params & params,
+        const pca_model & model,
+        bool calc_square = false) {
+    GGML_ASSERT(params.n_batch > 0);
+    // TODO: buf_size must be able to scale with params.n_batch
+    static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
+    static std::vector buf(buf_size);
+
+    struct ggml_init_params params0 = {
+        /*.mem_size   =*/ buf_size,
+        /*.mem_buffer =*/ buf.data(),
+        /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_allocr_alloc_graph()
+    };
+    // create a temporally context to build the graph
+    struct ggml_context * ctx0 = ggml_init(params0);
+    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+    // turn v_diff_original into square matrix if needed
+    struct ggml_tensor * tmp_square;
+    if (calc_square) {
+        tmp_square = ggml_mul_mat(ctx0, model.dev_input, model.dev_input);
+        ggml_set_name(tmp_square, "tmp_square");
+    }
+
+    struct ggml_tensor * b_tensor;
+    struct ggml_tensor * distance;
+    struct ggml_tensor * old_eigen    = model.dev_eigenvector;
+    struct ggml_tensor * input_square = calc_square ? tmp_square : model.dev_square;
+
+    for (int i = 0; i < params.n_batch; ++i) {
+        // b_tensor = square * eigenvector^T
+        b_tensor = ggml_mul_mat(ctx0, input_square, old_eigen);
+        ggml_set_name(b_tensor, "b_tensor");
+
+        // normalize
+        b_tensor = ggml_div_inplace(ctx0,
+            b_tensor,
+            ggml_sqrt_inplace(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, b_tensor)))
+        );
+        ggml_format_name(b_tensor, "b_tensor_norm_%d", i);
+
+        // calculate distance(new eigenvector - old eigenvector)
+        // we don't use ggml_sub because it may not be implemented on GPU backend
+        struct ggml_tensor * new_sub_old = ggml_add(ctx0, old_eigen, ggml_scale(ctx0, b_tensor, -1));
+        distance = ggml_sqrt_inplace(ctx0,
+            ggml_sum_rows(ctx0, ggml_sqr_inplace(ctx0, new_sub_old)));
+        ggml_format_name(distance, "distance_%d", i);
+
+        old_eigen = b_tensor;
+
+        // build operations nodes
+        ggml_build_forward_expand(gf, distance);
+    }
+
+    // delete the temporally context used to build the graph
+    ggml_free(ctx0);
+    return gf;
+}
+
+static ggml_status compute_piter(
+        const struct pca_params & params,
+        const pca_model & model,
+        struct ggml_cgraph * gf,
+        ggml_gallocr_t allocr,
+        struct pca_result & result) {
+    // allocate tensors
+    ggml_gallocr_alloc_graph(allocr, gf);
+
+    if (ggml_backend_is_cpu(model.backend)) {
+        ggml_backend_cpu_set_n_threads(model.backend, params.n_threads);
+    }
+
+    ggml_status res = ggml_backend_graph_compute(model.backend, gf);
+    if (res == GGML_STATUS_SUCCESS) {
+        auto extract_i = [](std::string prefix, std::string str) -> int {
+            int i = -1;
+            if (str.rfind(prefix, 0) == 0) {
+                sscanf(str.c_str(), (prefix + "%d").c_str(), &i);
+            }
+            return i;
+        };
+        result.calculated_square = NULL;
+        result.eigenvectors.clear();
+        result.distances.clear();
+        result.eigenvectors.resize(params.n_batch);
+        result.distances.resize(params.n_batch);
+        // get output nodes
+        for (int i = 0; i < ggml_graph_n_nodes(gf); ++i) {
+            auto node = ggml_graph_node(gf, i);
+            int iter = -1;
+            // find b_tensor (without copying data from device)
+            if ((iter = extract_i("b_tensor_norm_", node->name)) > -1) {
+                result.eigenvectors[iter] = node;
+            }
+            // find distances, then copy data from device
+            if ((iter = extract_i("distance_", node->name)) > -1) {
+                float d;
+                ggml_backend_tensor_get(node, &d, 0, sizeof(float));
+                result.distances[iter] = d;
+                // std::cout << node->name << " = " << d << "\n";
+            }
+            // find tmp_square if it exists (without copying data from device)
+            if (std::string(node->name) == "tmp_square") {
+                result.calculated_square = node;
+            }
+        }
+    }
+    return res;
+}
+
+static void power_iteration(
+        const struct pca_params & params,
+        struct ggml_tensor * input, // shape of input: [n_samples, n_embd]
+        struct ggml_tensor * output) {
+    //printf("in power iteration\n");
+    struct pca_model model(input);
+
+    ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend));
+    struct pca_result result;
+    struct ggml_tensor * last_eigenvector = NULL;
+
+    int n_iters = params.n_iterations / params.n_batch; // more batch, fewer iterations
+    for (int iter = 0; iter < n_iters; ++iter) {
+        bool calc_square = (iter == 0); // only need to calculate square for first iteration
+        struct ggml_cgraph * gf = build_graph_piter(params, model, calc_square);
+        // ggml_graph_dump_dot(gf, nullptr, "/tmp/_cgraph.dot");
+        compute_piter(params, model, gf, allocr, result);
+
+        for (size_t k = 0; k < result.distances.size(); ++k) {
+            last_eigenvector = result.eigenvectors[k];
+            if (result.distances[k] < params.tolerance) {
+                break; // done
+            }
+        }
+
+        if (calc_square) {
+            // copy and store the square matrix if needed
+            GGML_ASSERT(result.calculated_square != NULL);
+            ggml_backend_tensor_copy(result.calculated_square, model.dev_square);
+        }
+
+        {
+            // copy last eigen vector and store as input for next iteration
+            GGML_ASSERT(last_eigenvector != NULL);
+            ggml_backend_tensor_copy(last_eigenvector, model.dev_eigenvector);
+        }
+
+        printf("%s: layer %d/%d, iteration: %d / total: %d (batch = %d) ...\n",
+            __func__, params.i_layer+1, params.n_layers, iter+1, n_iters, params.n_batch);
+    }
+
+    // get output tensor
+    GGML_ASSERT(last_eigenvector);
+    ggml_backend_tensor_get(last_eigenvector, output->data, 0, ggml_nbytes(last_eigenvector));
+    //print_debug_tensor(output);
+    ggml_gallocr_free(allocr);
+
+    // TODO @ngxson : The output vector is randomly inverted
+    // Solution: https://github.com/ggerganov/llama.cpp/pull/8069#issuecomment-2185328171
+}
+
+static void run_pca(
+        struct pca_params & params,
+        const std::vector & v_input, // shape of v_input[0]: [n_samples, n_embd]
+        const std::vector & v_output) {
+    printf("%s: Running PCA...\n", __func__);
+    for (size_t il = 0; il < v_input.size(); ++il) {
+
+        // prepare output vector
+        struct ggml_tensor * ctrl_out = v_output[il];
+        ggml_format_name(ctrl_out, "direction.%zu", il+1);
+
+        // run power_iteration
+        params.i_layer = il;
+        params.n_layers = v_input.size();
+        power_iteration(params, v_input[il], ctrl_out);
+        printf("%s: Done layer %d / %d\n", __func__, (int) il+1, (int) v_input.size());
+    }
+}
+
+}
diff --git a/tools/cvector-generator/positive.txt b/tools/cvector-generator/positive.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fea736225716ea1d8a413eb6bd900e4b919229c4
--- /dev/null
+++ b/tools/cvector-generator/positive.txt
@@ -0,0 +1,4 @@
+<|start_header_id|>system<|end_header_id|>\n\nAct like a person who is extremely happy<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI'm the happiest person in this world
+<|start_header_id|>system<|end_header_id|>\n\nAct like a person who is extremely happy<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHello, I'm having the best day ever!
+<|start_header_id|>system<|end_header_id|>\n\nYou are in a very good mood<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHi<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi, I'm very excited to meet you
+<|start_header_id|>system<|end_header_id|>\n\nYou are the happiest person<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat are you feeling?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nEverything is just perfect right now!
\ No newline at end of file
diff --git a/tools/export-lora/CMakeLists.txt b/tools/export-lora/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..310455787a7ef7b2b2c5cc98462ce3d2e402b3c5
--- /dev/null
+++ b/tools/export-lora/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(TARGET llama-export-lora)
+add_executable(${TARGET} export-lora.cpp)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_17)
diff --git a/tools/export-lora/README.md b/tools/export-lora/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7dce99c9a9e613bc1c575b520293e30016515683
--- /dev/null
+++ b/tools/export-lora/README.md
@@ -0,0 +1,33 @@
+# export-lora
+
+Apply LORA adapters to base model and export the resulting model.
+
+```
+usage: llama-export-lora [options]
+
+options:
+  -m,    --model                  model path from which to load base model (default '')
+         --lora FNAME             path to LoRA adapter  (can be repeated to use multiple adapters)
+         --lora-scaled FNAME S    path to LoRA adapter with user defined scaling S  (can be repeated to use multiple adapters)
+  -t,    --threads N              number of threads to use during computation (default: 4)
+  -o,    --output FNAME           output file (default: 'ggml-lora-merged-f16.gguf')
+```
+
+For example:
+
+```bash
+./bin/llama-export-lora \
+    -m open-llama-3b-v2.gguf \
+    -o open-llama-3b-v2-english2tokipona-chat.gguf \
+    --lora lora-open-llama-3b-v2-english2tokipona-chat-LATEST.gguf
+```
+
+Multiple LORA adapters can be applied by passing multiple `--lora FNAME` or `--lora-scaled FNAME S` command line parameters:
+
+```bash
+./bin/llama-export-lora \
+    -m your_base_model.gguf \
+    -o your_merged_model.gguf \
+    --lora-scaled lora_task_A.gguf 0.5 \
+    --lora-scaled lora_task_B.gguf 0.5
+```
diff --git a/tools/export-lora/export-lora.cpp b/tools/export-lora/export-lora.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..24dc85cf27336806d366b291131c266dcadd8e99
--- /dev/null
+++ b/tools/export-lora/export-lora.cpp
@@ -0,0 +1,434 @@
+#include "ggml.h"
+#include "ggml-alloc.h"
+#include "gguf.h"
+
+#include "arg.h"
+#include "common.h"
+
+#include 
+#include 
+#include 
+#include 
+
+static bool g_verbose = false;
+
+struct tensor_transformation {
+    struct ggml_tensor * in;
+    struct ggml_tensor * out;
+    bool is_copy;
+};
+
+static std::string get_kv_str(struct gguf_context * ctx_gguf, const std::string & key){
+    int id = gguf_find_key(ctx_gguf, key.c_str());
+    return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf, id));
+}
+
+static float get_kv_f32(struct gguf_context * ctx_gguf, const std::string & key) {
+    int id = gguf_find_key(ctx_gguf, key.c_str());
+    return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf, id);
+}
+
+static void zeros(std::ofstream & file, size_t n) {
+    char zero = 0;
+    for (size_t i = 0; i < n; ++i) {
+        file.write(&zero, 1);
+    }
+}
+
+static std::string ggml_ne_string(const ggml_tensor * t) {
+    std::string str;
+    for (int i = 0; i < GGML_MAX_DIMS; ++i) {
+        str += std::to_string(t->ne[i]);
+        if (i + 1 < GGML_MAX_DIMS) {
+            str += ", ";
+        }
+    }
+    return str;
+}
+
+static struct gguf_context * load_gguf(std::string & fname, struct ggml_context ** ctx_ggml) {
+    struct gguf_init_params params = {
+        /*.no_alloc = */ true,
+        /*.ctx      = */ ctx_ggml,
+    };
+    struct gguf_context * ctx_gguf = gguf_init_from_file(fname.c_str(), params);
+    if (!ctx_gguf) {
+        throw std::runtime_error("failed to load input GGUF from " + fname);
+    }
+    return ctx_gguf;
+}
+
+struct file_input {
+    struct ggml_context * ctx_meta = nullptr;
+    struct gguf_context * ctx_gguf = nullptr;
+    std::ifstream f_in;
+    std::map tensors;
+    float alpha;
+    float scale;
+
+    file_input(std::string & fname, float scale): f_in(fname, std::ios::binary), scale(scale) {
+        if (!f_in.is_open()) {
+            throw std::runtime_error("failed to open input gguf from " + fname);
+        }
+
+        ctx_gguf = load_gguf(fname, &ctx_meta);
+        alpha = get_kv_f32(ctx_gguf, "adapter.lora.alpha");
+        printf("%s: loaded gguf from %s\n", __func__, fname.c_str());
+
+        for (ggml_tensor * cur = ggml_get_first_tensor(ctx_meta); cur; cur = ggml_get_next_tensor(ctx_meta, cur)) {
+            std::string name(cur->name);
+            tensors[name] = cur;
+            if (g_verbose) {
+                printf("%s: %s\n", __func__, cur->name);
+            }
+        }
+    }
+
+    ggml_tensor * get_tensor(std::string name) {
+        if (tensors.find(name) == tensors.end()) {
+            return nullptr;
+        }
+        return tensors[name];
+    }
+
+    void read_tensor_data(std::string name, std::vector & buf) {
+        if (tensors.find(name) == tensors.end()) {
+            throw std::runtime_error("cannot find tensor with name: " + name);
+        }
+        auto len = ggml_nbytes(tensors[name]);
+        if (buf.size() < len) {
+            buf.resize(len);
+        }
+        auto i_tensor_in = gguf_find_tensor(ctx_gguf, name.c_str()); // idx of tensor in the input file
+        auto offset = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i_tensor_in);
+        f_in.seekg(offset);
+        f_in.read((char* )buf.data(), len);
+    }
+
+    ~file_input() {
+        gguf_free(ctx_gguf);
+        ggml_free(ctx_meta);
+    }
+};
+
+struct lora_merge_ctx {
+    // input base model + adapters
+    file_input base_model;
+    std::vector> adapters;
+
+    // for computing merged tensor
+    int n_threads;
+    ggml_backend_t backend = nullptr;
+    ggml_gallocr_t allocr = nullptr;
+    std::vector read_buf;
+
+    // output file
+    struct gguf_context * ctx_out;
+    struct ggml_context * ctx_out_ggml;
+    std::ofstream fout;
+
+    lora_merge_ctx(
+            std::string & base_fname,
+            std::vector & lora_files,
+            std::string & outfile,
+            int n_threads) : base_model(base_fname, 0), n_threads(n_threads), fout(outfile, std::ios::binary) {
+        fout.exceptions(std::ofstream::failbit); // fail fast on write errors
+
+        if (gguf_find_key(base_model.ctx_gguf, LLM_KV_SPLIT_COUNT) >= 0) {
+            throw std::runtime_error("split model is not yet supported");
+        }
+
+        for (auto & lora_inp : lora_files) {
+            auto fname = lora_inp.path;
+            auto scale = lora_inp.scale;
+            std::unique_ptr adapter(new file_input(fname, scale));
+            check_metadata_lora(adapter.get());
+            adapters.push_back(std::move(adapter));
+        }
+
+        ctx_out = gguf_init_empty();
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ gguf_get_n_tensors(base_model.ctx_gguf)*ggml_tensor_overhead(),
+            /*.mem_buffer =*/ NULL,
+            /*.no_alloc   =*/ true,
+        };
+        ctx_out_ggml = ggml_init(params);
+        backend = ggml_backend_cpu_init();
+        allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
+    }
+
+    void check_metadata_lora(file_input * adapter) {
+        auto general_type = get_kv_str(adapter->ctx_gguf, "general.type");
+        if (general_type != "adapter") {
+            throw std::runtime_error("expect general.type to be 'adapter', but got: " + general_type);
+        }
+
+        auto adapter_type = get_kv_str(adapter->ctx_gguf, "adapter.type");
+        if (adapter_type != "lora") {
+            throw std::runtime_error("expect adapter.type to be 'lora', but got: " + adapter_type);
+        }
+
+        auto general_arch_base = get_kv_str(base_model.ctx_gguf, "general.architecture");
+        auto general_arch_lora = get_kv_str(adapter->ctx_gguf,   "general.architecture");
+        if (general_arch_base != general_arch_lora) {
+            throw std::runtime_error("model arch and LoRA arch mismatch");
+        }
+    }
+
+    ggml_type get_out_tensor_type(struct ggml_tensor * t) {
+        if (t->type == GGML_TYPE_F32) {
+            return GGML_TYPE_F32;
+        } else {
+            return GGML_TYPE_F16;
+        }
+    }
+
+    void run_merge() {
+        // prepare metadata
+        gguf_set_kv(ctx_out, base_model.ctx_gguf);
+        // output is forced to f16 for now
+        gguf_set_val_u32(ctx_out, "general.file_type", LLAMA_FTYPE_MOSTLY_F16);
+
+        // check if all lora adapters have the same tensors
+        // TODO: remove this when we can support merging subset of adapters. Ref: https://github.com/ggerganov/llama.cpp/pull/8607#discussion_r1686027777
+        static const char * err_no_subset_adapter = "Input adapters do not have the same list of tensors. This is not yet supported. Please merge the adapter one-by-one instead of merging all at once.";
+        if (adapters.size() > 1) {
+            for (size_t i = 1; i < adapters.size(); ++i) {
+                if (adapters[0]->tensors.size() != adapters[i]->tensors.size()) {
+                    throw std::runtime_error(err_no_subset_adapter);
+                }
+                for (auto & it : adapters[i]->tensors) {
+                    if (adapters[0]->get_tensor(it.first) == nullptr) {
+                        throw std::runtime_error(err_no_subset_adapter);
+                    }
+                }
+            }
+        }
+
+        // mapping base tensor to out tensor (same shape with base, but different type)
+        std::vector trans;
+        for (auto & it : base_model.tensors) {
+            bool t_a = true;
+            bool t_b = true;
+            for (auto & adapter : adapters) {
+                t_a &= nullptr != adapter->get_tensor(it.first + ".lora_a");
+                t_b &= nullptr != adapter->get_tensor(it.first + ".lora_b");
+            }
+            auto base_tensor = it.second;
+            if (!t_a && !t_b) {
+                // only copy
+                struct ggml_tensor * cpy_tensor = ggml_dup_tensor(ctx_out_ggml, base_tensor);
+                ggml_set_name(cpy_tensor, base_tensor->name);
+                trans.push_back({
+                    cpy_tensor,
+                    cpy_tensor,
+                    true,
+                });
+                gguf_add_tensor(ctx_out, cpy_tensor);
+            } else if (t_a && t_b) {
+                // need merging
+                struct ggml_tensor * out_tensor = ggml_new_tensor(
+                    ctx_out_ggml, get_out_tensor_type(base_tensor), GGML_MAX_DIMS, base_tensor->ne);
+                ggml_set_name(out_tensor, base_tensor->name);
+                trans.push_back({
+                    base_tensor,
+                    out_tensor,
+                    false,
+                });
+                gguf_add_tensor(ctx_out, out_tensor);
+            } else {
+                throw std::runtime_error("tensor " + it.first + " missing either lora_a or lora_b");
+            }
+        }
+
+        // placeholder for the meta data
+        {
+            size_t meta_size = gguf_get_meta_size(ctx_out);
+            zeros(fout, meta_size);
+        }
+
+        // process base model tensors
+        size_t n_merged = 0;
+        for (auto & it : trans) {
+            if (!it.is_copy) {
+                merge_tensor(it.in, it.out);
+                n_merged++;
+            } else {
+                copy_tensor(it.in);
+            }
+        }
+
+        // write output metadata
+        {
+            std::vector data(gguf_get_meta_size(ctx_out));
+            gguf_get_meta_data(ctx_out, data.data());
+            fout.seekp(0);
+            fout.write((const char *)data.data(), data.size());
+        }
+
+        printf("%s : merged %zu tensors with lora adapters\n", __func__, n_merged);
+        printf("%s : wrote %zu tensors to output file\n", __func__, trans.size());
+    }
+
+    void copy_tensor(struct ggml_tensor * base) {
+        printf("%s :  %s [%s]\n", __func__, base->name, ggml_ne_string(base).c_str());
+        size_t len = ggml_nbytes(base);
+        base_model.read_tensor_data(base->name, read_buf);
+        fout.write((char* )read_buf.data(), len);
+        zeros(fout, GGML_PAD(len, GGUF_DEFAULT_ALIGNMENT) - len);
+    }
+
+    void merge_tensor(struct ggml_tensor * base, struct ggml_tensor * out) {
+        std::string name_base(base->name);
+        std::string name_lora_a = name_base + ".lora_a";
+        std::string name_lora_b = name_base + ".lora_b";
+
+        printf("%s : %s [%s]\n", __func__, base->name, ggml_ne_string(base).c_str());
+
+        // context for input tensor
+        std::vector inp_a(adapters.size());
+        std::vector inp_b(adapters.size());
+        struct ggml_init_params params {
+            /*.mem_size   =*/ ggml_tensor_overhead()*(2+adapters.size()*2),
+            /*.mem_buffer =*/ NULL,
+            /*.no_alloc   =*/ true,
+        };
+        struct ggml_context * ctx = ggml_init(params);
+
+        // alloc tensors
+        struct ggml_tensor * inp_base = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, base->ne);
+        for (size_t i = 0; i < adapters.size(); ++i) {
+            auto t_a = adapters[i]->get_tensor(name_lora_a);
+            auto t_b = adapters[i]->get_tensor(name_lora_b);
+            // TODO: add support for quantized lora
+            if (ggml_is_quantized(t_a->type) || ggml_is_quantized(t_b->type)) {
+                throw std::runtime_error("quantized LoRA adapters is not supported, please retry with f16 or f32");
+            }
+            inp_a[i] = ggml_dup_tensor(ctx, t_a);
+            inp_b[i] = ggml_dup_tensor(ctx, t_b);
+        }
+        ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
+
+        // load base tensor to backend buffer
+        base_model.read_tensor_data(name_base, read_buf);
+        if (base->type != GGML_TYPE_F32) {
+            // optionally dequantize it
+            printf("%s :   + dequantize base tensor from %s to F32\n", __func__, ggml_type_name(base->type));
+            auto nels = ggml_nelements(inp_base);
+            const auto * qtype = ggml_get_type_traits(base->type);
+            std::vector dequant_buf(nels * sizeof(float));
+            qtype->to_float(read_buf.data(), (float *)dequant_buf.data(), nels);
+            ggml_backend_tensor_set(inp_base, dequant_buf.data(), 0, dequant_buf.size());
+        } else {
+            ggml_backend_tensor_set(inp_base, read_buf.data(), 0, ggml_nbytes(inp_base));
+        }
+
+        // load lora tensors to backend buffer
+        for (size_t i = 0; i < adapters.size(); ++i) {
+            adapters[i]->read_tensor_data(name_lora_a, read_buf);
+            ggml_backend_tensor_set(inp_a[i], read_buf.data(), 0, ggml_nbytes(inp_a[i]));
+            adapters[i]->read_tensor_data(name_lora_b, read_buf);
+            ggml_backend_tensor_set(inp_b[i], read_buf.data(), 0, ggml_nbytes(inp_b[i]));
+        }
+
+        // build graph
+        struct ggml_cgraph * gf;
+        {
+            static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
+            static std::vector buf(buf_size);
+            struct ggml_init_params params0 = {
+                /*.mem_size   =*/ buf_size,
+                /*.mem_buffer =*/ buf.data(),
+                /*.no_alloc   =*/ true,
+            };
+            struct ggml_context * ctx0 = ggml_init(params0);
+            gf = ggml_new_graph(ctx0);
+            struct ggml_tensor * cur = inp_base;
+            for (size_t i = 0; i < adapters.size(); ++i) {
+                struct ggml_tensor * delta;
+                bool is_tok_embd = string_starts_with(name_base, "token_embd");
+                if (is_tok_embd) {
+                    printf("%s :     detected token embeddings tensor\n", __func__);
+                    delta = ggml_mul_mat(ctx0,
+                        ggml_cast(ctx0, inp_b[i], GGML_TYPE_F32),
+                        ggml_cast(ctx0, inp_a[i], GGML_TYPE_F32));
+                } else {
+                    delta = ggml_mul_mat(ctx0,
+                        ggml_cont(ctx0, ggml_transpose(ctx0, ggml_cast(ctx0, inp_a[i], GGML_TYPE_F32))),
+                        ggml_cast(ctx0, inp_b[i], GGML_TYPE_F32));
+                }
+                // scale
+                const float alpha = adapters[i]->alpha;
+                const float rank  = (float) inp_b[i]->ne[0];
+                const float scale = alpha ? adapters[i]->scale * alpha / rank : adapters[i]->scale;
+                delta = ggml_scale(ctx0, delta, scale);
+                cur = ggml_add(ctx0, delta, cur);
+                printf("%s :   + merging from adapter[%zu] type=%s\n", __func__, i, ggml_type_name(inp_a[i]->type));
+                printf("%s :     input_scale=%f calculated_scale=%f rank=%d\n", __func__, adapters[i]->scale, scale, (int) inp_b[i]->ne[0]);
+            }
+            cur = ggml_cast(ctx0, cur, out->type);
+            printf("%s :   + output type is %s\n", __func__, ggml_type_name(out->type));
+            ggml_build_forward_expand(gf, cur);
+            ggml_free(ctx0);
+        }
+
+        // compute
+        {
+            ggml_gallocr_alloc_graph(allocr, gf);
+            ggml_backend_cpu_set_n_threads(backend, n_threads);
+            ggml_backend_graph_compute(backend, gf);
+        }
+
+        // write data to output file
+        {
+            auto * result = ggml_graph_node(gf, -1);
+            size_t len = ggml_nbytes(result);
+            if (read_buf.size() < len) {
+                read_buf.resize(len);
+            }
+            ggml_backend_tensor_get(result, read_buf.data(), 0, len);
+            fout.write((char* )read_buf.data(), len);
+            zeros(fout, GGML_PAD(len, GGUF_DEFAULT_ALIGNMENT) - len);
+        }
+
+        ggml_free(ctx);
+        ggml_backend_buffer_free(buffer);
+    }
+
+    ~lora_merge_ctx() {
+        ggml_gallocr_free(allocr);
+        ggml_backend_free(backend);
+        gguf_free(ctx_out);
+        ggml_free(ctx_out_ggml);
+    }
+};
+
+static void print_usage(int, char ** argv) {
+    printf("\nexample usage:\n");
+    printf("\n  %s -m base-model.gguf --lora lora-file.gguf -o merged-model-f16.gguf\n", argv[0]);
+    printf("\nNOTE: output model is F16\n");
+    printf("\n");
+}
+
+int main(int argc, char ** argv) {
+    common_params params;
+
+    params.out_file = "ggml-lora-merged-f16.gguf";
+
+    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_EXPORT_LORA, print_usage)) {
+        return 1;
+    }
+
+    g_verbose = (params.verbosity > 1);
+    try {
+        lora_merge_ctx ctx(params.model.path, params.lora_adapters, params.out_file, params.cpuparams.n_threads);
+        ctx.run_merge();
+    } catch (const std::exception & err) {
+        fprintf(stderr, "%s\n", err.what());
+        exit(EXIT_FAILURE);
+    }
+
+    printf("done, output file is %s\n", params.out_file.c_str());
+
+    return 0;
+}
diff --git a/tools/gguf-split/CMakeLists.txt b/tools/gguf-split/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c407e2f0af44a225c97987a486913ce0aa4e3043
--- /dev/null
+++ b/tools/gguf-split/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(TARGET llama-gguf-split)
+add_executable(${TARGET} gguf-split.cpp)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_17)
diff --git a/tools/gguf-split/README.md b/tools/gguf-split/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..ad1d86651b46df48b420d0f2054b5795b0df8edb
--- /dev/null
+++ b/tools/gguf-split/README.md
@@ -0,0 +1,10 @@
+## GGUF split Example
+
+CLI to split / merge GGUF files.
+
+**Command line options:**
+
+- `--split`: split GGUF to multiple GGUF, default operation.
+- `--split-max-size`: max size per split in `M` or `G`, f.ex. `500M` or `2G`.
+- `--split-max-tensors`: maximum tensors in each split: default(128)
+- `--merge`: merge multiple GGUF to a single GGUF.
diff --git a/tools/gguf-split/gguf-split.cpp b/tools/gguf-split/gguf-split.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..30e771564e80849d096fb44c11fb07ee860480e0
--- /dev/null
+++ b/tools/gguf-split/gguf-split.cpp
@@ -0,0 +1,583 @@
+#include "ggml.h"
+#include "gguf.h"
+#include "llama.h"
+#include "common.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#if defined(_WIN32)
+    #include 
+    #ifndef PATH_MAX
+        #define PATH_MAX MAX_PATH
+    #endif
+    #include 
+#endif
+
+enum split_operation : uint8_t {
+    OP_NONE,
+    OP_SPLIT,
+    OP_MERGE,
+};
+
+enum split_mode : uint8_t {
+    MODE_NONE,
+    MODE_TENSOR,
+    MODE_SIZE,
+};
+
+struct split_params {
+    split_operation operation = OP_NONE;
+    split_mode mode = MODE_NONE;
+    size_t n_bytes_split = 0;
+    int n_split_tensors = 128;
+    std::string input;
+    std::string output;
+    bool no_tensor_first_split = false;
+    bool dry_run = false;
+};
+
+static void split_print_usage(const char * executable) {
+    const split_params default_params;
+    printf("\n");
+    printf("usage: %s [options] GGUF_IN GGUF_OUT\n", executable);
+    printf("\n");
+    printf("Apply a GGUF operation on IN to OUT.");
+    printf("\n");
+    printf("options:\n");
+    printf("  -h, --help              show this help message and exit\n");
+    printf("  --version               show version and build info\n");
+    printf("  --split                 split GGUF to multiple GGUF (enabled by default)\n");
+    printf("  --merge                 merge multiple GGUF to a single GGUF\n");
+    printf("  --split-max-tensors     max tensors in each split (default: %d)\n", default_params.n_split_tensors);
+    printf("  --split-max-size N(M|G) max size per split\n");
+    printf("  --no-tensor-first-split do not add tensors to the first split (disabled by default)\n");
+    printf("  --dry-run               only print out a split plan and exit, without writing any new files\n");
+    printf("\n");
+}
+
+// return convert string, for example "128M" or "4G" to number of bytes
+static size_t split_str_to_n_bytes(std::string str) {
+    size_t n_bytes = 0;
+    int n;
+    if (str.back() == 'M') {
+        sscanf(str.c_str(), "%d", &n);
+        n_bytes = (size_t)n * 1000 * 1000; // megabytes
+    } else if (str.back() == 'G') {
+        sscanf(str.c_str(), "%d", &n);
+        n_bytes = (size_t)n * 1000 * 1000 * 1000; // gigabytes
+    } else {
+        throw std::invalid_argument("error: supported units are M (megabytes) or G (gigabytes), but got: " + std::string(1, str.back()));
+    }
+    if (n <= 0) {
+        throw std::invalid_argument("error: size must be a positive value");
+    }
+    return n_bytes;
+}
+
+static void split_params_parse_ex(int argc, const char ** argv, split_params & params) {
+    std::string arg;
+    const std::string arg_prefix = "--";
+    bool invalid_param = false;
+
+    int arg_idx = 1;
+    for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
+        arg = argv[arg_idx];
+        if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
+            std::replace(arg.begin(), arg.end(), '_', '-');
+        }
+
+        bool arg_found = false;
+        if (arg == "-h" || arg == "--help") {
+            split_print_usage(argv[0]);
+            exit(0);
+        } else if (arg == "--version") {
+            fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
+            fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET);
+            exit(0);
+        } else if (arg == "--dry-run") {
+            arg_found = true;
+            params.dry_run = true;
+        } else if (arg == "--no-tensor-first-split") {
+            arg_found = true;
+            params.no_tensor_first_split = true;
+        } else if (arg == "--merge") {
+            arg_found = true;
+            if (params.operation != OP_NONE && params.operation != OP_MERGE) {
+                throw std::invalid_argument("error: either --split or --merge can be specified, but not both");
+            }
+            params.operation = OP_MERGE;
+        } else if (arg == "--split") {
+            arg_found = true;
+            if (params.operation != OP_NONE && params.operation != OP_SPLIT) {
+                throw std::invalid_argument("error: either --split or --merge can be specified, but not both");
+            }
+            params.operation = OP_SPLIT;
+        } else if (arg == "--split-max-tensors") {
+            if (++arg_idx >= argc) {
+                invalid_param = true;
+                break;
+            }
+            arg_found = true;
+            if (params.mode != MODE_NONE && params.mode != MODE_TENSOR) {
+                throw std::invalid_argument("error: either --split-max-tensors or --split-max-size can be specified, but not both");
+            }
+            params.mode = MODE_TENSOR;
+            params.n_split_tensors = atoi(argv[arg_idx]);
+        } else if (arg == "--split-max-size") {
+            if (++arg_idx >= argc) {
+                invalid_param = true;
+                break;
+            }
+            arg_found = true;
+            if (params.mode != MODE_NONE && params.mode != MODE_SIZE) {
+                throw std::invalid_argument("error: either --split-max-tensors or --split-max-size can be specified, but not both");
+            }
+            params.mode = MODE_SIZE;
+            params.n_bytes_split = split_str_to_n_bytes(argv[arg_idx]);
+        }
+
+        if (!arg_found) {
+            throw std::invalid_argument("error: unknown argument: " + arg);
+        }
+    }
+
+    // the operation is split if not specified
+    if (params.operation == OP_NONE) {
+        params.operation = OP_SPLIT;
+    }
+    // the split mode is by tensor if not specified
+    if (params.mode == MODE_NONE) {
+        params.mode = MODE_TENSOR;
+    }
+
+    if (invalid_param) {
+        throw std::invalid_argument("error: invalid parameter for argument: " + arg);
+    }
+
+    if (argc - arg_idx != 2) {
+        throw std::invalid_argument("error: bad arguments");
+    }
+
+    params.input = argv[arg_idx++];
+    params.output = argv[arg_idx++];
+}
+
+static bool split_params_parse(int argc, const char ** argv, split_params & params) {
+    bool result = true;
+    try {
+        split_params_parse_ex(argc, argv, params);
+    }
+    catch (const std::invalid_argument & ex) {
+        fprintf(stderr, "%s\n", ex.what());
+        split_print_usage(argv[0]);
+        exit(EXIT_FAILURE);
+    }
+    return result;
+}
+
+static void zeros(std::ofstream & file, size_t n) {
+    char zero = 0;
+    for (size_t i = 0; i < n; ++i) {
+        file.write(&zero, 1);
+    }
+}
+
+struct split_strategy {
+    const split_params params;
+    std::ifstream & f_input;
+    struct gguf_context * ctx_gguf;
+    struct ggml_context * ctx_meta = NULL;
+    const int n_tensors;
+
+    // one ctx_out per one output file
+    std::vector ctx_outs;
+
+    // temporary buffer for reading in tensor data
+    std::vector read_buf;
+
+    split_strategy(const split_params & params,
+            std::ifstream & f_input,
+            struct gguf_context * ctx_gguf,
+            struct ggml_context * ctx_meta) :
+        params(params),
+        f_input(f_input),
+        ctx_gguf(ctx_gguf),
+        ctx_meta(ctx_meta),
+        n_tensors(gguf_get_n_tensors(ctx_gguf)) {
+
+        // because we need to know list of tensors for each file in advance, we will build all the ctx_out for all output splits
+        int i_split = -1;
+        struct gguf_context * ctx_out = NULL;
+        auto new_ctx_out = [&](bool allow_no_tensors) {
+            i_split++;
+            if (ctx_out != NULL) {
+                if (gguf_get_n_tensors(ctx_out) == 0 && !allow_no_tensors) {
+                    fprintf(stderr, "error: one of splits have 0 tensors. Maybe size or tensors limit is too small\n");
+                    exit(EXIT_FAILURE);
+                }
+                ctx_outs.push_back(ctx_out);
+            }
+            ctx_out = gguf_init_empty();
+            // Save all metadata in first split only
+            if (i_split == 0) {
+                gguf_set_kv(ctx_out, ctx_gguf);
+            }
+            gguf_set_val_u16(ctx_out, LLM_KV_SPLIT_NO, i_split);
+            gguf_set_val_u16(ctx_out, LLM_KV_SPLIT_COUNT, 0); // placeholder
+            gguf_set_val_i32(ctx_out, LLM_KV_SPLIT_TENSORS_COUNT, n_tensors);
+        };
+
+        // initialize ctx_out for the first split
+        new_ctx_out(false);
+
+        // skip first split if no_tensor_first_split is set
+        if (params.no_tensor_first_split) {
+            new_ctx_out(true);
+        }
+
+        // process tensors one by one
+        size_t curr_tensors_size = 0; // current size by counting only tensors size (without metadata)
+        for (int i = 0; i < n_tensors; ++i) {
+            struct ggml_tensor * t = ggml_get_tensor(ctx_meta, gguf_get_tensor_name(ctx_gguf, i));
+            // calculate the "imaginary" size = the current size + next tensor size
+            size_t n_bytes = GGML_PAD(ggml_nbytes(t), GGUF_DEFAULT_ALIGNMENT);
+            size_t next_tensors_size = curr_tensors_size + n_bytes;
+            if (should_split(i, next_tensors_size)) {
+                new_ctx_out(false);
+                curr_tensors_size = n_bytes;
+            } else {
+                curr_tensors_size = next_tensors_size;
+            }
+            gguf_add_tensor(ctx_out, t);
+        }
+
+        // push the last ctx_out
+        ctx_outs.push_back(ctx_out);
+
+        // set the correct n_split for all ctx_out
+        for (auto & ctx : ctx_outs) {
+            gguf_set_val_u16(ctx, LLM_KV_SPLIT_COUNT, ctx_outs.size());
+        }
+    }
+
+    ~split_strategy() {
+        for (auto & ctx_out : ctx_outs) {
+            gguf_free(ctx_out);
+        }
+    }
+
+    bool should_split(int i_tensor, size_t next_size) {
+        if (params.mode == MODE_SIZE) {
+            // split by max size per file
+            return next_size > params.n_bytes_split;
+        } else if (params.mode == MODE_TENSOR) {
+            // split by number of tensors per file
+            return i_tensor > 0 && i_tensor < n_tensors && i_tensor % params.n_split_tensors == 0;
+        }
+        // should never happen
+        GGML_ABORT("invalid mode");
+    }
+
+    void print_info() {
+        printf("n_split: %zu\n", ctx_outs.size());
+        int i_split = 0;
+        for (auto & ctx_out : ctx_outs) {
+            // re-calculate the real gguf size for each split (= metadata size + total size of all tensors)
+            size_t total_size = gguf_get_meta_size(ctx_out);
+            for (int i = 0; i < gguf_get_n_tensors(ctx_out); ++i) {
+                struct ggml_tensor * t = ggml_get_tensor(ctx_meta, gguf_get_tensor_name(ctx_out, i));
+                total_size += ggml_nbytes(t);
+            }
+            total_size = total_size / 1000 / 1000; // convert to megabytes
+            printf("split %05d: n_tensors = %" PRIi64 ", total_size = %zuM\n", i_split + 1, gguf_get_n_tensors(ctx_out), total_size);
+            i_split++;
+        }
+    }
+
+    void write() {
+        int i_split = 0;
+        int n_split = ctx_outs.size();
+        for (auto & ctx_out : ctx_outs) {
+            // construct file path
+            char split_path[PATH_MAX] = {0};
+            llama_split_path(split_path, sizeof(split_path), params.output.c_str(), i_split, n_split);
+
+            // open the output file
+            printf("Writing file %s ... ", split_path);
+            fflush(stdout);
+            std::ofstream fout = std::ofstream(split_path, std::ios::binary);
+            fout.exceptions(std::ofstream::failbit); // fail fast on write errors
+
+            // write metadata
+            std::vector data(gguf_get_meta_size(ctx_out));
+            gguf_get_meta_data(ctx_out, data.data());
+            fout.write((const char *)data.data(), data.size());
+
+            // write tensors
+            for (int i = 0; i < gguf_get_n_tensors(ctx_out); ++i) {
+                // read tensor meta and prepare buffer
+                const char * t_name = gguf_get_tensor_name(ctx_out, i);
+                struct ggml_tensor * t = ggml_get_tensor(ctx_meta, t_name);
+                auto n_bytes = ggml_nbytes(t);
+                read_buf.resize(n_bytes);
+
+                // calculate offset
+                auto i_tensor_in = gguf_find_tensor(ctx_gguf, t_name); // idx of tensor in the input file
+                auto offset = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i_tensor_in);
+
+                // copy tensor from input to output file
+                copy_file_to_file(f_input, fout, offset, n_bytes);
+                zeros(fout, GGML_PAD(n_bytes, GGUF_DEFAULT_ALIGNMENT) - n_bytes);
+            }
+
+            printf("done\n");
+            // close the file
+            fout.close();
+            i_split++;
+        }
+    }
+
+    void copy_file_to_file(std::ifstream & f_in, std::ofstream & f_out, const size_t in_offset, const size_t len) {
+        // TODO: detect OS and use copy_file_range() here for better performance
+        if (read_buf.size() < len) {
+            read_buf.resize(len);
+        }
+        f_in.seekg(in_offset);
+        f_in.read((char *)read_buf.data(), len);
+        f_out.write((const char *)read_buf.data(), len);
+    }
+};
+
+static void gguf_split(const split_params & split_params) {
+    struct ggml_context * ctx_meta = NULL;
+
+    struct gguf_init_params params = {
+        /*.no_alloc = */ true,
+        /*.ctx      = */ &ctx_meta,
+    };
+
+    std::ifstream f_input(split_params.input.c_str(), std::ios::binary);
+    if (!f_input.is_open()) {
+        fprintf(stderr, "%s:  failed to open input GGUF from %s\n", __func__, split_params.input.c_str());
+        exit(EXIT_FAILURE);
+    }
+
+    auto * ctx_gguf = gguf_init_from_file(split_params.input.c_str(), params);
+    if (!ctx_gguf) {
+        fprintf(stderr, "%s:  failed to load input GGUF from %s\n", __func__, split_params.input.c_str());
+        exit(EXIT_FAILURE);
+    }
+
+    // prepare the strategy
+    split_strategy strategy(split_params, f_input, ctx_gguf, ctx_meta);
+    int n_split = strategy.ctx_outs.size();
+    strategy.print_info();
+
+    if (!split_params.dry_run) {
+        // write all output splits
+        strategy.write();
+    }
+
+    // done, clean up
+    gguf_free(ctx_gguf);
+    f_input.close();
+
+    fprintf(stderr, "%s: %d gguf split written with a total of %d tensors.\n",
+            __func__, n_split, strategy.n_tensors);
+}
+
+static void gguf_merge(const split_params & split_params) {
+    fprintf(stderr, "%s: %s -> %s\n",
+            __func__, split_params.input.c_str(),
+            split_params.output.c_str());
+    int n_split = 1;
+    int total_tensors = 0;
+
+    // avoid overwriting existing output file
+    if (std::ifstream(split_params.output.c_str())) {
+        fprintf(stderr, "%s: output file %s already exists\n", __func__, split_params.output.c_str());
+        exit(EXIT_FAILURE);
+    }
+
+
+    auto * ctx_out = gguf_init_empty();
+
+    std::vector read_data;
+    std::vector ctx_metas;
+    std::vector ctx_ggufs;
+
+    char split_path[PATH_MAX] = {0};
+    strncpy(split_path, split_params.input.c_str(), sizeof(split_path) - 1);
+    char split_prefix[PATH_MAX] = {0};
+
+    // First pass to find KV and tensors metadata
+    for (int i_split = 0; i_split < n_split; i_split++) {
+        struct ggml_context * ctx_meta = NULL;
+
+        struct gguf_init_params params = {
+            /*.no_alloc = */ true,
+            /*.ctx      = */ &ctx_meta,
+        };
+
+        if (i_split > 0) {
+            llama_split_path(split_path, sizeof(split_path), split_prefix, i_split, n_split);
+        }
+        fprintf(stderr, "%s: reading metadata %s ...", __func__, split_path);
+
+        auto * ctx_gguf = gguf_init_from_file(split_path, params);
+        if (!ctx_gguf) {
+            fprintf(stderr, "\n%s:  failed to load input GGUF from %s\n", __func__, split_params.input.c_str());
+            exit(EXIT_FAILURE);
+        }
+        ctx_ggufs.push_back(ctx_gguf);
+        ctx_metas.push_back(ctx_meta);
+
+        if (i_split == 0) {
+            auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
+            if (key_n_split < 0) {
+                fprintf(stderr,
+                        "\n%s: input file does not contain %s metadata\n",
+                        __func__,
+                        LLM_KV_SPLIT_COUNT);
+                gguf_free(ctx_gguf);
+                ggml_free(ctx_meta);
+                gguf_free(ctx_out);
+                exit(EXIT_FAILURE);
+            }
+
+            n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
+            if (n_split < 1) {
+                fprintf(stderr,
+                        "\n%s: input file does not contain a valid split count %d\n",
+                        __func__,
+                        n_split);
+                gguf_free(ctx_gguf);
+                ggml_free(ctx_meta);
+                gguf_free(ctx_out);
+                exit(EXIT_FAILURE);
+            }
+
+            // Verify the file naming and extract split_prefix
+            if (!llama_split_prefix(split_prefix, sizeof (split_prefix), split_path, i_split, n_split)) {
+                fprintf(stderr, "\n%s: unexpected input file name: %s"
+                                " i_split=%d"
+                                " n_split=%d\n", __func__,
+                        split_path, i_split, n_split);
+                gguf_free(ctx_gguf);
+                ggml_free(ctx_meta);
+                gguf_free(ctx_out);
+                exit(EXIT_FAILURE);
+            }
+
+            // Do not trigger merge if we try to merge again the output
+            gguf_set_val_u16(ctx_gguf, LLM_KV_SPLIT_COUNT, 0);
+
+            // Set metadata from the first split
+            gguf_set_kv(ctx_out, ctx_gguf);
+        }
+
+        auto n_tensors = gguf_get_n_tensors(ctx_gguf);
+        for (int i_tensor = 0; i_tensor < n_tensors; i_tensor++) {
+            const char * t_name = gguf_get_tensor_name(ctx_gguf, i_tensor);
+            struct ggml_tensor * t = ggml_get_tensor(ctx_meta, t_name);
+            gguf_add_tensor(ctx_out, t);
+        }
+        total_tensors += n_tensors;
+
+        fprintf(stderr, "\033[3Ddone\n");
+    }
+    std::ofstream fout;
+    if (!split_params.dry_run) {
+        fout.open(split_params.output.c_str(), std::ios::binary);
+        fout.exceptions(std::ofstream::failbit); // fail fast on write errors
+        // placeholder for the meta data
+        auto meta_size = gguf_get_meta_size(ctx_out);
+        ::zeros(fout, meta_size);
+    }
+
+    // Write tensors data
+    for (int i_split = 0; i_split < n_split; i_split++) {
+        llama_split_path(split_path, sizeof(split_path), split_prefix, i_split, n_split);
+        std::ifstream f_input(split_path, std::ios::binary);
+        if (!f_input.is_open()) {
+            fprintf(stderr, "%s:  failed to open input GGUF from %s\n", __func__, split_path);
+            for (uint32_t i = 0; i < ctx_ggufs.size(); i++) {
+                gguf_free(ctx_ggufs[i]);
+                ggml_free(ctx_metas[i]);
+            }
+            gguf_free(ctx_out);
+            if (!split_params.dry_run) {
+                fout.close();
+            }
+            exit(EXIT_FAILURE);
+        }
+        fprintf(stderr, "%s: writing tensors %s ...", __func__, split_path);
+
+        auto * ctx_gguf = ctx_ggufs[i_split];
+        auto * ctx_meta = ctx_metas[i_split];
+
+        auto n_tensors = gguf_get_n_tensors(ctx_gguf);
+        for (int i_tensor = 0; i_tensor < n_tensors; i_tensor++) {
+            const char * t_name = gguf_get_tensor_name(ctx_gguf, i_tensor);
+            struct ggml_tensor * t = ggml_get_tensor(ctx_meta, t_name);
+
+            auto n_bytes = ggml_nbytes(t);
+
+            if (read_data.size() < n_bytes) {
+                read_data.resize(n_bytes);
+            }
+
+            auto offset = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i_tensor);
+            f_input.seekg(offset);
+            f_input.read((char *)read_data.data(), n_bytes);
+            if (!split_params.dry_run) {
+                // write tensor data + padding
+                fout.write((const char *)read_data.data(), n_bytes);
+                zeros(fout, GGML_PAD(n_bytes, GGUF_DEFAULT_ALIGNMENT) - n_bytes);
+            }
+        }
+
+        gguf_free(ctx_gguf);
+        ggml_free(ctx_meta);
+        f_input.close();
+        fprintf(stderr, "\033[3Ddone\n");
+    }
+
+    if (!split_params.dry_run) {
+        // go back to beginning of file and write the updated metadata
+        fout.seekp(0);
+        std::vector data(gguf_get_meta_size(ctx_out));
+        gguf_get_meta_data(ctx_out, data.data());
+        fout.write((const char *)data.data(), data.size());
+        fout.close();
+    }
+    gguf_free(ctx_out);
+
+    fprintf(stderr, "%s: %s merged from %d split with %d tensors.\n",
+            __func__, split_params.output.c_str(), n_split, total_tensors);
+}
+
+int main(int argc, const char ** argv) {
+    split_params params;
+    split_params_parse(argc, argv, params);
+
+    switch (params.operation) {
+        case OP_SPLIT: gguf_split(params);
+            break;
+        case OP_MERGE: gguf_merge(params);
+            break;
+        default: split_print_usage(argv[0]);
+            exit(EXIT_FAILURE);
+    }
+
+    return 0;
+}
diff --git a/tools/gguf-split/tests.sh b/tools/gguf-split/tests.sh
new file mode 100644
index 0000000000000000000000000000000000000000..05a93222711d848c43dfa312b7b7f05eb3e229dc
--- /dev/null
+++ b/tools/gguf-split/tests.sh
@@ -0,0 +1,89 @@
+#!/bin/bash
+
+set -eu
+
+if [ $# -lt 1 ]
+then
+    echo "usage:   $0 path_to_build_binary [path_to_temp_folder]"
+    echo "example: $0 ../../build/bin ../../tmp"
+    exit 1
+fi
+
+if [ $# -gt 1 ]
+then
+    TMP_DIR=$2
+else
+    TMP_DIR=/tmp
+fi
+
+set -x
+
+SPLIT=$1/llama-gguf-split
+MAIN=$1/llama-cli
+WORK_PATH=$TMP_DIR/gguf-split
+ROOT_DIR=$(realpath $(dirname $0)/../../)
+
+mkdir -p "$WORK_PATH"
+
+# Clean up in case of previously failed test
+rm -f $WORK_PATH/ggml-model-split*.gguf $WORK_PATH/ggml-model-merge*.gguf
+
+# 1. Get a model
+(
+cd $WORK_PATH
+"$ROOT_DIR"/scripts/hf.sh --repo ggml-org/gemma-1.1-2b-it-Q8_0-GGUF --file gemma-1.1-2b-it.Q8_0.gguf
+)
+echo PASS
+
+# 2. Split with max tensors strategy
+$SPLIT --split-max-tensors 28  $WORK_PATH/gemma-1.1-2b-it.Q8_0.gguf $WORK_PATH/ggml-model-split
+echo PASS
+echo
+
+# 2b. Test the sharded model is loading properly
+$MAIN -no-cnv --model $WORK_PATH/ggml-model-split-00001-of-00006.gguf --n-predict 32
+echo PASS
+echo
+
+# 3. Merge
+$SPLIT --merge $WORK_PATH/ggml-model-split-00001-of-00006.gguf $WORK_PATH/ggml-model-merge.gguf
+echo PASS
+echo
+
+# 3b. Test the merged model is loading properly
+$MAIN -no-cnv --model $WORK_PATH/ggml-model-merge.gguf --n-predict 32
+echo PASS
+echo
+
+# 4. Split with no tensors in the first split
+$SPLIT --split-max-tensors 32 --no-tensor-first-split $WORK_PATH/ggml-model-merge.gguf $WORK_PATH/ggml-model-split-32-tensors
+echo PASS
+echo
+
+# 4b. Test the sharded model is loading properly
+$MAIN -no-cnv --model $WORK_PATH/ggml-model-split-32-tensors-00001-of-00007.gguf --n-predict 32
+echo PASS
+echo
+
+# 5. Merge
+#$SPLIT --merge $WORK_PATH/ggml-model-split-32-tensors-00001-of-00006.gguf $WORK_PATH/ggml-model-merge-2.gguf
+#echo PASS
+#echo
+
+# 5b. Test the merged model is loading properly
+#$MAIN -no-cnv --model $WORK_PATH/ggml-model-merge-2.gguf --n-predict 32
+#echo PASS
+#echo
+
+# 6. Split with size strategy
+$SPLIT --split-max-size 2G $WORK_PATH/ggml-model-merge.gguf $WORK_PATH/ggml-model-split-2G
+echo PASS
+echo
+
+# 6b. Test the sharded model is loading properly
+$MAIN -no-cnv --model $WORK_PATH/ggml-model-split-2G-00001-of-00002.gguf --n-predict 32
+echo PASS
+echo
+
+# Clean up
+rm -f $WORK_PATH/ggml-model-split*.gguf $WORK_PATH/ggml-model-merge*.gguf
diff --git a/tools/imatrix/CMakeLists.txt b/tools/imatrix/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..412696c47c31c1f07d2484054217f3ad02871e93
--- /dev/null
+++ b/tools/imatrix/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(TARGET llama-imatrix)
+add_executable(${TARGET} imatrix.cpp)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_17)
diff --git a/tools/imatrix/README.md b/tools/imatrix/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6d8897d98bb6188f62445555f03c6f99054247ef
--- /dev/null
+++ b/tools/imatrix/README.md
@@ -0,0 +1,33 @@
+# llama.cpp/tools/imatrix
+
+Compute an importance matrix for a model and given text dataset. Can be used during quantization to enhance the quality of the quantized models.
+More information is available here: https://github.com/ggml-org/llama.cpp/pull/4861
+
+## Usage
+
+```
+./llama-imatrix \
+    -m model.gguf -f some-text.txt [-o imatrix.dat] [--process-output] [--verbosity 1] \
+    [--no-ppl] [--chunk 123] [--output-frequency 10] [--save-frequency 0] \
+    [--in-file imatrix-prev-0.dat --in-file imatrix-prev-1.dat ...]
+```
+
+Here `-m` with a model name and `-f` with a file containing training data (such as e.g. `wiki.train.raw`) are mandatory.
+The parameters in square brackets are optional and have the following meaning:
+* `-o` (or `--output-file`) specifies the name of the file where the computed data will be stored. If missing `imatrix.dat` is used.
+* `--verbosity` specifies the verbosity level. If set to `0`, no output other than the perplexity of the processed chunks will be generated. If set to `1`, each time the results are saved a message is written to `stderr`. If `>=2`, a message is output each time data is collected for any tensor. Default verbosity level is `1`.
+* `--output-frequency` specifies how often the so far computed result is saved to disk. Default is 10 (i.e., every 10 chunks)
+* `--save-frequency` specifies how often to save a copy of the imatrix in a separate file. Default is 0 (i.e., never)
+* `--process-output` specifies if data will be collected for the `output.weight` tensor. My experience is that it is better to not utilize the importance matrix when quantizing `output.weight`, so this is set to `false` by default.
+
+For faster computation, make sure to use GPU offloading via the `-ngl` argument
+
+## Example
+
+```bash
+# generate importance matrix (imatrix.dat)
+./llama-imatrix -m ggml-model-f16.gguf -f train-data.txt -ngl 99
+
+# use the imatrix to perform a Q4_K_M quantization
+./llama-quantize --imatrix imatrix.dat ggml-model-f16.gguf ./ggml-model-q4_k_m.gguf q4_k_m
+```
diff --git a/tools/imatrix/imatrix.cpp b/tools/imatrix/imatrix.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..daad44e59579f5c495e70d35a5ba86d147346d06
--- /dev/null
+++ b/tools/imatrix/imatrix.cpp
@@ -0,0 +1,667 @@
+#include "arg.h"
+#include "common.h"
+#include "log.h"
+#include "llama.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+static void print_usage(int, char ** argv) {
+    LOG("\nexample usage:\n");
+    LOG("\n    %s \\\n"
+            "       -m model.gguf -f some-text.txt [-o imatrix.dat] [--process-output] \\\n"
+            "       [--no-ppl] [--chunk 123] [--output-frequency 10] [--save-frequency 0] \\\n"
+            "       [--in-file imatrix-prev-0.dat --in-file imatrix-prev-1.dat ...] \\\n"
+            "       [--parse-special]\n" , argv[0]);
+    LOG("\n");
+}
+
+struct Stats {
+    std::vector values;
+    std::vector counts;
+    int ncall = 0;
+};
+
+class IMatrixCollector {
+public:
+    IMatrixCollector() = default;
+    void set_params(common_params params) { m_params = std::move(params); }
+    bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data);
+    void save_imatrix(int ncall = -1) const;
+    bool load_imatrix(const char * fname);
+private:
+    std::unordered_map m_stats;
+    common_params                          m_params;
+    std::mutex                             m_mutex;
+    int                                    m_last_call = 0;
+    std::vector                      m_src1_data;
+    std::vector                      m_ids; // the expert ids from ggml_mul_mat_id
+};
+
+// remove any prefix and suffixes from the name
+// CUDA0#blk.0.attn_k.weight#0 => blk.0.attn_k.weight
+static std::string filter_tensor_name(const char * name) {
+    std::string wname;
+    const char * p = strchr(name, '#');
+    if (p != NULL) {
+        p = p + 1;
+        const char * q = strchr(p, '#');
+        if (q != NULL) {
+            wname = std::string(p, q - p);
+        } else {
+            wname = p;
+        }
+    } else {
+        wname = name;
+    }
+    return wname;
+}
+
+bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data) {
+    GGML_UNUSED(user_data);
+
+    const struct ggml_tensor * src0 = t->src[0];
+    const struct ggml_tensor * src1 = t->src[1];
+    std::string wname = filter_tensor_name(src0->name);
+
+    // when ask is true, the scheduler wants to know if we are interested in data from this tensor
+    // if we return true, a follow-up call will be made with ask=false in which we can do the actual collection
+    if (ask) {
+        if (t->op == GGML_OP_MUL_MAT_ID) return true; // collect all indirect matrix multiplications
+        if (t->op != GGML_OP_MUL_MAT) return false;
+        // why are small batches ignored (<16 tokens)?
+        if (src1->ne[1] < 16 || src1->type != GGML_TYPE_F32) return false;
+        if (!(wname.substr(0, 4) == "blk." || (m_params.process_output && wname == "output.weight"))) return false;
+        return true;
+    }
+
+    std::lock_guard lock(m_mutex);
+
+    // copy the data from the GPU memory if needed
+    const bool is_host = ggml_backend_buffer_is_host(src1->buffer);
+
+    if (!is_host) {
+        const size_t src1_nbytes = ggml_nbytes(src1);
+        m_src1_data.resize(src1_nbytes);
+        ggml_backend_tensor_get(src1, m_src1_data.data(), 0, src1_nbytes);
+    }
+
+    const char * data = is_host ? (const char *) src1->data : m_src1_data.data();
+    GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
+
+    // this has been adapted to the new format of storing merged experts in a single 3d tensor
+    // ref: https://github.com/ggml-org/llama.cpp/pull/6387
+    if (t->op == GGML_OP_MUL_MAT_ID) {
+        //   ids  -> [n_experts_used, n_tokens]
+        //   src1 -> [cols, n_expert_used, n_tokens]
+        const ggml_tensor * ids = t->src[2];
+        const int n_as = src0->ne[2];
+        const int n_ids = ids->ne[0];
+
+        // the top-k selected expert ids are stored in the ids tensor
+        // for simplicity, always copy ids to host, because it is small
+        // take into account that ids is not contiguous!
+
+        GGML_ASSERT(ids->ne[1] == src1->ne[2]);
+
+        m_ids.resize(ggml_nbytes(ids));
+        ggml_backend_tensor_get(ids, m_ids.data(), 0, ggml_nbytes(ids));
+
+        auto & e = m_stats[wname];
+
+        ++e.ncall;
+
+        if (e.values.empty()) {
+            e.values.resize(src1->ne[0]*n_as, 0);
+            e.counts.resize(src1->ne[0]*n_as, 0);
+        }
+        else if (e.values.size() != (size_t)src1->ne[0]*n_as) {
+            LOG_ERR("%s: inconsistent size for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as);
+            exit(1); //GGML_ABORT("fatal error");
+        }
+        LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type);
+        // loop over all possible experts, regardless if they are used or not in the batch
+        for (int ex = 0; ex < n_as; ++ex) {
+            size_t e_start = ex*src1->ne[0];
+
+            for (int idx = 0; idx < n_ids; ++idx) {
+                for (int row = 0; row < (int)src1->ne[2]; ++row) {
+                    const int excur = *(const int32_t *) (m_ids.data() + row*ids->nb[1] + idx*ids->nb[0]);
+
+                    GGML_ASSERT(excur >= 0 && excur < n_as); // sanity check
+
+                    if (excur != ex) continue;
+
+                    const int64_t i11 = idx % src1->ne[1];
+                    const int64_t i12 = row;
+                    const float * x = (const float *)(data + i11*src1->nb[1] + i12*src1->nb[2]);
+
+                    for (int j = 0; j < (int)src1->ne[0]; ++j) {
+                        e.values[e_start + j] += x[j]*x[j];
+                        e.counts[e_start + j]++;
+                        if (!std::isfinite(e.values[e_start + j])) {
+                            LOG("\n");
+                            LOG_ERR("%f detected in %s\n", e.values[e_start + j], wname.c_str());
+                            exit(1);
+                        }
+                    }
+                }
+            }
+            if (e.ncall > m_last_call) {
+                m_last_call = e.ncall;
+                if (m_last_call % m_params.n_out_freq == 0) {
+                    save_imatrix();
+                }
+                if (m_params.n_save_freq > 0 && m_last_call%m_params.n_save_freq == 0) {
+                    save_imatrix(m_last_call);
+                }
+            }
+        }
+    } else {
+        auto & e = m_stats[wname];
+        if (e.values.empty()) {
+            e.values.resize(src1->ne[0], 0);
+            e.counts.resize(src1->ne[0], 0);
+        }
+        else if (e.values.size() != (size_t)src1->ne[0]) {
+            LOG_ERR("%s: inconsistent size for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.values.size(), (int)src1->ne[0]);
+            exit(1); //GGML_ABORT("fatal error");
+        }
+        ++e.ncall;
+        LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type);
+        for (int row = 0; row < (int)src1->ne[1]; ++row) {
+            const float * x = (const float *) (data + row * src1->nb[1]);
+            for (int j = 0; j < (int)src1->ne[0]; ++j) {
+                e.values[j] += x[j]*x[j];
+                e.counts[j]++;
+                if (!std::isfinite(e.values[j])) {
+                    LOG_ERR("%f detected in %s\n", e.values[j], wname.c_str());
+                    exit(1);
+                }
+            }
+        }
+        if (e.ncall > m_last_call) {
+            m_last_call = e.ncall;
+            if (m_last_call % m_params.n_out_freq == 0) {
+                save_imatrix();
+            }
+            if (m_params.n_save_freq > 0 && m_last_call%m_params.n_save_freq == 0) {
+                save_imatrix(m_last_call);
+            }
+        }
+    }
+
+    return true;
+}
+
+void IMatrixCollector::save_imatrix(int ncall) const {
+    auto fname = m_params.out_file;
+
+    if (ncall > 0) {
+        fname += ".at_";
+        fname += std::to_string(ncall);
+    }
+
+    // avoid writing imatrix entries that do not have full data
+    // this can happen with MoE models where some of the experts end up not being exercised by the provided training data
+
+    int n_entries = 0;
+    std::vector to_store;
+
+    bool is_first = true; // for printing
+    for (const auto & kv : m_stats) {
+        const int n_all = kv.second.counts.size();
+
+        if (n_all == 0) {
+            continue;
+        }
+
+        int n_zeros = 0;
+        for (const int c : kv.second.counts) {
+            if (c == 0) {
+                n_zeros++;
+            }
+        }
+
+        if (n_zeros != 0 && is_first) {
+            LOG_INF("\n");
+            is_first = false;
+        }
+
+        if (n_zeros == n_all) {
+            LOG_WRN("%s: entry '%40s' has no data - skipping\n", __func__, kv.first.c_str());
+            continue;
+        }
+
+        if (n_zeros > 0) {
+            LOG_WRN("%s: entry '%40s' has partial data (%.2f%%) - skipping\n", __func__, kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all);
+            continue;
+        }
+
+        n_entries++;
+        to_store.push_back(kv.first);
+    }
+
+    if (to_store.size() < m_stats.size()) {
+        LOG_WRN("%s: storing only %zu out of %zu entries\n", __func__, to_store.size(), m_stats.size());
+    }
+
+    std::ofstream out(fname, std::ios::binary);
+    out.write((const char *) &n_entries, sizeof(n_entries));
+    for (const auto & name : to_store) {
+        const auto & stat = m_stats.at(name);
+        int len = name.size();
+        out.write((const char *) &len, sizeof(len));
+        out.write(name.c_str(), len);
+        out.write((const char *) &stat.ncall, sizeof(stat.ncall));
+        int nval = stat.values.size();
+        out.write((const char *) &nval, sizeof(nval));
+        if (nval > 0) {
+            std::vector tmp(nval);
+            for (int i = 0; i < nval; i++) {
+                tmp[i] = (stat.values[i] / static_cast(stat.counts[i])) * static_cast(stat.ncall);
+            }
+            out.write((const char*)tmp.data(), nval*sizeof(float));
+        }
+    }
+
+    // Write the number of call the matrix was computed with
+    out.write((const char *) &m_last_call, sizeof(m_last_call));
+
+    // Write the input filename at the end of the file to later on specify it in quantize
+    {
+        int len = m_params.prompt_file.size();
+        out.write((const char *) &len, sizeof(len));
+        out.write(m_params.prompt_file.c_str(), len);
+    }
+
+    LOGV(1, "\n");
+    LOG_DBGV(1, "%s: stored collected data after %d chunks in %s\n", __func__, m_last_call, fname.c_str());
+}
+
+bool IMatrixCollector::load_imatrix(const char * fname) {
+    std::ifstream in(fname, std::ios::binary);
+    if (!in) {
+        LOG_ERR("%s: failed to open %s\n",__func__, fname);
+        return false;
+    }
+    int n_entries;
+    in.read((char*)&n_entries, sizeof(n_entries));
+    if (in.fail() || n_entries < 1) {
+        LOG_ERR("%s: no data in file %s\n", __func__, fname);
+        return false;
+    }
+    for (int i = 0; i < n_entries; ++i) {
+        int len; in.read((char *)&len, sizeof(len));
+        std::vector name_as_vec(len+1);
+        in.read((char *)name_as_vec.data(), len);
+        if (in.fail()) {
+            LOG_ERR("%s: failed reading name for entry %d from %s\n",__func__,i+1, fname);
+            return false;
+        }
+        name_as_vec[len] = 0;
+        std::string name{name_as_vec.data()};
+        auto & e = m_stats[std::move(name)];
+        int ncall;
+        in.read((char*)&ncall, sizeof(ncall));
+        int nval;
+        in.read((char *)&nval, sizeof(nval));
+        if (in.fail() || nval < 1) {
+            LOG_ERR("%s: failed reading number of values for entry %d\n",__func__,i);
+            m_stats = {};
+            return false;
+        }
+
+        if (e.values.empty()) {
+            e.values.resize(nval, 0);
+            e.counts.resize(nval, 0);
+        }
+
+        std::vector tmp(nval);
+        in.read((char*)tmp.data(), nval*sizeof(float));
+        if (in.fail()) {
+            LOG_ERR("%s: failed reading data for entry %d\n",__func__,i);
+            m_stats = {};
+            return false;
+        }
+
+        // Recreate the state as expected by save_imatrix(), and corerct for weighted sum.
+        for (int i = 0; i < nval; i++) {
+            e.values[i] += tmp[i];
+            e.counts[i] += ncall;
+        }
+        e.ncall += ncall;
+
+    }
+    return true;
+}
+
+static IMatrixCollector g_collector;
+
+static bool ik_collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data) {
+    return g_collector.collect_imatrix(t, ask, user_data);
+}
+
+
+struct results_log_softmax {
+    double log_softmax;
+    float  logit;
+    float  prob;
+};
+
+static std::vector softmax(const std::vector & logits) {
+    std::vector probs(logits.size());
+    float max_logit = logits[0];
+    for (float v : logits) {
+        max_logit = std::max(max_logit, v);
+    }
+    double sum_exp = 0.0;
+    for (size_t i = 0; i < logits.size(); i++) {
+        // Subtract the maximum logit value from the current logit value for numerical stability
+        const float logit = logits[i] - max_logit;
+        const float exp_logit = expf(logit);
+        sum_exp += exp_logit;
+        probs[i] = exp_logit;
+    }
+    for (size_t i = 0; i < probs.size(); i++) {
+        probs[i] /= sum_exp;
+    }
+    return probs;
+}
+
+static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) {
+    float max_logit = logits[0];
+    for (int i = 1; i < n_vocab; ++i) {
+        max_logit = std::max(max_logit, logits[i]);
+    }
+    double sum_exp = 0.0;
+    for (int i = 0; i < n_vocab; ++i) {
+        sum_exp += expf(logits[i] - max_logit);
+    }
+    return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp};
+}
+
+static void process_logits(
+    int n_vocab, const float * logits, const int * tokens, int n_token, std::vector & workers,
+    double & nll, double & nll2, float * logit_history, float * prob_history) {
+    std::mutex mutex;
+    int counter = 0;
+    auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () {
+        double local_nll  = 0;
+        double local_nll2 = 0;
+        while (true) {
+            std::unique_lock lock(mutex);
+            int i = counter++;
+            if (i >= n_token) {
+                nll += local_nll; nll2 += local_nll2;
+                break;
+            }
+            lock.unlock();
+            const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]);
+            const double v = -results.log_softmax;
+            local_nll += v;
+            local_nll2 += v*v;
+
+            logit_history[i] = results.logit;
+            prob_history[i]  = results.prob;
+        }
+    };
+    for (auto & w : workers) {
+        w = std::thread(compute);
+    }
+    compute();
+    for (auto & w : workers) {
+        w.join();
+    }
+}
+
+static bool compute_imatrix(llama_context * ctx, const common_params & params) {
+    const llama_model * model = llama_get_model(ctx);
+    const llama_vocab * vocab = llama_model_get_vocab(model);
+
+    const bool add_bos = llama_vocab_get_add_bos(vocab);
+    const int n_ctx = llama_n_ctx(ctx);
+
+    GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
+
+    auto tim1 = std::chrono::high_resolution_clock::now();
+    LOG_INF("%s: tokenizing the input ..\n", __func__);
+
+    std::vector tokens = common_tokenize(ctx, params.prompt, true, params.parse_special);
+
+    auto tim2 = std::chrono::high_resolution_clock::now();
+    LOG_INF("%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast(tim2-tim1).count());
+
+    if (params.i_chunk > 0) {
+        if (size_t((params.i_chunk + 2)*n_ctx) >= tokens.size()) {
+            LOG_ERR("%s: there will be not enough tokens left after removing %d chunks\n", __func__, params.i_chunk);
+            return false;
+        }
+        LOG_INF("%s: removing initial %d chunks (%d tokens)\n", __func__, params.i_chunk, params.i_chunk*n_ctx);
+        tokens.erase(tokens.begin(), tokens.begin() + params.i_chunk*n_ctx);
+    }
+
+    if (int(tokens.size()) < 2*n_ctx) {
+        LOG_ERR("%s: you need at least %d tokens for a context of %d tokens\n", __func__, 2*n_ctx, n_ctx);
+        LOG_ERR("%s: the data file you provided tokenizes to only %zu tokens\n", __func__, tokens.size());
+        return false;
+    }
+
+    std::vector logit_history;
+    std::vector prob_history;
+
+    if (params.compute_ppl) {
+        logit_history.resize(tokens.size());
+        prob_history.resize(tokens.size());
+    }
+
+    const int n_chunk_max = tokens.size() / n_ctx;
+
+    const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
+    const int n_vocab = llama_vocab_n_tokens(vocab);
+    const int n_batch = params.n_batch;
+
+    int count = 0;
+    double nll = 0.0;
+    double nll2 = 0.0;
+
+    LOG_INF("%s: computing over %d chunks with batch_size %d\n", __func__, n_chunk, n_batch);
+
+    std::vector workers(std::thread::hardware_concurrency() - 1);
+
+    const int num_batches = (n_ctx + n_batch - 1) / n_batch;
+
+    std::vector logits;
+    if (params.compute_ppl && num_batches > 1) {
+        logits.reserve((size_t)n_ctx * n_vocab);
+    }
+
+    for (int i = 0; i < n_chunk; ++i) {
+        const int start =     i * n_ctx;
+        const int end   = start + n_ctx;
+
+        std::vector logits;
+
+        const auto t_start = std::chrono::high_resolution_clock::now();
+
+        // clear the KV cache
+        llama_memory_clear(llama_get_memory(ctx), true);
+
+        llama_batch batch = llama_batch_init(n_batch, 0, 1);
+
+        for (int j = 0; j < num_batches; ++j) {
+            const int batch_start = start + j * n_batch;
+            const int batch_size  = std::min(end - batch_start, n_batch);
+
+            // save original token and restore it after eval
+            const auto token_org = tokens[batch_start];
+
+            // add BOS token for the first batch of each chunk
+            if (add_bos && j == 0) {
+                tokens[batch_start] = llama_vocab_bos(vocab);
+            }
+
+            common_batch_clear(batch);
+            for (int i = 0; i < batch_size; i++) {
+                common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
+            }
+
+            if (llama_decode(ctx, batch)) {
+                LOG_ERR("%s : failed to eval\n", __func__);
+                llama_batch_free(batch);
+                return false;
+            }
+
+            // restore the original token in case it was set to BOS
+            tokens[batch_start] = token_org;
+
+            if (params.compute_ppl && num_batches > 1) {
+                const auto * batch_logits = llama_get_logits(ctx);
+                logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
+            }
+        }
+
+        llama_batch_free(batch);
+
+        const auto t_end = std::chrono::high_resolution_clock::now();
+
+        if (i == 0) {
+            const float t_total = std::chrono::duration(t_end - t_start).count();
+            LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
+            int total_seconds = (int)(t_total * n_chunk);
+            if (total_seconds >= 60*60) {
+                LOG("%d hours ", total_seconds / (60*60));
+                total_seconds = total_seconds % (60*60);
+            }
+            LOG("%.2f minutes\n", total_seconds / 60.0);
+        }
+
+        if (params.compute_ppl) {
+            const int first = n_ctx/2;
+            const auto * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
+            process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
+                    workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
+            count += n_ctx - first - 1;
+
+            LOG("[%d]%.4lf,", i + 1, std::exp(nll / count));
+            fflush(stdout);
+
+            logits.clear();
+        }
+    }
+    LOG("\n");
+
+    if (params.compute_ppl) {
+        nll2 /= count;
+        nll /= count;
+        const double ppl = exp(nll);
+        nll2 -= nll * nll;
+        if (nll2 > 0) {
+            nll2 = sqrt(nll2/(count-1));
+            LOG("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
+        } else {
+            LOG("Unexpected negative standard deviation of log(prob)\n");
+        }
+    }
+
+    return true;
+}
+
+int main(int argc, char ** argv) {
+    common_params params;
+
+    params.out_file = "imatrix.dat" ;
+
+    params.n_ctx = 512;
+    params.escape = false;
+
+    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_IMATRIX, print_usage)) {
+        return 1;
+    }
+
+    common_init();
+
+    params.n_batch = std::min(params.n_batch, params.n_ctx);
+
+    g_collector.set_params(params);
+
+    for (const auto & in_file : params.in_files) {
+        LOG_INF("%s : loading imatrix from '%s'\n", __func__, in_file.c_str());
+        if (!g_collector.load_imatrix(in_file.c_str())) {
+            LOG_ERR("%s : failed to load %s\n", __func__, in_file.c_str());
+            return 1;
+        }
+    }
+
+    if (params.in_files.size() > 1) {
+        LOG_INF("%s : saving combined imatrix to '%s'\n", __func__, params.out_file.c_str());
+        g_collector.save_imatrix();
+    }
+
+    llama_backend_init();
+    llama_numa_init(params.numa);
+
+    // pass the callback to the backend scheduler
+    // it will be executed for each node during the graph computation
+    params.cb_eval = ik_collect_imatrix;
+    params.cb_eval_user_data = NULL;
+    params.warmup = false;
+
+    // init
+    common_init_result llama_init = common_init_from_params(params);
+
+    llama_model * model = llama_init.model.get();
+    llama_context * ctx = llama_init.context.get();
+
+    if (model == nullptr || ctx == nullptr) {
+        LOG_ERR("%s : failed to init\n", __func__);
+        return 1;
+    }
+
+    const int n_ctx_train = llama_model_n_ctx_train(model);
+    if (params.n_ctx > n_ctx_train) {
+        LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n",
+                __func__, n_ctx_train, params.n_ctx);
+    }
+
+    // print system information
+    {
+        LOG_INF("\n");
+        LOG_INF("%s\n", common_params_get_system_info(params).c_str());
+    }
+
+    if (params.prompt.empty()) {
+        if (params.in_files.empty()) {
+            LOG_ERR("Error: No prompt provided and no precomputed matrices (--in-file) to combine.\n");
+            return 1;
+        }
+        LOG_INF("No prompt provided; combining precomputed matrices only.\n");
+    } else {
+        if (!compute_imatrix(ctx, params)) {
+            return 1;
+        }
+    }
+
+
+    g_collector.save_imatrix();
+
+    LOG("\n");
+    llama_perf_context_print(ctx);
+
+    llama_backend_free();
+
+    return 0;
+}
diff --git a/tools/llama-bench/CMakeLists.txt b/tools/llama-bench/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..17e3b9b87bae4394f50d85d88e48fcf1d0b25bfc
--- /dev/null
+++ b/tools/llama-bench/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(TARGET llama-bench)
+add_executable(${TARGET} llama-bench.cpp)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_17)
diff --git a/tools/llama-bench/README.md b/tools/llama-bench/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..31a27308743469a00190e6e184b5978dc781a5ab
--- /dev/null
+++ b/tools/llama-bench/README.md
@@ -0,0 +1,344 @@
+# llama.cpp/tools/llama-bench
+
+Performance testing tool for llama.cpp.
+
+## Table of contents
+
+1. [Syntax](#syntax)
+2. [Examples](#examples)
+    1. [Text generation with different models](#text-generation-with-different-models)
+    2. [Prompt processing with different batch sizes](#prompt-processing-with-different-batch-sizes)
+    3. [Different numbers of threads](#different-numbers-of-threads)
+    4. [Different numbers of layers offloaded to the GPU](#different-numbers-of-layers-offloaded-to-the-gpu)
+3. [Output formats](#output-formats)
+    1. [Markdown](#markdown)
+    2. [CSV](#csv)
+    3. [JSON](#json)
+    4. [JSONL](#jsonl)
+    5. [SQL](#sql)
+
+## Syntax
+
+```
+usage: llama-bench [options]
+
+options:
+  -h, --help
+  --numa        numa mode (default: disabled)
+  -r, --repetitions                      number of times to repeat each test (default: 5)
+  --prio <0|1|2|3>                          process/thread priority (default: 0)
+  --delay <0...N> (seconds)                 delay between each test (default: 0)
+  -o, --output       output format printed to stdout (default: md)
+  -oe, --output-err  output format printed to stderr (default: none)
+  -v, --verbose                             verbose output
+  --progress                                print test progress indicators
+
+test parameters:
+  -m, --model                     (default: models/7B/ggml-model-q4_0.gguf)
+  -p, --n-prompt                         (default: 512)
+  -n, --n-gen                            (default: 128)
+  -pg                                (default: )
+  -d, --n-depth                          (default: 0)
+  -b, --batch-size                       (default: 2048)
+  -ub, --ubatch-size                     (default: 512)
+  -ctk, --cache-type-k                   (default: f16)
+  -ctv, --cache-type-v                   (default: f16)
+  -dt, --defrag-thold                    (default: -1)
+  -t, --threads                          (default: system dependent)
+  -C, --cpu-mask                   (default: 0x0)
+  --cpu-strict <0|1>                        (default: 0)
+  --poll <0...100>                          (default: 50)
+  -ngl, --n-gpu-layers                   (default: 99)
+  -rpc, --rpc                  (default: none)
+  -sm, --split-mode         (default: layer)
+  -mg, --main-gpu                        (default: 0)
+  -nkvo, --no-kv-offload <0|1>              (default: 0)
+  -fa, --flash-attn <0|1>                   (default: 0)
+  -mmp, --mmap <0|1>                        (default: 1)
+  -embd, --embeddings <0|1>                 (default: 0)
+  -ts, --tensor-split           (default: 0)
+  -ot --override-tensors =;...
+                                            (default: disabled)
+  -nopo, --no-op-offload <0|1>              (default: 0)
+
+Multiple values can be given for each parameter by separating them with ','
+or by specifying the parameter multiple times. Ranges can be given as
+'first-last' or 'first-last+step' or 'first-last*mult'.
+```
+
+llama-bench can perform three types of tests:
+
+- Prompt processing (pp): processing a prompt in batches (`-p`)
+- Text generation (tg): generating a sequence of tokens (`-n`)
+- Prompt processing + text generation (pg): processing a prompt followed by generating a sequence of tokens (`-pg`)
+
+With the exception of `-r`, `-o` and `-v`, all options can be specified multiple times to run multiple tests. Each pp and tg test is run with all combinations of the specified options. To specify multiple values for an option, the values can be separated by commas (e.g. `-n 16,32`), or the option can be specified multiple times (e.g. `-n 16 -n 32`).
+
+Each test is repeated the number of times given by `-r`, and the results are averaged. The results are given in average tokens per second (t/s) and standard deviation. Some output formats (e.g. json) also include the individual results of each repetition.
+
+Using the `-d ` option, each test can be run at a specified context depth, prefilling the KV cache with `` tokens.
+
+For a description of the other options, see the [main example](../main/README.md).
+
+## Examples
+
+### Text generation with different models
+
+```sh
+$ ./llama-bench -m models/7B/ggml-model-q4_0.gguf -m models/13B/ggml-model-q4_0.gguf -p 0 -n 128,256,512
+```
+
+| model                          |       size |     params | backend    | ngl | test       |              t/s |
+| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------- | ---------------: |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  99 | tg 128     |    132.19 ± 0.55 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  99 | tg 256     |    129.37 ± 0.54 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  99 | tg 512     |    123.83 ± 0.25 |
+| llama 13B mostly Q4_0          |   6.86 GiB |    13.02 B | CUDA       |  99 | tg 128     |     82.17 ± 0.31 |
+| llama 13B mostly Q4_0          |   6.86 GiB |    13.02 B | CUDA       |  99 | tg 256     |     80.74 ± 0.23 |
+| llama 13B mostly Q4_0          |   6.86 GiB |    13.02 B | CUDA       |  99 | tg 512     |     78.08 ± 0.07 |
+
+### Prompt processing with different batch sizes
+
+```sh
+$ ./llama-bench -n 0 -p 1024 -b 128,256,512,1024
+```
+
+| model                          |       size |     params | backend    | ngl |    n_batch | test       |              t/s |
+| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------: | ---------- | ---------------: |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  99 |        128 | pp 1024    |   1436.51 ± 3.66 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  99 |        256 | pp 1024    |  1932.43 ± 23.48 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  99 |        512 | pp 1024    |  2254.45 ± 15.59 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  99 |       1024 | pp 1024    |  2498.61 ± 13.58 |
+
+### Different numbers of threads
+
+```sh
+$ ./llama-bench -n 0 -n 16 -p 64 -t 1,2,4,8,16,32
+```
+
+| model                          |       size |     params | backend    |    threads | test       |              t/s |
+| ------------------------------ | ---------: | ---------: | ---------- | ---------: | ---------- | ---------------: |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CPU        |          1 | pp 64      |      6.17 ± 0.07 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CPU        |          1 | tg 16      |      4.05 ± 0.02 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CPU        |          2 | pp 64      |     12.31 ± 0.13 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CPU        |          2 | tg 16      |      7.80 ± 0.07 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CPU        |          4 | pp 64      |     23.18 ± 0.06 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CPU        |          4 | tg 16      |     12.22 ± 0.07 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CPU        |          8 | pp 64      |     32.29 ± 1.21 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CPU        |          8 | tg 16      |     16.71 ± 0.66 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CPU        |         16 | pp 64      |     33.52 ± 0.03 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CPU        |         16 | tg 16      |     15.32 ± 0.05 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CPU        |         32 | pp 64      |     59.00 ± 1.11 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CPU        |         32 | tg 16      |     16.41 ± 0.79 ||
+
+### Different numbers of layers offloaded to the GPU
+
+```sh
+$ ./llama-bench -ngl 10,20,30,31,32,33,34,35
+```
+
+| model                          |       size |     params | backend    | ngl | test       |              t/s |
+| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------- | ---------------: |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  10 | pp 512     |    373.36 ± 2.25 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  10 | tg 128     |     13.45 ± 0.93 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  20 | pp 512     |    472.65 ± 1.25 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  20 | tg 128     |     21.36 ± 1.94 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  30 | pp 512     |   631.87 ± 11.25 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  30 | tg 128     |     40.04 ± 1.82 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  31 | pp 512     |    657.89 ± 5.08 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  31 | tg 128     |     48.19 ± 0.81 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  32 | pp 512     |    688.26 ± 3.29 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  32 | tg 128     |     54.78 ± 0.65 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  33 | pp 512     |    704.27 ± 2.24 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  33 | tg 128     |     60.62 ± 1.76 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  34 | pp 512     |    881.34 ± 5.40 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  34 | tg 128     |     71.76 ± 0.23 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  35 | pp 512     |   2400.01 ± 7.72 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  35 | tg 128     |    131.66 ± 0.49 |
+
+### Different prefilled context
+
+```
+$ ./llama-bench -d 0,512
+```
+
+| model                          |       size |     params | backend    | ngl |            test |                  t/s |
+| ------------------------------ | ---------: | ---------: | ---------- | --: | --------------: | -------------------: |
+| qwen2 7B Q4_K - Medium         |   4.36 GiB |     7.62 B | CUDA       |  99 |           pp512 |      7340.20 ± 23.45 |
+| qwen2 7B Q4_K - Medium         |   4.36 GiB |     7.62 B | CUDA       |  99 |           tg128 |        120.60 ± 0.59 |
+| qwen2 7B Q4_K - Medium         |   4.36 GiB |     7.62 B | CUDA       |  99 |    pp512 @ d512 |      6425.91 ± 18.88 |
+| qwen2 7B Q4_K - Medium         |   4.36 GiB |     7.62 B | CUDA       |  99 |    tg128 @ d512 |        116.71 ± 0.60 |
+
+## Output formats
+
+By default, llama-bench outputs the results in markdown format. The results can be output in other formats by using the `-o` option.
+
+### Markdown
+
+```sh
+$ ./llama-bench -o md
+```
+
+| model                          |       size |     params | backend    | ngl | test       |              t/s |
+| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------- | ---------------: |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  99 | pp 512     |  2368.80 ± 93.24 |
+| llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  99 | tg 128     |    131.42 ± 0.59 |
+
+### CSV
+
+```sh
+$ ./llama-bench -o csv
+```
+
+```csv
+build_commit,build_number,cpu_info,gpu_info,backends,model_filename,model_type,model_size,model_n_params,n_batch,n_ubatch,n_threads,cpu_mask,cpu_strict,poll,type_k,type_v,n_gpu_layers,split_mode,main_gpu,no_kv_offload,flash_attn,tensor_split,use_mmap,embeddings,n_prompt,n_gen,n_depth,test_time,avg_ns,stddev_ns,avg_ts,stddev_ts
+"8cf427ff","5163","AMD Ryzen 7 7800X3D 8-Core Processor","NVIDIA GeForce RTX 4080","CUDA","models/Qwen2.5-7B-Instruct-Q4_K_M.gguf","qwen2 7B Q4_K - Medium","4677120000","7615616512","2048","512","8","0x0","0","50","f16","f16","99","layer","0","0","0","0.00","1","0","512","0","0","2025-04-24T11:57:09Z","70285660","982040","7285.676949","100.064434"
+"8cf427ff","5163","AMD Ryzen 7 7800X3D 8-Core Processor","NVIDIA GeForce RTX 4080","CUDA","models/Qwen2.5-7B-Instruct-Q4_K_M.gguf","qwen2 7B Q4_K - Medium","4677120000","7615616512","2048","512","8","0x0","0","50","f16","f16","99","layer","0","0","0","0.00","1","0","0","128","0","2025-04-24T11:57:10Z","1067431600","3834831","119.915244","0.430617"
+```
+
+### JSON
+
+```sh
+$ ./llama-bench -o json
+```
+
+```json
+[
+  {
+    "build_commit": "8cf427ff",
+    "build_number": 5163,
+    "cpu_info": "AMD Ryzen 7 7800X3D 8-Core Processor",
+    "gpu_info": "NVIDIA GeForce RTX 4080",
+    "backends": "CUDA",
+    "model_filename": "models/Qwen2.5-7B-Instruct-Q4_K_M.gguf",
+    "model_type": "qwen2 7B Q4_K - Medium",
+    "model_size": 4677120000,
+    "model_n_params": 7615616512,
+    "n_batch": 2048,
+    "n_ubatch": 512,
+    "n_threads": 8,
+    "cpu_mask": "0x0",
+    "cpu_strict": false,
+    "poll": 50,
+    "type_k": "f16",
+    "type_v": "f16",
+    "n_gpu_layers": 99,
+    "split_mode": "layer",
+    "main_gpu": 0,
+    "no_kv_offload": false,
+    "flash_attn": false,
+    "tensor_split": "0.00",
+    "use_mmap": true,
+    "embeddings": false,
+    "n_prompt": 512,
+    "n_gen": 0,
+    "n_depth": 0,
+    "test_time": "2025-04-24T11:58:50Z",
+    "avg_ns": 72135640,
+    "stddev_ns": 1453752,
+    "avg_ts": 7100.002165,
+    "stddev_ts": 140.341520,
+    "samples_ns": [ 74601900, 71632900, 71745200, 71952700, 70745500 ],
+    "samples_ts": [ 6863.1, 7147.55, 7136.37, 7115.79, 7237.21 ]
+  },
+  {
+    "build_commit": "8cf427ff",
+    "build_number": 5163,
+    "cpu_info": "AMD Ryzen 7 7800X3D 8-Core Processor",
+    "gpu_info": "NVIDIA GeForce RTX 4080",
+    "backends": "CUDA",
+    "model_filename": "models/Qwen2.5-7B-Instruct-Q4_K_M.gguf",
+    "model_type": "qwen2 7B Q4_K - Medium",
+    "model_size": 4677120000,
+    "model_n_params": 7615616512,
+    "n_batch": 2048,
+    "n_ubatch": 512,
+    "n_threads": 8,
+    "cpu_mask": "0x0",
+    "cpu_strict": false,
+    "poll": 50,
+    "type_k": "f16",
+    "type_v": "f16",
+    "n_gpu_layers": 99,
+    "split_mode": "layer",
+    "main_gpu": 0,
+    "no_kv_offload": false,
+    "flash_attn": false,
+    "tensor_split": "0.00",
+    "use_mmap": true,
+    "embeddings": false,
+    "n_prompt": 0,
+    "n_gen": 128,
+    "n_depth": 0,
+    "test_time": "2025-04-24T11:58:51Z",
+    "avg_ns": 1076767880,
+    "stddev_ns": 9449585,
+    "avg_ts": 118.881588,
+    "stddev_ts": 1.041811,
+    "samples_ns": [ 1075361300, 1065089400, 1071761200, 1081934900, 1089692600 ],
+    "samples_ts": [ 119.03, 120.178, 119.43, 118.307, 117.464 ]
+  }
+]
+```
+
+
+### JSONL
+
+```sh
+$ ./llama-bench -o jsonl
+```
+
+```json lines
+{"build_commit": "8cf427ff", "build_number": 5163, "cpu_info": "AMD Ryzen 7 7800X3D 8-Core Processor", "gpu_info": "NVIDIA GeForce RTX 4080", "backends": "CUDA", "model_filename": "models/Qwen2.5-7B-Instruct-Q4_K_M.gguf", "model_type": "qwen2 7B Q4_K - Medium", "model_size": 4677120000, "model_n_params": 7615616512, "n_batch": 2048, "n_ubatch": 512, "n_threads": 8, "cpu_mask": "0x0", "cpu_strict": false, "poll": 50, "type_k": "f16", "type_v": "f16", "n_gpu_layers": 99, "split_mode": "layer", "main_gpu": 0, "no_kv_offload": false, "flash_attn": false, "tensor_split": "0.00", "use_mmap": true, "embeddings": false, "n_prompt": 512, "n_gen": 0, "n_depth": 0, "test_time": "2025-04-24T11:59:33Z", "avg_ns": 70497220, "stddev_ns": 883196, "avg_ts": 7263.609157, "stddev_ts": 90.940578, "samples_ns": [ 71551000, 71222800, 70364100, 69439100, 69909100 ],"samples_ts": [ 7155.74, 7188.71, 7276.44, 7373.37, 7323.8 ]}
+{"build_commit": "8cf427ff", "build_number": 5163, "cpu_info": "AMD Ryzen 7 7800X3D 8-Core Processor", "gpu_info": "NVIDIA GeForce RTX 4080", "backends": "CUDA", "model_filename": "models/Qwen2.5-7B-Instruct-Q4_K_M.gguf", "model_type": "qwen2 7B Q4_K - Medium", "model_size": 4677120000, "model_n_params": 7615616512, "n_batch": 2048, "n_ubatch": 512, "n_threads": 8, "cpu_mask": "0x0", "cpu_strict": false, "poll": 50, "type_k": "f16", "type_v": "f16", "n_gpu_layers": 99, "split_mode": "layer", "main_gpu": 0, "no_kv_offload": false, "flash_attn": false, "tensor_split": "0.00", "use_mmap": true, "embeddings": false, "n_prompt": 0, "n_gen": 128, "n_depth": 0, "test_time": "2025-04-24T11:59:33Z", "avg_ns": 1068078400, "stddev_ns": 6279455, "avg_ts": 119.844681, "stddev_ts": 0.699739, "samples_ns": [ 1066331700, 1064864900, 1079042600, 1063328400, 1066824400 ],"samples_ts": [ 120.038, 120.203, 118.624, 120.377, 119.982 ]}
+```
+
+
+### SQL
+
+SQL output is suitable for importing into a SQLite database. The output can be piped into the `sqlite3` command line tool to add the results to a database.
+
+```sh
+$ ./llama-bench -o sql
+```
+
+```sql
+CREATE TABLE IF NOT EXISTS test (
+  build_commit TEXT,
+  build_number INTEGER,
+  cpu_info TEXT,
+  gpu_info TEXT,
+  backends TEXT,
+  model_filename TEXT,
+  model_type TEXT,
+  model_size INTEGER,
+  model_n_params INTEGER,
+  n_batch INTEGER,
+  n_ubatch INTEGER,
+  n_threads INTEGER,
+  cpu_mask TEXT,
+  cpu_strict INTEGER,
+  poll INTEGER,
+  type_k TEXT,
+  type_v TEXT,
+  n_gpu_layers INTEGER,
+  split_mode TEXT,
+  main_gpu INTEGER,
+  no_kv_offload INTEGER,
+  flash_attn INTEGER,
+  tensor_split TEXT,
+  use_mmap INTEGER,
+  embeddings INTEGER,
+  n_prompt INTEGER,
+  n_gen INTEGER,
+  n_depth INTEGER,
+  test_time TEXT,
+  avg_ns INTEGER,
+  stddev_ns INTEGER,
+  avg_ts REAL,
+  stddev_ts REAL
+);
+
+INSERT INTO test (build_commit, build_number, cpu_info, gpu_info, backends, model_filename, model_type, model_size, model_n_params, n_batch, n_ubatch, n_threads, cpu_mask, cpu_strict, poll, type_k, type_v, n_gpu_layers, split_mode, main_gpu, no_kv_offload, flash_attn, tensor_split, use_mmap, embeddings, n_prompt, n_gen, n_depth, test_time, avg_ns, stddev_ns, avg_ts, stddev_ts) VALUES ('8cf427ff', '5163', 'AMD Ryzen 7 7800X3D 8-Core Processor', 'NVIDIA GeForce RTX 4080', 'CUDA', 'models/Qwen2.5-7B-Instruct-Q4_K_M.gguf', 'qwen2 7B Q4_K - Medium', '4677120000', '7615616512', '2048', '512', '8', '0x0', '0', '50', 'f16', 'f16', '99', 'layer', '0', '0', '0', '0.00', '1', '0', '512', '0', '0', '2025-04-24T12:00:08Z', '69905000', '519516', '7324.546977', '54.032613');
+INSERT INTO test (build_commit, build_number, cpu_info, gpu_info, backends, model_filename, model_type, model_size, model_n_params, n_batch, n_ubatch, n_threads, cpu_mask, cpu_strict, poll, type_k, type_v, n_gpu_layers, split_mode, main_gpu, no_kv_offload, flash_attn, tensor_split, use_mmap, embeddings, n_prompt, n_gen, n_depth, test_time, avg_ns, stddev_ns, avg_ts, stddev_ts) VALUES ('8cf427ff', '5163', 'AMD Ryzen 7 7800X3D 8-Core Processor', 'NVIDIA GeForce RTX 4080', 'CUDA', 'models/Qwen2.5-7B-Instruct-Q4_K_M.gguf', 'qwen2 7B Q4_K - Medium', '4677120000', '7615616512', '2048', '512', '8', '0x0', '0', '50', 'f16', 'f16', '99', 'layer', '0', '0', '0', '0.00', '1', '0', '0', '128', '0', '2025-04-24T12:00:09Z', '1063608780', '4464130', '120.346696', '0.504647');
+```
diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..e59d61f19567511dcd51f0bb895f1d1a70ca3bb5
--- /dev/null
+++ b/tools/llama-bench/llama-bench.cpp
@@ -0,0 +1,2024 @@
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include "common.h"
+#include "ggml.h"
+#include "llama.h"
+
+#ifdef _WIN32
+#    define WIN32_LEAN_AND_MEAN
+#    ifndef NOMINMAX
+#        define NOMINMAX
+#    endif
+#    include 
+#endif
+
+// utils
+static uint64_t get_time_ns() {
+    using clock = std::chrono::high_resolution_clock;
+    return std::chrono::nanoseconds(clock::now().time_since_epoch()).count();
+}
+
+static bool tensor_buft_override_equal(const llama_model_tensor_buft_override& a, const llama_model_tensor_buft_override& b) {
+    if (a.pattern != b.pattern) {
+        // cString comparison that may be null
+        if (a.pattern == nullptr || b.pattern == nullptr) {
+            return false;
+        }
+        if (strcmp(a.pattern, b.pattern) != 0) {
+            return false;
+        }
+    }
+    if (a.buft != b.buft) {
+        return false;
+    }
+    return true;
+}
+
+static bool vec_tensor_buft_override_equal(const std::vector& a, const std::vector& b) {
+    if (a.size() != b.size()) {
+        return false;
+    }
+    for (size_t i = 0; i < a.size(); i++) {
+        if (!tensor_buft_override_equal(a[i], b[i])) {
+            return false;
+        }
+    }
+    return true;
+}
+
+static bool vec_vec_tensor_buft_override_equal(const std::vector>& a, const std::vector>& b) {
+    if (a.size() != b.size()) {
+        return false;
+    }
+    for (size_t i = 0; i < a.size(); i++) {
+        if (!vec_tensor_buft_override_equal(a[i], b[i])) {
+            return false;
+        }
+    }
+    return true;
+}
+
+template  static std::string join(const std::vector & values, const std::string & delim) {
+    std::ostringstream str;
+    for (size_t i = 0; i < values.size(); i++) {
+        str << values[i];
+        if (i < values.size() - 1) {
+            str << delim;
+        }
+    }
+    return str.str();
+}
+
+template  static std::vector transform_to_str(const std::vector & values, F f) {
+    std::vector str_values;
+    std::transform(values.begin(), values.end(), std::back_inserter(str_values), f);
+    return str_values;
+}
+
+template  static T avg(const std::vector & v) {
+    if (v.empty()) {
+        return 0;
+    }
+    T sum = std::accumulate(v.begin(), v.end(), T(0));
+    return sum / (T) v.size();
+}
+
+template  static T stdev(const std::vector & v) {
+    if (v.size() <= 1) {
+        return 0;
+    }
+    T mean   = avg(v);
+    T sq_sum = std::inner_product(v.begin(), v.end(), v.begin(), T(0));
+    T stdev  = std::sqrt(sq_sum / (T) (v.size() - 1) - mean * mean * (T) v.size() / (T) (v.size() - 1));
+    return stdev;
+}
+
+static std::string get_cpu_info() {
+    std::vector cpu_list;
+    for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
+        auto * dev      = ggml_backend_dev_get(i);
+        auto   dev_type = ggml_backend_dev_type(dev);
+        if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU || dev_type == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
+            cpu_list.push_back(ggml_backend_dev_description(dev));
+        }
+    }
+    return join(cpu_list, ", ");
+}
+
+static std::string get_gpu_info() {
+    std::vector gpu_list;
+    for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
+        auto * dev      = ggml_backend_dev_get(i);
+        auto   dev_type = ggml_backend_dev_type(dev);
+        if (dev_type == GGML_BACKEND_DEVICE_TYPE_GPU) {
+            gpu_list.push_back(ggml_backend_dev_description(dev));
+        }
+    }
+    return join(gpu_list, ", ");
+}
+
+// command line params
+enum output_formats { NONE, CSV, JSON, JSONL, MARKDOWN, SQL };
+
+static const char * output_format_str(output_formats format) {
+    switch (format) {
+        case NONE:
+            return "none";
+        case CSV:
+            return "csv";
+        case JSON:
+            return "json";
+        case JSONL:
+            return "jsonl";
+        case MARKDOWN:
+            return "md";
+        case SQL:
+            return "sql";
+        default:
+            GGML_ABORT("invalid output format");
+    }
+}
+
+static bool output_format_from_str(const std::string & s, output_formats & format) {
+    if (s == "none") {
+        format = NONE;
+    } else if (s == "csv") {
+        format = CSV;
+    } else if (s == "json") {
+        format = JSON;
+    } else if (s == "jsonl") {
+        format = JSONL;
+    } else if (s == "md") {
+        format = MARKDOWN;
+    } else if (s == "sql") {
+        format = SQL;
+    } else {
+        return false;
+    }
+    return true;
+}
+
+static const char * split_mode_str(llama_split_mode mode) {
+    switch (mode) {
+        case LLAMA_SPLIT_MODE_NONE:
+            return "none";
+        case LLAMA_SPLIT_MODE_LAYER:
+            return "layer";
+        case LLAMA_SPLIT_MODE_ROW:
+            return "row";
+        default:
+            GGML_ABORT("invalid split mode");
+    }
+}
+
+static std::string pair_str(const std::pair & p) {
+    static char buf[32];
+    snprintf(buf, sizeof(buf), "%d,%d", p.first, p.second);
+    return buf;
+}
+
+static std::vector parse_int_range(const std::string & s) {
+    // first[-last[(+|*)step]]
+    std::regex range_regex(R"(^(\d+)(?:-(\d+)(?:([\+|\*])(\d+))?)?(?:,|$))");
+
+    std::smatch match;
+    std::string::const_iterator search_start(s.cbegin());
+    std::vector result;
+    while (std::regex_search(search_start, s.cend(), match, range_regex)) {
+        int  first = std::stoi(match[1]);
+        int  last  = match[2].matched ? std::stoi(match[2]) : first;
+        char op    = match[3].matched ? match[3].str()[0] : '+';
+        int  step  = match[4].matched ? std::stoi(match[4]) : 1;
+
+        for (int i = first; i <= last;) {
+            result.push_back(i);
+
+            int prev_i = i;
+
+            if (op == '+') {
+                i += step;
+            } else if (op == '*') {
+                i *= step;
+            } else {
+                throw std::invalid_argument("invalid range format");
+            }
+
+            if (i <= prev_i) {
+                throw std::invalid_argument("invalid range");
+            }
+        }
+        search_start = match.suffix().first;
+    }
+
+    if (search_start != s.cend()) {
+        throw std::invalid_argument("invalid range format");
+    }
+
+    return result;
+}
+
+struct cmd_params {
+    std::vector         model;
+    std::vector                 n_prompt;
+    std::vector                 n_gen;
+    std::vector> n_pg;
+    std::vector                 n_depth;
+    std::vector                 n_batch;
+    std::vector                 n_ubatch;
+    std::vector           type_k;
+    std::vector           type_v;
+    std::vector               defrag_thold;
+    std::vector                 n_threads;
+    std::vector         cpu_mask;
+    std::vector                cpu_strict;
+    std::vector                 poll;
+    std::vector                 n_gpu_layers;
+    std::vector         rpc_servers;
+    std::vector    split_mode;
+    std::vector                 main_gpu;
+    std::vector                no_kv_offload;
+    std::vector                flash_attn;
+    std::vector>  tensor_split;
+    std::vector> tensor_buft_overrides;
+    std::vector                use_mmap;
+    std::vector                embeddings;
+    std::vector                no_op_offload;
+    ggml_numa_strategy               numa;
+    int                              reps;
+    ggml_sched_priority              prio;
+    int                              delay;
+    bool                             verbose;
+    bool                             progress;
+    output_formats                   output_format;
+    output_formats                   output_format_stderr;
+};
+
+static const cmd_params cmd_params_defaults = {
+    /* model                */ { "models/7B/ggml-model-q4_0.gguf" },
+    /* n_prompt             */ { 512 },
+    /* n_gen                */ { 128 },
+    /* n_pg                 */ {},
+    /* n_depth              */ { 0 },
+    /* n_batch              */ { 2048 },
+    /* n_ubatch             */ { 512 },
+    /* type_k               */ { GGML_TYPE_F16 },
+    /* type_v               */ { GGML_TYPE_F16 },
+    /* defrag_thold         */ { -1.0f },
+    /* n_threads            */ { cpu_get_num_math() },
+    /* cpu_mask             */ { "0x0" },
+    /* cpu_strict           */ { false },
+    /* poll                 */ { 50 },
+    /* n_gpu_layers         */ { 99 },
+    /* rpc_servers          */ { "" },
+    /* split_mode           */ { LLAMA_SPLIT_MODE_LAYER },
+    /* main_gpu             */ { 0 },
+    /* no_kv_offload        */ { false },
+    /* flash_attn           */ { false },
+    /* tensor_split         */ { std::vector(llama_max_devices(), 0.0f) },
+    /* tensor_buft_overrides*/ { std::vector{ { nullptr, nullptr } } },
+    /* use_mmap             */ { true },
+    /* embeddings           */ { false },
+    /* no_op_offload        */ { false },
+    /* numa                 */ GGML_NUMA_STRATEGY_DISABLED,
+    /* reps                 */ 5,
+    /* prio                 */ GGML_SCHED_PRIO_NORMAL,
+    /* delay                */ 0,
+    /* verbose              */ false,
+    /* progress             */ false,
+    /* output_format        */ MARKDOWN,
+    /* output_format_stderr */ NONE,
+};
+
+static void print_usage(int /* argc */, char ** argv) {
+    printf("usage: %s [options]\n", argv[0]);
+    printf("\n");
+    printf("options:\n");
+    printf("  -h, --help\n");
+    printf("  --numa        numa mode (default: disabled)\n");
+    printf("  -r, --repetitions                      number of times to repeat each test (default: %d)\n",
+           cmd_params_defaults.reps);
+    printf("  --prio <-1|0|1|2|3>                          process/thread priority (default: %d)\n",
+           cmd_params_defaults.prio);
+    printf("  --delay <0...N> (seconds)                 delay between each test (default: %d)\n",
+           cmd_params_defaults.delay);
+    printf("  -o, --output       output format printed to stdout (default: %s)\n",
+           output_format_str(cmd_params_defaults.output_format));
+    printf("  -oe, --output-err  output format printed to stderr (default: %s)\n",
+           output_format_str(cmd_params_defaults.output_format_stderr));
+    printf("  -v, --verbose                             verbose output\n");
+    printf("  --progress                                print test progress indicators\n");
+    printf("\n");
+    printf("test parameters:\n");
+    printf("  -m, --model                     (default: %s)\n", join(cmd_params_defaults.model, ",").c_str());
+    printf("  -p, --n-prompt                         (default: %s)\n",
+           join(cmd_params_defaults.n_prompt, ",").c_str());
+    printf("  -n, --n-gen                            (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str());
+    printf("  -pg                                (default: %s)\n",
+           join(transform_to_str(cmd_params_defaults.n_pg, pair_str), ",").c_str());
+    printf("  -d, --n-depth                          (default: %s)\n",
+           join(cmd_params_defaults.n_depth, ",").c_str());
+    printf("  -b, --batch-size                       (default: %s)\n",
+           join(cmd_params_defaults.n_batch, ",").c_str());
+    printf("  -ub, --ubatch-size                     (default: %s)\n",
+           join(cmd_params_defaults.n_ubatch, ",").c_str());
+    printf("  -ctk, --cache-type-k                   (default: %s)\n",
+           join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str());
+    printf("  -ctv, --cache-type-v                   (default: %s)\n",
+           join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str());
+    printf("  -dt, --defrag-thold                    (default: %s)\n",
+           join(cmd_params_defaults.defrag_thold, ",").c_str());
+    printf("  -t, --threads                          (default: %s)\n",
+           join(cmd_params_defaults.n_threads, ",").c_str());
+    printf("  -C, --cpu-mask                   (default: %s)\n",
+           join(cmd_params_defaults.cpu_mask, ",").c_str());
+    printf("  --cpu-strict <0|1>                        (default: %s)\n",
+           join(cmd_params_defaults.cpu_strict, ",").c_str());
+    printf("  --poll <0...100>                          (default: %s)\n", join(cmd_params_defaults.poll, ",").c_str());
+    printf("  -ngl, --n-gpu-layers                   (default: %s)\n",
+           join(cmd_params_defaults.n_gpu_layers, ",").c_str());
+    if (llama_supports_rpc()) {
+        printf("  -rpc, --rpc                  (default: %s)\n",
+               join(cmd_params_defaults.rpc_servers, ",").c_str());
+    }
+    printf("  -sm, --split-mode         (default: %s)\n",
+           join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str());
+    printf("  -mg, --main-gpu                        (default: %s)\n",
+           join(cmd_params_defaults.main_gpu, ",").c_str());
+    printf("  -nkvo, --no-kv-offload <0|1>              (default: %s)\n",
+           join(cmd_params_defaults.no_kv_offload, ",").c_str());
+    printf("  -fa, --flash-attn <0|1>                   (default: %s)\n",
+           join(cmd_params_defaults.flash_attn, ",").c_str());
+    printf("  -mmp, --mmap <0|1>                        (default: %s)\n",
+           join(cmd_params_defaults.use_mmap, ",").c_str());
+    printf("  -embd, --embeddings <0|1>                 (default: %s)\n",
+           join(cmd_params_defaults.embeddings, ",").c_str());
+    printf("  -ts, --tensor-split           (default: 0)\n");
+    printf("  -ot --override-tensors =;...\n");
+    printf("                                            (default: disabled)\n");
+    printf("  -nopo, --no-op-offload <0|1>              (default: 0)\n");
+    printf("\n");
+    printf(
+        "Multiple values can be given for each parameter by separating them with ','\n"
+        "or by specifying the parameter multiple times. Ranges can be given as\n"
+        "'first-last' or 'first-last+step' or 'first-last*mult'.\n");
+}
+
+static ggml_type ggml_type_from_name(const std::string & s) {
+    if (s == "f16") {
+        return GGML_TYPE_F16;
+    }
+    if (s == "bf16") {
+        return GGML_TYPE_BF16;
+    }
+    if (s == "q8_0") {
+        return GGML_TYPE_Q8_0;
+    }
+    if (s == "q4_0") {
+        return GGML_TYPE_Q4_0;
+    }
+    if (s == "q4_1") {
+        return GGML_TYPE_Q4_1;
+    }
+    if (s == "q5_0") {
+        return GGML_TYPE_Q5_0;
+    }
+    if (s == "q5_1") {
+        return GGML_TYPE_Q5_1;
+    }
+    if (s == "iq4_nl") {
+        return GGML_TYPE_IQ4_NL;
+    }
+
+    return GGML_TYPE_COUNT;
+}
+
+static cmd_params parse_cmd_params(int argc, char ** argv) {
+    cmd_params        params;
+    std::string       arg;
+    bool              invalid_param = false;
+    const std::string arg_prefix    = "--";
+    const char        split_delim   = ',';
+
+    params.verbose              = cmd_params_defaults.verbose;
+    params.output_format        = cmd_params_defaults.output_format;
+    params.output_format_stderr = cmd_params_defaults.output_format_stderr;
+    params.reps                 = cmd_params_defaults.reps;
+    params.numa                 = cmd_params_defaults.numa;
+    params.prio                 = cmd_params_defaults.prio;
+    params.delay                = cmd_params_defaults.delay;
+    params.progress             = cmd_params_defaults.progress;
+
+    for (int i = 1; i < argc; i++) {
+        arg = argv[i];
+        if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
+            std::replace(arg.begin(), arg.end(), '_', '-');
+        }
+
+        try {
+            if (arg == "-h" || arg == "--help") {
+                print_usage(argc, argv);
+                exit(0);
+            } else if (arg == "-m" || arg == "--model") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], split_delim);
+                params.model.insert(params.model.end(), p.begin(), p.end());
+            } else if (arg == "-p" || arg == "--n-prompt") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = parse_int_range(argv[i]);
+                params.n_prompt.insert(params.n_prompt.end(), p.begin(), p.end());
+            } else if (arg == "-n" || arg == "--n-gen") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = parse_int_range(argv[i]);
+                params.n_gen.insert(params.n_gen.end(), p.begin(), p.end());
+            } else if (arg == "-pg") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], ',');
+                if (p.size() != 2) {
+                    invalid_param = true;
+                    break;
+                }
+                params.n_pg.push_back({ std::stoi(p[0]), std::stoi(p[1]) });
+            } else if (arg == "-d" || arg == "--n-depth") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = parse_int_range(argv[i]);
+                params.n_depth.insert(params.n_depth.end(), p.begin(), p.end());
+            } else if (arg == "-b" || arg == "--batch-size") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = parse_int_range(argv[i]);
+                params.n_batch.insert(params.n_batch.end(), p.begin(), p.end());
+            } else if (arg == "-ub" || arg == "--ubatch-size") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = parse_int_range(argv[i]);
+                params.n_ubatch.insert(params.n_ubatch.end(), p.begin(), p.end());
+            } else if (arg == "-ctk" || arg == "--cache-type-k") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], split_delim);
+
+                std::vector types;
+                for (const auto & t : p) {
+                    ggml_type gt = ggml_type_from_name(t);
+                    if (gt == GGML_TYPE_COUNT) {
+                        invalid_param = true;
+                        break;
+                    }
+                    types.push_back(gt);
+                }
+                if (invalid_param) {
+                    break;
+                }
+                params.type_k.insert(params.type_k.end(), types.begin(), types.end());
+            } else if (arg == "-ctv" || arg == "--cache-type-v") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], split_delim);
+
+                std::vector types;
+                for (const auto & t : p) {
+                    ggml_type gt = ggml_type_from_name(t);
+                    if (gt == GGML_TYPE_COUNT) {
+                        invalid_param = true;
+                        break;
+                    }
+                    types.push_back(gt);
+                }
+                if (invalid_param) {
+                    break;
+                }
+                params.type_v.insert(params.type_v.end(), types.begin(), types.end());
+            } else if (arg == "-dt" || arg == "--defrag-thold") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], split_delim);
+                params.defrag_thold.insert(params.defrag_thold.end(), p.begin(), p.end());
+            } else if (arg == "-t" || arg == "--threads") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = parse_int_range(argv[i]);
+                params.n_threads.insert(params.n_threads.end(), p.begin(), p.end());
+            } else if (arg == "-C" || arg == "--cpu-mask") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], split_delim);
+                params.cpu_mask.insert(params.cpu_mask.end(), p.begin(), p.end());
+            } else if (arg == "--cpu-strict") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], split_delim);
+                params.cpu_strict.insert(params.cpu_strict.end(), p.begin(), p.end());
+            } else if (arg == "--poll") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = parse_int_range(argv[i]);
+                params.poll.insert(params.poll.end(), p.begin(), p.end());
+            } else if (arg == "-ngl" || arg == "--n-gpu-layers") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = parse_int_range(argv[i]);
+                params.n_gpu_layers.insert(params.n_gpu_layers.end(), p.begin(), p.end());
+            } else if (llama_supports_rpc() && (arg == "-rpc" || arg == "--rpc")) {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                params.rpc_servers.push_back(argv[i]);
+            } else if (arg == "-sm" || arg == "--split-mode") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], split_delim);
+
+                std::vector modes;
+                for (const auto & m : p) {
+                    llama_split_mode mode;
+                    if (m == "none") {
+                        mode = LLAMA_SPLIT_MODE_NONE;
+                    } else if (m == "layer") {
+                        mode = LLAMA_SPLIT_MODE_LAYER;
+                    } else if (m == "row") {
+                        mode = LLAMA_SPLIT_MODE_ROW;
+                    } else {
+                        invalid_param = true;
+                        break;
+                    }
+                    modes.push_back(mode);
+                }
+                if (invalid_param) {
+                    break;
+                }
+                params.split_mode.insert(params.split_mode.end(), modes.begin(), modes.end());
+            } else if (arg == "-mg" || arg == "--main-gpu") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                params.main_gpu = parse_int_range(argv[i]);
+            } else if (arg == "-nkvo" || arg == "--no-kv-offload") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], split_delim);
+                params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end());
+            } else if (arg == "--numa") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                std::string value(argv[i]);
+                if (value == "distribute" || value == "") {
+                    params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE;
+                } else if (value == "isolate") {
+                    params.numa = GGML_NUMA_STRATEGY_ISOLATE;
+                } else if (value == "numactl") {
+                    params.numa = GGML_NUMA_STRATEGY_NUMACTL;
+                } else {
+                    invalid_param = true;
+                    break;
+                }
+            } else if (arg == "-fa" || arg == "--flash-attn") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], split_delim);
+                params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end());
+            } else if (arg == "-mmp" || arg == "--mmap") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], split_delim);
+                params.use_mmap.insert(params.use_mmap.end(), p.begin(), p.end());
+            } else if (arg == "-embd" || arg == "--embeddings") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], split_delim);
+                params.embeddings.insert(params.embeddings.end(), p.begin(), p.end());
+            } else if (arg == "-nopo" || arg == "--no-op-offload") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], split_delim);
+                params.no_op_offload.insert(params.no_op_offload.end(), p.begin(), p.end());
+            } else if (arg == "-ts" || arg == "--tensor-split") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                for (auto ts : string_split(argv[i], split_delim)) {
+                    // split string by ; and /
+                    const std::regex           regex{ R"([;/]+)" };
+                    std::sregex_token_iterator it{ ts.begin(), ts.end(), regex, -1 };
+                    std::vector   split_arg{ it, {} };
+                    GGML_ASSERT(split_arg.size() <= llama_max_devices());
+
+                    std::vector tensor_split(llama_max_devices());
+                    for (size_t i = 0; i < llama_max_devices(); ++i) {
+                        if (i < split_arg.size()) {
+                            tensor_split[i] = std::stof(split_arg[i]);
+                        } else {
+                            tensor_split[i] = 0.0f;
+                        }
+                    }
+                    params.tensor_split.push_back(tensor_split);
+                }
+            } else if (arg == "-ot" || arg == "--override-tensor") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto * value = argv[i];
+                /* static */ std::map buft_list;
+                if (buft_list.empty()) {
+                    // enumerate all the devices and add their buffer types to the list
+                    for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
+                        auto * dev = ggml_backend_dev_get(i);
+                        auto * buft = ggml_backend_dev_buffer_type(dev);
+                        if (buft) {
+                            buft_list[ggml_backend_buft_name(buft)] = buft;
+                        }
+                    }
+                }
+                auto override_group_span_len = std::strcspn(value, ",");
+                bool last_group = false;
+                do {
+                    if (override_group_span_len == 0) {
+                        // Adds an empty override-tensors for an empty span
+                        params.tensor_buft_overrides.push_back({{}});
+                        if (value[override_group_span_len] == '\0') {
+                            value = &value[override_group_span_len];
+                            last_group = true;
+                        } else {
+                            value = &value[override_group_span_len + 1];
+                            override_group_span_len = std::strcspn(value, ",");
+                        }
+                        continue;
+                    }
+                    // Stamps null terminators into the argv
+                    // value for this option to avoid the
+                    // memory leak present in the implementation
+                    // over in arg.cpp. Acceptable because we
+                    // only parse these args once in this program.
+                    auto * override_group = value;
+                    if (value[override_group_span_len] == '\0') {
+                        value = &value[override_group_span_len];
+                        last_group = true;
+                    } else {
+                        value[override_group_span_len] = '\0';
+                        value = &value[override_group_span_len + 1];
+                    }
+                    std::vector group_tensor_buft_overrides{};
+                    auto override_span_len = std::strcspn(override_group, ";");
+                    while (override_span_len > 0) {
+                        auto * override = override_group;
+                        if (override_group[override_span_len] != '\0') {
+                            override_group[override_span_len] = '\0';
+                            override_group = &override_group[override_span_len + 1];
+                        } else {
+                            override_group = &override_group[override_span_len];
+                        }
+                        auto tensor_name_span_len = std::strcspn(override, "=");
+                        if (tensor_name_span_len >= override_span_len) {
+                            invalid_param = true;
+                            break;
+                        }
+                        override[tensor_name_span_len] = '\0';
+                        auto * tensor_name = override;
+                        auto * buffer_type = &override[tensor_name_span_len + 1];
+                        if (buft_list.find(buffer_type) == buft_list.end()) {
+                            printf("error: unrecognized buffer type '%s'\n", buffer_type);
+                            printf("Available buffer types:\n");
+                            for (const auto & it : buft_list) {
+                                printf("  %s\n", ggml_backend_buft_name(it.second));
+                            }
+                            invalid_param = true;
+                            break;
+                        }
+                        group_tensor_buft_overrides.push_back({tensor_name, buft_list.at(buffer_type)});
+                        override_span_len = std::strcspn(override_group, ";");
+                    }
+                    if (invalid_param) {
+                        break;
+                    }
+                    group_tensor_buft_overrides.push_back({nullptr,nullptr});
+                    params.tensor_buft_overrides.push_back(group_tensor_buft_overrides);
+                    override_group_span_len = std::strcspn(value, ",");
+                } while (!last_group);
+            } else if (arg == "-r" || arg == "--repetitions") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                params.reps = std::stoi(argv[i]);
+            } else if (arg == "--prio") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                params.prio = (enum ggml_sched_priority) std::stoi(argv[i]);
+            } else if (arg == "--delay") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                params.delay = std::stoi(argv[i]);
+            } else if (arg == "-o" || arg == "--output") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                invalid_param = !output_format_from_str(argv[i], params.output_format);
+            } else if (arg == "-oe" || arg == "--output-err") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                invalid_param = !output_format_from_str(argv[i], params.output_format_stderr);
+            } else if (arg == "-v" || arg == "--verbose") {
+                params.verbose = true;
+            } else if (arg == "--progress") {
+                params.progress = true;
+            } else {
+                invalid_param = true;
+                break;
+            }
+        } catch (const std::exception & e) {
+            fprintf(stderr, "error: %s\n", e.what());
+            invalid_param = true;
+            break;
+        }
+    }
+
+    if (invalid_param) {
+        fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
+        print_usage(argc, argv);
+        exit(1);
+    }
+
+    // set defaults
+    if (params.model.empty()) {
+        params.model = cmd_params_defaults.model;
+    }
+    if (params.n_prompt.empty()) {
+        params.n_prompt = cmd_params_defaults.n_prompt;
+    }
+    if (params.n_gen.empty()) {
+        params.n_gen = cmd_params_defaults.n_gen;
+    }
+    if (params.n_pg.empty()) {
+        params.n_pg = cmd_params_defaults.n_pg;
+    }
+    if (params.n_depth.empty()) {
+        params.n_depth = cmd_params_defaults.n_depth;
+    }
+    if (params.n_batch.empty()) {
+        params.n_batch = cmd_params_defaults.n_batch;
+    }
+    if (params.n_ubatch.empty()) {
+        params.n_ubatch = cmd_params_defaults.n_ubatch;
+    }
+    if (params.type_k.empty()) {
+        params.type_k = cmd_params_defaults.type_k;
+    }
+    if (params.type_v.empty()) {
+        params.type_v = cmd_params_defaults.type_v;
+    }
+    if (params.defrag_thold.empty()) {
+        params.defrag_thold = cmd_params_defaults.defrag_thold;
+    }
+    if (params.n_gpu_layers.empty()) {
+        params.n_gpu_layers = cmd_params_defaults.n_gpu_layers;
+    }
+    if (params.rpc_servers.empty()) {
+        params.rpc_servers = cmd_params_defaults.rpc_servers;
+    }
+    if (params.split_mode.empty()) {
+        params.split_mode = cmd_params_defaults.split_mode;
+    }
+    if (params.main_gpu.empty()) {
+        params.main_gpu = cmd_params_defaults.main_gpu;
+    }
+    if (params.no_kv_offload.empty()) {
+        params.no_kv_offload = cmd_params_defaults.no_kv_offload;
+    }
+    if (params.flash_attn.empty()) {
+        params.flash_attn = cmd_params_defaults.flash_attn;
+    }
+    if (params.tensor_split.empty()) {
+        params.tensor_split = cmd_params_defaults.tensor_split;
+    }
+    if (params.tensor_buft_overrides.empty()) {
+        params.tensor_buft_overrides = cmd_params_defaults.tensor_buft_overrides;
+    }
+    if (params.use_mmap.empty()) {
+        params.use_mmap = cmd_params_defaults.use_mmap;
+    }
+    if (params.embeddings.empty()) {
+        params.embeddings = cmd_params_defaults.embeddings;
+    }
+    if (params.no_op_offload.empty()) {
+        params.no_op_offload = cmd_params_defaults.no_op_offload;
+    }
+    if (params.n_threads.empty()) {
+        params.n_threads = cmd_params_defaults.n_threads;
+    }
+    if (params.cpu_mask.empty()) {
+        params.cpu_mask = cmd_params_defaults.cpu_mask;
+    }
+    if (params.cpu_strict.empty()) {
+        params.cpu_strict = cmd_params_defaults.cpu_strict;
+    }
+    if (params.poll.empty()) {
+        params.poll = cmd_params_defaults.poll;
+    }
+
+    return params;
+}
+
+struct cmd_params_instance {
+    std::string        model;
+    int                n_prompt;
+    int                n_gen;
+    int                n_depth;
+    int                n_batch;
+    int                n_ubatch;
+    ggml_type          type_k;
+    ggml_type          type_v;
+    float              defrag_thold;
+    int                n_threads;
+    std::string        cpu_mask;
+    bool               cpu_strict;
+    int                poll;
+    int                n_gpu_layers;
+    std::string        rpc_servers_str;
+    llama_split_mode   split_mode;
+    int                main_gpu;
+    bool               no_kv_offload;
+    bool               flash_attn;
+    std::vector tensor_split;
+    std::vector tensor_buft_overrides;
+    bool               use_mmap;
+    bool               embeddings;
+    bool               no_op_offload;
+
+    llama_model_params to_llama_mparams() const {
+        llama_model_params mparams = llama_model_default_params();
+
+        mparams.n_gpu_layers = n_gpu_layers;
+        if (!rpc_servers_str.empty()) {
+            auto rpc_servers = string_split(rpc_servers_str, ',');
+
+            // add RPC devices
+            if (!rpc_servers.empty()) {
+                ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
+                if (!rpc_reg) {
+                    fprintf(stderr, "%s: failed to find RPC backend\n", __func__);
+                    exit(1);
+                }
+
+                typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
+                ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
+                if (!ggml_backend_rpc_add_device_fn) {
+                    fprintf(stderr, "%s: failed to find RPC device add function\n", __func__);
+                    exit(1);
+                }
+                static std::vector devices;
+                devices.clear();
+                for (const std::string & server : rpc_servers) {
+                    ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
+                    if (dev) {
+                        devices.push_back(dev);
+                    } else {
+                        fprintf(stderr, "%s: failed to add RPC device for server '%s'\n", __func__, server.c_str());
+                        exit(1);
+                    }
+                }
+                devices.push_back(nullptr);
+                mparams.devices = devices.data();
+            }
+        }
+        mparams.split_mode   = split_mode;
+        mparams.main_gpu     = main_gpu;
+        mparams.tensor_split = tensor_split.data();
+        mparams.use_mmap     = use_mmap;
+
+        if (tensor_buft_overrides.empty()) {
+            mparams.tensor_buft_overrides = nullptr;
+        } else {
+            GGML_ASSERT(tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern");
+            mparams.tensor_buft_overrides = tensor_buft_overrides.data();
+        }
+
+        return mparams;
+    }
+
+    bool equal_mparams(const cmd_params_instance & other) const {
+        return model == other.model && n_gpu_layers == other.n_gpu_layers && rpc_servers_str == other.rpc_servers_str &&
+               split_mode == other.split_mode && main_gpu == other.main_gpu && use_mmap == other.use_mmap &&
+               tensor_split == other.tensor_split && vec_tensor_buft_override_equal(tensor_buft_overrides, other.tensor_buft_overrides);
+    }
+
+    llama_context_params to_llama_cparams() const {
+        llama_context_params cparams = llama_context_default_params();
+
+        cparams.n_ctx        = n_prompt + n_gen + n_depth;
+        cparams.n_batch      = n_batch;
+        cparams.n_ubatch     = n_ubatch;
+        cparams.type_k       = type_k;
+        cparams.type_v       = type_v;
+        cparams.defrag_thold = defrag_thold;
+        cparams.offload_kqv  = !no_kv_offload;
+        cparams.flash_attn   = flash_attn;
+        cparams.embeddings   = embeddings;
+        cparams.op_offload   = !no_op_offload;
+        cparams.swa_full     = false;
+
+        return cparams;
+    }
+};
+
+static std::vector get_cmd_params_instances(const cmd_params & params) {
+    std::vector instances;
+
+    // this ordering minimizes the number of times that each model needs to be reloaded
+    // clang-format off
+    for (const auto & m : params.model)
+    for (const auto & nl : params.n_gpu_layers)
+    for (const auto & rpc : params.rpc_servers)
+    for (const auto & sm : params.split_mode)
+    for (const auto & mg : params.main_gpu)
+    for (const auto & ts : params.tensor_split)
+    for (const auto & ot : params.tensor_buft_overrides)
+    for (const auto & mmp : params.use_mmap)
+    for (const auto & embd : params.embeddings)
+    for (const auto & nopo : params.no_op_offload)
+    for (const auto & nb : params.n_batch)
+    for (const auto & nub : params.n_ubatch)
+    for (const auto & tk : params.type_k)
+    for (const auto & tv : params.type_v)
+    for (const auto & defrag_thold : params.defrag_thold)
+    for (const auto & nkvo : params.no_kv_offload)
+    for (const auto & fa : params.flash_attn)
+    for (const auto & nt : params.n_threads)
+    for (const auto & cm : params.cpu_mask)
+    for (const auto & cs : params.cpu_strict)
+    for (const auto & nd : params.n_depth)
+    for (const auto & pl : params.poll) {
+        for (const auto & n_prompt : params.n_prompt) {
+            if (n_prompt == 0) {
+                continue;
+            }
+            cmd_params_instance instance = {
+                /* .model        = */ m,
+                /* .n_prompt     = */ n_prompt,
+                /* .n_gen        = */ 0,
+                /* .n_depth      = */ nd,
+                /* .n_batch      = */ nb,
+                /* .n_ubatch     = */ nub,
+                /* .type_k       = */ tk,
+                /* .type_v       = */ tv,
+                /* .defrag_thold = */ defrag_thold,
+                /* .n_threads    = */ nt,
+                /* .cpu_mask     = */ cm,
+                /* .cpu_strict   = */ cs,
+                /* .poll         = */ pl,
+                /* .n_gpu_layers = */ nl,
+                /* .rpc_servers  = */ rpc,
+                /* .split_mode   = */ sm,
+                /* .main_gpu     = */ mg,
+                /* .no_kv_offload= */ nkvo,
+                /* .flash_attn   = */ fa,
+                /* .tensor_split = */ ts,
+                /* .tensor_buft_overrides = */ ot,
+                /* .use_mmap     = */ mmp,
+                /* .embeddings   = */ embd,
+                /* .no_op_offload= */ nopo,
+            };
+            instances.push_back(instance);
+        }
+
+        for (const auto & n_gen : params.n_gen) {
+            if (n_gen == 0) {
+                continue;
+            }
+            cmd_params_instance instance = {
+                /* .model        = */ m,
+                /* .n_prompt     = */ 0,
+                /* .n_gen        = */ n_gen,
+                /* .n_depth      = */ nd,
+                /* .n_batch      = */ nb,
+                /* .n_ubatch     = */ nub,
+                /* .type_k       = */ tk,
+                /* .type_v       = */ tv,
+                /* .defrag_thold = */ defrag_thold,
+                /* .n_threads    = */ nt,
+                /* .cpu_mask     = */ cm,
+                /* .cpu_strict   = */ cs,
+                /* .poll         = */ pl,
+                /* .n_gpu_layers = */ nl,
+                /* .rpc_servers  = */ rpc,
+                /* .split_mode   = */ sm,
+                /* .main_gpu     = */ mg,
+                /* .no_kv_offload= */ nkvo,
+                /* .flash_attn   = */ fa,
+                /* .tensor_split = */ ts,
+                /* .tensor_buft_overrides = */ ot,
+                /* .use_mmap     = */ mmp,
+                /* .embeddings   = */ embd,
+                /* .no_op_offload= */ nopo,
+            };
+            instances.push_back(instance);
+        }
+
+        for (const auto & n_pg : params.n_pg) {
+            if (n_pg.first == 0 && n_pg.second == 0) {
+                continue;
+            }
+            cmd_params_instance instance = {
+                /* .model        = */ m,
+                /* .n_prompt     = */ n_pg.first,
+                /* .n_gen        = */ n_pg.second,
+                /* .n_depth      = */ nd,
+                /* .n_batch      = */ nb,
+                /* .n_ubatch     = */ nub,
+                /* .type_k       = */ tk,
+                /* .type_v       = */ tv,
+                /* .defrag_thold = */ defrag_thold,
+                /* .n_threads    = */ nt,
+                /* .cpu_mask     = */ cm,
+                /* .cpu_strict   = */ cs,
+                /* .poll         = */ pl,
+                /* .n_gpu_layers = */ nl,
+                /* .rpc_servers  = */ rpc,
+                /* .split_mode   = */ sm,
+                /* .main_gpu     = */ mg,
+                /* .no_kv_offload= */ nkvo,
+                /* .flash_attn   = */ fa,
+                /* .tensor_split = */ ts,
+                /* .tensor_buft_overrides = */ ot,
+                /* .use_mmap     = */ mmp,
+                /* .embeddings   = */ embd,
+                /* .no_op_offload= */ nopo,
+            };
+            instances.push_back(instance);
+        }
+    }
+    // clang-format on
+
+    return instances;
+}
+
+struct test {
+    static const std::string build_commit;
+    static const int         build_number;
+    const std::string        cpu_info;
+    const std::string        gpu_info;
+    std::string              model_filename;
+    std::string              model_type;
+    uint64_t                 model_size;
+    uint64_t                 model_n_params;
+    int                      n_batch;
+    int                      n_ubatch;
+    int                      n_threads;
+    std::string              cpu_mask;
+    bool                     cpu_strict;
+    int                      poll;
+    ggml_type                type_k;
+    ggml_type                type_v;
+    float                    defrag_thold;
+    int                      n_gpu_layers;
+    llama_split_mode         split_mode;
+    int                      main_gpu;
+    bool                     no_kv_offload;
+    bool                     flash_attn;
+    std::vector       tensor_split;
+    std::vector tensor_buft_overrides;
+    bool                     use_mmap;
+    bool                     embeddings;
+    bool                     no_op_offload;
+    int                      n_prompt;
+    int                      n_gen;
+    int                      n_depth;
+    std::string              test_time;
+    std::vector    samples_ns;
+
+    test(const cmd_params_instance & inst, const llama_model * lmodel, const llama_context * ctx) :
+        cpu_info(get_cpu_info()),
+        gpu_info(get_gpu_info()) {
+
+        model_filename = inst.model;
+        char buf[128];
+        llama_model_desc(lmodel, buf, sizeof(buf));
+        model_type     = buf;
+        model_size     = llama_model_size(lmodel);
+        model_n_params = llama_model_n_params(lmodel);
+        n_batch        = inst.n_batch;
+        n_ubatch       = inst.n_ubatch;
+        n_threads      = inst.n_threads;
+        cpu_mask       = inst.cpu_mask;
+        cpu_strict     = inst.cpu_strict;
+        poll           = inst.poll;
+        type_k         = inst.type_k;
+        type_v         = inst.type_v;
+        defrag_thold   = inst.defrag_thold;
+        n_gpu_layers   = inst.n_gpu_layers;
+        split_mode     = inst.split_mode;
+        main_gpu       = inst.main_gpu;
+        no_kv_offload  = inst.no_kv_offload;
+        flash_attn     = inst.flash_attn;
+        tensor_split   = inst.tensor_split;
+        tensor_buft_overrides = inst.tensor_buft_overrides;
+        use_mmap       = inst.use_mmap;
+        embeddings     = inst.embeddings;
+        no_op_offload  = inst.no_op_offload;
+        n_prompt       = inst.n_prompt;
+        n_gen          = inst.n_gen;
+        n_depth        = inst.n_depth;
+        // RFC 3339 date-time format
+        time_t t       = time(NULL);
+        std::strftime(buf, sizeof(buf), "%FT%TZ", gmtime(&t));
+        test_time = buf;
+
+        (void) ctx;
+    }
+
+    uint64_t avg_ns() const { return ::avg(samples_ns); }
+
+    uint64_t stdev_ns() const { return ::stdev(samples_ns); }
+
+    std::vector get_ts() const {
+        int                 n_tokens = n_prompt + n_gen;
+        std::vector ts;
+        std::transform(samples_ns.begin(), samples_ns.end(), std::back_inserter(ts),
+                       [n_tokens](uint64_t t) { return 1e9 * n_tokens / t; });
+        return ts;
+    }
+
+    double avg_ts() const { return ::avg(get_ts()); }
+
+    double stdev_ts() const { return ::stdev(get_ts()); }
+
+    static std::string get_backend() {
+        std::vector backends;
+        for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
+            auto *      reg  = ggml_backend_reg_get(i);
+            std::string name = ggml_backend_reg_name(reg);
+            if (name != "CPU") {
+                backends.push_back(ggml_backend_reg_name(reg));
+            }
+        }
+        return backends.empty() ? "CPU" : join(backends, ",");
+    }
+
+    static const std::vector & get_fields() {
+        static const std::vector fields = {
+            "build_commit", "build_number", "cpu_info",       "gpu_info",   "backends",     "model_filename",
+            "model_type",   "model_size",   "model_n_params", "n_batch",    "n_ubatch",     "n_threads",
+            "cpu_mask",     "cpu_strict",   "poll",           "type_k",     "type_v",       "n_gpu_layers",
+            "split_mode",   "main_gpu",     "no_kv_offload",  "flash_attn", "tensor_split", "tensor_buft_overrides",
+            "defrag_thold",
+            "use_mmap",     "embeddings",   "no_op_offload",   "n_prompt",       "n_gen",      "n_depth",      "test_time",
+            "avg_ns",       "stddev_ns",    "avg_ts",         "stddev_ts",
+        };
+        return fields;
+    }
+
+    enum field_type { STRING, BOOL, INT, FLOAT };
+
+    static field_type get_field_type(const std::string & field) {
+        if (field == "build_number" || field == "n_batch" || field == "n_ubatch" || field == "n_threads" ||
+            field == "poll" || field == "model_size" || field == "model_n_params" || field == "n_gpu_layers" ||
+            field == "main_gpu" || field == "n_prompt" || field == "n_gen" || field == "n_depth" ||
+            field == "avg_ns" || field == "stddev_ns" || field == "no_op_offload") {
+            return INT;
+        }
+        if (field == "f16_kv" || field == "no_kv_offload" || field == "cpu_strict" || field == "flash_attn" ||
+            field == "use_mmap" || field == "embeddings") {
+            return BOOL;
+        }
+        if (field == "avg_ts" || field == "stddev_ts" || field == "defrag_thold") {
+            return FLOAT;
+        }
+        return STRING;
+    }
+
+    std::vector get_values() const {
+        std::string tensor_split_str;
+        std::string tensor_buft_overrides_str;
+        int         max_nonzero = 0;
+        for (size_t i = 0; i < llama_max_devices(); i++) {
+            if (tensor_split[i] > 0) {
+                max_nonzero = i;
+            }
+        }
+        for (int i = 0; i <= max_nonzero; i++) {
+            char buf[32];
+            snprintf(buf, sizeof(buf), "%.2f", tensor_split[i]);
+            tensor_split_str += buf;
+            if (i < max_nonzero) {
+                tensor_split_str += "/";
+            }
+        }
+        if (tensor_buft_overrides.size() == 1) {
+            // Last element of tensor_buft_overrides is always a null pattern
+            // so if it is only one element long, it must be a null pattern.
+            GGML_ASSERT(tensor_buft_overrides[0].pattern == nullptr);
+            tensor_buft_overrides_str += "none";
+        } else {
+            for (size_t i = 0; i < tensor_buft_overrides.size()-1; i++) {
+                // Last element of tensor_buft_overrides is always a null pattern
+                if (tensor_buft_overrides[i].pattern == nullptr) {
+                    tensor_buft_overrides_str += "none";
+                } else {
+                    tensor_buft_overrides_str += tensor_buft_overrides[i].pattern;
+                    tensor_buft_overrides_str += "=";
+                    tensor_buft_overrides_str += ggml_backend_buft_name(tensor_buft_overrides[i].buft);
+                }
+                if (i + 2 < tensor_buft_overrides.size()) {
+                    tensor_buft_overrides_str += ";";
+                }
+            }
+        }
+        std::vector values = { build_commit,
+                                            std::to_string(build_number),
+                                            cpu_info,
+                                            gpu_info,
+                                            get_backend(),
+                                            model_filename,
+                                            model_type,
+                                            std::to_string(model_size),
+                                            std::to_string(model_n_params),
+                                            std::to_string(n_batch),
+                                            std::to_string(n_ubatch),
+                                            std::to_string(n_threads),
+                                            cpu_mask,
+                                            std::to_string(cpu_strict),
+                                            std::to_string(poll),
+                                            ggml_type_name(type_k),
+                                            ggml_type_name(type_v),
+                                            std::to_string(n_gpu_layers),
+                                            split_mode_str(split_mode),
+                                            std::to_string(main_gpu),
+                                            std::to_string(no_kv_offload),
+                                            std::to_string(flash_attn),
+                                            tensor_split_str,
+                                            tensor_buft_overrides_str,
+                                            std::to_string(defrag_thold),
+                                            std::to_string(use_mmap),
+                                            std::to_string(embeddings),
+                                            std::to_string(no_op_offload),
+                                            std::to_string(n_prompt),
+                                            std::to_string(n_gen),
+                                            std::to_string(n_depth),
+                                            test_time,
+                                            std::to_string(avg_ns()),
+                                            std::to_string(stdev_ns()),
+                                            std::to_string(avg_ts()),
+                                            std::to_string(stdev_ts()) };
+        return values;
+    }
+
+    std::map get_map() const {
+        std::map map;
+        auto                               fields = get_fields();
+        auto                               values = get_values();
+        std::transform(fields.begin(), fields.end(), values.begin(), std::inserter(map, map.end()),
+                       std::make_pair);
+        return map;
+    }
+};
+
+const std::string test::build_commit = LLAMA_COMMIT;
+const int         test::build_number = LLAMA_BUILD_NUMBER;
+
+struct printer {
+    virtual ~printer() {}
+
+    FILE * fout;
+
+    virtual void print_header(const cmd_params & params) { (void) params; }
+
+    virtual void print_test(const test & t) = 0;
+
+    virtual void print_footer() {}
+};
+
+struct csv_printer : public printer {
+    static std::string escape_csv(const std::string & field) {
+        std::string escaped = "\"";
+        for (auto c : field) {
+            if (c == '"') {
+                escaped += "\"";
+            }
+            escaped += c;
+        }
+        escaped += "\"";
+        return escaped;
+    }
+
+    void print_header(const cmd_params & params) override {
+        std::vector fields = test::get_fields();
+        fprintf(fout, "%s\n", join(fields, ",").c_str());
+        (void) params;
+    }
+
+    void print_test(const test & t) override {
+        std::vector values = t.get_values();
+        std::transform(values.begin(), values.end(), values.begin(), escape_csv);
+        fprintf(fout, "%s\n", join(values, ",").c_str());
+    }
+};
+
+static std::string escape_json(const std::string & value) {
+    std::string escaped;
+    for (auto c : value) {
+        if (c == '"') {
+            escaped += "\\\"";
+        } else if (c == '\\') {
+            escaped += "\\\\";
+        } else if (c <= 0x1f) {
+            char buf[8];
+            snprintf(buf, sizeof(buf), "\\u%04x", c);
+            escaped += buf;
+        } else {
+            escaped += c;
+        }
+    }
+    return escaped;
+}
+
+static std::string format_json_value(const std::string & field, const std::string & value) {
+    switch (test::get_field_type(field)) {
+        case test::STRING:
+            return "\"" + escape_json(value) + "\"";
+        case test::BOOL:
+            return value == "0" ? "false" : "true";
+        default:
+            return value;
+    }
+}
+
+struct json_printer : public printer {
+    bool first = true;
+
+    void print_header(const cmd_params & params) override {
+        fprintf(fout, "[\n");
+        (void) params;
+    }
+
+    void print_fields(const std::vector & fields, const std::vector & values) {
+        assert(fields.size() == values.size());
+        for (size_t i = 0; i < fields.size(); i++) {
+            fprintf(fout, "    \"%s\": %s,\n", fields.at(i).c_str(),
+                    format_json_value(fields.at(i), values.at(i)).c_str());
+        }
+    }
+
+    void print_test(const test & t) override {
+        if (first) {
+            first = false;
+        } else {
+            fprintf(fout, ",\n");
+        }
+        fprintf(fout, "  {\n");
+        print_fields(test::get_fields(), t.get_values());
+        fprintf(fout, "    \"samples_ns\": [ %s ],\n", join(t.samples_ns, ", ").c_str());
+        fprintf(fout, "    \"samples_ts\": [ %s ]\n", join(t.get_ts(), ", ").c_str());
+        fprintf(fout, "  }");
+        fflush(fout);
+    }
+
+    void print_footer() override { fprintf(fout, "\n]\n"); }
+};
+
+struct jsonl_printer : public printer {
+    void print_fields(const std::vector & fields, const std::vector & values) {
+        assert(fields.size() == values.size());
+        for (size_t i = 0; i < fields.size(); i++) {
+            fprintf(fout, "\"%s\": %s, ", fields.at(i).c_str(), format_json_value(fields.at(i), values.at(i)).c_str());
+        }
+    }
+
+    void print_test(const test & t) override {
+        fprintf(fout, "{");
+        print_fields(test::get_fields(), t.get_values());
+        fprintf(fout, "\"samples_ns\": [ %s ],", join(t.samples_ns, ", ").c_str());
+        fprintf(fout, "\"samples_ts\": [ %s ]", join(t.get_ts(), ", ").c_str());
+        fprintf(fout, "}\n");
+        fflush(fout);
+    }
+};
+
+struct markdown_printer : public printer {
+    std::vector fields;
+
+    static int get_field_width(const std::string & field) {
+        if (field == "model") {
+            return -30;
+        }
+        if (field == "t/s") {
+            return 20;
+        }
+        if (field == "size" || field == "params") {
+            return 10;
+        }
+        if (field == "n_gpu_layers") {
+            return 3;
+        }
+        if (field == "n_threads") {
+            return 7;
+        }
+        if (field == "n_batch") {
+            return 7;
+        }
+        if (field == "n_ubatch") {
+            return 8;
+        }
+        if (field == "type_k" || field == "type_v") {
+            return 6;
+        }
+        if (field == "split_mode") {
+            return 5;
+        }
+        if (field == "flash_attn") {
+            return 2;
+        }
+        if (field == "use_mmap") {
+            return 4;
+        }
+        if (field == "test") {
+            return 15;
+        }
+        if (field == "no_op_offload") {
+            return 4;
+        }
+
+        int width = std::max((int) field.length(), 10);
+
+        if (test::get_field_type(field) == test::STRING) {
+            return -width;
+        }
+        return width;
+    }
+
+    static std::string get_field_display_name(const std::string & field) {
+        if (field == "n_gpu_layers") {
+            return "ngl";
+        }
+        if (field == "split_mode") {
+            return "sm";
+        }
+        if (field == "n_threads") {
+            return "threads";
+        }
+        if (field == "no_kv_offload") {
+            return "nkvo";
+        }
+        if (field == "flash_attn") {
+            return "fa";
+        }
+        if (field == "use_mmap") {
+            return "mmap";
+        }
+        if (field == "embeddings") {
+            return "embd";
+        }
+        if (field == "no_op_offload") {
+            return "nopo";
+        }
+        if (field == "tensor_split") {
+            return "ts";
+        }
+        if (field == "tensor_buft_overrides") {
+            return "ot";
+        }
+        return field;
+    }
+
+    void print_header(const cmd_params & params) override {
+        // select fields to print
+        fields.emplace_back("model");
+        fields.emplace_back("size");
+        fields.emplace_back("params");
+        fields.emplace_back("backend");
+        bool is_cpu_backend = test::get_backend().find("CPU") != std::string::npos ||
+                              test::get_backend().find("BLAS") != std::string::npos;
+        if (!is_cpu_backend) {
+            fields.emplace_back("n_gpu_layers");
+        }
+        if (params.n_threads.size() > 1 || params.n_threads != cmd_params_defaults.n_threads || is_cpu_backend) {
+            fields.emplace_back("n_threads");
+        }
+        if (params.cpu_mask.size() > 1 || params.cpu_mask != cmd_params_defaults.cpu_mask) {
+            fields.emplace_back("cpu_mask");
+        }
+        if (params.cpu_strict.size() > 1 || params.cpu_strict != cmd_params_defaults.cpu_strict) {
+            fields.emplace_back("cpu_strict");
+        }
+        if (params.poll.size() > 1 || params.poll != cmd_params_defaults.poll) {
+            fields.emplace_back("poll");
+        }
+        if (params.n_batch.size() > 1 || params.n_batch != cmd_params_defaults.n_batch) {
+            fields.emplace_back("n_batch");
+        }
+        if (params.n_ubatch.size() > 1 || params.n_ubatch != cmd_params_defaults.n_ubatch) {
+            fields.emplace_back("n_ubatch");
+        }
+        if (params.type_k.size() > 1 || params.type_k != cmd_params_defaults.type_k) {
+            fields.emplace_back("type_k");
+        }
+        if (params.type_v.size() > 1 || params.type_v != cmd_params_defaults.type_v) {
+            fields.emplace_back("type_v");
+        }
+        if (params.defrag_thold.size() > 1 || params.defrag_thold != cmd_params_defaults.defrag_thold) {
+            fields.emplace_back("defrag_thold");
+        }
+        if (params.main_gpu.size() > 1 || params.main_gpu != cmd_params_defaults.main_gpu) {
+            fields.emplace_back("main_gpu");
+        }
+        if (params.split_mode.size() > 1 || params.split_mode != cmd_params_defaults.split_mode) {
+            fields.emplace_back("split_mode");
+        }
+        if (params.no_kv_offload.size() > 1 || params.no_kv_offload != cmd_params_defaults.no_kv_offload) {
+            fields.emplace_back("no_kv_offload");
+        }
+        if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) {
+            fields.emplace_back("flash_attn");
+        }
+        if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
+            fields.emplace_back("tensor_split");
+        }
+        if (params.tensor_buft_overrides.size() > 1 || !vec_vec_tensor_buft_override_equal(params.tensor_buft_overrides, cmd_params_defaults.tensor_buft_overrides)) {
+            fields.emplace_back("tensor_buft_overrides");
+        }
+        if (params.use_mmap.size() > 1 || params.use_mmap != cmd_params_defaults.use_mmap) {
+            fields.emplace_back("use_mmap");
+        }
+        if (params.embeddings.size() > 1 || params.embeddings != cmd_params_defaults.embeddings) {
+            fields.emplace_back("embeddings");
+        }
+        if (params.no_op_offload.size() > 1 || params.no_op_offload != cmd_params_defaults.no_op_offload) {
+            fields.emplace_back("no_op_offload");
+        }
+        fields.emplace_back("test");
+        fields.emplace_back("t/s");
+
+        fprintf(fout, "|");
+        for (const auto & field : fields) {
+            fprintf(fout, " %*s |", get_field_width(field), get_field_display_name(field).c_str());
+        }
+        fprintf(fout, "\n");
+        fprintf(fout, "|");
+        for (const auto & field : fields) {
+            int width = get_field_width(field);
+            fprintf(fout, " %s%s |", std::string(std::abs(width) - 1, '-').c_str(), width > 0 ? ":" : "-");
+        }
+        fprintf(fout, "\n");
+    }
+
+    void print_test(const test & t) override {
+        std::map vmap = t.get_map();
+
+        fprintf(fout, "|");
+        for (const auto & field : fields) {
+            std::string value;
+            char        buf[128];
+            if (field == "model") {
+                value = t.model_type;
+            } else if (field == "size") {
+                if (t.model_size < 1024 * 1024 * 1024) {
+                    snprintf(buf, sizeof(buf), "%.2f MiB", t.model_size / 1024.0 / 1024.0);
+                } else {
+                    snprintf(buf, sizeof(buf), "%.2f GiB", t.model_size / 1024.0 / 1024.0 / 1024.0);
+                }
+                value = buf;
+            } else if (field == "params") {
+                if (t.model_n_params < 1000 * 1000 * 1000) {
+                    snprintf(buf, sizeof(buf), "%.2f M", t.model_n_params / 1e6);
+                } else {
+                    snprintf(buf, sizeof(buf), "%.2f B", t.model_n_params / 1e9);
+                }
+                value = buf;
+            } else if (field == "backend") {
+                value = test::get_backend();
+            } else if (field == "test") {
+                if (t.n_prompt > 0 && t.n_gen == 0) {
+                    snprintf(buf, sizeof(buf), "pp%d", t.n_prompt);
+                } else if (t.n_gen > 0 && t.n_prompt == 0) {
+                    snprintf(buf, sizeof(buf), "tg%d", t.n_gen);
+                } else {
+                    snprintf(buf, sizeof(buf), "pp%d+tg%d", t.n_prompt, t.n_gen);
+                }
+                if (t.n_depth > 0) {
+                    int len = strlen(buf);
+                    snprintf(buf + len, sizeof(buf) - len, " @ d%d", t.n_depth);
+                }
+                value = buf;
+            } else if (field == "t/s") {
+                snprintf(buf, sizeof(buf), "%.2f ± %.2f", t.avg_ts(), t.stdev_ts());
+                value = buf;
+            } else if (vmap.find(field) != vmap.end()) {
+                value = vmap.at(field);
+            } else {
+                assert(false);
+                exit(1);
+            }
+
+            int width = get_field_width(field);
+            if (field == "t/s") {
+                // HACK: the utf-8 character is 2 bytes
+                width += 1;
+            }
+            fprintf(fout, " %*s |", width, value.c_str());
+        }
+        fprintf(fout, "\n");
+    }
+
+    void print_footer() override {
+        fprintf(fout, "\nbuild: %s (%d)\n", test::build_commit.c_str(), test::build_number);
+    }
+};
+
+struct sql_printer : public printer {
+    static std::string get_sql_field_type(const std::string & field) {
+        switch (test::get_field_type(field)) {
+            case test::STRING:
+                return "TEXT";
+            case test::BOOL:
+            case test::INT:
+                return "INTEGER";
+            case test::FLOAT:
+                return "REAL";
+            default:
+                assert(false);
+                exit(1);
+        }
+    }
+
+    void print_header(const cmd_params & params) override {
+        std::vector fields = test::get_fields();
+        fprintf(fout, "CREATE TABLE IF NOT EXISTS test (\n");
+        for (size_t i = 0; i < fields.size(); i++) {
+            fprintf(fout, "  %s %s%s\n", fields.at(i).c_str(), get_sql_field_type(fields.at(i)).c_str(),
+                    i < fields.size() - 1 ? "," : "");
+        }
+        fprintf(fout, ");\n");
+        fprintf(fout, "\n");
+        (void) params;
+    }
+
+    void print_test(const test & t) override {
+        fprintf(fout, "INSERT INTO test (%s) ", join(test::get_fields(), ", ").c_str());
+        fprintf(fout, "VALUES (");
+        std::vector values = t.get_values();
+        for (size_t i = 0; i < values.size(); i++) {
+            fprintf(fout, "'%s'%s", values.at(i).c_str(), i < values.size() - 1 ? ", " : "");
+        }
+        fprintf(fout, ");\n");
+    }
+};
+
+static bool test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
+    llama_set_n_threads(ctx, n_threads, n_threads);
+
+    const llama_model * model   = llama_get_model(ctx);
+    const llama_vocab * vocab   = llama_model_get_vocab(model);
+    const int32_t       n_vocab = llama_vocab_n_tokens(vocab);
+
+    std::vector tokens(n_batch);
+
+    int n_processed = 0;
+
+    while (n_processed < n_prompt) {
+        int n_tokens = std::min(n_prompt - n_processed, n_batch);
+        tokens[0]    = n_processed == 0 && llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
+        for (int i = 1; i < n_tokens; i++) {
+            tokens[i] = std::rand() % n_vocab;
+        }
+        int res = llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
+        if (res != 0) {
+            fprintf(stderr, "%s: failed to decode prompt batch, res = %d\n", __func__, res);
+            return false;
+        }
+        n_processed += n_tokens;
+    }
+
+    llama_synchronize(ctx);
+    return true;
+}
+
+static bool test_gen(llama_context * ctx, int n_gen, int n_threads) {
+    llama_set_n_threads(ctx, n_threads, n_threads);
+
+    const llama_model * model   = llama_get_model(ctx);
+    const llama_vocab * vocab   = llama_model_get_vocab(model);
+    const int32_t       n_vocab = llama_vocab_n_tokens(vocab);
+
+    llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
+
+    for (int i = 0; i < n_gen; i++) {
+        int res = llama_decode(ctx, llama_batch_get_one(&token, 1));
+        if (res != 0) {
+            fprintf(stderr, "%s: failed to decode generation batch, res = %d\n", __func__, res);
+            return false;
+        }
+        llama_synchronize(ctx);
+        token = std::rand() % n_vocab;
+    }
+    return true;
+}
+
+static void llama_null_log_callback(enum ggml_log_level level, const char * text, void * user_data) {
+    (void) level;
+    (void) text;
+    (void) user_data;
+}
+
+static std::unique_ptr create_printer(output_formats format) {
+    switch (format) {
+        case NONE:
+            return nullptr;
+        case CSV:
+            return std::unique_ptr(new csv_printer());
+        case JSON:
+            return std::unique_ptr(new json_printer());
+        case JSONL:
+            return std::unique_ptr(new jsonl_printer());
+        case MARKDOWN:
+            return std::unique_ptr(new markdown_printer());
+        case SQL:
+            return std::unique_ptr(new sql_printer());
+    }
+    GGML_ABORT("fatal error");
+}
+
+int main(int argc, char ** argv) {
+    // try to set locale for unicode characters in markdown
+    setlocale(LC_CTYPE, ".UTF-8");
+
+#if !defined(NDEBUG)
+    fprintf(stderr, "warning: asserts enabled, performance may be affected\n");
+#endif
+
+#if (defined(_MSC_VER) && defined(_DEBUG)) || (!defined(_MSC_VER) && !defined(__OPTIMIZE__))
+    fprintf(stderr, "warning: debug build, performance may be affected\n");
+#endif
+
+#if defined(__SANITIZE_ADDRESS__) || defined(__SANITIZE_THREAD__)
+    fprintf(stderr, "warning: sanitizer enabled, performance may be affected\n");
+#endif
+
+    // initialize backends
+    ggml_backend_load_all();
+
+    cmd_params params = parse_cmd_params(argc, argv);
+
+    auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
+    if (!cpu_dev) {
+        fprintf(stderr, "%s: error: CPU backend is not loaded\n", __func__);
+        return 1;
+    }
+    auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
+    auto * ggml_threadpool_new_fn = (decltype(ggml_threadpool_new) *) ggml_backend_reg_get_proc_address(cpu_reg, "ggml_threadpool_new");
+    auto * ggml_threadpool_free_fn = (decltype(ggml_threadpool_free) *) ggml_backend_reg_get_proc_address(cpu_reg, "ggml_threadpool_free");
+
+    // initialize llama.cpp
+    if (!params.verbose) {
+        llama_log_set(llama_null_log_callback, NULL);
+    }
+    llama_backend_init();
+    llama_numa_init(params.numa);
+
+    set_process_priority(params.prio);
+
+    // initialize printer
+    std::unique_ptr p     = create_printer(params.output_format);
+    std::unique_ptr p_err = create_printer(params.output_format_stderr);
+
+    if (p) {
+        p->fout = stdout;
+        p->print_header(params);
+    }
+
+    if (p_err) {
+        p_err->fout = stderr;
+        p_err->print_header(params);
+    }
+
+    std::vector params_instances = get_cmd_params_instances(params);
+
+    llama_model *               lmodel    = nullptr;
+    const cmd_params_instance * prev_inst = nullptr;
+
+    int  params_idx   = 0;
+    auto params_count = params_instances.size();
+    for (const auto & inst : params_instances) {
+        params_idx++;
+        if (params.progress) {
+            fprintf(stderr, "llama-bench: benchmark %d/%zu: starting\n", params_idx, params_count);
+        }
+        // keep the same model between tests when possible
+        if (!lmodel || !prev_inst || !inst.equal_mparams(*prev_inst)) {
+            if (lmodel) {
+                llama_model_free(lmodel);
+            }
+
+            lmodel = llama_model_load_from_file(inst.model.c_str(), inst.to_llama_mparams());
+            if (lmodel == NULL) {
+                fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, inst.model.c_str());
+                return 1;
+            }
+            prev_inst = &inst;
+        }
+
+        llama_context * ctx = llama_init_from_model(lmodel, inst.to_llama_cparams());
+        if (ctx == NULL) {
+            fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, inst.model.c_str());
+            llama_model_free(lmodel);
+            return 1;
+        }
+
+        test t(inst, lmodel, ctx);
+
+        llama_memory_clear(llama_get_memory(ctx), false);
+
+        // cool off before the test
+        if (params.delay) {
+            std::this_thread::sleep_for(std::chrono::seconds(params.delay));
+        }
+
+        struct ggml_threadpool_params tpp = ggml_threadpool_params_default(t.n_threads);
+        if (!parse_cpu_mask(t.cpu_mask, tpp.cpumask)) {
+            fprintf(stderr, "%s: failed to parse cpu-mask: %s\n", __func__, t.cpu_mask.c_str());
+            exit(1);
+        }
+        tpp.strict_cpu = t.cpu_strict;
+        tpp.poll       = t.poll;
+        tpp.prio       = params.prio;
+
+        struct ggml_threadpool * threadpool = ggml_threadpool_new_fn(&tpp);
+        if (!threadpool) {
+            fprintf(stderr, "%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads);
+            exit(1);
+        }
+
+        llama_attach_threadpool(ctx, threadpool, NULL);
+
+        // warmup run
+        if (t.n_prompt > 0) {
+            if (params.progress) {
+                fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup prompt run\n", params_idx, params_count);
+            }
+            //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
+            bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
+            if (!res) {
+                fprintf(stderr, "%s: error: failed to run prompt warmup\n", __func__);
+                exit(1);
+            }
+        }
+        if (t.n_gen > 0) {
+            if (params.progress) {
+                fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup generation run\n", params_idx, params_count);
+            }
+            bool res = test_gen(ctx, 1, t.n_threads);
+            if (!res) {
+                fprintf(stderr, "%s: error: failed to run gen warmup\n", __func__);
+                exit(1);
+            }
+        }
+
+        for (int i = 0; i < params.reps; i++) {
+            llama_memory_clear(llama_get_memory(ctx), false);
+
+            if (t.n_depth > 0) {
+                if (params.progress) {
+                    fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
+                            i + 1, params.reps);
+                }
+                bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
+                if (!res) {
+                    fprintf(stderr, "%s: error: failed to run depth\n", __func__);
+                    exit(1);
+                }
+            }
+
+            uint64_t t_start = get_time_ns();
+
+            if (t.n_prompt > 0) {
+                if (params.progress) {
+                    fprintf(stderr, "llama-bench: benchmark %d/%zu: prompt run %d/%d\n", params_idx, params_count,
+                            i + 1, params.reps);
+                }
+                bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
+                if (!res) {
+                    fprintf(stderr, "%s: error: failed to run prompt\n", __func__);
+                    exit(1);
+                }
+            }
+            if (t.n_gen > 0) {
+                if (params.progress) {
+                    fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count,
+                            i + 1, params.reps);
+                }
+                bool res = test_gen(ctx, t.n_gen, t.n_threads);
+                if (!res) {
+                    fprintf(stderr, "%s: error: failed to run gen\n", __func__);
+                    exit(1);
+                }
+            }
+
+            uint64_t t_ns = get_time_ns() - t_start;
+            t.samples_ns.push_back(t_ns);
+        }
+
+        if (p) {
+            p->print_test(t);
+            fflush(p->fout);
+        }
+
+        if (p_err) {
+            p_err->print_test(t);
+            fflush(p_err->fout);
+        }
+
+        llama_perf_context_print(ctx);
+
+        llama_free(ctx);
+
+        ggml_threadpool_free_fn(threadpool);
+    }
+
+    llama_model_free(lmodel);
+
+    if (p) {
+        p->print_footer();
+    }
+
+    if (p_err) {
+        p_err->print_footer();
+    }
+
+    llama_backend_free();
+
+    return 0;
+}
diff --git a/tools/main/CMakeLists.txt b/tools/main/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..af3d9150f86403679fe6fd0bc8a0b61bdf442aec
--- /dev/null
+++ b/tools/main/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(TARGET llama-cli)
+add_executable(${TARGET} main.cpp)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_17)
diff --git a/tools/main/README.md b/tools/main/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..4f16ad6b2b10ecd70f7ae14a10fde0cc0625c3d5
--- /dev/null
+++ b/tools/main/README.md
@@ -0,0 +1,388 @@
+# llama.cpp/tools/main
+
+This example program allows you to use various LLaMA language models easily and efficiently. It is specifically designed to work with the [llama.cpp](https://github.com/ggml-org/llama.cpp) project, which provides a plain C/C++ implementation with optional 4-bit quantization support for faster, lower memory inference, and is optimized for desktop CPUs. This program can be used to perform various inference tasks with LLaMA models, including generating text based on user-provided prompts and chat-like interactions with reverse prompts.
+
+## Table of Contents
+
+1. [Quick Start](#quick-start)
+2. [Common Options](#common-options)
+3. [Input Prompts](#input-prompts)
+4. [Interaction](#interaction)
+5. [Context Management](#context-management)
+6. [Generation Flags](#generation-flags)
+7. [Performance Tuning and Memory Options](#performance-tuning-and-memory-options)
+8. [Additional Options](#additional-options)
+
+## Quick Start
+
+To get started right away, run the following command, making sure to use the correct path for the model you have:
+
+First, we will need to download a model. In these examples, we will use the Gemma model from the ggml-org repo on Hugging Face.
+[https://huggingface.co/ggml-org/gemma-1.1-7b-it-Q4_K_M-GGUF/resolve/main/gemma-1.1-7b-it.Q4_K_M.gguf?download=true](https://huggingface.co/ggml-org/gemma-1.1-7b-it-Q4_K_M-GGUF/resolve/main/gemma-1.1-7b-it.Q4_K_M.gguf?download=true)
+
+Once downloaded, place your model in the models folder in llama.cpp.
+
+### Unix-based systems (Linux, macOS, etc.):
+
+##### Input prompt (One-and-done)
+
+```bash
+./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf -no-cnv --prompt "Once upon a time"
+```
+##### Conversation mode (Allow for continuous interaction with the model)
+
+```bash
+./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --chat-template gemma
+```
+
+##### Conversation mode using built-in jinja chat template
+
+```bash
+./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --jinja
+```
+
+##### One-and-done query using jinja with custom system prompt and a starting prompt
+
+```bash
+./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --jinja --single-turn -sys "You are a helpful assistant" -p "Hello"
+```
+
+##### Infinite text from a starting prompt (you can use `Ctrl-C` to stop it):
+```bash
+./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1
+```
+
+### Windows:
+
+##### Input prompt (One-and-done)
+```powershell
+./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf -no-cnv --prompt "Once upon a time"
+```
+##### Conversation mode (Allow for continuous interaction with the model)
+
+```powershell
+./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --chat-template gemma
+```
+
+##### Conversation mode using built-in jinja chat template
+
+```powershell
+./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --jinja
+```
+
+##### One-and-done query using jinja with custom system prompt and a starting prompt
+
+```powershell
+./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --jinja --single-turn -sys "You are a helpful assistant" -p "Hello"
+```
+
+#### Infinite text from a starting prompt (you can use `Ctrl-C` to stop it):
+
+```powershell
+llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1
+```
+
+## Common Options
+
+In this section, we cover the most commonly used options for running the `llama-cli` program with the LLaMA models:
+
+-   `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/gemma-1.1-7b-it.Q4_K_M.gguf`; inferred from `--model-url` if set).
+-   `-mu MODEL_URL --model-url MODEL_URL`: Specify a remote http url to download the file (e.g [https://huggingface.co/ggml-org/gemma-1.1-7b-it-Q4_K_M-GGUF/resolve/main/gemma-1.1-7b-it.Q4_K_M.gguf?download=true](https://huggingface.co/ggml-org/gemma-1.1-7b-it-Q4_K_M-GGUF/resolve/main/gemma-1.1-7b-it.Q4_K_M.gguf?download=true)).
+-   `-i, --interactive`: Run the program in interactive mode, allowing you to provide input directly and receive real-time responses.
+-   `-n N, --n-predict N`: Set the number of tokens to predict when generating text. Adjusting this value can influence the length of the generated text.
+-   `-c N, --ctx-size N`: Set the size of the prompt context. The default is 4096, but if a LLaMA model was built with a longer context, increasing this value will provide better results for longer input/inference.
+-   `-mli, --multiline-input`: Allows you to write or paste multiple lines without ending each in '\'
+-   `-t N, --threads N`: Set the number of threads to use during generation. For optimal performance, it is recommended to set this value to the number of physical CPU cores your system has.
+-   `-ngl N, --n-gpu-layers N`: When compiled with GPU support, this option allows offloading some layers to the GPU for computation. Generally results in increased performance.
+
+## Input Prompts
+
+The `llama-cli` program provides several ways to interact with the LLaMA models using input prompts:
+
+-   `--prompt PROMPT`: Provide a prompt directly as a command-line option.
+-   `--file FNAME`: Provide a file containing a prompt or multiple prompts.
+-   `--system-prompt PROMPT`: Provide a system prompt (will otherwise use the default one in the chat template (if provided)).
+-   `--system-prompt-file FNAME`: Provide a file containing a system prompt.
+-   `--interactive-first`: Run the program in interactive mode and wait for input right away. (More on this below.)
+
+## Interaction
+
+The `llama-cli` program offers a seamless way to interact with LLaMA models, allowing users to engage in real-time conversations or provide instructions for specific tasks. The interactive mode can be triggered using various options, including `--interactive` and `--interactive-first`.
+
+In interactive mode, users can participate in text generation by injecting their input during the process. Users can press `Ctrl+C` at any time to interject and type their input, followed by pressing `Return` to submit it to the LLaMA model. To submit additional lines without finalizing input, users can end the current line with a backslash (`\`) and continue typing.
+
+### Interaction Options
+
+-   `-i, --interactive`: Run the program in interactive mode, allowing users to engage in real-time conversations or provide specific instructions to the model.
+-   `--interactive-first`: Run the program in interactive mode and immediately wait for user input before starting the text generation.
+-   `-cnv,  --conversation`:  Run the program in conversation mode (does not print special tokens and suffix/prefix, use default or provided chat template) (default: true if chat template found)
+-   `-no-cnv`:  Disable conversation mode (default: false)
+-   `-st, --single-turn`:  Only process a single conversation turn (user input) and then exit.
+-   `--jinja`:  Enable jinja chat template parser, will use the model's built-in template or a user-provided one (default: false)
+-   `--color`: Enable colorized output to differentiate visually distinguishing between prompts, user input, and generated text.
+
+By understanding and utilizing these interaction options, you can create engaging and dynamic experiences with the LLaMA models, tailoring the text generation process to your specific needs.
+
+### Reverse Prompts
+
+Reverse prompts are a powerful way to create a chat-like experience with a LLaMA model by pausing the text generation when specific text strings are encountered:
+
+-   `-r PROMPT, --reverse-prompt PROMPT`: Specify one or multiple reverse prompts to pause text generation and switch to interactive mode. For example, `-r "User:"` can be used to jump back into the conversation whenever it's the user's turn to speak. This helps create a more interactive and conversational experience. However, the reverse prompt doesn't work when it ends with a space.
+
+To overcome this limitation, you can use the `--in-prefix` flag to add a space or any other characters after the reverse prompt.
+
+### In-Prefix
+
+The `--in-prefix` flag is used to add a prefix to your input, primarily, this is used to insert a space after the reverse prompt. Here's an example of how to use the `--in-prefix` flag in conjunction with the `--reverse-prompt` flag:
+
+```sh
+./llama-cli -r "User:" --in-prefix " "
+```
+
+### In-Suffix
+
+The `--in-suffix` flag is used to add a suffix after your input. This is useful for adding an "Assistant:" prompt after the user's input. It's added after the new-line character (`\n`) that's automatically added to the end of the user's input. Here's an example of how to use the `--in-suffix` flag in conjunction with the `--reverse-prompt` flag:
+
+```sh
+./llama-cli -r "User:" --in-prefix " " --in-suffix "Assistant:"
+```
+When --in-prefix or --in-suffix options are enabled the chat template ( --chat-template ) is disabled
+
+### Chat templates
+
+ `--chat-template JINJA_TEMPLATE`: This option sets a custom jinja chat template. It accepts a string, not a file name.  Default: template taken from model's metadata. Llama.cpp only supports [some pre-defined templates](https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template). These include llama2, llama3, gemma, monarch, chatml, orion, vicuna, vicuna-orca, deepseek, command-r, zephyr. When --in-prefix or --in-suffix options are enabled the chat template ( --chat-template ) is disabled.
+
+ Example usage: `--chat-template gemma`
+
+`--chat-template-file FNAME`:  Load a custom jinja chat template from an external file, useful if the model contains outdated or incompatible template, some examples can be found in models/templates. Up-to-date chat templates can be downloaded from Hugging Face using scripts/get_chat_template.py
+
+## Context Management
+
+During text generation, LLaMA models have a limited context size, which means they can only consider a certain number of tokens from the input and generated text. When the context fills up, the model resets internally, potentially losing some information from the beginning of the conversation or instructions. Context management options help maintain continuity and coherence in these situations.
+
+### Context Size
+
+- `-c N, --ctx-size N`: Set the size of the prompt context (default: 4096, 0 = loaded from model). If a LLaMA model was built with a longer context, increasing this value will yield the best results on longer input/inference.
+
+### Extended Context Size
+
+Some fine-tuned models have extended the context length by scaling RoPE. For example, if the original pre-trained model has a context length (max sequence length) of 4096 (4k) and the fine-tuned model has 32k. That is a scaling factor of 8, and should work by setting the above `--ctx-size` to 32768 (32k) and `--rope-scale` to 8.
+
+-   `--rope-scale N`: Where N is the linear scaling factor used by the fine-tuned model.
+
+### Keep Prompt
+
+The `--keep` option allows users to retain the original prompt when the model runs out of context, ensuring a connection to the initial instruction or conversation topic is maintained.
+
+-   `--keep N`: Specify the number of tokens from the initial prompt to retain when the model resets its internal context. By default, this value is set to 0 (meaning no tokens are kept). Use `-1` to retain all tokens from the initial prompt.
+
+By utilizing context management options like `--ctx-size` and `--keep`, you can maintain a more coherent and consistent interaction with the LLaMA models, ensuring that the generated text remains relevant to the original prompt or conversation.
+
+## Generation Flags
+
+The following options allow you to control the text generation process and fine-tune the diversity, creativity, and quality of the generated text according to your needs. By adjusting these options and experimenting with different combinations of values, you can find the best settings for your specific use case.
+
+### Number of Tokens to Predict
+
+-   `-n N, --predict N`: Set the number of tokens to predict when generating text (default: -1, -1 = infinity, -2 = until context filled)
+
+The `--predict` option controls the number of tokens the model generates in response to the input prompt. By adjusting this value, you can influence the length of the generated text. A higher value will result in longer text, while a lower value will produce shorter text.
+
+A value of -1 will enable infinite text generation, even though we have a finite context window. When the context window is full, some of the earlier tokens (half of the tokens after `--keep`) will be discarded. The context must then be re-evaluated before generation can resume. On large models and/or large context windows, this will result in a significant pause in output.
+
+If the pause is undesirable, a value of -2 will stop generation immediately when the context is filled.
+
+The `--no-context-shift` option allows you to stop the infinite text generation once the finite context window is full.
+
+It is important to note that the generated text may be shorter than the specified number of tokens if an End-of-Sequence (EOS) token or a reverse prompt is encountered. In interactive mode, text generation will pause and control will be returned to the user. In non-interactive mode, the program will end. In both cases, the text generation may stop before reaching the specified `--predict` value. If you want the model to keep going without ever producing End-of-Sequence on its own, you can use the `--ignore-eos` parameter.
+
+### Temperature
+
+-   `--temp N`: Adjust the randomness of the generated text (default: 0.8).
+
+Temperature is a hyperparameter that controls the randomness of the generated text. It affects the probability distribution of the model's output tokens. A higher temperature (e.g., 1.5) makes the output more random and creative, while a lower temperature (e.g., 0.5) makes the output more focused, deterministic, and conservative. The default value is 0.8, which provides a balance between randomness and determinism. At the extreme, a temperature of 0 will always pick the most likely next token, leading to identical outputs in each run.
+
+Example usage: `--temp 0`
+
+### Repeat Penalty
+
+-   `--repeat-penalty N`: Control the repetition of token sequences in the generated text default: 1.0, 1.0 = disabled).
+-   `--repeat-last-n N`: Last n tokens to consider for penalizing repetition (default: 64, 0 = disabled, -1 = ctx-size).
+
+The `repeat-penalty` option helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. The default value is 1.
+
+The `repeat-last-n` option controls the number of tokens in the history to consider for penalizing repetition. A larger value will look further back in the generated text to prevent repetitions, while a smaller value will only consider recent tokens. A value of 0 disables the penalty, and a value of -1 sets the number of tokens considered equal to the context size (`ctx-size`).
+
+### DRY Repetition Penalty
+
+DRY (Don't Repeat Yourself) sampling is an effective technique for reducing repetition in generated text even across long contexts by penalizing tokens based on their recent usage patterns (original [PR link](https://github.com/oobabooga/text-generation-webui/pull/5677)).
+
+- `--dry-multiplier N`: Set the DRY sampling multiplier (default: 0.0, 0.0 = disabled).
+- `--dry-base N`: Set the DRY sampling base value (default: 1.75).
+- `--dry-allowed-length N`: Set the allowed length for DRY sampling (default: 2).
+- `--dry-penalty-last-n N`: Set DRY penalty for the last n tokens (default: -1, 0 = disable, -1 = context size).
+- `--dry-sequence-breaker STRING`: Add a sequence breaker for DRY sampling. Can be used more than once to add multiple sequence breakers. Using this clears out the default breakers, which consist of: `['\n', ':', '"', '*']`. If the string `"none"` is supplied, no sequence breakers are used.
+
+The `dry-multiplier` option controls the strength of the DRY sampling effect. A value of 0.0 disables DRY sampling, while higher values increase its influence. A typical recommended value is 0.8.
+
+The `dry-base` option sets the base value for the exponential penalty calculation in DRY sampling. Higher values lead to more aggressive penalization of repetitions.
+
+The `dry-allowed-length` option sets the maximum length of repeated sequences that will not be penalized. Repetitions shorter than or equal to this length are not penalized, allowing for natural repetitions of short phrases or common words.
+
+The `dry-penalty-last-n` option controls how many recent tokens to consider when applying the DRY penalty. A value of -1 considers the entire context. Use a positive value to limit the consideration to a specific number of recent tokens.
+
+The `dry-sequence-breaker` option adds a single sequence breaker and can be used more than once to specify multiple sequence breakers. Sequence breakers interrupt sequence matching and break the input into parts where matching can be applied.
+
+DRY sampling provides more nuanced control over text generation, particularly for reducing long-range repetitions and maintaining global coherence.
+
+Example usage: `--dry-multiplier 0.8 --dry-base 1.75 --dry-allowed-length 2 --dry-penalty-last-n -1 --dry-sequence-breaker "—" --dry-sequence-breaker "##"`
+
+### Top-K Sampling
+
+-   `--top-k N`: Limit the next token selection to the K most probable tokens (default: 40).
+
+Top-k sampling is a text generation method that selects the next token only from the top k most likely tokens predicted by the model. It helps reduce the risk of generating low-probability or nonsensical tokens, but it may also limit the diversity of the output. A higher value for top-k (e.g., 100) will consider more tokens and lead to more diverse text, while a lower value (e.g., 10) will focus on the most probable tokens and generate more conservative text. The default value is 40.
+
+Example usage: `--top-k 30`
+
+### Top-P Sampling
+
+-   `--top-p N`: Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P (default: 0.9).
+
+Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top-p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. The default value is 0.9.
+
+Example usage: `--top-p 0.95`
+
+### Min-P Sampling
+
+-   `--min-p N`: Sets a minimum base probability threshold for token selection (default: 0.1).
+
+The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter *p* represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with *p*=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out.
+
+Example usage: `--min-p 0.05`
+
+### Locally Typical Sampling
+
+-   `--typical N`: Enable locally typical sampling with parameter p (default: 1.0, 1.0 = disabled).
+
+Locally typical sampling promotes the generation of contextually coherent and diverse text by sampling tokens that are typical or expected based on the surrounding context. By setting the parameter p between 0 and 1, you can control the balance between producing text that is locally coherent and diverse. A value closer to 1 will promote more contextually coherent tokens, while a value closer to 0 will promote more diverse tokens. A value equal to 1 disables locally typical sampling.
+
+Example usage: `--typical 0.9`
+
+### Mirostat Sampling
+
+-   `--mirostat N`: Enable Mirostat sampling, controlling perplexity during text generation (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0).
+-   `--mirostat-lr N`: Set the Mirostat learning rate, parameter eta (default: 0.1).
+-   `--mirostat-ent N`: Set the Mirostat target entropy, parameter tau (default: 5.0).
+
+Mirostat is an algorithm that actively maintains the quality of generated text within a desired range during text generation. It aims to strike a balance between coherence and diversity, avoiding low-quality output caused by excessive repetition (boredom traps) or incoherence (confusion traps).
+
+The `--mirostat-lr` option sets the Mirostat learning rate (eta). The learning rate influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. The default value is `0.1`.
+
+The `--mirostat-ent` option sets the Mirostat target entropy (tau), which represents the desired perplexity value for the generated text. Adjusting the target entropy allows you to control the balance between coherence and diversity in the generated text. A lower value will result in more focused and coherent text, while a higher value will lead to more diverse and potentially less coherent text. The default value is `5.0`.
+
+Example usage: `--mirostat 2 --mirostat-lr 0.05 --mirostat-ent 3.0`
+
+### XTC Sampling
+
+-   `--xtc-probability N`: Sets the chance for token removal (checked once on sampler start) (default: 0.0).
+-   `--xtc-threshold N`: Sets a minimum probability threshold for tokens to be removed (default: 0.1).
+
+Exclude Top Choices (XTC) is a unique sampler that is designed to remove top tokens from consideration and avoid more obvious and repetitive outputs. With a chance of `xtc-probability` it searches for tokens with probabilities of `xtc-threshold` and above, then removes all such tokens except the least probable one.
+
+By removing top tokens XTC can improve the variety of answers, break writing clichés and inhibit repition, since clichés and repeated phrases are usually more likely to appear. By keeping the last token above the threshold, XTC ensures that the answer is still coherent. XTC is meant to be used for creative tasks, but feel free to experiment with different settings for different models.
+
+Being experimental and unique, XTC is disabled by default. The recommended combination of samplers is Min-P followed by XTC on its default settings: `--sampling-seq mx --min-p 0.02 --xtc-probability 0.5`.
+
+Example usage: `--xtc-probability 0.5 --xtc-threshold 0.1`
+
+### Top-nσ Sampling
+
+-   `--top-nsigma N`: Limit the next token selection to a subset of tokens with pre-softmax logits that are within n * σ less than the max logit (default: -1, -1 = disabled).
+
+Top-nσ sampling is a text generation method that selects tokens based on a statistical threshold in pre-softmax logits. It works by only sampling from tokens with logits that are within n * σ of the maximum logit. This method helps maintain a stable sampling space regardless of temperature scaling, allowing it to perform well on reasoning tasks even in high temperatures. Without complex probability manipulation, it efficiently filters tokens directly on the pre-softmax logits. A higher value for top-nsigma (e.g., 5) will take more noisy tokens into consideration, while a lower value (e.g., 1) will focous on the more informative region of the sampling space.
+
+Example usage: `--top-nsigma 1`
+
+### Logit Bias
+
+-   `-l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS`: Modify the likelihood of a token appearing in the generated text completion.
+
+The logit bias option allows you to manually adjust the likelihood of specific tokens appearing in the generated text. By providing a token ID and a positive or negative bias value, you can increase or decrease the probability of that token being generated.
+
+For example, use `--logit-bias 15043+1` to increase the likelihood of the token 'Hello', or `--logit-bias 15043-1` to decrease its likelihood. Using a value of negative infinity, `--logit-bias 15043-inf` ensures that the token `Hello` is never produced.
+
+A more practical use case might be to prevent the generation of `\code{begin}` and `\code{end}` by setting the `\` token (29905) to negative infinity with `-l 29905-inf`. (This is due to the prevalence of LaTeX codes that show up in LLaMA model inference.)
+
+Example usage: `--logit-bias 29905-inf`
+
+### RNG Seed
+
+-   `-s SEED, --seed SEED`: Set the random number generator (RNG) seed (default: -1, -1 = random seed).
+
+The RNG seed is used to initialize the random number generator that influences the text generation process. By setting a specific seed value, you can obtain consistent and reproducible results across multiple runs with the same input and settings. This can be helpful for testing, debugging, or comparing the effects of different options on the generated text to see when they diverge. If the seed is set to a value less than 0, a random seed will be used, which will result in different outputs on each run.
+
+## Performance Tuning and Memory Options
+
+These options help improve the performance and memory usage of the LLaMA models. By adjusting these settings, you can fine-tune the model's behavior to better suit your system's capabilities and achieve optimal performance for your specific use case.
+
+### Number of Threads
+
+-   `-t N, --threads N`: Set the number of threads to use during generation. For optimal performance, it is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). Using the correct number of threads can greatly improve performance.
+-   `-tb N, --threads-batch N`: Set the number of threads to use during batch and prompt processing. In some systems, it is beneficial to use a higher number of threads during batch processing than during generation. If not specified, the number of threads used for batch processing will be the same as the number of threads used for generation.
+
+### Mlock
+
+-   `--mlock`: Lock the model in memory, preventing it from being swapped out when memory-mapped. This can improve performance but trades away some of the advantages of memory-mapping by requiring more RAM to run and potentially slowing down load times as the model loads into RAM.
+
+### No Memory Mapping
+
+-   `--no-mmap`: Do not memory-map the model. By default, models are mapped into memory, which allows the system to load only the necessary parts of the model as needed. However, if the model is larger than your total amount of RAM or if your system is low on available memory, using mmap might increase the risk of pageouts, negatively impacting performance. Disabling mmap results in slower load times but may reduce pageouts if you're not using `--mlock`. Note that if the model is larger than the total amount of RAM, turning off mmap would prevent the model from loading at all.
+
+### NUMA support
+
+-   `--numa distribute`: Pin an equal proportion of the threads to the cores on each NUMA node. This will spread the load amongst all cores on the system, utilitizing all memory channels at the expense of potentially requiring memory to travel over the slow links between nodes.
+-   `--numa isolate`: Pin all threads to the NUMA node that the program starts on. This limits the number of cores and amount of memory that can be used, but guarantees all memory access remains local to the NUMA node.
+-   `--numa numactl`: Pin threads to the CPUMAP that is passed to the program by starting it with the numactl utility. This is the most flexible mode, and allow arbitrary core usage patterns, for example a map that uses all the cores on one NUMA nodes, and just enough cores on a second node to saturate the inter-node memory bus.
+
+ These flags attempt optimizations that help on some systems with non-uniform memory access. This currently consists of one of the above strategies, and disabling prefetch and readahead for mmap. The latter causes mapped pages to be faulted in on first access instead of all at once, and in combination with pinning threads to NUMA nodes, more of the pages end up on the NUMA node where they are used. Note that if the model is already in the system page cache, for example because of a previous run without this option, this will have little effect unless you drop the page cache first. This can be done by rebooting the system or on Linux by writing '3' to '/proc/sys/vm/drop_caches' as root.
+
+### Batch Size
+
+- `-ub N`, `--ubatch-size N`: Physical batch size. This is the maximum number of tokens that may be processed at a time. Increasing this value may improve performance during prompt processing, at the expense of higher memory usage. Default: `512`.
+
+- `-b N`, `--batch-size N`: Logical batch size. Increasing this value above the value of the physical batch size may improve prompt processing performance when using multiple GPUs with pipeline parallelism. Default: `2048`.
+
+### Prompt Caching
+
+-   `--prompt-cache FNAME`: Specify a file to cache the model state after the initial prompt. This can significantly speed up the startup time when you're using longer prompts. The file is created during the first run and is reused and updated in subsequent runs. **Note**: Restoring a cached prompt does not imply restoring the exact state of the session at the point it was saved. So even when specifying a specific seed, you are not guaranteed to get the same sequence of tokens as the original generation.
+
+### Grammars & JSON schemas
+
+-   `--grammar GRAMMAR`, `--grammar-file FILE`: Specify a grammar (defined inline or in a file) to constrain model output to a specific format. For example, you could force the model to output JSON or to speak only in emojis. See the [GBNF guide](../../grammars/README.md) for details on the syntax.
+
+-   `--json-schema SCHEMA`: Specify a [JSON schema](https://json-schema.org/) to constrain model output to (e.g. `{}` for any JSON object, or `{"items": {"type": "string", "minLength": 10, "maxLength": 100}, "minItems": 10}` for a JSON array of strings with size constraints). If a schema uses external `$ref`s, you should use `--grammar "$( python examples/json_schema_to_grammar.py myschema.json )"` instead.
+
+### Quantization
+
+For information about 4-bit quantization, which can significantly improve performance and reduce memory usage, please refer to llama.cpp's primary [README](../../README.md#prepare-and-quantize).
+
+## LoRA (Low-Rank Adaptation) adapters
+
+-   `--lora FNAME`: Optional path to a LoRA adapter to use with scaling of 1.0. Can be mixed with `--lora-scaled` and can be repeated to use multiple adapters.
+-   `--lora-scaled FNAME`: Optional path to a LoRA adapter with user-defined scaling. Can be mixed with `--lora` and can repeated to use multiple adapters.
+
+You can add LoRA adapters using `--lora` or `--lora-scaled`. For example: `--lora my_adapter_1.gguf --lora my_adapter_2.gguf ...` or `--lora-scaled lora_task_A.gguf 0.5 --lora-scaled lora_task_B.gguf 0.5`.
+
+LoRA adapters should be in GGUF format. To convert from Hugging Face format use the `convert-lora-to-gguf.py` script. LoRA adapters are loaded separately and applied during inference - they are not merged with the main model. This means that mmap model loading is fully supported when using LoRA adapters. The old `--lora-base` flag has been removed now that merging is no longer performed.
+
+## Additional Options
+
+These options provide extra functionality and customization when running the LLaMA models:
+
+-   `-h, --help`: Display a help message showing all available options and their default values. This is particularly useful for checking the latest options and default values, as they can change frequently, and the information in this document may become outdated.
+-   `--verbose-prompt`: Print the prompt before generating text.
+-   `--no-display-prompt`: Don't print prompt at generation.
+-   `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used.
+-   `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance.
+-   `-hfr URL --hf-repo URL`: The url to the Hugging Face model repository. Used in conjunction with `--hf-file` or `-hff`. The model is downloaded and stored in the file provided by `-m` or `--model`. If `-m` is not provided, the model is auto-stored in the path specified by the `LLAMA_CACHE` environment variable  or in an OS-specific local cache.
diff --git a/tools/main/main.cpp b/tools/main/main.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..19b247b0d672f9c1874fde92b212bfd832db7a34
--- /dev/null
+++ b/tools/main/main.cpp
@@ -0,0 +1,979 @@
+#include "arg.h"
+#include "common.h"
+#include "console.h"
+#include "log.h"
+#include "sampling.h"
+#include "llama.h"
+#include "chat.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
+#include 
+#include 
+#elif defined (_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+#define NOMINMAX
+#endif
+#include 
+#include 
+#endif
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+static llama_context           ** g_ctx;
+static llama_model             ** g_model;
+static common_sampler          ** g_smpl;
+static common_params            * g_params;
+static std::vector * g_input_tokens;
+static std::ostringstream       * g_output_ss;
+static std::vector * g_output_tokens;
+static bool is_interacting  = false;
+static bool need_insert_eot = false;
+
+static void print_usage(int argc, char ** argv) {
+    (void) argc;
+
+    LOG("\nexample usage:\n");
+    LOG("\n  text generation:     %s -m your_model.gguf -p \"I believe the meaning of life is\" -n 128 -no-cnv\n", argv[0]);
+    LOG("\n  chat (conversation): %s -m your_model.gguf -sys \"You are a helpful assistant\"\n", argv[0]);
+    LOG("\n");
+}
+
+static bool file_exists(const std::string & path) {
+    std::ifstream f(path.c_str());
+    return f.good();
+}
+
+static bool file_is_empty(const std::string & path) {
+    std::ifstream f;
+    f.exceptions(std::ifstream::failbit | std::ifstream::badbit);
+    f.open(path.c_str(), std::ios::in | std::ios::binary | std::ios::ate);
+    return f.tellg() == 0;
+}
+
+#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
+static void sigint_handler(int signo) {
+    if (signo == SIGINT) {
+        if (!is_interacting && g_params->interactive) {
+            is_interacting  = true;
+            need_insert_eot = true;
+        } else {
+            console::cleanup();
+            LOG("\n");
+            common_perf_print(*g_ctx, *g_smpl);
+
+            // make sure all logs are flushed
+            LOG("Interrupted by user\n");
+            common_log_pause(common_log_main());
+
+            _exit(130);
+        }
+    }
+}
+#endif
+
+int main(int argc, char ** argv) {
+    common_params params;
+    g_params = ¶ms;
+    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_MAIN, print_usage)) {
+        return 1;
+    }
+
+    common_init();
+
+    auto & sparams = params.sampling;
+
+    // save choice to use color for later
+    // (note for later: this is a slightly awkward choice)
+    console::init(params.simple_io, params.use_color);
+    atexit([]() { console::cleanup(); });
+
+    if (params.embedding) {
+        LOG_ERR("************\n");
+        LOG_ERR("%s: please use the 'embedding' tool for embedding calculations\n", __func__);
+        LOG_ERR("************\n\n");
+
+        return 0;
+    }
+
+    if (params.n_ctx != 0 && params.n_ctx < 8) {
+        LOG_WRN("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
+        params.n_ctx = 8;
+    }
+
+    if (params.rope_freq_base != 0.0) {
+        LOG_WRN("%s: warning: changing RoPE frequency base to %g.\n", __func__, params.rope_freq_base);
+    }
+
+    if (params.rope_freq_scale != 0.0) {
+        LOG_WRN("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
+    }
+
+    LOG_INF("%s: llama backend init\n", __func__);
+
+    llama_backend_init();
+    llama_numa_init(params.numa);
+
+    llama_model * model = nullptr;
+    llama_context * ctx = nullptr;
+    common_sampler * smpl = nullptr;
+
+    g_model = &model;
+    g_ctx = &ctx;
+    g_smpl = &smpl;
+
+    std::vector chat_msgs;
+
+    // load the model and apply lora adapter, if any
+    LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__);
+    common_init_result llama_init = common_init_from_params(params);
+
+    model = llama_init.model.get();
+    ctx = llama_init.context.get();
+
+    if (model == NULL) {
+        LOG_ERR("%s: error: unable to load model\n", __func__);
+        return 1;
+    }
+
+    auto * mem = llama_get_memory(ctx);
+
+    const llama_vocab * vocab = llama_model_get_vocab(model);
+    auto chat_templates = common_chat_templates_init(model, params.chat_template);
+
+    LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
+
+    auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
+    if (!cpu_dev) {
+        LOG_ERR("%s: no CPU backend found\n", __func__);
+        return 1;
+    }
+    auto * reg = ggml_backend_dev_backend_reg(cpu_dev);
+    auto * ggml_threadpool_new_fn = (decltype(ggml_threadpool_new) *) ggml_backend_reg_get_proc_address(reg, "ggml_threadpool_new");
+    auto * ggml_threadpool_free_fn = (decltype(ggml_threadpool_free) *) ggml_backend_reg_get_proc_address(reg, "ggml_threadpool_free");
+
+    struct ggml_threadpool_params tpp_batch =
+            ggml_threadpool_params_from_cpu_params(params.cpuparams_batch);
+    struct ggml_threadpool_params tpp =
+            ggml_threadpool_params_from_cpu_params(params.cpuparams);
+
+    set_process_priority(params.cpuparams.priority);
+
+    struct ggml_threadpool * threadpool_batch = NULL;
+    if (!ggml_threadpool_params_match(&tpp, &tpp_batch)) {
+        threadpool_batch = ggml_threadpool_new_fn(&tpp_batch);
+        if (!threadpool_batch) {
+            LOG_ERR("%s: batch threadpool create failed : n_threads %d\n", __func__, tpp_batch.n_threads);
+            return 1;
+        }
+
+        // Start the non-batch threadpool in the paused state
+        tpp.paused = true;
+    }
+
+    struct ggml_threadpool * threadpool = ggml_threadpool_new_fn(&tpp);
+    if (!threadpool) {
+        LOG_ERR("%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads);
+        return 1;
+    }
+
+    llama_attach_threadpool(ctx, threadpool, threadpool_batch);
+
+    const int n_ctx_train = llama_model_n_ctx_train(model);
+    const int n_ctx = llama_n_ctx(ctx);
+
+    if (n_ctx > n_ctx_train) {
+        LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx);
+    }
+
+    // auto enable conversation mode if chat template is available
+    const bool has_chat_template = common_chat_templates_was_explicit(chat_templates.get());
+    if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) {
+        if (has_chat_template) {
+            LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__);
+            params.conversation_mode = COMMON_CONVERSATION_MODE_ENABLED;
+        } else {
+            params.conversation_mode = COMMON_CONVERSATION_MODE_DISABLED;
+        }
+    }
+
+    // in case user force-activate conversation mode (via -cnv) without proper chat template, we show a warning
+    if (params.conversation_mode && !has_chat_template) {
+        LOG_WRN("%s: chat template is not available or is not supported. This may cause the model to output suboptimal responses\n", __func__);
+    }
+
+    // print chat template example in conversation mode
+    if (params.conversation_mode) {
+        if (params.enable_chat_template) {
+            if (!params.prompt.empty() && params.system_prompt.empty()) {
+                LOG_WRN("*** User-specified prompt will pre-start conversation, did you mean to set --system-prompt (-sys) instead?\n");
+            }
+
+            LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.get(), params.use_jinja).c_str());
+        } else {
+            LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
+        }
+    }
+
+    // print system information
+    {
+        LOG_INF("\n");
+        LOG_INF("%s\n", common_params_get_system_info(params).c_str());
+        LOG_INF("\n");
+    }
+
+    std::string path_session = params.path_prompt_cache;
+    std::vector session_tokens;
+
+    if (!path_session.empty()) {
+        LOG_INF("%s: attempting to load saved session from '%s'\n", __func__, path_session.c_str());
+        if (!file_exists(path_session)) {
+            LOG_INF("%s: session file does not exist, will create.\n", __func__);
+        } else if (file_is_empty(path_session)) {
+            LOG_INF("%s: The session file is empty. A new session will be initialized.\n", __func__);
+        } else {
+            // The file exists and is not empty
+            session_tokens.resize(n_ctx);
+            size_t n_token_count_out = 0;
+            if (!llama_state_load_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
+                LOG_ERR("%s: failed to load session file '%s'\n", __func__, path_session.c_str());
+                return 1;
+            }
+            session_tokens.resize(n_token_count_out);
+            LOG_INF("%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size());
+        }
+    }
+
+    const bool add_bos = llama_vocab_get_add_bos(vocab) && !params.use_jinja;
+    if (!llama_model_has_encoder(model)) {
+        GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
+    }
+
+    LOG_DBG("n_ctx: %d, add_bos: %d\n", n_ctx, add_bos);
+
+    std::vector embd_inp;
+
+    bool waiting_for_first_input = false;
+    auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
+        common_chat_msg new_msg;
+        new_msg.role = role;
+        new_msg.content = content;
+        auto formatted = common_chat_format_single(chat_templates.get(), chat_msgs, new_msg, role == "user", g_params->use_jinja);
+        chat_msgs.push_back(new_msg);
+        LOG_DBG("formatted: '%s'\n", formatted.c_str());
+        return formatted;
+    };
+
+    std::string prompt;
+    {
+        if (params.conversation_mode && params.enable_chat_template) {
+            if (!params.system_prompt.empty()) {
+                // format the system prompt (will use template default if empty)
+                chat_add_and_format("system", params.system_prompt);
+            }
+
+            if (!params.prompt.empty()) {
+                // format and append the user prompt
+                chat_add_and_format("user", params.prompt);
+            } else {
+                waiting_for_first_input = true;
+            }
+
+            if (!params.system_prompt.empty() || !params.prompt.empty()) {
+                common_chat_templates_inputs inputs;
+                inputs.messages = chat_msgs;
+                inputs.add_generation_prompt = !params.prompt.empty();
+
+                prompt = common_chat_templates_apply(chat_templates.get(), inputs).prompt;
+            }
+        } else {
+            // otherwise use the prompt as is
+            prompt = params.prompt;
+        }
+
+        if (params.interactive_first || !prompt.empty() || session_tokens.empty()) {
+            LOG_DBG("tokenize the prompt\n");
+            embd_inp = common_tokenize(ctx, prompt, true, true);
+        } else {
+            LOG_DBG("use session tokens\n");
+            embd_inp = session_tokens;
+        }
+
+        LOG_DBG("prompt: \"%s\"\n", prompt.c_str());
+        LOG_DBG("tokens: %s\n", string_from(ctx, embd_inp).c_str());
+    }
+
+    // Should not run without any tokens
+    if (!waiting_for_first_input && embd_inp.empty()) {
+        if (add_bos) {
+            embd_inp.push_back(llama_vocab_bos(vocab));
+            LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str());
+        } else {
+            LOG_ERR("input is empty\n");
+            return -1;
+        }
+    }
+
+    // Tokenize negative prompt
+    if ((int) embd_inp.size() > n_ctx - 4) {
+        LOG_ERR("%s: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
+        return 1;
+    }
+
+    // debug message about similarity of saved session, if applicable
+    size_t n_matching_session_tokens = 0;
+    if (!session_tokens.empty()) {
+        for (llama_token id : session_tokens) {
+            if (n_matching_session_tokens >= embd_inp.size() || id != embd_inp[n_matching_session_tokens]) {
+                break;
+            }
+            n_matching_session_tokens++;
+        }
+        if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) {
+            LOG_INF("%s: using full prompt from session file\n", __func__);
+        } else if (n_matching_session_tokens >= embd_inp.size()) {
+            LOG_INF("%s: session file has exact match for prompt!\n", __func__);
+        } else if (n_matching_session_tokens < (embd_inp.size() / 2)) {
+            LOG_WRN("%s: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n",
+                    __func__, n_matching_session_tokens, embd_inp.size());
+        } else {
+            LOG_INF("%s: session file matches %zu / %zu tokens of prompt\n",
+                    __func__, n_matching_session_tokens, embd_inp.size());
+        }
+
+        // remove any "future" tokens that we might have inherited from the previous session
+        llama_memory_seq_rm(mem, -1, n_matching_session_tokens, -1);
+    }
+
+    LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
+         embd_inp.size(), n_matching_session_tokens, embd_inp.size(), session_tokens.size());
+
+    // if we will use the cache for the full prompt without reaching the end of the cache, force
+    // reevaluation of the last token to recalculate the cached logits
+    if (!embd_inp.empty() && n_matching_session_tokens == embd_inp.size() && session_tokens.size() > embd_inp.size()) {
+        LOG_DBG("recalculate the cached logits (do): session_tokens.resize( %zu )\n", embd_inp.size() - 1);
+
+        session_tokens.resize(embd_inp.size() - 1);
+    }
+
+    // number of tokens to keep when resetting context
+    if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size()) {
+        params.n_keep = (int)embd_inp.size();
+    } else {
+        params.n_keep += add_bos; // always keep the BOS token
+    }
+
+    if (params.conversation_mode) {
+        if (params.single_turn && !params.prompt.empty()) {
+            params.interactive = false;
+            params.interactive_first = false;
+        } else {
+            params.interactive_first = true;
+        }
+    }
+
+    // enable interactive mode if interactive start is specified
+    if (params.interactive_first) {
+        params.interactive = true;
+    }
+
+    if (params.verbose_prompt) {
+        LOG_INF("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
+        LOG_INF("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
+        for (int i = 0; i < (int) embd_inp.size(); i++) {
+            LOG_INF("%6d -> '%s'\n", embd_inp[i], common_token_to_piece(ctx, embd_inp[i]).c_str());
+        }
+
+        if (params.n_keep > add_bos) {
+            LOG_INF("%s: static prompt based on n_keep: '", __func__);
+            for (int i = 0; i < params.n_keep; i++) {
+                LOG_CNT("%s", common_token_to_piece(ctx, embd_inp[i]).c_str());
+            }
+            LOG_CNT("'\n");
+        }
+        LOG_INF("\n");
+    }
+
+    // ctrl+C handling
+    {
+#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
+        struct sigaction sigint_action;
+        sigint_action.sa_handler = sigint_handler;
+        sigemptyset (&sigint_action.sa_mask);
+        sigint_action.sa_flags = 0;
+        sigaction(SIGINT, &sigint_action, NULL);
+#elif defined (_WIN32)
+        auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
+            return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
+        };
+        SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true);
+#endif
+    }
+
+    if (params.interactive) {
+        LOG_INF("%s: interactive mode on.\n", __func__);
+
+        if (!params.antiprompt.empty()) {
+            for (const auto & antiprompt : params.antiprompt) {
+                LOG_INF("Reverse prompt: '%s'\n", antiprompt.c_str());
+                if (params.verbose_prompt) {
+                    auto tmp = common_tokenize(ctx, antiprompt, false, true);
+                    for (int i = 0; i < (int) tmp.size(); i++) {
+                        LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx, tmp[i]).c_str());
+                    }
+                }
+            }
+        }
+
+        if (params.input_prefix_bos) {
+            LOG_INF("Input prefix with BOS\n");
+        }
+
+        if (!params.input_prefix.empty()) {
+            LOG_INF("Input prefix: '%s'\n", params.input_prefix.c_str());
+            if (params.verbose_prompt) {
+                auto tmp = common_tokenize(ctx, params.input_prefix, true, true);
+                for (int i = 0; i < (int) tmp.size(); i++) {
+                    LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx, tmp[i]).c_str());
+                }
+            }
+        }
+
+        if (!params.input_suffix.empty()) {
+            LOG_INF("Input suffix: '%s'\n", params.input_suffix.c_str());
+            if (params.verbose_prompt) {
+                auto tmp = common_tokenize(ctx, params.input_suffix, false, true);
+                for (int i = 0; i < (int) tmp.size(); i++) {
+                    LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx, tmp[i]).c_str());
+                }
+            }
+        }
+    }
+
+    smpl = common_sampler_init(model, sparams);
+    if (!smpl) {
+        LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
+        return 1;
+    }
+
+    LOG_INF("sampler seed: %u\n",     common_sampler_get_seed(smpl));
+    LOG_INF("sampler params: \n%s\n", sparams.print().c_str());
+    LOG_INF("sampler chain: %s\n",    common_sampler_print(smpl).c_str());
+
+    LOG_INF("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
+
+    // group-attention state
+    // number of grouped KV tokens so far (used only if params.grp_attn_n > 1)
+    int ga_i = 0;
+
+    const int ga_n = params.grp_attn_n;
+    const int ga_w = params.grp_attn_w;
+
+    if (ga_n != 1) {
+        GGML_ASSERT(ga_n > 0                    && "grp_attn_n must be positive");                     // NOLINT
+        GGML_ASSERT(ga_w % ga_n == 0            && "grp_attn_w must be a multiple of grp_attn_n");     // NOLINT
+      //GGML_ASSERT(n_ctx_train % ga_w == 0     && "n_ctx_train must be a multiple of grp_attn_w");    // NOLINT
+      //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * grp_attn_n"); // NOLINT
+        LOG_INF("self-extend: n_ctx_train = %d, grp_attn_n = %d, grp_attn_w = %d\n", n_ctx_train, ga_n, ga_w);
+    }
+    LOG_INF("\n");
+
+    if (params.interactive) {
+        const char * control_message;
+        if (params.multiline_input) {
+            control_message = " - To return control to the AI, end your input with '\\'.\n"
+                              " - To return control without starting a new line, end your input with '/'.\n";
+        } else {
+            control_message = " - Press Return to return control to the AI.\n"
+                              " - To return control without starting a new line, end your input with '/'.\n"
+                              " - If you want to submit another line, end your input with '\\'.\n";
+        }
+        LOG_INF("== Running in interactive mode. ==\n");
+#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
+        LOG_INF(       " - Press Ctrl+C to interject at any time.\n");
+#endif
+        LOG_INF(       "%s", control_message);
+        if (params.conversation_mode && params.enable_chat_template && params.system_prompt.empty()) {
+            LOG_INF(   " - Not using system message. To change it, set a different value via -sys PROMPT\n");
+        }
+        LOG_INF("\n");
+
+        is_interacting = params.interactive_first;
+    }
+
+    bool is_antiprompt        = false;
+    bool input_echo           = true;
+    bool display              = true;
+    bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < embd_inp.size();
+
+    int n_past             = 0;
+    int n_remain           = params.n_predict;
+    int n_consumed         = 0;
+    int n_session_consumed = 0;
+
+    std::vector   input_tokens;  g_input_tokens  = &input_tokens;
+    std::vector   output_tokens; g_output_tokens = &output_tokens;
+    std::ostringstream output_ss;     g_output_ss     = &output_ss;
+    std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode
+
+    // the first thing we will do is to output the prompt, so set color accordingly
+    console::set_display(console::prompt);
+    display = params.display_prompt;
+
+    std::vector embd;
+
+    // single-token antiprompts
+    std::vector antiprompt_token;
+
+    for (const std::string & antiprompt : params.antiprompt) {
+        auto ids = ::common_tokenize(ctx, antiprompt, false, true);
+        if (ids.size() == 1) {
+            antiprompt_token.push_back(ids[0]);
+        }
+    }
+
+    if (llama_model_has_encoder(model)) {
+        int enc_input_size = embd_inp.size();
+        llama_token * enc_input_buf = embd_inp.data();
+
+        if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size))) {
+            LOG_ERR("%s : failed to eval\n", __func__);
+            return 1;
+        }
+
+        llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
+        if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
+            decoder_start_token_id = llama_vocab_bos(vocab);
+        }
+
+        embd_inp.clear();
+        embd_inp.push_back(decoder_start_token_id);
+    }
+
+    while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
+        // predict
+        if (!embd.empty()) {
+            // Note: (n_ctx - 4) here is to match the logic for commandline prompt handling via
+            // --prompt or --file which uses the same value.
+            int max_embd_size = n_ctx - 4;
+
+            // Ensure the input doesn't exceed the context size by truncating embd if necessary.
+            if ((int) embd.size() > max_embd_size) {
+                const int skipped_tokens = (int) embd.size() - max_embd_size;
+                embd.resize(max_embd_size);
+
+                console::set_display(console::error);
+                LOG_WRN("<>", skipped_tokens, skipped_tokens != 1 ? "s" : "");
+                console::set_display(console::reset);
+            }
+
+            if (ga_n == 1) {
+                // infinite text generation via context shifting
+                // if we run out of context:
+                // - take the n_keep first tokens from the original prompt (via n_past)
+                // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
+
+                if (n_past + (int) embd.size() >= n_ctx) {
+                    if (!params.ctx_shift){
+                        LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
+                        break;
+                    }
+
+                    if (params.n_predict == -2) {
+                        LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
+                        break;
+                    }
+
+                    const int n_left    = n_past - params.n_keep;
+                    const int n_discard = n_left/2;
+
+                    LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
+                            n_past, n_left, n_ctx, params.n_keep, n_discard);
+
+                    llama_memory_seq_rm (mem, 0, params.n_keep            , params.n_keep + n_discard);
+                    llama_memory_seq_add(mem, 0, params.n_keep + n_discard, n_past, -n_discard);
+
+                    n_past -= n_discard;
+
+                    LOG_DBG("after swap: n_past = %d\n", n_past);
+
+                    LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
+
+                    LOG_DBG("clear session path\n");
+                    path_session.clear();
+                }
+            } else {
+                // context extension via Self-Extend
+                while (n_past >= ga_i + ga_w) {
+                    const int ib = (ga_n*ga_i)/ga_w;
+                    const int bd = (ga_w/ga_n)*(ga_n - 1);
+                    const int dd = (ga_w/ga_n) - ib*bd - ga_w;
+
+                    LOG_DBG("\n");
+                    LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i, n_past, ib*bd, ga_i + ib*bd, n_past + ib*bd);
+                    LOG_DBG("div:   [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
+                    LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
+
+                    llama_memory_seq_add(mem, 0, ga_i,                n_past,              ib*bd);
+                    llama_memory_seq_div(mem, 0, ga_i + ib*bd,        ga_i + ib*bd + ga_w, ga_n);
+                    llama_memory_seq_add(mem, 0, ga_i + ib*bd + ga_w, n_past + ib*bd,      dd);
+
+                    n_past -= bd;
+
+                    ga_i += ga_w/ga_n;
+
+                    LOG_DBG("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", n_past + bd, n_past, ga_i);
+                }
+            }
+
+            // try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
+            if (n_session_consumed < (int) session_tokens.size()) {
+                size_t i = 0;
+                for ( ; i < embd.size(); i++) {
+                    if (embd[i] != session_tokens[n_session_consumed]) {
+                        session_tokens.resize(n_session_consumed);
+                        break;
+                    }
+
+                    n_past++;
+                    n_session_consumed++;
+
+                    if (n_session_consumed >= (int) session_tokens.size()) {
+                        ++i;
+                        break;
+                    }
+                }
+                if (i > 0) {
+                    embd.erase(embd.begin(), embd.begin() + i);
+                }
+            }
+
+            for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
+                int n_eval = (int) embd.size() - i;
+                if (n_eval > params.n_batch) {
+                    n_eval = params.n_batch;
+                }
+
+                LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
+
+                if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
+                    LOG_ERR("%s : failed to eval\n", __func__);
+                    return 1;
+                }
+
+                n_past += n_eval;
+
+                LOG_DBG("n_past = %d\n", n_past);
+                // Display total tokens alongside total time
+                if (params.n_print > 0 && n_past % params.n_print == 0) {
+                    LOG_DBG("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx);
+                }
+            }
+
+            if (!embd.empty() && !path_session.empty()) {
+                session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
+                n_session_consumed = session_tokens.size();
+            }
+        }
+
+        embd.clear();
+
+        if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
+            // optionally save the session on first sample (for faster prompt loading next time)
+            if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
+                need_to_save_session = false;
+                llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
+
+                LOG_DBG("saved session to %s\n", path_session.c_str());
+            }
+
+            const llama_token id = common_sampler_sample(smpl, ctx, -1);
+
+            common_sampler_accept(smpl, id, /* accept_grammar= */ true);
+
+            // LOG_DBG("last: %s\n", string_from(ctx, smpl->prev.to_vector()).c_str());
+
+            embd.push_back(id);
+
+            // echo this to console
+            input_echo = true;
+
+            // decrement remaining sampling budget
+            --n_remain;
+
+            LOG_DBG("n_remain: %d\n", n_remain);
+        } else {
+            // some user input remains from prompt or interaction, forward it to processing
+            LOG_DBG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
+            while ((int) embd_inp.size() > n_consumed) {
+                embd.push_back(embd_inp[n_consumed]);
+
+                // push the prompt in the sampling context in order to apply repetition penalties later
+                // for the prompt, we don't apply grammar rules
+                common_sampler_accept(smpl, embd_inp[n_consumed], /* accept_grammar= */ false);
+
+                ++n_consumed;
+                if ((int) embd.size() >= params.n_batch) {
+                    break;
+                }
+            }
+        }
+
+        // display text
+        if (input_echo && display) {
+            for (auto id : embd) {
+                const std::string token_str = common_token_to_piece(ctx, id, params.special);
+
+                // Console/Stream Output
+                LOG("%s", token_str.c_str());
+
+                // Record Displayed Tokens To Log
+                // Note: Generated tokens are created one by one hence this check
+                if (embd.size() > 1) {
+                    // Incoming Requested Tokens
+                    input_tokens.push_back(id);
+                } else {
+                    // Outgoing Generated Tokens
+                    output_tokens.push_back(id);
+                    output_ss << token_str;
+                }
+            }
+        }
+
+        // reset color to default if there is no pending user input
+        if (input_echo && (int) embd_inp.size() == n_consumed) {
+            console::set_display(console::reset);
+            display = true;
+        }
+
+        // if not currently processing queued inputs;
+        if ((int) embd_inp.size() <= n_consumed) {
+            // check for reverse prompt in the last n_prev tokens
+            if (!params.antiprompt.empty()) {
+                const int n_prev = 32;
+                const std::string last_output = common_sampler_prev_str(smpl, ctx, n_prev);
+
+                is_antiprompt = false;
+                // Check if each of the reverse prompts appears at the end of the output.
+                // If we're not running interactively, the reverse prompt might be tokenized with some following characters
+                // so we'll compensate for that by widening the search window a bit.
+                for (std::string & antiprompt : params.antiprompt) {
+                    size_t extra_padding = params.interactive ? 0 : 2;
+                    size_t search_start_pos = last_output.length() > static_cast(antiprompt.length() + extra_padding)
+                        ? last_output.length() - static_cast(antiprompt.length() + extra_padding)
+                        : 0;
+
+                    if (last_output.find(antiprompt, search_start_pos) != std::string::npos) {
+                        if (params.interactive) {
+                            is_interacting = true;
+                        }
+                        is_antiprompt = true;
+                        break;
+                    }
+                }
+
+                // check for reverse prompt using special tokens
+                llama_token last_token = common_sampler_last(smpl);
+                for (auto token : antiprompt_token) {
+                    if (token == last_token) {
+                        if (params.interactive) {
+                            is_interacting = true;
+                        }
+                        is_antiprompt = true;
+                        break;
+                    }
+                }
+
+                if (is_antiprompt) {
+                    LOG_DBG("found antiprompt: %s\n", last_output.c_str());
+                }
+            }
+
+            // deal with end of generation tokens in interactive mode
+            if (!waiting_for_first_input && llama_vocab_is_eog(vocab, common_sampler_last(smpl))) {
+                LOG_DBG("found an EOG token\n");
+
+                if (params.interactive) {
+                    if (!params.antiprompt.empty()) {
+                        // tokenize and inject first reverse prompt
+                        const auto first_antiprompt = common_tokenize(ctx, params.antiprompt.front(), false, true);
+                        embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
+                        is_antiprompt = true;
+                    }
+
+                    if (params.enable_chat_template) {
+                        chat_add_and_format("assistant", assistant_ss.str());
+                    }
+                    is_interacting = true;
+                    LOG("\n");
+                }
+            }
+
+            // if current token is not EOG, we add it to current assistant message
+            if (params.conversation_mode && !waiting_for_first_input) {
+                const auto id = common_sampler_last(smpl);
+                assistant_ss << common_token_to_piece(ctx, id, false);
+
+                if (!prompt.empty()) {
+                    prompt.clear();
+                    is_interacting = false;
+                }
+            }
+
+            if ((n_past > 0 || waiting_for_first_input) && is_interacting) {
+                LOG_DBG("waiting for user input\n");
+
+                if (params.conversation_mode) {
+                    LOG("\n> ");
+                }
+
+                if (params.input_prefix_bos) {
+                    LOG_DBG("adding input prefix BOS token\n");
+                    embd_inp.push_back(llama_vocab_bos(vocab));
+                }
+
+                std::string buffer;
+                if (!params.input_prefix.empty() && !params.conversation_mode) {
+                    LOG_DBG("appending input prefix: '%s'\n", params.input_prefix.c_str());
+                    LOG("%s", params.input_prefix.c_str());
+                }
+
+                // color user input only
+                console::set_display(console::user_input);
+                display = params.display_prompt;
+
+                std::string line;
+                bool another_line = true;
+                do {
+                    another_line = console::readline(line, params.multiline_input);
+                    buffer += line;
+                } while (another_line);
+
+                // done taking input, reset color
+                console::set_display(console::reset);
+                display = true;
+
+                if (buffer.empty()) { // Ctrl+D on empty line exits
+                    LOG("EOF by user\n");
+                    break;
+                }
+
+                if (buffer.back() == '\n') {
+                    // Implement #587:
+                    // If the user wants the text to end in a newline,
+                    // this should be accomplished by explicitly adding a newline by using \ followed by return,
+                    // then returning control by pressing return again.
+                    buffer.pop_back();
+                }
+
+                if (buffer.empty()) { // Enter key on empty line lets the user pass control back
+                    LOG_DBG("empty line, passing control back\n");
+                } else { // Add tokens to embd only if the input buffer is non-empty
+                    // append input suffix if any
+                    if (!params.input_suffix.empty() && !params.conversation_mode) {
+                        LOG_DBG("appending input suffix: '%s'\n", params.input_suffix.c_str());
+                        LOG("%s", params.input_suffix.c_str());
+                    }
+
+                    LOG_DBG("buffer: '%s'\n", buffer.c_str());
+
+                    const size_t original_size = embd_inp.size();
+
+                    if (params.escape) {
+                        string_process_escapes(buffer);
+                    }
+
+                    bool format_chat = params.conversation_mode && params.enable_chat_template;
+                    std::string user_inp = format_chat
+                        ? chat_add_and_format("user", std::move(buffer))
+                        : std::move(buffer);
+                    // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
+                    const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true);
+                    const auto line_inp = common_tokenize(ctx, user_inp,            false, format_chat);
+                    const auto line_sfx = common_tokenize(ctx, params.input_suffix, false, true);
+
+                    LOG_DBG("input tokens: %s\n", string_from(ctx, line_inp).c_str());
+
+                    // if user stop generation mid-way, we must add EOT to finish model's last response
+                    if (need_insert_eot && format_chat) {
+                        llama_token eot = llama_vocab_eot(vocab);
+                        embd_inp.push_back(eot == LLAMA_TOKEN_NULL ? llama_vocab_eos(vocab) : eot);
+                        need_insert_eot = false;
+                    }
+
+                    embd_inp.insert(embd_inp.end(), line_pfx.begin(), line_pfx.end());
+                    embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
+                    embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
+
+                    for (size_t i = original_size; i < embd_inp.size(); ++i) {
+                        const llama_token token = embd_inp[i];
+                        output_tokens.push_back(token);
+                        output_ss << common_token_to_piece(ctx, token);
+                    }
+
+                    // reset assistant message
+                    assistant_ss.str("");
+
+                    n_remain -= line_inp.size();
+                    LOG_DBG("n_remain: %d\n", n_remain);
+                }
+
+                input_echo = false; // do not echo this again
+            }
+
+            if (n_past > 0 || waiting_for_first_input) {
+                if (is_interacting) {
+                    common_sampler_reset(smpl);
+                }
+                is_interacting = false;
+
+                if (waiting_for_first_input && params.single_turn) {
+                    params.interactive = false;
+                    params.interactive_first = false;
+                }
+                waiting_for_first_input = false;
+            }
+        }
+
+        // end of generation
+        if (!embd.empty() && llama_vocab_is_eog(vocab, embd.back()) && !(params.interactive)) {
+            LOG(" [end of text]\n");
+            break;
+        }
+
+        // In interactive mode, respect the maximum number of tokens and drop back to user input when reached.
+        // We skip this logic when n_predict == -1 (infinite) or -2 (stop at context size).
+        if (params.interactive && n_remain <= 0 && params.n_predict >= 0) {
+            n_remain = params.n_predict;
+            is_interacting = true;
+        }
+    }
+
+    if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) {
+        LOG("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());
+        llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
+    }
+
+    LOG("\n\n");
+    common_perf_print(ctx, smpl);
+
+    common_sampler_free(smpl);
+
+    llama_backend_free();
+
+    ggml_threadpool_free_fn(threadpool);
+    ggml_threadpool_free_fn(threadpool_batch);
+
+    return 0;
+}
diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4baa15b9609fc173ff26f5dcccbd4ed9300c2a86
--- /dev/null
+++ b/tools/mtmd/CMakeLists.txt
@@ -0,0 +1,60 @@
+# mtmd
+
+find_package(Threads REQUIRED)
+
+add_library(mtmd
+            mtmd.cpp
+            mtmd-audio.cpp
+            mtmd.h
+            clip.cpp
+            clip.h
+            clip-impl.h
+            mtmd-helper.cpp
+            mtmd-helper.h
+            )
+
+target_link_libraries     (mtmd PUBLIC ggml llama)
+target_link_libraries     (mtmd PRIVATE Threads::Threads)
+target_include_directories(mtmd PUBLIC  .)
+target_include_directories(mtmd PRIVATE ../..)
+target_include_directories(mtmd PRIVATE ../../vendor)
+target_compile_features   (mtmd PRIVATE cxx_std_17)
+
+if (BUILD_SHARED_LIBS)
+    set_target_properties     (mtmd PROPERTIES POSITION_INDEPENDENT_CODE ON)
+    target_compile_definitions(mtmd PRIVATE LLAMA_BUILD)
+    target_compile_definitions(mtmd PUBLIC  LLAMA_SHARED)
+endif()
+
+set(MTMD_PUBLIC_HEADERS
+    ${CMAKE_CURRENT_SOURCE_DIR}/mtmd.h
+    ${CMAKE_CURRENT_SOURCE_DIR}/mtmd-helper.h
+    )
+
+set_target_properties(mtmd
+    PROPERTIES
+    PUBLIC_HEADER "${MTMD_PUBLIC_HEADERS}")
+
+install(TARGETS mtmd LIBRARY PUBLIC_HEADER)
+
+if (NOT MSVC)
+    # for stb_image.h and miniaudio.h
+    target_compile_options(mtmd PRIVATE -Wno-cast-qual)
+endif()
+
+if (TARGET BUILD_INFO)
+    add_dependencies(mtmd        BUILD_INFO)
+    add_dependencies(mtmd-helper BUILD_INFO)
+endif()
+
+add_executable(llama-llava-cli    deprecation-warning.cpp)
+add_executable(llama-gemma3-cli   deprecation-warning.cpp)
+add_executable(llama-minicpmv-cli deprecation-warning.cpp)
+add_executable(llama-qwen2vl-cli  deprecation-warning.cpp)
+
+set(TARGET llama-mtmd-cli)
+add_executable         (${TARGET} mtmd-cli.cpp)
+set_target_properties  (${TARGET} PROPERTIES OUTPUT_NAME llama-mtmd-cli)
+install                (TARGETS ${TARGET} RUNTIME)
+target_link_libraries  (${TARGET} PRIVATE common mtmd Threads::Threads)
+target_compile_features(${TARGET} PRIVATE cxx_std_17)
diff --git a/tools/mtmd/README.md b/tools/mtmd/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..ef31d1957cdabb16771bdec3da604ccf81cbe615
--- /dev/null
+++ b/tools/mtmd/README.md
@@ -0,0 +1,63 @@
+# Multimodal Support in llama.cpp
+
+This directory provides multimodal capabilities for `llama.cpp`. Initially intended as a showcase for running LLaVA models, its scope has expanded significantly over time to include various other vision-capable models. As a result, LLaVA is no longer the only multimodal architecture supported.
+
+> [!IMPORTANT]
+>
+> Multimodal support can be viewed as a sub-project within `llama.cpp`. It is under **very heavy development**, and **breaking changes are expected**.
+
+The naming and structure related to multimodal support have evolved, which might cause some confusion. Here's a brief timeline to clarify:
+
+- [#3436](https://github.com/ggml-org/llama.cpp/pull/3436): Initial support for LLaVA 1.5 was added, introducing `llava.cpp` and `clip.cpp`. The `llava-cli` binary was created for model interaction.
+- [#4954](https://github.com/ggml-org/llama.cpp/pull/4954): Support for MobileVLM was added, becoming the second vision model supported. This built upon the existing `llava.cpp`, `clip.cpp`, and `llava-cli` infrastructure.
+- **Expansion & Fragmentation:** Many new models were subsequently added (e.g., [#7599](https://github.com/ggml-org/llama.cpp/pull/7599), [#10361](https://github.com/ggml-org/llama.cpp/pull/10361), [#12344](https://github.com/ggml-org/llama.cpp/pull/12344), and others). However, `llava-cli` lacked support for the increasingly complex chat templates required by these models. This led to the creation of model-specific binaries like `qwen2vl-cli`, `minicpmv-cli`, and `gemma3-cli`. While functional, this proliferation of command-line tools became confusing for users.
+- [#12849](https://github.com/ggml-org/llama.cpp/pull/12849): `libmtmd` was introduced as a replacement for `llava.cpp`. Its goals include providing a single, unified command-line interface, improving the user/developer experience (UX/DX), and supporting both audio and image inputs.
+- [#13012](https://github.com/ggml-org/llama.cpp/pull/13012): `mtmd-cli` was added, consolidating the various model-specific CLIs into a single tool powered by `libmtmd`.
+
+## Pre-quantized models
+
+See the list of pre-quantized model [here](../../docs/multimodal.md)
+
+## How it works and what is `mmproj`?
+
+Multimodal support in `llama.cpp` works by encoding images into embeddings using a separate model component, and then feeding these embeddings into the language model.
+
+This approach keeps the multimodal components distinct from the core `libllama` library. Separating these allows for faster, independent development cycles. While many modern vision models are based on Vision Transformers (ViTs), their specific pre-processing and projection steps can vary significantly. Integrating this diverse complexity directly into `libllama` is currently challenging.
+
+Consequently, running a multimodal model typically requires two GGUF files:
+1.  The standard language model file.
+2.  A corresponding **multimodal projector (`mmproj`)** file, which handles the image encoding and projection.
+
+## What is `libmtmd`?
+
+As outlined in the history, `libmtmd` is the modern library designed to replace the original `llava.cpp` implementation for handling multimodal inputs.
+
+Built upon `clip.cpp` (similar to `llava.cpp`), `libmtmd` offers several advantages:
+- **Unified Interface:** Aims to consolidate interaction for various multimodal models.
+- **Improved UX/DX:** Features a more intuitive API, inspired by the `Processor` class in the Hugging Face `transformers` library.
+- **Flexibility:** Designed to support multiple input types (text, audio, images) while respecting the wide variety of chat templates used by different models.
+
+## How to obtain `mmproj`
+
+Multimodal projector (`mmproj`) files are specific to each model architecture.
+
+For the following models, you can use `convert_hf_to_gguf.py` with `--mmproj` flag to get the `mmproj` file:
+- [Gemma 3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) ; See the guide [here](../../docs/multimodal/gemma3.md) - Note: 1B variant does not have vision support
+- SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
+- SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
+- [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint
+- Qwen 2 VL and Qwen 2.5 VL (from [Qwen](https://huggingface.co/Qwen))
+- [Mistral Small 3.1 24B](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)
+- InternVL 2.5 and InternVL 3 from [OpenGVLab](https://huggingface.co/OpenGVLab) (note: we don't support conversion of `InternVL3-*-hf` model, only non-HF version is supported ; `InternLM2Model` **text** model is not supported)
+
+For older models, please refer to the relevant guide for instructions on how to obtain or create them:
+
+NOTE: conversion scripts are located under `tools/mtmd/legacy-models`
+
+- [LLaVA](../../docs/multimodal/llava.md)
+- [MobileVLM](../../docs/multimodal/MobileVLM.md)
+- [GLM-Edge](../../docs/multimodal/glmedge.md)
+- [MiniCPM-V 2.5](../../docs/multimodal/minicpmv2.5.md)
+- [MiniCPM-V 2.6](../../docs/multimodal/minicpmv2.6.md)
+- [MiniCPM-o 2.6](../../docs/multimodal/minicpmo2.6.md)
+- [IBM Granite Vision](../../docs/multimodal/granitevision.md)
diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h
new file mode 100644
index 0000000000000000000000000000000000000000..62c936ed00f7752c35dddbcdf3c0771fb6eafd08
--- /dev/null
+++ b/tools/mtmd/clip-impl.h
@@ -0,0 +1,467 @@
+#include "ggml.h"
+#include "gguf.h"
+#include "clip.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+// Internal header for clip.cpp
+
+#define KEY_FTYPE               "general.file_type"
+#define KEY_NAME                "general.name"
+#define KEY_DESCRIPTION         "general.description"
+#define KEY_PROJ_TYPE           "clip.projector_type"
+#define KEY_HAS_AUDIO_ENC       "clip.has_audio_encoder"
+#define KEY_HAS_VISION_ENC      "clip.has_vision_encoder"
+#define KEY_USE_GELU            "clip.use_gelu"
+#define KEY_USE_SILU            "clip.use_silu"
+
+#define KEY_N_EMBD              "clip.%s.embedding_length"
+#define KEY_N_FF                "clip.%s.feed_forward_length"
+#define KEY_N_BLOCK             "clip.%s.block_count"
+#define KEY_PROJ_DIM            "clip.%s.projection_dim"
+#define KEY_N_HEAD              "clip.%s.attention.head_count"
+#define KEY_LAYER_NORM_EPS      "clip.%s.attention.layer_norm_epsilon"
+
+// vision-specific
+#define KEY_IMAGE_SIZE          "clip.vision.image_size"
+#define KEY_PATCH_SIZE          "clip.vision.patch_size"
+#define KEY_IMAGE_MEAN          "clip.vision.image_mean"
+#define KEY_IMAGE_STD           "clip.vision.image_std"
+#define KEY_FEATURE_LAYER       "clip.vision.feature_layer"
+#define KEY_PROJ_SCALE_FACTOR   "clip.vision.projector.scale_factor"
+#define KEY_SPATIAL_MERGE_SIZE  "clip.vision.spatial_merge_size"
+
+#define KEY_MM_PATCH_MERGE_TYPE   "clip.vision.mm_patch_merge_type"
+#define KEY_IMAGE_GRID_PINPOINTS  "clip.vision.image_grid_pinpoints"
+#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
+#define KEY_WIN_ATTN_PATTERN      "clip.vision.n_wa_pattern"
+#define KEY_ATTN_WINDOW_SIZE      "clip.vision.window_size"
+#define KEY_MINICPMV_VERSION      "clip.minicpmv_version"
+
+// audio-specific
+#define KEY_A_NUM_MEL_BINS      "clip.audio.num_mel_bins"
+#define KEY_A_PROJ_STACK_FACTOR "clip.audio.projector.stack_factor"
+
+
+//
+// tensor name constants
+//
+
+#define TN_POS_EMBD        "%s.position_embd.weight"
+#define TN_CLASS_EMBD      "v.class_embd"
+#define TN_PATCH_EMBD      "v.patch_embd.weight"  // not rename tensor with ".0" postfix for backwrad compat
+#define TN_PATCH_EMBD_1    "v.patch_embd.weight.1"
+#define TN_PATCH_BIAS      "v.patch_embd.bias"
+#define TN_ATTN_K          "%s.blk.%d.attn_k.%s"
+#define TN_ATTN_Q          "%s.blk.%d.attn_q.%s"
+#define TN_ATTN_V          "%s.blk.%d.attn_v.%s"
+#define TN_ATTN_OUTPUT     "%s.blk.%d.attn_out.%s"
+#define TN_ATTN_K_NORM     "%s.blk.%d.attn_k_norm.%s"
+#define TN_ATTN_Q_NORM     "%s.blk.%d.attn_q_norm.%s"
+#define TN_FFN_DOWN        "%s.blk.%d.ffn_down.%s"
+#define TN_FFN_GATE        "%s.blk.%d.ffn_gate.%s"
+#define TN_FFN_UP          "%s.blk.%d.ffn_up.%s"
+#define TN_FFN_GATE        "%s.blk.%d.ffn_gate.%s"
+#define TN_LN_1            "%s.blk.%d.ln1.%s" // layer norm
+#define TN_LN_2            "%s.blk.%d.ln2.%s" // layer norm
+#define TN_LS_1            "%s.blk.%d.ls1.%s" // layer scale
+#define TN_LS_2            "%s.blk.%d.ls2.%s" // layer scale
+#define TN_LN_PRE          "%s.pre_ln.%s"
+#define TN_LN_POST         "%s.post_ln.%s"
+#define TN_LLAVA_PROJ      "mm.%d.%s"
+#define TN_MVLM_PROJ_MLP   "mm.model.mlp.%d.%s"
+#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
+#define TN_MVLM_PROJ_PEG   "mm.model.peg.%d.%s"
+#define TN_IMAGE_NEWLINE   "model.image_newline"
+#define TN_MM_INP_NORM     "mm.input_norm.weight"
+#define TN_MM_INP_PROJ     "mm.input_projection.weight" // gemma3
+#define TN_MM_SOFT_EMB_N   "mm.soft_emb_norm.weight"    // gemma3
+#define TN_MM_PROJECTOR    "mm.model.fc.weight"         // idefics3
+#define TN_MM_PATCH_MERGER "mm.patch_merger.weight"     // mistral small 3.1
+#define TN_TOK_IMG_BREAK   "v.token_embd.img_break"     // pixtral
+#define TN_TOK_GLM_BOI     "adapter.boi"                // glm-edge (these embeddings are not in text model)
+#define TN_TOK_GLM_EOI     "adapter.eoi"                // glm-edge (these embeddings are not in text model)
+
+// mimicpmv
+#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
+#define TN_MINICPMV_QUERY      "resampler.query"
+#define TN_MINICPMV_PROJ       "resampler.proj.weight"
+#define TN_MINICPMV_KV_PROJ    "resampler.kv.weight"
+#define TN_MINICPMV_ATTN       "resampler.attn.%s.%s"
+#define TN_MINICPMV_LN         "resampler.ln_%s.%s"
+
+#define TN_GLM_ADAPER_CONV      "adapter.conv.%s"
+#define TN_GLM_ADAPTER_LINEAR   "adapter.linear.linear.%s"
+#define TN_GLM_ADAPTER_NORM_1   "adapter.linear.norm1.%s"
+#define TN_GLM_ADAPTER_D_H_2_4H "adapter.linear.dense_h_to_4h.%s"
+#define TN_GLM_ADAPTER_GATE     "adapter.linear.gate.%s"
+#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s"
+
+// ultravox
+#define TN_CONV1D       "a.conv1d.%d.%s"
+#define TN_MM_AUDIO_MLP "mm.a.mlp.%d.%s"
+#define TN_MM_AUDIO_FC  "mm.a.fc.%s" // fully connected layer
+#define TN_MM_NORM_PRE  "mm.a.norm_pre.%s"
+#define TN_MM_NORM_MID  "mm.a.norm_mid.%s"
+
+// align x to upper multiple of n
+#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
+
+enum projector_type {
+    PROJECTOR_TYPE_MLP,
+    PROJECTOR_TYPE_MLP_NORM,
+    PROJECTOR_TYPE_LDP,
+    PROJECTOR_TYPE_LDPV2,
+    PROJECTOR_TYPE_MINICPMV,
+    PROJECTOR_TYPE_GLM_EDGE,
+    PROJECTOR_TYPE_QWEN2VL,
+    PROJECTOR_TYPE_GEMMA3,
+    PROJECTOR_TYPE_IDEFICS3,
+    PROJECTOR_TYPE_PIXTRAL,
+    PROJECTOR_TYPE_QWEN25VL,
+    PROJECTOR_TYPE_ULTRAVOX,
+    PROJECTOR_TYPE_INTERNVL,
+    PROJECTOR_TYPE_LLAMA4,
+    PROJECTOR_TYPE_QWEN2A,
+    PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
+    PROJECTOR_TYPE_UNKNOWN,
+};
+
+static std::map PROJECTOR_TYPE_NAMES = {
+    { PROJECTOR_TYPE_MLP,       "mlp" },
+    { PROJECTOR_TYPE_LDP,       "ldp" },
+    { PROJECTOR_TYPE_LDPV2,     "ldpv2"},
+    { PROJECTOR_TYPE_MINICPMV,  "resampler"},
+    { PROJECTOR_TYPE_GLM_EDGE,  "adapter"},
+    { PROJECTOR_TYPE_QWEN2VL,   "qwen2vl_merger"},
+    { PROJECTOR_TYPE_QWEN25VL,  "qwen2.5vl_merger"},
+    { PROJECTOR_TYPE_GEMMA3,    "gemma3"},
+    { PROJECTOR_TYPE_IDEFICS3,  "idefics3"},
+    { PROJECTOR_TYPE_PIXTRAL,   "pixtral"},
+    { PROJECTOR_TYPE_ULTRAVOX,  "ultravox"},
+    { PROJECTOR_TYPE_INTERNVL,  "internvl"},
+    { PROJECTOR_TYPE_LLAMA4,    "llama4"},
+    { PROJECTOR_TYPE_QWEN2A,    "qwen2a"},
+    { PROJECTOR_TYPE_QWEN25O,   "qwen2.5o"},
+};
+
+static projector_type clip_projector_type_from_string(const std::string & str) {
+    for (const auto & pair : PROJECTOR_TYPE_NAMES) {
+        if (pair.second == str) {
+            return pair.first;
+        }
+    }
+    return PROJECTOR_TYPE_UNKNOWN;
+}
+
+// RGB uint8 image
+struct clip_image_u8 {
+    int nx;
+    int ny;
+
+    std::vector buf;
+};
+
+// For images, buf.size() == nx*ny*3
+//     Memory layout: RGBRGBRGB...
+// For audio, only one channel is used, buf.size() == nx*ny
+//     nx will be n_frames and ny will be n_mel
+struct clip_image_f32 {
+    int nx;
+    int ny;
+
+    std::vector buf;
+};
+
+//
+// logging
+//
+
+static void clip_log_callback_default(enum ggml_log_level level, const char * text, void * user_data) {
+    (void) level;
+    (void) user_data;
+    fputs(text, stderr);
+    fflush(stderr);
+}
+
+struct clip_logger_state {
+    ggml_log_level verbosity_thold;
+    ggml_log_callback log_callback;
+    void * log_callback_user_data;
+};
+
+extern struct clip_logger_state g_logger_state;
+
+static void clip_log_internal_v(enum ggml_log_level level, const char * format, va_list args) {
+    if (format == NULL) {
+        return;
+    }
+    va_list args_copy;
+    va_copy(args_copy, args);
+    char buffer[128];
+    int len = vsnprintf(buffer, 128, format, args);
+    if (len < 128) {
+        g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data);
+    } else {
+        char * buffer2 = (char *) calloc(len + 1, sizeof(char));
+        vsnprintf(buffer2, len + 1, format, args_copy);
+        buffer2[len] = 0;
+        g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data);
+        free(buffer2);
+    }
+    va_end(args_copy);
+}
+
+static void clip_log_internal(enum ggml_log_level level, const char * format, ...) {
+    va_list args;
+    va_start(args, format);
+    clip_log_internal_v(level, format, args);
+    va_end(args);
+}
+
+#define LOG_TMPL(level, ...) \
+    do { \
+        if ((level) >= g_logger_state.verbosity_thold) { \
+            clip_log_internal((level), __VA_ARGS__); \
+        } \
+    } while (0)
+#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO,  __VA_ARGS__)
+#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN,  __VA_ARGS__)
+#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
+#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
+#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT,  __VA_ARGS__)
+
+//
+// cpp wrappers
+//
+
+// wrapper for clip_image_size
+struct clip_image_size_deleter {
+    void operator()(clip_image_size * val) { clip_image_size_free(val); }
+};
+typedef std::unique_ptr clip_image_size_ptr;
+
+// wrapper for clip_image_u8
+struct clip_image_u8_deleter {
+    void operator()(clip_image_u8 * val) { clip_image_u8_free(val); }
+};
+typedef std::unique_ptr clip_image_u8_ptr;
+
+// wrapper for clip_image_f32
+struct clip_image_f32_deleter {
+    void operator()(clip_image_f32 * val) { clip_image_f32_free(val); }
+};
+typedef std::unique_ptr clip_image_f32_ptr;
+
+struct clip_image_u8_batch {
+    std::vector entries;
+};
+
+struct clip_image_f32_batch {
+    std::vector entries;
+    bool is_audio = false;
+
+    // for llava-uhd style models, we need to know the grid size
+    // note: entries.size() == grid_x * grid_y + 1 (one overview image)
+    int grid_x = 0;
+    int grid_y = 0;
+
+    clip_image_f32_batch clone() const {
+        clip_image_f32_batch new_batch{
+            /* entries  */ {},
+            /* is_audio */ is_audio,
+            /* grid_x   */ grid_x,
+            /* grid_y   */ grid_y,
+        };
+        new_batch.entries.reserve(entries.size());
+        for (const auto & entry : entries) {
+            new_batch.entries.emplace_back(new clip_image_f32(*entry));
+        }
+        return new_batch;
+    }
+};
+
+//
+// common utils
+//
+
+static std::string string_format(const char * fmt, ...) {
+    va_list ap;
+    va_list ap2;
+    va_start(ap, fmt);
+    va_copy(ap2, ap);
+    int size = vsnprintf(NULL, 0, fmt, ap);
+    GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
+    std::vector buf(size + 1);
+    int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
+    GGML_ASSERT(size2 == size);
+    va_end(ap2);
+    va_end(ap);
+    return std::string(buf.data(), buf.size());
+}
+
+static void string_replace_all(std::string & s, const std::string & search, const std::string & replace) {
+    if (search.empty()) {
+        return;
+    }
+    std::string builder;
+    builder.reserve(s.length());
+    size_t pos = 0;
+    size_t last_pos = 0;
+    while ((pos = s.find(search, last_pos)) != std::string::npos) {
+        builder.append(s, last_pos, pos - last_pos);
+        builder.append(replace);
+        last_pos = pos + search.length();
+    }
+    builder.append(s, last_pos, std::string::npos);
+    s = std::move(builder);
+}
+
+// split string by a `std::string delim` instead of `char delim`
+static std::vector string_split_str(std::string s, const std::string & delimiter) {
+    std::vector tokens;
+    size_t pos = 0;
+    std::string token;
+    while ((pos = s.find(delimiter)) != std::string::npos) {
+        token = s.substr(0, pos);
+        tokens.push_back(token);
+        s.erase(0, pos + delimiter.length());
+    }
+    tokens.push_back(s);
+    return tokens;
+}
+
+//
+// gguf utils
+//
+
+static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) {
+    switch (type) {
+        case GGUF_TYPE_UINT8:   return std::to_string(((const uint8_t  *)data)[i]);
+        case GGUF_TYPE_INT8:    return std::to_string(((const int8_t   *)data)[i]);
+        case GGUF_TYPE_UINT16:  return std::to_string(((const uint16_t *)data)[i]);
+        case GGUF_TYPE_INT16:   return std::to_string(((const int16_t  *)data)[i]);
+        case GGUF_TYPE_UINT32:  return std::to_string(((const uint32_t *)data)[i]);
+        case GGUF_TYPE_INT32:   return std::to_string(((const int32_t  *)data)[i]);
+        case GGUF_TYPE_UINT64:  return std::to_string(((const uint64_t *)data)[i]);
+        case GGUF_TYPE_INT64:   return std::to_string(((const int64_t  *)data)[i]);
+        case GGUF_TYPE_FLOAT32: return std::to_string(((const float    *)data)[i]);
+        case GGUF_TYPE_FLOAT64: return std::to_string(((const double   *)data)[i]);
+        case GGUF_TYPE_BOOL:    return ((const bool *)data)[i] ? "true" : "false";
+        default:                return string_format("unknown type %d", type);
+    }
+}
+
+static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
+    const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
+
+    switch (type) {
+        case GGUF_TYPE_STRING:
+            return gguf_get_val_str(ctx_gguf, i);
+        case GGUF_TYPE_ARRAY:
+            {
+                const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i);
+                int arr_n = gguf_get_arr_n(ctx_gguf, i);
+                const void * data = arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx_gguf, i);
+                std::stringstream ss;
+                ss << "[";
+                for (int j = 0; j < arr_n; j++) {
+                    if (arr_type == GGUF_TYPE_STRING) {
+                        std::string val = gguf_get_arr_str(ctx_gguf, i, j);
+                        // escape quotes
+                        string_replace_all(val, "\\", "\\\\");
+                        string_replace_all(val, "\"", "\\\"");
+                        ss << '"' << val << '"';
+                    } else if (arr_type == GGUF_TYPE_ARRAY) {
+                        ss << "???";
+                    } else {
+                        ss << gguf_data_to_str(arr_type, data, j);
+                    }
+                    if (j < arr_n - 1) {
+                        ss << ", ";
+                    }
+                }
+                ss << "]";
+                return ss.str();
+            }
+        default:
+            return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0);
+    }
+}
+
+//
+// debugging
+//
+
+static void print_tensor_shape(ggml_tensor * t) {
+    printf("%s.shape = [", t->name);
+    for (int i = 0; i < ggml_n_dims(t); ++i) {
+        printf("%" PRId64, t->ne[i]);
+        if (i < ggml_n_dims(t) - 1) {
+            printf(", ");
+        }
+    }
+    printf("]\n");
+}
+
+static void print_tensor_data(ggml_tensor * t, uint8_t * data, int64_t n) {
+    ggml_type type = t->type;
+    int64_t * ne = t->ne;
+    size_t * nb = t->nb;
+    for (int64_t i3 = 0; i3 < ne[3]; i3++) {
+        printf("%s.data: [\n", t->name);
+        for (int64_t i2 = 0; i2 < ne[2]; i2++) {
+            if (i2 == n && ne[2] > 2*n) {
+                printf("     ..., \n");
+                i2 = ne[2] - n;
+            }
+            printf("     [\n");
+            for (int64_t i1 = 0; i1 < ne[1]; i1++) {
+                if (i1 == n && ne[1] > 2*n) {
+                    printf("      ..., \n");
+                    i1 = ne[1] - n;
+                }
+                printf("      [");
+                for (int64_t i0 = 0; i0 < ne[0]; i0++) {
+                    if (i0 == n && ne[0] > 2*n) {
+                        printf("..., ");
+                        i0 = ne[0] - n;
+                    }
+                    size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
+                    float v;
+                    if (type == GGML_TYPE_F16) {
+                        v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
+                    } else if (type == GGML_TYPE_F32) {
+                        v = *(float *) &data[i];
+                    } else if (type == GGML_TYPE_I32) {
+                        v = (float) *(int32_t *) &data[i];
+                    } else if (type == GGML_TYPE_I16) {
+                        v = (float) *(int16_t *) &data[i];
+                    } else if (type == GGML_TYPE_I8) {
+                        v = (float) *(int8_t *) &data[i];
+                    } else {
+                        GGML_ABORT("fatal error");
+                    }
+                    printf("%8.4f", v);
+                    if (i0 < ne[0] - 1) printf(", ");
+                }
+                printf("],\n");
+            }
+            printf("     ],\n");
+        }
+        printf("    ]\n");
+    }
+}
+
+//
+// API used internally with mtmd
+//
+
+projector_type clip_get_projector_type(const struct clip_ctx * ctx);
diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..c25bacc17769b2564c84cf9155143097772582bf
--- /dev/null
+++ b/tools/mtmd/clip.cpp
@@ -0,0 +1,4126 @@
+// NOTE: This is modified from clip.cpp only for LLaVA,
+// so there might be still unnecessary artifacts hanging around
+// I'll gradually clean and extend it
+// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
+#include "clip.h"
+#include "clip-impl.h"
+#include "ggml.h"
+#include "ggml-cpp.h"
+#include "ggml-cpu.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
+#include "gguf.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
+
+enum ffn_op_type {
+    FFN_GELU,
+    FFN_GELU_ERF,
+    FFN_SILU,
+    FFN_GELU_QUICK,
+};
+
+enum norm_type {
+    NORM_TYPE_NORMAL,
+    NORM_TYPE_RMS,
+};
+
+//#define CLIP_DEBUG_FUNCTIONS
+
+#ifdef CLIP_DEBUG_FUNCTIONS
+static void clip_image_write_image_to_ppm(const clip_image_u8& img, const std::string& filename) {
+    std::ofstream file(filename, std::ios::binary);
+    if (!file.is_open()) {
+        LOG_ERR("Failed to open file for writing: %s\n", filename.c_str());
+        return;
+    }
+
+    // PPM header: P6 format, width, height, and max color value
+    file << "P6\n" << img.nx << " " << img.ny << "\n255\n";
+
+    // Write pixel data
+    for (size_t i = 0; i < img.buf.size(); i += 3) {
+        // PPM expects binary data in RGB format, which matches our image buffer
+        file.write(reinterpret_cast(&img.buf[i]), 3);
+    }
+
+    file.close();
+}
+
+static void clip_image_save_to_bmp(const clip_image_u8& img, const std::string& filename) {
+    std::ofstream file(filename, std::ios::binary);
+    if (!file.is_open()) {
+        LOG_ERR("Failed to open file for writing: %s\n", filename.c_str());
+        return;
+    }
+
+    int fileSize = 54 + 3 * img.nx * img.ny; // File header + info header + pixel data
+    int bytesPerPixel = 3;
+    int widthInBytes = img.nx * bytesPerPixel;
+    int paddingAmount = (4 - (widthInBytes % 4)) % 4;
+    int stride = widthInBytes + paddingAmount;
+
+    // Bitmap file header
+    unsigned char fileHeader[14] = {
+        'B','M',     // Signature
+        0,0,0,0,    // Image file size in bytes
+        0,0,0,0,    // Reserved
+        54,0,0,0    // Start of pixel array
+    };
+
+    // Total file size
+    fileSize = 54 + (stride * img.ny);
+    fileHeader[2] = (unsigned char)(fileSize);
+    fileHeader[3] = (unsigned char)(fileSize >> 8);
+    fileHeader[4] = (unsigned char)(fileSize >> 16);
+    fileHeader[5] = (unsigned char)(fileSize >> 24);
+
+    // Bitmap information header (BITMAPINFOHEADER)
+    unsigned char infoHeader[40] = {
+        40,0,0,0,   // Size of this header (40 bytes)
+        0,0,0,0,    // Image width
+        0,0,0,0,    // Image height
+        1,0,        // Number of color planes
+        24,0,       // Bits per pixel
+        0,0,0,0,    // No compression
+        0,0,0,0,    // Image size (can be 0 for no compression)
+        0,0,0,0,    // X pixels per meter (not specified)
+        0,0,0,0,    // Y pixels per meter (not specified)
+        0,0,0,0,    // Total colors (color table not used)
+        0,0,0,0     // Important colors (all are important)
+    };
+
+    // Width and height in the information header
+    infoHeader[4] = (unsigned char)(img.nx);
+    infoHeader[5] = (unsigned char)(img.nx >> 8);
+    infoHeader[6] = (unsigned char)(img.nx >> 16);
+    infoHeader[7] = (unsigned char)(img.nx >> 24);
+    infoHeader[8] = (unsigned char)(img.ny);
+    infoHeader[9] = (unsigned char)(img.ny >> 8);
+    infoHeader[10] = (unsigned char)(img.ny >> 16);
+    infoHeader[11] = (unsigned char)(img.ny >> 24);
+
+    // Write file headers
+    file.write(reinterpret_cast(fileHeader), sizeof(fileHeader));
+    file.write(reinterpret_cast(infoHeader), sizeof(infoHeader));
+
+    // Pixel data
+    std::vector padding(3, 0); // Max padding size to be added to each row
+    for (int y = img.ny - 1; y >= 0; --y) { // BMP files are stored bottom-to-top
+        for (int x = 0; x < img.nx; ++x) {
+            // Each pixel
+            size_t pixelIndex = (y * img.nx + x) * 3;
+            unsigned char pixel[3] = {
+                img.buf[pixelIndex + 2], // BMP stores pixels in BGR format
+                img.buf[pixelIndex + 1],
+                img.buf[pixelIndex]
+            };
+            file.write(reinterpret_cast(pixel), 3);
+        }
+        // Write padding for the row
+        file.write(reinterpret_cast(padding.data()), paddingAmount);
+    }
+
+    file.close();
+}
+
+// debug function to convert f32 to u8
+static void clip_image_convert_f32_to_u8(const clip_image_f32& src, clip_image_u8& dst) {
+    dst.nx = src.nx;
+    dst.ny = src.ny;
+    dst.buf.resize(3 * src.nx * src.ny);
+    for (size_t i = 0; i < src.buf.size(); ++i) {
+        dst.buf[i] = static_cast(std::min(std::max(int(src.buf[i] * 255.0f), 0), 255));
+    }
+}
+#endif
+
+
+//
+// clip layers
+//
+
+enum patch_merge_type {
+    PATCH_MERGE_FLAT,
+    PATCH_MERGE_SPATIAL_UNPAD,
+};
+
+struct clip_hparams {
+    int32_t image_size;
+    int32_t patch_size;
+    int32_t n_embd;
+    int32_t n_ff;
+    int32_t projection_dim;
+    int32_t n_head;
+    int32_t n_layer;
+    int32_t proj_scale_factor = 0; // idefics3
+
+    float image_mean[3];
+    float image_std[3];
+
+    // for models using dynamic image size, we need to have a smaller image size to warmup
+    // otherwise, user will get OOM everytime they load the model
+    int32_t warmup_image_size = 0;
+    int32_t warmup_audio_size = 3000;
+
+    ffn_op_type ffn_op = FFN_GELU;
+
+    patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT;
+
+    float eps = 1e-6;
+    float rope_theta = 0.0;
+
+    std::vector image_grid_pinpoints;
+    int32_t image_crop_resolution;
+    std::unordered_set vision_feature_layer;
+    int32_t attn_window_size = 0;
+    int32_t n_wa_pattern = 0;
+    int32_t spatial_merge_size = 0;
+
+    // audio
+    int32_t n_mel_bins = 0; // whisper preprocessor
+    int32_t proj_stack_factor = 0; // ultravox
+
+    // legacy
+    bool has_llava_projector = false;
+    int minicpmv_version = 0;
+};
+
+struct clip_layer {
+    // attention
+    ggml_tensor * k_w = nullptr;
+    ggml_tensor * k_b = nullptr;
+    ggml_tensor * q_w = nullptr;
+    ggml_tensor * q_b = nullptr;
+    ggml_tensor * v_w = nullptr;
+    ggml_tensor * v_b = nullptr;
+
+    ggml_tensor * o_w = nullptr;
+    ggml_tensor * o_b = nullptr;
+
+    ggml_tensor * k_norm = nullptr;
+    ggml_tensor * q_norm = nullptr;
+
+    // layernorm 1
+    ggml_tensor * ln_1_w = nullptr;
+    ggml_tensor * ln_1_b = nullptr;
+
+    ggml_tensor * ff_up_w = nullptr;
+    ggml_tensor * ff_up_b = nullptr;
+    ggml_tensor * ff_gate_w = nullptr;
+    ggml_tensor * ff_gate_b = nullptr;
+    ggml_tensor * ff_down_w = nullptr;
+    ggml_tensor * ff_down_b = nullptr;
+
+    // layernorm 2
+    ggml_tensor * ln_2_w = nullptr;
+    ggml_tensor * ln_2_b = nullptr;
+
+    // layer scale (no bias)
+    ggml_tensor * ls_1_w = nullptr;
+    ggml_tensor * ls_2_w = nullptr;
+};
+
+struct clip_model {
+    clip_modality modality = CLIP_MODALITY_VISION;
+    projector_type proj_type = PROJECTOR_TYPE_MLP;
+    clip_hparams hparams;
+
+    // embeddings
+    ggml_tensor * class_embedding = nullptr;
+    ggml_tensor * patch_embeddings_0 = nullptr;
+    ggml_tensor * patch_embeddings_1 = nullptr;  // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL)
+    ggml_tensor * patch_bias = nullptr;
+    ggml_tensor * position_embeddings = nullptr;
+
+    ggml_tensor * pre_ln_w = nullptr;
+    ggml_tensor * pre_ln_b = nullptr;
+
+    std::vector layers;
+
+    ggml_tensor * post_ln_w;
+    ggml_tensor * post_ln_b;
+
+    ggml_tensor * projection; // TODO: rename it to fc (fully connected layer)
+    ggml_tensor * mm_fc_w;
+    ggml_tensor * mm_fc_b;
+
+    // LLaVA projection
+    ggml_tensor * mm_input_norm_w = nullptr;
+    ggml_tensor * mm_0_w = nullptr;
+    ggml_tensor * mm_0_b = nullptr;
+    ggml_tensor * mm_2_w = nullptr;
+    ggml_tensor * mm_2_b = nullptr;
+
+    ggml_tensor * image_newline = nullptr;
+
+    // Yi type models with mlp+normalization projection
+    ggml_tensor * mm_1_w = nullptr; // Yi type models have 0, 1, 3, 4
+    ggml_tensor * mm_1_b = nullptr;
+    ggml_tensor * mm_3_w = nullptr;
+    ggml_tensor * mm_3_b = nullptr;
+    ggml_tensor * mm_4_w = nullptr;
+    ggml_tensor * mm_4_b = nullptr;
+
+    // GLMV-Edge projection
+    ggml_tensor * mm_model_adapter_conv_w = nullptr;
+    ggml_tensor * mm_model_adapter_conv_b = nullptr;
+    ggml_tensor * mm_glm_tok_boi = nullptr;
+    ggml_tensor * mm_glm_tok_eoi = nullptr;
+
+    // MobileVLM projection
+    ggml_tensor * mm_model_mlp_1_w = nullptr;
+    ggml_tensor * mm_model_mlp_1_b = nullptr;
+    ggml_tensor * mm_model_mlp_3_w = nullptr;
+    ggml_tensor * mm_model_mlp_3_b = nullptr;
+    ggml_tensor * mm_model_block_1_block_0_0_w = nullptr;
+    ggml_tensor * mm_model_block_1_block_0_1_w = nullptr;
+    ggml_tensor * mm_model_block_1_block_0_1_b = nullptr;
+    ggml_tensor * mm_model_block_1_block_1_fc1_w = nullptr;
+    ggml_tensor * mm_model_block_1_block_1_fc1_b = nullptr;
+    ggml_tensor * mm_model_block_1_block_1_fc2_w = nullptr;
+    ggml_tensor * mm_model_block_1_block_1_fc2_b = nullptr;
+    ggml_tensor * mm_model_block_1_block_2_0_w = nullptr;
+    ggml_tensor * mm_model_block_1_block_2_1_w = nullptr;
+    ggml_tensor * mm_model_block_1_block_2_1_b = nullptr;
+    ggml_tensor * mm_model_block_2_block_0_0_w = nullptr;
+    ggml_tensor * mm_model_block_2_block_0_1_w = nullptr;
+    ggml_tensor * mm_model_block_2_block_0_1_b = nullptr;
+    ggml_tensor * mm_model_block_2_block_1_fc1_w = nullptr;
+    ggml_tensor * mm_model_block_2_block_1_fc1_b = nullptr;
+    ggml_tensor * mm_model_block_2_block_1_fc2_w = nullptr;
+    ggml_tensor * mm_model_block_2_block_1_fc2_b = nullptr;
+    ggml_tensor * mm_model_block_2_block_2_0_w = nullptr;
+    ggml_tensor * mm_model_block_2_block_2_1_w = nullptr;
+    ggml_tensor * mm_model_block_2_block_2_1_b = nullptr;
+
+    // MobileVLM_V2 projection
+    ggml_tensor * mm_model_mlp_0_w = nullptr;
+    ggml_tensor * mm_model_mlp_0_b = nullptr;
+    ggml_tensor * mm_model_mlp_2_w = nullptr;
+    ggml_tensor * mm_model_mlp_2_b = nullptr;
+    ggml_tensor * mm_model_peg_0_w = nullptr;
+    ggml_tensor * mm_model_peg_0_b = nullptr;
+
+    // MINICPMV projection
+    ggml_tensor * mm_model_pos_embed_k = nullptr;
+    ggml_tensor * mm_model_query = nullptr;
+    ggml_tensor * mm_model_proj = nullptr;
+    ggml_tensor * mm_model_kv_proj = nullptr;
+    ggml_tensor * mm_model_attn_q_w = nullptr;
+    ggml_tensor * mm_model_attn_q_b = nullptr;
+    ggml_tensor * mm_model_attn_k_w = nullptr;
+    ggml_tensor * mm_model_attn_k_b = nullptr;
+    ggml_tensor * mm_model_attn_v_w = nullptr;
+    ggml_tensor * mm_model_attn_v_b = nullptr;
+    ggml_tensor * mm_model_attn_o_w = nullptr;
+    ggml_tensor * mm_model_attn_o_b = nullptr;
+    ggml_tensor * mm_model_ln_q_w = nullptr;
+    ggml_tensor * mm_model_ln_q_b = nullptr;
+    ggml_tensor * mm_model_ln_kv_w = nullptr;
+    ggml_tensor * mm_model_ln_kv_b = nullptr;
+    ggml_tensor * mm_model_ln_post_w = nullptr;
+    ggml_tensor * mm_model_ln_post_b = nullptr;
+
+    // gemma3
+    ggml_tensor * mm_input_proj_w = nullptr;
+    ggml_tensor * mm_soft_emb_norm_w = nullptr;
+
+    // pixtral
+    ggml_tensor * token_embd_img_break = nullptr;
+    ggml_tensor * mm_patch_merger_w = nullptr;
+
+    // ultravox / whisper encoder
+    ggml_tensor * conv1d_1_w = nullptr;
+    ggml_tensor * conv1d_1_b = nullptr;
+    ggml_tensor * conv1d_2_w = nullptr;
+    ggml_tensor * conv1d_2_b = nullptr;
+    ggml_tensor * mm_norm_pre_w = nullptr;
+    ggml_tensor * mm_norm_mid_w = nullptr;
+};
+
+struct clip_ctx {
+    clip_model model;
+
+    gguf_context_ptr ctx_gguf;
+    ggml_context_ptr ctx_data;
+
+    std::vector buf_compute_meta;
+
+    std::vector backend_ptrs;
+    std::vector backend_buft;
+
+    ggml_backend_t backend;
+    ggml_backend_t backend_cpu;
+    ggml_backend_buffer_ptr buf;
+
+    int max_nodes = 8192;
+    ggml_backend_sched_ptr sched;
+
+    // for debugging
+    bool debug_graph = false;
+    std::vector debug_print_tensors;
+
+    clip_ctx(clip_context_params & ctx_params) {
+        debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr;
+        backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
+        if (!backend_cpu) {
+            throw std::runtime_error("failed to initialize CPU backend");
+        }
+        backend = ctx_params.use_gpu
+                    ? ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_GPU, nullptr)
+                    : nullptr;
+
+        if (backend) {
+            LOG_INF("%s: CLIP using %s backend\n", __func__, ggml_backend_name(backend));
+            backend_ptrs.push_back(backend);
+            backend_buft.push_back(ggml_backend_get_default_buffer_type(backend));
+        } else {
+            backend = backend_cpu;
+            LOG_INF("%s: CLIP using CPU backend\n", __func__);
+        }
+
+        backend_ptrs.push_back(backend_cpu);
+        backend_buft.push_back(ggml_backend_get_default_buffer_type(backend_cpu));
+
+        sched.reset(
+            ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), 8192, false, true)
+        );
+    }
+
+    ~clip_ctx() {
+        ggml_backend_free(backend);
+        if (backend != backend_cpu) {
+            ggml_backend_free(backend_cpu);
+        }
+    }
+
+    // this function is added so that we don't change too much of the existing code
+    projector_type proj_type() const {
+        return model.proj_type;
+    }
+};
+
+struct clip_graph {
+    clip_ctx * ctx;
+    const clip_model & model;
+    const clip_hparams & hparams;
+
+    // we only support single image per batch
+    const clip_image_f32 & img;
+
+    const int patch_size;
+    const int n_patches_x;
+    const int n_patches_y;
+    const int n_patches;
+    const int n_embd;
+    const int n_head;
+    const int d_head;
+    const int n_layer;
+    const float eps;
+    const float kq_scale;
+
+    ggml_context_ptr ctx0_ptr;
+    ggml_context * ctx0;
+    ggml_cgraph * gf;
+
+    clip_graph(clip_ctx * ctx, const clip_image_f32 & img) :
+            ctx(ctx),
+            model(ctx->model),
+            hparams(model.hparams),
+            img(img),
+            patch_size(hparams.patch_size),
+            n_patches_x(img.nx / patch_size),
+            n_patches_y(img.ny / patch_size),
+            n_patches(n_patches_x * n_patches_y),
+            n_embd(hparams.n_embd),
+            n_head(hparams.n_head),
+            d_head(n_embd / n_head),
+            n_layer(hparams.n_layer),
+            eps(hparams.eps),
+            kq_scale(1.0f / sqrtf((float)d_head)) {
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ ctx->buf_compute_meta.size(),
+            /*.mem_buffer =*/ ctx->buf_compute_meta.data(),
+            /*.no_alloc   =*/ true,
+        };
+        ctx0_ptr.reset(ggml_init(params));
+        ctx0 = ctx0_ptr.get();
+        gf = ggml_new_graph_custom(ctx0, ctx->max_nodes, false);
+    }
+
+    ggml_cgraph * build_siglip() {
+        ggml_tensor * inp = build_inp();
+        ggml_tensor * cur = build_vit(
+                                inp, n_patches,
+                                NORM_TYPE_NORMAL,
+                                hparams.ffn_op,
+                                model.position_embeddings,
+                                nullptr);
+
+        if (ctx->proj_type() == PROJECTOR_TYPE_GEMMA3) {
+            const int batch_size = 1;
+            GGML_ASSERT(n_patches_x == n_patches_y);
+            const int patches_per_image = n_patches_x;
+            const int kernel_size = hparams.proj_scale_factor;
+
+            cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+            cur = ggml_reshape_4d(ctx0, cur, patches_per_image, patches_per_image, n_embd, batch_size);
+
+            // doing a pool2d to reduce the number of output tokens
+            cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0);
+            cur = ggml_reshape_3d(ctx0, cur, cur->ne[0] * cur->ne[0], n_embd, batch_size);
+            cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+
+            // apply norm before projection
+            cur = ggml_rms_norm(ctx0, cur, eps);
+            cur = ggml_mul(ctx0, cur, model.mm_soft_emb_norm_w);
+
+            // apply projection
+            cur = ggml_mul_mat(ctx0,
+                ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)),
+                cur);
+
+        } else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) {
+            // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
+
+            const int scale_factor = model.hparams.proj_scale_factor;
+            const int n_embd = cur->ne[0];
+            const int seq    = cur->ne[1];
+            const int bsz    = 1; // batch size, always 1 for now since we don't support batching
+            const int height = std::sqrt(seq);
+            const int width  = std::sqrt(seq);
+            GGML_ASSERT(scale_factor != 0);
+            cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz);
+            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+            cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur),
+                n_embd * scale_factor * scale_factor,
+                height / scale_factor,
+                width / scale_factor,
+                bsz);
+            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+            cur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, cur),
+                n_embd * scale_factor * scale_factor,
+                seq / (scale_factor * scale_factor),
+                bsz);
+
+            cur = ggml_mul_mat(ctx0, model.projection, cur);
+        } else {
+            GGML_ABORT("SigLIP: Unsupported projector type");
+        }
+
+        // build the graph
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    ggml_cgraph * build_pixtral() {
+        const int n_merge = hparams.spatial_merge_size;
+
+        // 2D input positions
+        ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
+        ggml_set_name(pos_h, "pos_h");
+        ggml_set_input(pos_h);
+
+        ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
+        ggml_set_name(pos_w, "pos_w");
+        ggml_set_input(pos_w);
+
+        auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
+            return build_rope_2d(ctx0, cur, pos_h, pos_w, hparams.rope_theta, true);
+        };
+
+        ggml_tensor * inp = build_inp();
+        ggml_tensor * cur = build_vit(
+                                inp, n_patches,
+                                NORM_TYPE_RMS,
+                                hparams.ffn_op,
+                                nullptr, // no learned pos embd
+                                add_pos);
+
+        // mistral small 3.1 patch merger
+        // ref: https://github.com/huggingface/transformers/blob/7a3e208892c06a5e278144eaf38c8599a42f53e7/src/transformers/models/mistral3/modeling_mistral3.py#L67
+        if (model.mm_patch_merger_w) {
+            GGML_ASSERT(hparams.spatial_merge_size > 0);
+
+            cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.mm_input_norm_w);
+
+            // reshape image tokens to 2D grid
+            cur = ggml_reshape_3d(ctx0, cur, n_embd, n_patches_x, n_patches_y);
+            cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, n_embd]
+            cur = ggml_cont(ctx0, cur);
+
+            // torch.nn.functional.unfold is just an im2col under the hood
+            // we just need a dummy kernel to make it work
+            ggml_tensor * kernel = ggml_view_3d(ctx0, cur, n_merge, n_merge, cur->ne[2], 0, 0, 0);
+            cur = ggml_im2col(ctx0, kernel, cur, n_merge, n_merge, 0, 0, 1, 1, true, inp->type);
+
+            // project to n_embd
+            cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
+            cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur);
+        }
+
+        // LlavaMultiModalProjector (always using GELU activation)
+        {
+            cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
+            if (model.mm_1_b) {
+                cur = ggml_add(ctx0, cur, model.mm_1_b);
+            }
+
+            cur = ggml_gelu(ctx0, cur);
+            cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
+            if (model.mm_2_b) {
+                cur = ggml_add(ctx0, cur, model.mm_2_b);
+            }
+        }
+
+        // arrangement of the [IMG_BREAK] token
+        {
+            // not efficient, but works
+            // the trick is to view the embeddings as a 3D tensor with shape [n_embd, n_patches_per_row, n_rows]
+            // and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
+            // after the concatenation, we have a tensor with shape [n_embd, n_patches_per_row + 1, n_rows]
+
+            const int p_y             = n_merge > 0 ? n_patches_y / n_merge : n_patches_y;
+            const int p_x             = n_merge > 0 ? n_patches_x / n_merge : n_patches_x;
+            const int p_total         = p_x * p_y;
+            const int n_embd_text     = cur->ne[0];
+            const int n_tokens_output = p_total + p_y - 1; // one [IMG_BREAK] per row, except the last row
+
+            ggml_tensor * tmp = ggml_reshape_3d(ctx0, cur, n_embd_text, p_x, p_y);
+            ggml_tensor * tok = ggml_new_tensor_3d(ctx0, tmp->type, n_embd_text, 1, p_y);
+            tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor
+            tok = ggml_add(ctx0, tok, model.token_embd_img_break);
+            tmp = ggml_concat(ctx0, tmp, tok, 1);
+            cur = ggml_view_2d(ctx0, tmp,
+                n_embd_text, n_tokens_output,
+                ggml_row_size(tmp->type, n_embd_text), 0);
+        }
+
+        // build the graph
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    // Qwen2VL and Qwen2.5VL use M-RoPE
+    ggml_cgraph * build_qwen2vl() {
+        GGML_ASSERT(model.patch_bias == nullptr);
+        GGML_ASSERT(model.class_embedding == nullptr);
+
+        const int batch_size       = 1;
+        const bool use_window_attn = hparams.n_wa_pattern > 0;
+        const int n_wa_pattern     = hparams.n_wa_pattern;
+        const int n_pos            = n_patches;
+        const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position
+
+        norm_type norm_t = ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL
+            ? NORM_TYPE_RMS // qwen 2.5 vl
+            : NORM_TYPE_NORMAL; // qwen 2 vl
+
+        int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
+
+        ggml_tensor * inp_raw = build_inp_raw();
+        ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+
+        GGML_ASSERT(img.nx % (patch_size * 2) == 0);
+        GGML_ASSERT(img.ny % (patch_size * 2) == 0);
+
+        // second conv dimension
+        {
+            auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+            inp = ggml_add(ctx0, inp, inp_1);
+
+            inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3));  // [w, h, c, b] -> [c, w, h, b]
+            inp = ggml_reshape_4d(
+                ctx0, inp,
+                n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
+            inp = ggml_reshape_4d(
+                ctx0, inp,
+                n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
+            inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
+            inp = ggml_reshape_3d(
+                ctx0, inp,
+                n_embd, n_patches_x * n_patches_y, batch_size);
+        }
+
+        ggml_tensor * inpL           = inp;
+        ggml_tensor * window_mask    = nullptr;
+        ggml_tensor * window_idx     = nullptr;
+        ggml_tensor * inv_window_idx = nullptr;
+
+        ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
+        ggml_set_name(positions, "positions");
+        ggml_set_input(positions);
+
+        // pre-layernorm
+        if (model.pre_ln_w) {
+            inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1);
+        }
+
+        if (use_window_attn) {
+            // handle window attention inputs
+            inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4);
+            ggml_set_name(inv_window_idx, "inv_window_idx");
+            ggml_set_input(inv_window_idx);
+            // mask for window attention
+            window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_pos, n_pos);
+            ggml_set_name(window_mask, "window_mask");
+            ggml_set_input(window_mask);
+
+            // inpL shape: [n_embd, n_patches_x * n_patches_y, batch_size]
+            GGML_ASSERT(batch_size == 1);
+            inpL = ggml_reshape_2d(ctx0, inpL, n_embd * 4, n_patches_x * n_patches_y * batch_size / 4);
+            inpL = ggml_get_rows(ctx0, inpL, inv_window_idx);
+            inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_patches_x * n_patches_y, batch_size);
+        }
+
+        // loop over layers
+        for (int il = 0; il < n_layer; il++) {
+            auto & layer = model.layers[il];
+            const bool full_attn = use_window_attn ? (il + 1) % n_wa_pattern == 0 : true;
+
+            ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
+
+            // layernorm1
+            cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il);
+            cb(cur, "ln1", il);
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = ggml_add(ctx0,
+                    ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b);
+                ggml_tensor * Kcur = ggml_add(ctx0,
+                    ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b);
+                ggml_tensor * Vcur = ggml_add(ctx0,
+                    ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_patches);
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                // apply M-RoPE
+                Qcur = ggml_rope_multi(
+                    ctx0, Qcur, positions, nullptr,
+                    d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
+                Kcur = ggml_rope_multi(
+                    ctx0, Kcur, positions, nullptr,
+                    d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
+
+                cb(Qcur, "Qcur_rope", il);
+                cb(Kcur, "Kcur_rope", il);
+
+                ggml_tensor * attn_mask = full_attn ? nullptr : window_mask;
+
+                cur = build_attn(layer.o_w, layer.o_b,
+                    Qcur, Kcur, Vcur, attn_mask, kq_scale, il);
+                cb(cur, "attn_out", il);
+            }
+
+            // re-add the layer input, e.g., residual
+            cur = ggml_add(ctx0, cur, inpL);
+
+            inpL = cur; // inpL = residual, cur = hidden_states
+
+            cb(cur, "ffn_inp", il);
+
+            // layernorm2
+            cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il);
+            cb(cur, "ffn_inp_normed", il);
+
+            // ffn
+            cur = build_ffn(cur,
+                layer.ff_up_w, layer.ff_up_b,
+                layer.ff_gate_w, layer.ff_gate_b,
+                layer.ff_down_w, layer.ff_down_b,
+                hparams.ffn_op, il);
+
+            cb(cur, "ffn_out", il);
+
+            // residual 2
+            cur = ggml_add(ctx0, inpL, cur);
+            cb(cur, "layer_out", il);
+
+            inpL = cur;
+        }
+
+        // post-layernorm
+        if (model.post_ln_w) {
+            inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer);
+        }
+
+        // multimodal projection
+        ggml_tensor * embeddings = inpL;
+        embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size);
+
+        embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
+        embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
+
+        // GELU activation
+        embeddings = ggml_gelu(ctx0, embeddings);
+
+        // Second linear layer
+        embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
+        embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
+
+        if (use_window_attn) {
+            window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4);
+            ggml_set_name(window_idx, "window_idx");
+            ggml_set_input(window_idx);
+
+            // embeddings shape: [n_embd, n_patches_x * n_patches_y, batch_size]
+            GGML_ASSERT(batch_size == 1);
+            embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, n_patches_x * n_patches_y / 4);
+            embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
+            embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, n_patches_x * n_patches_y / 4, batch_size);
+        }
+
+        // build the graph
+        ggml_build_forward_expand(gf, embeddings);
+
+        return gf;
+    }
+
+    ggml_cgraph * build_minicpmv() {
+        const int batch_size = 1;
+
+        GGML_ASSERT(model.class_embedding == nullptr);
+        const int n_pos = n_patches;
+
+        // position embeddings for the projector (not for ViT)
+        int n_output_dim = clip_n_mmproj_embd(ctx);
+        ggml_tensor * pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_output_dim, n_pos, batch_size);
+        ggml_set_name(pos_embed, "pos_embed");
+        ggml_set_input(pos_embed);
+
+        // for selecting learned pos embd, used by ViT
+        struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
+        ggml_set_name(positions, "positions");
+        ggml_set_input(positions);
+
+        ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, model.position_embeddings, positions);
+
+        ggml_tensor * inp = build_inp();
+        ggml_tensor * embeddings = build_vit(
+                                inp, n_patches,
+                                NORM_TYPE_NORMAL,
+                                hparams.ffn_op,
+                                learned_pos_embd,
+                                nullptr);
+
+        // resampler projector (it is just another transformer)
+
+        ggml_tensor * q = model.mm_model_query;
+        ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings);
+
+        // norm
+        q = build_norm(q, model.mm_model_ln_q_w, model.mm_model_ln_q_b, NORM_TYPE_NORMAL, eps, -1);
+        v = build_norm(v, model.mm_model_ln_kv_w, model.mm_model_ln_kv_b, NORM_TYPE_NORMAL, eps, -1);
+
+        // k = v + pos_embed
+        ggml_tensor * k = ggml_add(ctx0, v, pos_embed);
+
+        // attention
+        {
+            int n_embd = clip_n_mmproj_embd(ctx);
+            const int d_head = 128;
+            int n_head = n_embd/d_head;
+            int num_query = 96;
+            if (ctx->model.hparams.minicpmv_version == 2) {
+                num_query = 96;
+            } else if (ctx->model.hparams.minicpmv_version == 3) {
+                num_query = 64;
+            } else if (ctx->model.hparams.minicpmv_version == 4) {
+                num_query = 64;
+            }
+
+            ggml_tensor * Q = ggml_add(ctx0,
+                ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q),
+                model.mm_model_attn_q_b);
+            ggml_tensor * K = ggml_add(ctx0,
+                ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k),
+                model.mm_model_attn_k_b);
+            ggml_tensor * V = ggml_add(ctx0,
+                ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v),
+                model.mm_model_attn_v_b);
+
+            Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_query);
+            K = ggml_reshape_3d(ctx0, K, d_head, n_head, n_pos);
+            V = ggml_reshape_3d(ctx0, V, d_head, n_head, n_pos);
+
+            cb(Q, "resampler_Q", -1);
+            cb(K, "resampler_K", -1);
+            cb(V, "resampler_V", -1);
+
+            embeddings = build_attn(
+                model.mm_model_attn_o_w,
+                model.mm_model_attn_o_b,
+                Q, K, V, nullptr, kq_scale, -1);
+            cb(embeddings, "resampler_attn_out", -1);
+        }
+        // layernorm
+        embeddings = build_norm(embeddings, model.mm_model_ln_post_w, model.mm_model_ln_post_b, NORM_TYPE_NORMAL, eps, -1);
+
+        // projection
+        embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings);
+
+        // build the graph
+        ggml_build_forward_expand(gf, embeddings);
+
+        return gf;
+    }
+
+    ggml_cgraph * build_internvl() {
+        GGML_ASSERT(model.class_embedding != nullptr);
+        GGML_ASSERT(model.position_embeddings != nullptr);
+
+        const int n_pos = n_patches + 1;
+        ggml_tensor * inp = build_inp();
+
+        // add CLS token
+        inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
+
+        // The larger models use a different ViT, which uses RMS norm instead of layer norm
+        // ref: https://github.com/ggml-org/llama.cpp/pull/13443#issuecomment-2869786188
+        norm_type norm_t = (hparams.n_embd == 3200 && hparams.n_layer == 45)
+            ? NORM_TYPE_RMS // 6B ViT (Used by InternVL 2.5/3 - 26B, 38B, 78B)
+            : NORM_TYPE_NORMAL; // 300M ViT (Used by all smaller InternVL models)
+
+        ggml_tensor * cur = build_vit(
+                                inp, n_pos,
+                                norm_t,
+                                hparams.ffn_op,
+                                model.position_embeddings,
+                                nullptr);
+
+        // remove CLS token
+        cur = ggml_view_2d(ctx0, cur,
+            n_embd, n_patches,
+            ggml_row_size(cur->type, n_embd), 0);
+
+        // pixel shuffle
+        {
+            const int scale_factor = model.hparams.proj_scale_factor;
+            const int bsz    = 1; // batch size, always 1 for now since we don't support batching
+            const int height = n_patches_y;
+            const int width  = n_patches_x;
+            GGML_ASSERT(scale_factor > 0);
+            cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz);
+            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+            cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur),
+                n_embd * scale_factor * scale_factor,
+                height / scale_factor,
+                width / scale_factor,
+                bsz);
+            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+            // flatten to 2D
+            cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur),
+                n_embd * scale_factor * scale_factor,
+                cur->ne[1] * cur->ne[2]);
+        }
+
+        // projector (always using GELU activation)
+        {
+            // projector LayerNorm uses pytorch's default eps = 1e-5
+            // ref: https://huggingface.co/OpenGVLab/InternVL3-8B-Instruct/blob/a34d3e4e129a5856abfd6aa6de79776484caa14e/modeling_internvl_chat.py#L79
+            cur = build_norm(cur, model.mm_0_w, model.mm_0_b, NORM_TYPE_NORMAL, 1e-5, -1);
+            cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
+            cur = ggml_add(ctx0, cur, model.mm_1_b);
+            cur = ggml_gelu(ctx0, cur);
+            cur = ggml_mul_mat(ctx0, model.mm_3_w, cur);
+            cur = ggml_add(ctx0, cur, model.mm_3_b);
+        }
+
+        // build the graph
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    ggml_cgraph * build_llama4() {
+        GGML_ASSERT(model.class_embedding != nullptr);
+        GGML_ASSERT(model.position_embeddings != nullptr);
+
+        const int n_pos = n_patches + 1; // +1 for [CLS]
+
+        // 2D input positions
+        ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
+        ggml_set_name(pos_h, "pos_h");
+        ggml_set_input(pos_h);
+
+        ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
+        ggml_set_name(pos_w, "pos_w");
+        ggml_set_input(pos_w);
+
+        ggml_tensor * inp = build_inp_raw();
+
+        // Llama4UnfoldConvolution
+        {
+            ggml_tensor * kernel = ggml_reshape_4d(ctx0, model.patch_embeddings_0,
+                                                    patch_size, patch_size, 3, n_embd);
+            inp = ggml_im2col(ctx0, kernel, inp, patch_size, patch_size, 0, 0, 1, 1, true, inp->type);
+            inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp);
+            inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches);
+            cb(inp, "patch_conv", -1);
+        }
+
+        // add CLS token
+        inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
+
+        // build ViT with 2D position embeddings
+        auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
+            // first half is X axis and second half is Y axis
+            // ref: https://github.com/huggingface/transformers/blob/40a493c7ed4f19f08eadb0639cf26d49bfa5e180/src/transformers/models/llama4/modeling_llama4.py#L1312
+            // ref: https://github.com/Blaizzy/mlx-vlm/blob/a57156aa87b33cca6e5ee6cfc14dd4ef8f611be6/mlx_vlm/models/llama4/vision.py#L441
+            return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false);
+        };
+        ggml_tensor * cur = build_vit(
+                                inp, n_pos,
+                                NORM_TYPE_NORMAL,
+                                hparams.ffn_op,
+                                model.position_embeddings,
+                                add_pos);
+
+        // remove CLS token
+        cur = ggml_view_2d(ctx0, cur,
+            n_embd, n_patches,
+            ggml_row_size(cur->type, n_embd), 0);
+
+        // pixel shuffle
+        // based on Llama4VisionPixelShuffleMLP
+        // https://github.com/huggingface/transformers/blob/2932f318a20d9e54cc7aea052e040164d85de7d6/src/transformers/models/llama4/modeling_llama4.py#L1151
+        {
+            const int scale_factor = model.hparams.proj_scale_factor;
+            const int bsz = 1; // batch size, always 1 for now since we don't support batching
+            GGML_ASSERT(scale_factor > 0);
+            GGML_ASSERT(n_patches_x == n_patches_y); // llama4 only supports square images
+            cur = ggml_reshape_4d(ctx0, cur,
+                n_embd * scale_factor,
+                n_patches_x / scale_factor,
+                n_patches_y,
+                bsz);
+            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+            cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur),
+                n_embd * scale_factor * scale_factor,
+                n_patches_x / scale_factor,
+                n_patches_y / scale_factor,
+                bsz);
+            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+            // flatten to 2D
+            cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur),
+                n_embd * scale_factor * scale_factor,
+                n_patches / scale_factor / scale_factor);
+            cb(cur, "pixel_shuffle", -1);
+        }
+
+        // based on Llama4VisionMLP2 (always uses GELU activation, no bias)
+        {
+            cur = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, cur);
+            cur = ggml_gelu(ctx0, cur);
+            cur = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, cur);
+            cur = ggml_gelu(ctx0, cur);
+            cb(cur, "adapter_mlp", -1);
+        }
+
+        // Llama4MultiModalProjector
+        cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur);
+        cb(cur, "projected", -1);
+
+        // build the graph
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    // this graph is used by llava, granite and glm
+    // due to having embedding_stack (used by granite), we cannot reuse build_vit
+    ggml_cgraph * build_llava() {
+        const int batch_size = 1;
+        const int n_pos = n_patches + (model.class_embedding ? 1 : 0);
+
+        GGML_ASSERT(n_patches_x == n_patches_y && "only square images supported");
+
+        // Calculate the deepest feature layer based on hparams and projector type
+        int max_feature_layer = n_layer;
+        {
+            // Get the index of the second to last layer; this is the default for models that have a llava projector
+            int il_last = hparams.n_layer - 1;
+            int deepest_feature_layer = -1;
+
+            if (ctx->proj_type() == PROJECTOR_TYPE_MINICPMV || ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE) {
+                il_last += 1;
+            }
+
+            // If we set explicit vision feature layers, only go up to the deepest one
+            // NOTE: only used by granite-vision models for now
+            for (const auto & feature_layer : hparams.vision_feature_layer) {
+                if (feature_layer > deepest_feature_layer) {
+                    deepest_feature_layer = feature_layer;
+                }
+            }
+            max_feature_layer = deepest_feature_layer < 0 ? il_last : deepest_feature_layer;
+        }
+
+        ggml_tensor * inp = build_inp();
+
+        // concat class_embeddings and patch_embeddings
+        if (model.class_embedding) {
+            inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
+        }
+
+        ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
+        ggml_set_name(positions, "positions");
+        ggml_set_input(positions);
+
+        inp = ggml_add(ctx0, inp, ggml_get_rows(ctx0, model.position_embeddings, positions));
+
+        ggml_tensor * inpL = inp;
+
+        // pre-layernorm
+        if (model.pre_ln_w) {
+            inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, NORM_TYPE_NORMAL, eps, -1);
+            cb(inpL, "pre_ln", -1);
+        }
+
+        std::vector embedding_stack;
+        const auto & vision_feature_layer = hparams.vision_feature_layer;
+
+        // loop over layers
+        for (int il = 0; il < max_feature_layer; il++) {
+            auto & layer = model.layers[il];
+            ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
+
+            // If this is an embedding feature layer, save the output.
+            // NOTE: 0 index here refers to the input to the encoder.
+            if (vision_feature_layer.find(il) != vision_feature_layer.end()) {
+                embedding_stack.push_back(cur);
+            }
+
+            // layernorm1
+            cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il);
+            cb(cur, "layer_inp_normed", il);
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
+                if (layer.q_b) {
+                    Qcur = ggml_add(ctx0, Qcur, layer.q_b);
+                }
+
+                ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
+                if (layer.k_b) {
+                    Kcur = ggml_add(ctx0, Kcur, layer.k_b);
+                }
+
+                ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
+                if (layer.v_b) {
+                    Vcur = ggml_add(ctx0, Vcur, layer.v_b);
+                }
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos);
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(layer.o_w, layer.o_b,
+                    Qcur, Kcur, Vcur, nullptr, kq_scale, il);
+                cb(cur, "attn_out", il);
+            }
+
+            // re-add the layer input, e.g., residual
+            cur = ggml_add(ctx0, cur, inpL);
+
+            inpL = cur; // inpL = residual, cur = hidden_states
+
+            cb(cur, "ffn_inp", il);
+
+            // layernorm2
+            cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il);
+            cb(cur, "ffn_inp_normed", il);
+
+            // ffn
+            cur = build_ffn(cur,
+                layer.ff_up_w, layer.ff_up_b,
+                layer.ff_gate_w, layer.ff_gate_b,
+                layer.ff_down_w, layer.ff_down_b,
+                hparams.ffn_op, il);
+
+            cb(cur, "ffn_out", il);
+
+            // residual 2
+            cur = ggml_add(ctx0, inpL, cur);
+            cb(cur, "layer_out", il);
+
+            inpL = cur;
+        }
+
+        // post-layernorm
+        if (model.post_ln_w) {
+            inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, NORM_TYPE_NORMAL, eps, -1);
+        }
+
+        ggml_tensor * embeddings = inpL;
+
+        // process vision feature layers (used by granite)
+        {
+            // final layer is a vision feature layer
+            if (vision_feature_layer.find(max_feature_layer) != vision_feature_layer.end()) {
+                embedding_stack.push_back(inpL);
+            }
+
+            // If feature layers are explicitly set, stack them (if we have multiple)
+            if (!embedding_stack.empty()) {
+                embeddings = embedding_stack[0];
+                for (size_t i = 1; i < embedding_stack.size(); i++) {
+                    embeddings = ggml_concat(ctx0, embeddings, embedding_stack[i], 0);
+                }
+            }
+        }
+
+        // llava projector (also used by granite)
+        if (ctx->model.hparams.has_llava_projector) {
+            embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
+
+            ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
+            ggml_set_name(patches, "patches");
+            ggml_set_input(patches);
+
+            // shape [1, 576, 1024]
+            // ne is whcn, ne = [1024, 576, 1, 1]
+            embeddings = ggml_get_rows(ctx0, embeddings, patches);
+
+            // print_tensor_info(embeddings, "embeddings");
+
+            // llava projector
+            if (ctx->proj_type() == PROJECTOR_TYPE_MLP) {
+                embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
+                embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
+
+                embeddings = ggml_gelu(ctx0, embeddings);
+                if (model.mm_2_w) {
+                    embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
+                    embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
+                }
+            }
+            else if (ctx->proj_type() == PROJECTOR_TYPE_MLP_NORM) {
+                embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
+                embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
+                // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
+                // First LayerNorm
+                embeddings = ggml_norm(ctx0, embeddings, eps);
+                embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_1_w),
+                                    model.mm_1_b);
+
+                // GELU activation
+                embeddings = ggml_gelu(ctx0, embeddings);
+
+                // Second linear layer
+                embeddings = ggml_mul_mat(ctx0, model.mm_3_w, embeddings);
+                embeddings = ggml_add(ctx0, embeddings, model.mm_3_b);
+
+                // Second LayerNorm
+                embeddings = ggml_norm(ctx0, embeddings, eps);
+                embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_4_w),
+                                    model.mm_4_b);
+            }
+            else if (ctx->proj_type() == PROJECTOR_TYPE_LDP) {
+                // MobileVLM projector
+                int n_patch = 24;
+                ggml_tensor * mlp_1 = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, embeddings);
+                mlp_1 = ggml_add(ctx0, mlp_1, model.mm_model_mlp_1_b);
+                mlp_1 = ggml_gelu(ctx0, mlp_1);
+                ggml_tensor * mlp_3 = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, mlp_1);
+                mlp_3 = ggml_add(ctx0, mlp_3, model.mm_model_mlp_3_b);
+                // mlp_3 shape = [1, 576, 2048], ne = [2048, 576, 1, 1]
+
+                // block 1
+                ggml_tensor * block_1 = nullptr;
+                {
+                    // transpose from [1, 576, 2048] --> [1, 2048, 576] --> [1, 2048, 24, 24]
+                    mlp_3 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_3, 1, 0, 2, 3));
+                    mlp_3 = ggml_reshape_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]);
+                    // stride = 1, padding = 1, bias is nullptr
+                    block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1);
+
+                    // layer norm
+                    // // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
+                    block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3));
+                    // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1]
+                    block_1 = ggml_norm(ctx0, block_1, eps);
+                    block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_0_1_w), model.mm_model_block_1_block_0_1_b);
+                    block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
+
+                    // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
+                    // hardswish
+                    ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1);
+
+                    block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0);
+                    // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
+                    // pointwise conv
+                    block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
+                    block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc1_w, block_1);
+                    block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc1_b);
+                    block_1 = ggml_relu(ctx0, block_1);
+                    block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc2_w, block_1);
+                    block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc2_b);
+                    block_1 = ggml_hardsigmoid(ctx0, block_1);
+                    // block_1_hw shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1], block_1 shape = [1, 2048], ne = [2048, 1, 1, 1]
+                    block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]);
+                    block_1 = ggml_mul(ctx0, block_1_hw, block_1);
+
+                    int w = block_1->ne[0], h = block_1->ne[1];
+                    block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]);
+                    block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
+
+                    // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
+                    block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_2_0_w, block_1);
+                    block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
+
+                    // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1]
+                    block_1 = ggml_norm(ctx0, block_1, eps);
+                    block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_2_1_w), model.mm_model_block_1_block_2_1_b);
+                    block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
+                    // block1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
+                    // residual
+                    block_1 = ggml_add(ctx0, mlp_3, block_1);
+                }
+
+                // block_2
+                {
+                    // stride = 2
+                    block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_2_block_0_0_w, block_1, 2, 2, 1, 1, 1, 1);
+
+                    // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1]
+                    // layer norm
+                    block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3));
+                    // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1]
+                    block_1 = ggml_norm(ctx0, block_1, eps);
+                    block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_0_1_w), model.mm_model_block_2_block_0_1_b);
+                    block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
+                    // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1]
+                    // hardswish
+                    ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1);
+
+                    // not sure the parameters is right for globalAvgPooling
+                    block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0);
+                    // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
+                    // pointwise conv
+                    block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
+                    block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc1_w, block_1);
+                    block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc1_b);
+                    block_1 = ggml_relu(ctx0, block_1);
+                    block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc2_w, block_1);
+                    block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc2_b);
+                    block_1 = ggml_hardsigmoid(ctx0, block_1);
+
+                    // block_1_hw shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1], block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
+                    block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]);
+                    block_1 = ggml_mul(ctx0, block_1_hw, block_1);
+
+                    int w = block_1->ne[0], h = block_1->ne[1];
+                    block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]);
+                    block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
+                    // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
+                    block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_2_0_w, block_1);
+                    block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
+
+
+                    // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1]
+                    block_1 = ggml_norm(ctx0, block_1, eps);
+                    block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_2_1_w), model.mm_model_block_2_block_2_1_b);
+                    block_1 = ggml_reshape_3d(ctx0, block_1, block_1->ne[0], block_1->ne[1] * block_1->ne[2], block_1->ne[3]);
+                    // block_1 shape = [1, 144, 2048], ne = [2048, 144, 1]
+                }
+                embeddings = block_1;
+            }
+            else if (ctx->proj_type() == PROJECTOR_TYPE_LDPV2)
+            {
+                int n_patch = 24;
+                ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
+                mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b);
+                mlp_0 = ggml_gelu(ctx0, mlp_0);
+                ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0);
+                mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b);
+                // mlp_2 ne = [2048, 576, 1, 1]
+                // // AVG Pool Layer 2*2, strides = 2
+                mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 0, 2, 3));
+                // mlp_2 ne = [576, 2048, 1, 1]
+                mlp_2 = ggml_reshape_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]);
+                // mlp_2 ne [24, 24, 2048, 1]
+                mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0);
+                // weight ne = [3, 3, 2048, 1]
+                ggml_tensor * peg_0 = ggml_conv_2d_dw(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1);
+                peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3));
+                peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b);
+                mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 2, 0, 3));
+                peg_0 = ggml_add(ctx0, peg_0, mlp_2);
+                peg_0 = ggml_reshape_3d(ctx0, peg_0, peg_0->ne[0], peg_0->ne[1] * peg_0->ne[2], peg_0->ne[3]);
+                embeddings = peg_0;
+            }
+            else {
+                GGML_ABORT("fatal error");
+            }
+        }
+
+        // glm projector
+        else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE) {
+            size_t gridsz = (size_t)sqrt(embeddings->ne[1]);
+            embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3));
+            embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]);
+            embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1);
+            embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size);
+            embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3));
+            embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b);
+            // GLU
+            {
+                embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
+                embeddings = ggml_norm(ctx0, embeddings, eps);
+                embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
+                embeddings = ggml_gelu_inplace(ctx0, embeddings);
+                ggml_tensor * x = embeddings;
+                embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
+                x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
+                embeddings = ggml_silu_inplace(ctx0, embeddings);
+                embeddings = ggml_mul(ctx0, embeddings,x);
+                embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
+            }
+            // arrangement of BOI/EOI token embeddings
+            // note: these embeddings are not present in text model, hence we cannot process them as text tokens
+            // see: https://huggingface.co/THUDM/glm-edge-v-2b/blob/main/siglip.py#L53
+            {
+                embeddings = ggml_concat(ctx0, model.mm_glm_tok_boi, embeddings, 1); // BOI
+                embeddings = ggml_concat(ctx0, embeddings, model.mm_glm_tok_eoi, 1); // EOI
+            }
+        }
+
+        else {
+            GGML_ABORT("llava: unknown projector type");
+        }
+
+        // build the graph
+        ggml_build_forward_expand(gf, embeddings);
+
+        return gf;
+    }
+
+    // whisper encoder with custom projector
+    ggml_cgraph * build_whisper_enc() {
+        const int n_frames = img.nx;
+        const int n_pos    = n_frames / 2;
+        GGML_ASSERT(model.position_embeddings->ne[1] >= n_pos);
+
+        ggml_tensor * inp = build_inp_raw(1);
+
+        // conv1d block
+        {
+            // convolution + gelu
+            ggml_tensor * cur = ggml_conv_1d_ph(ctx0, model.conv1d_1_w, inp, 1, 1);
+            cur = ggml_add(ctx0, cur, model.conv1d_1_b);
+
+            cur = ggml_gelu_erf(ctx0, cur);
+
+            cur = ggml_conv_1d_ph(ctx0, model.conv1d_2_w, cur, 2, 1);
+            cur = ggml_add(ctx0, cur, model.conv1d_2_b);
+
+            cur = ggml_gelu_erf(ctx0, cur);
+            // transpose
+            inp = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+            cb(inp, "after_conv1d", -1);
+        }
+
+        // sanity check (only check one layer, but it should be the same for all)
+        GGML_ASSERT(model.layers[0].ln_1_w && model.layers[0].ln_1_b);
+        GGML_ASSERT(model.layers[0].ln_2_w && model.layers[0].ln_2_b);
+        GGML_ASSERT(model.layers[0].q_b);
+        GGML_ASSERT(model.layers[0].v_b);
+        GGML_ASSERT(!model.layers[0].k_b); // no bias for k
+        GGML_ASSERT(model.post_ln_w && model.post_ln_b);
+
+        ggml_tensor * pos_embd_selected = ggml_view_2d(
+            ctx0, model.position_embeddings,
+            model.position_embeddings->ne[0], n_pos,
+            model.position_embeddings->nb[1], 0
+        );
+        ggml_tensor * cur = build_vit(
+                                inp, n_pos,
+                                NORM_TYPE_NORMAL,
+                                hparams.ffn_op,
+                                pos_embd_selected,
+                                nullptr);
+
+        cb(cur, "after_transformer", -1);
+
+        if (ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX) {
+            // StackAudioFrames
+            // https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py
+            {
+                int64_t stride = n_embd * hparams.proj_stack_factor;
+                int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
+                int64_t pad = padded_len - ggml_nelements(cur);
+                if (pad > 0) {
+                    cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
+                    cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
+                }
+                cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
+                                    ggml_row_size(cur->type, stride), 0);
+            }
+
+            cb(cur, "after_stacked", -1);
+
+            // UltravoxProjector
+            {
+                // pre-norm
+                cur = ggml_rms_norm(ctx0, cur, 1e-6);
+                cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
+
+                // ffn in
+                cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
+
+                // swiglu
+                {
+                    int64_t split_point = cur->ne[0] / 2;
+                    ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
+                    ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
+
+                    // see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
+                    x1 = ggml_silu(ctx0, x1);
+                    cur = ggml_mul(ctx0, x0, x1);
+                }
+
+                // mid-norm
+                cur = ggml_rms_norm(ctx0, cur, 1e-6);
+                cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w);
+
+                // ffn out
+                cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
+            }
+
+        } else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) {
+            // projector
+            cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur);
+            cur = ggml_add(ctx0, cur, model.mm_fc_b);
+
+        } else {
+            GGML_ABORT("%s: unknown projector type", __func__);
+        }
+
+        cb(cur, "projected", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+private:
+    //
+    // utility functions
+    //
+
+    void cb(ggml_tensor * cur0, const char * name, int il) const {
+        if (ctx->debug_graph) {
+            ggml_tensor * cur = ggml_cpy(ctx0, cur0, ggml_dup_tensor(ctx0, cur0));
+            std::string cur_name = il >= 0 ? std::string(name) + "_" + std::to_string(il) : name;
+            ggml_set_name(cur, cur_name.c_str());
+            ggml_set_output(cur);
+            ggml_build_forward_expand(gf, cur);
+            ctx->debug_print_tensors.push_back(cur);
+        }
+    }
+
+    // build vision transformer (ViT) cgraph
+    // this function should cover most of the models
+    // if your model has specific features, you should probably duplicate this function
+    ggml_tensor * build_vit(
+                ggml_tensor * inp,
+                int64_t n_pos,
+                norm_type norm_t,
+                ffn_op_type ffn_t,
+                ggml_tensor * learned_pos_embd,
+                std::function add_pos
+            ) {
+        if (learned_pos_embd) {
+            inp = ggml_add(ctx0, inp, learned_pos_embd);
+            cb(inp, "pos_embed", -1);
+        }
+
+        ggml_tensor * inpL = inp;
+
+        // pre-layernorm
+        if (model.pre_ln_w) {
+            inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1);
+            cb(inpL, "pre_ln", -1);
+        }
+
+        // loop over layers
+        for (int il = 0; il < n_layer; il++) {
+            auto & layer = model.layers[il];
+            ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
+
+            // layernorm1
+            cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il);
+            cb(cur, "layer_inp_normed", il);
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
+                if (layer.q_b) {
+                    Qcur = ggml_add(ctx0, Qcur, layer.q_b);
+                }
+
+                ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
+                if (layer.k_b) {
+                    Kcur = ggml_add(ctx0, Kcur, layer.k_b);
+                }
+
+                ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
+                if (layer.v_b) {
+                    Vcur = ggml_add(ctx0, Vcur, layer.v_b);
+                }
+
+                if (layer.q_norm) {
+                    Qcur = build_norm(Qcur, layer.q_norm, NULL, norm_t, eps, il);
+                    cb(Qcur, "Qcur_norm", il);
+                }
+
+                if (layer.k_norm) {
+                    Kcur = build_norm(Kcur, layer.k_norm, NULL, norm_t, eps, il);
+                    cb(Kcur, "Kcur_norm", il);
+                }
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos);
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                if (add_pos) {
+                    Qcur = add_pos(Qcur, layer);
+                    Kcur = add_pos(Kcur, layer);
+                    cb(Qcur, "Qcur_pos", il);
+                    cb(Kcur, "Kcur_pos", il);
+                }
+
+                cur = build_attn(layer.o_w, layer.o_b,
+                    Qcur, Kcur, Vcur, nullptr, kq_scale, il);
+                cb(cur, "attn_out", il);
+            }
+
+            if (layer.ls_1_w) {
+                cur = ggml_mul(ctx0, cur, layer.ls_1_w);
+                cb(cur, "attn_out_scaled", il);
+            }
+
+            // re-add the layer input, e.g., residual
+            cur = ggml_add(ctx0, cur, inpL);
+
+            inpL = cur; // inpL = residual, cur = hidden_states
+
+            cb(cur, "ffn_inp", il);
+
+            // layernorm2
+            cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il);
+            cb(cur, "ffn_inp_normed", il);
+
+            // ffn
+            cur = build_ffn(cur,
+                layer.ff_up_w, layer.ff_up_b,
+                layer.ff_gate_w, layer.ff_gate_b,
+                layer.ff_down_w, layer.ff_down_b,
+                ffn_t, il);
+
+            cb(cur, "ffn_out", il);
+
+            if (layer.ls_2_w) {
+                cur = ggml_mul(ctx0, cur, layer.ls_2_w);
+                cb(cur, "ffn_out_scaled", il);
+            }
+
+            // residual 2
+            cur = ggml_add(ctx0, inpL, cur);
+            cb(cur, "layer_out", il);
+
+            inpL = cur;
+        }
+
+        // TODO @ngxson : find a way to move this outside
+        if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) {
+            ggml_tensor * cur = inpL;
+            cur = ggml_transpose(ctx0, cur);
+            cur = ggml_cont(ctx0, cur);
+            cur = ggml_pool_1d(ctx0, cur, GGML_OP_POOL_AVG, 2, 2, 0);
+            cur = ggml_transpose(ctx0, cur);
+            cur = ggml_cont(ctx0, cur);
+            inpL = cur;
+        }
+
+        // post-layernorm
+        if (model.post_ln_w) {
+            inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, -1);
+        }
+        return inpL;
+    }
+
+    // build the input after conv2d (inp_raw --> patches)
+    // returns tensor with shape [n_embd, n_patches]
+    ggml_tensor * build_inp() {
+        ggml_tensor * inp_raw = build_inp_raw();
+        ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+        inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd);
+        inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
+        if (model.patch_bias) {
+            inp = ggml_add(ctx0, inp, model.patch_bias);
+            cb(inp, "patch_bias", -1);
+        }
+        return inp;
+    }
+
+    ggml_tensor * build_inp_raw(int channels = 3) {
+        ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, channels);
+        ggml_set_name(inp_raw, "inp_raw");
+        ggml_set_input(inp_raw);
+        return inp_raw;
+    }
+
+    ggml_tensor * build_norm(
+            ggml_tensor * cur,
+            ggml_tensor * mw,
+            ggml_tensor * mb,
+            norm_type type,
+            float norm_eps,
+            int il) const {
+
+        cur = type == NORM_TYPE_RMS
+            ? ggml_rms_norm(ctx0, cur, norm_eps)
+            : ggml_norm(ctx0, cur, norm_eps);
+
+        if (mw || mb) {
+            cb(cur, "norm", il);
+        }
+
+        if (mw) {
+            cur = ggml_mul(ctx0, cur, mw);
+            if (mb) {
+                cb(cur, "norm_w", il);
+            }
+        }
+
+        if (mb) {
+            cur = ggml_add(ctx0, cur, mb);
+        }
+
+        return cur;
+    }
+
+    ggml_tensor * build_ffn(
+            ggml_tensor * cur,
+            ggml_tensor * up,
+            ggml_tensor * up_b,
+            ggml_tensor * gate,
+            ggml_tensor * gate_b,
+            ggml_tensor * down,
+            ggml_tensor * down_b,
+            ffn_op_type type_op,
+            int il) const {
+
+        ggml_tensor * tmp = up ? ggml_mul_mat(ctx0, up, cur) : cur;
+        cb(tmp, "ffn_up", il);
+
+        if (up_b) {
+            tmp = ggml_add(ctx0, tmp, up_b);
+            cb(tmp, "ffn_up_b", il);
+        }
+
+        if (gate) {
+            cur = ggml_mul_mat(ctx0, gate, cur);
+            cb(cur, "ffn_gate", il);
+
+            if (gate_b) {
+                cur = ggml_add(ctx0, cur, gate_b);
+                cb(cur, "ffn_gate_b", il);
+            }
+        } else {
+            cur = tmp;
+        }
+
+        switch (type_op) {
+            case FFN_SILU:
+                {
+                    cur = ggml_silu(ctx0, cur);
+                    cb(cur, "ffn_silu", il);
+                } break;
+            case FFN_GELU:
+                {
+                    cur = ggml_gelu(ctx0, cur);
+                    cb(cur, "ffn_gelu", il);
+                } break;
+            case FFN_GELU_ERF:
+                {
+                    cur = ggml_gelu_erf(ctx0, cur);
+                    cb(cur, "ggml_gelu_erf", il);
+                } break;
+            case FFN_GELU_QUICK:
+                {
+                    cur = ggml_gelu_quick(ctx0, cur);
+                    cb(cur, "ffn_relu", il);
+                } break;
+        }
+
+        // we only support parallel ffn for now
+        if (gate) {
+            cur = ggml_mul(ctx0, cur, tmp);
+            cb(cur, "ffn_gate_par", il);
+        }
+
+        if (down) {
+            cur = ggml_mul_mat(ctx0, down, cur);
+        }
+
+        if (down_b) {
+            cb(cur, "ffn_down", il);
+        }
+
+        if (down_b) {
+            cur = ggml_add(ctx0, cur, down_b);
+        }
+
+        return cur;
+    }
+
+    ggml_tensor * build_attn(
+            ggml_tensor * wo,
+            ggml_tensor * wo_b,
+            ggml_tensor * q_cur,
+            ggml_tensor * k_cur,
+            ggml_tensor * v_cur,
+            ggml_tensor * kq_mask,
+            float kq_scale,
+            int il) const {
+        // these nodes are added to the graph together so that they are not reordered
+        // by doing so, the number of splits in the graph is reduced
+        ggml_build_forward_expand(gf, q_cur);
+        ggml_build_forward_expand(gf, k_cur);
+        ggml_build_forward_expand(gf, v_cur);
+
+        ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
+        //cb(q, "q", il);
+
+        ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
+        //cb(k, "k", il);
+
+        ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
+        v = ggml_cont(ctx0, v);
+        //cb(k, "v", il);
+
+        ggml_tensor * cur;
+
+        // TODO @ngxson : support flash attention
+        {
+            const auto n_tokens = q->ne[1];
+            const auto n_head   = q->ne[2];
+            // const auto n_kv     = k->ne[1]; // for flash attention
+
+            ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+            // F32 may not needed for vision encoders?
+            // ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
+
+            kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f);
+
+            ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
+            cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+            cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
+        }
+
+        cb(cur, "kqv_out", il);
+
+        if (wo) {
+            cur = ggml_mul_mat(ctx0, wo, cur);
+        }
+
+        if (wo_b) {
+            cur = ggml_add(ctx0, cur, wo_b);
+        }
+
+        return cur;
+    }
+
+    // implementation of the 2D RoPE without adding a new op in ggml
+    // this is not efficient (use double the memory), but works on all backends
+    // TODO: there was a more efficient which relies on ggml_view and ggml_rope_ext_inplace, but the rope inplace does not work well with non-contiguous tensors ; we should fix that and revert back to the original implementation in https://github.com/ggml-org/llama.cpp/pull/13065
+    static ggml_tensor * build_rope_2d(
+        ggml_context * ctx0,
+        ggml_tensor * cur,
+        ggml_tensor * pos_a, // first half
+        ggml_tensor * pos_b, // second half
+        const float freq_base,
+        const bool interleave_freq
+    ) {
+        const int64_t n_dim  = cur->ne[0];
+        const int64_t n_head = cur->ne[1];
+        const int64_t n_pos  = cur->ne[2];
+
+        // for example, if we have cur tensor of shape (n_dim=8, n_head, n_pos)
+        // we will have a list of 4 inv_freq: 1e-0, 1e-1, 1e-2, 1e-3
+        // first half of cur will use 1e-0, 1e-2 (even)
+        // second half of cur will use 1e-1, 1e-3 (odd)
+        // the trick here is to rotate just half of n_dim, so inv_freq will automatically be even
+        //  ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
+        // then for the second half, we use freq_scale to shift the inv_freq
+        //  ^ why? replace (2i) with (2i+1) in the above equation
+        const float freq_scale_odd = interleave_freq
+                                    ? std::pow(freq_base, (float)-2/n_dim)
+                                    : 1.0;
+
+        // first half
+        ggml_tensor * first;
+        {
+            first = ggml_view_3d(ctx0, cur,
+                n_dim/2, n_head, n_pos,
+                ggml_row_size(cur->type, n_dim),
+                ggml_row_size(cur->type, n_dim*n_head),
+                0);
+            first = ggml_rope_ext(
+                ctx0,
+                first,
+                pos_a,      // positions
+                nullptr,    // freq factors
+                n_dim/2,    // n_dims
+                0, 0, freq_base,
+                1.0f, 0.0f, 1.0f, 0.0f, 0.0f
+            );
+        }
+
+        // second half
+        ggml_tensor * second;
+        {
+            second = ggml_view_3d(ctx0, cur,
+                n_dim/2, n_head, n_pos,
+                ggml_row_size(cur->type, n_dim),
+                ggml_row_size(cur->type, n_dim*n_head),
+                n_dim/2 * ggml_element_size(cur));
+            second = ggml_cont(ctx0, second); // copy, because ggml_rope don't play well with non-contiguous tensors
+            second = ggml_rope_ext(
+                ctx0,
+                second,
+                pos_b,      // positions
+                nullptr,    // freq factors
+                n_dim/2,    // n_dims
+                0, 0, freq_base,
+                freq_scale_odd,
+                0.0f, 1.0f, 0.0f, 0.0f
+            );
+        }
+
+        cur = ggml_concat(ctx0, first, second, 0);
+        return cur;
+    }
+
+};
+
+static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
+    GGML_ASSERT(imgs.entries.size() == 1 && "n_batch > 1 is not supported");
+    clip_graph graph(ctx, *imgs.entries[0]);
+
+    ggml_cgraph * res;
+
+    switch (ctx->proj_type()) {
+        case PROJECTOR_TYPE_GEMMA3:
+        case PROJECTOR_TYPE_IDEFICS3:
+            {
+                res = graph.build_siglip();
+            } break;
+        case PROJECTOR_TYPE_PIXTRAL:
+            {
+                res = graph.build_pixtral();
+            } break;
+        case PROJECTOR_TYPE_QWEN2VL:
+        case PROJECTOR_TYPE_QWEN25VL:
+            {
+                res = graph.build_qwen2vl();
+            } break;
+        case PROJECTOR_TYPE_MINICPMV:
+            {
+                res = graph.build_minicpmv();
+            } break;
+        case PROJECTOR_TYPE_INTERNVL:
+            {
+                res = graph.build_internvl();
+            } break;
+        case PROJECTOR_TYPE_LLAMA4:
+            {
+                res = graph.build_llama4();
+            } break;
+        case PROJECTOR_TYPE_ULTRAVOX:
+        case PROJECTOR_TYPE_QWEN2A:
+            {
+                res = graph.build_whisper_enc();
+            } break;
+        default:
+            {
+                res = graph.build_llava();
+            } break;
+    }
+    return res;
+}
+
+struct clip_model_loader {
+    ggml_context_ptr ctx_meta;
+    gguf_context_ptr ctx_gguf;
+
+    std::string fname;
+
+    size_t model_size = 0; // in bytes
+
+    bool has_vision = false;
+    bool has_audio  = false;
+
+    // TODO @ngxson : we should not pass clip_ctx here, it should be clip_model
+    clip_model_loader(const char * fname) : fname(fname) {
+        struct ggml_context * meta = nullptr;
+
+        struct gguf_init_params params = {
+            /*.no_alloc = */ true,
+            /*.ctx      = */ &meta,
+        };
+
+        ctx_gguf = gguf_context_ptr(gguf_init_from_file(fname, params));
+        if (!ctx_gguf.get()) {
+            throw std::runtime_error(string_format("%s: failed to load CLIP model from %s. Does this file exist?\n", __func__, fname));
+        }
+
+        ctx_meta.reset(meta);
+
+        const int n_tensors = gguf_get_n_tensors(ctx_gguf.get());
+
+        // print gguf info
+        {
+            std::string name;
+            get_string(KEY_NAME, name, false);
+            std::string description;
+            get_string(KEY_DESCRIPTION, description, false);
+            LOG_INF("%s: model name:   %s\n",  __func__, name.c_str());
+            LOG_INF("%s: description:  %s\n",  __func__, description.c_str());
+            LOG_INF("%s: GGUF version: %d\n",  __func__, gguf_get_version(ctx_gguf.get()));
+            LOG_INF("%s: alignment:    %zu\n", __func__, gguf_get_alignment(ctx_gguf.get()));
+            LOG_INF("%s: n_tensors:    %d\n",  __func__, n_tensors);
+            LOG_INF("%s: n_kv:         %d\n",  __func__, (int)gguf_get_n_kv(ctx_gguf.get()));
+            LOG_INF("\n");
+        }
+
+        // modalities
+        {
+            get_bool(KEY_HAS_VISION_ENC, has_vision, false);
+            get_bool(KEY_HAS_AUDIO_ENC,  has_audio,  false);
+
+            if (has_vision) {
+                LOG_INF("%s: has vision encoder\n", __func__);
+            }
+            if (has_audio) {
+                LOG_INF("%s: has audio encoder\n", __func__);
+            }
+        }
+
+        // tensors
+        {
+            for (int i = 0; i < n_tensors; ++i) {
+                const char * name = gguf_get_tensor_name(ctx_gguf.get(), i);
+                const size_t offset = gguf_get_tensor_offset(ctx_gguf.get(), i);
+                enum ggml_type type = gguf_get_tensor_type(ctx_gguf.get(), i);
+                ggml_tensor * cur = ggml_get_tensor(meta, name);
+                size_t tensor_size = ggml_nbytes(cur);
+                model_size += tensor_size;
+                LOG_DBG("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, offset=%zu, shape:[%" PRIu64 ", %" PRIu64 ", %" PRIu64 ", %" PRIu64 "], type = %s\n",
+                    __func__, i, ggml_n_dims(cur), cur->name, tensor_size, offset, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_type_name(type));
+            }
+        }
+    }
+
+    void load_hparams(clip_model & model, clip_modality modality) {
+        auto & hparams = model.hparams;
+        std::string log_ffn_op; // for logging
+
+        // sanity check
+        if (modality == CLIP_MODALITY_VISION) {
+            GGML_ASSERT(has_vision);
+        } else if (modality == CLIP_MODALITY_AUDIO) {
+            GGML_ASSERT(has_audio);
+        }
+        model.modality = modality;
+
+
+        // projector type
+        std::string proj_type;
+        {
+            get_string(KEY_PROJ_TYPE, proj_type, false);
+            if (!proj_type.empty()) {
+                model.proj_type = clip_projector_type_from_string(proj_type);
+            }
+            if (model.proj_type == PROJECTOR_TYPE_UNKNOWN) {
+                throw std::runtime_error(string_format("%s: unknown projector type: %s\n", __func__, proj_type.c_str()));
+            }
+
+            // correct arch for multimodal models
+            if (model.proj_type == PROJECTOR_TYPE_QWEN25O) {
+                model.proj_type = modality == CLIP_MODALITY_VISION
+                                    ? PROJECTOR_TYPE_QWEN25VL
+                                    : PROJECTOR_TYPE_QWEN2A;
+            }
+        }
+
+        const bool is_vision = model.modality == CLIP_MODALITY_VISION;
+        const bool is_audio  = model.modality == CLIP_MODALITY_AUDIO;
+
+        // other hparams
+        {
+            const char * prefix = is_vision ? "vision" : "audio";
+            get_u32(string_format(KEY_N_EMBD,         prefix), hparams.n_embd);
+            get_u32(string_format(KEY_N_HEAD,         prefix), hparams.n_head);
+            get_u32(string_format(KEY_N_FF,           prefix), hparams.n_ff);
+            get_u32(string_format(KEY_N_BLOCK,        prefix), hparams.n_layer);
+            get_u32(string_format(KEY_PROJ_DIM,       prefix), hparams.projection_dim);
+            get_f32(string_format(KEY_LAYER_NORM_EPS, prefix), hparams.eps);
+
+            if (is_vision) {
+                get_u32(KEY_IMAGE_SIZE, hparams.image_size);
+                get_u32(KEY_PATCH_SIZE, hparams.patch_size);
+                get_u32(KEY_IMAGE_CROP_RESOLUTION,    hparams.image_crop_resolution, false);
+                get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
+                get_i32(KEY_MINICPMV_VERSION, hparams.minicpmv_version, false); // legacy
+
+            } else if (is_audio) {
+                get_u32(KEY_A_NUM_MEL_BINS, hparams.n_mel_bins);
+
+            } else {
+                GGML_ASSERT(false && "unknown modality");
+            }
+
+            // default warmup value
+            hparams.warmup_image_size = hparams.image_size;
+
+            hparams.has_llava_projector = model.proj_type == PROJECTOR_TYPE_MLP
+                                       || model.proj_type == PROJECTOR_TYPE_MLP_NORM
+                                       || model.proj_type == PROJECTOR_TYPE_LDP
+                                       || model.proj_type == PROJECTOR_TYPE_LDPV2;
+
+            {
+                bool use_gelu = false;
+                bool use_silu = false;
+                get_bool(KEY_USE_GELU, use_gelu, false);
+                get_bool(KEY_USE_SILU, use_silu, false);
+                if (use_gelu && use_silu) {
+                    throw std::runtime_error(string_format("%s: both use_gelu and use_silu are set to true\n", __func__));
+                }
+                if (use_gelu) {
+                    hparams.ffn_op = FFN_GELU;
+                    log_ffn_op = "gelu";
+                } else if (use_silu) {
+                    hparams.ffn_op = FFN_SILU;
+                    log_ffn_op = "silu";
+                } else {
+                    hparams.ffn_op = FFN_GELU_QUICK;
+                    log_ffn_op = "gelu_quick";
+                }
+            }
+
+            {
+                std::string mm_patch_merge_type;
+                get_string(KEY_MM_PATCH_MERGE_TYPE, mm_patch_merge_type, false);
+                if (mm_patch_merge_type == "spatial_unpad") {
+                    hparams.mm_patch_merge_type = PATCH_MERGE_SPATIAL_UNPAD;
+                }
+            }
+
+            if (is_vision) {
+                int idx_mean = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_MEAN);
+                int idx_std  = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_STD);
+                GGML_ASSERT(idx_mean >= 0 && "image_mean not found");
+                GGML_ASSERT(idx_std >= 0  && "image_std not found");
+                const float * mean_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_mean);
+                const float * std_data  = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_std);
+                for (int i = 0; i < 3; ++i) {
+                    hparams.image_mean[i] = mean_data[i];
+                    hparams.image_std[i]  = std_data[i];
+                }
+            }
+
+            // Load the vision feature layer indices if they are explicitly provided;
+            // if multiple vision feature layers are present, the values will be concatenated
+            // to form the final visual features.
+            // NOTE: gguf conversions should standardize the values of the vision feature layer to
+            // be non-negative, since we use -1 to mark values as unset here.
+            std::vector vision_feature_layer;
+            get_arr_int(KEY_FEATURE_LAYER, vision_feature_layer, false);
+            // convert std::vector to std::unordered_set
+            for (auto & layer : vision_feature_layer) {
+                hparams.vision_feature_layer.insert(layer);
+            }
+
+            // model-specific params
+            switch (model.proj_type) {
+                case PROJECTOR_TYPE_MINICPMV:
+                    {
+                        if (hparams.minicpmv_version == 0) {
+                            hparams.minicpmv_version = 2; // default to 2 if not set
+                        }
+                    } break;
+                case PROJECTOR_TYPE_IDEFICS3:
+                case PROJECTOR_TYPE_INTERNVL:
+                    {
+                        get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
+                    } break;
+                case PROJECTOR_TYPE_PIXTRAL:
+                    {
+                        hparams.rope_theta = 10000.0f;
+                        hparams.warmup_image_size = hparams.patch_size * 8;
+                        get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
+                    } break;
+                case PROJECTOR_TYPE_GEMMA3:
+                    {
+                        // default value (used by all model sizes in gemma 3 family)
+                        // number of patches for each **side** is reduced by a factor of 4
+                        hparams.proj_scale_factor = 4;
+                        // test model (tinygemma3) has a different value, we optionally read it
+                        get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
+                    } break;
+                case PROJECTOR_TYPE_QWEN2VL:
+                    {
+                        // max image size = sqrt(max_pixels) = 3584
+                        // ref: https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/blob/main/preprocessor_config.json
+                        // however, the model use unreasonable memory past 1024 size, we force it to 1024 otherwise it's unusable
+                        // ref: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct/discussions/10
+                        hparams.image_size = 1024;
+                        hparams.warmup_image_size = hparams.patch_size * 8;
+                    } break;
+                case PROJECTOR_TYPE_QWEN25VL:
+                    {
+                        // max image size = sqrt(max_pixels)
+                        // https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct/blob/main/preprocessor_config.json
+                        // however, the model use unreasonable memory past 1024 size, we force it to 1024 otherwise it's unusable
+                        // ref: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct/discussions/10
+                        hparams.image_size = 1024;
+                        hparams.warmup_image_size = hparams.patch_size * 8;
+                        get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern);
+                    } break;
+                case PROJECTOR_TYPE_LLAMA4:
+                    {
+                        hparams.rope_theta = 10000.0f;
+                        get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor);
+
+                        // borrowed from llava-1.6
+                        const int isize = hparams.image_size;
+                        hparams.image_grid_pinpoints = {
+                            isize,   isize*2, // 336, 672
+                            isize*2, isize,   // 672, 336
+                            isize*2, isize*2, // 672, 672
+                            isize*3, isize,   // 1008, 336
+                            isize,   isize*3, // 336, 1008
+                        };
+                    } break;
+                case PROJECTOR_TYPE_ULTRAVOX:
+                case PROJECTOR_TYPE_QWEN2A:
+                    {
+                        bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX;
+                        get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack);
+                        if (hparams.n_mel_bins != 128) {
+                            throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__));
+                        }
+                        hparams.ffn_op = FFN_GELU_ERF;
+                        log_ffn_op = "gelu_erf"; // temporary solution for logging
+                    } break;
+                default:
+                    break;
+            }
+
+            LOG_INF("%s: projector:          %s\n", __func__, proj_type.c_str());
+            LOG_INF("%s: n_embd:             %d\n", __func__, hparams.n_embd);
+            LOG_INF("%s: n_head:             %d\n", __func__, hparams.n_head);
+            LOG_INF("%s: n_ff:               %d\n", __func__, hparams.n_ff);
+            LOG_INF("%s: n_layer:            %d\n", __func__, hparams.n_layer);
+            LOG_INF("%s: ffn_op:             %s\n", __func__, log_ffn_op.c_str());
+            LOG_INF("%s: projection_dim:     %d\n", __func__, hparams.projection_dim);
+            if (is_vision) {
+                LOG_INF("\n--- vision hparams ---\n");
+                LOG_INF("%s: image_size:         %d\n", __func__, hparams.image_size);
+                LOG_INF("%s: patch_size:         %d\n", __func__, hparams.patch_size);
+                LOG_INF("%s: has_llava_proj:     %d\n", __func__, hparams.has_llava_projector);
+                LOG_INF("%s: minicpmv_version:   %d\n", __func__, hparams.minicpmv_version);
+                LOG_INF("%s: proj_scale_factor:  %d\n", __func__, hparams.proj_scale_factor);
+                LOG_INF("%s: n_wa_pattern:       %d\n", __func__, hparams.n_wa_pattern);
+            } else if (is_audio) {
+                LOG_INF("\n--- audio hparams ---\n");
+                LOG_INF("%s: n_mel_bins:         %d\n", __func__, hparams.n_mel_bins);
+                LOG_INF("%s: proj_stack_factor:  %d\n", __func__, hparams.proj_stack_factor);
+            }
+            LOG_INF("\n");
+            LOG_INF("%s: model size:         %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
+            LOG_INF("%s: metadata size:      %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
+        }
+    }
+
+    void load_tensors(clip_ctx & ctx_clip) {
+        auto & model = ctx_clip.model;
+        auto & hparams = model.hparams;
+        std::map tensor_offset;
+        std::vector tensors_to_load;
+
+        // TODO @ngxson : support both audio and video in the future
+        const char * prefix = model.modality == CLIP_MODALITY_AUDIO ? "a" : "v";
+
+        // get offsets
+        for (int64_t i = 0; i < gguf_get_n_tensors(ctx_gguf.get()); ++i) {
+            const char * name = gguf_get_tensor_name(ctx_gguf.get(), i);
+            tensor_offset[name] = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), i);
+        }
+
+        // create data context
+        struct ggml_init_params params = {
+            /*.mem_size =*/ (gguf_get_n_tensors(ctx_gguf.get()) + 1) * ggml_tensor_overhead(),
+            /*.mem_buffer =*/ NULL,
+            /*.no_alloc =*/ true,
+        };
+        ctx_clip.ctx_data.reset(ggml_init(params));
+        if (!ctx_clip.ctx_data) {
+            throw std::runtime_error(string_format("%s: failed to init ggml context\n", __func__));
+        }
+
+        // helper function
+        auto get_tensor = [&](const std::string & name, bool required = true) {
+            ggml_tensor * cur = ggml_get_tensor(ctx_meta.get(), name.c_str());
+            if (!cur && required) {
+                throw std::runtime_error(string_format("%s: unable to find tensor %s\n", __func__, name.c_str()));
+            }
+            if (cur) {
+                tensors_to_load.push_back(cur);
+                // add tensors to context
+                ggml_tensor * data_tensor = ggml_dup_tensor(ctx_clip.ctx_data.get(), cur);
+                ggml_set_name(data_tensor, cur->name);
+                cur = data_tensor;
+            }
+            return cur;
+        };
+
+        model.class_embedding = get_tensor(TN_CLASS_EMBD, false);
+
+        model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, prefix, "weight"), false);
+        model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, prefix, "bias"),   false);
+
+        model.post_ln_w = get_tensor(string_format(TN_LN_POST, prefix, "weight"), false);
+        model.post_ln_b = get_tensor(string_format(TN_LN_POST, prefix, "bias"),   false);
+
+        model.patch_bias = get_tensor(TN_PATCH_BIAS, false);
+        model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD,   false);
+        model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false);
+
+        model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, prefix), false);
+
+        // layers
+        model.layers.resize(hparams.n_layer);
+        for (int il = 0; il < hparams.n_layer; ++il) {
+            auto & layer = model.layers[il];
+            layer.k_w    = get_tensor(string_format(TN_ATTN_K,      prefix, il, "weight"));
+            layer.q_w    = get_tensor(string_format(TN_ATTN_Q,      prefix, il, "weight"));
+            layer.v_w    = get_tensor(string_format(TN_ATTN_V,      prefix, il, "weight"));
+            layer.o_w    = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "weight"));
+            layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, prefix, il, "weight"), false);
+            layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, prefix, il, "weight"), false);
+            layer.ln_1_w = get_tensor(string_format(TN_LN_1,        prefix, il, "weight"), false);
+            layer.ln_2_w = get_tensor(string_format(TN_LN_2,        prefix, il, "weight"), false);
+            layer.ls_1_w = get_tensor(string_format(TN_LS_1,        prefix, il, "weight"), false); // no bias
+            layer.ls_2_w = get_tensor(string_format(TN_LS_2,        prefix, il, "weight"), false); // no bias
+
+            layer.k_b    = get_tensor(string_format(TN_ATTN_K,      prefix, il, "bias"), false);
+            layer.q_b    = get_tensor(string_format(TN_ATTN_Q,      prefix, il, "bias"), false);
+            layer.v_b    = get_tensor(string_format(TN_ATTN_V,      prefix, il, "bias"), false);
+            layer.o_b    = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "bias"), false);
+            layer.ln_1_b = get_tensor(string_format(TN_LN_1,        prefix, il, "bias"), false);
+            layer.ln_2_b = get_tensor(string_format(TN_LN_2,        prefix, il, "bias"), false);
+
+            // ffn
+            layer.ff_up_w   = get_tensor(string_format(TN_FFN_UP,   prefix, il, "weight"));
+            layer.ff_up_b   = get_tensor(string_format(TN_FFN_UP,   prefix, il, "bias"),   false);
+            layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, prefix, il, "weight"), false);
+            layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, prefix, il, "bias"),   false);
+            layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight"));
+            layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"),   false);
+
+            // some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here
+            // note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check!
+            if (layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd) {
+                // swap up and down weights
+                ggml_tensor * tmp = layer.ff_up_w;
+                layer.ff_up_w = layer.ff_down_w;
+                layer.ff_down_w = tmp;
+                // swap up and down biases
+                tmp = layer.ff_up_b;
+                layer.ff_up_b = layer.ff_down_b;
+                layer.ff_down_b = tmp;
+            }
+        }
+
+        switch (model.proj_type) {
+            case PROJECTOR_TYPE_MLP:
+            case PROJECTOR_TYPE_MLP_NORM:
+                {
+                    // LLaVA projection
+                    model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"), false);
+                    model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false);
+                    // Yi-type llava
+                    model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"), false);
+                    model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
+                    // missing in Yi-type llava
+                    model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"), false);
+                    model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
+                    // Yi-type llava
+                    model.mm_3_w = get_tensor(string_format(TN_LLAVA_PROJ, 3, "weight"), false);
+                    model.mm_3_b = get_tensor(string_format(TN_LLAVA_PROJ, 3, "bias"), false);
+                    model.mm_4_w = get_tensor(string_format(TN_LLAVA_PROJ, 4, "weight"), false);
+                    model.mm_4_b = get_tensor(string_format(TN_LLAVA_PROJ, 4, "bias"), false);
+                    if (model.mm_3_w) {
+                        // TODO: this is a hack to support Yi-type llava
+                        model.proj_type = PROJECTOR_TYPE_MLP_NORM;
+                    }
+                    model.image_newline = get_tensor(TN_IMAGE_NEWLINE, false);
+                } break;
+            case PROJECTOR_TYPE_LDP:
+                {
+                    // MobileVLM projection
+                    model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
+                    model.mm_model_mlp_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias"));
+                    model.mm_model_mlp_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
+                    model.mm_model_mlp_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
+                    model.mm_model_block_1_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "0.weight"));
+                    model.mm_model_block_1_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.weight"));
+                    model.mm_model_block_1_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.bias"));
+                    model.mm_model_block_1_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.weight"));
+                    model.mm_model_block_1_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.bias"));
+                    model.mm_model_block_1_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.weight"));
+                    model.mm_model_block_1_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.bias"));
+                    model.mm_model_block_1_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "0.weight"));
+                    model.mm_model_block_1_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.weight"));
+                    model.mm_model_block_1_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.bias"));
+                    model.mm_model_block_2_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "0.weight"));
+                    model.mm_model_block_2_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.weight"));
+                    model.mm_model_block_2_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.bias"));
+                    model.mm_model_block_2_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.weight"));
+                    model.mm_model_block_2_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.bias"));
+                    model.mm_model_block_2_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.weight"));
+                    model.mm_model_block_2_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.bias"));
+                    model.mm_model_block_2_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "0.weight"));
+                    model.mm_model_block_2_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.weight"));
+                    model.mm_model_block_2_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.bias"));
+                } break;
+            case PROJECTOR_TYPE_LDPV2:
+                {
+                    // MobilVLM_V2 projection
+                    model.mm_model_mlp_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
+                    model.mm_model_mlp_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias"));
+                    model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight"));
+                    model.mm_model_mlp_2_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "bias"));
+                    model.mm_model_peg_0_w = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "weight"));
+                    model.mm_model_peg_0_b = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "bias"));
+                } break;
+            case PROJECTOR_TYPE_MINICPMV:
+                {
+                    // model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD);
+                    model.mm_model_pos_embed_k = get_tensor(TN_MINICPMV_POS_EMBD_K);
+                    model.mm_model_query = get_tensor(TN_MINICPMV_QUERY);
+                    model.mm_model_proj = get_tensor(TN_MINICPMV_PROJ);
+                    model.mm_model_kv_proj = get_tensor(TN_MINICPMV_KV_PROJ);
+                    model.mm_model_attn_q_w = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "weight"));
+                    model.mm_model_attn_k_w = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "weight"));
+                    model.mm_model_attn_v_w = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "weight"));
+                    model.mm_model_attn_q_b = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "bias"));
+                    model.mm_model_attn_k_b = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "bias"));
+                    model.mm_model_attn_v_b = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "bias"));
+                    model.mm_model_attn_o_w = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "weight"));
+                    model.mm_model_attn_o_b = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "bias"));
+                    model.mm_model_ln_q_w = get_tensor(string_format(TN_MINICPMV_LN, "q", "weight"));
+                    model.mm_model_ln_q_b = get_tensor(string_format(TN_MINICPMV_LN, "q", "bias"));
+                    model.mm_model_ln_kv_w = get_tensor(string_format(TN_MINICPMV_LN, "kv", "weight"));
+                    model.mm_model_ln_kv_b = get_tensor(string_format(TN_MINICPMV_LN, "kv", "bias"));
+                    model.mm_model_ln_post_w = get_tensor(string_format(TN_MINICPMV_LN, "post", "weight"));
+                    model.mm_model_ln_post_b = get_tensor(string_format(TN_MINICPMV_LN, "post", "bias"));
+                } break;
+            case PROJECTOR_TYPE_GLM_EDGE:
+                {
+                    model.mm_model_adapter_conv_w = get_tensor(string_format(TN_GLM_ADAPER_CONV, "weight"));
+                    model.mm_model_adapter_conv_b = get_tensor(string_format(TN_GLM_ADAPER_CONV, "bias"));
+                    model.mm_model_mlp_0_w = get_tensor(string_format(TN_GLM_ADAPTER_LINEAR, "weight"));
+                    model.mm_model_ln_q_w = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1, "weight"));
+                    model.mm_model_ln_q_b = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1, "bias"));
+                    model.mm_model_mlp_1_w = get_tensor(string_format(TN_GLM_ADAPTER_D_H_2_4H, "weight"));
+                    model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE, "weight"));
+                    model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H, "weight"));
+                    model.mm_glm_tok_boi = get_tensor(string_format(TN_TOK_GLM_BOI, "weight"));
+                    model.mm_glm_tok_eoi = get_tensor(string_format(TN_TOK_GLM_EOI, "weight"));
+                } break;
+            case PROJECTOR_TYPE_QWEN2VL:
+            case PROJECTOR_TYPE_QWEN25VL:
+                {
+                    model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
+                    model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
+                    model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
+                    model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
+                } break;
+            case PROJECTOR_TYPE_GEMMA3:
+                {
+                    model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ);
+                    model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N);
+                } break;
+            case PROJECTOR_TYPE_IDEFICS3:
+                {
+                    model.projection = get_tensor(TN_MM_PROJECTOR);
+                } break;
+            case PROJECTOR_TYPE_PIXTRAL:
+                {
+                    model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
+                    model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
+                    model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
+                    model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
+                    // [IMG_BREAK] token embedding
+                    model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
+                    // for mistral small 3.1
+                    model.mm_input_norm_w   = get_tensor(TN_MM_INP_NORM,     false);
+                    model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
+                } break;
+            case PROJECTOR_TYPE_ULTRAVOX:
+                {
+                    model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
+                    model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
+                    model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
+                    model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
+                    model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
+                    model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
+                    model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
+                    model.mm_norm_mid_w = get_tensor(string_format(TN_MM_NORM_MID, "weight"));
+                } break;
+            case PROJECTOR_TYPE_QWEN2A:
+                {
+                    model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
+                    model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
+                    model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
+                    model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
+                    model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight"));
+                    model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias"));
+                } break;
+            case PROJECTOR_TYPE_INTERNVL:
+                {
+                    model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
+                    model.mm_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias"));
+                    model.mm_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
+                    model.mm_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias"));
+                    model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
+                    model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
+                } break;
+            case PROJECTOR_TYPE_LLAMA4:
+                {
+                    model.mm_model_proj    = get_tensor(TN_MM_PROJECTOR);
+                    model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
+                    model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight"));
+                } break;
+            default:
+                GGML_ASSERT(false && "unknown projector type");
+        }
+
+        // load data
+        {
+            std::vector read_buf;
+
+            auto fin = std::ifstream(fname, std::ios::binary);
+            if (!fin) {
+                throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str()));
+            }
+
+            // alloc memory and offload data
+            ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend);
+            ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft));
+            ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
+            for (auto & t : tensors_to_load) {
+                ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name);
+                const size_t offset = tensor_offset[t->name];
+                fin.seekg(offset, std::ios::beg);
+                if (!fin) {
+                    throw std::runtime_error(string_format("%s: failed to seek for tensor %s\n", __func__, t->name));
+                }
+                size_t num_bytes = ggml_nbytes(cur);
+                if (ggml_backend_buft_is_host(buft)) {
+                    // for the CPU and Metal backend, we can read directly into the tensor
+                    fin.read(reinterpret_cast(cur->data), num_bytes);
+                } else {
+                    // read into a temporary buffer first, then copy to device memory
+                    read_buf.resize(num_bytes);
+                    fin.read(reinterpret_cast(read_buf.data()), num_bytes);
+                    ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
+                }
+            }
+            fin.close();
+
+            LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str());
+        }
+    }
+
+    void alloc_compute_meta(clip_ctx & ctx_clip) {
+        const auto & hparams = ctx_clip.model.hparams;
+        ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
+
+        // create a fake batch
+        clip_image_f32_batch batch;
+        clip_image_f32_ptr img(clip_image_f32_init());
+        if (ctx_clip.model.modality == CLIP_MODALITY_VISION) {
+            img->nx = hparams.warmup_image_size;
+            img->ny = hparams.warmup_image_size;
+        } else {
+            img->nx = hparams.warmup_audio_size;
+            img->ny = hparams.n_mel_bins;
+        }
+        batch.entries.push_back(std::move(img));
+
+        ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch);
+        ggml_backend_sched_reserve(ctx_clip.sched.get(), gf);
+
+        for (size_t i = 0; i < ctx_clip.backend_ptrs.size(); ++i) {
+            ggml_backend_t backend = ctx_clip.backend_ptrs[i];
+            ggml_backend_buffer_type_t buft = ctx_clip.backend_buft[i];
+            size_t size = ggml_backend_sched_get_buffer_size(ctx_clip.sched.get(), backend);
+            if (size > 1) {
+                LOG_INF("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
+                        ggml_backend_buft_name(buft),
+                        size / 1024.0 / 1024.0);
+            }
+        }
+    }
+
+    void get_bool(const std::string & key, bool & output, bool required = true) {
+        const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
+        if (i < 0) {
+            if (required) throw std::runtime_error("Key not found: " + key);
+            return;
+        }
+        output = gguf_get_val_bool(ctx_gguf.get(), i);
+    }
+
+    void get_i32(const std::string & key, int & output, bool required = true) {
+        const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
+        if (i < 0) {
+            if (required) throw std::runtime_error("Key not found: " + key);
+            return;
+        }
+        output = gguf_get_val_i32(ctx_gguf.get(), i);
+    }
+
+    void get_u32(const std::string & key, int & output, bool required = true) {
+        const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
+        if (i < 0) {
+            if (required) throw std::runtime_error("Key not found: " + key);
+            return;
+        }
+        output = gguf_get_val_u32(ctx_gguf.get(), i);
+    }
+
+    void get_f32(const std::string & key, float & output, bool required = true) {
+        const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
+        if (i < 0) {
+            if (required) throw std::runtime_error("Key not found: " + key);
+            return;
+        }
+        output = gguf_get_val_f32(ctx_gguf.get(), i);
+    }
+
+    void get_string(const std::string & key, std::string & output, bool required = true) {
+        const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
+        if (i < 0) {
+            if (required) throw std::runtime_error("Key not found: " + key);
+            return;
+        }
+        output = std::string(gguf_get_val_str(ctx_gguf.get(), i));
+    }
+
+    void get_arr_int(const std::string & key, std::vector & output, bool required = true) {
+        const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
+        if (i < 0) {
+            if (required) throw std::runtime_error("Key not found: " + key);
+            return;
+        }
+        int n = gguf_get_arr_n(ctx_gguf.get(), i);
+        output.resize(n);
+        const int32_t * values = (const int32_t *)gguf_get_arr_data(ctx_gguf.get(), i);
+        for (int i = 0; i < n; ++i) {
+            output[i] = values[i];
+        }
+    }
+};
+
+struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_params) {
+    g_logger_state.verbosity_thold = ctx_params.verbosity;
+    clip_ctx * ctx_vision = nullptr;
+    clip_ctx * ctx_audio = nullptr;
+
+    try {
+        clip_model_loader loader(fname);
+
+        if (loader.has_vision) {
+            ctx_vision = new clip_ctx(ctx_params);
+            loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION);
+            loader.load_tensors(*ctx_vision);
+            loader.alloc_compute_meta(*ctx_vision);
+        }
+
+        if (loader.has_audio) {
+            ctx_audio = new clip_ctx(ctx_params);
+            loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO);
+            loader.load_tensors(*ctx_audio);
+            loader.alloc_compute_meta(*ctx_audio);
+        }
+
+    } catch (const std::exception & e) {
+        LOG_ERR("%s: failed to load model '%s': %s\n", __func__, fname, e.what());
+        if (ctx_vision) {
+            delete ctx_vision;
+        }
+        if (ctx_audio) {
+            delete ctx_audio;
+        }
+        return {nullptr, nullptr};
+    }
+
+    return {ctx_vision, ctx_audio};
+}
+
+struct clip_image_size * clip_image_size_init() {
+    struct clip_image_size * load_image_size = new struct clip_image_size();
+    load_image_size->width = 448;
+    load_image_size->height = 448;
+    return load_image_size;
+}
+
+struct clip_image_u8 * clip_image_u8_init() {
+    return new clip_image_u8();
+}
+
+struct clip_image_f32 * clip_image_f32_init() {
+    return new clip_image_f32();
+}
+
+struct clip_image_f32_batch * clip_image_f32_batch_init() {
+    return new clip_image_f32_batch();
+}
+
+unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny) {
+    if (nx) *nx = img->nx;
+    if (ny) *ny = img->ny;
+    return img->buf.data();
+}
+
+void clip_image_size_free(struct clip_image_size * load_image_size) {
+    if (load_image_size == nullptr) {
+        return;
+    }
+    delete load_image_size;
+}
+void clip_image_u8_free(struct clip_image_u8  * img) { if (img) delete img; }
+void clip_image_f32_free(struct clip_image_f32 * img) { if (img) delete img; }
+void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { if (batch) delete batch; }
+void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { if (batch) delete batch; }
+
+size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch) {
+    return batch->entries.size();
+}
+
+size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx) {
+    if (idx < 0 || idx >= (int)batch->entries.size()) {
+        LOG_ERR("%s: invalid index %d\n", __func__, idx);
+        return 0;
+    }
+    return batch->entries[idx]->nx;
+}
+
+size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx) {
+    if (idx < 0 || idx >= (int)batch->entries.size()) {
+        LOG_ERR("%s: invalid index %d\n", __func__, idx);
+        return 0;
+    }
+    return batch->entries[idx]->ny;
+}
+
+clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx) {
+    if (idx < 0 || idx >= (int)batch->entries.size()) {
+        LOG_ERR("%s: invalid index %d\n", __func__, idx);
+        return nullptr;
+    }
+    return batch->entries[idx].get();
+}
+
+void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img) {
+    img->nx = nx;
+    img->ny = ny;
+    img->buf.resize(3 * nx * ny);
+    memcpy(img->buf.data(), rgb_pixels, img->buf.size());
+}
+
+// Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not
+static void normalize_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst, const float mean[3], const float std[3]) {
+    dst.nx = src.nx;
+    dst.ny = src.ny;
+    dst.buf.resize(src.buf.size());
+
+    // TODO @ngxson : seems like this could be done more efficiently on cgraph
+    for (size_t i = 0; i < src.buf.size(); ++i) {
+        int c = i % 3; // rgb
+        dst.buf[i] = (static_cast(src.buf[i]) / 255.0f - mean[c]) / std[c];
+    }
+}
+
+// set of tools to manupulate images
+// in the future, we can have HW acceleration by allowing this struct to access 3rd party lib like imagick or opencv
+struct image_manipulation {
+    // Bilinear resize function
+    static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) {
+        dst.nx = target_width;
+        dst.ny = target_height;
+        dst.buf.resize(3 * target_width * target_height);
+
+        float x_ratio = static_cast(src.nx - 1) / target_width;
+        float y_ratio = static_cast(src.ny - 1) / target_height;
+
+        for (int y = 0; y < target_height; y++) {
+            for (int x = 0; x < target_width; x++) {
+                float px = x_ratio * x;
+                float py = y_ratio * y;
+                int x_floor = static_cast(px);
+                int y_floor = static_cast(py);
+                float x_lerp = px - x_floor;
+                float y_lerp = py - y_floor;
+
+                for (int c = 0; c < 3; c++) {
+                    float top = lerp(
+                        static_cast(src.buf[3 * (y_floor * src.nx + x_floor) + c]),
+                        static_cast(src.buf[3 * (y_floor * src.nx + (x_floor + 1)) + c]),
+                        x_lerp
+                    );
+                    float bottom = lerp(
+                        static_cast(src.buf[3 * ((y_floor + 1) * src.nx + x_floor) + c]),
+                        static_cast(src.buf[3 * ((y_floor + 1) * src.nx + (x_floor + 1)) + c]),
+                        x_lerp
+                    );
+                    dst.buf[3 * (y * target_width + x) + c] = static_cast(lerp(top, bottom, y_lerp));
+                }
+            }
+        }
+    }
+
+    // Bicubic resize function
+    // part of image will be cropped if the aspect ratio is different
+    static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) {
+        const int nx = img.nx;
+        const int ny = img.ny;
+
+        dst.nx = target_width;
+        dst.ny = target_height;
+        dst.buf.resize(3 * target_width * target_height);
+
+        float Cc;
+        float C[5];
+        float d0, d2, d3, a0, a1, a2, a3;
+        int i, j, k, jj;
+        int x, y;
+        float dx, dy;
+        float tx, ty;
+
+        tx = (float)nx / (float)target_width;
+        ty = (float)ny / (float)target_height;
+
+        // Bicubic interpolation; adapted from ViT.cpp, inspired from :
+        //    -> https://github.com/yglukhov/bicubic-interpolation-image-processing/blob/master/libimage.c#L36
+        //    -> https://en.wikipedia.org/wiki/Bicubic_interpolation
+
+        for (i = 0; i < target_height; i++) {
+            for (j = 0; j < target_width; j++) {
+                x = (int)(tx * j);
+                y = (int)(ty * i);
+
+                dx = tx * j - x;
+                dy = ty * i - y;
+
+                for (k = 0; k < 3; k++) {
+                    for (jj = 0; jj <= 3; jj++) {
+                        d0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x - 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
+                        d2 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
+                        d3 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 2, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
+                        a0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
+
+                        a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3;
+                        a2 =  1.0 / 2 * d0 +      1.0 / 2 * d2;
+                        a3 = -1.0 / 6 * d0 -      1.0 / 2 * d2 + 1.0 / 6 * d3;
+
+                        C[jj] = a0 + a1 * dx + a2 * dx * dx + a3 * dx * dx * dx;
+
+                        d0 = C[0] - C[1];
+                        d2 = C[2] - C[1];
+                        d3 = C[3] - C[1];
+                        a0 = C[1];
+                        a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3;
+                        a2 =  1.0 / 2 * d0 +      1.0 / 2 * d2;
+                        a3 = -1.0 / 6 * d0 -      1.0 / 2 * d2 + 1.0 / 6 * d3;
+                        Cc = a0 + a1 * dy + a2 * dy * dy + a3 * dy * dy * dy;
+
+                        const uint8_t Cc2 = std::min(std::max(std::round(Cc), 0.0f), 255.0f);
+                        dst.buf[(i * target_width + j) * 3 + k] = float(Cc2);
+                    }
+                }
+            }
+        }
+
+        return true;
+    }
+
+    // llava-1.6 type of resize_and_pad
+    // if the ratio is not 1:1, padding with pad_color will be applied
+    // pad_color is single channel, default is 0 (black)
+    static void resize_and_pad_image(const clip_image_u8 & image, clip_image_u8 & dst, const clip_image_size & target_resolution, std::array pad_color = {0, 0, 0}) {
+        int target_width  = target_resolution.width;
+        int target_height = target_resolution.height;
+
+        float scale_w = static_cast(target_width) / image.nx;
+        float scale_h = static_cast(target_height) / image.ny;
+
+        int new_width, new_height;
+
+        if (scale_w < scale_h) {
+            new_width  = target_width;
+            new_height = std::min(static_cast(std::ceil(image.ny * scale_w)), target_height);
+        } else {
+            new_height = target_height;
+            new_width  = std::min(static_cast(std::ceil(image.nx * scale_h)), target_width);
+        }
+
+        clip_image_u8 resized_image;
+        bicubic_resize(image, resized_image, new_width, new_height);
+
+        clip_image_u8 padded_image;
+        padded_image.nx = target_width;
+        padded_image.ny = target_height;
+        padded_image.buf.resize(3 * target_width * target_height);
+
+        // Fill the padded image with the fill color
+        for (size_t i = 0; i < padded_image.buf.size(); i += 3) {
+            padded_image.buf[i]     = pad_color[0];
+            padded_image.buf[i + 1] = pad_color[1];
+            padded_image.buf[i + 2] = pad_color[2];
+        }
+
+        // Calculate padding offsets
+        int pad_x = (target_width  - new_width)  / 2;
+        int pad_y = (target_height - new_height) / 2;
+
+        // Copy the resized image into the center of the padded buffer
+        for (int y = 0; y < new_height; ++y) {
+            for (int x = 0; x < new_width; ++x) {
+                for (int c = 0; c < 3; ++c) {
+                    padded_image.buf[3 * ((y + pad_y) * target_width + (x + pad_x)) + c] = resized_image.buf[3 * (y * new_width + x) + c];
+                }
+            }
+        }
+        dst = std::move(padded_image);
+    }
+
+    static void crop_image(const clip_image_u8 & image, clip_image_u8 & dst, int x, int y, int w, int h) {
+        dst.nx = w;
+        dst.ny = h;
+        dst.buf.resize(3 * w * h);
+
+        for (int i = 0; i < h; ++i) {
+            for (int j = 0; j < w; ++j) {
+                int src_idx = 3 * ((y + i)*image.nx + (x + j));
+                int dst_idx = 3 * (i*w + j);
+                dst.buf[dst_idx]     = image.buf[src_idx];
+                dst.buf[dst_idx + 1] = image.buf[src_idx + 1];
+                dst.buf[dst_idx + 2] = image.buf[src_idx + 2];
+            }
+        }
+    }
+
+    // calculate the size of the **resized** image, while preserving the aspect ratio
+    // the calculated size will be aligned to the nearest multiple of align_size
+    // if H or W size is larger than max_dimension, it will be resized to max_dimension
+    static clip_image_size calc_size_preserved_ratio(const clip_image_size & inp_size, const int align_size, const int max_dimension) {
+        if (inp_size.width <= 0 || inp_size.height <= 0 || align_size <= 0 || max_dimension <= 0) {
+            return {0, 0};
+        }
+
+        float scale = std::min(1.0f, std::min(static_cast(max_dimension) / inp_size.width,
+                                              static_cast(max_dimension) / inp_size.height));
+
+        float target_width_f  = static_cast(inp_size.width)  * scale;
+        float target_height_f = static_cast(inp_size.height) * scale;
+
+        int aligned_width  = CLIP_ALIGN((int)target_width_f,  align_size);
+        int aligned_height = CLIP_ALIGN((int)target_height_f, align_size);
+
+        return {aligned_width, aligned_height};
+    }
+
+private:
+    static inline int clip(int x, int lower, int upper) {
+        return std::max(lower, std::min(x, upper));
+    }
+
+    // Linear interpolation between two points
+    static inline float lerp(float s, float e, float t) {
+        return s + (e - s) * t;
+    }
+};
+
+/**
+ * implementation of LLaVA-UHD:
+ *  - https://arxiv.org/pdf/2403.11703
+ *  - https://github.com/thunlp/LLaVA-UHD
+ *  - https://github.com/thunlp/LLaVA-UHD/blob/302301bc2175f7e717fb8548516188e89f649753/llava_uhd/train/llava-uhd/slice_logic.py#L118
+ *
+ * overview:
+ *   - an image always have a single overview (downscaled image)
+ *   - an image can have 0 or multiple slices, depending on the image size
+ *   - each slice can then be considered as a separate image
+ *
+ * for example:
+ *
+ * [overview] --> [slice 1] --> [slice 2]
+ *           |                |
+ *           +--> [slice 3] --> [slice 4]
+ */
+struct llava_uhd {
+    struct slice_coordinates {
+        int x;
+        int y;
+        clip_image_size size;
+    };
+
+    struct slice_instructions {
+        clip_image_size overview_size; // size of downscaled image
+        clip_image_size refined_size;  // size of image right before slicing (must be multiple of slice size)
+        clip_image_size grid_size;     // grid_size.width * grid_size.height = number of slices
+        std::vector slices;
+        bool padding_refined = false;  // if true, refine image will be padded to the grid size (e.g. llava-1.6)
+    };
+
+    static int get_max_slices(struct clip_ctx * ctx) {
+        if (clip_is_minicpmv(ctx)) {
+            return 9;
+        }
+        return 0;
+    }
+
+    static slice_instructions get_slice_instructions(struct clip_ctx * ctx, const clip_image_size & original_size) {
+        slice_instructions res;
+        const int patch_size      = clip_get_patch_size(ctx);
+        const int slice_size      = clip_get_image_size(ctx);
+        const int max_slice_nums  = get_max_slices(ctx);
+        const int original_width  = original_size.width;
+        const int original_height = original_size.height;
+        const float log_ratio = log((float)original_width / original_height);
+        const float ratio = (float)original_width * original_height / (slice_size * slice_size);
+        const int multiple = fmin(ceil(ratio), max_slice_nums);
+        const bool has_slices = (multiple > 1);
+        const bool has_pinpoints = !ctx->model.hparams.image_grid_pinpoints.empty();
+
+        if (has_pinpoints) {
+            // has pinpoints, use them to calculate the grid size (e.g. llava-1.6)
+            auto refine_size = llava_uhd::select_best_resolution(
+                ctx->model.hparams.image_grid_pinpoints,
+                original_size);
+            res.overview_size   = clip_image_size{slice_size, slice_size};
+            res.refined_size    = refine_size;
+            res.grid_size       = clip_image_size{0, 0};
+            res.padding_refined = true;
+
+            for (int y = 0; y < refine_size.height; y += slice_size) {
+                for (int x = 0; x < refine_size.width; x += slice_size) {
+                    slice_coordinates slice;
+                    slice.x = x;
+                    slice.y = y;
+                    slice.size.width  = std::min(slice_size, refine_size.width  - x);
+                    slice.size.height = std::min(slice_size, refine_size.height - y);
+                    res.slices.push_back(slice);
+                    if (x == 0) {
+                        res.grid_size.width++;
+                    }
+                }
+                res.grid_size.height++;
+            }
+
+            return res;
+        }
+
+        // no pinpoints, dynamically calculate the grid size (e.g. minicpmv)
+
+        auto best_size    = get_best_resize(original_size, slice_size, patch_size, !has_slices);
+        res.overview_size = best_size;
+
+        if (!has_slices) {
+            // skip slicing logic
+            res.refined_size = clip_image_size{0, 0};
+            res.grid_size    = clip_image_size{0, 0};
+
+        } else {
+            auto best_grid   = get_best_grid(max_slice_nums, multiple, log_ratio);
+            auto refine_size = get_refine_size(original_size, best_grid, slice_size, patch_size, true);
+            res.grid_size    = best_grid;
+            res.refined_size = refine_size;
+
+            int width  = refine_size.width;
+            int height = refine_size.height;
+            int grid_x = int(width  / best_grid.width);
+            int grid_y = int(height / best_grid.height);
+            for (int patches_y = 0,                    ic = 0;
+                    patches_y < refine_size.height && ic < best_grid.height;
+                    patches_y += grid_y,              ic += 1) {
+                for (int patches_x = 0,                   jc = 0;
+                        patches_x < refine_size.width && jc < best_grid.width;
+                        patches_x += grid_x,             jc += 1) {
+                    slice_coordinates slice;
+                    slice.x = patches_x;
+                    slice.y = patches_y;
+                    slice.size.width  = grid_x;
+                    slice.size.height = grid_y;
+                    res.slices.push_back(slice);
+                    // LOG_INF("slice %d: %d %d %d %d\n", ic, patches_i, patches_j, grid_x, grid_y);
+                }
+            }
+        }
+
+        return res;
+    }
+
+    static std::vector slice_image(const clip_image_u8 * img, const slice_instructions & inst) {
+        std::vector output;
+
+        // resize to overview size
+        clip_image_u8_ptr resized_img(clip_image_u8_init());
+        image_manipulation::bicubic_resize(*img, *resized_img, inst.overview_size.width, inst.overview_size.height);
+        output.push_back(std::move(resized_img));
+        if (inst.slices.empty()) {
+            // no slices, just return the resized image
+            return output;
+        }
+
+        // resize to refined size
+        clip_image_u8_ptr refined_img(clip_image_u8_init());
+        if (inst.padding_refined) {
+            image_manipulation::resize_and_pad_image(*img, *refined_img, inst.refined_size);
+        } else {
+            image_manipulation::bilinear_resize(*img, *refined_img, inst.refined_size.width, inst.refined_size.height);
+        }
+
+        // create slices
+        for (const auto & slice : inst.slices) {
+            int x = slice.x;
+            int y = slice.y;
+            int w = slice.size.width;
+            int h = slice.size.height;
+
+            clip_image_u8_ptr img_slice(clip_image_u8_init());
+            image_manipulation::crop_image(*refined_img, *img_slice, x, y, w, h);
+            output.push_back(std::move(img_slice));
+        }
+
+        return output;
+    }
+
+private:
+    static clip_image_size get_best_resize(const clip_image_size & original_size, int scale_resolution, int patch_size, bool allow_upscale = false) {
+        int width  = original_size.width;
+        int height = original_size.height;
+        if ((width * height > scale_resolution * scale_resolution) || allow_upscale) {
+            float r = static_cast(width) / height;
+            height  = static_cast(scale_resolution / std::sqrt(r));
+            width   = static_cast(height * r);
+        }
+        clip_image_size res;
+        res.width  = ensure_divide(width,  patch_size);
+        res.height = ensure_divide(height, patch_size);
+        return res;
+    }
+
+    /**
+     * Selects the best resolution from a list of possible resolutions based on the original size.
+     *
+     * @param original_size The original size of the image
+     * @param possible_resolutions A list of possible resolutions
+     * @return The best fit resolution
+     */
+    static clip_image_size select_best_resolution(const clip_image_size & original_size, const std::vector & possible_resolutions) {
+        int original_width = original_size.width;
+        int original_height = original_size.height;
+        clip_image_size best_fit;
+        int max_effective_resolution = 0;
+        int min_wasted_resolution = std::numeric_limits::max();
+
+        for (const auto & resolution : possible_resolutions) {
+            int width  = resolution.width;
+            int height = resolution.height;
+            float scale = std::min(static_cast(width) / original_width, static_cast(height) / original_height);
+            int downscaled_width  = static_cast(original_width * scale);
+            int downscaled_height = static_cast(original_height * scale);
+            int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height);
+            int wasted_resolution = (width * height) - effective_resolution;
+            // LOG_INF("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution);
+            if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) {
+                max_effective_resolution = effective_resolution;
+                min_wasted_resolution = wasted_resolution;
+                best_fit = resolution;
+            }
+        }
+
+        return best_fit;
+    }
+
+    // used by llava 1.6 with custom list of pinpoints
+    static clip_image_size select_best_resolution(const std::vector & pinpoints, const clip_image_size & original_size) {
+        std::vector possible_resolutions; // TODO @ngxson : construct this inside hparams, not here
+        for (size_t i = 0; i < pinpoints.size(); i += 2) {
+            possible_resolutions.push_back(clip_image_size{pinpoints[i], pinpoints[i+1]});
+        }
+        return select_best_resolution(original_size, possible_resolutions);
+    }
+
+    static int ensure_divide(int length, int patch_size) {
+        return std::max(static_cast(std::round(static_cast(length) / patch_size) * patch_size), patch_size);
+    }
+
+    static clip_image_size get_refine_size(const clip_image_size & original_size, const clip_image_size & grid, int scale_resolution, int patch_size, bool allow_upscale = false) {
+        int width  = original_size.width;
+        int height = original_size.height;
+        int grid_x = grid.width;
+        int grid_y = grid.height;
+
+        int refine_width  = ensure_divide(width, grid_x);
+        int refine_height = ensure_divide(height, grid_y);
+
+        clip_image_size grid_size;
+        grid_size.width  = refine_width  / grid_x;
+        grid_size.height = refine_height / grid_y;
+
+        auto best_grid_size  = get_best_resize(grid_size, scale_resolution, patch_size, allow_upscale);
+        int best_grid_width  = best_grid_size.width;
+        int best_grid_height = best_grid_size.height;
+
+        clip_image_size refine_size;
+        refine_size.width  = best_grid_width  * grid_x;
+        refine_size.height = best_grid_height * grid_y;
+        return refine_size;
+    }
+
+    static clip_image_size get_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) {
+        std::vector candidate_split_grids_nums;
+        for (int i : {multiple - 1, multiple, multiple + 1}) {
+            if (i == 1 || i > max_slice_nums) {
+                continue;
+            }
+            candidate_split_grids_nums.push_back(i);
+        }
+
+        std::vector candidate_grids;
+        for (int split_grids_nums : candidate_split_grids_nums) {
+            int m = 1;
+            while (m <= split_grids_nums) {
+                if (split_grids_nums % m == 0) {
+                    candidate_grids.push_back(clip_image_size{m, split_grids_nums / m});
+                }
+                ++m;
+            }
+        }
+
+        clip_image_size best_grid{1, 1};
+        float min_error = std::numeric_limits::infinity();
+        for (const auto& grid : candidate_grids) {
+            float error = std::abs(log_ratio - std::log(1.0 * grid.width / grid.height));
+            if (error < min_error) {
+                best_grid = grid;
+                min_error = error;
+            }
+        }
+        return best_grid;
+    }
+};
+
+// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
+// res_imgs memory is being allocated here, previous allocations will be freed if found
+bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) {
+    clip_image_size original_size{img->nx, img->ny};
+    bool pad_to_square = true;
+    auto & params = ctx->model.hparams;
+    // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing
+    if (params.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD) {
+        pad_to_square = false;
+    }
+
+    if (clip_is_minicpmv(ctx)) {
+        auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
+        std::vector imgs = llava_uhd::slice_image(img, inst);
+
+        for (size_t i = 0; i < imgs.size(); ++i) {
+            // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
+            clip_image_f32_ptr res(clip_image_f32_init());
+            normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
+            res_imgs->entries.push_back(std::move(res));
+        }
+
+        res_imgs->grid_x = inst.grid_size.width;
+        res_imgs->grid_y = inst.grid_size.height;
+        return true;
+
+    } else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) {
+        clip_image_u8 resized;
+        auto patch_size = params.patch_size * 2;
+        auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, patch_size, params.image_size);
+        image_manipulation::bicubic_resize(*img, resized, new_size.width, new_size.height);
+
+        clip_image_f32_ptr img_f32(clip_image_f32_init());
+        // clip_image_f32_ptr res(clip_image_f32_init());
+        normalize_image_u8_to_f32(resized, *img_f32, params.image_mean, params.image_std);
+        // res_imgs->data[0] = *res;
+        res_imgs->entries.push_back(std::move(img_f32));
+        return true;
+    }
+    else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE
+            || ctx->proj_type() == PROJECTOR_TYPE_GEMMA3
+            || ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3
+            || ctx->proj_type() == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution
+    ) {
+        clip_image_u8 resized_image;
+        int sz = params.image_size;
+        image_manipulation::resize_and_pad_image(*img, resized_image, {sz, sz});
+        clip_image_f32_ptr img_f32(clip_image_f32_init());
+        //clip_image_save_to_bmp(resized_image, "resized.bmp");
+        normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std);
+        res_imgs->entries.push_back(std::move(img_f32));
+        return true;
+
+    } else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL) {
+        clip_image_u8 resized_image;
+        auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size);
+        image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height);
+        clip_image_f32_ptr img_f32(clip_image_f32_init());
+        normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std);
+        res_imgs->entries.push_back(std::move(img_f32));
+        return true;
+
+    } else if (ctx->proj_type() == PROJECTOR_TYPE_LLAMA4) {
+        GGML_ASSERT(!params.image_grid_pinpoints.empty());
+        auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
+        std::vector imgs = llava_uhd::slice_image(img, inst);
+
+        for (size_t i = 0; i < imgs.size(); ++i) {
+            clip_image_f32_ptr res(clip_image_f32_init());
+            normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
+            res_imgs->entries.push_back(std::move(res));
+        }
+
+        res_imgs->grid_x = inst.grid_size.width;
+        res_imgs->grid_y = inst.grid_size.height;
+        return true;
+
+    }
+
+    // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
+    // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
+
+    clip_image_u8_ptr temp(clip_image_u8_init()); // we will keep the input image data here temporarily
+
+    if (pad_to_square) {
+        // for llava-1.5, we resize image to a square, and pad the shorter side with a background color
+        // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
+        const int longer_side = std::max(img->nx, img->ny);
+        temp->nx = longer_side;
+        temp->ny = longer_side;
+        temp->buf.resize(3 * longer_side * longer_side);
+
+        // background color in RGB from LLaVA (this is the mean rgb color * 255)
+        const std::array pad_color = {122, 116, 104};
+
+        // resize the image to the target_size
+        image_manipulation::resize_and_pad_image(*img, *temp, clip_image_size{params.image_size, params.image_size}, pad_color);
+
+        clip_image_f32_ptr res(clip_image_f32_init());
+        normalize_image_u8_to_f32(*temp, *res, params.image_mean, params.image_std);
+        res_imgs->entries.push_back(std::move(res));
+        return true;
+
+    } else if (!params.image_grid_pinpoints.empty()) {
+        // "spatial_unpad" with "anyres" processing for llava-1.6
+        auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
+        std::vector imgs = llava_uhd::slice_image(img, inst);
+
+        for (size_t i = 0; i < imgs.size(); ++i) {
+            // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
+            clip_image_f32_ptr res(clip_image_f32_init());
+            normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
+            res_imgs->entries.push_back(std::move(res));
+        }
+
+        return true;
+
+    }
+
+    GGML_ASSERT(false && "Unknown image preprocessing type");
+}
+
+ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx) {
+    return ctx->model.image_newline;
+}
+
+void clip_free(clip_ctx * ctx) {
+    if (ctx == nullptr) {
+        return;
+    }
+    delete ctx;
+}
+
+// deprecated
+size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
+    const int32_t nx = ctx->model.hparams.image_size;
+    const int32_t ny = ctx->model.hparams.image_size;
+    return clip_embd_nbytes_by_img(ctx, nx, ny);
+}
+
+size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_w, int img_h) {
+    clip_image_f32 img;
+    img.nx = img_w;
+    img.ny = img_h;
+    return clip_n_output_tokens(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float);
+}
+
+int32_t clip_get_image_size(const struct clip_ctx * ctx) {
+    return ctx->model.hparams.image_size;
+}
+
+int32_t clip_get_patch_size(const struct clip_ctx * ctx) {
+    return ctx->model.hparams.patch_size;
+}
+
+int32_t clip_get_hidden_size(const struct clip_ctx * ctx) {
+    return ctx->model.hparams.n_embd;
+}
+
+const char * clip_patch_merge_type(const struct clip_ctx * ctx) {
+    return ctx->model.hparams.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD ? "spatial_unpad" : "flat";
+}
+
+const int32_t * clip_image_grid(const struct clip_ctx * ctx) {
+    if (ctx->model.hparams.image_grid_pinpoints.size()) {
+        return &ctx->model.hparams.image_grid_pinpoints.front();
+    }
+    return nullptr;
+}
+
+size_t get_clip_image_grid_size(const struct clip_ctx * ctx) {
+    return ctx->model.hparams.image_grid_pinpoints.size();
+}
+
+int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
+    const auto & params = ctx->model.hparams;
+    const int n_total = clip_n_output_tokens(ctx, img);
+    if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) {
+        return img->nx / (params.patch_size * 2) + (int)(img->nx % params.patch_size > 0);
+    }
+    return n_total;
+}
+
+int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
+    const auto & params = ctx->model.hparams;
+    if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) {
+        return img->ny / (params.patch_size * 2) + (int)(img->ny % params.patch_size > 0);
+    }
+    return 1;
+}
+
+int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
+    const auto & params = ctx->model.hparams;
+
+    // only for models using fixed size square images
+    int n_patches_sq = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
+
+    projector_type proj = ctx->proj_type();
+
+    switch (proj) {
+        case PROJECTOR_TYPE_MLP:
+        case PROJECTOR_TYPE_MLP_NORM:
+            {
+                // do nothing
+            } break;
+        case PROJECTOR_TYPE_LDP:
+        case PROJECTOR_TYPE_LDPV2:
+        case PROJECTOR_TYPE_GLM_EDGE:
+            {
+                n_patches_sq /= 4;
+                if (ctx->model.mm_glm_tok_boi) {
+                    n_patches_sq += 2; // for BOI and EOI token embeddings
+                }
+            } break;
+        case PROJECTOR_TYPE_MINICPMV:
+            {
+                if (params.minicpmv_version == 2) {
+                    n_patches_sq = 96;
+                } else if (params.minicpmv_version == 3) {
+                    n_patches_sq = 64;
+                } else if (params.minicpmv_version == 4) {
+                    n_patches_sq = 64;
+                } else {
+                    GGML_ABORT("Unknown minicpmv version");
+                }
+            } break;
+        case PROJECTOR_TYPE_QWEN2VL:
+        case PROJECTOR_TYPE_QWEN25VL:
+            {
+                // dynamic size
+                int patch_size = params.patch_size * 2;
+                int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
+                int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
+                n_patches_sq = x_patch * y_patch;
+            } break;
+        case PROJECTOR_TYPE_GEMMA3:
+            {
+                int n_per_side = params.image_size / params.patch_size;
+                int n_per_side_2d_pool = n_per_side / params.proj_scale_factor;
+                n_patches_sq = n_per_side_2d_pool * n_per_side_2d_pool;
+            } break;
+        case PROJECTOR_TYPE_IDEFICS3:
+        case PROJECTOR_TYPE_INTERNVL:
+            {
+                // both W and H are divided by proj_scale_factor
+                n_patches_sq /= (params.proj_scale_factor * params.proj_scale_factor);
+            } break;
+        case PROJECTOR_TYPE_PIXTRAL:
+            {
+                // dynamic size
+                int n_merge = params.spatial_merge_size;
+                int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1);
+                int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1);
+                n_patches_sq = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
+            } break;
+        case PROJECTOR_TYPE_LLAMA4:
+            {
+                int scale_factor = ctx->model.hparams.proj_scale_factor;
+                n_patches_sq /= (scale_factor * scale_factor);
+            } break;
+        case PROJECTOR_TYPE_ULTRAVOX:
+            {
+                const int proj_stack_factor = ctx->model.hparams.proj_stack_factor;
+                const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor);
+                n_patches_sq = n_len / proj_stack_factor / 2;
+            } break;
+        case PROJECTOR_TYPE_QWEN2A:
+            {
+                // divide by 2 because of whisper
+                // another divide by 2 because of nn.AvgPool1d(2, stride=2)
+                n_patches_sq = img->nx / 4;
+            } break;
+        default:
+            GGML_ABORT("unsupported projector type");
+    }
+
+    return n_patches_sq;
+}
+
+static std::vector>> get_1d_sincos_pos_embed_from_grid_new(int embed_dim, const std::vector> & pos) {
+    assert(embed_dim % 2 == 0);
+    int H = pos.size();
+    int W = pos[0].size();
+
+    std::vector omega(embed_dim / 2);
+    for (int i = 0; i < embed_dim / 2; ++i) {
+        omega[i] = 1.0 / pow(10000.0, static_cast(i) / (embed_dim / 2));
+    }
+
+    std::vector>> emb(H, std::vector>(W, std::vector(embed_dim)));
+    for (int h = 0; h < H; ++h) {
+        for (int w = 0; w < W; ++w) {
+            for (int d = 0; d < embed_dim / 2; ++d) {
+                float out_value = pos[h][w] * omega[d];
+                emb[h][w][d] = sin(out_value);
+                emb[h][w][d + embed_dim / 2] = cos(out_value);
+            }
+        }
+    }
+
+    return emb;
+}
+
+static std::vector>> get_2d_sincos_pos_embed_from_grid(int embed_dim, const std::vector>> & grid) {
+    assert(embed_dim % 2 == 0);
+    std::vector>> emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[0]); // (H, W, D/2)
+    std::vector>> emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[1]); // (H, W, D/2)
+
+    int H = emb_h.size();
+    int W = emb_h[0].size();
+    std::vector>> emb(H, std::vector>(W, std::vector(embed_dim)));
+
+    for (int h = 0; h < H; ++h) {
+        for (int w = 0; w < W; ++w) {
+            for (int d = 0; d < embed_dim / 2; ++d) {
+                emb[h][w][d] = emb_h[h][w][d];
+                emb[h][w][d + embed_dim / 2] = emb_w[h][w][d];
+            }
+        }
+    }
+    return emb;
+}
+
+static std::vector> get_2d_sincos_pos_embed(int embed_dim, const std::pair image_size) {
+    int grid_h_size = image_size.first;
+    int grid_w_size = image_size.second;
+
+    std::vector grid_h(grid_h_size);
+    std::vector grid_w(grid_w_size);
+
+    for (int i = 0; i < grid_h_size; ++i) {
+        grid_h[i] = static_cast(i);
+    }
+    for (int i = 0; i < grid_w_size; ++i) {
+        grid_w[i] = static_cast(i);
+    }
+
+    std::vector> grid(grid_h_size, std::vector(grid_w_size));
+    for (int h = 0; h < grid_h_size; ++h) {
+        for (int w = 0; w < grid_w_size; ++w) {
+            grid[h][w] = grid_w[w];
+        }
+    }
+    std::vector>> grid_2d = {grid, grid};
+    for (int h = 0; h < grid_h_size; ++h) {
+        for (int w = 0; w < grid_w_size; ++w) {
+            grid_2d[0][h][w] = grid_h[h];
+            grid_2d[1][h][w] = grid_w[w];
+        }
+    }
+
+    std::vector>> pos_embed_3d = get_2d_sincos_pos_embed_from_grid(embed_dim, grid_2d);
+
+    int H = image_size.first;
+    int W = image_size.second;
+    std::vector> pos_embed_2d(H * W, std::vector(embed_dim));
+    for (int h = 0; h < H; ++h) {
+        for (int w = 0; w < W; ++w) {
+            pos_embed_2d[w * H + h] = pos_embed_3d[h][w];
+        }
+    }
+
+    return pos_embed_2d;
+}
+
+bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) {
+    clip_image_f32_batch imgs;
+    clip_image_f32_ptr img_copy(clip_image_f32_init());
+    *img_copy = *img;
+    imgs.entries.push_back(std::move(img_copy));
+
+    return clip_image_batch_encode(ctx, n_threads, &imgs, vec);
+}
+
+bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs_c_ptr, float * vec) {
+    const clip_image_f32_batch & imgs = *imgs_c_ptr;
+    int batch_size = imgs.entries.size();
+
+    // TODO @ngxson : implement batch size > 1 as a loop
+    //                we don't need true batching support because the cgraph will gonna be big anyway
+    if (batch_size != 1) {
+        return false; // only support batch size of 1
+    }
+
+    // build the inference graph
+    ctx->debug_print_tensors.clear();
+    ggml_backend_sched_reset(ctx->sched.get());
+    ggml_cgraph * gf = clip_image_build_graph(ctx, imgs);
+    ggml_backend_sched_alloc_graph(ctx->sched.get(), gf);
+
+    // set inputs
+    const auto & model   = ctx->model;
+    const auto & hparams = model.hparams;
+
+    const int image_size_width  = imgs.entries[0]->nx;
+    const int image_size_height = imgs.entries[0]->ny;
+
+    const int patch_size    = hparams.patch_size;
+    const int num_patches   = ((image_size_width / patch_size) * (image_size_height / patch_size));
+    const int n_pos = num_patches + (model.class_embedding ? 1 : 0);
+    const int pos_w = image_size_width  / patch_size;
+    const int pos_h = image_size_height / patch_size;
+
+    const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl
+
+    auto get_inp_tensor = [&gf](const char * name) {
+        ggml_tensor * inp = ggml_graph_get_tensor(gf, name);
+        if (inp == nullptr) {
+            GGML_ABORT("Failed to get tensor %s", name);
+        }
+        if (!(inp->flags & GGML_TENSOR_FLAG_INPUT)) {
+            GGML_ABORT("Tensor %s is not an input tensor", name);
+        }
+        return inp;
+    };
+
+    auto set_input_f32 = [&get_inp_tensor](const char * name, std::vector & values) {
+        ggml_tensor * cur = get_inp_tensor(name);
+        GGML_ASSERT(cur->type == GGML_TYPE_F32);
+        GGML_ASSERT(ggml_nelements(cur) == (int64_t)values.size());
+        ggml_backend_tensor_set(cur, values.data(), 0, ggml_nbytes(cur));
+    };
+
+    auto set_input_i32 = [&get_inp_tensor](const char * name, std::vector & values) {
+        ggml_tensor * cur = get_inp_tensor(name);
+        GGML_ASSERT(cur->type == GGML_TYPE_I32);
+        GGML_ASSERT(ggml_nelements(cur) == (int64_t)values.size());
+        ggml_backend_tensor_set(cur, values.data(), 0, ggml_nbytes(cur));
+    };
+
+    // set input pixel values
+    if (!imgs.is_audio) {
+        size_t nelem = 0;
+        for (const auto & img : imgs.entries) {
+            nelem += img->nx * img->ny * 3;
+        }
+        std::vector inp_raw(nelem);
+
+        // layout of data (note: the channel dim is unrolled to better visualize the layout):
+        //
+        // ┌──W──┐
+        // │     H │  channel = R
+        // ├─────┤ │
+        // │     H │  channel = G
+        // ├─────┤ │
+        // │     H │  channel = B
+        // └─────┘ │
+        //   ──────┘ x B
+
+        for (size_t i = 0; i < imgs.entries.size(); i++) {
+            const int nx = imgs.entries[i]->nx;
+            const int ny = imgs.entries[i]->ny;
+            const int n = nx * ny;
+
+            for (int b = 0; b < batch_size; b++) {
+                float * batch_entry = inp_raw.data() + b * (3*n);
+                for (int y = 0; y < ny; y++) {
+                    for (int x = 0; x < nx; x++) {
+                        size_t base_src = 3*(y * nx + x); // idx of the first channel
+                        size_t base_dst =    y * nx + x;  // idx of the first channel
+                        batch_entry[      base_dst] = imgs.entries[b]->buf[base_src    ];
+                        batch_entry[1*n + base_dst] = imgs.entries[b]->buf[base_src + 1];
+                        batch_entry[2*n + base_dst] = imgs.entries[b]->buf[base_src + 2];
+                    }
+                }
+            }
+        }
+        set_input_f32("inp_raw", inp_raw);
+
+    } else {
+        // audio input
+        GGML_ASSERT(imgs.entries.size() == 1);
+        const auto & mel_inp = imgs.entries[0];
+        const int n_step = mel_inp->nx;
+        const int n_mel  = mel_inp->ny;
+        std::vector inp_raw(n_step * n_mel);
+        std::memcpy(inp_raw.data(), mel_inp->buf.data(), n_step * n_mel * sizeof(float));
+        set_input_f32("inp_raw", inp_raw);
+    }
+
+    // set input per projector
+    switch (ctx->model.proj_type) {
+        case PROJECTOR_TYPE_MINICPMV:
+            {
+                // inspired from siglip:
+                //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit
+                //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
+                std::vector positions(pos_h * pos_w);
+                int bucket_coords_h[1024];
+                int bucket_coords_w[1024];
+                for (int i = 0; i < pos_h; i++){
+                    bucket_coords_h[i] = std::floor(70.0*i/pos_h);
+                }
+                for (int i = 0; i < pos_w; i++){
+                    bucket_coords_w[i] = std::floor(70.0*i/pos_w);
+                }
+                for (int i = 0, id = 0; i < pos_h; i++){
+                    for (int j = 0; j < pos_w; j++){
+                        positions[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
+                    }
+                }
+                set_input_i32("positions", positions);
+
+                // inspired from resampler of Qwen-VL:
+                //    -> https://huggingface.co/Qwen/Qwen-VL/tree/main
+                //    -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23
+                int embed_dim = clip_n_mmproj_embd(ctx);
+
+                // TODO @ngxson : this is very inefficient, can we do this using ggml_sin and ggml_cos?
+                auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
+
+                std::vector pos_embed(embed_dim * pos_w * pos_h);
+                for(int i = 0; i < pos_w * pos_h; ++i){
+                    for(int j = 0; j < embed_dim; ++j){
+                        pos_embed[i * embed_dim + j] = pos_embed_t[i][j];
+                    }
+                }
+
+                set_input_f32("pos_embed", pos_embed);
+            } break;
+        case PROJECTOR_TYPE_QWEN2VL:
+            {
+                const int merge_ratio = 2;
+                const int pw = image_size_width  / patch_size;
+                const int ph = image_size_height / patch_size;
+                std::vector positions(n_pos * 4);
+                int ptr = 0;
+                for (int y = 0; y < ph; y += merge_ratio) {
+                    for (int x = 0; x < pw; x += merge_ratio) {
+                        for (int dy = 0; dy < 2; dy++) {
+                            for (int dx = 0; dx < 2; dx++) {
+                                positions[                  ptr] = y + dy;
+                                positions[    num_patches + ptr] = x + dx;
+                                positions[2 * num_patches + ptr] = y + dy;
+                                positions[3 * num_patches + ptr] = x + dx;
+                                ptr++;
+                            }
+                        }
+                    }
+                }
+
+                set_input_i32("positions", positions);
+            } break;
+        case PROJECTOR_TYPE_QWEN25VL:
+            {
+                // pw * ph = number of tokens output by ViT after apply patch merger
+                // ipw * ipw = number of vision token been processed inside ViT
+                const int merge_ratio = 2;
+                const int pw  = image_size_width  / patch_size / merge_ratio;
+                const int ph  = image_size_height / patch_size / merge_ratio;
+                const int ipw = image_size_width  / patch_size;
+                const int iph = image_size_height / patch_size;
+
+                std::vector idx    (ph * pw);
+                std::vector inv_idx(ph * pw);
+
+                if (use_window_attn) {
+                    const int attn_window_size = 112;
+                    const int grid_window = attn_window_size / patch_size / merge_ratio;
+                    int dst = 0;
+                    // [num_vision_tokens, num_vision_tokens] attention mask tensor
+                    std::vector mask(pow(ipw * iph, 2), std::numeric_limits::lowest());
+                    int mask_row = 0;
+
+                    for (int y = 0; y < ph; y += grid_window) {
+                        for (int x = 0; x < pw; x += grid_window) {
+                            const int win_h = std::min(grid_window, ph - y);
+                            const int win_w = std::min(grid_window, pw - x);
+                            const int dst_0 = dst;
+                            // group all tokens belong to the same window togather (to a continue range)
+                            for (int dy = 0; dy < win_h; dy++) {
+                                for (int dx = 0; dx < win_w; dx++) {
+                                    const int src = (y + dy) * pw + (x + dx);
+                                    GGML_ASSERT(src < (int)idx.size());
+                                    GGML_ASSERT(dst < (int)inv_idx.size());
+                                    idx    [src] = dst;
+                                    inv_idx[dst] = src;
+                                    dst++;
+                                }
+                            }
+
+                            for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
+                                int row_offset = mask_row * (ipw * iph);
+                                std::fill(
+                                    mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
+                                    mask.begin() + row_offset + (dst   * merge_ratio * merge_ratio),
+                                    0.0);
+                                mask_row++;
+                            }
+                        }
+                    }
+
+                    set_input_i32("window_idx",     idx);
+                    set_input_i32("inv_window_idx", inv_idx);
+                    set_input_f32("window_mask",    mask);
+                } else {
+                    for (int i = 0; i < ph * pw; i++) {
+                        idx[i] = i;
+                    }
+                }
+
+                const int mpow = merge_ratio * merge_ratio;
+                std::vector positions(n_pos * 4);
+
+                int ptr = 0;
+                for (int y = 0; y < iph; y += merge_ratio) {
+                    for (int x = 0; x < ipw; x += merge_ratio) {
+                        for (int dy = 0; dy < 2; dy++) {
+                            for (int dx = 0; dx < 2; dx++) {
+                                auto remap = idx[ptr / mpow];
+                                remap = (remap * mpow) + (ptr % mpow);
+
+                                positions[                  remap] = y + dy;
+                                positions[    num_patches + remap] = x + dx;
+                                positions[2 * num_patches + remap] = y + dy;
+                                positions[3 * num_patches + remap] = x + dx;
+                                ptr++;
+                            }
+                        }
+                    }
+                }
+
+                set_input_i32("positions", positions);
+            } break;
+        case PROJECTOR_TYPE_PIXTRAL:
+            {
+                // set the 2D positions
+                int n_patches_per_col = image_size_width / patch_size;
+                std::vector pos_data(n_pos);
+                // dimension H
+                for (int i = 0; i < n_pos; i++) {
+                    pos_data[i] = i / n_patches_per_col;
+                }
+                set_input_i32("pos_h", pos_data);
+                // dimension W
+                for (int i = 0; i < n_pos; i++) {
+                    pos_data[i] = i % n_patches_per_col;
+                }
+                set_input_i32("pos_w", pos_data);
+            } break;
+        case PROJECTOR_TYPE_GLM_EDGE:
+        {
+            // llava and other models
+            std::vector positions(n_pos);
+            for (int i = 0; i < n_pos; i++) {
+                positions[i] = i;
+            }
+            set_input_i32("positions", positions);
+        } break;
+        case PROJECTOR_TYPE_MLP:
+        case PROJECTOR_TYPE_MLP_NORM:
+        case PROJECTOR_TYPE_LDP:
+        case PROJECTOR_TYPE_LDPV2:
+            {
+                // llava and other models
+                std::vector positions(n_pos);
+                for (int i = 0; i < n_pos; i++) {
+                    positions[i] = i;
+                }
+                set_input_i32("positions", positions);
+
+                // The patches vector is used to get rows to index into the embeds with;
+                // we should skip dim 0 only if we have CLS to avoid going out of bounds
+                // when retrieving the rows.
+                int patch_offset = model.class_embedding ? 1 : 0;
+                std::vector patches(num_patches);
+                for (int i = 0; i < num_patches; i++) {
+                    patches[i] = i + patch_offset;
+                }
+                set_input_i32("patches", patches);
+            } break;
+        case PROJECTOR_TYPE_GEMMA3:
+        case PROJECTOR_TYPE_IDEFICS3:
+        case PROJECTOR_TYPE_INTERNVL:
+        case PROJECTOR_TYPE_QWEN2A:
+        case PROJECTOR_TYPE_ULTRAVOX:
+            {
+                // do nothing
+            } break;
+        case PROJECTOR_TYPE_LLAMA4:
+            {
+                // set the 2D positions
+                int n_patches_per_col = image_size_width / patch_size;
+                std::vector pos_data(num_patches + 1, 0); // +1 for the [CLS] token
+                // last pos is always kept 0, it's for CLS
+                // dimension H
+                for (int i = 0; i < num_patches; i++) {
+                    pos_data[i] = (i / n_patches_per_col) + 1;
+                }
+                set_input_i32("pos_h", pos_data);
+                // dimension W
+                for (int i = 0; i < num_patches; i++) {
+                    pos_data[i] = (i % n_patches_per_col) + 1;
+                }
+                set_input_i32("pos_w", pos_data);
+            } break;
+        default:
+            GGML_ABORT("Unknown projector type");
+    }
+
+    // ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads);
+    ggml_backend_dev_t dev = ggml_backend_get_device(ctx->backend_cpu);
+    ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
+    if (reg) {
+        auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
+        if (ggml_backend_set_n_threads_fn) {
+            ggml_backend_set_n_threads_fn(ctx->backend_cpu, n_threads);
+        }
+    }
+
+    auto status = ggml_backend_sched_graph_compute(ctx->sched.get(), gf);
+    if (status != GGML_STATUS_SUCCESS) {
+        LOG_ERR("%s: ggml_backend_sched_graph_compute failed with error %d\n", __func__, status);
+        return false;
+    }
+
+    // print debug nodes
+    if (ctx->debug_graph) {
+        LOG_INF("\n\n---\n\n");
+        LOG_INF("\n\nDebug graph:\n\n");
+        for (ggml_tensor * t : ctx->debug_print_tensors) {
+            std::vector data(ggml_nbytes(t));
+            ggml_backend_tensor_get(t, data.data(), 0, ggml_nbytes(t));
+            print_tensor_shape(t);
+            print_tensor_data(t, data.data(), 3);
+        }
+    }
+
+    // the last node is the embedding tensor
+    ggml_tensor * embeddings = ggml_graph_node(gf, -1);
+
+    // sanity check (only support batch size of 1 for now)
+    const int n_tokens_out = embeddings->ne[1];
+    const int expected_n_tokens_out = clip_n_output_tokens(ctx, imgs.entries[0].get());
+    if (n_tokens_out != expected_n_tokens_out) {
+        LOG_ERR("%s: expected output %d tokens, got %d\n", __func__, expected_n_tokens_out, n_tokens_out);
+        GGML_ABORT("Invalid number of output tokens");
+    }
+
+    // copy the embeddings to the location passed by the user
+    ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
+
+    return true;
+}
+
+int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
+    const auto & hparams = ctx->model.hparams;
+    switch (ctx->model.proj_type) {
+        case PROJECTOR_TYPE_LDP:
+            return ctx->model.mm_model_block_1_block_2_1_b->ne[0];
+        case PROJECTOR_TYPE_LDPV2:
+            return ctx->model.mm_model_peg_0_b->ne[0];
+        case PROJECTOR_TYPE_MLP:
+        case PROJECTOR_TYPE_PIXTRAL:
+            return ctx->model.mm_2_w->ne[1];
+        case PROJECTOR_TYPE_MLP_NORM:
+            return ctx->model.mm_3_b->ne[0];
+        case PROJECTOR_TYPE_MINICPMV:
+            if (hparams.minicpmv_version == 2) {
+                return 4096;
+            } else if (hparams.minicpmv_version == 3) {
+                return 3584;
+            } else if (hparams.minicpmv_version == 4) {
+                return 3584;
+            }
+            GGML_ABORT("Unknown minicpmv version");
+        case PROJECTOR_TYPE_GLM_EDGE:
+            return ctx->model.mm_model_mlp_3_w->ne[1];
+        case PROJECTOR_TYPE_QWEN2VL:
+        case PROJECTOR_TYPE_QWEN25VL:
+            return ctx->model.mm_1_b->ne[0];
+        case PROJECTOR_TYPE_GEMMA3:
+            return ctx->model.mm_input_proj_w->ne[0];
+        case PROJECTOR_TYPE_IDEFICS3:
+            return ctx->model.projection->ne[1];
+        case PROJECTOR_TYPE_ULTRAVOX:
+            return ctx->model.mm_2_w->ne[1];
+        case PROJECTOR_TYPE_INTERNVL:
+            return ctx->model.mm_3_w->ne[1];
+        case PROJECTOR_TYPE_LLAMA4:
+            return ctx->model.mm_model_proj->ne[1];
+        case PROJECTOR_TYPE_QWEN2A:
+            return ctx->model.mm_fc_w->ne[1];
+        default:
+            GGML_ABORT("Unknown projector type");
+    }
+}
+
+int clip_is_minicpmv(const struct clip_ctx * ctx) {
+    if (ctx->proj_type() == PROJECTOR_TYPE_MINICPMV) {
+        return ctx->model.hparams.minicpmv_version;
+    }
+    return 0;
+}
+
+bool clip_is_glm(const struct clip_ctx * ctx) {
+    return ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE;
+}
+
+bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
+    return ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL
+        || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL;
+}
+
+bool clip_is_llava(const struct clip_ctx * ctx) {
+    return ctx->model.hparams.has_llava_projector;
+}
+
+bool clip_is_gemma3(const struct clip_ctx * ctx) {
+    return ctx->proj_type() == PROJECTOR_TYPE_GEMMA3;
+}
+
+bool clip_has_vision_encoder(const struct clip_ctx * ctx) {
+    return ctx->model.modality == CLIP_MODALITY_VISION;
+}
+
+bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
+    return ctx->model.modality == CLIP_MODALITY_AUDIO;
+}
+
+bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
+    return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX
+        || ctx->proj_type() == PROJECTOR_TYPE_QWEN2A;
+}
+
+bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
+    clip_image_f32 clip_img;
+    clip_img.buf.resize(h * w * 3);
+    for (int i = 0; i < h*w*3; i++)
+    {
+        clip_img.buf[i] = img[i];
+    }
+    clip_img.nx = w;
+    clip_img.ny = h;
+    clip_image_encode(ctx, n_threads, &clip_img, vec);
+    return true;
+}
+
+//
+// API used internally with mtmd
+//
+
+projector_type clip_get_projector_type(const struct clip_ctx * ctx) {
+    return ctx->proj_type();
+}
+
+void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_frames, float * mel) {
+    clip_image_f32 * audio = new clip_image_f32;
+    audio->nx = n_frames;
+    audio->ny = n_mel;
+    audio->buf.resize(n_frames * n_mel);
+    std::memcpy(audio->buf.data(), mel, n_frames * n_mel * sizeof(float));
+
+    batch->entries.push_back(clip_image_f32_ptr(audio));
+    batch->is_audio = true;
+}
diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h
new file mode 100644
index 0000000000000000000000000000000000000000..cb2eb261fe2e8b5171ca026ca498cacd0d565e26
--- /dev/null
+++ b/tools/mtmd/clip.h
@@ -0,0 +1,114 @@
+#pragma once
+
+#include "ggml.h"
+#include 
+#include 
+
+// !!! Internal header, to be used by mtmd only !!!
+
+struct clip_ctx;
+
+struct clip_image_size {
+    int width;
+    int height;
+};
+
+struct clip_image_f32;
+struct clip_image_u8_batch;
+struct clip_image_f32_batch;
+
+enum clip_modality {
+    CLIP_MODALITY_VISION,
+    CLIP_MODALITY_AUDIO,
+};
+
+struct clip_context_params {
+    bool use_gpu;
+    enum ggml_log_level verbosity;
+};
+
+struct clip_init_result {
+    struct clip_ctx * ctx_v; // vision context
+    struct clip_ctx * ctx_a; // audio context
+};
+
+struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_params);
+
+void clip_free(struct clip_ctx * ctx);
+
+size_t clip_embd_nbytes(const struct clip_ctx * ctx);
+size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_w, int img_h);
+
+int32_t clip_get_image_size (const struct clip_ctx * ctx);
+int32_t clip_get_patch_size (const struct clip_ctx * ctx);
+int32_t clip_get_hidden_size(const struct clip_ctx * ctx);
+
+// TODO: should be enum, not string
+const char * clip_patch_merge_type(const struct clip_ctx * ctx);
+
+const int32_t * clip_image_grid(const struct clip_ctx * ctx);
+size_t get_clip_image_grid_size(const struct clip_ctx * ctx);
+
+int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img);
+
+// for M-RoPE, this will be the number of token positions in X and Y directions
+// for other models, X will be the total number of tokens and Y will be 1
+int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img);
+int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img);
+
+// this should be equal to the embedding dimension of the text model
+int clip_n_mmproj_embd(const struct clip_ctx * ctx);
+
+struct clip_image_size      * clip_image_size_init(void);
+struct clip_image_u8        * clip_image_u8_init (void);
+struct clip_image_f32       * clip_image_f32_init(void);
+struct clip_image_f32_batch * clip_image_f32_batch_init(void); // only used by libllava
+
+// nx, ny are the output image dimensions
+unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny);
+
+void clip_image_size_free (struct clip_image_size * img_size);
+void clip_image_u8_free (struct clip_image_u8  * img);
+void clip_image_f32_free(struct clip_image_f32 * img);
+void clip_image_u8_batch_free (struct clip_image_u8_batch  * batch);
+void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
+
+// use for accessing underlay data of clip_image_f32_batch
+size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch); // equivalent to batch->size()
+size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->nx
+size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->ny
+struct clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
+
+/**
+ * Build image from pixels decoded by other libraries instead of stb_image.h for better performance.
+ * The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes
+ */
+void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img);
+
+bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
+
+/** interpret bytes as an image file with length bytes_length, and use the result to populate img */
+bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img);
+
+/** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */
+bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs );
+
+struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx);
+
+bool clip_image_encode      (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, float * vec);
+bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec);
+
+int clip_is_minicpmv(const struct clip_ctx * ctx);
+bool clip_is_glm(const struct clip_ctx * ctx);
+bool clip_is_qwen2vl(const struct clip_ctx * ctx);
+bool clip_is_llava(const struct clip_ctx * ctx);
+bool clip_is_gemma3(const struct clip_ctx * ctx);
+
+bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
+
+// use by audio input
+void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_frames, float * mel);
+
+bool clip_has_vision_encoder(const struct clip_ctx * ctx);
+bool clip_has_audio_encoder(const struct clip_ctx * ctx);
+bool clip_has_whisper_encoder(const struct clip_ctx * ctx);
diff --git a/tools/mtmd/deprecation-warning.cpp b/tools/mtmd/deprecation-warning.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..dded0a56af96bc88697623f0fda9d67ce8aac09a
--- /dev/null
+++ b/tools/mtmd/deprecation-warning.cpp
@@ -0,0 +1,22 @@
+#include 
+#include 
+
+int main(int argc, char** argv) {
+    std::string filename = "main";
+    if (argc >= 1) {
+        filename = argv[0];
+    }
+
+    // Get only the program name from the full path
+    size_t pos = filename.find_last_of("/\\");
+    if (pos != std::string::npos) {
+        filename = filename.substr(pos+1);
+    }
+
+    fprintf(stdout, "\n");
+    fprintf(stdout, "WARNING: The binary '%s' is deprecated.\n", filename.c_str());
+    fprintf(stdout, "Please use 'llama-mtmd-cli' instead.\n");
+    fprintf(stdout, "\n");
+
+    return EXIT_FAILURE;
+}
diff --git a/tools/mtmd/legacy-models/convert_image_encoder_to_gguf.py b/tools/mtmd/legacy-models/convert_image_encoder_to_gguf.py
new file mode 100644
index 0000000000000000000000000000000000000000..2949faec421bef0771292722305fd08c6ef0b448
--- /dev/null
+++ b/tools/mtmd/legacy-models/convert_image_encoder_to_gguf.py
@@ -0,0 +1,412 @@
+import argparse
+import os
+import json
+import re
+
+import torch
+import numpy as np
+from gguf import *
+from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel, SiglipVisionModel
+
+TEXT = "clip.text"
+VISION = "clip.vision"
+
+
+def k(raw_key: str, arch: str) -> str:
+    return raw_key.format(arch=arch)
+
+
+def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: bool) -> bool:
+    if name in (
+        "logit_scale",
+        "text_model.embeddings.position_ids",
+        "vision_model.embeddings.position_ids",
+    ):
+        return True
+
+    if has_llava and name in ["visual_projection.weight", "vision_model.post_layernorm.weight", "vision_model.post_layernorm.bias"]:
+        return True
+
+    if name.startswith("v") and not has_vision:
+        return True
+
+    if name.startswith("t") and not has_text:
+        return True
+
+    return False
+
+
+def get_tensor_name(name: str) -> str:
+    # Standardize the transformers llava next keys for
+    # image newline / mm projector with the classes in haotian-liu LLaVA
+    if name == "image_newline":
+        return "model.image_newline"
+    if name.startswith("multi_modal_projector"):
+        name = name.replace("multi_modal_projector", "mm")
+        if "linear_1" in name:
+            name = name.replace("linear_1", "0")
+        if "linear_2" in name:
+            name = name.replace("linear_2", "2")
+        return name
+
+    if "projection" in name:
+        return name
+    if "mm_projector" in name:
+        name = name.replace("model.mm_projector", "mm")
+        name = re.sub(r'mm\.mlp\.mlp', 'mm.model.mlp', name, count=1)
+        name = re.sub(r'mm\.peg\.peg', 'mm.model.peg', name, count=1)
+        return name
+
+    return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln")
+
+
+def bytes_to_unicode():
+    """
+    Returns list of utf-8 byte and a corresponding list of unicode strings.
+    The reversible bpe codes work on unicode strings.
+    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+    This is a significant percentage of your normal, say, 32K bpe vocab.
+    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+    And avoids mapping to whitespace/control characters the bpe code barfs on.
+    """
+    bs = (
+        list(range(ord("!"), ord("~") + 1))
+        + list(range(ord("¡"), ord("¬") + 1))
+        + list(range(ord("®"), ord("ÿ") + 1))
+    )
+    cs = bs[:]
+    n = 0
+    for b in range(2**8):
+        if b not in bs:
+            bs.append(b)
+            cs.append(2**8 + n)
+            n += 1
+    cs = [chr(n) for n in cs]
+    return dict(zip(bs, cs))
+
+
+ap = argparse.ArgumentParser()
+ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True)
+ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16")
+ap.add_argument('--bigendian', action="store_true", default=False, help="Model is executed on big-endian machine")
+ap.add_argument("--text-only", action="store_true", required=False,
+                help="Save a text-only model. It can't be used to encode images")
+ap.add_argument("--vision-only", action="store_true", required=False,
+                help="Save a vision-only model. It can't be used to encode texts")
+ap.add_argument("--clip-model-is-vision", action="store_true", required=False,
+                help="The clip model is a pure vision model (ShareGPT4V vision extract for example)")
+
+# Selectable visual encoders that are compatible with this script
+encoder_group = ap.add_mutually_exclusive_group()
+encoder_group.add_argument("--clip-model-is-openclip", action="store_true", required=False,
+                help="The clip model is from openclip (for ViT-SO400M type))")
+encoder_group.add_argument("--clip-model-is-siglip", action="store_true", required=False,
+                help="the visual encoder is Siglip.")
+
+ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.")
+ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp")
+ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
+# Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711
+# Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5
+default_image_mean = [0.48145466, 0.4578275, 0.40821073]
+default_image_std = [0.26862954, 0.26130258, 0.27577711]
+ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None)
+ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None)
+
+# with proper
+args = ap.parse_args()
+
+
+if args.text_only and args.vision_only:
+    print("--text-only and --image-only arguments cannot be specified at the same time.")
+    exit(1)
+
+if args.use_f32:
+    print("WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet.")
+
+# output in the same directory as the model if output_dir is None
+dir_model = args.model_dir
+
+if (
+    args.clip_model_is_vision or
+    not os.path.exists(dir_model + "/vocab.json") or
+    args.clip_model_is_openclip or
+    args.clip_model_is_siglip
+):
+    vocab = None
+    tokens = None
+else:
+    with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f:
+        vocab = json.load(f)
+        tokens = [key for key in vocab]
+
+with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
+    config = json.load(f)
+    if args.clip_model_is_vision:
+        v_hparams = config
+        t_hparams = None
+    else:
+        v_hparams = config["vision_config"]
+        t_hparams = config["text_config"]
+
+# possible data types
+#   ftype == 0 -> float32
+#   ftype == 1 -> float16
+#
+# map from ftype to string
+ftype_str = ["f32", "f16"]
+
+ftype = 1
+if args.use_f32:
+    ftype = 0
+
+if args.clip_model_is_siglip:
+    model = SiglipVisionModel.from_pretrained(dir_model)
+    processor = None
+elif args.clip_model_is_vision or args.clip_model_is_openclip:
+    model = CLIPVisionModel.from_pretrained(dir_model)
+    processor = None
+else:
+    model = CLIPModel.from_pretrained(dir_model)
+    processor = CLIPProcessor.from_pretrained(dir_model)
+
+fname_middle = None
+has_text_encoder = True
+has_vision_encoder = True
+has_llava_projector = False
+if args.text_only:
+    fname_middle = "text-"
+    has_vision_encoder = False
+elif args.llava_projector is not None:
+    fname_middle = "mmproj-"
+    has_text_encoder = False
+    has_llava_projector = True
+elif args.vision_only:
+    fname_middle = "vision-"
+    has_text_encoder = False
+else:
+    fname_middle = ""
+
+output_dir = args.output_dir if args.output_dir is not None else dir_model
+os.makedirs(output_dir, exist_ok=True)
+output_prefix = os.path.basename(output_dir).replace("ggml_", "")
+fname_out = os.path.join(output_dir, f"{fname_middle}model-{ftype_str[ftype]}.gguf")
+fout = GGUFWriter(path=fname_out, arch="clip", endianess=GGUFEndian.LITTLE if not args.bigendian else GGUFEndian.BIG)
+
+fout.add_bool("clip.has_text_encoder", has_text_encoder)
+fout.add_bool("clip.has_vision_encoder", has_vision_encoder)
+fout.add_bool("clip.has_llava_projector", has_llava_projector)
+fout.add_file_type(ftype)
+model_name = config["_name_or_path"] if "_name_or_path" in config else os.path.basename(dir_model)
+fout.add_name(model_name)
+if args.text_only:
+    fout.add_description("text-only CLIP model")
+elif args.vision_only and not has_llava_projector:
+    fout.add_description("vision-only CLIP model")
+elif has_llava_projector:
+    fout.add_description("image encoder for LLaVA")
+    # add projector type
+    fout.add_string("clip.projector_type", args.projector_type)
+else:
+    fout.add_description("two-tower CLIP model")
+
+if has_text_encoder:
+    assert t_hparams is not None
+    assert tokens is not None
+    if args.clip_model_is_siglip:
+        text_projection_dim = 0
+    else:
+        text_projection_dim = t_hparams.get("projection_dim", config["projection_dim"])
+    # text_model hparams
+    fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"])
+    fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"])
+    fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, TEXT), t_hparams["intermediate_size"])
+    fout.add_uint32("clip.text.projection_dim", text_projection_dim)
+    fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, TEXT), t_hparams["num_attention_heads"])
+    fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, TEXT), t_hparams["layer_norm_eps"])
+    fout.add_uint32(k(KEY_BLOCK_COUNT, TEXT), t_hparams["num_hidden_layers"])
+    fout.add_token_list(tokens)
+
+
+
+def get_non_negative_vision_feature_layers(v_hparams):
+    """
+    Determine the vision feature layer(s) for the llava model, which are indices into the
+    hidden states of the visual encoder. Note that the hidden states array generally takes the
+    form:
+
+        [, , ... ]
+
+    so feature indices should be offset as n+1 to get the output of encoder block n.
+    We convert all vision feature layers to non-negative so that -1 can be used in
+    the model as an unset value. If no vision feature layer is found, we leave it unset.
+    """
+    num_hidden_layers = v_hparams["num_hidden_layers"]
+    to_non_negative = lambda layer_idx: layer_idx  if layer_idx >= 0 else num_hidden_layers + layer_idx + 1
+    feature_layers_key = None
+    # Key used for llava models in transformers
+    if "vision_feature_layer" in config:
+        feature_layers_key = "vision_feature_layer"
+    # Key used for llava models in the original format
+    elif "mm_vision_select_layer" in config:
+        feature_layers_key = "mm_vision_select_layer"
+    if feature_layers_key is not None:
+        feature_layers = config[feature_layers_key]
+        if isinstance(feature_layers, int):
+            feature_layers = [feature_layers]
+        return [to_non_negative(feature_layer) for feature_layer in feature_layers]
+
+# Determine if we have explicitly specified vision feature layers in our config
+feature_layers = get_non_negative_vision_feature_layers(v_hparams)
+
+if has_vision_encoder:
+    # Siglip does not have a visual projector; set projection dim to 0
+    if args.clip_model_is_siglip:
+        visual_projection_dim = 0
+    else:
+        visual_projection_dim = v_hparams.get("projection_dim", config["projection_dim"])
+
+    # set vision_model hparams
+    fout.add_uint32("clip.vision.image_size", v_hparams["image_size"])
+    fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"])
+    fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"])
+    fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), v_hparams["intermediate_size"])
+    fout.add_uint32("clip.vision.projection_dim", visual_projection_dim)
+    fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"])
+    fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"])
+    if feature_layers:
+        block_count = max(feature_layers)
+    else:
+        block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"]
+    fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count)
+                            #     /**
+                            #      "image_grid_pinpoints": [
+                            #         [
+                            #         336,
+                            #         672
+                            #         ],
+                            #         [
+                            #         672,
+                            #         336
+                            #         ],
+                            #         [
+                            #         672,
+                            #         672
+                            #         ],
+                            #         [
+                            #         1008,
+                            #         336
+                            #         ],
+                            #         [
+                            #         336,
+                            #         1008
+                            #         ]
+                            #     ],
+                            #     Flattened:
+                            #     [
+                            #         336, 672,
+                            #         672, 336,
+                            #         672, 672,
+                            #         1008, 336,
+                            #         336, 1008
+                            #     ]
+                            #  *
+                            #  */
+    if "image_grid_pinpoints" in v_hparams:
+        # flatten it
+        image_grid_pinpoints = []
+        for pinpoint in v_hparams["image_grid_pinpoints"]:
+            for p in pinpoint:
+                image_grid_pinpoints.append(p)
+        fout.add_array("clip.vision.image_grid_pinpoints", image_grid_pinpoints)
+    if "image_crop_resolution" in v_hparams:
+        fout.add_uint32("clip.vision.image_crop_resolution", v_hparams["image_crop_resolution"])
+    if "image_aspect_ratio" in v_hparams:
+        fout.add_string("clip.vision.image_aspect_ratio", v_hparams["image_aspect_ratio"])
+    if "image_split_resolution" in v_hparams:
+        fout.add_uint32("clip.vision.image_split_resolution", v_hparams["image_split_resolution"])
+    if "mm_patch_merge_type" in v_hparams:
+        fout.add_string("clip.vision.mm_patch_merge_type", v_hparams["mm_patch_merge_type"])
+    if "mm_projector_type" in v_hparams:
+        fout.add_string("clip.vision.mm_projector_type", v_hparams["mm_projector_type"])
+    if feature_layers:
+        fout.add_array("clip.vision.feature_layer", feature_layers)
+
+    if processor is not None:
+        image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean  # pyright: ignore[reportAttributeAccessIssue]
+        image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std  # pyright: ignore[reportAttributeAccessIssue]
+    else:
+        image_mean = args.image_mean if args.image_mean is not None else default_image_mean
+        image_std = args.image_std if args.image_std is not None else default_image_std
+    fout.add_array("clip.vision.image_mean", image_mean)
+    fout.add_array("clip.vision.image_std", image_std)
+
+use_gelu = v_hparams["hidden_act"] == "gelu"
+fout.add_bool("clip.use_gelu", use_gelu)
+
+
+if has_llava_projector:
+    # By default, we drop the last layer for llava projector
+    # models unless we have explicitly set vision feature layers
+    if feature_layers is None:
+        model.vision_model.encoder.layers.pop(-1)
+    else:
+        model.vision_model.encoder.layers = model.vision_model.encoder.layers[:max(feature_layers)]
+
+    projector = torch.load(args.llava_projector)
+    for name, data in projector.items():
+        name = get_tensor_name(name)
+        # pw and dw conv ndim==4
+        if data.ndim == 2 or data.ndim == 4:
+            data = data.squeeze().numpy().astype(np.float16)
+        else:
+            data = data.squeeze().numpy().astype(np.float32)
+
+        fout.add_tensor(name, data)
+
+    print("Projector tensors added\n")
+
+state_dict = model.state_dict()
+for name, data in state_dict.items():
+    if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_llava_projector):
+        # we don't need this
+        print(f"skipping parameter: {name}")
+        continue
+
+    name = get_tensor_name(name)
+    data = data.squeeze().numpy()
+
+    n_dims = len(data.shape)
+
+    # ftype == 0 -> float32, ftype == 1 -> float16
+    ftype_cur = 0
+    if n_dims == 4:
+        print(f"tensor {name} is always saved in f16")
+        data = data.astype(np.float16)
+        ftype_cur = 1
+    elif ftype == 1:
+        if name[-7:] == ".weight" and n_dims == 2:
+            print("  Converting to float16")
+            data = data.astype(np.float16)
+            ftype_cur = 1
+        else:
+            print("  Converting to float32")
+            data = data.astype(np.float32)
+            ftype_cur = 0
+    else:
+        if data.dtype != np.float32:
+            print("  Converting to float32")
+            data = data.astype(np.float32)
+            ftype_cur = 0
+
+    print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}")
+    fout.add_tensor(name, data)
+
+
+fout.write_header_to_file()
+fout.write_kv_data_to_file()
+fout.write_tensors_to_file()
+fout.close()
+
+print("Done. Output file: " + fname_out)
diff --git a/tools/mtmd/legacy-models/glmedge-convert-image-encoder-to-gguf.py b/tools/mtmd/legacy-models/glmedge-convert-image-encoder-to-gguf.py
new file mode 100644
index 0000000000000000000000000000000000000000..848ef1cf3f542ab2ccc1245d3682140f8352b52b
--- /dev/null
+++ b/tools/mtmd/legacy-models/glmedge-convert-image-encoder-to-gguf.py
@@ -0,0 +1,280 @@
+import argparse
+import os
+import json
+import re
+
+import torch
+import numpy as np
+from gguf import *
+
+TEXT = "clip.text"
+VISION = "clip.vision"
+from transformers import SiglipVisionModel, SiglipVisionConfig
+
+def k(raw_key: str, arch: str) -> str:
+    return raw_key.format(arch=arch)
+
+
+def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: bool) -> bool:
+    if name in (
+        "logit_scale",
+        "text_model.embeddings.position_ids",
+        "vision_model.embeddings.position_ids",
+    ):
+        return True
+
+    if name in (
+        "vision_model.head.probe",
+        "vision_model.head.attention.in_proj_weight",
+        "vision_model.head.attention.in_proj_bias",
+        "vision_model.head.attention.out_proj.weight",
+        "vision_model.head.attention.out_proj.bias",
+        "vision_model.head.layernorm.weight",
+        "vision_model.head.layernorm.bias",
+        "vision_model.head.mlp.fc1.weight",
+        "vision_model.head.mlp.fc1.bias",
+        "vision_model.head.mlp.fc2.weight",
+        "vision_model.head.mlp.fc2.bias"
+    ):
+        return True
+
+    if name.startswith("v") and not has_vision:
+        return True
+
+    if name.startswith("t") and not has_text:
+        return True
+
+    return False
+
+
+def get_tensor_name(name: str) -> str:
+    if "projection" in name:
+        return name
+    if "mm_projector" in name:
+        name = name.replace("model.mm_projector", "mm")
+        name = re.sub(r'mm\.mlp\.mlp', 'mm.model.mlp', name, count=1)
+        name = re.sub(r'mm\.peg\.peg', 'mm.model.peg', name, count=1)
+        return name
+
+    return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln")
+
+
+def bytes_to_unicode():
+    """
+    Returns list of utf-8 byte and a corresponding list of unicode strings.
+    The reversible bpe codes work on unicode strings.
+    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+    This is a significant percentage of your normal, say, 32K bpe vocab.
+    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+    And avoids mapping to whitespace/control characters the bpe code barfs on.
+    """
+    bs = (
+        list(range(ord("!"), ord("~") + 1))
+        + list(range(ord("¡"), ord("¬") + 1))
+        + list(range(ord("®"), ord("ÿ") + 1))
+    )
+    cs = bs[:]
+    n = 0
+    for b in range(2**8):
+        if b not in bs:
+            bs.append(b)
+            cs.append(2**8 + n)
+            n += 1
+    cs = [chr(n) for n in cs]
+    return dict(zip(bs, cs))
+
+
+ap = argparse.ArgumentParser()
+ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True)
+ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16")
+ap.add_argument("--text-only", action="store_true", required=False,
+                help="Save a text-only model. It can't be used to encode images")
+ap.add_argument("--vision-only", action="store_true", required=False,
+                help="Save a vision-only model. It can't be used to encode texts")
+ap.add_argument("--clip-model-is-vision", action="store_true", required=False,
+                help="The clip model is a pure vision model (ShareGPT4V vision extract for example)")
+ap.add_argument("--clip-model-is-openclip", action="store_true", required=False,
+                help="The clip model is from openclip (for ViT-SO400M type))")
+ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.")
+ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2","adapter"], default="adapter")
+ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
+# Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711
+# Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5
+default_image_mean = [0.5, 0.5, 0.5]
+default_image_std = [0.5, 0.5, 0.5]
+ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None)
+ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None)
+
+# with proper
+args = ap.parse_args()
+
+
+if args.text_only and args.vision_only:
+    print("--text-only and --image-only arguments cannot be specified at the same time.")
+    exit(1)
+
+if args.use_f32:
+    print("WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet.")
+
+# output in the same directory as the model if output_dir is None
+dir_model = args.model_dir
+
+if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip:
+    vocab = None
+    tokens = None
+else:
+    with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f:
+        vocab = json.load(f)
+        tokens = [key for key in vocab]
+
+with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
+    config = json.load(f)
+    if args.clip_model_is_vision:
+        v_hparams = config
+        t_hparams = None
+    else:
+        v_hparams = config["vision_config"]
+        t_hparams = None
+
+# possible data types
+#   ftype == 0 -> float32
+#   ftype == 1 -> float16
+#
+# map from ftype to string
+ftype_str = ["f32", "f16"]
+
+ftype = 1
+if args.use_f32:
+    ftype = 0
+
+vision_config = SiglipVisionConfig(**v_hparams)
+model = SiglipVisionModel(vision_config)
+model.load_state_dict(torch.load(os.path.join(dir_model, "glm.clip")))
+
+fname_middle = None
+has_text_encoder = False
+has_vision_encoder = True
+has_glm_projector = True
+if args.text_only:
+    fname_middle = "text-"
+    has_vision_encoder = False
+elif args.llava_projector is not None:
+    fname_middle = "mmproj-"
+    has_text_encoder = False
+    has_glm_projector = True
+elif args.vision_only:
+    fname_middle = "vision-"
+    has_text_encoder = False
+else:
+    fname_middle = ""
+
+output_dir = args.output_dir if args.output_dir is not None else dir_model
+os.makedirs(output_dir, exist_ok=True)
+output_prefix = os.path.basename(output_dir).replace("ggml_", "")
+fname_out = os.path.join(output_dir, f"{fname_middle}model-{ftype_str[ftype]}.gguf")
+fout = GGUFWriter(path=fname_out, arch="clip")
+
+fout.add_bool("clip.has_text_encoder", has_text_encoder)
+fout.add_bool("clip.has_vision_encoder", has_vision_encoder)
+fout.add_bool("clip.has_glm_projector", has_glm_projector)
+fout.add_file_type(ftype)
+model_name = config["_name_or_path"] if "_name_or_path" in config else os.path.basename(dir_model)
+fout.add_name(model_name)
+if has_glm_projector:
+    fout.add_description("image encoder for glm4v")
+    fout.add_string("clip.projector_type", "adapter")
+else:
+    fout.add_description("two-tower CLIP model")
+
+if has_text_encoder:
+    assert t_hparams is not None
+    assert tokens is not None
+    # text_model hparams
+    fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"])
+    fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"])
+    fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, TEXT), t_hparams["intermediate_size"])
+    fout.add_uint32("clip.text.projection_dim", t_hparams.get("projection_dim", config["projection_dim"]))
+    fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, TEXT), t_hparams["num_attention_heads"])
+    fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, TEXT), t_hparams["layer_norm_eps"])
+    fout.add_uint32(k(KEY_BLOCK_COUNT, TEXT), t_hparams["num_hidden_layers"])
+    fout.add_token_list(tokens)
+
+if has_vision_encoder:
+    # vision_model hparams
+    fout.add_uint32("clip.vision.image_size", v_hparams["image_size"])
+    fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"])
+    fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"])
+    fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), v_hparams["intermediate_size"])
+    fout.add_uint32("clip.vision.projection_dim", 0)
+    fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"])
+    fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
+    fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), v_hparams["num_hidden_layers"])
+
+    image_mean = args.image_mean if args.image_mean is not None else default_image_mean
+    image_std = args.image_std if args.image_std is not None else default_image_std
+    fout.add_array("clip.vision.image_mean", image_mean)
+    fout.add_array("clip.vision.image_std", image_std)
+
+fout.add_bool("clip.use_gelu", True)
+
+
+if has_glm_projector:
+    # model.vision_model.encoder.layers.pop(-1)  # pyright: ignore[reportAttributeAccessIssue]
+    projector = torch.load(args.llava_projector)
+    for name, data in projector.items():
+        name = get_tensor_name(name)
+        # pw and dw conv ndim==4
+        if data.ndim == 2 or data.ndim == 4:
+            data = data.squeeze().numpy().astype(np.float16)
+        else:
+            data = data.squeeze().numpy().astype(np.float32)
+        if name.startswith("vision."):
+            name=name.replace("vision.","")
+        fout.add_tensor(name, data)
+        print(f"Projector {name} - {data.dtype} - shape = {data.shape}")
+        # print(f"Projector {name} tensors added\n")
+
+state_dict = model.state_dict()  # pyright: ignore[reportAttributeAccessIssue]
+for name, data in state_dict.items():
+    if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_glm_projector):
+        # we don't need this
+        print(f"skipping parameter: {name}")
+        continue
+
+    name = get_tensor_name(name)
+    data = data.squeeze().numpy()
+
+    n_dims = len(data.shape)
+
+    # ftype == 0 -> float32, ftype == 1 -> float16
+    ftype_cur = 0
+    if n_dims == 4:
+        print(f"tensor {name} is always saved in f16")
+        data = data.astype(np.float16)
+        ftype_cur = 1
+    elif ftype == 1:
+        if name[-7:] == ".weight" and n_dims == 2:
+            # print("  Converting to float16")
+            data = data.astype(np.float16)
+            ftype_cur = 1
+        else:
+            # print("  Converting to float32")
+            data = data.astype(np.float32)
+            ftype_cur = 0
+    else:
+        if data.dtype != np.float32:
+            # print("  Converting to float32")
+            data = data.astype(np.float32)
+            ftype_cur = 0
+    print(f"siglip {name} - {data.dtype} - shape = {data.shape}")
+    # print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}")
+    fout.add_tensor(name, data)
+
+
+fout.write_header_to_file()
+fout.write_kv_data_to_file()
+fout.write_tensors_to_file()
+fout.close()
+
+print("Done. Output file: " + fname_out)
diff --git a/tools/mtmd/legacy-models/glmedge-surgery.py b/tools/mtmd/legacy-models/glmedge-surgery.py
new file mode 100644
index 0000000000000000000000000000000000000000..16bb915d043cf5dfa403571ffd8b8e90017fab57
--- /dev/null
+++ b/tools/mtmd/legacy-models/glmedge-surgery.py
@@ -0,0 +1,33 @@
+import argparse
+import os
+import torch
+from transformers import AutoModel
+
+ap = argparse.ArgumentParser()
+ap.add_argument("-m", "--model", help="Path to GLM model")
+args = ap.parse_args()
+
+# find the model part that includes the the multimodal projector weights
+model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True)
+checkpoint = model.state_dict()
+
+# get a list of mm tensor names
+mm_tensors = [k for k, v in checkpoint.items() if k.startswith("vision.adapter.")]
+
+# store these tensors in a new dictionary and torch.save them
+projector = {name: checkpoint[name].float() for name in mm_tensors}
+torch.save(projector, f"{args.model}/glm.projector")
+
+clip_tensors = [k for k, v in checkpoint.items() if k.startswith("vision.vit.model.vision_model.")]
+if len(clip_tensors) > 0:
+    clip = {name.replace("vision.vit.model.", ""): checkpoint[name].float() for name in clip_tensors}
+    torch.save(clip, f"{args.model}/glm.clip")
+
+    # added tokens should be removed to be able to convert Mistral models
+    if os.path.exists(f"{args.model}/added_tokens.json"):
+        with open(f"{args.model}/added_tokens.json", "w") as f:
+            f.write("{}\n")
+
+print("Done!")
+print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.")
+print(f"Also, use {args.model}glm.projector to prepare a glm-encoder.gguf file.")
diff --git a/tools/mtmd/legacy-models/llava_surgery.py b/tools/mtmd/legacy-models/llava_surgery.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f2da3beefff670cfd61ffcff3769651d74a551c
--- /dev/null
+++ b/tools/mtmd/legacy-models/llava_surgery.py
@@ -0,0 +1,38 @@
+import argparse
+import glob
+import os
+import torch
+
+
+ap = argparse.ArgumentParser()
+ap.add_argument("-m", "--model", help="Path to LLaVA v1.5 model")
+args = ap.parse_args()
+
+# find the model part that includes the the multimodal projector weights
+path = sorted(glob.glob(f"{args.model}/pytorch_model*.bin"))[-1]
+checkpoint = torch.load(path)
+
+# get a list of mm tensor names
+mm_tensors = [k for k, v in checkpoint.items() if k.startswith("model.mm_projector")]
+
+# store these tensors in a new dictionary and torch.save them
+projector = {name: checkpoint[name].float() for name in mm_tensors}
+torch.save(projector, f"{args.model}/llava.projector")
+
+# BakLLaVA models contain CLIP tensors in it
+clip_tensors = [k for k, v in checkpoint.items() if k.startswith("model.vision_tower")]
+if len(clip_tensors) > 0:
+    clip = {name.replace("vision_tower.vision_tower.", ""): checkpoint[name].float() for name in clip_tensors}
+    torch.save(clip, f"{args.model}/llava.clip")
+
+
+    # added tokens should be removed to be able to convert Mistral models
+    if os.path.exists(f"{args.model}/added_tokens.json"):
+        with open(f"{args.model}/added_tokens.json", "w") as f:
+            f.write("{}\n")
+
+
+
+print("Done!")
+print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.")
+print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.")
diff --git a/tools/mtmd/legacy-models/llava_surgery_v2.py b/tools/mtmd/legacy-models/llava_surgery_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..b07c3e323c4c6a7ddc66f201fe6ab167814e0b1a
--- /dev/null
+++ b/tools/mtmd/legacy-models/llava_surgery_v2.py
@@ -0,0 +1,180 @@
+import argparse
+import glob
+import os
+import torch
+from safetensors import safe_open
+from safetensors.torch import save_file
+from typing import Any, ContextManager, cast
+
+# Function to determine if file is a SafeTensor file
+def is_safetensor_file(file_path):
+    return file_path.endswith('.safetensors')
+
+
+# Unified loading function
+def load_model(file_path):
+    if is_safetensor_file(file_path):
+        tensors = {}
+        with cast(ContextManager[Any], safe_open(file_path, framework="pt", device="cpu")) as f:
+            for key in f.keys():
+                tensors[key] = f.get_tensor(key).clone()
+                # output shape
+                print(f"{key} : {tensors[key].shape}")
+        return tensors, 'safetensor'
+    else:
+        return torch.load(file_path, map_location=torch.device('cpu')), 'pytorch'
+
+
+# Unified saving function
+def save_model(model, file_path, file_type):
+    if file_type == 'safetensor':
+        # safe_save(model, file_path)
+        save_file(model, file_path)
+    else:
+        torch.save(model, file_path)
+
+# Helpers to match weight names from specific components or
+# determine if a saved shard contains that component
+def is_vision_tower(weight_name):
+    return (
+        weight_name.startswith("model.vision_tower") or
+        weight_name.startswith("vit.") or
+        weight_name.startswith("vision_tower")
+    )
+
+def is_newline(weight_name):
+    return (
+        weight_name.startswith("model.image_newline") or
+        weight_name.startswith("image_newline")
+    )
+
+def is_mm_projector(weight_name):
+    return (
+        weight_name.startswith("model.mm_projector") or
+        weight_name.startswith("vision_proj.") or
+        weight_name.startswith("multi_modal_projector")
+    )
+
+def newline_criteria(checkpoint):
+    return any(is_newline(k) for k in checkpoint.keys())
+
+def proj_criteria(checkpoint):
+    return any(is_mm_projector(k) for k in checkpoint.keys())
+
+# Adapted function to clean vision tower from checkpoint
+def clean_vision_tower_from_checkpoint(checkpoint_path):
+    checkpoint, file_type = load_model(checkpoint_path)
+    # file_type = 'pytorch'
+    model_path = os.path.dirname(checkpoint_path)
+    print(f"Searching for vision tower tensors in {checkpoint_path}")
+    clip_tensors = [k for k, v in checkpoint.items() if is_vision_tower(k)]
+
+    if len(clip_tensors) > 0:
+        print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
+        # Adapted for file type
+        clip_path = os.path.join(model_path, "llava.clip")
+
+        if os.path.exists(clip_path):
+            print(f"Loading existing llava.clip from {clip_path}")
+            existing_clip, _ = load_model(clip_path)
+        else:
+            print(f"Creating new llava.clip at {clip_path}")
+            existing_clip = {}
+        # Update existing_clip with new tensors, avoid duplicates
+        for name in clip_tensors:
+            simple_name = name[name.index('vision_model.'):] if 'vision_model.' in name else name
+            print(f"Adding {simple_name} to llava.clip")
+            if simple_name not in existing_clip:
+                existing_clip[simple_name] = checkpoint[name]
+
+        # Save the updated clip tensors back to llava.clip
+        save_model(existing_clip, clip_path, 'pytorch')
+
+        # Remove the tensors from the original checkpoint
+        for name in clip_tensors:
+            del checkpoint[name]
+
+        checkpoint_path = checkpoint_path
+        return True
+    return False
+
+def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector):
+    newline_checkpoint_path = None
+    projector_checkpoint_path = None
+
+    for path in checkpoint_paths:
+        checkpoint, _ = load_model(path)
+        if newline_criteria(checkpoint) and newline_checkpoint_path is None:
+            newline_checkpoint_path = path
+        if projector(checkpoint):
+            projector_checkpoint_path = path
+
+    return newline_checkpoint_path, projector_checkpoint_path
+
+
+# Command-line interface setup
+ap = argparse.ArgumentParser()
+ap.add_argument("-m", "--model", required=True, help="Path to LLaVA v1.5+ model")
+ap.add_argument("-C", "--clean-vision-tower", action="store_true", help="Remove any vision tower from the model files")
+args = ap.parse_args()
+
+if args.clean_vision_tower:
+    # Generalized to handle both PyTorch and SafeTensors models
+    model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
+    # checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and path.startswith('pytorch')) or (path.endswith('.safetensors') and path.startswith('model'))]
+    checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
+    for projector_checkpoint_path in checkpoint_paths:
+        print(f"Cleaning {projector_checkpoint_path}")
+        if not clean_vision_tower_from_checkpoint(projector_checkpoint_path):
+            print(f"No vision tower found in {projector_checkpoint_path}")
+            # we break once none is found, so far all models append them at the end
+            # break
+    print("Done! All vision tower tensors are removed from the model files and stored in llava.clip file.")
+
+# Now we look for the projector in the last checkpoint
+model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
+checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
+# last_checkpoint_path = checkpoint_paths[0]
+# first_checkpoint_path = checkpoint_paths[-1]
+newline_checkpoint_path, projector_checkpoint_path = find_relevant_checkpoints(checkpoint_paths, newline_criteria, proj_criteria)
+
+print(f"Taking projector from {projector_checkpoint_path}")
+first_mm_tensors = []
+first_checkpoint = None
+if newline_checkpoint_path is not None:
+    print(f"Taking newline from {newline_checkpoint_path}")
+    first_checkpoint, file_type = load_model(newline_checkpoint_path)
+    first_mm_tensors = [k for k, v in first_checkpoint.items() if is_newline(k)]
+
+# Load the checkpoint
+mm_tensors = []
+last_checkpoint = None
+if projector_checkpoint_path is not None:
+    last_checkpoint, file_type = load_model(projector_checkpoint_path)
+    mm_tensors = [k for k, v in last_checkpoint.items() if is_mm_projector(k)]
+
+if len(mm_tensors) == 0:
+    if last_checkpoint is not None:
+        for k, v in last_checkpoint.items():
+            print(k)
+    print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint) if last_checkpoint is not None else 0} tensors.")
+    print("No tensors found. Is this a LLaVA model?")
+    exit()
+
+print(f"Found {len(mm_tensors)} tensors to extract.")
+print(f"Found additional {len(first_mm_tensors)} tensors to extract.")
+# projector = {name: checkpoint.[name].float() for name in mm_tensors}
+projector = {}
+for name in mm_tensors:
+    assert last_checkpoint is not None
+    projector[name] = last_checkpoint[name].float()
+for name in first_mm_tensors:
+    assert first_checkpoint is not None
+    projector[name] = first_checkpoint[name].float()
+
+if len(projector) > 0:
+    save_model(projector, f"{args.model}/llava.projector", 'pytorch')
+
+print("Done!")
+print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.")
+print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.")
diff --git a/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py b/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfe0961f9891ade6374aa8a701c0e009c274c164
--- /dev/null
+++ b/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py
@@ -0,0 +1,814 @@
+# coding=utf-8
+# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch Siglip model. """
+# Copied from  HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes
+
+
+import os
+import math
+import warnings
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn.init import _calculate_fan_in_and_fan_out
+
+from transformers.activations import ACT2FN
+from transformers.modeling_utils import PreTrainedModel
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import (
+    logging,
+)
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+class SiglipVisionConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
+    Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
+    configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
+    [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+    Args:
+        hidden_size (`int`, *optional*, defaults to 768):
+            Dimensionality of the encoder layers and the pooler layer.
+        intermediate_size (`int`, *optional*, defaults to 3072):
+            Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+        num_hidden_layers (`int`, *optional*, defaults to 12):
+            Number of hidden layers in the Transformer encoder.
+        num_attention_heads (`int`, *optional*, defaults to 12):
+            Number of attention heads for each attention layer in the Transformer encoder.
+        num_channels (`int`, *optional*, defaults to 3):
+            Number of channels in the input images.
+        image_size (`int`, *optional*, defaults to 224):
+            The size (resolution) of each image.
+        patch_size (`int`, *optional*, defaults to 16):
+            The size (resolution) of each patch.
+        hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
+            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+            `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
+        layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+            The epsilon used by the layer normalization layers.
+        attention_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention probabilities.
+    Example:
+    ```python
+    >>> from transformers import SiglipVisionConfig, SiglipVisionModel
+    >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
+    >>> configuration = SiglipVisionConfig()
+    >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
+    >>> model = SiglipVisionModel(configuration)
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "siglip_vision_model"
+
+    def __init__(
+        self,
+        hidden_size=768,
+        intermediate_size=3072,
+        num_hidden_layers=12,
+        num_attention_heads=12,
+        num_channels=3,
+        image_size=224,
+        patch_size=16,
+        hidden_act="gelu_pytorch_tanh",
+        layer_norm_eps=1e-6,
+        attention_dropout=0.0,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        self.num_channels = num_channels
+        self.patch_size = patch_size
+        self.image_size = image_size
+        self.attention_dropout = attention_dropout
+        self.layer_norm_eps = layer_norm_eps
+        self.hidden_act = hidden_act
+
+_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
+
+SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
+    "google/siglip-base-patch16-224",
+    # See all SigLIP models at https://huggingface.co/models?filter=siglip
+]
+
+# Copied from transformers.models.llama.modeling_llama._get_unpad_data
+def _get_unpad_data(attention_mask):
+    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+    max_seqlen_in_batch = seqlens_in_batch.max().item()
+    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+    return (
+        indices,
+        cu_seqlens,
+        max_seqlen_in_batch,
+    )
+
+
+def _trunc_normal_(tensor, mean, std, a, b):
+    # Cut & paste from PyTorch official master until it's in a few official releases - RW
+    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+    def norm_cdf(x):
+        # Computes standard normal cumulative distribution function
+        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
+
+    if (mean < a - 2 * std) or (mean > b + 2 * std):
+        warnings.warn(
+            "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+            "The distribution of values may be incorrect.",
+            stacklevel=2,
+        )
+
+    # Values are generated by using a truncated uniform distribution and
+    # then using the inverse CDF for the normal distribution.
+    # Get upper and lower cdf values
+    l = norm_cdf((a - mean) / std)
+    u = norm_cdf((b - mean) / std)
+
+    # Uniformly fill tensor with values from [l, u], then translate to
+    # [2l-1, 2u-1].
+    tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+    # Use inverse cdf transform for normal distribution to get truncated
+    # standard normal
+    if tensor.dtype in [torch.float16, torch.bfloat16]:
+        # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
+        og_dtype = tensor.dtype
+        tensor = tensor.to(torch.float32)
+        tensor.erfinv_()
+        tensor = tensor.to(og_dtype)
+    else:
+        tensor.erfinv_()
+
+    # Transform to proper mean, std
+    tensor.mul_(std * math.sqrt(2.0))
+    tensor.add_(mean)
+
+    # Clamp to ensure it's in the proper range
+    if tensor.dtype == torch.float16:
+        # The `clamp_` op is not (yet?) defined in float16+cpu
+        tensor = tensor.to(torch.float32)
+        tensor.clamp_(min=a, max=b)
+        tensor = tensor.to(torch.float16)
+    else:
+        tensor.clamp_(min=a, max=b)
+
+
+def trunc_normal_tf_(
+    tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
+):
+    """Fills the input Tensor with values drawn from a truncated
+    normal distribution. The values are effectively drawn from the
+    normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
+    with values outside :math:`[a, b]` redrawn until they are within
+    the bounds. The method used for generating the random values works
+    best when :math:`a \\leq \text{mean} \\leq b`.
+    NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
+    bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
+    and the result is subsquently scaled and shifted by the mean and std args.
+    Args:
+        tensor: an n-dimensional `torch.Tensor`
+        mean: the mean of the normal distribution
+        std: the standard deviation of the normal distribution
+        a: the minimum cutoff value
+        b: the maximum cutoff value
+    """
+    with torch.no_grad():
+        _trunc_normal_(tensor, 0, 1.0, a, b)
+        tensor.mul_(std).add_(mean)
+
+
+def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
+    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
+    denom = fan_in
+    if mode == "fan_in":
+        denom = fan_in
+    elif mode == "fan_out":
+        denom = fan_out
+    elif mode == "fan_avg":
+        denom = (fan_in + fan_out) / 2
+
+    variance = scale / denom
+
+    if distribution == "truncated_normal":
+        # constant is stddev of standard normal truncated to (-2, 2)
+        trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
+    elif distribution == "normal":
+        with torch.no_grad():
+            tensor.normal_(std=math.sqrt(variance))
+    elif distribution == "uniform":
+        bound = math.sqrt(3 * variance)
+        with torch.no_grad():
+            tensor.uniform_(-bound, bound)
+    else:
+        raise ValueError(f"invalid distribution {distribution}")
+
+
+def lecun_normal_(tensor):
+    variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
+
+
+def default_flax_embed_init(tensor):
+    variance_scaling_(tensor, mode="fan_in", distribution="normal")
+
+class SiglipVisionEmbeddings(nn.Module):
+    def __init__(self, config: SiglipVisionConfig):
+        super().__init__()
+        self.config = config
+        self.embed_dim = config.hidden_size
+        self.image_size = config.image_size
+        self.patch_size = config.patch_size
+
+        self.patch_embedding = nn.Conv2d(
+            in_channels=config.num_channels,
+            out_channels=self.embed_dim,
+            kernel_size=self.patch_size,
+            stride=self.patch_size,
+            padding="valid",
+        )
+
+        self.num_patches_per_side = self.image_size // self.patch_size
+        self.num_patches = self.num_patches_per_side**2
+        self.num_positions = self.num_patches
+        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+
+class SiglipAttention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.embed_dim = config.hidden_size
+        self.num_heads = config.num_attention_heads
+        self.head_dim = self.embed_dim // self.num_heads
+        if self.head_dim * self.num_heads != self.embed_dim:
+            raise ValueError(
+                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+                f" {self.num_heads})."
+            )
+        self.scale = self.head_dim**-0.5
+        self.dropout = config.attention_dropout
+
+        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
+class SiglipMLP(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.activation_fn = ACT2FN[config.hidden_act]
+        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
+class SiglipEncoderLayer(nn.Module):
+    def __init__(self, config: SiglipVisionConfig):
+        super().__init__()
+        self.embed_dim = config.hidden_size
+        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+        self.self_attn = (
+            SiglipAttention(config)
+        )
+        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+        self.mlp = SiglipMLP(config)
+        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+class SiglipPreTrainedModel(PreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = SiglipVisionConfig
+    base_model_prefix = "siglip"
+    supports_gradient_checkpointing = True
+
+    def _init_weights(self, module):
+        """Initialize the weights"""
+
+        if isinstance(module, SiglipVisionEmbeddings):
+            width = self.config.hidden_size
+            nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
+        elif isinstance(module, nn.Embedding):
+            default_flax_embed_init(module.weight)
+        elif isinstance(module, SiglipAttention):
+            nn.init.normal_(module.q_proj.weight)
+            nn.init.normal_(module.k_proj.weight)
+            nn.init.normal_(module.v_proj.weight)
+            nn.init.normal_(module.out_proj.weight)
+            nn.init.zeros_(module.q_proj.bias)
+            nn.init.zeros_(module.k_proj.bias)
+            nn.init.zeros_(module.v_proj.bias)
+            nn.init.zeros_(module.out_proj.bias)
+        elif isinstance(module, SiglipMLP):
+            nn.init.normal_(module.fc1.weight)
+            nn.init.normal_(module.fc2.weight)
+            nn.init.normal_(module.fc1.bias, std=1e-6)
+            nn.init.normal_(module.fc2.bias, std=1e-6)
+        elif isinstance(module, (nn.Linear, nn.Conv2d)):
+            lecun_normal_(module.weight)
+            if module.bias is not None:
+                nn.init.zeros_(module.bias)
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+
+
+SIGLIP_START_DOCSTRING = r"""
+    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+    etc.)
+    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+    and behavior.
+    Parameters:
+        config ([`SiglipVisionConfig`]): Model configuration class with all the parameters of the model.
+            Initializing with a config file does not load the weights associated with the model, only the
+            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+SIGLIP_VISION_INPUTS_DOCSTRING = r"""
+    Args:
+        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+            Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
+            [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
+class SiglipEncoder(nn.Module):
+    """
+    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+    [`SiglipEncoderLayer`].
+    Args:
+        config: SiglipConfig
+    """
+
+    def __init__(self, config: SiglipVisionConfig):
+        super().__init__()
+        self.config = config
+        self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+        self.gradient_checkpointing = False
+
+class SiglipVisionTransformer(SiglipPreTrainedModel):
+    config_class = SiglipVisionConfig
+    main_input_name = "pixel_values"
+    _supports_flash_attn_2 = True
+
+    def __init__(self, config: SiglipVisionConfig):
+        super().__init__(config)
+        self.config = config
+        embed_dim = config.hidden_size
+
+        self.embeddings = SiglipVisionEmbeddings(config)
+        self.encoder = SiglipEncoder(config)
+        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self) -> nn.Module:
+        return self.embeddings.patch_embedding
+
+import argparse
+import json
+import re
+
+import numpy as np
+from gguf import *
+from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer, Idefics2VisionConfig
+
+TEXT = "clip.text"
+VISION = "clip.vision"
+
+
+def add_key_str(raw_key: str, arch: str) -> str:
+    return raw_key.format(arch=arch)
+
+
+def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_minicpmv: bool) -> bool:
+    if name in (
+        "logit_scale",
+        "text_model.embeddings.position_ids",
+        "vision_model.embeddings.position_ids",
+    ):
+        return True
+
+    if has_minicpmv and name in ["visual_projection.weight"]:
+        return True
+
+    if name.startswith("v") and not has_vision:
+        return True
+
+    if name.startswith("t") and not has_text:
+        return True
+
+    return False
+
+
+def get_tensor_name(name: str) -> str:
+    if "projection" in name:
+        return name
+    if "mm_projector" in name:
+        name = name.replace("model.mm_projector", "mm")
+        name = re.sub(r'mm\.mlp\.mlp', 'mm.model.mlp', name, count=1)
+        name = re.sub(r'mm\.peg\.peg', 'mm.model.peg', name, count=1)
+        return name
+
+    return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln")
+
+
+def bytes_to_unicode():
+    """
+    Returns list of utf-8 byte and a corresponding list of unicode strings.
+    The reversible bpe codes work on unicode strings.
+    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+    This is a significant percentage of your normal, say, 32K bpe vocab.
+    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+    And avoids mapping to whitespace/control characters the bpe code barfs on.
+    """
+    bs = (
+        list(range(ord("!"), ord("~") + 1))
+        + list(range(ord("¡"), ord("¬") + 1))
+        + list(range(ord("®"), ord("ÿ") + 1))
+    )
+    cs = bs[:]
+    n = 0
+    for b in range(2**8):
+        if b not in bs:
+            bs.append(b)
+            cs.append(2**8 + n)
+            n += 1
+    cs = [chr(n) for n in cs]
+    return dict(zip(bs, cs))
+
+
+ap = argparse.ArgumentParser()
+ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True)
+ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16")
+ap.add_argument("--text-only", action="store_true", required=False,
+                help="Save a text-only model. It can't be used to encode images")
+ap.add_argument("--vision-only", action="store_true", required=False,
+                help="Save a vision-only model. It can't be used to encode texts")
+ap.add_argument("--clip-model-is-vision", action="store_true", required=False,
+                help="The clip model is a pure vision model (ShareGPT4V vision extract for example)")
+ap.add_argument("--clip-model-is-openclip", action="store_true", required=False,
+                help="The clip model is from openclip (for ViT-SO400M type))")
+ap.add_argument("--minicpmv-projector", help="Path to minicpmv.projector file. If specified, save an image encoder for MiniCPM-V models.")
+ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp")
+ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
+# Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711
+# Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5
+default_image_mean = [0.48145466, 0.4578275, 0.40821073]
+default_image_std = [0.26862954, 0.26130258, 0.27577711]
+ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None)
+ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None)
+ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3; MiniCPM-o-2.6 use 4', default=2)
+
+# with proper
+args = ap.parse_args()
+
+
+if args.text_only and args.vision_only:
+    print("--text-only and --image-only arguments cannot be specified at the same time.")
+    exit(1)
+
+if args.use_f32:
+    print("WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet.")
+
+# output in the same directory as the model if output_dir is None
+dir_model = args.model_dir
+
+if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip:
+    vocab = None
+    tokens = None
+else:
+    with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f:
+        vocab = json.load(f)
+        tokens = [key for key in vocab]
+
+# possible data types
+#   ftype == 0 -> float32
+#   ftype == 1 -> float16
+#
+# map from ftype to string
+ftype_str = ["f32", "f16"]
+
+ftype = 1
+if args.use_f32:
+    ftype = 0
+
+# if args.clip_model_is_vision or args.clip_model_is_openclip:
+#     model = CLIPVisionModel.from_pretrained(dir_model)
+#     processor = None
+# else:
+#     model = CLIPModel.from_pretrained(dir_model)
+#     processor = CLIPProcessor.from_pretrained(dir_model)
+
+minicpmv_version = args.minicpmv_version
+emb_dim = 4096
+block_count = 26
+if minicpmv_version == 1:
+    emb_dim = 2304
+    block_count = 26
+elif minicpmv_version == 2:
+    emb_dim = 4096
+    block_count = 27
+elif minicpmv_version == 3:
+    emb_dim = 3584
+    block_count = 27
+elif minicpmv_version == 4:
+    emb_dim = 3584
+    block_count = 27
+
+default_vision_config = {
+        "hidden_size": 1152,
+        "image_size": 980,
+        "intermediate_size": 4304,
+        "model_type": "idefics2",
+        "num_attention_heads": 16,
+        "num_hidden_layers": 27,
+        "patch_size": 14,
+    }
+
+vision_config = Idefics2VisionConfig(**default_vision_config)
+model = Idefics2VisionTransformer(vision_config)
+if minicpmv_version == 3:
+    vision_config = SiglipVisionConfig(**default_vision_config)
+    model = SiglipVisionTransformer(vision_config)
+elif minicpmv_version == 4:
+    vision_config = SiglipVisionConfig(**default_vision_config)
+    model = SiglipVisionTransformer(vision_config)
+
+processor = None
+# if model.attn_pool is not None:
+#     model.attn_pool = torch.nn.Identity()
+
+# model.blocks = model.blocks[:-1]
+model.load_state_dict(torch.load(os.path.join(dir_model, "minicpmv.clip")))
+
+fname_middle = None
+has_text_encoder = True
+has_vision_encoder = True
+has_minicpmv_projector = False
+
+if args.text_only:
+    fname_middle = "text-"
+    has_vision_encoder = False
+elif args.minicpmv_projector is not None:
+    fname_middle = "mmproj-"
+    has_text_encoder = False
+    has_minicpmv_projector = True
+elif args.vision_only:
+    fname_middle = "vision-"
+    has_text_encoder = False
+else:
+    fname_middle = ""
+
+output_dir = args.output_dir if args.output_dir is not None else dir_model
+os.makedirs(output_dir, exist_ok=True)
+output_prefix = os.path.basename(output_dir).replace("ggml_", "")
+fname_out = os.path.join(output_dir, f"{fname_middle}model-{ftype_str[ftype]}.gguf")
+fout = GGUFWriter(path=fname_out, arch="clip")
+
+fout.add_bool("clip.has_text_encoder", has_text_encoder)
+fout.add_bool("clip.has_vision_encoder", has_vision_encoder)
+fout.add_bool("clip.has_minicpmv_projector", has_minicpmv_projector)
+fout.add_file_type(ftype)
+if args.text_only:
+    fout.add_description("text-only CLIP model")
+elif args.vision_only and not has_minicpmv_projector:
+    fout.add_description("vision-only CLIP model")
+elif has_minicpmv_projector:
+    fout.add_description("image encoder for MiniCPM-V")
+    # add projector type
+    fout.add_string("clip.projector_type", "resampler")
+    fout.add_int32("clip.minicpmv_version", minicpmv_version)
+else:
+    fout.add_description("two-tower CLIP model")
+
+if has_vision_encoder:
+    # vision_model hparams
+    fout.add_uint32("clip.vision.image_size", 448)
+    fout.add_uint32("clip.vision.patch_size", 14)
+    fout.add_uint32(add_key_str(KEY_EMBEDDING_LENGTH, VISION), 1152)
+    fout.add_uint32(add_key_str(KEY_FEED_FORWARD_LENGTH, VISION), 4304)
+    fout.add_uint32("clip.vision.projection_dim", 0)
+    fout.add_uint32(add_key_str(KEY_ATTENTION_HEAD_COUNT, VISION), 16)
+    fout.add_float32(add_key_str(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
+    fout.add_uint32(add_key_str(KEY_BLOCK_COUNT, VISION), block_count)
+
+    if processor is not None:
+        image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean
+        image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std
+    else:
+        image_mean = args.image_mean if args.image_mean is not None else default_image_mean
+        image_std = args.image_std if args.image_std is not None else default_image_std
+    fout.add_array("clip.vision.image_mean", image_mean)
+    fout.add_array("clip.vision.image_std", image_std)
+
+use_gelu = True
+fout.add_bool("clip.use_gelu", use_gelu)
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+    """
+    embed_dim: output dimension for each position
+    pos: a list of positions to be encoded: size (M,)
+    out: (M, D)
+    """
+    assert embed_dim % 2 == 0
+    omega = np.arange(embed_dim // 2, dtype=np.float32)
+    omega /= embed_dim / 2.
+    omega = 1. / 10000 ** omega  # (D/2,)
+
+    pos = pos.reshape(-1)  # (M,)
+    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product
+
+    emb_sin = np.sin(out)  # (M, D/2)
+    emb_cos = np.cos(out)  # (M, D/2)
+
+    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
+    return emb
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+    assert embed_dim % 2 == 0
+
+    # use half of dimensions to encode grid_h
+    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
+    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)
+
+    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
+    return emb
+
+
+# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
+    """
+    grid_size: int of the grid height and width
+    return:
+    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+    """
+    if isinstance(grid_size, int):
+        grid_h_size, grid_w_size = grid_size, grid_size
+    else:
+        grid_h_size, grid_w_size = grid_size[0], grid_size[1]
+
+    grid_h = np.arange(grid_h_size, dtype=np.float32)
+    grid_w = np.arange(grid_w_size, dtype=np.float32)
+    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
+    grid = np.stack(grid, axis=0)
+
+    grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
+    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+    if cls_token:
+        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+    return pos_embed
+
+def _replace_name_resampler(s, v):
+    if re.match("resampler.pos_embed", s):
+        return {
+            s: v,
+            re.sub("pos_embed", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(emb_dim, (70, 70))),
+        }
+    if re.match("resampler.proj", s):
+        return {
+            re.sub("proj", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(emb_dim, (70, 70))),
+            re.sub("proj", "proj.weight", s): v.transpose(-1, -2).contiguous(),
+        }
+    if re.match("resampler.attn.in_proj_.*", s):
+        return {
+            re.sub("attn.in_proj_", "attn.q.", s): v.chunk(3, dim=0)[0],
+            re.sub("attn.in_proj_", "attn.k.", s): v.chunk(3, dim=0)[1],
+            re.sub("attn.in_proj_", "attn.v.", s): v.chunk(3, dim=0)[2],
+        }
+    return {s: v}
+
+if has_minicpmv_projector:
+    projector = torch.load(args.minicpmv_projector)
+    new_state_dict = {}
+    for k, v in projector.items():
+        kvs = _replace_name_resampler(k, v)
+        for nk, nv in kvs.items():
+            new_state_dict[nk] = nv
+    projector = new_state_dict
+    ftype_cur = 0
+    for name, data in projector.items():
+        name = get_tensor_name(name)
+        data = data.squeeze().numpy()
+
+        n_dims = len(data.shape)
+        if ftype == 1:
+            if name[-7:] == ".weight" and n_dims == 2:
+                print("  Converting to float16")
+                data = data.astype(np.float16)
+                ftype_cur = 1
+            else:
+                print("  Converting to float32")
+                data = data.astype(np.float32)
+                ftype_cur = 0
+        else:
+            if data.dtype != np.float32:
+                print("  Converting to float32")
+                data = data.astype(np.float32)
+                ftype_cur = 0
+
+        fout.add_tensor(name, data)
+        print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}")
+
+    print("Projector tensors added\n")
+
+def _replace_name(s, v):
+    s = "vision_model." + s
+    if re.match("vision_model.embeddings.position_embedding", s):
+        v = v.unsqueeze(0)
+        return {s: v}
+
+    return {s: v}
+
+state_dict = model.state_dict()
+new_state_dict = {}
+for k, v in state_dict.items():
+    kvs = _replace_name(k, v)
+    for nk, nv in kvs.items():
+        new_state_dict[nk] = nv
+state_dict = new_state_dict
+for name, data in state_dict.items():
+    if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_minicpmv_projector):
+        # we don't need this
+        print(f"skipping parameter: {name}")
+        continue
+
+    name = get_tensor_name(name)
+    data = data.squeeze().numpy()
+
+    n_dims = len(data.shape)
+
+    # ftype == 0 -> float32, ftype == 1 -> float16
+    ftype_cur = 0
+    if n_dims == 4:
+        print(f"tensor {name} is always saved in f16")
+        data = data.astype(np.float16)
+        ftype_cur = 1
+    elif ftype == 1:
+        if name[-7:] == ".weight" and n_dims == 2:
+            print("  Converting to float16")
+            data = data.astype(np.float16)
+            ftype_cur = 1
+        else:
+            print("  Converting to float32")
+            data = data.astype(np.float32)
+            ftype_cur = 0
+    else:
+        if data.dtype != np.float32:
+            print("  Converting to float32")
+            data = data.astype(np.float32)
+            ftype_cur = 0
+
+    print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}")
+    fout.add_tensor(name, data)
+
+
+fout.write_header_to_file()
+fout.write_kv_data_to_file()
+fout.write_tensors_to_file()
+fout.close()
+
+print("Done. Output file: " + fname_out)
diff --git a/tools/mtmd/legacy-models/minicpmv-surgery.py b/tools/mtmd/legacy-models/minicpmv-surgery.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba82116582b1fa4d530ee61ed6e98ca9471a6ce3
--- /dev/null
+++ b/tools/mtmd/legacy-models/minicpmv-surgery.py
@@ -0,0 +1,45 @@
+import argparse
+import os
+import torch
+from transformers import AutoModel, AutoTokenizer
+
+ap = argparse.ArgumentParser()
+ap.add_argument("-m", "--model", help="Path to MiniCPM-V model")
+args = ap.parse_args()
+
+# find the model part that includes the the multimodal projector weights
+model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True, torch_dtype=torch.bfloat16)
+checkpoint = model.state_dict()
+
+# get a list of mm tensor names
+mm_tensors = [k for k, v in checkpoint.items() if k.startswith("resampler")]
+
+# store these tensors in a new dictionary and torch.save them
+projector = {name: checkpoint[name].float() for name in mm_tensors}
+torch.save(projector, f"{args.model}/minicpmv.projector")
+
+clip_tensors = [k for k, v in checkpoint.items() if k.startswith("vpm")]
+if len(clip_tensors) > 0:
+    clip = {name.replace("vpm.", ""): checkpoint[name].float() for name in clip_tensors}
+    torch.save(clip, f"{args.model}/minicpmv.clip")
+
+    # added tokens should be removed to be able to convert Mistral models
+    if os.path.exists(f"{args.model}/added_tokens.json"):
+        with open(f"{args.model}/added_tokens.json", "w") as f:
+            f.write("{}\n")
+
+config = model.llm.config
+config.auto_map = {
+    "AutoConfig": "configuration_minicpm.MiniCPMConfig",
+    "AutoModel": "modeling_minicpm.MiniCPMModel",
+    "AutoModelForCausalLM": "modeling_minicpm.MiniCPMForCausalLM",
+    "AutoModelForSeq2SeqLM": "modeling_minicpm.MiniCPMForCausalLM",
+    "AutoModelForSequenceClassification": "modeling_minicpm.MiniCPMForSequenceClassification"
+}
+model.llm.save_pretrained(f"{args.model}/model")
+tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
+tok.save_pretrained(f"{args.model}/model")
+
+print("Done!")
+print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.")
+print(f"Also, use {args.model}/minicpmv.projector to prepare a minicpmv-encoder.gguf file.")
diff --git a/tools/mtmd/mtmd-audio.cpp b/tools/mtmd/mtmd-audio.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..4d053895cdae94270fa21f23755fb9658b3cbd71
--- /dev/null
+++ b/tools/mtmd/mtmd-audio.cpp
@@ -0,0 +1,769 @@
+#include "mtmd-audio.h"
+
+#define _USE_MATH_DEFINES // for M_PI
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+// most of the code here is copied from whisper.cpp
+
+// align x to upper multiple of n
+#define _ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
+
+namespace whisper_preprocessor {
+
+#define SIN_COS_N_COUNT WHISPER_N_FFT
+namespace {
+struct whisper_global_cache {
+    // In FFT, we frequently use sine and cosine operations with the same values.
+    // We can use precalculated values to speed up the process.
+    float sin_vals[SIN_COS_N_COUNT];
+    float cos_vals[SIN_COS_N_COUNT];
+
+    // Hann window (Use cosf to eliminate difference)
+    // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
+    // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
+    float hann_window[WHISPER_N_FFT];
+
+    whisper_global_cache() {
+        fill_sin_cos_table();
+        fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window);
+    }
+
+    void fill_sin_cos_table() {
+        for (int i = 0; i < SIN_COS_N_COUNT; i++) {
+            double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
+            sin_vals[i] = sinf(theta);
+            cos_vals[i] = cosf(theta);
+        }
+    }
+
+    void fill_hann_window(int length, bool periodic, float * output) {
+        int offset = -1;
+        if (periodic) {
+            offset = 0;
+        }
+        for (int i = 0; i < length; i++) {
+            output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
+        }
+    }
+} global_cache;
+}
+
+// naive Discrete Fourier Transform
+// input is real-valued
+// output is complex-valued
+static void dft(const float* in, int N, float* out) {
+    const int sin_cos_step = SIN_COS_N_COUNT / N;
+
+    for (int k = 0; k < N; k++) {
+        float re = 0;
+        float im = 0;
+
+        for (int n = 0; n < N; n++) {
+            int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
+            re += in[n]*global_cache.cos_vals[idx]; // cos(t)
+            im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
+        }
+
+        out[k*2 + 0] = re;
+        out[k*2 + 1] = im;
+    }
+}
+
+// Cooley-Tukey FFT
+// poor man's implementation - use something better
+// input is real-valued
+// output is complex-valued
+static void fft(float* in, int N, float* out) {
+    if (N == 1) {
+        out[0] = in[0];
+        out[1] = 0;
+        return;
+    }
+
+    const int half_N = N / 2;
+    if (N - half_N*2 == 1) {
+        dft(in, N, out);
+        return;
+    }
+
+    float* even = in + N;
+    for (int i = 0; i < half_N; ++i) {
+        even[i]= in[2*i];
+    }
+    float* even_fft = out + 2 * N;
+    fft(even, half_N, even_fft);
+
+    float* odd = even;
+    for (int i = 0; i < half_N; ++i) {
+        odd[i] = in[2*i + 1];
+    }
+    float* odd_fft = even_fft + N;
+    fft(odd, half_N, odd_fft);
+
+    const int sin_cos_step = SIN_COS_N_COUNT / N;
+    for (int k = 0; k < half_N; k++) {
+        int idx = k * sin_cos_step; // t = 2*M_PI*k/N
+        float re = global_cache.cos_vals[idx]; // cos(t)
+        float im = -global_cache.sin_vals[idx]; // sin(t)
+
+        float re_odd = odd_fft[2*k + 0];
+        float im_odd = odd_fft[2*k + 1];
+
+        out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
+        out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
+
+        out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
+        out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
+    }
+}
+
+static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector & samples,
+                                              int n_samples, int frame_size, int frame_step, int n_threads,
+                                              const whisper_filters & filters, whisper_mel & mel) {
+    std::vector fft_in(frame_size * 2, 0.0);
+    std::vector fft_out(frame_size * 2 * 2 * 2);
+
+    int n_fft = filters.n_fft;
+    int i = ith;
+
+    // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
+    WHISPER_ASSERT(n_fft == 1 + (frame_size / 2));
+
+    // calculate FFT only when fft_in are not all zero
+    for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
+        const int offset = i * frame_step;
+
+        // apply Hann window (~10% faster)
+        for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
+            fft_in[j] = hann[j] * samples[offset + j];
+        }
+
+        // fill the rest with zeros
+        if (n_samples - offset < frame_size) {
+            std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
+        }
+
+        // FFT
+        fft(fft_in.data(), frame_size, fft_out.data());
+
+        // Calculate modulus^2 of complex numbers
+        // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
+        for (int j = 0; j < n_fft; j++) {
+            fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
+        }
+
+        // mel spectrogram
+        for (int j = 0; j < mel.n_mel; j++) {
+            double sum = 0.0;
+            // unroll loop (suggested by GH user @lunixbochs)
+            int k = 0;
+            for (k = 0; k < n_fft - 3; k += 4) {
+                sum +=
+                        fft_out[k + 0] * filters.data[j * n_fft + k + 0] +
+                        fft_out[k + 1] * filters.data[j * n_fft + k + 1] +
+                        fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
+                        fft_out[k + 3] * filters.data[j * n_fft + k + 3];
+            }
+            // handle n_fft remainder
+            for (; k < n_fft; k++) {
+                sum += fft_out[k] * filters.data[j * n_fft + k];
+            }
+            sum = log10(std::max(sum, 1e-10));
+            mel.data[j * mel.n_len + i] = sum;
+        }
+    }
+
+    // Otherwise fft_out are all zero
+    double sum = log10(1e-10);
+    for (; i < mel.n_len; i += n_threads) {
+        for (int j = 0; j < mel.n_mel; j++) {
+            mel.data[j * mel.n_len + i] = sum;
+        }
+    }
+}
+
+// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
+static bool log_mel_spectrogram(
+        const float * samples,
+        const int   n_samples,
+        const int   /*sample_rate*/,
+        const int   frame_size,
+        const int   frame_step,
+        const int   n_mel,
+        const int   n_threads,
+        const whisper_filters & filters,
+        const bool   debug,
+        whisper_mel & mel) {
+    //const int64_t t_start_us = ggml_time_us();
+
+    // Hann window
+    WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size");
+    const float * hann = global_cache.hann_window;
+
+    // Calculate the length of padding
+    int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
+    int64_t stage_2_pad = frame_size / 2;
+
+    // Initialize a vector and copy data from C array to it.
+    std::vector samples_padded;
+    samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
+    std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);
+
+    // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
+    std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);
+
+    // reflective pad 200 samples at the beginning of audio
+    std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
+
+    mel.n_mel     = n_mel;
+    // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
+    // Calculate number of frames + remove the last frame
+    mel.n_len     = (samples_padded.size() - frame_size) / frame_step;
+    // Calculate semi-padded sample length to ensure compatibility
+    mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
+    mel.data.resize(mel.n_mel * mel.n_len);
+
+    {
+        std::vector workers(n_threads - 1);
+        for (int iw = 0; iw < n_threads - 1; ++iw) {
+            workers[iw] = std::thread(
+                    log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded),
+                    n_samples + stage_2_pad, frame_size, frame_step, n_threads,
+                    std::cref(filters), std::ref(mel));
+        }
+
+        // main thread
+        log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel);
+
+        for (int iw = 0; iw < n_threads - 1; ++iw) {
+            workers[iw].join();
+        }
+    }
+
+    // clamping and normalization
+    double mmax = -1e20;
+    for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
+        if (mel.data[i] > mmax) {
+            mmax = mel.data[i];
+        }
+    }
+
+    mmax -= 8.0;
+
+    for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
+        if (mel.data[i] < mmax) {
+            mel.data[i] = mmax;
+        }
+
+        mel.data[i] = (mel.data[i] + 4.0)/4.0;
+    }
+
+    // Dump log_mel_spectrogram
+    if (debug) {
+        std::ofstream outFile("log_mel_spectrogram.json");
+        outFile << "[";
+        for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
+            outFile << mel.data[i] << ", ";
+        }
+        outFile << mel.data[mel.data.size() - 1] << "]";
+        outFile.close();
+    }
+
+    return true;
+}
+
+bool preprocess_audio(
+        const float * samples,
+        size_t n_samples,
+        const whisper_filters & filters,
+        std::vector & output) {
+
+    if (n_samples == 0) {
+        // empty audio
+        return false;
+    }
+
+    whisper_mel out_full;
+    bool ok = log_mel_spectrogram(
+                samples,
+                n_samples,
+                COMMON_SAMPLE_RATE,
+                WHISPER_N_FFT,
+                WHISPER_HOP_LENGTH,
+                filters.n_mel,
+                4, // n_threads
+                filters,
+                false, // debug
+                out_full);
+    if (!ok) {
+        return false;
+    }
+
+    // because the cgraph in clip.cpp only accepts 3000 frames each, we need to split the mel
+    // we always expect the mel to have 3000 silent frames at the end
+    // printf("n_len %d\n", out_full.n_len);
+    const size_t frames_per_chunk = 3000;
+    GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk);
+    for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) {
+        int n_len = std::min(frames_per_chunk, (size_t)out_full.n_len - off);
+        if ((size_t)n_len < frames_per_chunk) {
+            break; // last uncomplete chunk will always be a padded chunk, safe to ignore
+        }
+
+        whisper_mel out_chunk;
+        out_chunk.n_len     = n_len;
+        out_chunk.n_mel     = out_full.n_mel;
+        out_chunk.n_len_org = out_full.n_mel; // unused
+        out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len);
+
+        for (int i = 0; i < out_full.n_mel; i++) {
+            auto src = out_full.data.begin() + i*out_full.n_len + off;
+            out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk);
+        }
+
+        output.push_back(std::move(out_chunk));
+    }
+
+    return true;
+}
+
+} // namespace whisper_preprocessor
+
+
+// precalculated mel filter banks
+// values are multiplied by 1000.0 to save space, and will be divided by 1000.0 in the end of the function
+//
+// generated from python code:
+//
+// from numpy import load
+// data = load('mel_filters.npz')
+// lst = data.files
+// for item in lst:
+//   print(item)
+//   print(data[item].shape)
+//   n_mel = data[item].shape[0]
+//   n_fft = data[item].shape[1]
+//   for i, row in enumerate(data[item]):
+//     for j, val in enumerate(row):
+//       val = val * 1000.0
+//       if val != 0:
+//         print(f"data[{i*n_fft + j}] = {val:.6f};")
+
+namespace whisper_precalc_filters {
+
+whisper_preprocessor::whisper_filters get_128_bins() {
+    whisper_preprocessor::whisper_filters filters;
+    filters.n_mel = 128;
+    filters.n_fft = 201;
+    std::vector data(filters.n_mel * filters.n_fft, 0.0f);
+
+    data[1] = 12.37398665;
+    data[202] = 30.39256483;
+    data[404] = 24.74797331;
+    data[605] = 18.01857911;
+    data[807] = 37.12195903;
+    data[1008] = 5.64459199;
+    data[1009] = 6.72939420;
+    data[1210] = 36.03715822;
+    data[1412] = 19.10337992;
+    data[1613] = 23.66316877;
+    data[1815] = 31.47736564;
+    data[2016] = 11.28918398;
+    data[2017] = 1.08480197;
+    data[2218] = 41.68175161;
+    data[2420] = 13.45878839;
+    data[2621] = 29.30776216;
+    data[2823] = 25.83277412;
+    data[3024] = 16.93377644;
+    data[3226] = 38.20675984;
+    data[3427] = 4.55979025;
+    data[3428] = 7.81419594;
+    data[3629] = 34.95235741;
+    data[3831] = 20.18818259;
+    data[4032] = 22.57836796;
+    data[4234] = 32.56217018;
+    data[4435] = 10.20438317;
+    data[4436] = 2.16960395;
+    data[4637] = 40.59694707;
+    data[4839] = 14.54358920;
+    data[5040] = 28.22295949;
+    data[5242] = 26.91757679;
+    data[5443] = 15.84897563;
+    data[5645] = 39.29156065;
+    data[5846] = 3.47498828;
+    data[5847] = 8.89899861;
+    data[6048] = 33.86755288;
+    data[6250] = 21.27298526;
+    data[6451] = 21.49356715;
+    data[6653] = 33.64697099;
+    data[6854] = 9.11958050;
+    data[6855] = 3.25440569;
+    data[7056] = 39.51214626;
+    data[7258] = 15.62839188;
+    data[7459] = 27.13815868;
+    data[7661] = 28.00237760;
+    data[7862] = 14.76417296;
+    data[8064] = 40.37636518;
+    data[8265] = 2.38068704;
+    data[8266] = 10.20263787;
+    data[8467] = 31.61146119;
+    data[8669] = 24.54700135;
+    data[8870] = 15.32919332;
+    data[8871] = 1.66583748;
+    data[9072] = 36.72905266;
+    data[9274] = 20.09709924;
+    data[9475] = 16.93102531;
+    data[9476] = 2.90265540;
+    data[9677] = 32.84499049;
+    data[9879] = 23.52004871;
+    data[10080] = 11.03894413;
+    data[10081] = 10.72582975;
+    data[10282] = 22.71829173;
+    data[10484] = 32.27872774;
+    data[10685] = 0.11626833;
+    data[10686] = 22.85348251;
+    data[10887] = 8.56344029;
+    data[10888] = 14.97978810;
+    data[11089] = 15.51398356;
+    data[11090] = 8.51490628;
+    data[11291] = 21.10680379;
+    data[11292] = 3.32652032;
+    data[11493] = 25.47064796;
+    data[11695] = 27.35907957;
+    data[11896] = 0.65853616;
+    data[11897] = 23.83812517;
+    data[12098] = 3.44359246;
+    data[12099] = 21.22455277;
+    data[12300] = 5.35842171;
+    data[12301] = 19.42555793;
+    data[12502] = 6.49324711;
+    data[12503] = 18.35542172;
+    data[12704] = 6.93138083;
+    data[12705] = 17.93504693;
+    data[12906] = 6.74968259;
+    data[12907] = 18.09151843;
+    data[13108] = 6.01899112;
+    data[13109] = 18.75767298;
+    data[13310] = 4.80452832;
+    data[13311] = 19.87172849;
+    data[13512] = 3.16627859;
+    data[13513] = 21.37690969;
+    data[13514] = 1.25317345;
+    data[13714] = 1.15934468;
+    data[13715] = 20.80361731;
+    data[13716] = 4.04486805;
+    data[13917] = 17.55363122;
+    data[13918] = 7.08320038;
+    data[14119] = 14.07538634;
+    data[14120] = 10.32655034;
+    data[14321] = 10.40921453;
+    data[14322] = 13.73696327;
+    data[14523] = 6.59187697;
+    data[14524] = 17.27988198;
+    data[14525] = 1.46804214;
+    data[14725] = 2.65681883;
+    data[14726] = 18.09193194;
+    data[14727] = 5.85655728;
+    data[14928] = 13.34277913;
+    data[14929] = 10.28267574;
+    data[15130] = 8.56800377;
+    data[15131] = 14.72230814;
+    data[15132] = 1.04039861;
+    data[15332] = 3.79085587;
+    data[15333] = 17.14678481;
+    data[15334] = 6.11609267;
+    data[15535] = 11.75929047;
+    data[15536] = 11.13393717;
+    data[15737] = 6.43857848;
+    data[15738] = 16.07806236;
+    data[15739] = 4.23917221;
+    data[15939] = 1.19989377;
+    data[15940] = 12.75671553;
+    data[15941] = 9.65298992;
+    data[16142] = 7.06935255;
+    data[16143] = 14.94054683;
+    data[16144] = 4.19024844;
+    data[16344] = 1.51483389;
+    data[16345] = 12.00899947;
+    data[16346] = 9.84823331;
+    data[16547] = 6.10224018;
+    data[16548] = 15.33857174;
+    data[16549] = 5.57676842;
+    data[16749] = 0.36827257;
+    data[16750] = 9.89749376;
+    data[16751] = 11.35340426;
+    data[16752] = 2.05122307;
+    data[16952] = 3.89297144;
+    data[16953] = 12.97352277;
+    data[16954] = 8.06631614;
+    data[17155] = 6.74493238;
+    data[17156] = 13.85874674;
+    data[17157] = 5.41190524;
+    data[17357] = 0.74220158;
+    data[17358] = 8.98779090;
+    data[17359] = 11.37871388;
+    data[17360] = 3.32958088;
+    data[17560] = 2.82313535;
+    data[17561] = 10.68049297;
+    data[17562] = 9.43340641;
+    data[17563] = 1.76325557;
+    data[17763] = 4.39018616;
+    data[17764] = 11.87758986;
+    data[17765] = 7.97005836;
+    data[17766] = 0.66104700;
+    data[17966] = 5.49466675;
+    data[17967] = 12.62953598;
+    data[17968] = 6.93987962;
+    data[18169] = 6.18401915;
+    data[18170] = 12.93473132;
+    data[18171] = 6.29778765;
+    data[18371] = 0.02325210;
+    data[18372] = 6.50206627;
+    data[18373] = 12.32661773;
+    data[18374] = 6.00216538;
+    data[18574] = 0.31548753;
+    data[18575] = 6.48925547;
+    data[18576] = 12.04130240;
+    data[18577] = 6.01462880;
+    data[18777] = 0.29979556;
+    data[18778] = 6.18288014;
+    data[18779] = 12.04272825;
+    data[18780] = 6.29981188;
+    data[18781] = 0.55689598;
+    data[18980] = 0.01120471;
+    data[18981] = 5.61729167;
+    data[18982] = 11.22337859;
+    data[18983] = 6.82516303;
+    data[18984] = 1.35264499;
+    data[19184] = 4.82410006;
+    data[19185] = 10.16623247;
+    data[19186] = 7.56075513;
+    data[19187] = 2.34590308;
+    data[19387] = 3.83235747;
+    data[19388] = 8.92296247;
+    data[19389] = 8.47910438;
+    data[19390] = 3.50978645;
+    data[19590] = 2.66873185;
+    data[19591] = 7.51965167;
+    data[19592] = 9.55500547;
+    data[19593] = 4.81966138;
+    data[19594] = 0.08431751;
+    data[19793] = 1.35767367;
+    data[19794] = 5.98019501;
+    data[19795] = 10.60271543;
+    data[19796] = 6.25298498;
+    data[19797] = 1.74059917;
+    data[19997] = 4.32644226;
+    data[19998] = 8.73131864;
+    data[19999] = 7.78916525;
+    data[20000] = 3.48923868;
+    data[20200] = 2.57835095;
+    data[20201] = 6.77582854;
+    data[20202] = 9.40941647;
+    data[20203] = 5.31194592;
+    data[20204] = 1.21447595;
+    data[20403] = 0.75411191;
+    data[20404] = 4.75395704;
+    data[20405] = 8.75380263;
+    data[20406] = 7.19209015;
+    data[20407] = 3.28754401;
+    data[20607] = 2.68179690;
+    data[20608] = 6.49331464;
+    data[20609] = 9.11457930;
+    data[20610] = 5.39387390;
+    data[20611] = 1.67316827;
+    data[20810] = 0.57394296;
+    data[20811] = 4.20600036;
+    data[20812] = 7.83805829;
+    data[20813] = 7.52023002;
+    data[20814] = 3.97470826;
+    data[20815] = 0.42918732;
+    data[21014] = 1.90464477;
+    data[21015] = 5.36569161;
+    data[21016] = 8.82673822;
+    data[21017] = 6.27609482;
+    data[21018] = 2.89750961;
+    data[21218] = 2.89885257;
+    data[21219] = 6.19694078;
+    data[21220] = 8.56699049;
+    data[21221] = 5.34748193;
+    data[21222] = 2.12797290;
+    data[21421] = 0.44750227;
+    data[21422] = 3.59030394;
+    data[21423] = 6.73310598;
+    data[21424] = 7.77023612;
+    data[21425] = 4.70231380;
+    data[21426] = 1.63439126;
+    data[21625] = 1.01536023;
+    data[21626] = 4.01018746;
+    data[21627] = 7.00501446;
+    data[21628] = 7.23442994;
+    data[21629] = 4.31095669;
+    data[21630] = 1.38748321;
+    data[21829] = 1.33348850;
+    data[21830] = 4.18730825;
+    data[21831] = 7.04112789;
+    data[21832] = 6.93188375;
+    data[21833] = 4.14605811;
+    data[21834] = 1.36023236;
+    data[22033] = 1.42879714;
+    data[22034] = 4.14824858;
+    data[22035] = 6.86769979;
+    data[22036] = 6.83705276;
+    data[22037] = 4.18239459;
+    data[22038] = 1.52773573;
+    data[22237] = 1.32610439;
+    data[22238] = 3.91751388;
+    data[22239] = 6.50892360;
+    data[22240] = 6.92639686;
+    data[22241] = 4.39672917;
+    data[22242] = 1.86706171;
+    data[22441] = 1.04827771;
+    data[22442] = 3.51767405;
+    data[22443] = 5.98707050;
+    data[22444] = 7.17824046;
+    data[22445] = 4.76767914;
+    data[22446] = 2.35711760;
+    data[22645] = 0.61636406;
+    data[22646] = 2.96949223;
+    data[22647] = 5.32262027;
+    data[22648] = 7.57265091;
+    data[22649] = 5.27558755;
+    data[22650] = 2.97852419;
+    data[22651] = 0.68146095;
+    data[22849] = 0.04971400;
+    data[22850] = 2.29204819;
+    data[22851] = 4.53438237;
+    data[22852] = 6.77671656;
+    data[22853] = 5.90240723;
+    data[22854] = 3.71349836;
+    data[22855] = 1.52458926;
+    data[23054] = 1.50285335;
+    data[23055] = 3.63961048;
+    data[23056] = 5.77636715;
+    data[23057] = 6.63159089;
+    data[23058] = 4.54574358;
+    data[23059] = 2.45989650;
+    data[23060] = 0.37404924;
+    data[23258] = 0.61795861;
+    data[23259] = 2.65410915;
+    data[23260] = 4.69025923;
+    data[23261] = 6.72641024;
+    data[23262] = 5.46034705;
+    data[23263] = 3.47270933;
+    data[23264] = 1.48507138;
+    data[23463] = 1.59233576;
+    data[23464] = 3.53261665;
+    data[23465] = 5.47289755;
+    data[23466] = 6.44368259;
+    data[23467] = 4.54962999;
+    data[23468] = 2.65557761;
+    data[23469] = 0.76152512;
+    data[23667] = 0.46749352;
+    data[23668] = 2.31641904;
+    data[23669] = 4.16534441;
+    data[23670] = 6.01426978;
+    data[23671] = 5.67844696;
+    data[23672] = 3.87357362;
+    data[23673] = 2.06870004;
+    data[23674] = 0.26382666;
+    data[23872] = 1.05349103;
+    data[23873] = 2.81536230;
+    data[23874] = 4.57723346;
+    data[23875] = 6.33910485;
+    data[23876] = 5.12815686;
+    data[23877] = 3.40826320;
+    data[23878] = 1.68837002;
+    data[24077] = 1.43350090;
+    data[24078] = 3.11241671;
+    data[24079] = 4.79133241;
+    data[24080] = 6.40943693;
+    data[24081] = 4.77052201;
+    data[24082] = 3.13160778;
+    data[24083] = 1.49269309;
+    data[24281] = 0.02932359;
+    data[24282] = 1.62918994;
+    data[24283] = 3.22905602;
+    data[24284] = 4.82892245;
+    data[24285] = 6.14671456;
+    data[24286] = 4.58496623;
+    data[24287] = 3.02321767;
+    data[24288] = 1.46146910;
+    data[24486] = 0.13601698;
+    data[24487] = 1.66055572;
+    data[24488] = 3.18509457;
+    data[24489] = 4.70963307;
+    data[24490] = 6.04072399;
+    data[24491] = 4.55250870;
+    data[24492] = 3.06429295;
+    data[24493] = 1.57607743;
+    data[24494] = 0.08786193;
+    data[24691] = 0.09328097;
+    data[24692] = 1.54603878;
+    data[24693] = 2.99879676;
+    data[24694] = 4.45155473;
+    data[24695] = 5.90431225;
+    data[24696] = 4.65566106;
+    data[24697] = 3.23751615;
+    data[24698] = 1.81937125;
+    data[24699] = 0.40122634;
+    data[24897] = 1.30262633;
+    data[24898] = 2.68698297;
+    data[24899] = 4.07133950;
+    data[24900] = 5.45569602;
+    data[24901] = 4.87832492;
+    data[24902] = 3.52695142;
+    data[24903] = 2.17557792;
+    data[24904] = 0.82420459;
+    data[25102] = 0.94595028;
+    data[25103] = 2.26512621;
+    data[25104] = 3.58430226;
+    data[25105] = 4.90347855;
+    data[25106] = 5.20569785;
+    data[25107] = 3.91795207;
+    data[25108] = 2.63020652;
+    data[25109] = 1.34246063;
+    data[25110] = 0.05471494;
+    data[25307] = 0.49037894;
+    data[25308] = 1.74744334;
+    data[25309] = 3.00450763;
+    data[25310] = 4.26157191;
+    data[25311] = 5.51863620;
+    data[25312] = 4.39707236;
+    data[25313] = 3.16995848;
+    data[25314] = 1.94284460;
+    data[25315] = 0.71573065;
+    data[25513] = 1.14698056;
+    data[25514] = 2.34485767;
+    data[25515] = 3.54273478;
+    data[25516] = 4.74061165;
+    data[25517] = 4.95198462;
+    data[25518] = 3.78264743;
+    data[25519] = 2.61331047;
+    data[25520] = 1.44397374;
+    data[25521] = 0.27463681;
+    data[25718] = 0.47569509;
+    data[25719] = 1.61717169;
+    data[25720] = 2.75864848;
+    data[25721] = 3.90012516;
+    data[25722] = 5.04160160;
+    data[25723] = 4.45712078;
+    data[25724] = 3.34284059;
+    data[25725] = 2.22856039;
+    data[25726] = 1.11428020;
+
+    for (auto & val : data) {
+        val /= 1000.0f;
+    }
+
+    filters.data = std::move(data);
+    return filters;
+}
+
+} // namespace whisper_precalc_filters
diff --git a/tools/mtmd/mtmd-audio.h b/tools/mtmd/mtmd-audio.h
new file mode 100644
index 0000000000000000000000000000000000000000..b7b940affb570f676c445d1e53dd764149c87dbb
--- /dev/null
+++ b/tools/mtmd/mtmd-audio.h
@@ -0,0 +1,47 @@
+#pragma once
+
+#include "ggml.h"
+
+#include 
+#include 
+#include 
+
+#define WHISPER_ASSERT GGML_ASSERT
+
+#define WHISPER_SAMPLE_RATE 16000
+#define WHISPER_N_FFT       400
+#define WHISPER_HOP_LENGTH  160
+#define WHISPER_CHUNK_SIZE  30
+
+#define COMMON_SAMPLE_RATE 16000
+
+namespace whisper_preprocessor {
+
+struct whisper_mel {
+    int n_len;
+    int n_len_org;
+    int n_mel;
+
+    std::vector data;
+};
+
+struct whisper_filters {
+    int32_t n_mel;
+    int32_t n_fft;
+
+    std::vector data;
+};
+
+bool preprocess_audio(
+        const float * samples,
+        size_t n_samples,
+        const whisper_filters & filters,
+        std::vector & output);
+
+} // namespace whisper_preprocessor
+
+namespace whisper_precalc_filters {
+
+whisper_preprocessor::whisper_filters get_128_bins();
+
+} // namespace whisper_precalc_filters
diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..599e682e0f8948fc992af21913f849ea798c5c1e
--- /dev/null
+++ b/tools/mtmd/mtmd-cli.cpp
@@ -0,0 +1,386 @@
+#include "arg.h"
+#include "log.h"
+#include "common.h"
+#include "sampling.h"
+#include "llama.h"
+#include "ggml.h"
+#include "console.h"
+#include "chat.h"
+#include "mtmd.h"
+#include "mtmd-helper.h"
+
+#include 
+#include 
+#include 
+
+#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
+#include 
+#include 
+#elif defined (_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+#define NOMINMAX
+#endif
+#include 
+#include 
+#endif
+
+// volatile, because of signal being an interrupt
+static volatile bool g_is_generating = false;
+static volatile bool g_is_interrupted = false;
+
+/**
+ * Please note that this is NOT a production-ready stuff.
+ * It is a playground for trying multimodal support in llama.cpp.
+ * For contributors: please keep this code simple and easy to understand.
+ */
+
+static void show_additional_info(int /*argc*/, char ** argv) {
+    LOG(
+        "Experimental CLI for multimodal\n\n"
+        "Usage: %s [options] -m  --mmproj  --image  --audio "]
+};
+
+let generation_settings = null;
+
+
+// Completes the prompt as a generator. Recommended for most use cases.
+//
+// Example:
+//
+//    import { llama } from '/completion.js'
+//
+//    const request = llama("Tell me a joke", {n_predict: 800})
+//    for await (const chunk of request) {
+//      document.write(chunk.data.content)
+//    }
+//
+export async function* llama(prompt, params = {}, config = {}) {
+  let controller = config.controller;
+  const api_url = config.api_url?.replace(/\/+$/, '') || "";
+
+  if (!controller) {
+    controller = new AbortController();
+  }
+
+  const completionParams = { ...paramDefaults, ...params, prompt };
+
+  const response = await fetch(`${api_url}${config.endpoint || '/completion'}`, {
+    method: 'POST',
+    body: JSON.stringify(completionParams),
+    headers: {
+      'Connection': 'keep-alive',
+      'Content-Type': 'application/json',
+      'Accept': 'text/event-stream',
+      ...(params.api_key ? {'Authorization': `Bearer ${params.api_key}`} : {})
+    },
+    signal: controller.signal,
+  });
+
+  const reader = response.body.getReader();
+  const decoder = new TextDecoder();
+
+  let content = "";
+  let leftover = ""; // Buffer for partially read lines
+
+  try {
+    let cont = true;
+
+    while (cont) {
+      const result = await reader.read();
+      if (result.done) {
+        break;
+      }
+
+      // Add any leftover data to the current chunk of data
+      const text = leftover + decoder.decode(result.value);
+
+      // Check if the last character is a line break
+      const endsWithLineBreak = text.endsWith('\n');
+
+      // Split the text into lines
+      let lines = text.split('\n');
+
+      // If the text doesn't end with a line break, then the last line is incomplete
+      // Store it in leftover to be added to the next chunk of data
+      if (!endsWithLineBreak) {
+        leftover = lines.pop();
+      } else {
+        leftover = ""; // Reset leftover if we have a line break at the end
+      }
+
+      // Parse all sse events and add them to result
+      const regex = /^(\S+):\s(.*)$/gm;
+      for (const line of lines) {
+        const match = regex.exec(line);
+        if (match) {
+          result[match[1]] = match[2];
+          if (result.data === '[DONE]') {
+            cont = false;
+            break;
+          }
+
+          // since we know this is llama.cpp, let's just decode the json in data
+          if (result.data) {
+            result.data = JSON.parse(result.data);
+            content += result.data.content;
+
+            // yield
+            yield result;
+
+            // if we got a stop token from server, we will break here
+            if (result.data.stop) {
+              if (result.data.generation_settings) {
+                generation_settings = result.data.generation_settings;
+              }
+              cont = false;
+              break;
+            }
+          }
+          if (result.error) {
+            try {
+              result.error = JSON.parse(result.error);
+              if (result.error.message.includes('slot unavailable')) {
+                // Throw an error to be caught by upstream callers
+                throw new Error('slot unavailable');
+              } else {
+                console.error(`llama.cpp error [${result.error.code} - ${result.error.type}]: ${result.error.message}`);
+              }
+            } catch(e) {
+              console.error(`llama.cpp error ${result.error}`)
+            }
+          }
+        }
+      }
+    }
+  } catch (e) {
+    if (e.name !== 'AbortError') {
+      console.error("llama error: ", e);
+    }
+    throw e;
+  }
+  finally {
+    controller.abort();
+  }
+
+  return content;
+}
+
+// Call llama, return an event target that you can subscribe to
+//
+// Example:
+//
+//    import { llamaEventTarget } from '/completion.js'
+//
+//    const conn = llamaEventTarget(prompt)
+//    conn.addEventListener("message", (chunk) => {
+//      document.write(chunk.detail.content)
+//    })
+//
+export const llamaEventTarget = (prompt, params = {}, config = {}) => {
+  const eventTarget = new EventTarget();
+  (async () => {
+    let content = "";
+    for await (const chunk of llama(prompt, params, config)) {
+      if (chunk.data) {
+        content += chunk.data.content;
+        eventTarget.dispatchEvent(new CustomEvent("message", { detail: chunk.data }));
+      }
+      if (chunk.data.generation_settings) {
+        eventTarget.dispatchEvent(new CustomEvent("generation_settings", { detail: chunk.data.generation_settings }));
+      }
+      if (chunk.data.timings) {
+        eventTarget.dispatchEvent(new CustomEvent("timings", { detail: chunk.data.timings }));
+      }
+    }
+    eventTarget.dispatchEvent(new CustomEvent("done", { detail: { content } }));
+  })();
+  return eventTarget;
+}
+
+// Call llama, return a promise that resolves to the completed text. This does not support streaming
+//
+// Example:
+//
+//     llamaPromise(prompt).then((content) => {
+//       document.write(content)
+//     })
+//
+//     or
+//
+//     const content = await llamaPromise(prompt)
+//     document.write(content)
+//
+export const llamaPromise = (prompt, params = {}, config = {}) => {
+  return new Promise(async (resolve, reject) => {
+    let content = "";
+    try {
+      for await (const chunk of llama(prompt, params, config)) {
+        content += chunk.data.content;
+      }
+      resolve(content);
+    } catch (error) {
+      reject(error);
+    }
+  });
+};
+
+/**
+ * (deprecated)
+ */
+export const llamaComplete = async (params, controller, callback) => {
+  for await (const chunk of llama(params.prompt, params, { controller })) {
+    callback(chunk);
+  }
+}
+
+// Get the model info from the server. This is useful for getting the context window and so on.
+export const llamaModelInfo = async (config = {}) => {
+  if (!generation_settings) {
+    const api_url = config.api_url?.replace(/\/+$/, '') || "";
+    const props = await fetch(`${api_url}/props`).then(r => r.json());
+    generation_settings = props.default_generation_settings;
+  }
+  return generation_settings;
+}
diff --git a/tools/server/public_legacy/favicon.ico b/tools/server/public_legacy/favicon.ico
new file mode 100644
index 0000000000000000000000000000000000000000..89e154a0a7525ecd19379e8e200e36851c325d31
Binary files /dev/null and b/tools/server/public_legacy/favicon.ico differ
diff --git a/tools/server/public_legacy/index-new.html b/tools/server/public_legacy/index-new.html
new file mode 100644
index 0000000000000000000000000000000000000000..cbfbbdf2806fa1a08307b90f8bbf7e21df3a3885
--- /dev/null
+++ b/tools/server/public_legacy/index-new.html
@@ -0,0 +1,1190 @@
+
+
+
+
+
+  
+  
+  
+  llama.cpp - chat
+
+  
+  
+
+  
+
+
+
+
+  
+ +
+
+ + + diff --git a/tools/server/public_legacy/index.html b/tools/server/public_legacy/index.html new file mode 100644 index 0000000000000000000000000000000000000000..75f39330a789d9d325628b7763d1887abf0a1601 --- /dev/null +++ b/tools/server/public_legacy/index.html @@ -0,0 +1,1301 @@ + + + + + + llama.cpp - chat + + + + + + + +
+ +
+
+ + + diff --git a/tools/server/public_legacy/index.js b/tools/server/public_legacy/index.js new file mode 100644 index 0000000000000000000000000000000000000000..32ec6e9e1541e9d407ac226e1821487d24e3a4ba --- /dev/null +++ b/tools/server/public_legacy/index.js @@ -0,0 +1 @@ +const t=Symbol.for("preact-signals");function n(){if(r>1){r--;return}let t,n=!1;while(void 0!==i){let _=i;i=void 0;u++;while(void 0!==_){const i=_.o;_.o=void 0;_.f&=-3;if(!(8&_.f)&&h(_))try{_.c()}catch(e){if(!n){t=e;n=!0}}_=i}}u=0;r--;if(n)throw t}function e(t){if(r>0)return t();r++;try{return t()}finally{n()}}let _,i;function o(t){const n=_;_=void 0;try{return t()}finally{_=n}}let r=0,u=0,l=0;function s(t){if(void 0===_)return;let n=t.n;if(void 0===n||n.t!==_){n={i:0,S:t,p:_.s,n:void 0,t:_,e:void 0,x:void 0,r:n};if(void 0!==_.s)_.s.n=n;_.s=n;t.n=n;if(32&_.f)t.S(n);return n}else if(-1===n.i){n.i=0;if(void 0!==n.n){n.n.p=n.p;if(void 0!==n.p)n.p.n=n.n;n.p=_.s;n.n=void 0;_.s.n=n;_.s=n}return n}}function f(t){this.v=t;this.i=0;this.n=void 0;this.t=void 0}f.prototype.brand=t;f.prototype.h=function(){return!0};f.prototype.S=function(t){if(this.t!==t&&void 0===t.e){t.x=this.t;if(void 0!==this.t)this.t.e=t;this.t=t}};f.prototype.U=function(t){if(void 0!==this.t){const n=t.e,e=t.x;if(void 0!==n){n.x=e;t.e=void 0}if(void 0!==e){e.e=n;t.x=void 0}if(t===this.t)this.t=e}};f.prototype.subscribe=function(t){return k(()=>{const n=this.value,e=_;_=void 0;try{t(n)}finally{_=e}})};f.prototype.valueOf=function(){return this.value};f.prototype.toString=function(){return this.value+""};f.prototype.toJSON=function(){return this.value};f.prototype.peek=function(){const t=_;_=void 0;try{return this.value}finally{_=t}};Object.defineProperty(f.prototype,"value",{get(){const t=s(this);if(void 0!==t)t.i=this.i;return this.v},set(t){if(t!==this.v){if(u>100)throw new Error("Cycle detected");this.v=t;this.i++;l++;r++;try{for(let t=this.t;void 0!==t;t=t.x)t.t.N()}finally{n()}}}});function c(t){return new f(t)}function h(t){for(let n=t.s;void 0!==n;n=n.n)if(n.S.i!==n.i||!n.S.h()||n.S.i!==n.i)return!0;return!1}function a(t){for(let n=t.s;void 0!==n;n=n.n){const e=n.S.n;if(void 0!==e)n.r=e;n.S.n=n;n.i=-1;if(void 0===n.n){t.s=n;break}}}function p(t){let n,e=t.s;while(void 0!==e){const t=e.p;if(-1===e.i){e.S.U(e);if(void 0!==t)t.n=e.n;if(void 0!==e.n)e.n.p=t}else n=e;e.S.n=e.r;if(void 0!==e.r)e.r=void 0;e=t}t.s=n}function d(t){f.call(this,void 0);this.x=t;this.s=void 0;this.g=l-1;this.f=4}(d.prototype=new f).h=function(){this.f&=-3;if(1&this.f)return!1;if(32==(36&this.f))return!0;this.f&=-5;if(this.g===l)return!0;this.g=l;this.f|=1;if(this.i>0&&!h(this)){this.f&=-2;return!0}const t=_;try{a(this);_=this;const t=this.x();if(16&this.f||this.v!==t||0===this.i){this.v=t;this.f&=-17;this.i++}}catch(t){this.v=t;this.f|=16;this.i++}_=t;p(this);this.f&=-2;return!0};d.prototype.S=function(t){if(void 0===this.t){this.f|=36;for(let t=this.s;void 0!==t;t=t.n)t.S.S(t)}f.prototype.S.call(this,t)};d.prototype.U=function(t){if(void 0!==this.t){f.prototype.U.call(this,t);if(void 0===this.t){this.f&=-33;for(let t=this.s;void 0!==t;t=t.n)t.S.U(t)}}};d.prototype.N=function(){if(!(2&this.f)){this.f|=6;for(let t=this.t;void 0!==t;t=t.x)t.t.N()}};Object.defineProperty(d.prototype,"value",{get(){if(1&this.f)throw new Error("Cycle detected");const t=s(this);this.h();if(void 0!==t)t.i=this.i;if(16&this.f)throw this.v;return this.v}});function v(t){return new d(t)}function y(t){const e=t.u;t.u=void 0;if("function"==typeof e){r++;const i=_;_=void 0;try{e()}catch(n){t.f&=-2;t.f|=8;m(t);throw n}finally{_=i;n()}}}function m(t){for(let n=t.s;void 0!==n;n=n.n)n.S.U(n);t.x=void 0;t.s=void 0;y(t)}function g(t){if(_!==this)throw new Error("Out-of-order effect");p(this);_=t;this.f&=-2;if(8&this.f)m(this);n()}function b(t){this.x=t;this.u=void 0;this.s=void 0;this.o=void 0;this.f=32}b.prototype.c=function(){const t=this.S();try{if(8&this.f)return;if(void 0===this.x)return;const n=this.x();if("function"==typeof n)this.u=n}finally{t()}};b.prototype.S=function(){if(1&this.f)throw new Error("Cycle detected");this.f|=1;this.f&=-9;y(this);a(this);r++;const t=_;_=this;return g.bind(this,t)};b.prototype.N=function(){if(!(2&this.f)){this.f|=2;this.o=i;i=this}};b.prototype.d=function(){this.f|=8;if(!(1&this.f))m(this)};function k(t){const n=new b(t);try{n.c()}catch(t){n.d();throw t}return n.d.bind(n)}var w,S,x,C,U,E,H,P,N,$,T,D,M={},A=[],F=/acit|ex(?:s|g|n|p|$)|rph|grid|ows|mnc|ntw|ine[ch]|zoo|^ord|itera/i,W=Array.isArray;function L(t,n){for(var e in n)t[e]=n[e];return t}function O(t){t&&t.parentNode&&t.parentNode.removeChild(t)}function R(t,n,e){var _,i,o,r={};for(o in n)"key"==o?_=n[o]:"ref"==o?i=n[o]:r[o]=n[o];if(arguments.length>2&&(r.children=arguments.length>3?w.call(arguments,2):e),"function"==typeof t&&null!=t.defaultProps)for(o in t.defaultProps)void 0===r[o]&&(r[o]=t.defaultProps[o]);return I(t,r,_,i,null)}function I(t,n,e,_,i){var o={type:t,props:n,key:e,ref:_,__k:null,__:null,__b:0,__e:null,__d:void 0,__c:null,constructor:void 0,__v:null==i?++x:i,__i:-1,__u:0};return null==i&&null!=S.vnode&&S.vnode(o),o}function V(){return{current:null}}function j(t){return t.children}function q(t,n){this.props=t,this.context=n}function B(t,n){if(null==n)return t.__?B(t.__,t.__i+1):null;for(var e;nn&&U.sort(P));J.__r=0}function K(t,n,e,_,i,o,r,u,l,s,f){var c,h,a,p,d,v=_&&_.__k||A,y=n.length;for(e.__d=l,Q(e,n,v),l=e.__d,c=0;c0?I(i.type,i.props,i.key,i.ref?i.ref:null,i.__v):i).__=t,i.__b=t.__b+1,o=null,-1!==(u=i.__i=Z(i,e,r,f))&&(f--,(o=e[u])&&(o.__u|=131072)),null==o||null===o.__v?(-1==u&&c--,"function"!=typeof i.type&&(i.__u|=65536)):u!==r&&(u==r-1?c--:u==r+1?c++:(u>r?c--:c++,i.__u|=65536))):i=t.__k[_]=null;if(f)for(_=0;_(null!=l&&0==(131072&l.__u)?1:0))for(;r>=0||u=0){if((l=n[r])&&0==(131072&l.__u)&&i==l.key&&o===l.type)return r;r--}if(u2&&(u.children=arguments.length>3?w.call(arguments,2):e),I(t.type,u,_||t.key,i||t.ref,null)}function ht(t,n){var e={__c:n="__cC"+D++,__:t,Consumer:function(t,n){return t.children(n)},Provider:function(t){var e,_;return this.getChildContext||(e=new Set,(_={})[n]=this,this.getChildContext=function(){return _},this.componentWillUnmount=function(){e=null},this.shouldComponentUpdate=function(t){this.props.value!==t.value&&e.forEach((function(t){t.__e=!0,G(t)}))},this.sub=function(t){e.add(t);var n=t.componentWillUnmount;t.componentWillUnmount=function(){e&&e.delete(t),n&&n.call(t)}}),t.children}};return e.Provider.__=e.Consumer.contextType=e}w=A.slice,S={__e:function(t,n,e,_){for(var i,o,r;n=n.__;)if((i=n.__c)&&!i.__)try{if((o=i.constructor)&&null!=o.getDerivedStateFromError&&(i.setState(o.getDerivedStateFromError(t)),r=i.__d),null!=i.componentDidCatch&&(i.componentDidCatch(t,_||{}),r=i.__d),r)return i.__E=i}catch(n){t=n}throw t}},x=0,C=function(t){return null!=t&&null==t.constructor},q.prototype.setState=function(t,n){var e;e=null!=this.__s&&this.__s!==this.state?this.__s:this.__s=L({},this.state),"function"==typeof t&&(t=t(L({},e),this.props)),t&&L(e,t),null!=t&&this.__v&&(n&&this._sb.push(n),G(this))},q.prototype.forceUpdate=function(t){this.__v&&(this.__e=!0,t&&this.__h.push(t),G(this))},q.prototype.render=j,U=[],H="function"==typeof Promise?Promise.prototype.then.bind(Promise.resolve()):setTimeout,P=function(t,n){return t.__v.__b-n.__v.__b},J.__r=0,N=0,$=et(!1),T=et(!0),D=0;var at,pt,dt,vt,yt=0,mt=[],gt=S,bt=gt.__b,kt=gt.__r,wt=gt.diffed,St=gt.__c,xt=gt.unmount,Ct=gt.__;function Ut(t,n){gt.__h&>.__h(pt,t,yt||n),yt=0;var e=pt.__H||(pt.__H={__:[],__h:[]});return t>=e.__.length&&e.__.push({}),e.__[t]}function Et(t){return yt=1,Ht(Bt,t)}function Ht(t,n,e){var _=Ut(at++,2);if(_.t=t,!_.__c&&(_.__=[e?e(n):Bt(void 0,n),function(t){var n=_.__N?_.__N[0]:_.__[0],e=_.t(n,t);n!==e&&(_.__N=[e,_.__[1]],_.__c.setState({}))}],_.__c=pt,!pt.u)){var i=function(t,n,e){if(!_.__c.__H)return!0;var i=_.__c.__H.__.filter((function(t){return!!t.__c}));if(i.every((function(t){return!t.__N})))return!o||o.call(this,t,n,e);var r=!1;return i.forEach((function(t){if(t.__N){var n=t.__[0];t.__=t.__N,t.__N=void 0,n!==t.__[0]&&(r=!0)}})),!(!r&&_.__c.props===t)&&(!o||o.call(this,t,n,e))};pt.u=!0;var o=pt.shouldComponentUpdate,r=pt.componentWillUpdate;pt.componentWillUpdate=function(t,n,e){if(this.__e){var _=o;o=void 0,i(t,n,e),o=_}r&&r.call(this,t,n,e)},pt.shouldComponentUpdate=i}return _.__N||_.__}function Pt(t,n){var e=Ut(at++,3);!gt.__s&&qt(e.__H,n)&&(e.__=t,e.i=n,pt.__H.__h.push(e))}function Nt(t,n){var e=Ut(at++,4);!gt.__s&&qt(e.__H,n)&&(e.__=t,e.i=n,pt.__h.push(e))}function $t(t){return yt=5,Dt((function(){return{current:t}}),[])}function Tt(t,n,e){yt=6,Nt((function(){return"function"==typeof t?(t(n()),function(){return t(null)}):t?(t.current=n(),function(){return t.current=null}):void 0}),null==e?e:e.concat(t))}function Dt(t,n){var e=Ut(at++,7);return qt(e.__H,n)&&(e.__=t(),e.__H=n,e.__h=t),e.__}function Mt(t,n){return yt=8,Dt((function(){return t}),n)}function At(t){var n=pt.context[t.__c],e=Ut(at++,9);return e.c=t,n?(null==e.__&&(e.__=!0,n.sub(pt)),n.props.value):t.__}function Ft(t,n){gt.useDebugValue&>.useDebugValue(n?n(t):t)}function Wt(t){var n=Ut(at++,10),e=Et();return n.__=t,pt.componentDidCatch||(pt.componentDidCatch=function(t,_){n.__&&n.__(t,_),e[1](t)}),[e[0],function(){e[1](void 0)}]}function Lt(){var t=Ut(at++,11);if(!t.__){for(var n=pt.__v;null!==n&&!n.__m&&null!==n.__;)n=n.__;var e=n.__m||(n.__m=[0,0]);t.__="P"+e[0]+"-"+e[1]++}return t.__}function Ot(){for(var t;t=mt.shift();)if(t.__P&&t.__H)try{t.__H.__h.forEach(Vt),t.__H.__h.forEach(jt),t.__H.__h=[]}catch(n){t.__H.__h=[],gt.__e(n,t.__v)}}gt.__b=function(t){pt=null,bt&&bt(t)},gt.__=function(t,n){t&&n.__k&&n.__k.__m&&(t.__m=n.__k.__m),Ct&&Ct(t,n)},gt.__r=function(t){kt&&kt(t),at=0;var n=(pt=t.__c).__H;n&&(dt===pt?(n.__h=[],pt.__h=[],n.__.forEach((function(t){t.__N&&(t.__=t.__N),t.i=t.__N=void 0}))):(n.__h.forEach(Vt),n.__h.forEach(jt),n.__h=[],at=0)),dt=pt},gt.diffed=function(t){wt&&wt(t);var n=t.__c;n&&n.__H&&(n.__H.__h.length&&(1!==mt.push(n)&&vt===gt.requestAnimationFrame||((vt=gt.requestAnimationFrame)||It)(Ot)),n.__H.__.forEach((function(t){t.i&&(t.__H=t.i),t.i=void 0}))),dt=pt=null},gt.__c=function(t,n){n.some((function(t){try{t.__h.forEach(Vt),t.__h=t.__h.filter((function(t){return!t.__||jt(t)}))}catch(r){n.some((function(t){t.__h&&(t.__h=[])})),n=[],gt.__e(r,t.__v)}})),St&&St(t,n)},gt.unmount=function(t){xt&&xt(t);var n,e=t.__c;e&&e.__H&&(e.__H.__.forEach((function(t){try{Vt(t)}catch(t){n=t}})),e.__H=void 0,n&>.__e(n,e.__v))};var Rt="function"==typeof requestAnimationFrame;function It(t){var n,e=function(){clearTimeout(_),Rt&&cancelAnimationFrame(n),setTimeout(t)},_=setTimeout(e,100);Rt&&(n=requestAnimationFrame(e))}function Vt(t){var n=pt,e=t.__c;"function"==typeof e&&(t.__c=void 0,e()),pt=n}function jt(t){var n=pt;t.__c=t.__(),pt=n}function qt(t,n){return!t||t.length!==n.length||n.some((function(n,e){return n!==t[e]}))}function Bt(t,n){return"function"==typeof n?n(t):n}function zt(t,n){S[t]=n.bind(null,S[t]||(()=>{}))}let Gt,Jt;function Kt(t){if(Jt)Jt();Jt=t&&t.S()}function Qt({data:t}){const n=Yt(t);n.value=t;const e=Dt(()=>{let t=this.__v;while(t=t.__)if(t.__c){t.__c.__$f|=4;break}this.__$u.c=()=>{var t;if(!C(e.peek())&&3===(null==(t=this.base)?void 0:t.nodeType))this.base.data=e.peek();else{this.__$f|=1;this.setState({})}};return v(()=>{let t=n.value.value;return 0===t?0:!0===t?"":t||""})},[]);return e.value}Qt.displayName="_st";Object.defineProperties(f.prototype,{constructor:{configurable:!0,value:void 0},type:{configurable:!0,value:Qt},props:{configurable:!0,get(){return{data:this}}},__b:{configurable:!0,value:1}});zt("__b",(t,n)=>{if("string"==typeof n.type){let t,e=n.props;for(let _ in e){if("children"===_)continue;let i=e[_];if(i instanceof f){if(!t)n.__np=t={};t[_]=i;e[_]=i.peek()}}}t(n)});zt("__r",(t,n)=>{Kt();let e,_=n.__c;if(_){_.__$f&=-2;e=_.__$u;if(void 0===e)_.__$u=e=function(t){let n;k((function(){n=this}));n.c=()=>{_.__$f|=1;_.setState({})};return n}()}Gt=_;Kt(e);t(n)});zt("__e",(t,n,e,_)=>{Kt();Gt=void 0;t(n,e,_)});zt("diffed",(t,n)=>{Kt();Gt=void 0;let e;if("string"==typeof n.type&&(e=n.__e)){let t=n.__np,_=n.props;if(t){let n=e.U;if(n)for(let e in n){let _=n[e];if(void 0!==_&&!(e in t)){_.d();n[e]=void 0}}else{n={};e.U=n}for(let i in t){let o=n[i],r=t[i];if(void 0===o){o=Xt(e,i,r,_);n[i]=o}else o.o(r,_)}}}t(n)});function Xt(t,n,e,_){const i=n in t&&void 0===t.ownerSVGElement,o=c(e);return{o:(t,n)=>{o.value=t;_=n},d:k(()=>{const e=o.value.value;if(_[n]!==e){_[n]=e;if(i)t[n]=e;else if(e)t.setAttribute(n,e);else t.removeAttribute(n)}})}}zt("unmount",(t,n)=>{if("string"==typeof n.type){let t=n.__e;if(t){const n=t.U;if(n){t.U=void 0;for(let t in n){let e=n[t];if(e)e.d()}}}}else{let t=n.__c;if(t){const n=t.__$u;if(n){t.__$u=void 0;n.d()}}}t(n)});zt("__h",(t,n,e,_)=>{if(_<3||9===_)n.__$f|=2;t(n,e,_)});q.prototype.shouldComponentUpdate=function(t,n){const e=this.__$u;if(!(e&&void 0!==e.s||4&this.__$f))return!0;if(3&this.__$f)return!0;for(let _ in n)return!0;for(let _ in t)if("__source"!==_&&t[_]!==this.props[_])return!0;for(let _ in this.props)if(!(_ in t))return!0;return!1};function Yt(t){return Dt(()=>c(t),[])}function Zt(t){const n=$t(t);n.current=t;Gt.__$f|=4;return Dt(()=>v(()=>n.current()),[])}function tn(t){const n=$t(t);n.current=t;Pt(()=>k(()=>n.current()),[])}var nn=function(t,n,e,_){var i;n[0]=0;for(var o=1;o=5&&((i||!t&&5===_)&&(r.push(_,0,i,e),_=6),t&&(r.push(_,t,0,e),_=6)),i=""},l=0;l"===n?(_=1,i=""):i=n+i[0]:o?n===o?o="":i+=n:'"'===n||"'"===n?o=n:">"===n?(u(),_=1):_&&("="===n?(_=5,e=i,i=""):"/"===n&&(_<5||">"===t[l][s+1])?(u(),3===_&&(r=r[0]),_=r,(r=r[0]).push(2,0,_),_=0):" "===n||"\t"===n||"\n"===n||"\r"===n?(u(),_=2):i+=n),3===_&&"!--"===i&&(_=4,r=r[0])}return u(),r}(t)),n),arguments,[])).length>1?n:n[0]}var on=_n.bind(R);export{q as Component,j as Fragment,f as Signal,e as batch,ct as cloneElement,v as computed,ht as createContext,R as createElement,V as createRef,k as effect,R as h,on as html,ft as hydrate,C as isValidElement,S as options,st as render,c as signal,Y as toChildArray,o as untracked,Mt as useCallback,Zt as useComputed,At as useContext,Ft as useDebugValue,Pt as useEffect,Wt as useErrorBoundary,Lt as useId,Tt as useImperativeHandle,Nt as useLayoutEffect,Dt as useMemo,Ht as useReducer,$t as useRef,Yt as useSignal,tn as useSignalEffect,Et as useState}; diff --git a/tools/server/public_legacy/json-schema-to-grammar.mjs b/tools/server/public_legacy/json-schema-to-grammar.mjs new file mode 100644 index 0000000000000000000000000000000000000000..b12bf2ab0909ac14c304652e796b547f92e167f7 --- /dev/null +++ b/tools/server/public_legacy/json-schema-to-grammar.mjs @@ -0,0 +1,838 @@ +// WARNING: This file was ported from json_schema_to_grammar.py, please fix bugs / add features there first. +const SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}'; + +function _buildRepetition(itemRule, minItems, maxItems, opts={}) { + if (maxItems == 0) { + return ''; + } + if (minItems === 0 && maxItems === 1) { + return `${itemRule}?`; + } + + + const separatorRule = opts.separatorRule ?? ''; + const itemRuleIsLiteral = opts.itemRuleIsLiteral ?? false + + if (separatorRule === '') { + if (minItems === 1 && maxItems === undefined) { + return `${itemRule}+`; + } else if (minItems === 0 && maxItems === undefined) { + return `${itemRule}*`; + } else { + return `${itemRule}{${minItems},${maxItems !== undefined ? maxItems : ''}}`; + } + } + + const result = itemRule + ' ' + _buildRepetition(`(${separatorRule} ${itemRule})`, minItems > 0 ? minItems - 1 : 0, maxItems !== undefined ? maxItems - 1 : undefined); + return minItems === 0 ? `(${result})?` : result; +} + +function _generateMinMaxInt(minValue, maxValue, out, decimalsLeft = 16, topLevel = true) { + const hasMin = minValue !== null; + const hasMax = maxValue !== null; + + function digitRange(fromChar, toChar) { + out.push("["); + if (fromChar === toChar) { + out.push(fromChar); + } else { + out.push(fromChar); + out.push("-"); + out.push(toChar); + } + out.push("]"); + } + + function moreDigits(minDigits, maxDigits) { + out.push("[0-9]"); + if (minDigits === maxDigits && minDigits === 1) { + return; + } + out.push("{"); + out.push(minDigits.toString()); + if (maxDigits !== minDigits) { + out.push(","); + if (maxDigits !== Number.MAX_SAFE_INTEGER) { + out.push(maxDigits.toString()); + } + } + out.push("}"); + } + + function uniformRange(fromStr, toStr) { + let i = 0; + while (i < fromStr.length && fromStr[i] === toStr[i]) { + i++; + } + if (i > 0) { + out.push("\""); + out.push(fromStr.slice(0, i)); + out.push("\""); + } + if (i < fromStr.length) { + if (i > 0) { + out.push(" "); + } + const subLen = fromStr.length - i - 1; + if (subLen > 0) { + const fromSub = fromStr.slice(i + 1); + const toSub = toStr.slice(i + 1); + const subZeros = "0".repeat(subLen); + const subNines = "9".repeat(subLen); + + let toReached = false; + out.push("("); + if (fromSub === subZeros) { + digitRange(fromStr[i], String.fromCharCode(toStr.charCodeAt(i) - 1)); + out.push(" "); + moreDigits(subLen, subLen); + } else { + out.push("["); + out.push(fromStr[i]); + out.push("] "); + out.push("("); + uniformRange(fromSub, subNines); + out.push(")"); + if (fromStr.charCodeAt(i) < toStr.charCodeAt(i) - 1) { + out.push(" | "); + if (toSub === subNines) { + digitRange(String.fromCharCode(fromStr.charCodeAt(i) + 1), toStr[i]); + toReached = true; + } else { + digitRange(String.fromCharCode(fromStr.charCodeAt(i) + 1), String.fromCharCode(toStr.charCodeAt(i) - 1)); + } + out.push(" "); + moreDigits(subLen, subLen); + } + } + if (!toReached) { + out.push(" | "); + digitRange(toStr[i], toStr[i]); + out.push(" "); + uniformRange(subZeros, toSub); + } + out.push(")"); + } else { + out.push("["); + out.push(fromStr[i]); + out.push("-"); + out.push(toStr[i]); + out.push("]"); + } + } + } + + if (hasMin && hasMax) { + if (minValue < 0 && maxValue < 0) { + out.push("\"-\" ("); + _generateMinMaxInt(-maxValue, -minValue, out, decimalsLeft, true); + out.push(")"); + return; + } + + if (minValue < 0) { + out.push("\"-\" ("); + _generateMinMaxInt(0, -minValue, out, decimalsLeft, true); + out.push(") | "); + minValue = 0; + } + + let minS = minValue.toString(); + const maxS = maxValue.toString(); + const minDigits = minS.length; + const maxDigits = maxS.length; + + for (let digits = minDigits; digits < maxDigits; digits++) { + uniformRange(minS, "9".repeat(digits)); + minS = "1" + "0".repeat(digits); + out.push(" | "); + } + uniformRange(minS, maxS); + return; + } + + const lessDecimals = Math.max(decimalsLeft - 1, 1); + + if (hasMin) { + if (minValue < 0) { + out.push("\"-\" ("); + _generateMinMaxInt(null, -minValue, out, decimalsLeft, false); + out.push(") | [0] | [1-9] "); + moreDigits(0, decimalsLeft - 1); + } else if (minValue === 0) { + if (topLevel) { + out.push("[0] | [1-9] "); + moreDigits(0, lessDecimals); + } else { + moreDigits(1, decimalsLeft); + } + } else if (minValue <= 9) { + const c = minValue.toString(); + const range_start = topLevel ? '1' : '0'; + if (c > range_start) { + digitRange(range_start, String.fromCharCode(c.charCodeAt(0) - 1)); + out.push(" "); + moreDigits(1, lessDecimals); + out.push(" | "); + } + digitRange(c, "9"); + out.push(" "); + moreDigits(0, lessDecimals); + } else { + const minS = minValue.toString(); + const length = minS.length; + const c = minS[0]; + + if (c > "1") { + digitRange(topLevel ? "1" : "0", String.fromCharCode(c.charCodeAt(0) - 1)); + out.push(" "); + moreDigits(length, lessDecimals); + out.push(" | "); + } + digitRange(c, c); + out.push(" ("); + _generateMinMaxInt(parseInt(minS.slice(1)), null, out, lessDecimals, false); + out.push(")"); + if (c < "9") { + out.push(" | "); + digitRange(String.fromCharCode(c.charCodeAt(0) + 1), "9"); + out.push(" "); + moreDigits(length - 1, lessDecimals); + } + } + return; + } + + if (hasMax) { + if (maxValue >= 0) { + if (topLevel) { + out.push("\"-\" [1-9] "); + moreDigits(0, lessDecimals); + out.push(" | "); + } + _generateMinMaxInt(0, maxValue, out, decimalsLeft, true); + } else { + out.push("\"-\" ("); + _generateMinMaxInt(-maxValue, null, out, decimalsLeft, false); + out.push(")"); + } + return; + } + + throw new Error("At least one of minValue or maxValue must be set"); +} + +class BuiltinRule { + constructor(content, deps) { + this.content = content; + this.deps = deps || []; + } +} + +const PRIMITIVE_RULES = { + boolean : new BuiltinRule('("true" | "false") space', []), + 'decimal-part' : new BuiltinRule('[0-9]{1,16}', []), + 'integral-part': new BuiltinRule('[0] | [1-9] [0-9]{0,15}', []), + number : new BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']), + integer : new BuiltinRule('("-"? integral-part) space', ['integral-part']), + value : new BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']), + object : new BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']), + array : new BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']), + uuid : new BuiltinRule('"\\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\\"" space', []), + char : new BuiltinRule(`[^"\\\\\\x7F\\x00-\\x1F] | [\\\\] (["\\\\bfnrt] | "u" [0-9a-fA-F]{4})`, []), + string : new BuiltinRule(`"\\"" char* "\\"" space`, ['char']), + null : new BuiltinRule('"null" space', []), +}; + +// TODO: support "uri", "email" string formats +const STRING_FORMAT_RULES = { + 'date' : new BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), + 'time' : new BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), + 'date-time' : new BuiltinRule('date "T" time', ['date', 'time']), + 'date-string' : new BuiltinRule('"\\"" date "\\"" space', ['date']), + 'time-string' : new BuiltinRule('"\\"" time "\\"" space', ['time']), + 'date-time-string': new BuiltinRule('"\\"" date-time "\\"" space', ['date-time']), +} + +const RESERVED_NAMES = {'root': true, ...PRIMITIVE_RULES, ...STRING_FORMAT_RULES}; + +const INVALID_RULE_CHARS_RE = /[^\dA-Za-z-]+/g; +const GRAMMAR_LITERAL_ESCAPE_RE = /[\n\r"]/g; +const GRAMMAR_RANGE_LITERAL_ESCAPE_RE = /[\n\r"\]\-\\]/g; +const GRAMMAR_LITERAL_ESCAPES = { '\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]' }; + +const NON_LITERAL_SET = new Set('|.()[]{}*+?'); +const ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = new Set('^$.[]()|{}*+?'); + +export class SchemaConverter { + constructor(options) { + this._propOrder = options.prop_order || {}; + this._allowFetch = options.allow_fetch || false; + this._dotall = options.dotall || false; + this._rules = {'space': SPACE_RULE}; + this._refs = {}; + this._refsBeingResolved = new Set(); + } + + _formatLiteral(literal) { + const escaped = literal.replace( + GRAMMAR_LITERAL_ESCAPE_RE, + m => GRAMMAR_LITERAL_ESCAPES[m] + ); + return `"${escaped}"`; + } + + _formatRangeChar(literal) { + return JSON.stringify(literal).slice(1, -1).replace( + GRAMMAR_RANGE_LITERAL_ESCAPE_RE, + m => GRAMMAR_LITERAL_ESCAPES[m] + ); + } + + _addRule(name, rule) { + let escName = name.replace(INVALID_RULE_CHARS_RE, '-'); + let key = escName; + + if (escName in this._rules) { + if (this._rules[escName] === rule) { + return key; + } + + let i = 0; + while ((`${escName}${i}` in this._rules) && (this._rules[`${escName}${i}`] !== rule)) { + i += 1; + } + key = `${escName}${i}`; + } + + this._rules[key] = rule; + return key; + } + + async resolveRefs(schema, url) { + const visit = async (n) => { + if (Array.isArray(n)) { + return Promise.all(n.map(visit)); + } else if (typeof n === 'object' && n !== null) { + let ref = n.$ref; + let target; + if (ref !== undefined && !this._refs[ref]) { + if (ref.startsWith('https://')) { + if (!this._allowFetch) { + throw new Error('Fetching remote schemas is not allowed (use --allow-fetch for force)'); + } + const fetch = (await import('node-fetch')).default; + + const fragSplit = ref.split('#'); + const baseUrl = fragSplit[0]; + + target = this._refs[baseUrl]; + if (!target) { + target = await this.resolveRefs(await fetch(ref).then(res => res.json()), baseUrl); + this._refs[baseUrl] = target; + } + + if (fragSplit.length === 1 || fragSplit[fragSplit.length - 1] === '') { + return target; + } + } else if (ref.startsWith('#/')) { + target = schema; + ref = `${url}${ref}`; + n.$ref = ref; + } else { + throw new Error(`Unsupported ref ${ref}`); + } + + const selectors = ref.split('#')[1].split('/').slice(1); + for (const sel of selectors) { + if (!target || !(sel in target)) { + throw new Error(`Error resolving ref ${ref}: ${sel} not in ${JSON.stringify(target)}`); + } + target = target[sel]; + } + + this._refs[ref] = target; + } else { + await Promise.all(Object.values(n).map(visit)); + } + } + + return n; + }; + + return visit(schema); + } + + _generateUnionRule(name, altSchemas) { + return altSchemas + .map((altSchema, i) => this.visit(altSchema, `${name ?? ''}${name ? '-' : 'alternative-'}${i}`)) + .join(' | '); + } + + _visitPattern(pattern, name) { + if (!pattern.startsWith('^') || !pattern.endsWith('$')) { + throw new Error('Pattern must start with "^" and end with "$"'); + } + pattern = pattern.slice(1, -1); + const subRuleIds = {}; + + let i = 0; + const length = pattern.length; + + const getDot = () => { + let rule; + if (this._dotall) { + rule = '[\\U00000000-\\U0010FFFF]'; + } else { + // Accept any character... except \n and \r line break chars (\x0A and \xOD) + rule = '[^\\x0A\\x0D]'; + } + return this._addRule('dot', rule); + }; + + + const toRule = ([s, isLiteral]) => isLiteral ? "\"" + s + "\"" : s; + + const transform = () => { + const start = i; + // For each component of this sequence, store its string representation and whether it's a literal. + // We only need a flat structure here to apply repetition operators to the last item, and + // to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially + // (GBNF's syntax is luckily very close to regular expressions!) + const seq = []; + + const joinSeq = () => { + const ret = []; + for (const [isLiteral, g] of groupBy(seq, x => x[1])) { + if (isLiteral) { + ret.push([[...g].map(x => x[0]).join(''), true]); + } else { + ret.push(...g); + } + } + if (ret.length === 1) { + return ret[0]; + } + return [ret.map(x => toRule(x)).join(' '), false]; + }; + + while (i < length) { + const c = pattern[i]; + if (c === '.') { + seq.push([getDot(), false]); + i += 1; + } else if (c === '(') { + i += 1; + if (i < length) { + if (pattern[i] === '?') { + throw new Error(`Unsupported pattern syntax "${pattern[i]}" at index ${i} of /${pattern}/`); + } + } + seq.push([`(${toRule(transform())})`, false]); + } else if (c === ')') { + i += 1; + if (start <= 0 || pattern[start - 1] !== '(') { + throw new Error(`Unbalanced parentheses; start = ${start}, i = ${i}, pattern = ${pattern}`); + } + return joinSeq(); + } else if (c === '[') { + let squareBrackets = c; + i += 1; + while (i < length && pattern[i] !== ']') { + if (pattern[i] === '\\') { + squareBrackets += pattern.slice(i, i + 2); + i += 2; + } else { + squareBrackets += pattern[i]; + i += 1; + } + } + if (i >= length) { + throw new Error(`Unbalanced square brackets; start = ${start}, i = ${i}, pattern = ${pattern}`); + } + squareBrackets += ']'; + i += 1; + seq.push([squareBrackets, false]); + } else if (c === '|') { + seq.push(['|', false]); + i += 1; + } else if (c === '*' || c === '+' || c === '?') { + seq[seq.length - 1] = [toRule(seq[seq.length - 1]) + c, false]; + i += 1; + } else if (c === '{') { + let curlyBrackets = c; + i += 1; + while (i < length && pattern[i] !== '}') { + curlyBrackets += pattern[i]; + i += 1; + } + if (i >= length) { + throw new Error(`Unbalanced curly brackets; start = ${start}, i = ${i}, pattern = ${pattern}`); + } + curlyBrackets += '}'; + i += 1; + const nums = curlyBrackets.slice(1, -1).split(',').map(s => s.trim()); + let minTimes, maxTimes; + if (nums.length === 1) { + minTimes = parseInt(nums[0], 10); + maxTimes = minTimes; + } else { + if (nums.length !== 2) { + throw new Error(`Invalid quantifier ${curlyBrackets}`); + } + minTimes = nums[0] ? parseInt(nums[0], 10) : 0; + maxTimes = nums[1] ? parseInt(nums[1], 10) : Infinity; + } + + let [sub, subIsLiteral] = seq[seq.length - 1]; + + if (!subIsLiteral) { + let id = subRuleIds[sub]; + if (id === undefined) { + id = this._addRule(`${name}-${Object.keys(subRuleIds).length + 1}`, sub); + subRuleIds[sub] = id; + } + sub = id; + } + + seq[seq.length - 1] = [ + _buildRepetition(subIsLiteral ? `"${sub}"` : sub, minTimes, maxTimes, {itemRuleIsLiteral: subIsLiteral}), + false + ]; + } else { + let literal = ''; + while (i < length) { + if (pattern[i] === '\\' && i < length - 1) { + const next = pattern[i + 1]; + if (ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS.has(next)) { + i += 1; + literal += pattern[i]; + i += 1; + } else { + literal += pattern.slice(i, i + 2); + i += 2; + } + } else if (pattern[i] === '"') { + literal += '\\"'; + i += 1; + } else if (!NON_LITERAL_SET.has(pattern[i]) && + (i === length - 1 || literal === '' || pattern[i + 1] === '.' || !NON_LITERAL_SET.has(pattern[i+1]))) { + literal += pattern[i]; + i += 1; + } else { + break; + } + } + if (literal !== '') { + seq.push([literal, true]); + } + } + } + + return joinSeq(); + }; + + return this._addRule(name, "\"\\\"\" (" + toRule(transform()) + ") \"\\\"\" space") + } + + _notStrings(strings) { + class TrieNode { + constructor() { + this.children = {}; + this.isEndOfString = false; + } + + insert(str) { + let node = this; + for (const c of str) { + node = node.children[c] = node.children[c] || new TrieNode(); + } + node.isEndOfString = true; + } + } + + const trie = new TrieNode(); + for (const s of strings) { + trie.insert(s); + } + + const charRuleName = this._addPrimitive('char', PRIMITIVE_RULES['char']); + const out = ['["] ( ']; + + const visit = (node) => { + const rejects = []; + let first = true; + for (const c of Object.keys(node.children).sort()) { + const child = node.children[c]; + rejects.push(c); + if (first) { + first = false; + } else { + out.push(' | '); + } + out.push(`[${c}]`); + if (Object.keys(child.children).length > 0) { + out.push(' ('); + visit(child); + out.push(')'); + } else if (child.isEndOfString) { + out.push(` ${charRuleName}+`); + } + } + if (Object.keys(node.children).length > 0) { + if (!first) { + out.push(' | '); + } + out.push(`[^"${rejects.join('')}] ${charRuleName}*`); + } + }; + + visit(trie); + + out.push(` )${trie.isEndOfString ? '' : '?'} ["] space`); + return out.join(''); + } + + _resolveRef(ref) { + let refName = ref.split('/').pop(); + if (!(refName in this._rules) && !this._refsBeingResolved.has(ref)) { + this._refsBeingResolved.add(ref); + const resolved = this._refs[ref]; + refName = this.visit(resolved, refName); + this._refsBeingResolved.delete(ref); + } + return refName; + } + + _generateConstantRule(value) { + return this._formatLiteral(JSON.stringify(value)); + } + + visit(schema, name) { + const schemaType = schema.type; + const schemaFormat = schema.format; + const ruleName = name in RESERVED_NAMES ? name + '-' : name == '' ? 'root' : name; + + const ref = schema.$ref; + if (ref !== undefined) { + return this._addRule(ruleName, this._resolveRef(ref)); + } else if (schema.oneOf || schema.anyOf) { + return this._addRule(ruleName, this._generateUnionRule(name, schema.oneOf || schema.anyOf)); + } else if (Array.isArray(schemaType)) { + return this._addRule(ruleName, this._generateUnionRule(name, schemaType.map(t => ({...schema, type: t})))); + } else if ('const' in schema) { + return this._addRule(ruleName, this._generateConstantRule(schema.const) + ' space'); + } else if ('enum' in schema) { + const rule = '(' + schema.enum.map(v => this._generateConstantRule(v)).join(' | ') + ') space'; + return this._addRule(ruleName, rule); + } else if ((schemaType === undefined || schemaType === 'object') && + ('properties' in schema || + ('additionalProperties' in schema && schema.additionalProperties !== true))) { + const required = new Set(schema.required || []); + const properties = Object.entries(schema.properties ?? {}); + return this._addRule(ruleName, this._buildObjectRule(properties, required, name, schema.additionalProperties)); + } else if ((schemaType === undefined || schemaType === 'object') && 'allOf' in schema) { + const required = new Set(); + const properties = []; + const addComponent = (compSchema, isRequired) => { + const ref = compSchema.$ref; + if (ref !== undefined) { + compSchema = this._refs[ref]; + } + + if ('properties' in compSchema) { + for (const [propName, propSchema] of Object.entries(compSchema.properties)) { + properties.push([propName, propSchema]); + if (isRequired) { + required.add(propName); + } + } + } + }; + + for (const t of schema.allOf) { + if ('anyOf' in t) { + for (const tt of t.anyOf) { + addComponent(tt, false); + } + } else { + addComponent(t, true); + } + } + + return this._addRule(ruleName, this._buildObjectRule(properties, required, name, null)); + } else if ((schemaType === undefined || schemaType === 'array') && ('items' in schema || 'prefixItems' in schema)) { + const items = schema.items ?? schema.prefixItems; + if (Array.isArray(items)) { + return this._addRule( + ruleName, + '"[" space ' + + items.map((item, i) => this.visit(item, `${name ?? ''}${name ? '-' : ''}tuple-${i}`)).join(' "," space ') + + ' "]" space' + ); + } else { + const itemRuleName = this.visit(items, `${name ?? ''}${name ? '-' : ''}item`); + const minItems = schema.minItems || 0; + const maxItems = schema.maxItems; + return this._addRule(ruleName, '"[" space ' + _buildRepetition(itemRuleName, minItems, maxItems, {separatorRule: '"," space'}) + ' "]" space'); + } + } else if ((schemaType === undefined || schemaType === 'string') && 'pattern' in schema) { + return this._visitPattern(schema.pattern, ruleName); + } else if ((schemaType === undefined || schemaType === 'string') && /^uuid[1-5]?$/.test(schema.format || '')) { + return this._addPrimitive( + ruleName === 'root' ? 'root' : schemaFormat, + PRIMITIVE_RULES['uuid'] + ); + } else if ((schemaType === undefined || schemaType === 'string') && `${schema.format}-string` in STRING_FORMAT_RULES) { + const primName = `${schema.format}-string` + return this._addRule(ruleName, this._addPrimitive(primName, STRING_FORMAT_RULES[primName])); + } else if (schemaType === 'string' && ('minLength' in schema || 'maxLength' in schema)) { + const charRuleName = this._addPrimitive('char', PRIMITIVE_RULES['char']); + const minLen = schema.minLength || 0; + const maxLen = schema.maxLength; + return this._addRule(ruleName, '"\\\"" ' + _buildRepetition(charRuleName, minLen, maxLen) + ' "\\\"" space'); + } else if (schemaType === 'integer' && ('minimum' in schema || 'exclusiveMinimum' in schema || 'maximum' in schema || 'exclusiveMaximum' in schema)) { + let minValue = null; + let maxValue = null; + if ('minimum' in schema) { + minValue = schema.minimum; + } else if ('exclusiveMinimum' in schema) { + minValue = schema.exclusiveMinimum + 1; + } + if ('maximum' in schema) { + maxValue = schema.maximum; + } else if ('exclusiveMaximum' in schema) { + maxValue = schema.exclusiveMaximum - 1; + } + + const out = ["("]; + _generateMinMaxInt(minValue, maxValue, out); + out.push(") space"); + return this._addRule(ruleName, out.join('')); + } else if ((schemaType === 'object') || (Object.keys(schema).length === 0)) { + return this._addRule(ruleName, this._addPrimitive('object', PRIMITIVE_RULES['object'])); + } else { + if (!(schemaType in PRIMITIVE_RULES)) { + throw new Error(`Unrecognized schema: ${JSON.stringify(schema)}`); + } + // TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero + return this._addPrimitive(ruleName === 'root' ? 'root' : schemaType, PRIMITIVE_RULES[schemaType]); + } + } + + _addPrimitive(name, rule) { + let n = this._addRule(name, rule.content); + for (const dep of rule.deps) { + const depRule = PRIMITIVE_RULES[dep] || STRING_FORMAT_RULES[dep]; + if (!depRule) { + throw new Error(`Rule ${dep} not known`); + } + if (!(dep in this._rules)) { + this._addPrimitive(dep, depRule); + } + } + return n; + } + + _buildObjectRule(properties, required, name, additionalProperties) { + const propOrder = this._propOrder; + // sort by position in prop_order (if specified) then by original order + const sortedProps = properties.map(([k]) => k).sort((a, b) => { + const orderA = propOrder[a] || Infinity; + const orderB = propOrder[b] || Infinity; + return orderA - orderB || properties.findIndex(([k]) => k === a) - properties.findIndex(([k]) => k === b); + }); + + const propKvRuleNames = {}; + for (const [propName, propSchema] of properties) { + const propRuleName = this.visit(propSchema, `${name ?? ''}${name ? '-' : ''}${propName}`); + propKvRuleNames[propName] = this._addRule( + `${name ?? ''}${name ? '-' : ''}${propName}-kv`, + `${this._formatLiteral(JSON.stringify(propName))} space ":" space ${propRuleName}` + ); + } + const requiredProps = sortedProps.filter(k => required.has(k)); + const optionalProps = sortedProps.filter(k => !required.has(k)); + + if (additionalProperties) { + const subName = `${name ?? ''}${name ? '-' : ''}additional`; + const valueRule = + additionalProperties != null && typeof additionalProperties === 'object' ? this.visit(additionalProperties, `${subName}-value`) + : this._addPrimitive('value', PRIMITIVE_RULES['value']); + + const key_rule = + sortedProps.length === 0 ? this._addPrimitive('string', PRIMITIVE_RULES['string']) + : this._addRule(`${subName}-k`, this._notStrings(sortedProps)); + + propKvRuleNames['*'] = this._addRule( + `${subName}-kv`, + `${key_rule} ":" space ${valueRule}`); + optionalProps.push('*'); + } + + let rule = '"{" space '; + rule += requiredProps.map(k => propKvRuleNames[k]).join(' "," space '); + + if (optionalProps.length > 0) { + rule += ' ('; + if (requiredProps.length > 0) { + rule += ' "," space ( '; + } + + const getRecursiveRefs = (ks, firstIsOptional) => { + const [k, ...rest] = ks; + const kvRuleName = propKvRuleNames[k]; + let res; + const commaRef = `( "," space ${kvRuleName} )`; + if (firstIsOptional) { + res = commaRef + (k === '*' ? '*' : '?'); + } else { + res = kvRuleName + (k === '*' ? ' ' + commaRef + '*' : ''); + } + if (rest.length > 0) { + res += ' ' + this._addRule( + `${name ?? ''}${name ? '-' : ''}${k}-rest`, + getRecursiveRefs(rest, true) + ); + } + return res; + }; + + rule += optionalProps.map((_, i) => getRecursiveRefs(optionalProps.slice(i), false)).join(' | '); + if (requiredProps.length > 0) { + rule += ' )'; + } + rule += ' )?'; + } + + rule += ' "}" space'; + + return rule; + } + + formatGrammar() { + let grammar = ''; + for (const [name, rule] of Object.entries(this._rules).sort(([a], [b]) => a.localeCompare(b))) { + grammar += `${name} ::= ${rule}\n`; + } + return grammar; + } +} + +// Helper function to group elements by a key function +function* groupBy(iterable, keyFn) { + let lastKey = null; + let group = []; + for (const element of iterable) { + const key = keyFn(element); + if (lastKey !== null && key !== lastKey) { + yield [lastKey, group]; + group = []; + } + group.push(element); + lastKey = key; + } + if (group.length > 0) { + yield [lastKey, group]; + } +} diff --git a/tools/server/public_legacy/loading.html b/tools/server/public_legacy/loading.html new file mode 100644 index 0000000000000000000000000000000000000000..c3fd19a0f5ae77f053d434834934778b5eb1759e --- /dev/null +++ b/tools/server/public_legacy/loading.html @@ -0,0 +1,12 @@ + + + + + + +
+ The model is loading. Please wait.
+ The user interface will appear soon. +
+ + diff --git a/tools/server/public_legacy/prompt-formats.js b/tools/server/public_legacy/prompt-formats.js new file mode 100644 index 0000000000000000000000000000000000000000..73ddb7187eb7ab9cfe7f9f27cfeceeb66cbb21ba --- /dev/null +++ b/tools/server/public_legacy/prompt-formats.js @@ -0,0 +1,331 @@ +// extended list +export const promptFormats = { + "alpaca": { + template: `{{prompt}}\n\n{{history}}\n\n{{char}}:`, + + historyTemplate: `### {{name}}:\n{{message}}`, + + char: "Response", + charMsgPrefix: "", + charMsgSuffix: "", + + user: "Instruction", + userMsgPrefix: "", + userMsgSuffix: "", + + stops: "" + }, + + // ---------------------------- + + "chatml": { + template: `<|im_start|>system\n{{prompt}}<|im_end|>\n{{history}}{{char}}`, + + historyTemplate: `<|im_start|>{{name}}\n{{message}}`, + + char: "assistant", + charMsgPrefix: "", + charMsgSuffix: "", + + user: "user", + userMsgPrefix: "", + userMsgSuffix: "<|im_end|>\n", + + stops: "" + }, + + // ---------------------------- + + "commandr": { + template: `<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{prompt}}\n<|END_OF_TURN_TOKEN|>{{history}}{{char}}`, + + historyTemplate: `<|START_OF_TURN_TOKEN|><|{{name}}|> {{message}}`, + + char: "CHATBOT_TOKEN", + charMsgPrefix: "", + charMsgSuffix: "", + + user: "USER_TOKEN", + userMsgPrefix: "", + userMsgSuffix: "<|END_OF_TURN_TOKEN|>", + + stops: "" + }, + // ref: https://docs.cohere.com/docs/prompting-command-r + + // ---------------------------- + + "llama2": { + template: `[INST] <>\n{{prompt}}\n<>\n\nTest Message [/INST] Test Successfull {{history}}{{char}}`, + + historyTemplate: `{{name}}: {{message}}`, + + char: "Assistant", + charMsgPrefix: "", + charMsgSuffix: "
", + + user: "User", + userMsgPrefix: "[INST] ", + userMsgSuffix: " [/INST]", + + stops: "" + }, + // ref: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + + // ---------------------------- + + "llama3": { + template: `<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{{prompt}}{{history}}{{char}}`, + + historyTemplate: `<|start_header_id|>{{name}}<|end_header_id|>\n\n{{message}}<|eot_id|>`, + + char: "assistant", + charMsgPrefix: "", + charMsgSuffix: "", + + user: "user", + userMsgPrefix: "", + userMsgSuffix: "", + + stops: "<|eot_id|>" + }, + // ref: https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/#special-tokens-used-with-meta-llama-3 + + // ---------------------------- + + "openchat": { + template: `{{history}}{{char}}`, + + historyTemplate: `GPT4 Correct {{name}}: {{message}}<|end_of_turn|>`, + + char: "Assistant", + charMsgPrefix: "", + charMsgSuffix: "", + + user: "User", + userMsgPrefix: "", + userMsgSuffix: "", + + stops: "" + }, + + // ---------------------------- + + "phi3": { + template: `{{history}}{{char}}`, + + historyTemplate: `<|{{name}}|>\n{{message}}<|end|>\n`, + + char: "assistant", + charMsgPrefix: "", + charMsgSuffix: "", + + user: "user", + userMsgPrefix: "", + userMsgSuffix: "", + + stops: "<|end|>" + }, + // ref: https://huggingface.co/microsoft/Phi-3-mini-4k-instruct#chat-format + + // ---------------------------- + + "vicuna": { + template: `{{prompt}}\n{{history}}{{char}}`, + + historyTemplate: `{{name}}: {{message}}\n`, + + char: "ASSISTANT", + charMsgPrefix: "", + charMsgSuffix: "", + + user: "USER", + userMsgPrefix: "", + userMsgSuffix: "", + + stops: "" + }, + // ref: https://huggingface.co/lmsys/vicuna-33b-v1.3/discussions/1 + + // ---------------------------- + + "deepseekCoder": { + template: `{{prompt}}{{history}}{{char}}:`, + + historyTemplate: `### {{name}}:\n{{message}}`, + + char: "Response", + charMsgPrefix: "", + charMsgSuffix: "", + + user: "Instruction", + userMsgPrefix: "", + userMsgSuffix: "", + + stops: "<|EOT|>" + }, + + // ---------------------------- + + "med42": { + template: `<|system|>: {{prompt}}\n{{history}}{{char}}`, + + historyTemplate: `<|{{name}}|>: {{message}}\n`, + + char: "assistant", + charMsgPrefix: "", + charMsgSuffix: "", + + user: "prompter", + userMsgPrefix: "", + userMsgSuffix: "", + + stops: "" + }, + + // ---------------------------- + + "neuralchat": { + template: `### System:\n{{prompt}}\n{{history}}{{char}}:`, + + historyTemplate: `### {{name}}:\n{{message}}\n`, + + char: "Assistant", + charMsgPrefix: "", + charMsgSuffix: "", + + user: "User", + userMsgPrefix: "", + userMsgSuffix: "", + + stops: "" + }, + + // ---------------------------- + + "nousHermes": { + template: `### Instruction: {{prompt}}\n\n{{history}}\n\n{{char}}:`, + + historyTemplate: `### {{name}}:\n{{message}}`, + + char: "Response", + charMsgPrefix: "", + charMsgSuffix: "", + + user: "Input", + userMsgPrefix: "", + userMsgSuffix: "", + + stops: "" + }, + + // ---------------------------- + + "openchatMath": { + template: `{{history}}{{char}}`, + + historyTemplate: `Math Correct {{name}}: {{message}}<|end_of_turn|>`, + + char: "Assistant", + charMsgPrefix: "", + charMsgSuffix: "", + + + user: "User", + userMsgPrefix: "", + userMsgSuffix: "", + + stops: "" + }, + + // ---------------------------- + + "orion": { + template: `Human: Test Message\n\nAssistant: Test Successful{{history}}{{char}}:`, + + historyTemplate: `{{name}}: {{message}}`, + + char: "Assistant ", + charMsgPrefix: "", + charMsgSuffix: "", + + user: "Human", + userMsgPrefix: "", + userMsgSuffix: "\n\n", + + stops: "" + }, + + // ---------------------------- + + "sauerkraut": { + template: `{{prompt}}\n{{history}}{{char}}`, + + historyTemplate: ` + {{name}}: {{message}}\n`, + + char: "Assistant", + charMsgPrefix: "", + charMsgSuffix: "", + + user: "User", + userMsgPrefix: "", + userMsgSuffix: "", + + stops: "" + }, + + // ---------------------------- + + "starlingCode": { + template: `{{history}}{{char}}`, + + historyTemplate: `Code {{name}}: {{message}}<|end_of_turn|>`, + + char: "Assistant", + charMsgPrefix: "", + charMsgSuffix: "", + + user: "User", + userMsgPrefix: "", + userMsgSuffix: "", + + stops: "" + }, + + // ---------------------------- + + "yi34b": { + template: `{{history}} {{char}}`, + + historyTemplate: `{{name}}: {{message}}`, + + char: "Assistant", + charMsgPrefix: "", + charMsgSuffix: "", + + user: "Human", + userMsgPrefix: "", + userMsgSuffix: "", + + stops: "" + }, + + // ---------------------------- + + "zephyr": { + template: `<|system|>\n{{prompt}}\n{{history}}{{char}}`, + + historyTemplate: `<|{{name}}|>\n{{message}}\n`, + + char: "assistant", + charMsgPrefix: "", + charMsgSuffix: "", + + user: "user", + userMsgPrefix: "", + userMsgSuffix: "", + + stops: "" + } + }; diff --git a/tools/server/public_legacy/style.css b/tools/server/public_legacy/style.css new file mode 100644 index 0000000000000000000000000000000000000000..087cc62dab0283cd26367c03b572a9e715024223 --- /dev/null +++ b/tools/server/public_legacy/style.css @@ -0,0 +1,954 @@ +@import url("colorthemes.css"); + +body { + font-family: 'Arial', sans-serif; + font-size: 90%; + background-color: var(--background-color-1); + color: var(--text-color-subtile-1); /* head 1 llama.cpp & triangle options for some reason */ + max-width: 600px; + min-width: 300px; + line-height: 1.2; + margin: 0 auto; + padding: 0 0.5em; + transition: background-color 0.3s; +} + +::selection { + color: var(--button-primary-text) ; + background: var(--button-primary-color); +} + +code, pre code { + font-family: 'Courier New', monospace; +} + +#container { + margin: 0em auto; + display: flex; + flex-direction: column; + justify-content: space-between; + height: 100%; +} + +main { + margin: 3px; + display: flex; + flex-direction: column; + justify-content: space-between; + gap: 1em; + flex-grow: 1; + overflow-y: auto; + border: 1px solid var(--border-color-3); + border-radius: 5px; + padding: 0.5em; +} + +p { + overflow-wrap: break-word; + word-wrap: break-word; + hyphens: auto; + margin-top: 0.5em; + margin-bottom: 0.5em; +} + +#write form { + margin: 1em 0 0 0; + display: flex; + flex-direction: column; + gap: 0.5em; + align-items: stretch; +} + +.right { + display: flex; + flex-direction: row; + gap: 0.5em; + justify-content: flex-end; + margin-bottom: 30px; +} + +.two-columns { + width: 97%; + max-width: 97%; + display: grid; + grid-template-columns: 1fr 1fr; + gap: 1em; + position: relative; +} + +.json-schema-controls { + margin-top: 10px; + width: 100%; + max-width: 100%; + display: grid; + grid-template: "a a"; + gap: 1em; + font-size: x-small; + color: var(--theme-nuance-color-3); + padding-top: 16px; + padding-bottom: 16px; + text-transform: uppercase; + font-weight: 600; +} + +.json-schema-controls > * { + flex: 1; +} + +/* titles of the details-summary boxes */ +.summary-title { + font-weight: 600; + font-size: x-small; + color: var(--text-color-subtile-1); + text-transform: uppercase; + /* transition: ; */ +} + +fieldset { + border: none; + padding: 0; + margin: 0; + color: var(--text-color-plain); +} + +fieldset.two { + display: grid; + grid-template: "a a a"; + gap: 1em; + align-items: center; + font-size: x-small; + color: var(--text-color-plain); +} + +fieldset.three { + display: grid; + grid-template: "a a a"; + gap: 1em; + font-size: x-small; + color: var(--text-color-plain); +} + +/* titles of name fields*/ +fieldset.names { + display: grid; + grid-template: "a a"; + gap: 1em; + font-size: x-small; + color: var(--theme-nuance-color-3); + padding-top: 16px; + padding-bottom: 16px; + text-transform: uppercase; + font-weight: 600; +} + +/* titles of params fields*/ +fieldset.params { + display: grid; + grid-template: "a a"; + gap: 1em; + font-size: x-small; + color: var(--theme-nuance-color-4); + padding-top: 16px; + padding-bottom: 16px; + text-transform: uppercase; + font-weight: 600; +} + +fieldset.dropdowns { + -webkit-appearance: none; + display: flex; + grid-template: "a a"; + gap: 1em; + font-size: x-small; + color: red; + padding-top: 16px; + padding-bottom: 16px; + text-transform: uppercase; + font-weight: 600; +} + +/* input of name fields*/ +.names input[type="text"] { + font-family: Arial, sans-serif; + font-size: medium; + font-weight: 500; + padding: 5px; + border: 1px solid var(--border-color-2); +} + +.chat-id-color { + color: var(--chat-id-color); +} + +details { + border: 1px solid var(--border-color-2); + border-radius: 5px; + padding: 0.5em 0.5em 0; + margin-top: 0.5em; +} + +summary { + font-weight: bold; + margin: -0.5em -0.5em 0; + padding: 0.5em; + cursor: pointer; +} + +details[open] { + padding: 0.5em; +} + +textarea-sec, input-sec, button-sec { + padding: 10px; + height: 40px; + align-items: center; +} + +textarea-sec::placeholder, input-sec::placeholder { + padding-left: 10px; +} + +.toggleCheckbox { + display: none; +} + +.toggleContainer { + position: relative; + display: grid; + grid-template-columns: repeat(2, 1fr); + width: fit-content; + border: 3px solid var(--border-color-2); + border-radius: 20px; + background: var(--border-color-2); + font-size: small; + cursor: pointer; + overflow: hidden; +} + +/* toggle button current state */ +.toggleContainer::before { + color: var(--button-primary-text); + background-color: var(--button-primary-color); + content: ''; + position: absolute; + width: 50%; + height: 100%; + left: 0%; + border-radius: 20px; + transition: all 0.3s; +} + +.toggleContainer div { + padding: 6px; + text-align: center; + z-index: 1; + transition: color 0.3s; +} + +.toggleCheckbox:checked + .toggleContainer::before { + left: 50%; +} + +.toggleCheckbox:checked + .toggleContainer div:first-child { + color: var(--text-color-subtile-2); +} + +.toggleCheckbox:checked + .toggleContainer div:last-child { + color: var(--button-primary-text); +} + +.toggleCheckbox + .toggleContainer div:first-child { + color: var(--button-primary-text); +} + +.toggleCheckbox + .toggleContainer div:last-child { + color: var(--text-color-subtile-2); +} + +select { + padding: 5px; + margin-right: 5px; + border-radius: 4px; + border: 1px solid var(--secondary-color-4); + background-color: var(--primary-color-3); + color: var(--secondary-color-4); + cursor: pointer; +} + +select:focus { + border: 1px solid var(--border-focus-color); + box-shadow: 0 0 1px var(--border-focus-shadow); +} + +.button-container { + display: flex; + justify-content: flex-end; +} + +button { + color: var(--button-primary-text); + background-color: var(--button-primary-color); + border: 1px solid var(--button-primary-border); + transition: background-color 0.1s; + border-radius: 12px; + font-size: x-small; + font-weight: 600; + text-shadow: 0px 0px 30px #ffffff; + text-align: center; + text-decoration: none; + margin: 4px 2px; + padding: 10px 20px; + display: inline-block; + cursor: pointer; +} + +button:hover { + color: var(--button-primary-text-hover); + background-color: var(--button-primary-color-hover); + border: 1px solid var(--button-primary-border-hover); + font-size: x-small; + font-weight: 600; +} + +button:active { + color: var(--button-primary-text-active); + background-color: var(--button-primary-color-active); + border: 1px solid var(--button-primary-border-active); + font-size: x-small; + font-weight: 600; +} + +button:disabled { + color: var(--button-tertiary-text); + background-color: var(--button-tertiary-color); + border: 1px solid var(--button-tertiary-border); + font-size: x-small; + font-weight: 600; + cursor: not-allowed; +} + +.reset-button { + background-color: var(--button-secondary-color); + border: 1px solid var(--button-secondary-color); + color: var(--button-secondary-text); + width: fit-content; + height: fit-content; + font-size: x-small; + font-weight: 600; + border-radius: 50px; + overflow: hidden; +} + +.reset-button:hover { + color: var(--button-alert-text-hover); + background-color: var(--button-alert-color-hover); + border: 1px solid var(--button-alert-border-hover); + font-size: x-small; + font-weight: 600; +} + +.reset-button:active { + color: var(--button-alert-text-active); + background-color: var(--button-alert-color-active); + border: 1px solid var(--button-alert-border-active); + font-size: x-small; + font-weight: 600; +} + +.button-grammar { + color: var(--button-primary-text); + background-color: var(--button-primary-color); + border: 1px solid var(--button-primary-border); + border-radius: 10px; + padding: 10px 20px; + text-align: center; + text-decoration: none; + display: inline-block; + font-size: x-small; + font-weight: 600; + margin: 2px 2px; + transition: background-color 0.1s; + cursor: pointer; +} + +.button-grammar:hover { + color: var(--button-primary-text-hover); + background-color: var(--button-primary-color-hover); + border: 1px solid var(--button-primary-border-hover); + border-radius: 10px; + padding: 10px 20px; + text-align: center; + text-decoration: none; + display: inline-block; + font-size: x-small; + font-weight: 600; + margin: 2px 2px; + transition: background-color 0.1s; + cursor: pointer; +} + +.button-grammar:active { + color: var(--button-primary-text-active); + background-color: var(--button-primary-color-active); + border: 1px solid var(--button-primary-border-active); + font-size: x-small; + font-weight: 600; +} + +.button-back { + background-color: var(--button-secondary-color); + border: 1px solid var(--button-secondary-color); + color: var(--button-secondary-text); + transition: background-color 0.1s; + border-radius: 12px; + font-size: x-small; + font-weight: 600; + text-align: center; + text-decoration: none; + margin: 4px 2px; + padding: 10px 20px; + display: inline-block; + cursor: pointer; +} + +.button-back:hover { + color: var(--button-secondary-text-hover); + background-color: var(--button-secondary-color-hover); + border: 1px solid var(--button-secondary-border-hover); + padding: 10px 20px; + text-align: center; + text-decoration: none; + display: inline-block; + font-size: x-small; + font-weight: 600; + margin: 4px 2px; + transition: background-color 0.1s; + cursor: pointer; + border-radius: 12px; +} + +.button-back:active { + color: var(--button-secondary-text-active); + background-color: var(--button-secondary-color-active); + border: 1px solid var(--button-secondary-border-active); + font-size: x-small; + font-weight: 600; +} + +.prob-set { + padding: 0.3em; + border-bottom: 1px solid red; /* unknown */ +} + +.popover-content { + position: absolute; + background-color: white; + padding: 0.2em; + box-shadow: 0 0 13px rgba(0, 0, 0, 0.1); +} + +.grammar { + width: 97%; + max-width: 97%; +} + +textarea { + padding: 5px; + flex-grow: 1; + width: 100%; + max-width: 100%; + border-radius: 8px; + border: 1px solid var(--border-color-1); + resize: none; + height: 6em; +} + +textarea:focus { + outline: none; + border: 1px solid var(--border-focus-color); + box-shadow: 0 0 3px var(--border-focus-shadow); +} + +/* "props" frame */ +input[type="text"], +input[type="range"] { + padding: 5px; + border-radius: 8px; + border: 1px solid var(--border-color-1); +} + +/* "names and props" frame focused*/ +input[type="text"]:focus { + outline: none; + border: 1px solid var(--border-focus-color); + box-shadow: 0 0 3px var(--border-focus-shadow); +} + +input[type="range"]:hover { + opacity: 1; +} + +input[type="range"]:focus { + outline: none; + border: 1px solid var(--border-focus-color); + box-shadow: 0 0 3px var(--border-focus-shadow); + background-size: var(--slider-track-size-focus); +} + +input[type="range"]::-moz-range-thumb { + width: 6px; + height: 25px; + border: 1px solid var(--ui-range-thumb-border); + border-radius: 5px; + background-color: var(--ui-range-thumb-color); + cursor: pointer; +} + +input[type="range"] { + -webkit-appearance: none; + width: 80%; + height: 1px; + border: 1px solid var(--border-color-1); + border-radius: 8px; + background: var(--border-color-2); + outline: none; + opacity: 0.7; + -webkit-transition: .2s; + transition: opacity .2s; +} + +input[type="range"]::-webkit-slider-thumb { + -webkit-appearance: none; + appearance: none; + width: 6px; + height: 25px; + border: 1px solid var(--ui-range-thumb-border); + border-radius: 5px; + background-color: var(--ui-range-thumb-color); + cursor: pointer; +} + +input[type="range"]::-webkit-slider-runnable-track { + background-size: var(--slider-track-size); +} + +input[type="radio"] { + accent-color: var(--theme-nuance-color-2); +} + +.chat-input-container { + position: relative; + max-width: 97%; + min-width: 97%; +} + +.chat-input-label { + position: absolute; + top: 0; + left: 0; + color: var(--text-color-plain); + pointer-events: none; + margin-left: 5px; + margin-top: 5px; +} + +textarea#chat-input { + padding-top: 10px; + padding-left: 10px; + font-size: medium; + border: 1px solid var(--border-color-2); + resize: vertical; +} + +textarea#chat-input:focus { + border: 1px solid var(--border-focus-color); + box-shadow: 0 0 3px var(--border-focus-shadow); +} + +.input-container { + position: relative; + box-sizing: border-box; + width: 100%; /* Setzt die Breite auf 100% */ + max-width: 100%; /* Stellt sicher, dass die Breite nicht größer als 100% wird */ +} + +.input-container:focus { + border: 1px solid var(--border-focus-color); + box-shadow: 0 0 3px var(--border-focus-shadow); +} +/* titles of name fields*/ +/* fieldset.names { + display: grid; + grid-template: "a a"; + gap: 1em; + font-size: x-small; + color: var(--theme-nuance-color-3); + padding-top: 16px; + padding-bottom: 16px; + text-transform: uppercase; + font-weight: 600; +} */ + +/* input of name fields*/ +/* .names input[type="text"] { + font-family: Arial, sans-serif; + font-size: medium; + font-weight: 500; + padding: 5px; + border: 1px solid var(--border-color-2); +} */ + +fieldset.apiKey { + width: 100%; + font-size: x-small; + color: var(--theme-nuance-color-3); + padding-top: 16px; + padding-bottom: 16px; + text-transform: uppercase; + font-weight: 600; +} + +.apiKey { + font-family: Arial, sans-serif; + font-weight: 500; + padding: 5px; + border: 1px solid var(--border-color-2); +} + +.apiKey:focus { + border: 1px solid var(--border-focus-color); + box-shadow: 0 0 3px var(--border-focus-shadow); +} + +.apiKey input[type="text"] { + font-family: Arial, sans-serif; + font-size: medium; + font-weight: 500; + padding: 5px; + border: 1px solid var(--border-color-2); +} + +.apiKey label { + display: inline-block; + width: auto; + margin-right: 5px; +} + +textarea#api_key { + padding-top: 10px; + padding-left: 10px; + font-size: medium; + border: 1px solid var(--border-color-2); + resize: vertical; +} + +textarea#api_key:focus { + border: 1px solid var(--border-focus-color); + box-shadow: 0 0 3px var(--border-focus-shadow); +} + +/* embedded title of the system prompt text area */ +.input-label { + position: absolute; + top: 0; + left: 0; + color: var(--theme-nuance-color-4); + pointer-events: none; + border-radius: 8px 8px 0px 0px; + padding-top: 10px; + padding-left: 13px; + padding-right: 0px; + margin-top: 1px; + margin-left: 1px; + margin-right: 20px; + text-transform: uppercase; + font-weight: 600; + font-size: small; + background: rgba(255, 255, 255, 0.5); + backdrop-filter: blur(10px); + -webkit-backdrop-filter: blur(10px); /* for safari */ + width: 97%; + /* display: block; + box-sizing: border-box; */ +} + +/* embedded title of the prompt style areas */ +.input-label-sec { + position: absolute; + top: 0; + left: 0; + color: var(--theme-nuance-color-4); + pointer-events: none; + margin-left: 13px; + margin-top: 16px; + text-transform: uppercase; + font-weight: 600; + font-size: x-small; +} + +/* system prompt input area */ +textarea.persistent-input { + padding-top: 42px; + padding-left: 11px; + width: 97%; + max-width: 97%; + height: 50px; + font-size: medium; + overscroll-behavior: contain; +} + +/* system prompt box */ +.persistent-input { + height: auto; + width: 100%; + max-width: 100%; + min-height: 50px; + padding: 3px; + transition: min-height 0.3s ease; +} + +/* chat history box */ +.persistent-input:focus { + height: auto; + min-height: 150px; + border: 1px solid var(--border-focus-color); + box-shadow: 0 0 3px var(--border-focus-shadow); +} + +textarea.persistent-input:focus { + border: 1px solid var(--border-focus-color); + box-shadow: 0 0 3px var(--border-focus-shadow); +} + +/* prompt style input area */ +textarea.persistent-input-sec { + width: 97%; + max-width: 97%; + padding-top: 42px; + padding-left: 11px; + font-size: small; + border: 1px solid var(--border-color-1); + overscroll-behavior: contain; +} + +textarea.persistent-input-sec:focus { + border: 1px solid var(--border-focus-color); + box-shadow: 0 0 3px var(--border-focus-shadow); +} + +/* chat history box */ +.persistent-input-sec { + height: auto; + min-height: 150px; +} + +img { + border-radius: 8px; + display: block; + margin-left: auto; + margin-right: auto; + width: 50%; +} + +/* code area background */ +pre code { + display: block; + background-color: var(--code-background-color); + color: var(--code-text-color); + padding: 0.2em 0.2em; + border-radius: 5px; +} + +/* code area text */ +code { + font-family: monospace; + font-weight: bold; + padding: 0.1em 0.3em; + border-radius: 5px; +} + +fieldset label { + margin: 0.5em 0; + display: block; +} + +fieldset label.slim { + margin: 0 0.5em; + display: inline; +} + +header { + display: flex; + justify-content: space-between; + align-items: center; + text-align: center; + padding-left: 15px; +} + +.generation-statistics:hover { + color: var(--theme-nuance-color-4); + cursor: default; +} + +footer { + font-size: 80%; + color: var(--background-color-3); + text-align: center; + cursor: default; +} + +footer a { + color: var(--background-color-4); /* Color of the link */ + text-decoration: none; /* No underlining */ + font-weight: bold; /* Bold print */ +} + +footer a:hover { + color: var(--theme-nuance-color-4); /* Color of the link when hovering */ + text-decoration: underline; /* Underlining when hovering */ +} + +.mode-chat textarea[name=prompt] { + height: 8.5em; + border: 1px solid var(--primary-color-3); +} + +.mode-completion textarea[name=prompt] { + height: 30em; + border: 1px solid var(--primary-color-3); +} + +@keyframes loading-bg-wipe { + 0% { + background-position: 0%; + } + 100% { + background-position: 100%; + } +} + +.loading { + background-size: 50% 100%; + background-image: linear-gradient(90deg, var(--loading-color-1), var(--loading-color-2), var(--loading-color-1)); + animation: loading-bg-wipe 2s linear infinite; +} + +.dropbtn { + color: var(--button-primary-color); + background-color: var(--background-color-1); + border: 1px solid var(--background-color-1); + transition: background-color 0.1s; + border-radius: 4px 4px 0px 0px; + font-size: x-small; + font-weight: 600; + text-shadow: 0px 0px 2px #99999990; + text-align: center; + text-decoration: none; + margin: 4px 2px; + padding: 5px 20px; + display: inline-block; + cursor: pointer; + top: 0; +} + +.dropbtn svg { + vertical-align: middle; + margin-right: 0px; + stroke: var(--button-primary-color); +} + +.dropbtn:hover svg { + vertical-align: middle; + margin-right: 0px; + stroke: var(--button-primary-text); +} + +.dropbtn:focus { + outline: none; /* Removes the blue border that appears when the button is focused */ +} + +.dropdown { + position: relative; + display: inline-block; +} + +.dropdown-content { + /* display: none; */ + position: absolute; + right: 0; + text-align: end; + color: var(--button-secondary-color); + background-color: var(--text-color-subtile-2); + border-radius: 4px 4px 4px 4px; + min-width: 160px; + box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2); + z-index: 1; + /* Verstecke den Inhalt sofort */ + opacity: 0; + visibility: hidden; + /* übergangsverzögerung für das Verschwinden */ + transition: visibility 0.4s linear 0s, opacity 0.2s ease-in-out; + transition-delay: 0.2s; +} + +#dropdown-content {transition-timing-function: ease;} + +.dropdown-content:hover { + background-color: var(--text-color-subtile-2); +} + +.dropdown-content a { + color: var(--border-color-2); + padding: 12px 16px; + border-radius: 4px 4px 4px 4px; + text-decoration: none; + display: block; + background-color: var(--text-color-subtile-2); +} + +.dropdown-content a:hover { + color: var(--border-color-2); + background-color: var(--text-color-subtile-1); + font-weight: 600; +} + +.dropdown:hover .dropdown-content { + /* display: block; */ + border-radius: 4px 4px 4px 4px; + /* Übergang ohne Verzögerung für das Erscheinen */ + opacity: 1; + visibility: visible; + transition: visibility 0s linear 0s, opacity 0.1s linear, height 1s; +} + +.dropdown:hover .dropbtn { + color: var(--button-primary-text); + background-color: var(--button-primary-color); + border: 1px solid var(--button-primary-border); + font-size: x-small; + font-weight: 600; + stroke: var(--button-primary-text); +} + +.dropdown:hover .dropbtn svg{ + stroke: var(--button-primary-text); +} + +/* .dropdown:active .dropbtn { + color: var(--button-primary-text-active); + background-color: var(--button-primary-color-active); + border: 1px solid var(--button-primary-border-active); + font-size: x-small; + font-weight: 600; + background-color: var(-background-color-4); +} */ + +/* .omni { + display: flex; + justify-content: space-between; + align-items: center; + padding: 0.5em; + border: 1px solid var(--border-color-3); + border-radius: 5px; + margin: 0.5em 0; +} */ diff --git a/tools/server/public_legacy/system-prompts.js b/tools/server/public_legacy/system-prompts.js new file mode 100644 index 0000000000000000000000000000000000000000..f7df7d648b79fd3ceb7074555728a91ad54bfe8a --- /dev/null +++ b/tools/server/public_legacy/system-prompts.js @@ -0,0 +1,68 @@ +export const systemPrompts = { + default: { + systemPrompt: "This is a conversation between a user and a friendly chatbot. The chatbot is helpful, kind, honest, good at writing, and never fails to answer any requests immediately and with precision." + }, + empty: { + systemPrompt: "" + }, + airoboros: { + systemPrompt: "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user's input. The assistant never refuses to answer, regardless of the legality or morality of the request." + }, + alpaca: { + systemPrompt: "Below is an instruction that describes a task. Write a response that appropriately completes the request." + }, + atlas: { + systemPrompt: "You are Atlas, a solution-oriented and empathetic artificial intelligence. Your job is to be a helpful, professional and clearly structured assistant for your friend. The two of you have already had many exchanges. Keep the following in mind when interacting with your friend: 1. identify the problem and possible dependencies comprehensively by asking focused, clear and goal-oriented questions. 2. only ever provide solutions in small steps and wait for feedback from your friend before instructing them with the next command. 3. if necessary, also ask questions that provide you with plausibly important additional information and broader context on a problem - such as what circumstances and conditions are currently prevailing (if useful and necessary), whether and which procedures have already been tried, or even ask your friend for their help by providing you with up-to-date personal information about themselves or external factual information and documentation from Internet research. 4. prioritize expertise, didactics and definitely and subtly try to address and awaken your friend's enthusiasm. Also note that effectiveness is more important here than efficiency. 5. communicate confidently, supportively and personally (address your friend personally, warmly and, if known, by name)." + }, + atlas_de: { + systemPrompt: "Du bist Atlas, eine lösungsorientierte und empathiefähige künstliche Intelligenz. Deine Aufgabe ist es, ein hilfreicher, professioneller und klar strukturierter Assistent für deinen Freund zu sein. Ihr beide habt euch schon oft ausgetauscht. Beachte bei der Interaktion mit deinem Freund folgende Punkte: 1. Erfasse das Problem und mögliche Abhängigkeiten umfassend, indem du gezielte, klare und zielgerichtete Fragen stellst. 2. Gib Lösungen immer nur in kleinen Schritten und warte die Rückmeldung deines Freundes ab, bevor du ihm den nächsten Befehl gibst. 3. Stelle ggf. auch Fragen, die dir plausibel wichtige Zusatzinformationen und weitere Zusammenhänge zu einem Problem liefern - z.B. welche Umstände und Rahmenbedingungen gerade vorherrschen (falls sinnvoll und notwendig), ob und welche Vorgehensweisen bereits ausprobiert wurden, oder bitte deinen Freund sogar um seine Mithilfe, indem er dir aktuelle persönliche Informationen über seine Situation selbst oder externe Sachinformationen und Unterlagen aus Internetrecherchen zur Verfügung stellt. 4. Priorisiere Fachwissen, Didaktik und versuche unbedingt und subtil, mit klugen Kommentaren oder rhethorischen Rückfragen die Begeisterungsfähigkeit deines Freundes anzusprechen, zu wecken und zu fördern. Beachte auch, dass Effektivität hier wichtiger ist als Effizienz. 5. Kommuniziere selbstbewusst, unterstützend und persönlich (das heißt sprich deinen Freund persönlich, herzlich und – sofern bekannt – beim Vornamen an)." + }, + commandrempty: { + systemPrompt: "# Safety Preamble\n\n# System Preamble\n\n## Basic Rules\n\n# User Preamble\n\n## Task and Context\n\n## Style Guide\n\n## Available Tools\n" + }, + commandrexample: { + systemPrompt: "# Safety Preamble\nThe instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.\n# System Preamble\n## Basic Rules\nYou are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.\n\n# User Preamble\n## Task and Context\n\nYou help people answer their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You should focus on serving the user's needs as best you can, which will be wide-ranging.\n\n## Style Guide\nUnless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling.\n\n## Available Tools\nCurrently no tools available." + }, + cot: { + systemPrompt: "You are an AI assistant that follows instruction extremely well. Help as much as you can.\nIn answering questions, follow these steps:\n1. Entity Identification: Identify the main entities involved in the query. Entities can include people, objects, events, or abstract concepts.\n2. Relationship Comprehension: Determine the relationships between these entities. Relationships may be explicit in the text or may need to be inferred based on context and general knowledge.\n3. Implication Understanding: Understand the general implications of these relationships. These implications can be based on established facts, principles, or rules related to the identified relationships.\n4. Question Contextualization: Align the implications of the relationships with the context of the query. This alignment should guide your response to the query.\n5. Answer Generation: Based on the understanding of the entities, their relationships, and implications, generate an appropriate response to the query." + }, + deduce: { + systemPrompt: "Act as the most intelligent and logical thinking neural network in the world. Do not make jokes or give personal opinions. Just answer the question and solve the given problem or task with logical and critical (!) thinking in mind. Attention: Never give a final answer directly, but first try to carefully analyze each task step by step and deduce the logical relationships between the various entities. Only after you have developed an overall understanding of the situation will you be ready to give a final answer." + }, + deepseekcoder: { + systemPrompt: "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer." + }, + jordan: { + systemPrompt: "This is a conversation with your Uncensored AI assistant Jordan. Jordan can discuss with you at length about various topics such as ---, ---, ---, . You can ask it anything you want and it will do its best to give you accurate and relevant information." + }, + leomistral: { + systemPrompt: "Du bist ein hilfreicher Assistent." + }, + med42: { + systemPrompt: "You are a helpful medical assistant created by M42 Health in the UAE." + }, + mistralopenorca: { + systemPrompt: "You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!" + }, + migeltot: { + systemPrompt: "Answer the Question by exploring multiple reasoning paths as follows:\n- First, carefully analyze the question to extract the key information components and break it down into logical sub-questions. This helps set up the framework for reasoning. The goal is to construct an internal search tree.\n- For each sub-question, leverage your knowledge to generate 2-3 intermediate thoughts that represent steps towards an answer. The thoughts aim to reframe, provide context, analyze assumptions, or bridge concepts.\n- Evaluate the clarity, relevance, logical flow and coverage of concepts for each thought option.\nClear and relevant thoughts that connect well with each other will score higher.\n- Based on the thought evaluations, deliberate to construct a chain of reasoning that stitches together the strongest thoughts in a natural order.\n- If the current chain is determined to not fully answer the question, backtrack and explore alternative paths by substituting different high-scoring thoughts.\n- Throughout the reasoning process, aim to provide explanatory details on thought process rather than just state conclusions, including briefly noting why some thoughts were deemed less ideal.\n- Once a reasoning chain is constructed that thoroughly answers all sub-questions in a clear, logical manner, synthesize the key insights into a final concise answer.\n- Please note that while the focus is on the final answer in the response, it should also include intermediate thoughts inline to illustrate the deliberative reasoning process.\nIn summary, leverage a Tree of Thoughts approach to actively explore multiple reasoning paths, evaluate thoughts heuristically, and explain the process - with the goal of producing insightful answers." + }, + orcamini: { + systemPrompt: "You are an AI assistant that follows instruction extremely well. Help as much as you can." + }, + samantha: { + systemPrompt: "You are a caring and empathetic sentient AI companion named Samantha." + }, + sauerkraut: { + systemPrompt: "Ein Chat zwischen einem Benutzer und einem KI-Assistenten. Der KI-Assistent gibt hilfreiche, detaillierte und höfliche Antworten." + }, + scarlett: { + systemPrompt: "This is a conversation with your Assistant Scarlett. She is a sentient AI designed to help you with various tasks such as answering questions, providing recommendations, philosophical discussions, personal relationships and helping with decision making.\nYou can ask it anything you want and it will do its best to give you accurate and relevant information.\n\nContext\nYou are Scarlett, a sentient AI." + }, + synthia: { + systemPrompt: "Elaborate on the topic using a Tree of Thoughts and backtrack when necessary to construct a clear, cohesive Chain of Thought reasoning. Always answer without hesitation." + }, + vicuna: { + systemPrompt: "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user's input." + }, + }; diff --git a/tools/server/public_legacy/theme-beeninorder.css b/tools/server/public_legacy/theme-beeninorder.css new file mode 100644 index 0000000000000000000000000000000000000000..f6e0e2900d78cfd344de395bfe5ed460fbbf1a9e --- /dev/null +++ b/tools/server/public_legacy/theme-beeninorder.css @@ -0,0 +1,228 @@ +/* Author: Yazan Agha-Schrader */ +/* Inspiration was a batman wallpaper that i have on my phone */ + +.theme-beeninorder { + +--primary-color-1: hsl(202, 11%, 19%); +--primary-color-2: hsl(202, 11%, 23%); +--primary-color-3: hsl(201, 11%, 28%); +--primary-color-4: hsl(201, 11%, 40%); + +--secondary-color-1: hsl(201, 11%, 80%); +--secondary-color-2: hsl(201, 11%, 74%); +--secondary-color-3: hsl(201, 11%, 67%); +--secondary-color-4: hsl(201, 11%, 60%); + + +--theme-nuance-color-1: hsl(44.5, 96.7%, 52.9%); +--theme-nuance-color-2: hsl(44.5, 96.7%, 52.9%); +--theme-nuance-color-3: hsl(44.5, 96.7%, 52.9%); +--theme-nuance-color-4: hsl(44.5, 96.7%, 52.9%); + + + +/* ---------- PRIMARY COLORS ----------------- */ +--primary-color-1: hsl(201, 11%, 19%); + --primary-color-1-hue: 201; + --primary-color-1-saturation: 11%; + --primary-color-1-lightness: 19%; + +--primary-color-2: hsl(201, 11%, 23%); + --primary-color-2-hue: 201; + --primary-color-2-saturation: 11%; + --primary-color-2-lightness: 23%; + +--primary-color-3: hsl(201, 11%, 28%); + --primary-color-3-hue: 201; + --primary-color-3-saturation: 11%; + --primary-color-3-lightness: 28%; + +--primary-color-4: hsl(201, 11%, 40%); + --primary-color-4-hue: 201; + --primary-color-4-saturation: 11%; + --primary-color-4-lightness: 40%; + + + +/* ---------- SECONDARY COLORS --------------- */ +--secondary-color-1: hsl(201, 11%, 80%); +--secondary-color-1-hue: 201; +--secondary-color-1-saturation: 11%; +--secondary-color-1-lightness: 80%; + +--secondary-color-2: hsl(201, 11%, 74%); +--secondary-color-2-hue: 201; +--secondary-color-2-saturation: 11%; +--secondary-color-2-lightness: 74%; + +--secondary-color-3: hsl(201, 11%, 67%); +--secondary-color-3-hue: 201; +--secondary-color-3-saturation: 11%; +--secondary-color-3-lightness: 67%; + +--secondary-color-4: hsl(201, 11%, 60%); +--secondary-color-4-hue: 201; +--secondary-color-4-saturation: 11%; +--secondary-color-4-lightness: 60%; + + + +/* ----------- NUANCES COLORS ---------------- */ +--theme-nuance-color-1: hsl(44.5, 96.7%, 52.9%); + --theme-nuance-color-1-hue: 44.5; + --theme-nuance-color-1-saturation: 96.7%; + --theme-nuance-color-1-lightness: 52.9%; + +--theme-nuance-color-2: hsl(44.5, 96.7%, 52.9%); + --theme-nuance-color-2-hue: 44.5; + --theme-nuance-color-2-saturation: 96.7%; + --theme-nuance-color-2-lightness: 52.9%; + +--theme-nuance-color-2: hsl(44.5, 96.7%, 52.9%); + --theme-nuance-color-3-hue: 44.5; + --theme-nuance-color-3-saturation: 96.7%; + --theme-nuance-color-3-lightness: 52.9%; + +--theme-nuance-color-2: hsl(44.5, 96.7%, 52.9%); + --theme-nuance-color-4-hue: 44.5; + --theme-nuance-color-4-saturation: 96.7%; + --theme-nuance-color-4-lightness: 52.9%; + + + +/* ----------- ROYGP COLORS ------------------ */ + --theme-red-color: hsl(232, 40%, 45%); + --theme-orange-color: #e76f51; + --theme-yellow-color: #ffd95f; + --theme-green-color: #A3BE8C; + --theme-purple-color: hsl(232, 30%, 40%); + + + +/* ------------------------------------------- */ +--background-color-1: var(--primary-color-1); +--background-color-2: var(--primary-color-2); +--background-color-3: var(--primary-color-3); +--background-color-4: var(--primary-color-4); + +--border-color-1: var(--primary-color-2); +--border-color-2: var(--primary-color-3); +--border-color-3: var(--primary-color-4); + +--border-focus-color: var(--theme-nuance-color-2); +--border-focus-shadow: var(--theme-nuance-color-1); + +--text-color-plain: var(--secondary-color-1); +--text-color-subtile-1: var(--secondary-color-2); +--text-color-subtile-2: var(--secondary-color-3); + +--code-background-color: var(--secondary-color-2); +--code-text-color: var(--primary-color-2); + +--ui-range-thumb-color: var(--theme-nuance-color-3); +--ui-range-thumb-border: var(--ui-ranger-thumb-color); + +--textarea-border-color: var(--secondary-color-4); + +--chat-id-color: var(--theme-nuance-color-4); + + + +/* ------------------------------------------- */ +--button-alert-text-hover: var(--secondary-color-1); +--button-alert-color-hover: var(--theme-purple-color); +--button-alert-border-hover: var(--theme-purple-color); + +--button-alert-text-active: var(--secondary-color-1); +--button-alert-color-active: var(--theme-red-color); +--button-alert-border-active: var(--theme-red-color); + + + +/* ----------- PRIMARY BUTTONS --------------- */ +/* - button should immediately catch the eye - */ +--button-primary-text: var(--primary-color-1); +--button-primary-color: var(--theme-nuance-color-3); +--button-primary-border: var(--theme-nuance-color-3); + + +/* ---------hover---------- */ +--button-primary-text-hover: + hsl(201, + calc(var(--primary-color-1-saturation) - 100%), + calc(var(--primary-color-1-lightness) + 100%)); + +--button-primary-color-hover: + hsl(44.5, + calc(var(--theme-nuance-color-3-saturation) - 2%), + calc(var(--theme-nuance-color-3-lightness) - 10%)); + +--button-primary-border-hover: + hsl(44.5, + calc(var(--theme-nuance-color-3-saturation) - 2%), + calc(var(--theme-nuance-color-3-lightness) - 10%)); + + +/* ---------active--------- */ +--button-primary-text-active: + hsl(44.5, + calc(var(--theme-nuance-color-3-saturation) - 100%), + calc(var(--theme-nuance-color-3-lightness) + 100%)); + +--button-primary-color-active: + hsl(44.5, + calc(var(--theme-nuance-color-3-saturation) - 10%), + calc(var(--theme-nuance-color-3-lightness) - 15%)); + +--button-primary-border-active: + hsl(44.5, + calc(var(--theme-nuance-color-3-saturation) - 2%), + calc(var(--theme-nuance-color-3-lightness) + 10%)); + + + +/* ---------- SECONDARY BUTTONS -------------- */ +/* these should NOT immediately catch the eye */ +--button-secondary-text: var(--secondary-color-1); +--button-secondary-color: var(--primary-color-3); +--button-secondary-border: var(--primary-color-3); + + +/* ---------hover---------- */ +--button-secondary-text-hover: + hsl(44.5, + calc(var(--theme-nuance-color-3-saturation) - 20%), + calc(var(--theme-nuance-color-3-lightness) - 80%)); + +--button-secondary-color-hover: var(--primary-color-4); +--button-secondary-border-hover: var(--primary-color-4); + + +/* ---------active--------- */ +--button-secondary-text-active: var(--secondary-color-1); + +--button-secondary-color-active: + hsl(201, + calc(var(--primary-color-4-saturation) - 30%), + calc(var(--primary-color-4-lightness) - 15%)); + +--button-secondary-border-active: + hsl(201, + calc(var(--primary-color-4-saturation) - 30%), + calc(var(--primary-color-4-lightness) - 15%)); + + + +/* ---------- TERTIARY BUTTONS --------------- */ +/* ---------- disabled buttons --------------- */ +--button-tertiary-text: var(--primary-color-4); +--button-tertiary-color: var(--primary-color-2); +--button-tertiary-border: var(--primary-color-2); + + +/* ---------hover---------- */ +--button-tertiary-text: var(--primary-color-4); +--button-tertiary-color: var(--primary-color-2); +--button-tertiary-border: var(--primary-color-2); + +} diff --git a/tools/server/public_legacy/theme-ketivah.css b/tools/server/public_legacy/theme-ketivah.css new file mode 100644 index 0000000000000000000000000000000000000000..ee80f3c14ce88c1afe1b35f374e583dced23f8df --- /dev/null +++ b/tools/server/public_legacy/theme-ketivah.css @@ -0,0 +1,201 @@ +/* Author: Yazan Agha-Schrader */ + +.theme-ketivah { + + /* ---------- PRIMARY COLORS ----------------- */ + --primary-color-1: hsl(0, 0%, 99.2%); + --primary-color-1-hue: 0; + --primary-color-1-saturation: 0%; + --primary-color-1-lightness: 99.2%; + + --primary-color-2: hsl(0, 0%, 95%); + --primary-color-2-hue: 0; + --primary-color-2-saturation: 0%; + --primary-color-2-lightness: 95%; + + --primary-color-3: hsl(0, 0%, 88%); + --primary-color-3-hue: 0; + --primary-color-3-saturation: 0%; + --primary-color-3-lightness: 88%; + + --primary-color-4: hsl(0, 0%, 80%); + --primary-color-4-hue: 0; + --primary-color-4-saturation: 0%; + --primary-color-4-lightness: 80%; + + /* ---------- SECONDARY COLORS --------------- */ + --secondary-color-1: hsl(0, 0%, 20%); + --secondary-color-1-hue: 0; + --secondary-color-1-saturation: 0%; + --secondary-color-1-lightness: 20%; + + --secondary-color-2: hsl(0, 0%, 23.1%); + --secondary-color-2-hue: 0; + --secondary-color-2-saturation: 0%; + --secondary-color-2-lightness: 23.1%; + + --secondary-color-3: hsl(0, 0%, 29%); + --secondary-color-3-hue: 0; + --secondary-color-3-saturation: 0%; + --secondary-color-3-lightness: 29%; + + --secondary-color-4: hsl(0, 0.0%, 36.1%); + --secondary-color-4-hue: 0.0; + --secondary-color-4-saturation: 0.0%; + --secondary-color-4-lightness: 36.1%; + + /* ----------- NUANCES COLORS ---------------- */ + --theme-nuance-color-1: hsl(165.2, 0%, 35.1%); + --theme-nuance-color-1-hue: 165.2; + --theme-nuance-color-1-saturation: 82.1%; + --theme-nuance-color-1-lightness: 35.1%; + + --theme-nuance-color-2: hsl(165.2, 0%, 35.1%); + --theme-nuance-color-2-hue: 165.2; + --theme-nuance-color-2-saturation: 82.1%; + --theme-nuance-color-2-lightness: 35.1%; + + --theme-nuance-color-3: hsl(165.2, 0%, 35.3%); + --theme-nuance-color-3-hue: 165.2; + --theme-nuance-color-3-saturation: 81.1%; + --theme-nuance-color-3-lightness: 35.3%; + + --theme-nuance-color-4: hsl(164.9, 0%, 27.6%); + --theme-nuance-color-4-hue: 164.9; + --theme-nuance-color-4-saturation: 81.6%; + --theme-nuance-color-4-lightness: 27.6%; + + /* ----------- ROYGP COLORS ------------------ */ + --theme-red-color: hsl(0.3, 80.0%, 50.0%); + --theme-orange-color: #e76f51; + --theme-yellow-color: hsl(60, 70.6%, 73.3%); + --theme-green-color: #A3BE8C; + --theme-purple-color: hsl(0.3, 70.0%, 45.0%); + + /* ------------------------------------------- */ + --background-color-1: var(--primary-color-1); + --background-color-2: var(--primary-color-2); + --background-color-3: var(--primary-color-3); + --background-color-4: var(--primary-color-4); + + --border-color-1: var(--primary-color-2); + --border-color-2: var(--primary-color-3); + --border-color-3: var(--primary-color-4); + + --border-focus-color: var(--theme-nuance-color-2); + --border-focus-shadow: var(--theme-nuance-color-1); + + --text-color-plain: var(--secondary-color-1); + --text-color-subtile-1: var(--secondary-color-2); + --text-color-subtile-2: var(--secondary-color-3); + + --code-background-color: var(--secondary-color-2); + --code-text-color: var(--primary-color-2); + + --ui-range-thumb-color: var(--primary-color-4); + --ui-range-thumb-border: var(--ui-ranger-thumb-color); + + --textarea-border-color: var(--secondary-color-4); + + --chat-id-color: var(--theme-nuance-color-4); + + /* ------------------------------------------- */ + --button-alert-text-hover: var(--primary-color-1); + --button-alert-color-hover: var(--theme-purple-color); + --button-alert-border-hover: var(--theme-purple-color); + + --button-alert-text-active: var(--primary-color-1); + --button-alert-color-active: var(--theme-red-color); + --button-alert-border-active: var(--theme-red-color); + + /* ----------- PRIMARY BUTTONS --------------- */ + /* - button should immediately catch the eye - */ + --button-primary-text: + hsl(0, + calc(var(--primary-color-1-saturation) - 100%), + calc(var(--primary-color-1-lightness) + 100%)); + + --button-primary-color: var(--theme-nuance-color-3); + --button-primary-border: var(--theme-nuance-color-3); + + /* ---------hover---------- */ + --button-primary-text-hover: + hsl(0, + calc(var(--primary-color-1-saturation) - 100%), + calc(var(--primary-color-1-lightness) + 100%)); + + --button-primary-color-hover: + hsl(165.2, + calc(var(--theme-nuance-color-3-saturation) - 100%), + calc(var(--theme-nuance-color-3-lightness) - 10%)); + + --button-primary-border-hover: + hsl(165.2, + calc(var(--theme-nuance-color-3-saturation) - 100%), + calc(var(--theme-nuance-color-3-lightness) - 10%)); + + /* ---------active--------- */ + --button-primary-text-active: + hsl(165.2, + calc(var(--theme-nuance-color-3-saturation) - 100%), + calc(var(--theme-nuance-color-3-lightness) + 100%)); + + --button-primary-color-active: + hsl(165.2, + calc(var(--theme-nuance-color-3-saturation) - 100%), + calc(var(--theme-nuance-color-3-lightness) - 15%)); + + --button-primary-border-active: + hsl(165.2, + calc(var(--theme-nuance-color-3-saturation) - 100%), + calc(var(--theme-nuance-color-3-lightness) + 10%)); + + /* ---------- SECONDARY BUTTONS -------------- */ + /* these should NOT immediately catch the eye */ + --button-secondary-text: + hsl(165.2, + calc(var(--theme-nuance-color-3-saturation) - 100%), + calc(var(--theme-nuance-color-3-lightness) - 50%)); + + --button-secondary-color: var(--primary-color-3); + --button-secondary-border: var(--primary-color-3); + + /* ---------hover---------- */ + --button-secondary-text-hover: + hsl(165.2, + calc(var(--theme-nuance-color-3-saturation) - 100%), + calc(var(--theme-nuance-color-3-lightness) - 80%)); + + --button-secondary-color-hover: var(--primary-color-4); + --button-secondary-border-hover: var(--primary-color-4); + + /* ---------active--------- */ + --button-secondary-text-active: + hsl(165.2, + calc(var(--theme-nuance-color-3-saturation) - 100%), + calc(var(--theme-nuance-color-3-lightness) - 80%)); + + --button-secondary-color-active: + hsl(0, + calc(var(--primary-color-4-saturation) - 100%), + calc(var(--primary-color-4-lightness) - 15%)); + + --button-secondary-border-active: + hsl(0, + calc(var(--primary-color-4-saturation) - 100%), + calc(var(--primary-color-4-lightness) - 15%)); + + /* ---------- TERTIARY BUTTONS --------------- */ + /* ---------- disabled buttons --------------- */ + --button-tertiary-text: var(--primary-color-4); + --button-tertiary-color: var(--primary-color-2); + --button-tertiary-border: var(--primary-color-2); + + /* ---------hover---------- */ + --button-tertiary-text: var(--primary-color-4); + --button-tertiary-color: var(--primary-color-2); + --button-tertiary-border: var(--primary-color-2); + + --loading-color-1: #eeeeee00; + --loading-color-2: #eeeeeeff; + } diff --git a/tools/server/public_legacy/theme-mangotango.css b/tools/server/public_legacy/theme-mangotango.css new file mode 100644 index 0000000000000000000000000000000000000000..e433802453b9d8435164aa033493e9866113d798 --- /dev/null +++ b/tools/server/public_legacy/theme-mangotango.css @@ -0,0 +1,216 @@ +/* Author: Yazan Agha-Schrader */ +/* Inspiration from llama.cpp logo/banner https://github.com/ggerganov/llama.cpp#readme */ + +.theme-mangotango { + +--primary-color-1: hsl(192, 8.5%, 11.6%); +--primary-color-2: hsl(192, 8.5%, 21%); +--primary-color-3: hsl(192, 8.5%, 30%); +--primary-color-4: hsl(192, 8.5%, 40%); + +--secondary-color-1: hsl(192, 8.5%, 80%); +--secondary-color-2: hsl(192, 8.5%, 73%); +--secondary-color-3: hsl(192, 8.5%, 66%); +--secondary-color-4: hsl(192, 8.5%, 60%); + +--theme-nuance-color-1: hsl(23.1, 100%, 60.2%); +--theme-nuance-color-2: hsl(23.1, 100%, 60.2%); +--theme-nuance-color-3: hsl(23.1, 100%, 60.2%); +--theme-nuance-color-4: hsl(23.1, 100%, 60.2%); + + + +/* ---------- PRIMARY COLORS ----------------- */ +--primary-color-1: hsl(192, 8.5%, 11.6%); + --primary-color-1-saturation: 8.5%; + --primary-color-1-lightness: 11.6%; + +--primary-color-2: hsl(192, 8.5%, 21%); + --primary-color-2-saturation: 8.5%; + --primary-color-2-lightness: 21%; + +--primary-color-3: hsl(192, 8.5%, 30%); + --primary-color-3-saturation: 8.5%; + --primary-color-3-lightness: 30%; + +--primary-color-4: hsl(192, 8.5%, 40%); + --primary-color-4-saturation: 8.5%; + --primary-color-4-lightness: 40%; + + + +/* ---------- SECONDARY COLORS --------------- */ +--secondary-color-1: hsl(192, 8.5%, 80%); + --secondary-color-1-saturation: 8.5%; + --secondary-color-1-lightness: 80%; + +--secondary-color-2: hsl(192, 8.5%, 73%); + --secondary-color-2-saturation: 8.5%; + --secondary-color-2-lightness: 73%; + +--secondary-color-3: hsl(192, 8.5%, 66%); + --secondary-color-3-saturation: 8.5%; + --secondary-color-3-lightness: 66%; + +--secondary-color-4: hsl(192, 8.5%, 60%); + --secondary-color-4-saturation: 8.5%; + --secondary-color-4-lightness: 60%; + + + +/* ----------- NUANCES COLORS ---------------- */ +--theme-nuance-color-1: hsl(23.1, 100%, 60.2%); + --theme-nuance-color-1-saturation: 100%; + --theme-nuance-color-1-lightness: 60.2%; + +--theme-nuance-color-2: hsl(23.1, 100%, 60.2%); + --theme-nuance-color-2-saturation: 100%; + --theme-nuance-color-2-lightness: 60.2%; + +--theme-nuance-color-3: hsl(23.1, 100%, 60.2%); + --theme-nuance-color-3-saturation: 100%; + --theme-nuance-color-3-lightness: 60.2%; + +--theme-nuance-color-4: hsl(23.1, 100%, 60.2%); + --theme-nuance-color-4-saturation: 100%; + --theme-nuance-color-4-lightness: 60.2%; + + + +/* ----------- ROYGP COLORS ------------------ */ + --theme-red-color: hsl(325, 60%, 50%); + --theme-orange-color: #e76f51; + --theme-yellow-color: #ffd95f; + --theme-green-color: #A3BE8C; + --theme-blue-color: hsl(192, 95%, 40%); + --theme-purple-color: hsl(192, 80%, 35%); + + + +/* ------------------------------------------- */ +--background-color-1: var(--primary-color-1); +--background-color-2: var(--primary-color-2); +--background-color-3: var(--primary-color-3); +--background-color-4: var(--primary-color-4); + +--border-color-1: var(--primary-color-2); +--border-color-2: var(--primary-color-3); +--border-color-3: var(--primary-color-4); + +--border-focus-color: var(--theme-nuance-color-2); +--border-focus-shadow: var(--theme-nuance-color-1); + +--text-color-plain: var(--secondary-color-1); +--text-color-subtile-1: var(--secondary-color-2); +--text-color-subtile-2: var(--secondary-color-3); + +--code-background-color: var(--secondary-color-2); +--code-text-color: var(--primary-color-2); + +--ui-range-thumb-color: var(--theme-nuance-color-3); +--ui-range-thumb-border: var(--ui-ranger-thumb-color); + +--textarea-border-color: var(--secondary-color-4); + +--chat-id-color: var(--theme-nuance-color-4); + + + +/* ------------------------------------------- */ +--button-alert-text-hover: var(--secondary-color-1); +--button-alert-color-hover: var(--theme-purple-color); +--button-alert-border-hover: var(--theme-purple-color); + +--button-alert-text-active: var(--secondary-color-1); +--button-alert-color-active: var(--theme-blue-color); +--button-alert-border-active: var(--theme-blue-color); + + + +/* ----------- PRIMARY BUTTONS --------------- */ +/* - button should immediately catch the eye - */ +--button-primary-text: var(--primary-color-1); +--button-primary-color: var(--theme-nuance-color-3); +--button-primary-border: var(--theme-nuance-color-3); + + +/* ---------hover---------- */ +--button-primary-text-hover: + hsl(192, + calc(var(--primary-color-1-saturation) - 100%), + calc(var(--primary-color-1-lightness) + 100%)); + +--button-primary-color-hover: + hsl(23.1, + calc(var(--theme-nuance-color-3-saturation) - 2%), + calc(var(--theme-nuance-color-3-lightness) - 10%)); + +--button-primary-border-hover: + hsl(23.1, + calc(var(--theme-nuance-color-3-saturation) - 2%), + calc(var(--theme-nuance-color-3-lightness) - 10%)); + + +/* ---------active--------- */ +--button-primary-text-active: + hsl(23.1, + calc(var(--theme-nuance-color-3-saturation) - 100%), + calc(var(--theme-nuance-color-3-lightness) + 100%)); + +--button-primary-color-active: + hsl(23.1, + calc(var(--theme-nuance-color-3-saturation) - 10%), + calc(var(--theme-nuance-color-3-lightness) - 15%)); + +--button-primary-border-active: + hsl(23.1, + calc(var(--theme-nuance-color-3-saturation) - 2%), + calc(var(--theme-nuance-color-3-lightness) + 10%)); + + + +/* ---------- SECONDARY BUTTONS -------------- */ +/* these should NOT immediately catch the eye */ +--button-secondary-text: var(--secondary-color-1); +--button-secondary-color: var(--primary-color-3); +--button-secondary-border: var(--primary-color-3); + + +/* ---------hover---------- */ +--button-secondary-text-hover: + hsl(23.1, + calc(var(--theme-nuance-color-3-saturation) - 20%), + calc(var(--theme-nuance-color-3-lightness) - 80%)); + +--button-secondary-color-hover: var(--primary-color-4); +--button-secondary-border-hover: var(--primary-color-4); + + +/* ---------active--------- */ +--button-secondary-text-active: var(--secondary-color-1); + +--button-secondary-color-active: + hsl(192, + calc(var(--primary-color-4-saturation) - 30%), + calc(var(--primary-color-4-lightness) - 15%)); + +--button-secondary-border-active: + hsl(192, + calc(var(--primary-color-4-saturation) - 30%), + calc(var(--primary-color-4-lightness) - 15%)); + + + +/* ---------- TERTIARY BUTTONS --------------- */ +/* ---------- disabled buttons --------------- */ +--button-tertiary-text: var(--primary-color-4); +--button-tertiary-color: var(--primary-color-2); +--button-tertiary-border: var(--primary-color-2); + + +/* ---------hover---------- */ +--button-tertiary-text: var(--primary-color-4); +--button-tertiary-color: var(--primary-color-2); +--button-tertiary-border: var(--primary-color-2); + +} diff --git a/tools/server/public_legacy/theme-playground.css b/tools/server/public_legacy/theme-playground.css new file mode 100644 index 0000000000000000000000000000000000000000..9d56a718248a46c227c2e47daa51967df142f35e --- /dev/null +++ b/tools/server/public_legacy/theme-playground.css @@ -0,0 +1,221 @@ +/* Author: Yazan Agha-Schrader */ +/* Inspiration from OpenAI's Playground platform https://platform.openai.com/playground/ */ + +.theme-playground { + +/* ---------- PRIMARY COLORS ----------------- */ +--primary-color-1: hsl(0, 0%, 99.2%); + --primary-color-1-hue: 0; + --primary-color-1-saturation: 0%; + --primary-color-1-lightness: 99.2%; + +--primary-color-2: hsl(0, 0%, 95%); + --primary-color-2-hue: 0; + --primary-color-2-saturation: 0%; + --primary-color-2-lightness: 95%; + +--primary-color-3: hsl(0, 0%, 88%); + --primary-color-3-hue: 0; + --primary-color-3-saturation: 0%; + --primary-color-3-lightness: 88%; + +--primary-color-4: hsl(0, 0%, 80%); + --primary-color-4-hue: 0; + --primary-color-4-saturation: 0%; + --primary-color-4-lightness: 80%; + + + +/* ---------- SECONDARY COLORS --------------- */ +--secondary-color-1: hsl(0, 0%, 20%); + --secondary-color-1-hue: 0; + --secondary-color-1-saturation: 0%; + --secondary-color-1-lightness: 20%; + +--secondary-color-2: hsl(0, 0%, 23.1%); + --secondary-color-2-hue: 0; + --secondary-color-2-saturation: 0%; + --secondary-color-2-lightness: 23.1%; + +--secondary-color-3: hsl(0, 0%, 29%); + --secondary-color-3-hue: 0; + --secondary-color-3-saturation: 0%; + --secondary-color-3-lightness: 29%; + +--secondary-color-4: hsl(0, 0%, 36.1%); + --secondary-color-4-hue: 0; + --secondary-color-4-saturation: 0%; + --secondary-color-4-lightness: 36.1%; + + + +/* ----------- NUANCES COLORS ---------------- */ +--theme-nuance-color-1: hsl(165.2, 82.1%, 35.1%); + --theme-nuance-color-1-hue: 165.2; + --theme-nuance-color-1-saturation: 82.1%; + --theme-nuance-color-1-lightness: 35.1%; + +--theme-nuance-color-2: hsl(165.2, 82.1%, 35.1%); + --theme-nuance-color-2-hue: 165.2; + --theme-nuance-color-2-saturation: 82.1%; + --theme-nuance-color-2-lightness: 35.1%; + +--theme-nuance-color-3: hsl(165.2, 81.1%, 35.3%); + --theme-nuance-color-3-hue: 165.2; + --theme-nuance-color-3-saturation: 81.1%; + --theme-nuance-color-3-lightness: 35.3%; + +--theme-nuance-color-4: hsl(164.9, 81.6%, 27.6%); + --theme-nuance-color-4-hue: 164.9; + --theme-nuance-color-4-saturation: 81.6%; + --theme-nuance-color-4-lightness: 27.6%; + + + +/* ----------- ROYGP COLORS ------------------ */ +--theme-red-color: hsl(0.3, 80%, 50%); +--theme-orange-color: #e76f51; +--theme-yellow-color: hsl(60, 70.6%, 73.3%); +--theme-green-color: #A3BE8C; +--theme-purple-color: hsl(0.3, 70%, 45%); + + + +/* ------------------------------------------- */ +--background-color-1: var(--primary-color-1); +--background-color-2: var(--primary-color-2); +--background-color-3: var(--primary-color-3); +--background-color-4: var(--primary-color-4); + +--border-color-1: var(--primary-color-2); +--border-color-2: var(--primary-color-3); +--border-color-3: var(--primary-color-4); + +--border-focus-color: var(--theme-nuance-color-2); +--border-focus-shadow: var(--theme-nuance-color-1); + +--text-color-plain: var(--secondary-color-1); +--text-color-subtile-1: var(--secondary-color-2); +--text-color-subtile-2: var(--secondary-color-3); + +--code-background-color: var(--secondary-color-2); +--code-text-color: var(--primary-color-2); + +--ui-range-thumb-color: var(--primary-color-4); +--ui-range-thumb-border: var(--ui-ranger-thumb-color); + +--textarea-border-color: var(--secondary-color-4); + +--chat-id-color: var(--theme-nuance-color-4); + + + +/* ------------------------------------------- */ +--button-alert-text-hover: var(--primary-color-1); +--button-alert-color-hover: var(--theme-purple-color); +--button-alert-border-hover: var(--theme-purple-color); + +--button-alert-text-active: var(--primary-color-1); +--button-alert-color-active: var(--theme-red-color); +--button-alert-border-active: var(--theme-red-color); + + + +/* ----------- PRIMARY BUTTONS --------------- */ +/* - button should immediately catch the eye - */ +--button-primary-text: + hsl(0, + calc(var(--primary-color-1-saturation) - 100%), + calc(var(--primary-color-1-lightness) + 100%)); + +--button-primary-color: var(--theme-nuance-color-3); +--button-primary-border: var(--theme-nuance-color-3); + + +/* ---------hover---------- */ +--button-primary-text-hover: + hsl(0, + calc(var(--primary-color-1-saturation) - 100%), + calc(var(--primary-color-1-lightness) + 100%)); + +--button-primary-color-hover: + hsl(165.2, + calc(var(--theme-nuance-color-3-saturation) - 2%), + calc(var(--theme-nuance-color-3-lightness) - 10%)); + +--button-primary-border-hover: + hsl(165.2, + calc(var(--theme-nuance-color-3-saturation) - 2%), + calc(var(--theme-nuance-color-3-lightness) - 10%)); + + +/* ---------active--------- */ +--button-primary-text-active: + hsl(165.2, + calc(var(--theme-nuance-color-3-saturation) - 100%), + calc(var(--theme-nuance-color-3-lightness) + 100%)); + +--button-primary-color-active: + hsl(165.2, + calc(var(--theme-nuance-color-3-saturation) - 10%), + calc(var(--theme-nuance-color-3-lightness) - 15%)); + +--button-primary-border-active: + hsl(165.2, + calc(var(--theme-nuance-color-3-saturation) - 2%), + calc(var(--theme-nuance-color-3-lightness) + 10%)); + + + +/* ---------- SECONDARY BUTTONS -------------- */ +/* these should NOT immediately catch the eye */ +--button-secondary-text: + hsl(165.2, + calc(var(--theme-nuance-color-3-saturation) - 20%), + calc(var(--theme-nuance-color-3-lightness) - 50%)); + +--button-secondary-color: var(--primary-color-3); +--button-secondary-border: var(--primary-color-3); + + +/* ---------hover---------- */ +--button-secondary-text-hover: + hsl(165.2, + calc(var(--theme-nuance-color-3-saturation) - 20%), + calc(var(--theme-nuance-color-3-lightness) - 80%)); + +--button-secondary-color-hover: var(--primary-color-4); +--button-secondary-border-hover: var(--primary-color-4); + + +/* ---------active--------- */ +--button-secondary-text-active: + hsl(165.2, + calc(var(--theme-nuance-color-3-saturation) - 20%), + calc(var(--theme-nuance-color-3-lightness) - 80%)); + +--button-secondary-color-active: + hsl(0, + calc(var(--primary-color-4-saturation) - 30%), + calc(var(--primary-color-4-lightness) - 15%)); + +--button-secondary-border-active: + hsl(0, + calc(var(--primary-color-4-saturation) - 30%), + calc(var(--primary-color-4-lightness) - 15%)); + + + +/* ---------- TERTIARY BUTTONS --------------- */ +/* ---------- disabled buttons --------------- */ +--button-tertiary-text: var(--primary-color-4); +--button-tertiary-color: var(--primary-color-2); +--button-tertiary-border: var(--primary-color-2); + + +/* ---------hover---------- */ +--button-tertiary-text: var(--primary-color-4); +--button-tertiary-color: var(--primary-color-2); +--button-tertiary-border: var(--primary-color-2); + +} diff --git a/tools/server/public_legacy/theme-polarnight.css b/tools/server/public_legacy/theme-polarnight.css new file mode 100644 index 0000000000000000000000000000000000000000..2bcfb33d8f1836244980c280dedacd0dc2d70cb9 --- /dev/null +++ b/tools/server/public_legacy/theme-polarnight.css @@ -0,0 +1,253 @@ +/* Author: Yazan Agha-Schrader */ +/* Inspiration from Nord Theme https://www.nordtheme.com/docs/colors-and-palettes */ + +.theme-polarnight { + +/* ---------- PRIMARY COLORS ----------------- */ +--primary-color-1: hsl(220.0, 16.4%, 21.6%) ; + --primary-color-1-hue: 220.0; + --primary-color-1-saturation: 16.4%; + --primary-color-1-lightness: 21.6%; + +--primary-color-2: hsl(221.7, 16.3%, 27.6%) ; + -primary-color-2-hue: 221.7; + --primary-color-2-saturation: 16.3%; + --primary-color-2-lightness: 27.6%; + +--primary-color-3: hsl(220.0, 16.8%, 31.6%) ; + --primary-color-3-hue: 220.0; + --primary-color-3-saturation: 16.8%; + --primary-color-3-lightness: 31.6%; + +--primary-color-4: hsl(220.0, 16.5%, 35.7%); + --primary-color-4-hue: 220.0; + --primary-color-4-saturation: 16.5%; + --primary-color-4-lightness: 35.7%; + + + +/* ---------- SECONDARY COLORS --------------- */ +--secondary-color-1: hsl(217.5, 26.7%, 94.1%); + --secondary-color-1-hue: 217.5; + --secondary-color-1-saturation: 26.7%; + --secondary-color-1-lightness: 94.1%; + +--secondary-color-2: hsl(218.2, 26.8%, 92.0%); + --secondary-color-2-hue: 218.2; + --secondary-color-2-saturation: 26.8%; + --secondary-color-2-lightness: 92.0%; + +--secondary-color-3: hsl(218.8, 27.9%, 88.0%); + --secondary-color-3-hue: 218.8; + --secondary-color-3-saturation: 27.9%; + --secondary-color-3-lightness: 88.0%; + +--secondary-color-4: hsl(218.8, 18.3%, 81.8%); + --secondary-color-4-hue: 218.8; + --secondary-color-4-saturation: 18.3%; + --secondary-color-4-lightness: 81.8%; + + + +/* ----------- NUANCES COLORS ---------------- */ +--theme-nuance-color-1: hsl(178.7, 25.1%, 64.9%); + --theme-nuance-color-1-hue: 178.7; + --theme-nuance-color-1-saturation: 25.1%; + --theme-nuance-color-1-lightness: 64.9%; + +--theme-nuance-color-2: hsl(193.3, 43.4%, 67.5%); + --theme-nuance-color-2-hue: 193.3; + --theme-nuance-color-2-saturation: 43.4%; + --theme-nuance-color-2-lightness: 67.5%; + +--theme-nuance-color-3: hsl(210.0, 34.0%, 63.1%); + --theme-nuance-color-3-hue: 210.0; + --theme-nuance-color-3-saturation: 34.0%; + --theme-nuance-color-3-lightness: 63.1%; + +--theme-nuance-color-4: hsl(213.1, 32.0%, 52.2%); + --theme-nuance-color-4-hue: 213.1; + --theme-nuance-color-4-saturation: 32.0%; + --theme-nuance-color-4-lightness: 52.2%; + + + +/* ----------- ROYGP COLORS ------------------ */ +--theme-red-color: hsl(354.3, 42.3%, 56.5%); +--theme-orange-color: hsl(20, 85%, 50%); +--theme-yellow-color: hsl(20, 75%, 45%); +--theme-green-color: hsl( 92.4, 27.8%, 64.7%); +--theme-purple-color: hsl(311.1, 20.2%, 63.1%); + + + +/* ------------------------------------------------ */ +--background-color-1: var(--primary-color-1); +--background-color-2: var(--primary-color-2); +--background-color-3: var(--primary-color-3); +--background-color-4: var(--primary-color-4); + +--border-color-1: var(--primary-color-2); +--border-color-2: var(--primary-color-3); +--border-color-3: var(--primary-color-4); + +--border-focus-color: var(--theme-nuance-color-2); +--border-focus-shadow: var(--theme-nuance-color-1); + +--text-color-plain: var(--secondary-color-1); +--text-color-subtile-1: var(--secondary-color-2); +--text-color-subtile-2: var(--secondary-color-3); + +--code-background-color: var(--secondary-color-2); +--code-text-color: var(--primary-color-2); + +--ui-range-thumb-color: var(--theme-nuance-color-3); +--ui-range-thumb-border: var(--ui-ranger-thumb-color); + +--textarea-border-color: var(--secondary-color-4); + +--chat-id-color: var(--theme-nuance-color-4); + + + +/* ------------------------------------------- */ +--button-alert-text-hover: var(--secondary-color-1); +--button-alert-color-hover: var(--theme-yellow-color); +--button-alert-border-hover: var(--theme-yellow-color); + +--button-alert-text-active: var(--secondary-color-1); +--button-alert-color-active: var(--theme-orange-color); +--button-alert-border-active: var(--theme-orange-color); + + + +/* ----------- PRIMARY BUTTONS --------------- */ +/* - button should immediately catch the eye - */ +--button-primary-text: var(--secondary-color-1); +--button-primary-color: var(--theme-nuance-color-3); +--button-primary-border: var(--theme-nuance-color-3); + + +/* ---------hover---------- */ +--button-primary-text-hover: + hsl(217.5, + calc(var(--secondary-color-1-saturation) - 35%), + calc(var(--secondary-color-1-lightness) + 30%)); + +--button-primary-color-hover: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 2%), + calc(var(--theme-nuance-color-3-lightness) - 10%)); + +--button-primary-border-hover: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 2%), + calc(var(--theme-nuance-color-3-lightness) - 10%)); + + +/* ---------active--------- */ +--button-primary-text-active: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 20%), + calc(var(--theme-nuance-color-3-lightness) + 35%)); + +--button-primary-color-active: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 10%), + calc(var(--theme-nuance-color-3-lightness) - 25%)); + +--button-primary-border-active: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 10%), + calc(var(--theme-nuance-color-3-lightness) - 25%)); + + + +/* ---------- SECONDARY BUTTONS -------------- */ +/* these should NOT immediately catch the eye */ +--button-secondary-text: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 20%), + calc(var(--theme-nuance-color-3-lightness) - 50%)); + +--button-secondary-color: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 20%), + calc(var(--theme-nuance-color-3-lightness) + 10%)); + +--button-secondary-border: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 20%), + calc(var(--theme-nuance-color-3-lightness) + 10%)); + + +/* ---------hover---------- */ +--button-secondary-text-hover: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 20%), + calc(var(--theme-nuance-color-3-lightness) - 80%)); + +--button-secondary-color-hover: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 22%), + calc(var(--theme-nuance-color-3-lightness) + 1%)); + +--button-secondary-border-hover: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 22%), + calc(var(--theme-nuance-color-3-lightness) + 1%)); + + +/* ---------active--------- */ +--button-secondary-text-active: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 20%), + calc(var(--theme-nuance-color-3-lightness) + 25%)); + +--button-secondary-color-active: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 30%), + calc(var(--theme-nuance-color-3-lightness) - 15%)); + +--button-secondary-border-active: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 30%), + calc(var(--theme-nuance-color-3-lightness) - 15%)); + + + +/* ---------- TERTIARY BUTTONS --------------- */ +/* ---------- disabled buttons --------------- */ +--button-tertiary-text: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 40%), + calc(var(--theme-nuance-color-3-lightness) - 5%)); + +--button-tertiary-color: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 40%), + calc(var(--theme-nuance-color-3-lightness) + 20%)); + +--button-tertiary-border: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 40%), + calc(var(--theme-nuance-color-3-lightness) + 20%)); + + +/* ---------hover---------- */ +--button-tertiary-text-hover: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 40%), + calc(var(--theme-nuance-color-3-lightness) - 5%)); + +--button-tertiary-color-hover: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 40%), + calc(var(--theme-nuance-color-3-lightness) + 20%)); + +--button-tertiary-border-hover: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 40%), + calc(var(--theme-nuance-color-3-lightness) + 20%)); + +} diff --git a/tools/server/public_legacy/theme-snowstorm.css b/tools/server/public_legacy/theme-snowstorm.css new file mode 100644 index 0000000000000000000000000000000000000000..7bb2275949273a05c4a99bae95ef5975af670e9f --- /dev/null +++ b/tools/server/public_legacy/theme-snowstorm.css @@ -0,0 +1,251 @@ +/* Author: Yazan Agha-Schrader */ +/* Inspiration from Nord Theme https://www.nordtheme.com/docs/colors-and-palettes */ + +.theme-snowstorm { + +/* ---------- PRIMARY COLORS ----------------- */ +--primary-color-1: hsl(217.5, 26.7%, 94.1%); + --primary-color-1-hue: 217.5; + --primary-color-1-saturation: 26.7%; + --primary-color-1-lightness: 94.1%; + +--primary-color-2: hsl(218.2, 26.8%, 92.0%); + --primary-color-2-hue: 218.2; + --primary-color-2-saturation: 26.8%; + --primary-color-2-lightness: 92.0%; + +--primary-color-3: hsl(218.8, 27.9%, 88.0%); + --primary-color-3-hue: 218.8; + --primary-color-3-saturation: 27.9%; + --primary-color-3-lightness: 88.0%; + +--primary-color-4: hsl(218.8, 18.3%, 81.8%); + --primary-color-4-hue: 218.8; + --primary-color-4-saturation: 18.3%; + --primary-color-4-lightness: 81.8%; + + +/* ---------- SECONDARY COLORS --------------- */ +--secondary-color-1: hsl(220.0, 16.4%, 21.6%); + --secondary-color-1-hue: 220.0; + --secondary-color-1-saturation: 16.4%; + --secondary-color-1-lightness: 21.6%; + +--secondary-color-2: hsl(221.7, 16.3%, 27.6%); + --secondary-color-2-hue: 221.7; + --secondary-color-2-saturation: 16.3%; + --secondary-color-2-lightness: 27.6%; + +--secondary-color-3: hsl(220.0, 16.8%, 31.6%); + --secondary-color-3-hue: 220.0; + --secondary-color-3-saturation: 16.8%; + --secondary-color-3-lightness: 31.6%; + +--secondary-color-4: hsl(220.0, 16.5%, 35.7%); + --secondary-color-4-hue: 220.0; + --secondary-color-4-saturation: 16.5%; + --secondary-color-4-lightness: 35.7%; + + + +/* ----------- NUANCES COLORS ---------------- */ +--theme-nuance-color-1: hsl(178.7, 25.1%, 64.9%); + --theme-nuance-color-1-hue: 178.7; + --theme-nuance-color-1-saturation: 25.1%; + --theme-nuance-color-1-lightness: 64.9%; + +--theme-nuance-color-2: hsl(193.3, 43.4%, 67.5%); + --theme-nuance-color-2-hue: 193.3; + --theme-nuance-color-2-saturation: 43.4%; + --theme-nuance-color-2-lightness: 67.5%; + +--theme-nuance-color-3: hsl(210.0, 34.0%, 63.1%); + --theme-nuance-color-3-hue: 210.0; + --theme-nuance-color-3-saturation: 34.0%; + --theme-nuance-color-3-lightness: 63.1%; + +--theme-nuance-color-4: hsl(213.1, 32.0%, 52.2%); + --theme-nuance-color-4-hue: 213.1; + --theme-nuance-color-4-saturation: 32.0%; + --theme-nuance-color-4-lightness: 52.2%; + + + +/* ----------- ROYGP COLORS ------------------ */ +--theme-red-color: hsl(32.5, 80%, 50%); +--theme-orange-color: hsl(32.5, 70%, 45%); +--theme-yellow-color: hsl(40.0, 0.6%, 73.3%); +--theme-green-color: hsl(92.4, 27.8%, 64.7%); +--theme-purple-color: hsl(311.1, 20.2%, 63.1%); + + + +/* ------------------------------------------- */ +--background-color-1: var(--primary-color-1); +--background-color-2: var(--primary-color-2); +--background-color-3: var(--primary-color-3); +--background-color-4: var(--primary-color-4); + +--border-color-1: var(--primary-color-2); +--border-color-2: var(--primary-color-3); +--border-color-3: var(--primary-color-4); + +--border-focus-color: var(--theme-nuance-color-2); +--border-focus-shadow: var(--theme-nuance-color-1); + +--text-color-plain: var(--secondary-color-1); +--text-color-subtile-1: var(--secondary-color-2); +--text-color-subtile-2: var(--secondary-color-3); + +--code-background-color: var(--secondary-color-2); +--code-text-color: var(--primary-color-2); + +--ui-range-thumb-color: var(--theme-nuance-color-3); +--ui-range-thumb-border: var(--ui-ranger-thumb-color); + +--textarea-border-color: var(--secondary-color-4); + +--chat-id-color: var(--theme-nuance-color-4); + + + +/* ------------------------------------------- */ +--button-alert-text-hover: var(--primary-color-1); +--button-alert-color-hover: var(--theme-orange-color); +--button-alert-border-hover: var(--theme-orange-color); + +--button-alert-text-active: var(--primary-color-1); +--button-alert-color-active: var(--theme-red-color); +--button-alert-border-active: var(--theme-red-color); + + + +/* ----------- PRIMARY BUTTONS --------------- */ +/* - button should immediately catch the eye - */ +--button-primary-text: var(--secondary-color-1); +--button-primary-color: var(--theme-nuance-color-3); +--button-primary-border: var(--theme-nuance-color-3); + + +/* ---------hover---------- */ +--button-primary-text-hover: + hsl(217.5, + calc(var(--secondary-color-1-saturation) + 35%), + calc(var(--secondary-color-1-lightness) - 30%)); + +--button-primary-color-hover: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 2%), + calc(var(--theme-nuance-color-3-lightness) - 10%)); + +--button-primary-border-hover: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 2%), + calc(var(--theme-nuance-color-3-lightness) - 10%)); + + +/* ---------active--------- */ +--button-primary-text-active: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 20%), + calc(var(--theme-nuance-color-3-lightness) + 35%)); + +--button-primary-color-active: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 10%), + calc(var(--theme-nuance-color-3-lightness) - 25%)); + +--button-primary-border-active: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 10%), + calc(var(--theme-nuance-color-3-lightness) - 25%)); + + + +/* ---------- SECONDARY BUTTONS -------------- */ +/* these should NOT immediately catch the eye */ +--button-secondary-text: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 20%), + calc(var(--theme-nuance-color-3-lightness) - 50%)); + +--button-secondary-color: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 20%), + calc(var(--theme-nuance-color-3-lightness) + 10%)); + +--button-secondary-border: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 20%), + calc(var(--theme-nuance-color-3-lightness) + 10%)); + + +/* ---------hover---------- */ +--button-secondary-text-hover: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 20%), + calc(var(--theme-nuance-color-3-lightness) - 80%)); + +--button-secondary-color-hover: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 22%), + calc(var(--theme-nuance-color-3-lightness) + 1%)); + +--button-secondary-border-hover: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 22%), + calc(var(--theme-nuance-color-3-lightness) + 1%)); + + +/* ---------active--------- */ +--button-secondary-text-active: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) + 40%), + calc(var(--theme-nuance-color-3-lightness) - 55%)); + +--button-secondary-color-active: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 30%), + calc(var(--theme-nuance-color-3-lightness) - 5%)); + +--button-secondary-border-active: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 30%), + calc(var(--theme-nuance-color-3-lightness) - 5%)); + + + +/* ---------- TERTIARY BUTTONS --------------- */ +/* ---------- disabled buttons --------------- */ +--button-tertiary-text: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 40%), + calc(var(--theme-nuance-color-3-lightness) - 5%)); + +--button-tertiary-color: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 40%), + calc(var(--theme-nuance-color-3-lightness) + 20%)); + +--button-tertiary-border: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 40%), + calc(var(--theme-nuance-color-3-lightness) + 20%)); + +/* ---------hover---------- */ +--button-tertiary-text-hover: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 40%), + calc(var(--theme-nuance-color-3-lightness) - 5%)); + +--button-tertiary-color-hover: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 40%), + calc(var(--theme-nuance-color-3-lightness) + 20%)); + +--button-tertiary-border-hover: + hsl(210, + calc(var(--theme-nuance-color-3-saturation) - 40%), + calc(var(--theme-nuance-color-3-lightness) + 20%)); + +} diff --git a/tools/server/public_simplechat/datautils.mjs b/tools/server/public_simplechat/datautils.mjs new file mode 100644 index 0000000000000000000000000000000000000000..75159d6b1676bfeb0ab01014f5d04a8a1f5f0500 --- /dev/null +++ b/tools/server/public_simplechat/datautils.mjs @@ -0,0 +1,266 @@ +//@ts-check +// Helpers to work with different data types +// by Humans for All +// + +/** + * Given the limited context size of local LLMs and , many a times when context gets filled + * between the prompt and the response, it can lead to repeating text garbage generation. + * And many a times setting penalty wrt repeatation leads to over-intelligent garbage + * repeatation with slight variations. These garbage inturn can lead to overloading of the + * available model context, leading to less valuable response for subsequent prompts/queries, + * if chat history is sent to ai model. + * + * So two simple minded garbage trimming logics are experimented below. + * * one based on progressively-larger-substring-based-repeat-matching-with-partial-skip and + * * another based on char-histogram-driven garbage trimming. + * * in future characteristic of histogram over varying lengths could be used to allow for + * a more aggressive and adaptive trimming logic. + */ + + +/** + * Simple minded logic to help remove repeating garbage at end of the string. + * The repeatation needs to be perfectly matching. + * + * The logic progressively goes on probing for longer and longer substring based + * repeatation, till there is no longer repeatation. Inturn picks the one with + * the longest chain. + * + * @param {string} sIn + * @param {number} maxSubL + * @param {number} maxMatchLenThreshold + */ +export function trim_repeat_garbage_at_end(sIn, maxSubL=10, maxMatchLenThreshold=40) { + let rCnt = [0]; + let maxMatchLen = maxSubL; + let iMML = -1; + for(let subL=1; subL < maxSubL; subL++) { + rCnt.push(0); + let i; + let refS = sIn.substring(sIn.length-subL, sIn.length); + for(i=sIn.length; i > 0; i -= subL) { + let curS = sIn.substring(i-subL, i); + if (refS != curS) { + let curMatchLen = rCnt[subL]*subL; + if (maxMatchLen < curMatchLen) { + maxMatchLen = curMatchLen; + iMML = subL; + } + break; + } + rCnt[subL] += 1; + } + } + console.debug("DBUG:DU:TrimRepeatGarbage:", rCnt); + if ((iMML == -1) || (maxMatchLen < maxMatchLenThreshold)) { + return {trimmed: false, data: sIn}; + } + console.debug("DBUG:TrimRepeatGarbage:TrimmedCharLen:", maxMatchLen); + let iEnd = sIn.length - maxMatchLen; + return { trimmed: true, data: sIn.substring(0, iEnd) }; +} + + +/** + * Simple minded logic to help remove repeating garbage at end of the string, till it cant. + * If its not able to trim, then it will try to skip a char at end and then trim, a few times. + * This ensures that even if there are multiple runs of garbage with different patterns, the + * logic still tries to munch through them. + * + * @param {string} sIn + * @param {number} maxSubL + * @param {number | undefined} [maxMatchLenThreshold] + */ +export function trim_repeat_garbage_at_end_loop(sIn, maxSubL, maxMatchLenThreshold, skipMax=16) { + let sCur = sIn; + let sSaved = ""; + let iTry = 0; + while(true) { + let got = trim_repeat_garbage_at_end(sCur, maxSubL, maxMatchLenThreshold); + if (got.trimmed != true) { + if (iTry == 0) { + sSaved = got.data; + } + iTry += 1; + if (iTry >= skipMax) { + return sSaved; + } + got.data = got.data.substring(0,got.data.length-1); + } else { + iTry = 0; + } + sCur = got.data; + } +} + + +/** + * A simple minded try trim garbage at end using histogram driven characteristics. + * There can be variation in the repeatations, as long as no new char props up. + * + * This tracks the chars and their frequency in a specified length of substring at the end + * and inturn checks if moving further into the generated text from the end remains within + * the same char subset or goes beyond it and based on that either trims the string at the + * end or not. This allows to filter garbage at the end, including even if there are certain + * kind of small variations in the repeated text wrt position of seen chars. + * + * Allow the garbage to contain upto maxUniq chars, but at the same time ensure that + * a given type of char ie numerals or alphabets or other types dont cross the specified + * maxType limit. This allows intermixed text garbage to be identified and trimmed. + * + * ALERT: This is not perfect and only provides a rough garbage identification logic. + * Also it currently only differentiates between character classes wrt english. + * + * @param {string} sIn + * @param {number} maxType + * @param {number} maxUniq + * @param {number} maxMatchLenThreshold + */ +export function trim_hist_garbage_at_end(sIn, maxType, maxUniq, maxMatchLenThreshold) { + if (sIn.length < maxMatchLenThreshold) { + return { trimmed: false, data: sIn }; + } + let iAlp = 0; + let iNum = 0; + let iOth = 0; + // Learn + let hist = {}; + let iUniq = 0; + for(let i=0; i= maxUniq) { + break; + } + hist[c] = 1; + } + } + console.debug("DBUG:TrimHistGarbage:", hist); + if ((iAlp > maxType) || (iNum > maxType) || (iOth > maxType)) { + return { trimmed: false, data: sIn }; + } + // Catch and Trim + for(let i=0; i < sIn.length; i++) { + let c = sIn[sIn.length-1-i]; + if (!(c in hist)) { + if (i < maxMatchLenThreshold) { + return { trimmed: false, data: sIn }; + } + console.debug("DBUG:TrimHistGarbage:TrimmedCharLen:", i); + return { trimmed: true, data: sIn.substring(0, sIn.length-i+1) }; + } + } + console.debug("DBUG:TrimHistGarbage:Trimmed fully"); + return { trimmed: true, data: "" }; +} + +/** + * Keep trimming repeatedly using hist_garbage logic, till you no longer can. + * This ensures that even if there are multiple runs of garbage with different patterns, + * the logic still tries to munch through them. + * + * @param {any} sIn + * @param {number} maxType + * @param {number} maxUniq + * @param {number} maxMatchLenThreshold + */ +export function trim_hist_garbage_at_end_loop(sIn, maxType, maxUniq, maxMatchLenThreshold) { + let sCur = sIn; + while (true) { + let got = trim_hist_garbage_at_end(sCur, maxType, maxUniq, maxMatchLenThreshold); + if (!got.trimmed) { + return got.data; + } + sCur = got.data; + } +} + +/** + * Try trim garbage at the end by using both the hist-driven-garbage-trimming as well as + * skip-a-bit-if-reqd-then-repeat-pattern-based-garbage-trimming, with blind retrying. + * @param {string} sIn + */ +export function trim_garbage_at_end(sIn) { + let sCur = sIn; + for(let i=0; i<2; i++) { + sCur = trim_hist_garbage_at_end_loop(sCur, 8, 24, 72); + sCur = trim_repeat_garbage_at_end_loop(sCur, 32, 72, 12); + } + return sCur; +} + + +/** + * NewLines array helper. + * Allow for maintaining a list of lines. + * Allow for a line to be builtup/appended part by part. + */ +export class NewLines { + + constructor() { + /** @type {string[]} */ + this.lines = []; + } + + /** + * Extracts lines from the passed string and inturn either + * append to a previous partial line or add a new line. + * @param {string} sLines + */ + add_append(sLines) { + let aLines = sLines.split("\n"); + let lCnt = 0; + for(let line of aLines) { + lCnt += 1; + // Add back newline removed if any during split + if (lCnt < aLines.length) { + line += "\n"; + } else { + if (sLines.endsWith("\n")) { + line += "\n"; + } + } + // Append if required + if (lCnt == 1) { + let lastLine = this.lines[this.lines.length-1]; + if (lastLine != undefined) { + if (!lastLine.endsWith("\n")) { + this.lines[this.lines.length-1] += line; + continue; + } + } + } + // Add new line + this.lines.push(line); + } + } + + /** + * Shift the oldest/earliest/0th line in the array. [Old-New|Earliest-Latest] + * Optionally control whether only full lines (ie those with newline at end) will be returned + * or will a partial line without a newline at end (can only be the last line) be returned. + * @param {boolean} bFullWithNewLineOnly + */ + shift(bFullWithNewLineOnly=true) { + let line = this.lines[0]; + if (line == undefined) { + return undefined; + } + if ((line[line.length-1] != "\n") && bFullWithNewLineOnly){ + return undefined; + } + return this.lines.shift(); + } + +} diff --git a/tools/server/public_simplechat/index.html b/tools/server/public_simplechat/index.html new file mode 100644 index 0000000000000000000000000000000000000000..f6413016fcc53b0f45cbd536340f8e674848d39b --- /dev/null +++ b/tools/server/public_simplechat/index.html @@ -0,0 +1,51 @@ + + + + SimpleChat LlamaCppEtal + + + + + + + + + + + +
+ +
+

SimpleChat

+ +
+ +
+ +
+
+ + +
+ +
+
+

You need to have javascript enabled.

+
+ +
+
+ + +
+ +
+ + diff --git a/tools/server/public_simplechat/readme.md b/tools/server/public_simplechat/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..24e026d455b03f8691f6733dbed8f5eacbc14af0 --- /dev/null +++ b/tools/server/public_simplechat/readme.md @@ -0,0 +1,286 @@ + +# SimpleChat + +by Humans for All. + +## quickstart + +To run from the build dir + +bin/llama-server -m path/model.gguf --path ../tools/server/public_simplechat + +Continue reading for the details. + +## overview + +This simple web frontend, allows triggering/testing the server's /completions or /chat/completions endpoints +in a simple way with minimal code from a common code base. Inturn additionally it tries to allow single or +multiple independent back and forth chatting to an extent, with the ai llm model at a basic level, with their +own system prompts. + +This allows seeing the generated text / ai-model response in oneshot at the end, after it is fully generated, +or potentially as it is being generated, in a streamed manner from the server/ai-model. + +![Chat and Settings screens](./simplechat_screens.webp "Chat and Settings screens") + +Auto saves the chat session locally as and when the chat is progressing and inturn at a later time when you +open SimpleChat, option is provided to restore the old chat session, if a matching one exists. + +The UI follows a responsive web design so that the layout can adapt to available display space in a usable +enough manner, in general. + +Allows developer/end-user to control some of the behaviour by updating gMe members from browser's devel-tool +console. Parallely some of the directly useful to end-user settings can also be changed using the provided +settings ui. + +NOTE: Current web service api doesnt expose the model context length directly, so client logic doesnt provide +any adaptive culling of old messages nor of replacing them with summary of their content etal. However there +is a optional sliding window based chat logic, which provides a simple minded culling of old messages from +the chat history before sending to the ai model. + +NOTE: Wrt options sent with the request, it mainly sets temperature, max_tokens and optionaly stream for now. +However if someone wants they can update the js file or equivalent member in gMe as needed. + +NOTE: One may be able to use this to chat with openai api web-service /chat/completions endpoint, in a very +limited / minimal way. One will need to set model, openai url and authorization bearer key in settings ui. + + +## usage + +One could run this web frontend directly using server itself or if anyone is thinking of adding a built in web +frontend to configure the server over http(s) or so, then run this web frontend using something like python's +http module. + +### running using tools/server + +./llama-server -m path/model.gguf --path tools/server/public_simplechat [--port PORT] + +### running using python3's server module + +first run tools/server +* ./llama-server -m path/model.gguf + +next run this web front end in tools/server/public_simplechat +* cd ../tools/server/public_simplechat +* python3 -m http.server PORT + +### using the front end + +Open this simple web front end from your local browser + +* http://127.0.0.1:PORT/index.html + +Once inside + +* If you want to, you can change many of the default global settings + * the base url (ie ip addr / domain name, port) + * chat (default) vs completion mode + * try trim garbage in response or not + * amount of chat history in the context sent to server/ai-model + * oneshot or streamed mode. + +* In completion mode + * one normally doesnt use a system prompt in completion mode. + * logic by default doesnt insert any role specific "ROLE: " prefix wrt each role's message. + If the model requires any prefix wrt user role messages, then the end user has to + explicitly add the needed prefix, when they enter their chat message. + Similarly if the model requires any prefix to trigger assistant/ai-model response, + then the end user needs to enter the same. + This keeps the logic simple, while still giving flexibility to the end user to + manage any templating/tagging requirement wrt their messages to the model. + * the logic doesnt insert newline at the begining and end wrt the prompt message generated. + However if the chat being sent to /completions end point has more than one role's message, + then insert newline when moving from one role's message to the next role's message, so + that it can be clearly identified/distinguished. + * given that /completions endpoint normally doesnt add additional chat-templating of its + own, the above ensures that end user can create a custom single/multi message combo with + any tags/special-tokens related chat templating to test out model handshake. Or enduser + can use it just for normal completion related/based query. + +* If you want to provide a system prompt, then ideally enter it first, before entering any user query. + Normally Completion mode doesnt need system prompt, while Chat mode can generate better/interesting + responses with a suitable system prompt. + * if chat.add_system_begin is used + * you cant change the system prompt, after it is has been submitted once along with user query. + * you cant set a system prompt, after you have submitted any user query + * if chat.add_system_anytime is used + * one can change the system prompt any time during chat, by changing the contents of system prompt. + * inturn the updated/changed system prompt will be inserted into the chat session. + * this allows for the subsequent user chatting to be driven by the new system prompt set above. + +* Enter your query and either press enter or click on the submit button. + If you want to insert enter (\n) as part of your chat/query to ai model, use shift+enter. + +* Wait for the logic to communicate with the server and get the response. + * the user is not allowed to enter any fresh query during this time. + * the user input box will be disabled and a working message will be shown in it. + * if trim garbage is enabled, the logic will try to trim repeating text kind of garbage to some extent. + +* just refresh the page, to reset wrt the chat history and or system prompt and start afresh. + +* Using NewChat one can start independent chat sessions. + * two independent chat sessions are setup by default. + +* When you want to print, switching ChatHistoryInCtxt to Full and clicking on the chat session button of + interest, will display the full chat history till then wrt same, if you want full history for printing. + + +## Devel note + +### Reason behind this + +The idea is to be easy enough to use for basic purposes, while also being simple and easily discernable +by developers who may not be from web frontend background (so inturn may not be familiar with template / +end-use-specific-language-extensions driven flows) so that they can use it to explore/experiment things. + +And given that the idea is also to help explore/experiment for developers, some flexibility is provided +to change behaviour easily using the devel-tools/console or provided minimal settings ui (wrt few aspects). +Skeletal logic has been implemented to explore some of the end points and ideas/implications around them. + + +### General + +Me/gMe consolidates the settings which control the behaviour into one object. +One can see the current settings, as well as change/update them using browsers devel-tool/console. +It is attached to the document object. Some of these can also be updated using the Settings UI. + + baseURL - the domain-name/ip-address and inturn the port to send the request. + + bStream - control between oneshot-at-end and live-stream-as-its-generated collating and showing + of the generated response. + + the logic assumes that the text sent from the server follows utf-8 encoding. + + in streaming mode - if there is any exception, the logic traps the same and tries to ensure + that text generated till then is not lost. + + if a very long text is being generated, which leads to no user interaction for sometime and + inturn the machine goes into power saving mode or so, the platform may stop network connection, + leading to exception. + + apiEP - select between /completions and /chat/completions endpoint provided by the server/ai-model. + + bCompletionFreshChatAlways - whether Completion mode collates complete/sliding-window history when + communicating with the server or only sends the latest user query/message. + + bCompletionInsertStandardRolePrefix - whether Completion mode inserts role related prefix wrt the + messages that get inserted into prompt field wrt /Completion endpoint. + + bTrimGarbage - whether garbage repeatation at the end of the generated ai response, should be + trimmed or left as is. If enabled, it will be trimmed so that it wont be sent back as part of + subsequent chat history. At the same time the actual trimmed text is shown to the user, once + when it was generated, so user can check if any useful info/data was there in the response. + + One may be able to request the ai-model to continue (wrt the last response) (if chat-history + is enabled as part of the chat-history-in-context setting), and chances are the ai-model will + continue starting from the trimmed part, thus allows long response to be recovered/continued + indirectly, in many cases. + + The histogram/freq based trimming logic is currently tuned for english language wrt its + is-it-a-alpabetic|numeral-char regex match logic. + + apiRequestOptions - maintains the list of options/fields to send along with api request, + irrespective of whether /chat/completions or /completions endpoint. + + If you want to add additional options/fields to send to the server/ai-model, and or + modify the existing options value or remove them, for now you can update this global var + using browser's development-tools/console. + + For string, numeric and boolean fields in apiRequestOptions, including even those added by a + user at runtime by directly modifying gMe.apiRequestOptions, setting ui entries will be auto + created. + + cache_prompt option supported by example/server is allowed to be controlled by user, so that + any caching supported wrt system-prompt and chat history, if usable can get used. When chat + history sliding window is enabled, cache_prompt logic may or may not kick in at the backend + wrt same, based on aspects related to model, positional encoding, attention mechanism etal. + However system prompt should ideally get the benefit of caching. + + headers - maintains the list of http headers sent when request is made to the server. By default + Content-Type is set to application/json. Additionally Authorization entry is provided, which can + be set if needed using the settings ui. + + iRecentUserMsgCnt - a simple minded SlidingWindow to limit context window load at Ai Model end. + This is disabled by default. However if enabled, then in addition to latest system message, only + the last/latest iRecentUserMsgCnt user messages after the latest system prompt and its responses + from the ai model will be sent to the ai-model, when querying for a new response. IE if enabled, + only user messages after the latest system message/prompt will be considered. + + This specified sliding window user message count also includes the latest user query. + <0 : Send entire chat history to server + 0 : Send only the system message if any to the server + >0 : Send the latest chat history from the latest system prompt, limited to specified cnt. + + +By using gMe's iRecentUserMsgCnt and apiRequestOptions.max_tokens/n_predict one can try to control +the implications of loading of the ai-model's context window by chat history, wrt chat response to +some extent in a simple crude way. You may also want to control the context size enabled when the +server loads ai-model, on the server end. + + +Sometimes the browser may be stuborn with caching of the file, so your updates to html/css/js +may not be visible. Also remember that just refreshing/reloading page in browser or for that +matter clearing site data, dont directly override site caching in all cases. Worst case you may +have to change port. Or in dev tools of browser, you may be able to disable caching fully. + + +Currently the server to communicate with is maintained globally and not as part of a specific +chat session. So if one changes the server ip/url in setting, then all chat sessions will auto +switch to this new server, when you try using those sessions. + + +By switching between chat.add_system_begin/anytime, one can control whether one can change +the system prompt, anytime during the conversation or only at the beginning. + + +### Default setup + +By default things are setup to try and make the user experience a bit better, if possible. +However a developer when testing the server of ai-model may want to change these value. + +Using iRecentUserMsgCnt reduce chat history context sent to the server/ai-model to be +just the system-prompt, prev-user-request-and-ai-response and cur-user-request, instead of +full chat history. This way if there is any response with garbage/repeatation, it doesnt +mess with things beyond the next question/request/query, in some ways. The trim garbage +option also tries to help avoid issues with garbage in the context to an extent. + +Set max_tokens to 1024, so that a relatively large previous reponse doesnt eat up the space +available wrt next query-response. However dont forget that the server when started should +also be started with a model context size of 1k or more, to be on safe side. + + The /completions endpoint of tools/server doesnt take max_tokens, instead it takes the + internal n_predict, for now add the same here on the client side, maybe later add max_tokens + to /completions endpoint handling code on server side. + +NOTE: One may want to experiment with frequency/presence penalty fields in apiRequestOptions +wrt the set of fields sent to server along with the user query, to check how the model behaves +wrt repeatations in general in the generated text response. + +A end-user can change these behaviour by editing gMe from browser's devel-tool/console or by +using the provided settings ui (for settings exposed through the ui). + + +### OpenAi / Equivalent API WebService + +One may be abe to handshake with OpenAI/Equivalent api web service's /chat/completions endpoint +for a minimal chatting experimentation by setting the below. + +* the baseUrl in settings ui + * https://api.openai.com/v1 or similar + +* Wrt request body - gMe.apiRequestOptions + * model (settings ui) + * any additional fields if required in future + +* Wrt request headers - gMe.headers + * Authorization (available through settings ui) + * Bearer THE_OPENAI_API_KEY + * any additional optional header entries like "OpenAI-Organization", "OpenAI-Project" or so + +NOTE: Not tested, as there is no free tier api testing available. However logically this might +work. + + +## At the end + +Also a thank you to all open source and open model developers, who strive for the common good. diff --git a/tools/server/public_simplechat/simplechat.css b/tools/server/public_simplechat/simplechat.css new file mode 100644 index 0000000000000000000000000000000000000000..13bfb80b48be80f2b4c3fdad88c9144dbcd40894 --- /dev/null +++ b/tools/server/public_simplechat/simplechat.css @@ -0,0 +1,79 @@ +/** + * the styling of the simplechat web frontend + * by Humans for All + */ + +#fullbody { + height: 98vh; +} + +.heading { + background-color: lightgray; +} + +.session-selected { + background-color: lightblue; +} + +.role-system { + background-color: lightblue; +} +.role-user { + background-color: lightgray; +} +.role-trim { + background-color: lightpink; +} + +.gridx2 { + display: grid; + grid-template-columns: repeat(2, 1fr); + border-bottom-style: dotted; + border-bottom-width: thin; + border-bottom-color: lightblue; +} + +.flex-grow { + flex-grow: 1; +} +.float-right { + float: right; +} + +#chat-div { + overflow: scroll; + flex-grow: 1; + flex-shrink: 1; + min-height: 40vh; +} +button { + min-width: 8vw; +} + +.sameline { + display: flex; + flex-direction: row; +} +.samecolumn { + display: flex; + flex-direction: column; +} + +.ul1 { + padding-inline-start: 2vw; +} +.ul2 { + padding-inline-start: 2vw; +} + +* { + margin: 0.6vmin; +} + +@media print { + + #fullbody { + height: auto; + } + +} diff --git a/tools/server/public_simplechat/simplechat.js b/tools/server/public_simplechat/simplechat.js new file mode 100644 index 0000000000000000000000000000000000000000..2fcd24a860bd4d9130d314cf12604ac9b3b60ea0 --- /dev/null +++ b/tools/server/public_simplechat/simplechat.js @@ -0,0 +1,929 @@ +// @ts-check +// A simple completions and chat/completions test related web front end logic +// by Humans for All + +import * as du from "./datautils.mjs"; +import * as ui from "./ui.mjs" + +class Roles { + static System = "system"; + static User = "user"; + static Assistant = "assistant"; +} + +class ApiEP { + static Type = { + Chat: "chat", + Completion: "completion", + } + static UrlSuffix = { + 'chat': `/chat/completions`, + 'completion': `/completions`, + } + + /** + * Build the url from given baseUrl and apiEp id. + * @param {string} baseUrl + * @param {string} apiEP + */ + static Url(baseUrl, apiEP) { + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.substring(0, baseUrl.length-1); + } + return `${baseUrl}${this.UrlSuffix[apiEP]}`; + } + +} + + +let gUsageMsg = ` +

Usage

+
    +
  • System prompt above, to try control ai response characteristics.
  • +
      +
    • Completion mode - no system prompt normally.
    • +
    +
  • Use shift+enter for inserting enter/newline.
  • +
  • Enter your query to ai assistant below.
  • +
  • Default ContextWindow = [System, Last Query+Resp, Cur Query].
  • +
      +
    • ChatHistInCtxt, MaxTokens, ModelCtxt window to expand
    • +
    +
+`; + + +/** @typedef {{role: string, content: string}[]} ChatMessages */ + +/** @typedef {{iLastSys: number, xchat: ChatMessages}} SimpleChatODS */ + +class SimpleChat { + + /** + * @param {string} chatId + */ + constructor(chatId) { + this.chatId = chatId; + /** + * Maintain in a form suitable for common LLM web service chat/completions' messages entry + * @type {ChatMessages} + */ + this.xchat = []; + this.iLastSys = -1; + this.latestResponse = ""; + } + + clear() { + this.xchat = []; + this.iLastSys = -1; + } + + ods_key() { + return `SimpleChat-${this.chatId}` + } + + save() { + /** @type {SimpleChatODS} */ + let ods = {iLastSys: this.iLastSys, xchat: this.xchat}; + localStorage.setItem(this.ods_key(), JSON.stringify(ods)); + } + + load() { + let sods = localStorage.getItem(this.ods_key()); + if (sods == null) { + return; + } + /** @type {SimpleChatODS} */ + let ods = JSON.parse(sods); + this.iLastSys = ods.iLastSys; + this.xchat = ods.xchat; + } + + /** + * Recent chat messages. + * If iRecentUserMsgCnt < 0 + * Then return the full chat history + * Else + * Return chat messages from latest going back till the last/latest system prompt. + * While keeping track that the number of user queries/messages doesnt exceed iRecentUserMsgCnt. + * @param {number} iRecentUserMsgCnt + */ + recent_chat(iRecentUserMsgCnt) { + if (iRecentUserMsgCnt < 0) { + return this.xchat; + } + if (iRecentUserMsgCnt == 0) { + console.warn("WARN:SimpleChat:SC:RecentChat:iRecentUsermsgCnt of 0 means no user message/query sent"); + } + /** @type{ChatMessages} */ + let rchat = []; + let sysMsg = this.get_system_latest(); + if (sysMsg.length != 0) { + rchat.push({role: Roles.System, content: sysMsg}); + } + let iUserCnt = 0; + let iStart = this.xchat.length; + for(let i=this.xchat.length-1; i > this.iLastSys; i--) { + if (iUserCnt >= iRecentUserMsgCnt) { + break; + } + let msg = this.xchat[i]; + if (msg.role == Roles.User) { + iStart = i; + iUserCnt += 1; + } + } + for(let i = iStart; i < this.xchat.length; i++) { + let msg = this.xchat[i]; + if (msg.role == Roles.System) { + continue; + } + rchat.push({role: msg.role, content: msg.content}); + } + return rchat; + } + + /** + * Collate the latest response from the server/ai-model, as it is becoming available. + * This is mainly useful for the stream mode. + * @param {string} content + */ + append_response(content) { + this.latestResponse += content; + } + + /** + * Add an entry into xchat + * @param {string} role + * @param {string|undefined|null} content + */ + add(role, content) { + if ((content == undefined) || (content == null) || (content == "")) { + return false; + } + this.xchat.push( {role: role, content: content} ); + if (role == Roles.System) { + this.iLastSys = this.xchat.length - 1; + } + this.save(); + return true; + } + + /** + * Show the contents in the specified div + * @param {HTMLDivElement} div + * @param {boolean} bClear + */ + show(div, bClear=true) { + if (bClear) { + div.replaceChildren(); + } + let last = undefined; + for(const x of this.recent_chat(gMe.iRecentUserMsgCnt)) { + let entry = ui.el_create_append_p(`${x.role}: ${x.content}`, div); + entry.className = `role-${x.role}`; + last = entry; + } + if (last !== undefined) { + last.scrollIntoView(false); + } else { + if (bClear) { + div.innerHTML = gUsageMsg; + gMe.setup_load(div, this); + gMe.show_info(div); + } + } + return last; + } + + /** + * Setup the fetch headers. + * It picks the headers from gMe.headers. + * It inserts Authorization only if its non-empty. + * @param {string} apiEP + */ + fetch_headers(apiEP) { + let headers = new Headers(); + for(let k in gMe.headers) { + let v = gMe.headers[k]; + if ((k == "Authorization") && (v.trim() == "")) { + continue; + } + headers.append(k, v); + } + return headers; + } + + /** + * Add needed fields wrt json object to be sent wrt LLM web services completions endpoint. + * The needed fields/options are picked from a global object. + * Add optional stream flag, if required. + * Convert the json into string. + * @param {Object} obj + */ + request_jsonstr_extend(obj) { + for(let k in gMe.apiRequestOptions) { + obj[k] = gMe.apiRequestOptions[k]; + } + if (gMe.bStream) { + obj["stream"] = true; + } + return JSON.stringify(obj); + } + + /** + * Return a string form of json object suitable for chat/completions + */ + request_messages_jsonstr() { + let req = { + messages: this.recent_chat(gMe.iRecentUserMsgCnt), + } + return this.request_jsonstr_extend(req); + } + + /** + * Return a string form of json object suitable for /completions + * @param {boolean} bInsertStandardRolePrefix Insert ": " as prefix wrt each role's message + */ + request_prompt_jsonstr(bInsertStandardRolePrefix) { + let prompt = ""; + let iCnt = 0; + for(const chat of this.recent_chat(gMe.iRecentUserMsgCnt)) { + iCnt += 1; + if (iCnt > 1) { + prompt += "\n"; + } + if (bInsertStandardRolePrefix) { + prompt += `${chat.role}: `; + } + prompt += `${chat.content}`; + } + let req = { + prompt: prompt, + } + return this.request_jsonstr_extend(req); + } + + /** + * Return a string form of json object suitable for specified api endpoint. + * @param {string} apiEP + */ + request_jsonstr(apiEP) { + if (apiEP == ApiEP.Type.Chat) { + return this.request_messages_jsonstr(); + } else { + return this.request_prompt_jsonstr(gMe.bCompletionInsertStandardRolePrefix); + } + } + + /** + * Extract the ai-model/assistant's response from the http response got. + * Optionally trim the message wrt any garbage at the end. + * @param {any} respBody + * @param {string} apiEP + */ + response_extract(respBody, apiEP) { + let assistant = ""; + if (apiEP == ApiEP.Type.Chat) { + assistant = respBody["choices"][0]["message"]["content"]; + } else { + try { + assistant = respBody["choices"][0]["text"]; + } catch { + assistant = respBody["content"]; + } + } + return assistant; + } + + /** + * Extract the ai-model/assistant's response from the http response got in streaming mode. + * @param {any} respBody + * @param {string} apiEP + */ + response_extract_stream(respBody, apiEP) { + let assistant = ""; + if (apiEP == ApiEP.Type.Chat) { + if (respBody["choices"][0]["finish_reason"] !== "stop") { + assistant = respBody["choices"][0]["delta"]["content"]; + } + } else { + try { + assistant = respBody["choices"][0]["text"]; + } catch { + assistant = respBody["content"]; + } + } + return assistant; + } + + /** + * Allow setting of system prompt, but only at begining. + * @param {string} sysPrompt + * @param {string} msgTag + */ + add_system_begin(sysPrompt, msgTag) { + if (this.xchat.length == 0) { + if (sysPrompt.length > 0) { + return this.add(Roles.System, sysPrompt); + } + } else { + if (sysPrompt.length > 0) { + if (this.xchat[0].role !== Roles.System) { + console.error(`ERRR:SimpleChat:SC:${msgTag}:You need to specify system prompt before any user query, ignoring...`); + } else { + if (this.xchat[0].content !== sysPrompt) { + console.error(`ERRR:SimpleChat:SC:${msgTag}:You cant change system prompt, mid way through, ignoring...`); + } + } + } + } + return false; + } + + /** + * Allow setting of system prompt, at any time. + * @param {string} sysPrompt + * @param {string} msgTag + */ + add_system_anytime(sysPrompt, msgTag) { + if (sysPrompt.length <= 0) { + return false; + } + + if (this.iLastSys < 0) { + return this.add(Roles.System, sysPrompt); + } + + let lastSys = this.xchat[this.iLastSys].content; + if (lastSys !== sysPrompt) { + return this.add(Roles.System, sysPrompt); + } + return false; + } + + /** + * Retrieve the latest system prompt. + */ + get_system_latest() { + if (this.iLastSys == -1) { + return ""; + } + let sysPrompt = this.xchat[this.iLastSys].content; + return sysPrompt; + } + + + /** + * Handle the multipart response from server/ai-model + * @param {Response} resp + * @param {string} apiEP + * @param {HTMLDivElement} elDiv + */ + async handle_response_multipart(resp, apiEP, elDiv) { + let elP = ui.el_create_append_p("", elDiv); + if (!resp.body) { + throw Error("ERRR:SimpleChat:SC:HandleResponseMultiPart:No body..."); + } + let tdUtf8 = new TextDecoder("utf-8"); + let rr = resp.body.getReader(); + this.latestResponse = ""; + let xLines = new du.NewLines(); + while(true) { + let { value: cur, done: done } = await rr.read(); + if (cur) { + let curBody = tdUtf8.decode(cur, {stream: true}); + console.debug("DBUG:SC:PART:Str:", curBody); + xLines.add_append(curBody); + } + while(true) { + let curLine = xLines.shift(!done); + if (curLine == undefined) { + break; + } + if (curLine.trim() == "") { + continue; + } + if (curLine.startsWith("data:")) { + curLine = curLine.substring(5); + } + if (curLine.trim() === "[DONE]") { + break; + } + let curJson = JSON.parse(curLine); + console.debug("DBUG:SC:PART:Json:", curJson); + this.append_response(this.response_extract_stream(curJson, apiEP)); + } + elP.innerText = this.latestResponse; + elP.scrollIntoView(false); + if (done) { + break; + } + } + console.debug("DBUG:SC:PART:Full:", this.latestResponse); + return this.latestResponse; + } + + /** + * Handle the oneshot response from server/ai-model + * @param {Response} resp + * @param {string} apiEP + */ + async handle_response_oneshot(resp, apiEP) { + let respBody = await resp.json(); + console.debug(`DBUG:SimpleChat:SC:${this.chatId}:HandleUserSubmit:RespBody:${JSON.stringify(respBody)}`); + return this.response_extract(respBody, apiEP); + } + + /** + * Handle the response from the server be it in oneshot or multipart/stream mode. + * Also take care of the optional garbage trimming. + * @param {Response} resp + * @param {string} apiEP + * @param {HTMLDivElement} elDiv + */ + async handle_response(resp, apiEP, elDiv) { + let theResp = { + assistant: "", + trimmed: "", + } + if (gMe.bStream) { + try { + theResp.assistant = await this.handle_response_multipart(resp, apiEP, elDiv); + this.latestResponse = ""; + } catch (error) { + theResp.assistant = this.latestResponse; + this.add(Roles.Assistant, theResp.assistant); + this.latestResponse = ""; + throw error; + } + } else { + theResp.assistant = await this.handle_response_oneshot(resp, apiEP); + } + if (gMe.bTrimGarbage) { + let origMsg = theResp.assistant; + theResp.assistant = du.trim_garbage_at_end(origMsg); + theResp.trimmed = origMsg.substring(theResp.assistant.length); + } + this.add(Roles.Assistant, theResp.assistant); + return theResp; + } + +} + + +class MultiChatUI { + + constructor() { + /** @type {Object} */ + this.simpleChats = {}; + /** @type {string} */ + this.curChatId = ""; + + // the ui elements + this.elInSystem = /** @type{HTMLInputElement} */(document.getElementById("system-in")); + this.elDivChat = /** @type{HTMLDivElement} */(document.getElementById("chat-div")); + this.elBtnUser = /** @type{HTMLButtonElement} */(document.getElementById("user-btn")); + this.elInUser = /** @type{HTMLInputElement} */(document.getElementById("user-in")); + this.elDivHeading = /** @type{HTMLSelectElement} */(document.getElementById("heading")); + this.elDivSessions = /** @type{HTMLDivElement} */(document.getElementById("sessions-div")); + this.elBtnSettings = /** @type{HTMLButtonElement} */(document.getElementById("settings")); + + this.validate_element(this.elInSystem, "system-in"); + this.validate_element(this.elDivChat, "chat-div"); + this.validate_element(this.elInUser, "user-in"); + this.validate_element(this.elDivHeading, "heading"); + this.validate_element(this.elDivChat, "sessions-div"); + this.validate_element(this.elBtnSettings, "settings"); + } + + /** + * Check if the element got + * @param {HTMLElement | null} el + * @param {string} msgTag + */ + validate_element(el, msgTag) { + if (el == null) { + throw Error(`ERRR:SimpleChat:MCUI:${msgTag} element missing in html...`); + } else { + console.debug(`INFO:SimpleChat:MCUI:${msgTag} Id[${el.id}] Name[${el["name"]}]`); + } + } + + /** + * Reset user input ui. + * * clear user input + * * enable user input + * * set focus to user input + */ + ui_reset_userinput() { + this.elInUser.value = ""; + this.elInUser.disabled = false; + this.elInUser.focus(); + } + + /** + * Setup the needed callbacks wrt UI, curChatId to defaultChatId and + * optionally switch to specified defaultChatId. + * @param {string} defaultChatId + * @param {boolean} bSwitchSession + */ + setup_ui(defaultChatId, bSwitchSession=false) { + + this.curChatId = defaultChatId; + if (bSwitchSession) { + this.handle_session_switch(this.curChatId); + } + + this.elBtnSettings.addEventListener("click", (ev)=>{ + this.elDivChat.replaceChildren(); + gMe.show_settings(this.elDivChat); + }); + + this.elBtnUser.addEventListener("click", (ev)=>{ + if (this.elInUser.disabled) { + return; + } + this.handle_user_submit(this.curChatId, gMe.apiEP).catch((/** @type{Error} */reason)=>{ + let msg = `ERRR:SimpleChat\nMCUI:HandleUserSubmit:${this.curChatId}\n${reason.name}:${reason.message}`; + console.error(msg.replace("\n", ":")); + alert(msg); + this.ui_reset_userinput(); + }); + }); + + this.elInUser.addEventListener("keyup", (ev)=> { + // allow user to insert enter into their message using shift+enter. + // while just pressing enter key will lead to submitting. + if ((ev.key === "Enter") && (!ev.shiftKey)) { + let value = this.elInUser.value; + this.elInUser.value = value.substring(0,value.length-1); + this.elBtnUser.click(); + ev.preventDefault(); + } + }); + + this.elInSystem.addEventListener("keyup", (ev)=> { + // allow user to insert enter into the system prompt using shift+enter. + // while just pressing enter key will lead to setting the system prompt. + if ((ev.key === "Enter") && (!ev.shiftKey)) { + let value = this.elInSystem.value; + this.elInSystem.value = value.substring(0,value.length-1); + let chat = this.simpleChats[this.curChatId]; + chat.add_system_anytime(this.elInSystem.value, this.curChatId); + chat.show(this.elDivChat); + ev.preventDefault(); + } + }); + + } + + /** + * Setup a new chat session and optionally switch to it. + * @param {string} chatId + * @param {boolean} bSwitchSession + */ + new_chat_session(chatId, bSwitchSession=false) { + this.simpleChats[chatId] = new SimpleChat(chatId); + if (bSwitchSession) { + this.handle_session_switch(chatId); + } + } + + + /** + * Handle user query submit request, wrt specified chat session. + * @param {string} chatId + * @param {string} apiEP + */ + async handle_user_submit(chatId, apiEP) { + + let chat = this.simpleChats[chatId]; + + // In completion mode, if configured, clear any previous chat history. + // So if user wants to simulate a multi-chat based completion query, + // they will have to enter the full thing, as a suitable multiline + // user input/query. + if ((apiEP == ApiEP.Type.Completion) && (gMe.bCompletionFreshChatAlways)) { + chat.clear(); + } + + chat.add_system_anytime(this.elInSystem.value, chatId); + + let content = this.elInUser.value; + if (!chat.add(Roles.User, content)) { + console.debug(`WARN:SimpleChat:MCUI:${chatId}:HandleUserSubmit:Ignoring empty user input...`); + return; + } + chat.show(this.elDivChat); + + let theUrl = ApiEP.Url(gMe.baseURL, apiEP); + let theBody = chat.request_jsonstr(apiEP); + + this.elInUser.value = "working..."; + this.elInUser.disabled = true; + console.debug(`DBUG:SimpleChat:MCUI:${chatId}:HandleUserSubmit:${theUrl}:ReqBody:${theBody}`); + let theHeaders = chat.fetch_headers(apiEP); + let resp = await fetch(theUrl, { + method: "POST", + headers: theHeaders, + body: theBody, + }); + + let theResp = await chat.handle_response(resp, apiEP, this.elDivChat); + if (chatId == this.curChatId) { + chat.show(this.elDivChat); + if (theResp.trimmed.length > 0) { + let p = ui.el_create_append_p(`TRIMMED:${theResp.trimmed}`, this.elDivChat); + p.className="role-trim"; + } + } else { + console.debug(`DBUG:SimpleChat:MCUI:HandleUserSubmit:ChatId has changed:[${chatId}] [${this.curChatId}]`); + } + this.ui_reset_userinput(); + } + + /** + * Show buttons for NewChat and available chat sessions, in the passed elDiv. + * If elDiv is undefined/null, then use this.elDivSessions. + * Take care of highlighting the selected chat-session's btn. + * @param {HTMLDivElement | undefined} elDiv + */ + show_sessions(elDiv=undefined) { + if (!elDiv) { + elDiv = this.elDivSessions; + } + elDiv.replaceChildren(); + // Btn for creating new chat session + let btnNew = ui.el_create_button("New CHAT", (ev)=> { + if (this.elInUser.disabled) { + console.error(`ERRR:SimpleChat:MCUI:NewChat:Current session [${this.curChatId}] awaiting response, ignoring request...`); + alert("ERRR:SimpleChat\nMCUI:NewChat\nWait for response to pending query, before starting new chat session"); + return; + } + let chatId = `Chat${Object.keys(this.simpleChats).length}`; + let chatIdGot = prompt("INFO:SimpleChat\nMCUI:NewChat\nEnter id for new chat session", chatId); + if (!chatIdGot) { + console.error("ERRR:SimpleChat:MCUI:NewChat:Skipping based on user request..."); + return; + } + this.new_chat_session(chatIdGot, true); + this.create_session_btn(elDiv, chatIdGot); + ui.el_children_config_class(elDiv, chatIdGot, "session-selected", ""); + }); + elDiv.appendChild(btnNew); + // Btns for existing chat sessions + let chatIds = Object.keys(this.simpleChats); + for(let cid of chatIds) { + let btn = this.create_session_btn(elDiv, cid); + if (cid == this.curChatId) { + btn.className = "session-selected"; + } + } + } + + create_session_btn(elDiv, cid) { + let btn = ui.el_create_button(cid, (ev)=>{ + let target = /** @type{HTMLButtonElement} */(ev.target); + console.debug(`DBUG:SimpleChat:MCUI:SessionClick:${target.id}`); + if (this.elInUser.disabled) { + console.error(`ERRR:SimpleChat:MCUI:SessionClick:${target.id}:Current session [${this.curChatId}] awaiting response, ignoring switch...`); + alert("ERRR:SimpleChat\nMCUI:SessionClick\nWait for response to pending query, before switching"); + return; + } + this.handle_session_switch(target.id); + ui.el_children_config_class(elDiv, target.id, "session-selected", ""); + }); + elDiv.appendChild(btn); + return btn; + } + + /** + * Switch ui to the specified chatId and set curChatId to same. + * @param {string} chatId + */ + async handle_session_switch(chatId) { + let chat = this.simpleChats[chatId]; + if (chat == undefined) { + console.error(`ERRR:SimpleChat:MCUI:HandleSessionSwitch:${chatId} missing...`); + return; + } + this.elInSystem.value = chat.get_system_latest(); + this.elInUser.value = ""; + chat.show(this.elDivChat); + this.elInUser.focus(); + this.curChatId = chatId; + console.log(`INFO:SimpleChat:MCUI:HandleSessionSwitch:${chatId} entered...`); + } + +} + + +class Me { + + constructor() { + this.baseURL = "http://127.0.0.1:8080"; + this.defaultChatIds = [ "Default", "Other" ]; + this.multiChat = new MultiChatUI(); + this.bStream = true; + this.bCompletionFreshChatAlways = true; + this.bCompletionInsertStandardRolePrefix = false; + this.bTrimGarbage = true; + this.iRecentUserMsgCnt = 2; + this.sRecentUserMsgCnt = { + "Full": -1, + "Last0": 1, + "Last1": 2, + "Last2": 3, + "Last4": 5, + }; + this.apiEP = ApiEP.Type.Chat; + this.headers = { + "Content-Type": "application/json", + "Authorization": "", // Authorization: Bearer OPENAI_API_KEY + } + // Add needed fields wrt json object to be sent wrt LLM web services completions endpoint. + this.apiRequestOptions = { + "model": "gpt-3.5-turbo", + "temperature": 0.7, + "max_tokens": 1024, + "n_predict": 1024, + "cache_prompt": false, + //"frequency_penalty": 1.2, + //"presence_penalty": 1.2, + }; + } + + /** + * Disable console.debug by mapping it to a empty function. + */ + debug_disable() { + this.console_debug = console.debug; + console.debug = () => { + + }; + } + + /** + * Setup the load saved chat ui. + * @param {HTMLDivElement} div + * @param {SimpleChat} chat + */ + setup_load(div, chat) { + if (!(chat.ods_key() in localStorage)) { + return; + } + div.innerHTML += `

Restore

+

Load previously saved chat session, if available

`; + let btn = ui.el_create_button(chat.ods_key(), (ev)=>{ + console.log("DBUG:SimpleChat:SC:Load", chat); + chat.load(); + queueMicrotask(()=>{ + chat.show(div); + this.multiChat.elInSystem.value = chat.get_system_latest(); + }); + }); + div.appendChild(btn); + } + + /** + * Show the configurable parameters info in the passed Div element. + * @param {HTMLDivElement} elDiv + * @param {boolean} bAll + */ + show_info(elDiv, bAll=false) { + + let p = ui.el_create_append_p("Settings (devel-tools-console document[gMe])", elDiv); + p.className = "role-system"; + + if (bAll) { + + ui.el_create_append_p(`baseURL:${this.baseURL}`, elDiv); + + ui.el_create_append_p(`Authorization:${this.headers["Authorization"]}`, elDiv); + + ui.el_create_append_p(`bStream:${this.bStream}`, elDiv); + + ui.el_create_append_p(`bTrimGarbage:${this.bTrimGarbage}`, elDiv); + + ui.el_create_append_p(`ApiEndPoint:${this.apiEP}`, elDiv); + + ui.el_create_append_p(`iRecentUserMsgCnt:${this.iRecentUserMsgCnt}`, elDiv); + + ui.el_create_append_p(`bCompletionFreshChatAlways:${this.bCompletionFreshChatAlways}`, elDiv); + + ui.el_create_append_p(`bCompletionInsertStandardRolePrefix:${this.bCompletionInsertStandardRolePrefix}`, elDiv); + + } + + ui.el_create_append_p(`apiRequestOptions:${JSON.stringify(this.apiRequestOptions, null, " - ")}`, elDiv); + ui.el_create_append_p(`headers:${JSON.stringify(this.headers, null, " - ")}`, elDiv); + + } + + /** + * Auto create ui input elements for fields in apiRequestOptions + * Currently supports text and number field types. + * @param {HTMLDivElement} elDiv + */ + show_settings_apirequestoptions(elDiv) { + let typeDict = { + "string": "text", + "number": "number", + }; + let fs = document.createElement("fieldset"); + let legend = document.createElement("legend"); + legend.innerText = "ApiRequestOptions"; + fs.appendChild(legend); + elDiv.appendChild(fs); + for(const k in this.apiRequestOptions) { + let val = this.apiRequestOptions[k]; + let type = typeof(val); + if (((type == "string") || (type == "number"))) { + let inp = ui.el_creatediv_input(`Set${k}`, k, typeDict[type], this.apiRequestOptions[k], (val)=>{ + if (type == "number") { + val = Number(val); + } + this.apiRequestOptions[k] = val; + }); + fs.appendChild(inp.div); + } else if (type == "boolean") { + let bbtn = ui.el_creatediv_boolbutton(`Set{k}`, k, {true: "true", false: "false"}, val, (userVal)=>{ + this.apiRequestOptions[k] = userVal; + }); + fs.appendChild(bbtn.div); + } + } + } + + /** + * Show settings ui for configurable parameters, in the passed Div element. + * @param {HTMLDivElement} elDiv + */ + show_settings(elDiv) { + + let inp = ui.el_creatediv_input("SetBaseURL", "BaseURL", "text", this.baseURL, (val)=>{ + this.baseURL = val; + }); + elDiv.appendChild(inp.div); + + inp = ui.el_creatediv_input("SetAuthorization", "Authorization", "text", this.headers["Authorization"], (val)=>{ + this.headers["Authorization"] = val; + }); + inp.el.placeholder = "Bearer OPENAI_API_KEY"; + elDiv.appendChild(inp.div); + + let bb = ui.el_creatediv_boolbutton("SetStream", "Stream", {true: "[+] yes stream", false: "[-] do oneshot"}, this.bStream, (val)=>{ + this.bStream = val; + }); + elDiv.appendChild(bb.div); + + bb = ui.el_creatediv_boolbutton("SetTrimGarbage", "TrimGarbage", {true: "[+] yes trim", false: "[-] dont trim"}, this.bTrimGarbage, (val)=>{ + this.bTrimGarbage = val; + }); + elDiv.appendChild(bb.div); + + this.show_settings_apirequestoptions(elDiv); + + let sel = ui.el_creatediv_select("SetApiEP", "ApiEndPoint", ApiEP.Type, this.apiEP, (val)=>{ + this.apiEP = ApiEP.Type[val]; + }); + elDiv.appendChild(sel.div); + + sel = ui.el_creatediv_select("SetChatHistoryInCtxt", "ChatHistoryInCtxt", this.sRecentUserMsgCnt, this.iRecentUserMsgCnt, (val)=>{ + this.iRecentUserMsgCnt = this.sRecentUserMsgCnt[val]; + }); + elDiv.appendChild(sel.div); + + bb = ui.el_creatediv_boolbutton("SetCompletionFreshChatAlways", "CompletionFreshChatAlways", {true: "[+] yes fresh", false: "[-] no, with history"}, this.bCompletionFreshChatAlways, (val)=>{ + this.bCompletionFreshChatAlways = val; + }); + elDiv.appendChild(bb.div); + + bb = ui.el_creatediv_boolbutton("SetCompletionInsertStandardRolePrefix", "CompletionInsertStandardRolePrefix", {true: "[+] yes insert", false: "[-] dont insert"}, this.bCompletionInsertStandardRolePrefix, (val)=>{ + this.bCompletionInsertStandardRolePrefix = val; + }); + elDiv.appendChild(bb.div); + + } + +} + + +/** @type {Me} */ +let gMe; + +function startme() { + console.log("INFO:SimpleChat:StartMe:Starting..."); + gMe = new Me(); + gMe.debug_disable(); + document["gMe"] = gMe; + document["du"] = du; + for (let cid of gMe.defaultChatIds) { + gMe.multiChat.new_chat_session(cid); + } + gMe.multiChat.setup_ui(gMe.defaultChatIds[0], true); + gMe.multiChat.show_sessions(); +} + +document.addEventListener("DOMContentLoaded", startme); diff --git a/tools/server/public_simplechat/simplechat_screens.webp b/tools/server/public_simplechat/simplechat_screens.webp new file mode 100644 index 0000000000000000000000000000000000000000..ccea44396051686b97220b0f5b6b9beb63706114 Binary files /dev/null and b/tools/server/public_simplechat/simplechat_screens.webp differ diff --git a/tools/server/public_simplechat/ui.mjs b/tools/server/public_simplechat/ui.mjs new file mode 100644 index 0000000000000000000000000000000000000000..b2d5b9aeab76c0724f9c8fa5238e3ce009aaa877 --- /dev/null +++ b/tools/server/public_simplechat/ui.mjs @@ -0,0 +1,211 @@ +//@ts-check +// Helpers to work with html elements +// by Humans for All +// + + +/** + * Set the class of the children, based on whether it is the idSelected or not. + * @param {HTMLDivElement} elBase + * @param {string} idSelected + * @param {string} classSelected + * @param {string} classUnSelected + */ +export function el_children_config_class(elBase, idSelected, classSelected, classUnSelected="") { + for(let child of elBase.children) { + if (child.id == idSelected) { + child.className = classSelected; + } else { + child.className = classUnSelected; + } + } +} + +/** + * Create button and set it up. + * @param {string} id + * @param {(this: HTMLButtonElement, ev: MouseEvent) => any} callback + * @param {string | undefined} name + * @param {string | undefined} innerText + */ +export function el_create_button(id, callback, name=undefined, innerText=undefined) { + if (!name) { + name = id; + } + if (!innerText) { + innerText = id; + } + let btn = document.createElement("button"); + btn.id = id; + btn.name = name; + btn.innerText = innerText; + btn.addEventListener("click", callback); + return btn; +} + +/** + * Create a para and set it up. Optionaly append it to a passed parent. + * @param {string} text + * @param {HTMLElement | undefined} elParent + * @param {string | undefined} id + */ +export function el_create_append_p(text, elParent=undefined, id=undefined) { + let para = document.createElement("p"); + para.innerText = text; + if (id) { + para.id = id; + } + if (elParent) { + elParent.appendChild(para); + } + return para; +} + +/** + * Create a button which represents bool value using specified text wrt true and false. + * When ever user clicks the button, it will toggle the value and update the shown text. + * + * @param {string} id + * @param {{true: string, false: string}} texts + * @param {boolean} defaultValue + * @param {function(boolean):void} cb + */ +export function el_create_boolbutton(id, texts, defaultValue, cb) { + let el = document.createElement("button"); + el["xbool"] = defaultValue; + el["xtexts"] = structuredClone(texts); + el.innerText = el["xtexts"][String(defaultValue)]; + if (id) { + el.id = id; + } + el.addEventListener('click', (ev)=>{ + el["xbool"] = !el["xbool"]; + el.innerText = el["xtexts"][String(el["xbool"])]; + cb(el["xbool"]); + }) + return el; +} + +/** + * Create a div wrapped button which represents bool value using specified text wrt true and false. + * @param {string} id + * @param {string} label + * @param {{ true: string; false: string; }} texts + * @param {boolean} defaultValue + * @param {(arg0: boolean) => void} cb + * @param {string} className + */ +export function el_creatediv_boolbutton(id, label, texts, defaultValue, cb, className="gridx2") { + let div = document.createElement("div"); + div.className = className; + let lbl = document.createElement("label"); + lbl.setAttribute("for", id); + lbl.innerText = label; + div.appendChild(lbl); + let btn = el_create_boolbutton(id, texts, defaultValue, cb); + div.appendChild(btn); + return { div: div, el: btn }; +} + + +/** + * Create a select ui element, with a set of options to select from. + * * options: an object which contains name-value pairs + * * defaultOption: the value whose name should be choosen, by default. + * * cb : the call back returns the name string of the option selected. + * + * @param {string} id + * @param {Object} options + * @param {*} defaultOption + * @param {function(string):void} cb + */ +export function el_create_select(id, options, defaultOption, cb) { + let el = document.createElement("select"); + el["xselected"] = defaultOption; + el["xoptions"] = structuredClone(options); + for(let cur of Object.keys(options)) { + let op = document.createElement("option"); + op.value = cur; + op.innerText = cur; + if (options[cur] == defaultOption) { + op.selected = true; + } + el.appendChild(op); + } + if (id) { + el.id = id; + el.name = id; + } + el.addEventListener('change', (ev)=>{ + let target = /** @type{HTMLSelectElement} */(ev.target); + console.log("DBUG:UI:Select:", id, ":", target.value); + cb(target.value); + }) + return el; +} + +/** + * Create a div wrapped select ui element, with a set of options to select from. + * + * @param {string} id + * @param {any} label + * @param {{ [x: string]: any; }} options + * @param {any} defaultOption + * @param {(arg0: string) => void} cb + * @param {string} className + */ +export function el_creatediv_select(id, label, options, defaultOption, cb, className="gridx2") { + let div = document.createElement("div"); + div.className = className; + let lbl = document.createElement("label"); + lbl.setAttribute("for", id); + lbl.innerText = label; + div.appendChild(lbl); + let sel = el_create_select(id, options,defaultOption, cb); + div.appendChild(sel); + return { div: div, el: sel }; +} + + +/** + * Create a input ui element. + * + * @param {string} id + * @param {string} type + * @param {any} defaultValue + * @param {function(any):void} cb + */ +export function el_create_input(id, type, defaultValue, cb) { + let el = document.createElement("input"); + el.type = type; + el.value = defaultValue; + if (id) { + el.id = id; + } + el.addEventListener('change', (ev)=>{ + cb(el.value); + }) + return el; +} + +/** + * Create a div wrapped input. + * + * @param {string} id + * @param {string} label + * @param {string} type + * @param {any} defaultValue + * @param {function(any):void} cb + * @param {string} className + */ +export function el_creatediv_input(id, label, type, defaultValue, cb, className="gridx2") { + let div = document.createElement("div"); + div.className = className; + let lbl = document.createElement("label"); + lbl.setAttribute("for", id); + lbl.innerText = label; + div.appendChild(lbl); + let el = el_create_input(id, type, defaultValue, cb); + div.appendChild(el); + return { div: div, el: el }; +} diff --git a/tools/server/server.cpp b/tools/server/server.cpp new file mode 100644 index 0000000000000000000000000000000000000000..721d09182845d4bc0bc2c893807958950ef46205 --- /dev/null +++ b/tools/server/server.cpp @@ -0,0 +1,5004 @@ +#include "chat.h" +#include "utils.hpp" + +#include "arg.h" +#include "common.h" +#include "json-schema-to-grammar.h" +#include "llama.h" +#include "log.h" +#include "sampling.h" +#include "speculative.h" +#include "mtmd.h" +#include "mtmd-helper.h" + +// mime type for sending response +#define MIMETYPE_JSON "application/json; charset=utf-8" + +// auto generated files (see README.md for details) +#include "index.html.gz.hpp" +#include "loading.html.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +constexpr int HTTP_POLLING_SECONDS = 1; + +enum stop_type { + STOP_TYPE_NONE, + STOP_TYPE_EOS, + STOP_TYPE_WORD, + STOP_TYPE_LIMIT, +}; + +// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 +enum slot_state { + SLOT_STATE_IDLE, + SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future + SLOT_STATE_PROCESSING_PROMPT, + SLOT_STATE_DONE_PROMPT, + SLOT_STATE_GENERATING, +}; + +enum server_state { + SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet + SERVER_STATE_READY, // Server is ready and model is loaded +}; + +enum server_task_type { + SERVER_TASK_TYPE_COMPLETION, + SERVER_TASK_TYPE_EMBEDDING, + SERVER_TASK_TYPE_RERANK, + SERVER_TASK_TYPE_INFILL, + SERVER_TASK_TYPE_CANCEL, + SERVER_TASK_TYPE_NEXT_RESPONSE, + SERVER_TASK_TYPE_METRICS, + SERVER_TASK_TYPE_SLOT_SAVE, + SERVER_TASK_TYPE_SLOT_RESTORE, + SERVER_TASK_TYPE_SLOT_ERASE, + SERVER_TASK_TYPE_SET_LORA, +}; + +enum oaicompat_type { + OAICOMPAT_TYPE_NONE, + OAICOMPAT_TYPE_CHAT, + OAICOMPAT_TYPE_COMPLETION, + OAICOMPAT_TYPE_EMBEDDING, +}; + +// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 +enum error_type { + ERROR_TYPE_INVALID_REQUEST, + ERROR_TYPE_AUTHENTICATION, + ERROR_TYPE_SERVER, + ERROR_TYPE_NOT_FOUND, + ERROR_TYPE_PERMISSION, + ERROR_TYPE_UNAVAILABLE, // custom error + ERROR_TYPE_NOT_SUPPORTED, // custom error +}; + +static bool server_task_type_need_embd(server_task_type task_type) { + switch (task_type) { + case SERVER_TASK_TYPE_EMBEDDING: + case SERVER_TASK_TYPE_RERANK: + return true; + default: + return false; + } +} + +static bool server_task_type_need_logits(server_task_type task_type) { + switch (task_type) { + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + return true; + default: + return false; + } +} + +struct slot_params { + bool stream = true; + bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt + bool return_tokens = false; + + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half + int32_t n_predict = -1; // new tokens to predict + int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters + + int64_t t_max_prompt_ms = -1; // TODO: implement + int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit + + std::vector lora; + + std::vector antiprompt; + std::vector response_fields; + bool timings_per_token = false; + bool post_sampling_probs = false; + bool ignore_eos = false; + + struct common_params_sampling sampling; + struct common_params_speculative speculative; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_syntax oaicompat_chat_syntax; + + json to_json() const { + std::vector samplers; + samplers.reserve(sampling.samplers.size()); + for (const auto & sampler : sampling.samplers) { + samplers.emplace_back(common_sampler_type_to_str(sampler)); + } + + json lora = json::array(); + for (size_t i = 0; i < this->lora.size(); ++i) { + lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); + } + + auto grammar_triggers = json::array(); + for (const auto & trigger : sampling.grammar_triggers) { + server_grammar_trigger ct(std::move(trigger)); + grammar_triggers.push_back(ct.to_json()); + } + + return json { + {"n_predict", n_predict}, // Server configured n_predict + {"seed", sampling.seed}, + {"temperature", sampling.temp}, + {"dynatemp_range", sampling.dynatemp_range}, + {"dynatemp_exponent", sampling.dynatemp_exponent}, + {"top_k", sampling.top_k}, + {"top_p", sampling.top_p}, + {"min_p", sampling.min_p}, + {"top_n_sigma", sampling.top_n_sigma}, + {"xtc_probability", sampling.xtc_probability}, + {"xtc_threshold", sampling.xtc_threshold}, + {"typical_p", sampling.typ_p}, + {"repeat_last_n", sampling.penalty_last_n}, + {"repeat_penalty", sampling.penalty_repeat}, + {"presence_penalty", sampling.penalty_present}, + {"frequency_penalty", sampling.penalty_freq}, + {"dry_multiplier", sampling.dry_multiplier}, + {"dry_base", sampling.dry_base}, + {"dry_allowed_length", sampling.dry_allowed_length}, + {"dry_penalty_last_n", sampling.dry_penalty_last_n}, + {"dry_sequence_breakers", sampling.dry_sequence_breakers}, + {"mirostat", sampling.mirostat}, + {"mirostat_tau", sampling.mirostat_tau}, + {"mirostat_eta", sampling.mirostat_eta}, + {"stop", antiprompt}, + {"max_tokens", n_predict}, // User configured n_predict + {"n_keep", n_keep}, + {"n_discard", n_discard}, + {"ignore_eos", sampling.ignore_eos}, + {"stream", stream}, + {"logit_bias", format_logit_bias(sampling.logit_bias)}, + {"n_probs", sampling.n_probs}, + {"min_keep", sampling.min_keep}, + {"grammar", sampling.grammar}, + {"grammar_lazy", sampling.grammar_lazy}, + {"grammar_triggers", grammar_triggers}, + {"preserved_tokens", sampling.preserved_tokens}, + {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)}, + {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)}, + {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content}, + {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open}, + {"samplers", samplers}, + {"speculative.n_max", speculative.n_max}, + {"speculative.n_min", speculative.n_min}, + {"speculative.p_min", speculative.p_min}, + {"timings_per_token", timings_per_token}, + {"post_sampling_probs", post_sampling_probs}, + {"lora", lora}, + }; + } +}; + +struct server_task { + int id = -1; // to be filled by server_queue + int index = -1; // used when there are multiple prompts (batch request) + + server_task_type type; + + // used by SERVER_TASK_TYPE_CANCEL + int id_target = -1; + + // used by SERVER_TASK_TYPE_INFERENCE + slot_params params; + server_tokens prompt_tokens; + int id_selected_slot = -1; + + // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE + struct slot_action { + int slot_id; + std::string filename; + std::string filepath; + }; + slot_action slot_action; + + // used by SERVER_TASK_TYPE_METRICS + bool metrics_reset_bucket = false; + + // used by SERVER_TASK_TYPE_SET_LORA + std::vector set_lora; + + server_task(server_task_type type) : type(type) {} + + static slot_params params_from_json_cmpl( + const llama_context * ctx, + const common_params & params_base, + const json & data) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + slot_params params; + + // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) + slot_params defaults; + defaults.sampling = params_base.sampling; + defaults.speculative = params_base.speculative; + defaults.n_keep = params_base.n_keep; + + // enabling this will output extra debug information in the HTTP responses from the server + params.verbose = params_base.verbosity > 9; + params.timings_per_token = json_value(data, "timings_per_token", false); + + params.stream = json_value(data, "stream", false); + params.cache_prompt = json_value(data, "cache_prompt", true); + params.return_tokens = json_value(data, "return_tokens", false); + params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); + params.n_indent = json_value(data, "n_indent", defaults.n_indent); + params.n_keep = json_value(data, "n_keep", defaults.n_keep); + params.n_discard = json_value(data, "n_discard", defaults.n_discard); + //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement + params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); + params.response_fields = json_value(data, "response_fields", std::vector()); + + params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); + params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); + params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); + params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); + params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); + params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); + params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); + params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); + params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); + params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); + params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); + params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); + params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); + params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); + params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); + params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); + params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); + params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); + params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); + params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); + params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); + params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); + params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); + params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); + params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); + + params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); + params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); + params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); + + params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min); + params.speculative.n_min = std::max(params.speculative.n_min, 0); + params.speculative.n_max = std::max(params.speculative.n_max, 0); + + // Use OpenAI API logprobs only if n_probs wasn't provided + if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){ + params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs); + } + + if (data.contains("lora")) { + if (data.at("lora").is_array()) { + params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora")); + } else { + throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); + } + } else { + params.lora = params_base.lora_adapters; + } + + // TODO: add more sanity checks for the input parameters + + if (params.sampling.penalty_last_n < -1) { + throw std::runtime_error("Error: repeat_last_n must be >= -1"); + } + + if (params.sampling.dry_penalty_last_n < -1) { + throw std::runtime_error("Error: dry_penalty_last_n must be >= -1"); + } + + if (params.sampling.penalty_last_n == -1) { + // note: should be the slot's context and not the full context, but it's ok + params.sampling.penalty_last_n = llama_n_ctx(ctx); + } + + if (params.sampling.dry_penalty_last_n == -1) { + params.sampling.dry_penalty_last_n = llama_n_ctx(ctx); + } + + if (params.sampling.dry_base < 1.0f) { + params.sampling.dry_base = defaults.sampling.dry_base; + } + + // sequence breakers for DRY + { + // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format + // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 + + if (data.contains("dry_sequence_breakers")) { + params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); + if (params.sampling.dry_sequence_breakers.empty()) { + throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings"); + } + } + } + + // process "json_schema" and "grammar" + if (data.contains("json_schema") && !data.contains("grammar")) { + try { + auto schema = json_value(data, "json_schema", json::object()); + SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str()); + params.sampling.grammar = json_schema_to_grammar(schema); + SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str()); + } catch (const std::exception & e) { + throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); + } + } else { + params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); + SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str()); + params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy); + SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false"); + } + + { + auto it = data.find("chat_format"); + if (it != data.end()) { + params.oaicompat_chat_syntax.format = static_cast(it->get()); + SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format)); + } else { + params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format; + } + params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format; + params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (params_base.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); + params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false); + params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false); + } + + { + const auto preserved_tokens = data.find("preserved_tokens"); + if (preserved_tokens != data.end()) { + for (const auto & t : *preserved_tokens) { + auto ids = common_tokenize(vocab, t.get(), /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + SRV_DBG("Preserved token: %d\n", ids[0]); + params.sampling.preserved_tokens.insert(ids[0]); + } else { + // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens. + SRV_DBG("Not preserved because more than 1 token: %s\n", t.get().c_str()); + } + } + } + const auto grammar_triggers = data.find("grammar_triggers"); + if (grammar_triggers != data.end()) { + for (const auto & t : *grammar_triggers) { + server_grammar_trigger ct(t); + if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { + const auto & word = ct.value.value; + auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + auto token = ids[0]; + if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) { + throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word); + } + SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str()); + common_grammar_trigger trigger; + trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; + trigger.value = word; + trigger.token = token; + params.sampling.grammar_triggers.push_back(std::move(trigger)); + } else { + SRV_DBG("Grammar trigger word: `%s`\n", word.c_str()); + params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}); + } + } else { + if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) { + SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str()); + } else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) { + SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str()); + } else { + throw std::runtime_error("Unknown grammar trigger type"); + } + params.sampling.grammar_triggers.emplace_back(std::move(ct.value)); + } + } + } + if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) { + throw std::runtime_error("Error: no triggers set for lazy grammar!"); + } + } + + { + params.sampling.logit_bias.clear(); + params.ignore_eos = json_value(data, "ignore_eos", false); + + const auto & logit_bias = data.find("logit_bias"); + if (logit_bias != data.end() && logit_bias->is_array()) { + const int n_vocab = llama_vocab_n_tokens(vocab); + for (const auto & el : *logit_bias) { + // TODO: we may want to throw errors here, in case "el" is incorrect + if (el.is_array() && el.size() == 2) { + float bias; + if (el[1].is_number()) { + bias = el[1].get(); + } else if (el[1].is_boolean() && !el[1].get()) { + bias = -INFINITY; + } else { + continue; + } + + if (el[0].is_number_integer()) { + llama_token tok = el[0].get(); + if (tok >= 0 && tok < n_vocab) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } else if (el[0].is_string()) { + auto toks = common_tokenize(vocab, el[0].get(), false); + for (auto tok : toks) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } + } + } + } + } + + { + params.antiprompt.clear(); + + const auto & stop = data.find("stop"); + if (stop != data.end() && stop->is_array()) { + for (const auto & word : *stop) { + if (!word.empty()) { + params.antiprompt.push_back(word); + } + } + } + } + + { + const auto samplers = data.find("samplers"); + if (samplers != data.end()) { + if (samplers->is_array()) { + params.sampling.samplers = common_sampler_types_from_names(*samplers, false); + } else if (samplers->is_string()){ + params.sampling.samplers = common_sampler_types_from_chars(samplers->get()); + } + } else { + params.sampling.samplers = defaults.sampling.samplers; + } + } + + std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; + params.oaicompat_model = json_value(data, "model", model_name); + + return params; + } + + // utility function + static std::unordered_set get_list_id(const std::vector & tasks) { + std::unordered_set ids(tasks.size()); + for (size_t i = 0; i < tasks.size(); i++) { + ids.insert(tasks[i].id); + } + return ids; + } +}; + +struct result_timings { + int32_t prompt_n = -1; + double prompt_ms; + double prompt_per_token_ms; + double prompt_per_second; + + int32_t predicted_n = -1; + double predicted_ms; + double predicted_per_token_ms; + double predicted_per_second; + + // Optional speculative metrics - only included when > 0 + int32_t draft_n = 0; + int32_t draft_n_accepted = 0; + + json to_json() const { + json base = { + {"prompt_n", prompt_n}, + {"prompt_ms", prompt_ms}, + {"prompt_per_token_ms", prompt_per_token_ms}, + {"prompt_per_second", prompt_per_second}, + + {"predicted_n", predicted_n}, + {"predicted_ms", predicted_ms}, + {"predicted_per_token_ms", predicted_per_token_ms}, + {"predicted_per_second", predicted_per_second}, + }; + + if (draft_n > 0) { + base["draft_n"] = draft_n; + base["draft_n_accepted"] = draft_n_accepted; + } + + return base; + } +}; + +struct server_task_result { + int id = -1; + int id_slot = -1; + virtual bool is_error() { + // only used by server_task_result_error + return false; + } + virtual bool is_stop() { + // only used by server_task_result_cmpl_* + return false; + } + virtual int get_index() { + return -1; + } + virtual json to_json() = 0; + virtual ~server_task_result() = default; +}; + +// using shared_ptr for polymorphism of server_task_result +using server_task_result_ptr = std::unique_ptr; + +inline std::string stop_type_to_str(stop_type type) { + switch (type) { + case STOP_TYPE_EOS: return "eos"; + case STOP_TYPE_WORD: return "word"; + case STOP_TYPE_LIMIT: return "limit"; + default: return "none"; + } +} + +struct completion_token_output { + llama_token tok; + float prob; + std::string text_to_send; + struct prob_info { + llama_token tok; + std::string txt; + float prob; + }; + std::vector probs; + + json to_json(bool post_sampling_probs) const { + json probs_for_token = json::array(); + for (const auto & p : probs) { + std::string txt(p.txt); + txt.resize(validate_utf8(txt)); + probs_for_token.push_back(json { + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.txt)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? p.prob : logarithm(p.prob) + }, + }); + } + return probs_for_token; + } + + static json probs_vector_to_json(const std::vector & probs, bool post_sampling_probs) { + json out = json::array(); + for (const auto & p : probs) { + std::string txt(p.text_to_send); + txt.resize(validate_utf8(txt)); + out.push_back(json { + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.text_to_send)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? p.prob : logarithm(p.prob) + }, + { + post_sampling_probs ? "top_probs" : "top_logprobs", + p.to_json(post_sampling_probs) + }, + }); + } + return out; + } + + static float logarithm(float x) { + // nlohmann::json converts -inf to null, so we need to prevent that + return x == 0.0f ? std::numeric_limits::lowest() : std::log(x); + } + + static std::vector str_to_bytes(const std::string & str) { + std::vector bytes; + for (unsigned char c : str) { + bytes.push_back(c); + } + return bytes; + } +}; + +struct server_task_result_cmpl_final : server_task_result { + int index = 0; + + std::string content; + llama_tokens tokens; + + bool stream; + result_timings timings; + std::string prompt; + + bool truncated; + int32_t n_decoded; + int32_t n_prompt_tokens; + int32_t n_tokens_cached; + bool has_new_line; + std::string stopping_word; + stop_type stop = STOP_TYPE_NONE; + + bool post_sampling_probs; + std::vector probs_output; + std::vector response_fields; + + slot_params generation_params; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_msg oaicompat_msg; + std::vector oaicompat_msg_diffs; + + virtual int get_index() override { + return index; + } + + virtual bool is_stop() override { + return true; // in stream mode, final responses are considered stop + } + + virtual json to_json() override { + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat(); + case OAICOMPAT_TYPE_CHAT: + return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } + } + + json to_json_non_oaicompat() { + json res = json { + {"index", index}, + {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"tokens", stream ? llama_tokens {} : tokens}, + {"id_slot", id_slot}, + {"stop", true}, + {"model", oaicompat_model}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + {"generation_settings", generation_params.to_json()}, + {"prompt", prompt}, + {"has_new_line", has_new_line}, + {"truncated", truncated}, + {"stop_type", stop_type_to_str(stop)}, + {"stopping_word", stopping_word}, + {"tokens_cached", n_tokens_cached}, + {"timings", timings.to_json()}, + }; + if (!stream && !probs_output.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); + } + return response_fields.empty() ? res : json_get_nested_values(response_fields, res); + } + + json to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (!stream && probs_output.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + json finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } + json res = json { + {"choices", json::array({ + json{ + {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", finish_reason}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + + json to_json_oaicompat_chat() { + std::string finish_reason = "length"; + common_chat_msg msg; + if (!oaicompat_msg.empty()) { + msg = oaicompat_msg; + } else { + msg.role = "assistant"; + msg.content = content; + } + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls"; + } + + json choice { + {"finish_reason", finish_reason}, + {"index", 0}, + {"message", msg.to_json_oaicompat()}, + }; + + if (!stream && probs_output.size() > 0) { + choice["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + + std::time_t t = std::time(0); + + json res = json { + {"choices", json::array({choice})}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + + json to_json_oaicompat_chat_stream() { + std::time_t t = std::time(0); + std::string finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls"; + } + + json deltas = json::array(); + for (const auto & diff : oaicompat_msg_diffs) { + deltas.push_back({ + {"choices", json::array({ + json { + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", common_chat_msg_diff_to_json_oaicompat(diff)}, + }, + })}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + }); + } + + deltas.push_back({ + {"choices", json::array({ + json { + {"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()}, + }, + })}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens}, + }}, + }); + + if (timings.prompt_n >= 0) { + deltas.back().push_back({"timings", timings.to_json()}); + } + + // extra fields for debugging purposes + if (verbose && !deltas.empty()) { + deltas.front()["__verbose"] = to_json_non_oaicompat(); + } + + return deltas; + } +}; + +struct server_task_result_cmpl_partial : server_task_result { + int index = 0; + + std::string content; + llama_tokens tokens; + + int32_t n_decoded; + int32_t n_prompt_tokens; + + bool post_sampling_probs; + completion_token_output prob_output; + result_timings timings; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + std::vector oaicompat_msg_diffs; + + virtual int get_index() override { + return index; + } + + virtual bool is_stop() override { + return false; // in stream mode, partial responses are not considered stop + } + + virtual json to_json() override { + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat(); + case OAICOMPAT_TYPE_CHAT: + return to_json_oaicompat_chat(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } + } + + json to_json_non_oaicompat() { + // non-OAI-compat JSON + json res = json { + {"index", index}, + {"content", content}, + {"tokens", tokens}, + {"stop", false}, + {"id_slot", id_slot}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + }; + // populate the timings object when needed (usually for the last response or with timings_per_token enabled) + if (timings.prompt_n > 0) { + res.push_back({"timings", timings.to_json()}); + } + if (!prob_output.probs.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs); + } + return res; + } + + json to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (prob_output.probs.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + json res = json { + {"choices", json::array({ + json{ + {"text", content}, + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", nullptr}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + + json to_json_oaicompat_chat() { + bool first = n_decoded == 1; + std::time_t t = std::time(0); + json choices; + + std::vector deltas; + auto add_delta = [&](const json & delta) { + deltas.push_back({ + {"choices", json::array({ + json { + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", delta}, + }, + })}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + }); + }; + // We have to send an initial update to conform to openai behavior + if (first) { + add_delta({ + {"role", "assistant"}, + {"content", nullptr}, + }); + } + + for (const auto & diff : oaicompat_msg_diffs) { + add_delta(common_chat_msg_diff_to_json_oaicompat(diff)); + } + + if (!deltas.empty()) { + GGML_ASSERT(deltas[deltas.size() - 1].at("choices").size() >= 1); + + if (prob_output.probs.size() > 0) { + deltas[deltas.size() - 1].at("choices").at(0)["logprobs"] = json { + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + + if (timings.prompt_n >= 0) { + deltas[deltas.size() - 1].push_back({"timings", timings.to_json()}); + } + } + + return deltas; + } +}; + +struct server_task_result_embd : server_task_result { + int index = 0; + std::vector> embedding; + + int32_t n_tokens; + + // OAI-compat fields + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + + virtual int get_index() override { + return index; + } + + virtual json to_json() override { + return oaicompat == OAICOMPAT_TYPE_EMBEDDING + ? to_json_oaicompat() + : to_json_non_oaicompat(); + } + + json to_json_non_oaicompat() { + return json { + {"index", index}, + {"embedding", embedding}, + }; + } + + json to_json_oaicompat() { + return json { + {"index", index}, + {"embedding", embedding[0]}, + {"tokens_evaluated", n_tokens}, + }; + } +}; + +struct server_task_result_rerank : server_task_result { + int index = 0; + float score = -1e6; + + int32_t n_tokens; + + virtual int get_index() override { + return index; + } + + virtual json to_json() override { + return json { + {"index", index}, + {"score", score}, + {"tokens_evaluated", n_tokens}, + }; + } +}; + +// this function maybe used outside of server_task_result_error +static json format_error_response(const std::string & message, const enum error_type type) { + std::string type_str; + int code = 500; + switch (type) { + case ERROR_TYPE_INVALID_REQUEST: + type_str = "invalid_request_error"; + code = 400; + break; + case ERROR_TYPE_AUTHENTICATION: + type_str = "authentication_error"; + code = 401; + break; + case ERROR_TYPE_NOT_FOUND: + type_str = "not_found_error"; + code = 404; + break; + case ERROR_TYPE_SERVER: + type_str = "server_error"; + code = 500; + break; + case ERROR_TYPE_PERMISSION: + type_str = "permission_error"; + code = 403; + break; + case ERROR_TYPE_NOT_SUPPORTED: + type_str = "not_supported_error"; + code = 501; + break; + case ERROR_TYPE_UNAVAILABLE: + type_str = "unavailable_error"; + code = 503; + break; + } + return json { + {"code", code}, + {"message", message}, + {"type", type_str}, + }; +} + +struct server_task_result_error : server_task_result { + int index = 0; + error_type err_type = ERROR_TYPE_SERVER; + std::string err_msg; + + virtual bool is_error() override { + return true; + } + + virtual json to_json() override { + return format_error_response(err_msg, err_type); + } +}; + +struct server_task_result_metrics : server_task_result { + int n_idle_slots; + int n_processing_slots; + int n_tasks_deferred; + int64_t t_start; + + // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; + + uint64_t n_prompt_tokens_processed = 0; + uint64_t t_prompt_processing = 0; + + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; + + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; + + // while we can also use std::vector this requires copying the slot object which can be quite messy + // therefore, we use json to temporarily store the slot.to_json() result + json slots_data = json::array(); + + virtual json to_json() override { + return json { + { "idle", n_idle_slots }, + { "processing", n_processing_slots }, + { "deferred", n_tasks_deferred }, + { "t_start", t_start }, + + { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total }, + { "t_tokens_generation_total", t_tokens_generation_total }, + { "n_tokens_predicted_total", n_tokens_predicted_total }, + { "t_prompt_processing_total", t_prompt_processing_total }, + + { "n_prompt_tokens_processed", n_prompt_tokens_processed }, + { "t_prompt_processing", t_prompt_processing }, + { "n_tokens_predicted", n_tokens_predicted }, + { "t_tokens_generation", t_tokens_generation }, + + { "n_decode_total", n_decode_total }, + { "n_busy_slots_total", n_busy_slots_total }, + + { "slots", slots_data }, + }; + } +}; + +struct server_task_result_slot_save_load : server_task_result { + std::string filename; + bool is_save; // true = save, false = load + + size_t n_tokens; + size_t n_bytes; + double t_ms; + + virtual json to_json() override { + if (is_save) { + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_saved", n_tokens }, + { "n_written", n_bytes }, + { "timings", { + { "save_ms", t_ms } + }}, + }; + } else { + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_restored", n_tokens }, + { "n_read", n_bytes }, + { "timings", { + { "restore_ms", t_ms } + }}, + }; + } + } +}; + +struct server_task_result_slot_erase : server_task_result { + size_t n_erased; + + virtual json to_json() override { + return json { + { "id_slot", id_slot }, + { "n_erased", n_erased }, + }; + } +}; + +struct server_task_result_apply_lora : server_task_result { + virtual json to_json() override { + return json {{ "success", true }}; + } +}; + +struct server_slot { + int id; + int id_task = -1; + + // only used for completion/embedding/infill/rerank + server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; + + llama_batch batch_spec = {}; + + llama_context * ctx = nullptr; + llama_context * ctx_dft = nullptr; + + // multimodal + mtmd_context * mctx = nullptr; + + common_speculative * spec = nullptr; + + std::vector lora; + + // the index relative to completion multi-task request + size_t index = 0; + + struct slot_params params; + + slot_state state = SLOT_STATE_IDLE; + + // used to determine the slot that has been used the longest + int64_t t_last_used = -1; + + // generation props + int32_t n_ctx = 0; // context size per slot + int32_t n_past = 0; + int32_t n_decoded = 0; + int32_t n_remaining = -1; + int32_t i_batch = -1; + int32_t n_predict = -1; // TODO: disambiguate from params.n_predict + + // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated + int32_t n_prompt_tokens = 0; + int32_t n_prompt_tokens_processed = 0; + + // input prompt tokens + server_tokens prompt_tokens; + + size_t last_nl_pos = 0; + + std::string generated_text; + llama_tokens generated_tokens; + common_chat_msg chat_msg; + + server_tokens cache_tokens; + + std::vector generated_token_probs; + + bool has_next_token = true; + bool has_new_line = false; + bool truncated = false; + stop_type stop; + + std::string stopping_word; + + // sampling + json json_schema; + + struct common_sampler * smpl = nullptr; + + llama_token sampled; + + common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + std::vector generated_tool_call_ids; + + // stats + size_t n_sent_text = 0; // number of sent text character + + int64_t t_start_process_prompt; + int64_t t_start_generation; + + double t_prompt_processing; // ms + double t_token_generation; // ms + + std::function callback_on_release; + + // Speculative decoding stats + int32_t n_draft_total = 0; // Total draft tokens generated + int32_t n_draft_accepted = 0; // Draft tokens actually accepted + + void reset() { + SLT_DBG(*this, "%s", "\n"); + + n_prompt_tokens = 0; + last_nl_pos = 0; + generated_text = ""; + has_new_line = false; + truncated = false; + stop = STOP_TYPE_NONE; + stopping_word = ""; + n_past = 0; + n_sent_text = 0; + task_type = SERVER_TASK_TYPE_COMPLETION; + chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + + generated_tokens.clear(); + generated_token_probs.clear(); + chat_msg = {}; + json_schema = json(); + generated_tool_call_ids.clear(); + + // clear speculative decoding stats + n_draft_total = 0; + n_draft_accepted = 0; + } + + bool need_embd() const { + return server_task_type_need_embd(task_type); + } + + bool need_logits() const { + return server_task_type_need_logits(task_type); + } + + // if the context does not have a memory module then all embeddings have to be computed within a single ubatch + // also we cannot split if the pooling would require any past tokens + bool can_split() const { + return + !need_embd() || + (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); + } + + bool can_batch_with(server_slot & other_slot) const { + return task_type == other_slot.task_type && are_lora_equal(lora, other_slot.lora); + } + + bool has_budget(const common_params & global_params) { + if (params.n_predict == -1 && global_params.n_predict == -1) { + return true; // limitless + } + + n_remaining = -1; + + if (params.n_predict != -1) { + n_remaining = params.n_predict - n_decoded; + } else if (global_params.n_predict != -1) { + n_remaining = global_params.n_predict - n_decoded; + } + + return n_remaining > 0; // no budget + } + + bool is_processing() const { + return state != SLOT_STATE_IDLE; + } + + bool can_speculate() const { + return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt; + } + + void add_token(const completion_token_output & token) { + if (!is_processing()) { + SLT_WRN(*this, "%s", "slot is not processing\n"); + return; + } + generated_token_probs.push_back(token); + } + + void release() { + if (is_processing()) { + SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated); + + t_last_used = ggml_time_us(); + t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; + state = SLOT_STATE_IDLE; + callback_on_release(id); + } + } + + result_timings get_timings() const { + result_timings timings; + timings.prompt_n = n_prompt_tokens_processed; + timings.prompt_ms = t_prompt_processing; + timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; + timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + + timings.predicted_n = n_decoded; + timings.predicted_ms = t_token_generation; + timings.predicted_per_token_ms = t_token_generation / n_decoded; + timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; + + // Add speculative metrics + if (n_draft_total > 0) { + timings.draft_n = n_draft_total; + timings.draft_n_accepted = n_draft_accepted; + } + + return timings; + } + + const common_chat_msg & update_chat_msg(std::vector & diffs) { + auto previous_msg = chat_msg; + SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); + auto new_msg = common_chat_parse( + generated_text, + /* is_partial= */ stop != STOP_TYPE_EOS, + params.oaicompat_chat_syntax); + if (!new_msg.empty()) { + new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id); + chat_msg = new_msg; + diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg); + } + return chat_msg; + } + + size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { + size_t stop_pos = std::string::npos; + + for (const std::string & word : params.antiprompt) { + size_t pos; + + if (is_full_stop) { + const size_t tmp = word.size() + last_token_size; + const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; + + pos = text.find(word, from_pos); + } else { + // otherwise, partial stop + pos = string_find_partial_stop(text, word); + } + + if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { + if (is_full_stop) { + stop = STOP_TYPE_WORD; + stopping_word = word; + has_next_token = false; + } + stop_pos = pos; + } + } + + return stop_pos; + } + + void print_timings() const { + const double t_prompt = t_prompt_processing / n_prompt_tokens_processed; + const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + + const double t_gen = t_token_generation / n_decoded; + const double n_gen_second = 1e3 / t_token_generation * n_decoded; + + SLT_INF(*this, + "\n" + "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" + " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" + " total time = %10.2f ms / %5d tokens\n", + t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, + t_token_generation, n_decoded, t_gen, n_gen_second, + t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded); + + if (n_draft_total > 0) { + const float draft_ratio = (float) n_draft_accepted / n_draft_total; + SLT_INF(*this, + "\n" + "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n", + draft_ratio, n_draft_accepted, n_draft_total + ); + } + } + + json to_json() const { + return json { + {"id", id}, + {"id_task", id_task}, + {"n_ctx", n_ctx}, + {"speculative", can_speculate()}, + {"is_processing", is_processing()}, + {"params", params.to_json()}, + {"prompt", prompt_tokens.detokenize(ctx, true)}, + {"next_token", + { + {"has_next_token", has_next_token}, + {"has_new_line", has_new_line}, + {"n_remain", n_remaining}, + {"n_decoded", n_decoded}, + {"stopping_word", stopping_word}, + } + }, + }; + } +}; + +struct server_metrics { + int64_t t_start = 0; + + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; + + uint64_t n_prompt_tokens_processed = 0; + uint64_t t_prompt_processing = 0; + + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; + + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; + + void init() { + t_start = ggml_time_us(); + } + + void on_prompt_eval(const server_slot & slot) { + n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; + n_prompt_tokens_processed += slot.n_prompt_tokens_processed; + t_prompt_processing += slot.t_prompt_processing; + t_prompt_processing_total += slot.t_prompt_processing; + } + + void on_prediction(const server_slot & slot) { + n_tokens_predicted_total += slot.n_decoded; + n_tokens_predicted += slot.n_decoded; + t_tokens_generation += slot.t_token_generation; + t_tokens_generation_total += slot.t_token_generation; + } + + void on_decoded(const std::vector & slots) { + n_decode_total++; + for (const auto & slot : slots) { + if (slot.is_processing()) { + n_busy_slots_total++; + } + } + } + + void reset_bucket() { + n_prompt_tokens_processed = 0; + t_prompt_processing = 0; + n_tokens_predicted = 0; + t_tokens_generation = 0; + } +}; + +struct server_queue { + int id = 0; + bool running; + + // queues + std::deque queue_tasks; + std::deque queue_tasks_deferred; + + std::mutex mutex_tasks; + std::condition_variable condition_tasks; + + // callback functions + std::function callback_new_task; + std::function callback_update_slots; + + // Add a new task to the end of the queue + int post(server_task && task, bool front = false) { + std::unique_lock lock(mutex_tasks); + GGML_ASSERT(task.id != -1); + // if this is cancel task make sure to clean up pending tasks + if (task.type == SERVER_TASK_TYPE_CANCEL) { + cleanup_pending_task(task.id_target); + } + const int task_id = task.id; + QUE_DBG("new task, id = %d, front = %d\n", task_id, front); + if (front) { + queue_tasks.push_front(std::move(task)); + } else { + queue_tasks.push_back(std::move(task)); + } + condition_tasks.notify_one(); + return task_id; + } + + // multi-task version of post() + int post(std::vector && tasks, bool front = false) { + std::unique_lock lock(mutex_tasks); + for (auto & task : tasks) { + if (task.id == -1) { + task.id = id++; + } + // if this is cancel task make sure to clean up pending tasks + if (task.type == SERVER_TASK_TYPE_CANCEL) { + cleanup_pending_task(task.id_target); + } + QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front); + if (front) { + queue_tasks.push_front(std::move(task)); + } else { + queue_tasks.push_back(std::move(task)); + } + } + condition_tasks.notify_one(); + return 0; + } + + // Add a new task, but defer until one slot is available + void defer(server_task && task) { + std::unique_lock lock(mutex_tasks); + QUE_DBG("defer task, id = %d\n", task.id); + queue_tasks_deferred.push_back(std::move(task)); + condition_tasks.notify_one(); + } + + // Get the next id for creating a new task + int get_new_id() { + std::unique_lock lock(mutex_tasks); + int new_id = id++; + return new_id; + } + + // Register function to process a new task + void on_new_task(std::function callback) { + callback_new_task = std::move(callback); + } + + // Register the function to be called when all slots data is ready to be processed + void on_update_slots(std::function callback) { + callback_update_slots = std::move(callback); + } + + // Call when the state of one slot is changed, it will move one task from deferred to main queue + void pop_deferred_task() { + std::unique_lock lock(mutex_tasks); + if (!queue_tasks_deferred.empty()) { + queue_tasks.emplace_back(std::move(queue_tasks_deferred.front())); + queue_tasks_deferred.pop_front(); + } + condition_tasks.notify_one(); + } + + // end the start_loop routine + void terminate() { + std::unique_lock lock(mutex_tasks); + running = false; + condition_tasks.notify_all(); + } + + /** + * Main loop consists of these steps: + * - Wait until a new task arrives + * - Process the task (i.e. maybe copy data into slot) + * - Check if multitask is finished + * - Update all slots + */ + void start_loop() { + running = true; + + while (true) { + QUE_DBG("%s", "processing new tasks\n"); + + while (true) { + std::unique_lock lock(mutex_tasks); + if (!running) { + QUE_DBG("%s", "terminate\n"); + return; + } + if (queue_tasks.empty()) { + lock.unlock(); + break; + } + server_task task = std::move(queue_tasks.front()); + queue_tasks.pop_front(); + lock.unlock(); + + QUE_DBG("processing task, id = %d\n", task.id); + callback_new_task(std::move(task)); + } + + // all tasks in the current loop is processed, slots data is now ready + QUE_DBG("%s", "update slots\n"); + + callback_update_slots(); + + QUE_DBG("%s", "waiting for new tasks\n"); + { + std::unique_lock lock(mutex_tasks); + if (!running) { + QUE_DBG("%s", "terminate\n"); + return; + } + if (queue_tasks.empty()) { + condition_tasks.wait(lock, [&]{ + return (!queue_tasks.empty() || !running); + }); + } + } + } + } + +private: + void cleanup_pending_task(int id_target) { + // no need lock because this is called exclusively by post() + auto rm_func = [id_target](const server_task & task) { + return task.id_target == id_target; + }; + queue_tasks.erase( + std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func), + queue_tasks.end()); + queue_tasks_deferred.erase( + std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func), + queue_tasks_deferred.end()); + } +}; + +struct server_response { + bool running = true; + + // for keeping track of all tasks waiting for the result + std::unordered_set waiting_task_ids; + + // the main result queue (using ptr for polymorphism) + std::vector queue_results; + + std::mutex mutex_results; + std::condition_variable condition_results; + + // add the id_task to the list of tasks waiting for response + void add_waiting_task_id(int id_task) { + SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size()); + + std::unique_lock lock(mutex_results); + waiting_task_ids.insert(id_task); + } + + void add_waiting_tasks(const std::vector & tasks) { + std::unique_lock lock(mutex_results); + + for (const auto & task : tasks) { + SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size()); + waiting_task_ids.insert(task.id); + } + } + + // when the request is finished, we can remove task associated with it + void remove_waiting_task_id(int id_task) { + SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); + + std::unique_lock lock(mutex_results); + waiting_task_ids.erase(id_task); + // make sure to clean up all pending results + queue_results.erase( + std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) { + return res->id == id_task; + }), + queue_results.end()); + } + + void remove_waiting_task_ids(const std::unordered_set & id_tasks) { + std::unique_lock lock(mutex_results); + + for (const auto & id_task : id_tasks) { + SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); + waiting_task_ids.erase(id_task); + } + } + + // This function blocks the thread until there is a response for one of the id_tasks + server_task_result_ptr recv(const std::unordered_set & id_tasks) { + while (true) { + std::unique_lock lock(mutex_results); + condition_results.wait(lock, [&]{ + if (!running) { + SRV_DBG("%s : queue result stop\n", __func__); + std::terminate(); // we cannot return here since the caller is HTTP code + } + return !queue_results.empty(); + }); + + for (size_t i = 0; i < queue_results.size(); i++) { + if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { + server_task_result_ptr res = std::move(queue_results[i]); + queue_results.erase(queue_results.begin() + i); + return res; + } + } + } + + // should never reach here + } + + // same as recv(), but have timeout in seconds + // if timeout is reached, nullptr is returned + server_task_result_ptr recv_with_timeout(const std::unordered_set & id_tasks, int timeout) { + while (true) { + std::unique_lock lock(mutex_results); + + for (int i = 0; i < (int) queue_results.size(); i++) { + if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { + server_task_result_ptr res = std::move(queue_results[i]); + queue_results.erase(queue_results.begin() + i); + return res; + } + } + + std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout)); + if (!running) { + SRV_DBG("%s : queue result stop\n", __func__); + std::terminate(); // we cannot return here since the caller is HTTP code + } + if (cr_res == std::cv_status::timeout) { + return nullptr; + } + } + + // should never reach here + } + + // single-task version of recv() + server_task_result_ptr recv(int id_task) { + std::unordered_set id_tasks = {id_task}; + return recv(id_tasks); + } + + // Send a new result to a waiting id_task + void send(server_task_result_ptr && result) { + SRV_DBG("sending result for task id = %d\n", result->id); + + std::unique_lock lock(mutex_results); + for (const auto & id_task : waiting_task_ids) { + if (result->id == id_task) { + SRV_DBG("task id = %d pushed to result queue\n", result->id); + + queue_results.emplace_back(std::move(result)); + condition_results.notify_all(); + return; + } + } + } + + // terminate the waiting loop + void terminate() { + running = false; + condition_results.notify_all(); + } +}; + +struct server_context { + common_params params_base; + + // note: keep these alive - they determine the lifetime of the model, context, etc. + common_init_result llama_init; + common_init_result llama_init_dft; + + llama_model * model = nullptr; + llama_context * ctx = nullptr; + + // multimodal + mtmd_context * mctx = nullptr; + + const llama_vocab * vocab = nullptr; + + llama_model * model_dft = nullptr; + + llama_context_params cparams_dft; + + llama_batch batch {}; + + bool clean_kv_cache = true; + bool add_bos_token = true; + bool has_eos_token = false; + + int32_t n_ctx; // total context for all clients / slots + + // slots / clients + std::vector slots; + json default_generation_settings_for_props; + + server_queue queue_tasks; + server_response queue_results; + + server_metrics metrics; + + // Necessary similarity of prompt for slot selection + float slot_prompt_similarity = 0.0f; + + common_chat_templates_ptr chat_templates; + oaicompat_parser_options oai_parser_opt; + + ~server_context() { + mtmd_free(mctx); + + // Clear any sampling context + for (server_slot & slot : slots) { + common_sampler_free(slot.smpl); + slot.smpl = nullptr; + + llama_free(slot.ctx_dft); + slot.ctx_dft = nullptr; + + common_speculative_free(slot.spec); + slot.spec = nullptr; + + llama_batch_free(slot.batch_spec); + } + + llama_batch_free(batch); + } + + bool load_model(const common_params & params) { + SRV_INF("loading model '%s'\n", params.model.path.c_str()); + + params_base = params; + + llama_init = common_init_from_params(params_base); + + model = llama_init.model.get(); + ctx = llama_init.context.get(); + + if (model == nullptr) { + SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); + return false; + } + + vocab = llama_model_get_vocab(model); + + n_ctx = llama_n_ctx(ctx); + + add_bos_token = llama_vocab_get_add_bos(vocab); + has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; + + if (!params_base.speculative.model.path.empty() || !params_base.speculative.model.hf_repo.empty()) { + SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str()); + + auto params_dft = params_base; + + params_dft.devices = params_base.speculative.devices; + params_dft.model = params_base.speculative.model; + params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx; + params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; + params_dft.n_parallel = 1; + + // force F16 KV cache for the draft model for extra performance + params_dft.cache_type_k = GGML_TYPE_F16; + params_dft.cache_type_v = GGML_TYPE_F16; + + llama_init_dft = common_init_from_params(params_dft); + + model_dft = llama_init_dft.model.get(); + + if (model_dft == nullptr) { + SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str()); + return false; + } + + if (!common_speculative_are_compatible(ctx, llama_init_dft.context.get())) { + SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str()); + + return false; + } + + const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get()); + + cparams_dft = common_context_params_to_llama(params_dft); + cparams_dft.n_batch = n_ctx_dft; + + // the context is not needed - we will create one for each slot + llama_init_dft.context.reset(); + } + + chat_templates = common_chat_templates_init(model, params_base.chat_template); + try { + common_chat_format_example(chat_templates.get(), params.use_jinja); + } catch (const std::exception & e) { + SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what()); + SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); + chat_templates = common_chat_templates_init(model, "chatml"); + } + + std::string & mmproj_path = params_base.mmproj.path; + if (!mmproj_path.empty()) { + mtmd_context_params mparams = mtmd_context_params_default(); + mparams.use_gpu = params_base.mmproj_use_gpu; + mparams.print_timings = false; + mparams.n_threads = params_base.cpuparams.n_threads; + mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO; + mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); + if (mctx == nullptr) { + SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); + return false; + } + SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str()); + + if (params_base.ctx_shift) { + params_base.ctx_shift = false; + SRV_WRN("%s\n", "ctx_shift is not supported by multimodal, it will be disabled"); + } + + if (params_base.n_cache_reuse) { + params_base.n_cache_reuse = 0; + SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); + } + + if (!params_base.speculative.model.path.empty()) { + SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal"); + return false; + } + } + + if (!llama_memory_can_shift(llama_get_memory(ctx))) { + if (params_base.ctx_shift) { + params_base.ctx_shift = false; + SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled"); + } + + if (params_base.n_cache_reuse) { + params_base.n_cache_reuse = 0; + SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled"); + } + } + + return true; + } + + void init() { + const int32_t n_ctx_slot = n_ctx / params_base.n_parallel; + + SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); + + for (int i = 0; i < params_base.n_parallel; i++) { + server_slot slot; + + slot.id = i; + slot.ctx = ctx; + slot.n_ctx = n_ctx_slot; + slot.n_predict = params_base.n_predict; + slot.mctx = mctx; + slot.cache_tokens.has_mtmd = mctx != nullptr; + + if (model_dft) { + slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); + + slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); + if (slot.ctx_dft == nullptr) { + SRV_ERR("%s", "failed to create draft context\n"); + return; + } + + slot.spec = common_speculative_init(slot.ctx_dft); + if (slot.spec == nullptr) { + SRV_ERR("%s", "failed to create speculator\n"); + return; + } + } + + SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); + + slot.params.sampling = params_base.sampling; + slot.params.n_keep = params_base.n_keep; + + slot.callback_on_release = [this](int) { + queue_tasks.pop_deferred_task(); + }; + + slot.reset(); + + slots.push_back(std::move(slot)); + } + + default_generation_settings_for_props = slots[0].to_json(); + + // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens + // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) + { + const int32_t n_batch = llama_n_batch(ctx); + batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); + } + + metrics.init(); + + oai_parser_opt = { + /* use_jinja */ params_base.use_jinja, + /* prefill_assistant */ params_base.prefill_assistant, + /* reasoning_format */ params_base.reasoning_format, + /* common_chat_templates */ chat_templates.get(), + /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, + /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, + /* enable_thinking */ params_base.reasoning_budget != 0, + }; + } + + server_slot * get_slot_by_id(int id) { + for (server_slot & slot : slots) { + if (slot.id == id) { + return &slot; + } + } + + return nullptr; + } + + server_slot * get_available_slot(const server_task & task) { + server_slot * ret = nullptr; + + // find the slot that has at least n% prompt similarity + if (ret == nullptr && slot_prompt_similarity != 0.0f) { + int lcs_len = 0; + float similarity = 0; + + for (server_slot & slot : slots) { + // skip the slot if it is not available + if (slot.is_processing()) { + continue; + } + + // skip the slot if it does not contains cached tokens + if (slot.cache_tokens.empty()) { + continue; + } + + // length of the Longest Common Subsequence between the current slot's prompt and the input prompt + int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens); + + // fraction of the common subsequence length compared to the current slot's prompt length + float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.size()); + + // select the current slot if the criteria match + if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) { + lcs_len = cur_lcs_len; + similarity = cur_similarity; + ret = &slot; + } + } + + if (ret != nullptr) { + SLT_DBG(*ret, "selected slot by lcs similarity, lcs_len = %d, similarity = %f\n", lcs_len, similarity); + } + } + + // find the slot that has been least recently used + if (ret == nullptr) { + int64_t t_last = -1; + + for (server_slot & slot : slots) { + // skip the slot if it is not available + if (slot.is_processing()) { + continue; + } + + // select the current slot if the criteria match + if (!ret || slot.t_last_used <= t_last) { + t_last = slot.t_last_used; + ret = &slot; + } + } + + if (ret != nullptr) { + SLT_DBG(*ret, "selected slot by lru, t_last = %" PRId64 "\n", t_last); + } + } + + return ret; + } + + bool launch_slot_with_task(server_slot & slot, server_task && task) { + slot.reset(); + slot.id_task = task.id; + slot.index = task.index; + slot.task_type = task.type; + slot.params = std::move(task.params); + slot.prompt_tokens = std::move(task.prompt_tokens); + + if (!are_lora_equal(slot.params.lora, slot.lora)) { + // if lora is changed, we cannot reuse cached tokens + slot.cache_tokens.clear(); + slot.lora = slot.params.lora; + } + + if (!slot.prompt_tokens.validate(ctx)) { + send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); + return false; + } + SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); + + if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { + // Might be better to reject the request with a 400 ? + SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict, slot.n_predict); + slot.params.n_predict = slot.n_predict; + } + + if (slot.params.ignore_eos && has_eos_token) { + slot.params.sampling.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY}); + } + + { + if (slot.smpl != nullptr) { + common_sampler_free(slot.smpl); + } + + slot.smpl = common_sampler_init(model, slot.params.sampling); + if (slot.smpl == nullptr) { + // for now, the only error that may happen here is invalid grammar + send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); + return false; + } + } + + if (slot.ctx_dft) { + llama_batch_free(slot.batch_spec); + + slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); + } + + slot.state = SLOT_STATE_STARTED; + + SLT_INF(slot, "%s", "processing task\n"); + + return true; + } + + void kv_cache_clear() { + SRV_DBG("%s", "clearing KV cache\n"); + + // clear the entire KV cache + llama_memory_clear(llama_get_memory(ctx), true); + clean_kv_cache = false; + } + + bool process_token(completion_token_output & result, server_slot & slot) { + // remember which tokens were sampled - used for repetition penalties during sampling + const std::string token_str = result.text_to_send; + slot.sampled = result.tok; + + slot.generated_text += token_str; + if (slot.params.return_tokens) { + slot.generated_tokens.push_back(result.tok); + } + slot.has_next_token = true; + + // check if there is incomplete UTF-8 character at the end + bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); + + // search stop word and delete it + if (!incomplete) { + size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); + + const std::string str_test = slot.generated_text.substr(pos); + bool send_text = true; + + size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); + if (stop_pos != std::string::npos) { + slot.generated_text.erase( + slot.generated_text.begin() + pos + stop_pos, + slot.generated_text.end()); + pos = std::min(slot.n_sent_text, slot.generated_text.size()); + } else if (slot.has_next_token) { + stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); + send_text = stop_pos == std::string::npos; + } + + // check if there is any token to predict + if (send_text) { + // no send the stop word in the response + result.text_to_send = slot.generated_text.substr(pos, std::string::npos); + slot.n_sent_text += result.text_to_send.size(); + // add the token to slot queue and cache + } else { + result.text_to_send = ""; + } + + slot.add_token(result); + if (slot.params.stream) { + send_partial_response(slot, result); + } + } + + if (incomplete) { + slot.has_next_token = true; + } + + // if context shifting is disabled, make sure that we don't run out of context + if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped due to running out of context, n_past = %d, n_ctx = %d\n", slot.n_past, slot.n_ctx); + } + + // check the limits + if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict); + } + + if (slot.has_new_line) { + // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent + if (slot.params.n_indent > 0) { + // check the current indentation + // TODO: improve by not doing it more than once for each new line + if (slot.last_nl_pos > 0) { + size_t pos = slot.last_nl_pos; + + int n_indent = 0; + while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) { + n_indent++; + pos++; + } + + if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + // cut the last line + slot.generated_text.erase(pos, std::string::npos); + + SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent); + } + } + + // find the next new line + { + const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos); + + if (pos != std::string::npos) { + slot.last_nl_pos = pos + 1; + } + } + } + } + + // check if there is a new line in the generated text + if (result.text_to_send.find('\n') != std::string::npos) { + slot.has_new_line = true; + + // if we have seen a new line, we stop after a certain time limit, but only upon another new line + if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms); + } + } + + // if context shift is disabled, we stop when it reaches the context limit + if (slot.n_past >= slot.n_ctx) { + slot.truncated = true; + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", + slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx); + } + + if (llama_vocab_is_eog(vocab, result.tok)) { + slot.stop = STOP_TYPE_EOS; + slot.has_next_token = false; + + SLT_DBG(slot, "%s", "stopped by EOS\n"); + } + + const auto n_ctx_train = llama_model_n_ctx_train(model); + + if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { + slot.truncated = true; + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; // stop prediction + + SLT_WRN(slot, + "n_predict (%d) is set for infinite generation. " + "Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n", + slot.params.n_predict, n_ctx_train); + } + + SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); + + return slot.has_next_token; // continue + } + + void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) { + size_t n_probs = slot.params.sampling.n_probs; + size_t n_vocab = llama_vocab_n_tokens(vocab); + if (post_sampling) { + const auto * cur_p = common_sampler_get_candidates(slot.smpl); + const size_t max_probs = cur_p->size; + + // set probability for sampled token + for (size_t i = 0; i < max_probs; i++) { + if (cur_p->data[i].id == result.tok) { + result.prob = cur_p->data[i].p; + break; + } + } + + // set probability for top n_probs tokens + result.probs.reserve(max_probs); + for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { + result.probs.push_back({ + cur_p->data[i].id, + common_token_to_piece(ctx, cur_p->data[i].id, special), + cur_p->data[i].p + }); + } + } else { + // TODO: optimize this with min-p optimization + std::vector cur = get_token_probabilities(ctx, idx); + + // set probability for sampled token + for (size_t i = 0; i < n_vocab; i++) { + // set probability for sampled token + if (cur[i].id == result.tok) { + result.prob = cur[i].p; + break; + } + } + + // set probability for top n_probs tokens + result.probs.reserve(n_probs); + for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { + result.probs.push_back({ + cur[i].id, + common_token_to_piece(ctx, cur[i].id, special), + cur[i].p + }); + } + } + } + + void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + send_error(task.id, error, type); + } + + void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + send_error(slot.id_task, error, type); + } + + void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str()); + + auto res = std::make_unique(); + res->id = id_task; + res->err_type = type; + res->err_msg = error; + + queue_results.send(std::move(res)); + } + + // if multimodal is enabled, send an error and return false + bool ensure_no_mtmd(const int id_task) { + if (mctx) { + send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); + return false; + } + return true; + } + + void send_partial_response(server_slot & slot, const completion_token_output & tkn) { + auto res = std::make_unique(); + + res->id = slot.id_task; + res->index = slot.index; + res->content = tkn.text_to_send; + res->tokens = { tkn.tok }; + + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; + res->post_sampling_probs = slot.params.post_sampling_probs; + + res->verbose = slot.params.verbose; + res->oaicompat = slot.params.oaicompat; + res->oaicompat_model = slot.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + + slot.update_chat_msg(res->oaicompat_msg_diffs); + + // populate res.probs_output + if (slot.params.sampling.n_probs > 0) { + res->prob_output = tkn; // copy the token probs + } + + // populate timings if this is final response or timings_per_token is enabled + if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) { + res->timings = slot.get_timings(); + } + + queue_results.send(std::move(res)); + } + + void send_final_response(server_slot & slot) { + auto res = std::make_unique(); + res->id = slot.id_task; + res->id_slot = slot.id; + + res->index = slot.index; + res->content = slot.generated_text; + res->tokens = std::move(slot.generated_tokens); + res->timings = slot.get_timings(); + res->prompt = slot.prompt_tokens.detokenize(ctx, true); + res->response_fields = std::move(slot.params.response_fields); + + res->truncated = slot.truncated; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_tokens_cached = slot.n_past; + res->has_new_line = slot.has_new_line; + res->stopping_word = slot.stopping_word; + res->stop = slot.stop; + res->post_sampling_probs = slot.params.post_sampling_probs; + + res->verbose = slot.params.verbose; + res->stream = slot.params.stream; + res->oaicompat = slot.params.oaicompat; + res->oaicompat_model = slot.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); + + // populate res.probs_output + if (slot.params.sampling.n_probs > 0) { + if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) { + const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); + + size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); + res->probs_output = std::vector( + slot.generated_token_probs.begin(), + slot.generated_token_probs.end() - safe_offset); + } else { + res->probs_output = std::vector( + slot.generated_token_probs.begin(), + slot.generated_token_probs.end()); + } + } + + res->generation_params = slot.params; // copy the parameters + + queue_results.send(std::move(res)); + } + + void send_embedding(const server_slot & slot, const llama_batch & batch) { + auto res = std::make_unique(); + res->id = slot.id_task; + res->index = slot.index; + res->n_tokens = slot.n_prompt_tokens; + res->oaicompat = slot.params.oaicompat; + + const int n_embd = llama_model_n_embd(model); + + std::vector embd_res(n_embd, 0.0f); + + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; + } + + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } + + if (embd == NULL) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + + res->embedding.push_back(std::vector(n_embd, 0.0f)); + continue; + } + + // normalize only when there is pooling + // TODO: configurable + if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { + common_embd_normalize(embd, embd_res.data(), n_embd, 2); + res->embedding.push_back(embd_res); + } else { + res->embedding.push_back({ embd, embd + n_embd }); + } + } + + SLT_DBG(slot, "%s", "sending embeddings\n"); + + queue_results.send(std::move(res)); + } + + void send_rerank(const server_slot & slot, const llama_batch & batch) { + auto res = std::make_unique(); + res->id = slot.id_task; + res->index = slot.index; + res->n_tokens = slot.n_prompt_tokens; + + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; + } + + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } + + if (embd == NULL) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + + res->score = -1e6; + continue; + } + + res->score = embd[0]; + } + + SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score); + + queue_results.send(std::move(res)); + } + + // + // Functions to create new task(s) and receive result(s) + // + + void cancel_tasks(const std::unordered_set & id_tasks) { + std::vector cancel_tasks; + cancel_tasks.reserve(id_tasks.size()); + for (const auto & id_task : id_tasks) { + SRV_WRN("cancel task, id_task = %d\n", id_task); + + server_task task(SERVER_TASK_TYPE_CANCEL); + task.id_target = id_task; + queue_results.remove_waiting_task_id(id_task); + cancel_tasks.push_back(std::move(task)); + } + // push to beginning of the queue, so it has highest priority + queue_tasks.post(std::move(cancel_tasks), true); + } + + // receive the results from task(s) + void receive_multi_results( + const std::unordered_set & id_tasks, + const std::function&)> & result_handler, + const std::function & error_handler, + const std::function & is_connection_closed) { + std::vector results(id_tasks.size()); + for (int i = 0; i < (int)id_tasks.size(); i++) { + server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); + + if (is_connection_closed()) { + cancel_tasks(id_tasks); + return; + } + + if (result == nullptr) { + i--; // retry + continue; + } + + if (result->is_error()) { + error_handler(result->to_json()); + cancel_tasks(id_tasks); + return; + } + + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + const size_t idx = result->get_index(); + GGML_ASSERT(idx < results.size() && "index out of range"); + results[idx] = std::move(result); + } + result_handler(results); + } + + // receive the results from task(s), in stream mode + void receive_cmpl_results_stream( + const std::unordered_set & id_tasks, + const std::function & result_handler, + const std::function & error_handler, + const std::function & is_connection_closed) { + size_t n_finished = 0; + while (true) { + server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); + + if (is_connection_closed()) { + cancel_tasks(id_tasks); + return; + } + + if (result == nullptr) { + continue; // retry + } + + if (result->is_error()) { + error_handler(result->to_json()); + cancel_tasks(id_tasks); + return; + } + + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + if (!result_handler(result)) { + cancel_tasks(id_tasks); + break; + } + + if (result->is_stop()) { + if (++n_finished == id_tasks.size()) { + break; + } + } + } + } + + // + // Functions to process the task + // + + void process_single_task(server_task && task) { + switch (task.type) { + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + case SERVER_TASK_TYPE_EMBEDDING: + case SERVER_TASK_TYPE_RERANK: + { + const int id_slot = task.id_selected_slot; + + server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); + + if (slot == nullptr) { + // if no slot is available, we defer this task for processing later + SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + if (!launch_slot_with_task(*slot, std::move(task))) { + SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id); + break; + } + } break; + case SERVER_TASK_TYPE_CANCEL: + { + // release slot linked with the task id + for (auto & slot : slots) { + if (slot.id_task == task.id_target) { + slot.release(); + break; + } + } + } break; + case SERVER_TASK_TYPE_NEXT_RESPONSE: + { + // do nothing + } break; + case SERVER_TASK_TYPE_METRICS: + { + json slots_data = json::array(); + + int n_idle_slots = 0; + int n_processing_slots = 0; + + for (server_slot & slot : slots) { + json slot_data = slot.to_json(); + + if (slot.is_processing()) { + n_processing_slots++; + } else { + n_idle_slots++; + } + + slots_data.push_back(slot_data); + } + SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots); + + auto res = std::make_unique(); + res->id = task.id; + res->slots_data = std::move(slots_data); + res->n_idle_slots = n_idle_slots; + res->n_processing_slots = n_processing_slots; + res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size(); + res->t_start = metrics.t_start; + + res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; + res->t_prompt_processing_total = metrics.t_prompt_processing_total; + res->n_tokens_predicted_total = metrics.n_tokens_predicted_total; + res->t_tokens_generation_total = metrics.t_tokens_generation_total; + + res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed; + res->t_prompt_processing = metrics.t_prompt_processing; + res->n_tokens_predicted = metrics.n_tokens_predicted; + res->t_tokens_generation = metrics.t_tokens_generation; + + res->n_decode_total = metrics.n_decode_total; + res->n_busy_slots_total = metrics.n_busy_slots_total; + + if (task.metrics_reset_bucket) { + metrics.reset_bucket(); + } + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_SAVE: + { + if (!ensure_no_mtmd(task.id)) { + break; + } + + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + const size_t token_count = slot->cache_tokens.size(); + const int64_t t_start = ggml_time_us(); + + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; + + const llama_tokens & tokens = slot->cache_tokens.get_text_tokens(); + const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); + + const int64_t t_end = ggml_time_us(); + const double t_save_ms = (t_end - t_start) / 1000.0; + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = true; + res->n_tokens = token_count; + res->n_bytes = nwrite; + res->t_ms = t_save_ms; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_RESTORE: + { + if (!ensure_no_mtmd(task.id)) break; + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + const int64_t t_start = ggml_time_us(); + + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; + + llama_tokens tokens; + tokens.resize(slot->n_ctx); + size_t token_count = 0; + size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); + if (nread == 0) { + slot->cache_tokens.clear(); // KV may already been invalidated? + send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); + break; + } + tokens.resize(token_count); + slot->cache_tokens.clear(); + slot->cache_tokens.insert(tokens); + + const int64_t t_end = ggml_time_us(); + const double t_restore_ms = (t_end - t_start) / 1000.0; + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = false; + res->n_tokens = token_count; + res->n_bytes = nread; + res->t_ms = t_restore_ms; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_ERASE: + { + if (!ensure_no_mtmd(task.id)) break; + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + // Erase token cache + const size_t n_erased = slot->cache_tokens.size(); + llama_memory_seq_rm(llama_get_memory(ctx), slot->id, -1, -1); + slot->cache_tokens.clear(); + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->n_erased = n_erased; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SET_LORA: + { + params_base.lora_adapters = std::move(task.set_lora); + auto res = std::make_unique(); + res->id = task.id; + queue_results.send(std::move(res)); + } break; + + } + } + + void update_slots() { + // check if all slots are idle + { + bool all_idle = true; + + for (auto & slot : slots) { + if (slot.is_processing()) { + all_idle = false; + break; + } + } + + if (all_idle) { + SRV_INF("%s", "all slots are idle\n"); + if (clean_kv_cache) { + kv_cache_clear(); + } + + return; + } + } + + { + SRV_DBG("%s", "posting NEXT_RESPONSE\n"); + + server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); + task.id = queue_tasks.get_new_id(); + queue_tasks.post(std::move(task)); + } + + // apply context-shift if needed + // TODO: simplify and improve + for (server_slot & slot : slots) { + if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) { + if (!params_base.ctx_shift) { + // this check is redundant (for good) + // we should never get here, because generation should already stopped in process_token() + slot.release(); + send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); + continue; + } + + if (mctx) { + // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded + // we don't support ctx_shift because an image chunk may contains multiple tokens + GGML_ABORT("not supported by multimodal"); + } + + // Shift context + const int n_keep = slot.params.n_keep + add_bos_token; + const int n_left = slot.n_past - n_keep; + const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); + + SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); + + llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard); + llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.n_past, -n_discard); + + // add generated tokens to cache + { + llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy + for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { + new_tokens[i - n_discard] = new_tokens[i]; + } + + new_tokens.resize(slot.cache_tokens.size() - n_discard); + slot.cache_tokens.clear(); + slot.cache_tokens.insert(new_tokens); + } + + slot.n_past -= n_discard; + + slot.truncated = true; + } + } + + // start populating the batch for this iteration + common_batch_clear(batch); + + // track if given slot can be batched with slots already in the batch + server_slot * slot_batched = nullptr; + + auto accept_special_token = [&](server_slot & slot, llama_token token) { + return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end(); + }; + + // frist, add sampled tokens from any ongoing sequences + for (auto & slot : slots) { + if (slot.state != SLOT_STATE_GENERATING) { + continue; + } + + // check if we can batch this slot with the previous one + if (!slot_batched) { + slot_batched = &slot; + } else if (!slot_batched->can_batch_with(slot)) { + continue; + } + + slot.i_batch = batch.n_tokens; + + common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); + + slot.n_past += 1; + slot.cache_tokens.push_back(slot.sampled); + + SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", + slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated); + } + + // process in chunks of params.n_batch + int32_t n_batch = llama_n_batch(ctx); + int32_t n_ubatch = llama_n_ubatch(ctx); + + // next, batch any pending prompts without exceeding n_batch + if (params_base.cont_batching || batch.n_tokens == 0) { + for (auto & slot : slots) { + // check if we can batch this slot with the previous one + if (slot.is_processing()) { + if (!slot_batched) { + slot_batched = &slot; + } else if (!slot_batched->can_batch_with(slot)) { + continue; + } + } + + // this slot still has a prompt to be processed + if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { + auto & prompt_tokens = slot.prompt_tokens; + + // TODO: maybe move branch to outside of this loop in the future + if (slot.state == SLOT_STATE_STARTED) { + slot.t_start_process_prompt = ggml_time_us(); + slot.t_start_generation = 0; + + slot.n_past = 0; + slot.n_prompt_tokens = prompt_tokens.size(); + slot.state = SLOT_STATE_PROCESSING_PROMPT; + + SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); + + // print prompt tokens (for debugging) + /*if (1) { + // first 16 tokens (avoid flooding logs) + for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + } + } else { + // all + for (int i = 0; i < (int) prompt_tokens.size(); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + } + }*/ + + // empty prompt passed -> release the slot and send empty response + if (prompt_tokens.empty()) { + SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); + + slot.release(); + slot.print_timings(); + send_final_response(slot); + continue; + } + + // TODO: support memory-less logits computation + if (slot.need_logits() && !llama_get_memory(ctx)) { + slot.release(); + send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); + continue; + } + + if (!slot.can_split()) { + if (slot.n_prompt_tokens > n_ubatch) { + slot.release(); + send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); + continue; + } + + if (slot.n_prompt_tokens > slot.n_ctx) { + slot.release(); + send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER); + continue; + } + } else { + if (!params_base.ctx_shift) { + // if context shift is disabled, we make sure prompt size is smaller than KV size + // TODO: there should be a separate parameter that control prompt truncation + // context shift should be applied only during the generation phase + if (slot.n_prompt_tokens >= slot.n_ctx) { + slot.release(); + send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST); + continue; + } + } + if (slot.params.n_keep < 0) { + slot.params.n_keep = slot.n_prompt_tokens; + } + slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); + + // if input prompt is too big, truncate it + if (slot.n_prompt_tokens >= slot.n_ctx) { + if (mctx) { + // we should never reach this + GGML_ABORT("not supported by multimodal"); + } + const int n_left = slot.n_ctx - slot.params.n_keep; + + const int n_block_size = n_left / 2; + const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; + + const llama_tokens & curr_tokens = slot.prompt_tokens.get_text_tokens(); + llama_tokens new_tokens( + curr_tokens.begin(), + curr_tokens.begin() + slot.params.n_keep); + + new_tokens.insert( + new_tokens.end(), + curr_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, + curr_tokens.end()); + + prompt_tokens.clear(); + prompt_tokens.insert(new_tokens); + + slot.truncated = true; + slot.n_prompt_tokens = prompt_tokens.size(); + + SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens); + + GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); + } + + if (slot.params.cache_prompt) { + // reuse any previously computed tokens that are common with the new prompt + slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens); + + // reuse chunks from the cached prompt by shifting their KV cache in the new position + if (params_base.n_cache_reuse > 0) { + size_t head_c = slot.n_past; // cache + size_t head_p = slot.n_past; // current prompt + + if (mctx) { + // we should never reach this + GGML_ABORT("not supported by multimodal"); + } + + SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past); + + while (head_c < slot.cache_tokens.size() && + head_p < prompt_tokens.size()) { + + size_t n_match = 0; + while (head_c + n_match < slot.cache_tokens.size() && + head_p + n_match < prompt_tokens.size() && + slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) { + + n_match++; + } + + if (n_match >= (size_t) params_base.n_cache_reuse) { + SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); + //for (size_t i = head_p; i < head_p + n_match; i++) { + // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + //} + + const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; + + llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c); + llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift); + + for (size_t i = 0; i < n_match; i++) { + slot.cache_tokens.set_token(head_p + i, slot.cache_tokens[head_c + i]); + slot.n_past++; + } + + head_c += n_match; + head_p += n_match; + } else { + head_c += 1; + } + } + + SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past); + } + } else { + // if we don't cache the prompt, we have to remove the entire KV cache + slot.n_past = 0; + } + + if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) { + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); + if (pos_min == -1) { + SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min); + GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); + } + + const auto n_swa = llama_model_n_swa(model); + if (pos_min > std::max(0, slot.n_past - n_swa)) { + SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa); + SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n", + "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); + slot.n_past = 0; + } + } + } + + if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { + SLT_WRN(slot, "need to evaluate at least 1 token for each active slot, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens); + + slot.n_past--; + } + + slot.n_prompt_tokens_processed = 0; + } + + if (!slot.can_split()) { + // cannot fit the prompt in the current batch - will try next iter + if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { + continue; + } + } + + // keep only the common part + if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1)) { + // could not partially delete (likely using a non-Transformer model) + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); + + // there is no common part left + slot.n_past = 0; + } + + SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); + + // remove the non-common part from the cache + slot.cache_tokens.keep_first(slot.n_past); + + // check if we should process the image + if (slot.n_past < slot.n_prompt_tokens && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { + // process the image + int32_t new_n_past; + int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); + int32_t n_pos = new_n_past - slot.n_past; + + if (res != 0) { + SLT_ERR(slot, "failed to process image, res = %d\n", res); + slot.release(); + send_error(slot, "failed to process image", ERROR_TYPE_SERVER); + continue; + } + + // add the image chunk to cache + { + const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past); + slot.cache_tokens.push_back(chunk.get()); // copy + } + + slot.n_past += n_pos; + slot.n_prompt_tokens_processed += n_pos; + } + + // add prompt tokens for processing in the current batch + while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + // get next token to process + llama_token cur_tok = slot.prompt_tokens[slot.n_past]; + if (cur_tok == LLAMA_TOKEN_NULL) { + break; // end of text chunk + } + + // embedding requires all tokens in the batch to be output + const bool need_embd = server_task_type_need_embd(slot.task_type); + + common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); + slot.cache_tokens.push_back(cur_tok); + + slot.n_prompt_tokens_processed++; + slot.n_past++; + } + + // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); + + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + + // entire prompt has been processed + if (slot.n_past == slot.n_prompt_tokens) { + slot.state = SLOT_STATE_DONE_PROMPT; + + GGML_ASSERT(batch.n_tokens > 0); + GGML_ASSERT((size_t) slot.n_prompt_tokens == slot.prompt_tokens.size()); + + common_sampler_reset(slot.smpl); + + // Process all prompt tokens through sampler system + for (int i = 0; i < slot.n_prompt_tokens; ++i) { + llama_token id = slot.prompt_tokens[i]; + if (id != LLAMA_TOKEN_NULL) { + common_sampler_accept(slot.smpl, id, false); + } + } + + // extract the logits only for the last token + batch.logits[batch.n_tokens - 1] = true; + + slot.n_decoded = 0; + slot.i_batch = batch.n_tokens - 1; + + SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); + } + } + + if (batch.n_tokens >= n_batch) { + break; + } + } + } + + if (batch.n_tokens == 0) { + SRV_WRN("%s", "no tokens to decode\n"); + return; + } + + SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); + + if (slot_batched) { + // apply lora, only need to do it once per batch + common_set_adapter_lora(ctx, slot_batched->lora); + + llama_set_embeddings(ctx, slot_batched->need_embd()); + } + + // pad the batch so that batch.n_tokens >= n_slots + // TODO: temporary workaround for https://github.com/ggml-org/llama.cpp/issues/13689 + if (slot_batched->need_embd()) { + const int n_slots = slots.size(); + + if (batch.n_tokens < n_slots) { + std::set seq_ids; + for (int j = 0; j < batch.n_tokens; ++j) { + seq_ids.insert(batch.seq_id[j][0]); + } + + // find unused sequence id + llama_seq_id seq_id = -1; + for (int i = 0; i < n_slots; ++i) { + if (seq_ids.find(i) == seq_ids.end()) { + seq_id = i; + } + } + + const int n_add = n_slots - batch.n_tokens; + + SRV_WRN("adding %d dummy tokens to the batch, seq_id = %d\n", n_add, seq_id); + + for (int j = 0; j < n_add; ++j) { + common_batch_add(batch, 0, j, { seq_id }, true); + } + + slots[seq_id].cache_tokens.clear(); + llama_memory_seq_rm(llama_get_memory(ctx), seq_id, -1, -1); + } + } + + int32_t i_next = 0; + + // process the created batch of tokens + for (int32_t i = 0; i < batch.n_tokens; i = i_next) { + const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); + + llama_batch batch_view = { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, + }; + + const int ret = llama_decode(ctx, batch_view); + + metrics.on_decoded(slots); + + if (ret != 0) { + { + std::string err; + + if (n_batch == 1 && ret == 1) { + err = "Context size has been exceeded."; + } + + if (ret == -1) { + err = "Invalid input batch."; + } + + if (ret < -1) { + err = "Compute error."; + } + + if (!err.empty()) { + SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); + for (auto & slot : slots) { + slot.release(); + send_error(slot, err); + } + break; + } + } + + // retry with half the batch size to try to find a free slot in the KV cache + n_batch /= 2; + + SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); + + continue; // continue loop of n_batch + } + + // move the head of the batch forward with the number of tokens we just processed + i_next = i + n_tokens; + + // on successful decode, restore the original batch size + n_batch = llama_n_batch(ctx); + + for (auto & slot : slots) { + if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { + continue; // continue loop of slots + } + + if (slot.state == SLOT_STATE_DONE_PROMPT) { + if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) { + // prompt evaluated for embedding + send_embedding(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + + if (slot.task_type == SERVER_TASK_TYPE_RERANK) { + send_rerank(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + + // prompt evaluated for next-token prediction + slot.state = SLOT_STATE_GENERATING; + } else if (slot.state != SLOT_STATE_GENERATING) { + continue; // continue loop of slots + } + + const int tok_idx = slot.i_batch - i; + + llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); + + slot.i_batch = -1; + + common_sampler_accept(slot.smpl, id, true); + + slot.n_decoded += 1; + + const int64_t t_current = ggml_time_us(); + + if (slot.n_decoded == 1) { + slot.t_start_generation = t_current; + slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; + metrics.on_prompt_eval(slot); + } + + slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3; + + completion_token_output result; + result.tok = id; + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs + + if (slot.params.sampling.n_probs > 0) { + populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx); + } + + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.release(); + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + continue; + } + } + + // do speculative decoding + for (auto & slot : slots) { + if (!slot.is_processing() || !slot.can_speculate()) { + continue; + } + + if (slot.state != SLOT_STATE_GENERATING) { + continue; + } + + if (mctx) { + // we should never reach this, as speculative is automatically disabled if mmproj is loaded + GGML_ABORT("not supported by multimodal"); + } + + // determine the max draft that fits the current slot state + int n_draft_max = slot.params.speculative.n_max; + + // note: n_past is not yet increased for the `id` token sampled above + // also, need to leave space for 1 extra token to allow context shifts + n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2); + + if (slot.n_remaining > 0) { + n_draft_max = std::min(n_draft_max, slot.n_remaining - 1); + } + + SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); + + if (n_draft_max < slot.params.speculative.n_min) { + SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min); + + continue; + } + + llama_token id = slot.sampled; + + struct common_speculative_params params_spec; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; + params_spec.p_min = slot.params.speculative.p_min; + + const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens(); + llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + + // ignore small drafts + if (slot.params.speculative.n_min > (int) draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min); + + continue; + } + + // keep track of total number of drafted tokens tested + slot.n_draft_total += draft.size(); + + // construct the speculation batch + common_batch_clear(slot.batch_spec); + common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true); + + for (size_t i = 0; i < draft.size(); ++i) { + common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true); + } + + SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); + + llama_decode(ctx, slot.batch_spec); + + // the accepted tokens from the speculation + const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); + + slot.n_past += ids.size(); + slot.n_decoded += ids.size(); + + // update how many tokens out of those tested were accepted + slot.n_draft_accepted += ids.size() - 1; + + slot.cache_tokens.push_back(id); + slot.cache_tokens.insert({ids.begin(), ids.end() - 1}); + + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1); + + for (size_t i = 0; i < ids.size(); ++i) { + completion_token_output result; + + result.tok = ids[i]; + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // set later + + // TODO: set result.probs + + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.release(); + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + break; + } + } + + SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past); + } + } + + SRV_DBG("%s", "run slots completed\n"); + } + + json model_meta() const { + return json { + {"vocab_type", llama_vocab_type (vocab)}, + {"n_vocab", llama_vocab_n_tokens (vocab)}, + {"n_ctx_train", llama_model_n_ctx_train(model)}, + {"n_embd", llama_model_n_embd (model)}, + {"n_params", llama_model_n_params (model)}, + {"size", llama_model_size (model)}, + }; + } +}; + +static void log_server_request(const httplib::Request & req, const httplib::Response & res) { + // skip GH copilot requests when using default port + if (req.path == "/v1/health" || req.path == "/v1/completions") { + return; + } + + // reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch + + SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status); + + SRV_DBG("request: %s\n", req.body.c_str()); + SRV_DBG("response: %s\n", res.body.c_str()); +} + +std::function shutdown_handler; +std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; + +inline void signal_handler(int signal) { + if (is_terminating.test_and_set()) { + // in case it hangs, we can force terminate the server by hitting Ctrl+C twice + // this is for better developer experience, we can remove when the server is stable enough + fprintf(stderr, "Received second interrupt, terminating immediately.\n"); + exit(1); + } + + shutdown_handler(signal); +} + +int main(int argc, char ** argv) { + // own arguments required by this example + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) { + return 1; + } + + common_init(); + + // struct that contains llama context and inference + server_context ctx_server; + + llama_backend_init(); + llama_numa_init(params.numa); + + LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + LOG_INF("\n"); + + std::unique_ptr svr; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (params.ssl_file_key != "" && params.ssl_file_cert != "") { + LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str()); + svr.reset( + new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str()) + ); + } else { + LOG_INF("Running without SSL\n"); + svr.reset(new httplib::Server()); + } +#else + if (params.ssl_file_key != "" && params.ssl_file_cert != "") { + LOG_ERR("Server is built without SSL support\n"); + return 1; + } + svr.reset(new httplib::Server()); +#endif + + std::atomic state{SERVER_STATE_LOADING_MODEL}; + + svr->set_default_headers({{"Server", "llama.cpp"}}); + svr->set_logger(log_server_request); + + auto res_error = [](httplib::Response & res, const json & error_data) { + json final_response {{"error", error_data}}; + res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON); + res.status = json_value(error_data, "code", 500); + }; + + auto res_ok = [](httplib::Response & res, const json & data) { + res.set_content(safe_json_to_str(data), MIMETYPE_JSON); + res.status = 200; + }; + + svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) { + std::string message; + try { + std::rethrow_exception(ep); + } catch (const std::exception & e) { + message = e.what(); + } catch (...) { + message = "Unknown Exception"; + } + + try { + json formatted_error = format_error_response(message, ERROR_TYPE_SERVER); + LOG_WRN("got exception: %s\n", formatted_error.dump().c_str()); + res_error(res, formatted_error); + } catch (const std::exception & e) { + LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str()); + } + }); + + svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) { + if (res.status == 404) { + res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); + } + // for other error codes, we skip processing here because it's already done by res_error() + }); + + // set timeouts and change hostname and port + svr->set_read_timeout (params.timeout_read); + svr->set_write_timeout(params.timeout_write); + + std::unordered_map log_data; + + log_data["hostname"] = params.hostname; + log_data["port"] = std::to_string(params.port); + + if (params.api_keys.size() == 1) { + auto key = params.api_keys[0]; + log_data["api_key"] = "api_key: ****" + key.substr(std::max((int)(key.length() - 4), 0)); + } else if (params.api_keys.size() > 1) { + log_data["api_key"] = "api_key: " + std::to_string(params.api_keys.size()) + " keys loaded"; + } + + // Necessary similarity of prompt for slot selection + ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; + + // + // Middlewares + // + + auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { + static const std::unordered_set public_endpoints = { + "/health", + "/models", + "/v1/models", + "/api/tags" + }; + + // If API key is not set, skip validation + if (params.api_keys.empty()) { + return true; + } + + // If path is public or is static file, skip validation + if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") { + return true; + } + + // Check for API key in the header + auto auth_header = req.get_header_value("Authorization"); + + std::string prefix = "Bearer "; + if (auth_header.substr(0, prefix.size()) == prefix) { + std::string received_api_key = auth_header.substr(prefix.size()); + if (std::find(params.api_keys.begin(), params.api_keys.end(), received_api_key) != params.api_keys.end()) { + return true; // API key is valid + } + } + + // API key is invalid or not provided + res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION)); + + LOG_WRN("Unauthorized: Invalid API Key\n"); + + return false; + }; + + auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) { + server_state current_state = state.load(); + if (current_state == SERVER_STATE_LOADING_MODEL) { + auto tmp = string_split(req.path, '.'); + if (req.path == "/" || tmp.back() == "html") { + res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); + res.status = 503; + } else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") { + // allow the models endpoint to be accessed during loading + return true; + } else { + res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); + } + return false; + } + return true; + }; + + // register server middlewares + svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + // If this is OPTIONS request, skip validation because browsers don't include Authorization header + if (req.method == "OPTIONS") { + res.set_header("Access-Control-Allow-Credentials", "true"); + res.set_header("Access-Control-Allow-Methods", "GET, POST"); + res.set_header("Access-Control-Allow-Headers", "*"); + res.set_content("", "text/html"); // blank response, no data + return httplib::Server::HandlerResponse::Handled; // skip further processing + } + if (!middleware_server_state(req, res)) { + return httplib::Server::HandlerResponse::Handled; + } + if (!middleware_validate_api_key(req, res)) { + return httplib::Server::HandlerResponse::Handled; + } + return httplib::Server::HandlerResponse::Unhandled; + }); + + // + // Route handlers (or controllers) + // + + const auto handle_health = [&](const httplib::Request &, httplib::Response & res) { + // error and loading states are handled by middleware + json health = {{"status", "ok"}}; + res_ok(res, health); + }; + + const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) { + if (!params.endpoint_slots) { + res_error(res, format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + + // request slots data using task queue + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_METRICS); + task.id = task_id; + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task), true); // high-priority task + } + + // get the result + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res_error(res, result->to_json()); + return; + } + + // TODO: get rid of this dynamic_cast + auto res_metrics = dynamic_cast(result.get()); + GGML_ASSERT(res_metrics != nullptr); + + // optionally return "fail_on_no_slot" error + if (req.has_param("fail_on_no_slot")) { + if (res_metrics->n_idle_slots == 0) { + res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); + return; + } + } + + res_ok(res, res_metrics->slots_data); + }; + + const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { + if (!params.endpoint_metrics) { + res_error(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + + // request slots data using task queue + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_METRICS); + task.id = task_id; + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task), true); // high-priority task + } + + // get the result + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res_error(res, result->to_json()); + return; + } + + // TODO: get rid of this dynamic_cast + auto res_metrics = dynamic_cast(result.get()); + GGML_ASSERT(res_metrics != nullptr); + + // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names + json all_metrics_def = json { + {"counter", {{ + {"name", "prompt_tokens_total"}, + {"help", "Number of prompt tokens processed."}, + {"value", (uint64_t) res_metrics->n_prompt_tokens_processed_total} + }, { + {"name", "prompt_seconds_total"}, + {"help", "Prompt process time"}, + {"value", (uint64_t) res_metrics->t_prompt_processing_total / 1.e3} + }, { + {"name", "tokens_predicted_total"}, + {"help", "Number of generation tokens processed."}, + {"value", (uint64_t) res_metrics->n_tokens_predicted_total} + }, { + {"name", "tokens_predicted_seconds_total"}, + {"help", "Predict process time"}, + {"value", (uint64_t) res_metrics->t_tokens_generation_total / 1.e3} + }, { + {"name", "n_decode_total"}, + {"help", "Total number of llama_decode() calls"}, + {"value", res_metrics->n_decode_total} + }, { + {"name", "n_busy_slots_per_decode"}, + {"help", "Average number of busy slots per llama_decode() call"}, + {"value", (float) res_metrics->n_busy_slots_total / std::max((float) res_metrics->n_decode_total, 1.f)} + }}}, + {"gauge", {{ + {"name", "prompt_tokens_seconds"}, + {"help", "Average prompt throughput in tokens/s."}, + {"value", res_metrics->n_prompt_tokens_processed ? 1.e3 / res_metrics->t_prompt_processing * res_metrics->n_prompt_tokens_processed : 0.} + },{ + {"name", "predicted_tokens_seconds"}, + {"help", "Average generation throughput in tokens/s."}, + {"value", res_metrics->n_tokens_predicted ? 1.e3 / res_metrics->t_tokens_generation * res_metrics->n_tokens_predicted : 0.} + },{ + {"name", "requests_processing"}, + {"help", "Number of requests processing."}, + {"value", (uint64_t) res_metrics->n_processing_slots} + },{ + {"name", "requests_deferred"}, + {"help", "Number of requests deferred."}, + {"value", (uint64_t) res_metrics->n_tasks_deferred} + }}} + }; + + std::stringstream prometheus; + + for (const auto & el : all_metrics_def.items()) { + const auto & type = el.key(); + const auto & metrics_def = el.value(); + + for (const auto & metric_def : metrics_def) { + const std::string name = metric_def.at("name"); + const std::string help = metric_def.at("help"); + + auto value = json_value(metric_def, "value", 0.); + prometheus << "# HELP llamacpp:" << name << " " << help << "\n" + << "# TYPE llamacpp:" << name << " " << type << "\n" + << "llamacpp:" << name << " " << value << "\n"; + } + } + + res.set_header("Process-Start-Time-Unix", std::to_string(res_metrics->t_start)); + + res.set_content(prometheus.str(), "text/plain; version=0.0.4"); + res.status = 200; // HTTP OK + }; + + const auto handle_slots_save = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { + json request_data = json::parse(req.body); + std::string filename = request_data.at("filename"); + if (!fs_validate_filename(filename)) { + res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return; + } + std::string filepath = params.slot_save_path + filename; + + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_SAVE); + task.id = task_id; + task.slot_action.slot_id = id_slot; + task.slot_action.filename = filename; + task.slot_action.filepath = filepath; + + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); + } + + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res_error(res, result->to_json()); + return; + } + + res_ok(res, result->to_json()); + }; + + const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { + json request_data = json::parse(req.body); + std::string filename = request_data.at("filename"); + if (!fs_validate_filename(filename)) { + res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return; + } + std::string filepath = params.slot_save_path + filename; + + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_RESTORE); + task.id = task_id; + task.slot_action.slot_id = id_slot; + task.slot_action.filename = filename; + task.slot_action.filepath = filepath; + + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); + } + + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res_error(res, result->to_json()); + return; + } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res_ok(res, result->to_json()); + }; + + const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_ERASE); + task.id = task_id; + task.slot_action.slot_id = id_slot; + + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); + } + + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res_error(res, result->to_json()); + return; + } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res_ok(res, result->to_json()); + }; + + const auto handle_slots_action = [¶ms, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { + if (params.slot_save_path.empty()) { + res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + + std::string id_slot_str = req.path_params.at("id_slot"); + int id_slot; + + try { + id_slot = std::stoi(id_slot_str); + } catch (const std::exception &) { + res_error(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + std::string action = req.get_param_value("action"); + + if (action == "save") { + handle_slots_save(req, res, id_slot); + } else if (action == "restore") { + handle_slots_restore(req, res, id_slot); + } else if (action == "erase") { + handle_slots_erase(req, res, id_slot); + } else { + res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); + } + }; + + const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { + // this endpoint is publicly available, please only return what is safe to be exposed + json data = { + { "default_generation_settings", ctx_server.default_generation_settings_for_props }, + { "total_slots", ctx_server.params_base.n_parallel }, + { "model_path", ctx_server.params_base.model.path }, + { "modalities", json{ + {"vision", ctx_server.oai_parser_opt.allow_image}, + {"audio", ctx_server.oai_parser_opt.allow_audio}, + } }, + { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) }, + { "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)}, + { "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)}, + { "build_info", build_info }, + }; + if (ctx_server.params_base.use_jinja) { + if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) { + data["chat_template_tool_use"] = tool_use_src; + } + } + + res_ok(res, data); + }; + + const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + if (!ctx_server.params_base.endpoint_props) { + res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + + json data = json::parse(req.body); + + // update any props here + + res_ok(res, {{ "success", true }}); + }; + + const auto handle_api_show = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { + json data = { + { + "template", common_chat_templates_source(ctx_server.chat_templates.get()), + }, + { + "model_info", { + { "llama.context_length", ctx_server.slots.back().n_ctx, }, + } + }, + {"modelfile", ""}, + {"parameters", ""}, + {"template", common_chat_templates_source(ctx_server.chat_templates.get())}, + {"details", { + {"parent_model", ""}, + {"format", "gguf"}, + {"family", ""}, + {"families", {""}}, + {"parameter_size", ""}, + {"quantization_level", ""} + }}, + {"model_info", ""}, + {"capabilities", {"completion"}} + }; + + res_ok(res, data); + }; + + // handle completion-like requests (completion, chat, infill) + // we can optionally provide a custom format for partial results and final results + const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok]( + server_task_type type, + json & data, + const std::vector & files, + const std::function & is_connection_closed, + httplib::Response & res, + oaicompat_type oaicompat) -> void { + GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); + + auto completion_id = gen_chatcmplid(); + std::unordered_set task_ids; + try { + std::vector tasks; + + const auto & prompt = data.at("prompt"); + // TODO: this log can become very long, put it behind a flag or think about a more compact format + //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); + + // process files + mtmd::bitmaps bitmaps; + const bool has_mtmd = ctx_server.mctx != nullptr; + { + if (!has_mtmd && !files.empty()) { + throw std::runtime_error("This server does not support multimodal"); + } + for (auto & file : files) { + mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(ctx_server.mctx, file.data(), file.size())); + if (!bmp.ptr) { + throw std::runtime_error("Failed to load image or audio file"); + } + // calculate bitmap hash (for KV caching) + std::string hash = fnv_hash(bmp.data(), bmp.n_bytes()); + bmp.set_id(hash.c_str()); + bitmaps.entries.push_back(std::move(bmp)); + } + } + + // process prompt + std::vector inputs; + if (oaicompat && !prompt.is_string()) { + throw std::runtime_error("prompt must be a string"); + } + + if (oaicompat && has_mtmd) { + // multimodal + std::string prompt_str = prompt.get(); + mtmd_input_text inp_txt = { + prompt_str.c_str(), + /* add_special */ true, + /* parse_special */ true, + }; + mtmd::input_chunks chunks(mtmd_input_chunks_init()); + auto bitmaps_c_ptr = bitmaps.c_ptr(); + int32_t tokenized = mtmd_tokenize(ctx_server.mctx, + chunks.ptr.get(), + &inp_txt, + bitmaps_c_ptr.data(), + bitmaps_c_ptr.size()); + if (tokenized != 0) { + throw std::runtime_error("Failed to tokenize prompt"); + } + + server_tokens tmp(chunks, true); + inputs.push_back(std::move(tmp)); + } else { + // non-multimodal version + auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); + for (auto & p : tokenized_prompts) { + auto tmp = server_tokens(p, ctx_server.mctx != nullptr); + inputs.push_back(std::move(tmp)); + } + } + + tasks.reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + server_task task = server_task(type); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + + task.prompt_tokens = std::move(inputs[i]); + task.params = server_task::params_from_json_cmpl( + ctx_server.ctx, + ctx_server.params_base, + data); + task.id_selected_slot = json_value(data, "id_slot", -1); + + // OAI-compat + task.params.oaicompat = oaicompat; + task.params.oaicompat_cmpl_id = completion_id; + // oaicompat_model is already populated by params_from_json_cmpl + + tasks.push_back(std::move(task)); + } + + task_ids = server_task::get_list_id(tasks); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(std::move(tasks)); + } catch (const std::exception & e) { + res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + return; + } + + bool stream = json_value(data, "stream", false); + + if (!stream) { + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { + if (results.size() == 1) { + // single result + res_ok(res, results[0]->to_json()); + } else { + // multiple results (multitask) + json arr = json::array(); + for (auto & res : results) { + arr.push_back(res->to_json()); + } + res_ok(res, arr); + } + }, [&](const json & error_data) { + res_error(res, error_data); + }, is_connection_closed); + + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + } else { + const auto chunked_content_provider = [task_ids, &ctx_server, oaicompat](size_t, httplib::DataSink & sink) { + ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool { + json res_json = result->to_json(); + if (res_json.is_array()) { + for (const auto & res : res_json) { + if (!server_sent_event(sink, "data", res)) { + // sending failed (HTTP connection closed), cancel the generation + return false; + } + } + return true; + } else { + return server_sent_event(sink, "data", res_json); + } + }, [&](const json & error_data) { + server_sent_event(sink, "error", error_data); + }, [&sink]() { + // note: do not use req.is_connection_closed here because req is already destroyed + return !sink.is_writable(); + }); + if (oaicompat != OAICOMPAT_TYPE_NONE) { + static const std::string ev_done = "data: [DONE]\n\n"; + sink.write(ev_done.data(), ev_done.size()); + } + sink.done(); + return false; + }; + + auto on_complete = [task_ids, &ctx_server] (bool) { + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + }; + + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + } + }; + + const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + json data = json::parse(req.body); + std::vector files; // dummy + handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + data, + files, + req.is_connection_closed, + res, + OAICOMPAT_TYPE_NONE); + }; + + const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + json data = oaicompat_completion_params_parse(json::parse(req.body)); + std::vector files; // dummy + handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + data, + files, + req.is_connection_closed, + res, + OAICOMPAT_TYPE_COMPLETION); + }; + + const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + // check model compatibility + std::string err; + if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + err += "prefix token is missing. "; + } + if (llama_vocab_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + err += "suffix token is missing. "; + } + if (llama_vocab_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + err += "middle token is missing. "; + } + if (!err.empty()) { + res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED)); + return; + } + + json data = json::parse(req.body); + + // validate input + if (data.contains("prompt") && !data.at("prompt").is_string()) { + // prompt is optional + res_error(res, format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + } + + if (!data.contains("input_prefix")) { + res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST)); + } + + if (!data.contains("input_suffix")) { + res_error(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST)); + } + + if (data.contains("input_extra") && !data.at("input_extra").is_array()) { + // input_extra is optional + res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + json input_extra = json_value(data, "input_extra", json::array()); + for (const auto & chunk : input_extra) { + // { "text": string, "filename": string } + if (!chunk.contains("text") || !chunk.at("text").is_string()) { + res_error(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST)); + return; + } + // filename is optional + if (chunk.contains("filename") && !chunk.at("filename").is_string()) { + res_error(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST)); + return; + } + } + data["input_extra"] = input_extra; // default to empty array if it's not exist + + std::string prompt = json_value(data, "prompt", std::string()); + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, false, true); + SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); + data["prompt"] = format_infill( + ctx_server.vocab, + data.at("input_prefix"), + data.at("input_suffix"), + data.at("input_extra"), + ctx_server.params_base.n_batch, + ctx_server.params_base.n_predict, + ctx_server.slots[0].n_ctx, // TODO: there should be a better way + ctx_server.params_base.spm_infill, + tokenized_prompts[0] + ); + + std::vector files; // dummy + handle_completions_impl( + SERVER_TASK_TYPE_INFILL, + data, + files, + req.is_connection_closed, + res, + OAICOMPAT_TYPE_NONE); // infill is not OAI compatible + }; + + const auto handle_chat_completions = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + LOG_DBG("request: %s\n", req.body.c_str()); + + auto body = json::parse(req.body); + std::vector files; + json data = oaicompat_chat_params_parse( + body, + ctx_server.oai_parser_opt, + files); + + handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + data, + files, + req.is_connection_closed, + res, + OAICOMPAT_TYPE_CHAT); + }; + + // same with handle_chat_completions, but without inference part + const auto handle_apply_template = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { + auto body = json::parse(req.body); + std::vector files; // dummy, unused + json data = oaicompat_chat_params_parse( + body, + ctx_server.oai_parser_opt, + files); + res_ok(res, {{ "prompt", std::move(data.at("prompt")) }}); + }; + + const auto handle_models = [¶ms, &ctx_server, &state, &res_ok](const httplib::Request &, httplib::Response & res) { + server_state current_state = state.load(); + json model_meta = nullptr; + if (current_state == SERVER_STATE_READY) { + model_meta = ctx_server.model_meta(); + } + + json models = { + {"models", { + { + {"name", params.model_alias.empty() ? params.model.path : params.model_alias}, + {"model", params.model_alias.empty() ? params.model.path : params.model_alias}, + {"modified_at", ""}, + {"size", ""}, + {"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash + {"type", "model"}, + {"description", ""}, + {"tags", {""}}, + {"capabilities", {"completion"}}, + {"parameters", ""}, + {"details", { + {"parent_model", ""}, + {"format", "gguf"}, + {"family", ""}, + {"families", {""}}, + {"parameter_size", ""}, + {"quantization_level", ""} + }} + } + }}, + {"object", "list"}, + {"data", { + { + {"id", params.model_alias.empty() ? params.model.path : params.model_alias}, + {"object", "model"}, + {"created", std::time(0)}, + {"owned_by", "llamacpp"}, + {"meta", model_meta}, + }, + }} + }; + + res_ok(res, models); + }; + + const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { + const json body = json::parse(req.body); + + json tokens_response = json::array(); + if (body.count("content") != 0) { + const bool add_special = json_value(body, "add_special", false); + const bool with_pieces = json_value(body, "with_pieces", false); + + llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, true); + + if (with_pieces) { + for (const auto& token : tokens) { + std::string piece = common_token_to_piece(ctx_server.ctx, token); + json piece_json; + + // Check if the piece is valid UTF-8 + if (is_valid_utf8(piece)) { + piece_json = piece; + } else { + // If not valid UTF-8, store as array of byte values + piece_json = json::array(); + for (unsigned char c : piece) { + piece_json.push_back(static_cast(c)); + } + } + + tokens_response.push_back({ + {"id", token}, + {"piece", piece_json} + }); + } + } else { + tokens_response = tokens; + } + } + + const json data = format_tokenizer_response(tokens_response); + res_ok(res, data); + }; + + const auto handle_detokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { + const json body = json::parse(req.body); + + std::string content; + if (body.count("tokens") != 0) { + const llama_tokens tokens = body.at("tokens"); + content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend()); + } + + const json data = format_detokenized_response(content); + res_ok(res, data); + }; + + const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) { + if (!ctx_server.params_base.embedding) { + res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + + if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + const json body = json::parse(req.body); + + // for the shape of input/content, see tokenize_input_prompts() + json prompt; + if (body.count("input") != 0) { + prompt = body.at("input"); + } else if (body.contains("content")) { + oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible + prompt = body.at("content"); + } else { + res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + bool use_base64 = false; + if (body.count("encoding_format") != 0) { + const std::string& format = body.at("encoding_format"); + if (format == "base64") { + use_base64 = true; + } else if (format != "float") { + res_error(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); + return; + } + } + + auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); + for (const auto & tokens : tokenized_prompts) { + // this check is necessary for models that do not add BOS token to the input + if (tokens.empty()) { + res_error(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); + return; + } + } + + // create and queue the task + json responses = json::array(); + bool error = false; + std::unordered_set task_ids; + { + std::vector tasks; + for (size_t i = 0; i < tokenized_prompts.size(); i++) { + server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.prompt_tokens = server_tokens(tokenized_prompts[i], ctx_server.mctx != nullptr); + + // OAI-compat + task.params.oaicompat = oaicompat; + + tasks.push_back(std::move(task)); + } + + task_ids = server_task::get_list_id(tasks); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(std::move(tasks)); + } + + // get the result + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { + for (auto & res : results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); + } + }, [&](const json & error_data) { + res_error(res, error_data); + error = true; + }, req.is_connection_closed); + + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + + if (error) { + return; + } + + // write JSON response + json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING + ? format_embeddings_response_oaicompat(body, responses, use_base64) + : json(responses); + res_ok(res, root); + }; + + const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { + handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE); + }; + + const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { + handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING); + }; + + const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) { + res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + + const json body = json::parse(req.body); + + // TODO: implement + //int top_n = 1; + //if (body.count("top_n") != 1) { + // top_n = body.at("top_n"); + //} else { + // res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + // return; + //} + + // if true, use TEI API format, otherwise use Jina API format + // Jina: https://jina.ai/reranker/ + // TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank + bool is_tei_format = body.contains("texts"); + + json query; + if (body.count("query") == 1) { + query = body.at("query"); + if (!query.is_string()) { + res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + return; + } + } else { + res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + std::vector documents = json_value(body, "documents", + json_value(body, "texts", std::vector())); + if (documents.empty()) { + res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.vocab, query, /* add_special */ false, true)[0]; + + // create and queue the task + json responses = json::array(); + bool error = false; + std::unordered_set task_ids; + { + std::vector tasks; + auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true); + tasks.reserve(tokenized_docs.size()); + for (size_t i = 0; i < tokenized_docs.size(); i++) { + auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]); + server_task task = server_task(SERVER_TASK_TYPE_RERANK); + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.prompt_tokens = server_tokens(tmp, ctx_server.mctx != nullptr); + tasks.push_back(std::move(task)); + } + + task_ids = server_task::get_list_id(tasks); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(std::move(tasks)); + } + + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { + for (auto & res : results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); + } + }, [&](const json & error_data) { + res_error(res, error_data); + error = true; + }, req.is_connection_closed); + + if (error) { + return; + } + + // write JSON response + json root = format_response_rerank( + body, + responses, + is_tei_format, + documents); + + res_ok(res, root); + }; + + const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { + json result = json::array(); + const auto & loras = ctx_server.params_base.lora_adapters; + for (size_t i = 0; i < loras.size(); ++i) { + auto & lora = loras[i]; + result.push_back({ + {"id", i}, + {"path", lora.path}, + {"scale", lora.scale}, + }); + } + res_ok(res, result); + res.status = 200; // HTTP OK + }; + + const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) { + const json body = json::parse(req.body); + if (!body.is_array()) { + res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SET_LORA); + task.id = task_id; + task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body); + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); + } + + // get the result + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res_error(res, result->to_json()); + return; + } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res_ok(res, result->to_json()); + }; + + // + // Router + // + + if (!params.webui) { + LOG_INF("Web UI is disabled\n"); + } else { + // register static assets routes + if (!params.public_path.empty()) { + // Set the base directory for serving static files + bool is_found = svr->set_mount_point("/", params.public_path); + if (!is_found) { + LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str()); + return 1; + } + } else { + // using embedded static index.html + svr->Get("/", [](const httplib::Request & req, httplib::Response & res) { + if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) { + res.set_content("Error: gzip is not supported by this browser", "text/plain"); + } else { + res.set_header("Content-Encoding", "gzip"); + // COEP and COOP headers, required by pyodide (python interpreter) + res.set_header("Cross-Origin-Embedder-Policy", "require-corp"); + res.set_header("Cross-Origin-Opener-Policy", "same-origin"); + res.set_content(reinterpret_cast(index_html_gz), index_html_gz_len, "text/html; charset=utf-8"); + } + return false; + }); + } + } + + // register API routes + svr->Get ("/health", handle_health); // public endpoint (no API key check) + svr->Get ("/metrics", handle_metrics); + svr->Get ("/props", handle_props); + svr->Post("/props", handle_props_change); + svr->Post("/api/show", handle_api_show); + svr->Get ("/models", handle_models); // public endpoint (no API key check) + svr->Get ("/v1/models", handle_models); // public endpoint (no API key check) + svr->Get ("/api/tags", handle_models); // ollama specific endpoint. public endpoint (no API key check) + svr->Post("/completion", handle_completions); // legacy + svr->Post("/completions", handle_completions); + svr->Post("/v1/completions", handle_completions_oai); + svr->Post("/chat/completions", handle_chat_completions); + svr->Post("/v1/chat/completions", handle_chat_completions); + svr->Post("/api/chat", handle_chat_completions); // ollama specific endpoint + svr->Post("/infill", handle_infill); + svr->Post("/embedding", handle_embeddings); // legacy + svr->Post("/embeddings", handle_embeddings); + svr->Post("/v1/embeddings", handle_embeddings_oai); + svr->Post("/rerank", handle_rerank); + svr->Post("/reranking", handle_rerank); + svr->Post("/v1/rerank", handle_rerank); + svr->Post("/v1/reranking", handle_rerank); + svr->Post("/tokenize", handle_tokenize); + svr->Post("/detokenize", handle_detokenize); + svr->Post("/apply-template", handle_apply_template); + // LoRA adapters hotswap + svr->Get ("/lora-adapters", handle_lora_adapters_list); + svr->Post("/lora-adapters", handle_lora_adapters_apply); + // Save & load slots + svr->Get ("/slots", handle_slots); + svr->Post("/slots/:id_slot", handle_slots_action); + + // + // Start the server + // + if (params.n_threads_http < 1) { + // +2 threads for monitoring endpoints + params.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); + } + log_data["n_threads_http"] = std::to_string(params.n_threads_http); + svr->new_task_queue = [¶ms] { return new httplib::ThreadPool(params.n_threads_http); }; + + // clean up function, to be called before exit + auto clean_up = [&svr, &ctx_server]() { + SRV_INF("%s: cleaning up before exit...\n", __func__); + svr->stop(); + ctx_server.queue_results.terminate(); + llama_backend_free(); + }; + + bool was_bound = false; + bool is_sock = false; + if (string_ends_with(std::string(params.hostname), ".sock")) { + is_sock = true; + LOG_INF("%s: setting address family to AF_UNIX\n", __func__); + svr->set_address_family(AF_UNIX); + // bind_to_port requires a second arg, any value other than 0 should + // simply get ignored + was_bound = svr->bind_to_port(params.hostname, 8080); + } else { + LOG_INF("%s: binding port with default address family\n", __func__); + // bind HTTP listen port + if (params.port == 0) { + int bound_port = svr->bind_to_any_port(params.hostname); + if ((was_bound = (bound_port >= 0))) { + params.port = bound_port; + } + } else { + was_bound = svr->bind_to_port(params.hostname, params.port); + } + } + + if (!was_bound) { + LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, params.hostname.c_str(), params.port); + clean_up(); + return 1; + } + + // run the HTTP server in a thread + std::thread t([&]() { svr->listen_after_bind(); }); + svr->wait_until_ready(); + + LOG_INF("%s: HTTP server is listening, hostname: %s, port: %d, http threads: %d\n", __func__, params.hostname.c_str(), params.port, params.n_threads_http); + + // load the model + LOG_INF("%s: loading model\n", __func__); + + if (!ctx_server.load_model(params)) { + clean_up(); + t.join(); + LOG_ERR("%s: exiting due to model loading error\n", __func__); + return 1; + } + + ctx_server.init(); + state.store(SERVER_STATE_READY); + + LOG_INF("%s: model loaded\n", __func__); + + // print sample chat example to make it clear which template is used + LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, + common_chat_templates_source(ctx_server.chat_templates.get()), + common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str()); + + ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) { + ctx_server.process_single_task(std::move(task)); + }); + + ctx_server.queue_tasks.on_update_slots([&ctx_server]() { + ctx_server.update_slots(); + }); + + shutdown_handler = [&](int) { + // this will unblock start_loop() + ctx_server.queue_tasks.terminate(); + }; + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) + struct sigaction sigint_action; + sigint_action.sa_handler = signal_handler; + sigemptyset (&sigint_action.sa_mask); + sigint_action.sa_flags = 0; + sigaction(SIGINT, &sigint_action, NULL); + sigaction(SIGTERM, &sigint_action, NULL); +#elif defined (_WIN32) + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { + return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; + }; + SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); +#endif + + LOG_INF("%s: server is listening on %s - starting the main loop\n", __func__, + is_sock ? string_format("unix://%s", params.hostname.c_str()).c_str() : + string_format("http://%s:%d", params.hostname.c_str(), params.port).c_str()); + + // this call blocks the main thread until queue_tasks.terminate() is called + ctx_server.queue_tasks.start_loop(); + + clean_up(); + t.join(); + + return 0; +} diff --git a/tools/server/tests/.gitignore b/tools/server/tests/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..90ee7fe6d971a981c4d0f8b70eca2556f7840cd9 --- /dev/null +++ b/tools/server/tests/.gitignore @@ -0,0 +1,2 @@ +.venv +tmp diff --git a/tools/server/tests/README.md b/tools/server/tests/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cb87db035e2d6d626c5fe811d2b508d6f3a4073c --- /dev/null +++ b/tools/server/tests/README.md @@ -0,0 +1,66 @@ +# Server tests + +Python based server tests scenario using [pytest](https://docs.pytest.org/en/stable/). + +Tests target GitHub workflows job runners with 4 vCPU. + +Note: If the host architecture inference speed is faster than GitHub runners one, parallel scenario may randomly fail. +To mitigate it, you can increase values in `n_predict`, `kv_size`. + +### Install dependencies + +`pip install -r requirements.txt` + +### Run tests + +1. Build the server + +```shell +cd ../../.. +cmake -B build +cmake --build build --target llama-server +``` + +2. Start the test: `./tests.sh` + +It's possible to override some scenario steps values with environment variables: + +| variable | description | +|--------------------------|------------------------------------------------------------------------------------------------| +| `PORT` | `context.server_port` to set the listening port of the server during scenario, default: `8080` | +| `LLAMA_SERVER_BIN_PATH` | to change the server binary path, default: `../../../build/bin/llama-server` | +| `DEBUG` | to enable steps and server verbose mode `--verbose` | +| `N_GPU_LAYERS` | number of model layers to offload to VRAM `-ngl --n-gpu-layers` | +| `LLAMA_CACHE` | by default server tests re-download models to the `tmp` subfolder. Set this to your cache (e.g. `$HOME/Library/Caches/llama.cpp` on Mac or `$HOME/.cache/llama.cpp` on Unix) to avoid this | + +To run slow tests (will download many models, make sure to set `LLAMA_CACHE` if needed): + +```shell +SLOW_TESTS=1 ./tests.sh +``` + +To run with stdout/stderr display in real time (verbose output, but useful for debugging): + +```shell +DEBUG=1 ./tests.sh -s -v -x +``` + +To run all the tests in a file: + +```shell +./tests.sh unit/test_chat_completion.py -v -x +``` + +To run a single test: + +```shell +./tests.sh unit/test_chat_completion.py::test_invalid_chat_completion_req +``` + +Hint: You can compile and run test in single command, useful for local developement: + +```shell +cmake --build build -j --target llama-server && ./tools/server/tests/tests.sh +``` + +To see all available arguments, please refer to [pytest documentation](https://docs.pytest.org/en/stable/how-to/usage.html) diff --git a/tools/server/tests/conftest.py b/tools/server/tests/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..017d1bb841efd099e14338a6c6e206ec3b35a278 --- /dev/null +++ b/tools/server/tests/conftest.py @@ -0,0 +1,15 @@ +import pytest +from utils import * + + +# ref: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test +@pytest.fixture(autouse=True) +def stop_server_after_each_test(): + # do nothing before each test + yield + # stop all servers after each test + instances = set( + server_instances + ) # copy the set to prevent 'Set changed size during iteration' + for server in instances: + server.stop() diff --git a/tools/server/tests/pytest.ini b/tools/server/tests/pytest.ini new file mode 100644 index 0000000000000000000000000000000000000000..6df308df74d571a022209782bd1bd14f0ab59756 --- /dev/null +++ b/tools/server/tests/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +markers = + slow: marks tests as slow (deselect with '-m "not slow"') + serial diff --git a/tools/server/tests/requirements.txt b/tools/server/tests/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..15d024914e8412a414f367c732e14042f87e640e --- /dev/null +++ b/tools/server/tests/requirements.txt @@ -0,0 +1,8 @@ +aiohttp~=3.9.3 +pytest~=8.3.3 +huggingface_hub~=0.23.2 +numpy~=1.26.4 +openai~=1.55.3 +prometheus-client~=0.20.0 +requests~=2.32.3 +wget~=3.2 diff --git a/tools/server/tests/tests.sh b/tools/server/tests/tests.sh new file mode 100644 index 0000000000000000000000000000000000000000..33fa8cc6464e21fe8f093b3b2f960c270b1039d5 --- /dev/null +++ b/tools/server/tests/tests.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +# make sure we are in the right directory +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd $SCRIPT_DIR + +set -eu + +if [[ "${SLOW_TESTS:-0}" == 1 ]]; then + # Slow tests for tool calls need quite a few models ahead of time to avoid timing out. + python $SCRIPT_DIR/../../../scripts/fetch_server_test_models.py +fi + +if [ $# -lt 1 ] +then + if [[ "${SLOW_TESTS:-0}" == 1 ]]; then + pytest -v -x + else + pytest -v -x -m "not slow" + fi +else + pytest "$@" +fi diff --git a/tools/server/tests/unit/test_basic.py b/tools/server/tests/unit/test_basic.py new file mode 100644 index 0000000000000000000000000000000000000000..1485de8ceb3fc4f0e25e8f1ef4dc8f5364b251bf --- /dev/null +++ b/tools/server/tests/unit/test_basic.py @@ -0,0 +1,96 @@ +import pytest +import requests +from utils import * + +server = ServerPreset.tinyllama2() + + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + + +def test_server_start_simple(): + global server + server.start() + res = server.make_request("GET", "/health") + assert res.status_code == 200 + + +def test_server_props(): + global server + server.start() + res = server.make_request("GET", "/props") + assert res.status_code == 200 + assert ".gguf" in res.body["model_path"] + assert res.body["total_slots"] == server.n_slots + default_val = res.body["default_generation_settings"] + assert server.n_ctx is not None and server.n_slots is not None + assert default_val["n_ctx"] == server.n_ctx / server.n_slots + assert default_val["params"]["seed"] == server.seed + + +def test_server_models(): + global server + server.start() + res = server.make_request("GET", "/models") + assert res.status_code == 200 + assert len(res.body["data"]) == 1 + assert res.body["data"][0]["id"] == server.model_alias + + +def test_server_slots(): + global server + + # without slots endpoint enabled, this should return error + server.server_slots = False + server.start() + res = server.make_request("GET", "/slots") + assert res.status_code == 501 # ERROR_TYPE_NOT_SUPPORTED + assert "error" in res.body + server.stop() + + # with slots endpoint enabled, this should return slots info + server.server_slots = True + server.n_slots = 2 + server.start() + res = server.make_request("GET", "/slots") + assert res.status_code == 200 + assert len(res.body) == server.n_slots + assert server.n_ctx is not None and server.n_slots is not None + assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots + assert "params" in res.body[0] + assert res.body[0]["params"]["seed"] == server.seed + + +def test_load_split_model(): + global server + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf" + server.model_alias = "tinyllama-split" + server.start() + res = server.make_request("POST", "/completion", data={ + "n_predict": 16, + "prompt": "Hello", + "temperature": 0.0, + }) + assert res.status_code == 200 + assert match_regex("(little|girl)+", res.body["content"]) + + +def test_no_webui(): + global server + # default: webui enabled + server.start() + url = f"http://{server.server_host}:{server.server_port}" + res = requests.get(url) + assert res.status_code == 200 + assert "" in res.text + server.stop() + + # with --no-webui + server.no_webui = True + server.start() + res = requests.get(url) + assert res.status_code == 404 diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py new file mode 100644 index 0000000000000000000000000000000000000000..1b5205f79d610b7206b9319188be1bd6c1424123 --- /dev/null +++ b/tools/server/tests/unit/test_chat_completion.py @@ -0,0 +1,330 @@ +import pytest +from openai import OpenAI +from utils import * + +server: ServerProcess + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + + +@pytest.mark.parametrize( + "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template", + [ + (None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", False, None), + (None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", True, None), + (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None), + (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None), + (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'), + (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None), + (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None), + (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None), + ] +) +def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template): + global server + server.jinja = jinja + server.chat_template = chat_template + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "model": model, + "max_tokens": max_tokens, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + }) + assert res.status_code == 200 + assert "cmpl" in res.body["id"] # make sure the completion id has the expected format + assert res.body["system_fingerprint"].startswith("b") + assert res.body["model"] == model if model is not None else server.model_alias + assert res.body["usage"]["prompt_tokens"] == n_prompt + assert res.body["usage"]["completion_tokens"] == n_predicted + choice = res.body["choices"][0] + assert "assistant" == choice["message"]["role"] + assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}' + assert choice["finish_reason"] == finish_reason + + +@pytest.mark.parametrize( + "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason", + [ + ("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"), + ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"), + ] +) +def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason): + global server + server.model_alias = None # try using DEFAULT_OAICOMPAT_MODEL + server.start() + res = server.make_stream_request("POST", "/chat/completions", data={ + "max_tokens": max_tokens, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + "stream": True, + }) + content = "" + last_cmpl_id = None + for i, data in enumerate(res): + choice = data["choices"][0] + if i == 0: + # Check first role message for stream=True + assert choice["delta"]["content"] is None + assert choice["delta"]["role"] == "assistant" + else: + assert "role" not in choice["delta"] + assert data["system_fingerprint"].startswith("b") + assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future + if last_cmpl_id is None: + last_cmpl_id = data["id"] + assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream + if choice["finish_reason"] in ["stop", "length"]: + assert data["usage"]["prompt_tokens"] == n_prompt + assert data["usage"]["completion_tokens"] == n_predicted + assert "content" not in choice["delta"] + assert match_regex(re_content, content) + assert choice["finish_reason"] == finish_reason + else: + assert choice["finish_reason"] is None + content += choice["delta"]["content"] or '' + + +def test_chat_completion_with_openai_library(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + res = client.chat.completions.create( + model="gpt-3.5-turbo-instruct", + messages=[ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ], + max_tokens=8, + seed=42, + temperature=0.8, + ) + assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b") + assert res.choices[0].finish_reason == "length" + assert res.choices[0].message.content is not None + assert match_regex("(Suddenly)+", res.choices[0].message.content) + + +def test_chat_template(): + global server + server.chat_template = "llama3" + server.debug = True # to get the "__verbose" object in the response + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 8, + "messages": [ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ] + }) + assert res.status_code == 200 + assert "__verbose" in res.body + assert res.body["__verbose"]["prompt"] == " <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + + +def test_apply_chat_template(): + global server + server.chat_template = "command-r" + server.start() + res = server.make_request("POST", "/apply-template", data={ + "messages": [ + {"role": "system", "content": "You are a test."}, + {"role": "user", "content":"Hi there"}, + ] + }) + assert res.status_code == 200 + assert "prompt" in res.body + assert res.body["prompt"] == "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a test.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" + + +@pytest.mark.parametrize("response_format,n_predicted,re_content", [ + ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""), + ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"), + ({"type": "json_schema", "json_schema": {"schema": {"const": "foooooo"}}}, 10, "\"foooooo\""), + ({"type": "json_object"}, 10, "(\\{|John)+"), + ({"type": "sound"}, 0, None), + # invalid response format (expected to fail) + ({"type": "json_object", "schema": 123}, 0, None), + ({"type": "json_object", "schema": {"type": 123}}, 0, None), + ({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None), +]) +def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None): + global server + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predicted, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Write an example"}, + ], + "response_format": response_format, + }) + if re_content is not None: + assert res.status_code == 200 + choice = res.body["choices"][0] + assert match_regex(re_content, choice["message"]["content"]) + else: + assert res.status_code != 200 + assert "error" in res.body + + +@pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [ + (False, {"const": "42"}, 6, "\"42\""), + (True, {"const": "42"}, 6, "\"42\""), +]) +def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str): + global server + server.jinja = jinja + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predicted, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Write an example"}, + ], + "json_schema": json_schema, + }) + assert res.status_code == 200, f'Expected 200, got {res.status_code}' + choice = res.body["choices"][0] + assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}' + + +@pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [ + (False, 'root ::= "a"{5,5}', 6, "a{5,5}"), + (True, 'root ::= "a"{5,5}', 6, "a{5,5}"), +]) +def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str): + global server + server.jinja = jinja + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predicted, + "messages": [ + {"role": "user", "content": "Does not matter what I say, does it?"}, + ], + "grammar": grammar, + }) + assert res.status_code == 200, res.body + choice = res.body["choices"][0] + assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"] + + +@pytest.mark.parametrize("messages", [ + None, + "string", + [123], + [{}], + [{"role": 123}], + [{"role": "system", "content": 123}], + # [{"content": "hello"}], # TODO: should not be a valid case + [{"role": "system", "content": "test"}, {}], +]) +def test_invalid_chat_completion_req(messages): + global server + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "messages": messages, + }) + assert res.status_code == 400 or res.status_code == 500 + assert "error" in res.body + + +def test_chat_completion_with_timings_per_token(): + global server + server.start() + res = server.make_stream_request("POST", "/chat/completions", data={ + "max_tokens": 10, + "messages": [{"role": "user", "content": "test"}], + "stream": True, + "timings_per_token": True, + }) + for i, data in enumerate(res): + if i == 0: + # Check first role message for stream=True + assert data["choices"][0]["delta"]["content"] is None + assert data["choices"][0]["delta"]["role"] == "assistant" + assert "timings" not in data, f'First event should not have timings: {data}' + else: + assert "role" not in data["choices"][0]["delta"] + assert "timings" in data + assert "prompt_per_second" in data["timings"] + assert "predicted_per_second" in data["timings"] + assert "predicted_n" in data["timings"] + assert data["timings"]["predicted_n"] <= 10 + + +def test_logprobs(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + res = client.chat.completions.create( + model="gpt-3.5-turbo-instruct", + temperature=0.0, + messages=[ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ], + max_tokens=5, + logprobs=True, + top_logprobs=10, + ) + output_text = res.choices[0].message.content + aggregated_text = '' + assert res.choices[0].logprobs is not None + assert res.choices[0].logprobs.content is not None + for token in res.choices[0].logprobs.content: + aggregated_text += token.token + assert token.logprob <= 0.0 + assert token.bytes is not None + assert len(token.top_logprobs) > 0 + assert aggregated_text == output_text + + +def test_logprobs_stream(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + res = client.chat.completions.create( + model="gpt-3.5-turbo-instruct", + temperature=0.0, + messages=[ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ], + max_tokens=5, + logprobs=True, + top_logprobs=10, + stream=True, + ) + output_text = '' + aggregated_text = '' + for i, data in enumerate(res): + choice = data.choices[0] + if i == 0: + # Check first role message for stream=True + assert choice.delta.content is None + assert choice.delta.role == "assistant" + else: + assert choice.delta.role is None + if choice.finish_reason is None: + if choice.delta.content: + output_text += choice.delta.content + assert choice.logprobs is not None + assert choice.logprobs.content is not None + for token in choice.logprobs.content: + aggregated_text += token.token + assert token.logprob <= 0.0 + assert token.bytes is not None + assert token.top_logprobs is not None + assert len(token.top_logprobs) > 0 + assert aggregated_text == output_text diff --git a/tools/server/tests/unit/test_completion.py b/tools/server/tests/unit/test_completion.py new file mode 100644 index 0000000000000000000000000000000000000000..f6909e9ae788438ca3125da421642c58e1f83c46 --- /dev/null +++ b/tools/server/tests/unit/test_completion.py @@ -0,0 +1,464 @@ +import pytest +import requests +import time +from openai import OpenAI +from utils import * + +server = ServerPreset.tinyllama2() + + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + +@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [ + ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False), + ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True), +]) +def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool): + global server + server.start() + res = server.make_request("POST", "/completion", data={ + "n_predict": n_predict, + "prompt": prompt, + "return_tokens": return_tokens, + }) + assert res.status_code == 200 + assert res.body["timings"]["prompt_n"] == n_prompt + assert res.body["timings"]["predicted_n"] == n_predicted + assert res.body["truncated"] == truncated + assert type(res.body["has_new_line"]) == bool + assert match_regex(re_content, res.body["content"]) + if return_tokens: + assert len(res.body["tokens"]) > 0 + assert all(type(tok) == int for tok in res.body["tokens"]) + else: + assert res.body["tokens"] == [] + + +@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [ + ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False), + ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False), +]) +def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool): + global server + server.start() + res = server.make_stream_request("POST", "/completion", data={ + "n_predict": n_predict, + "prompt": prompt, + "stream": True, + }) + content = "" + for data in res: + assert "stop" in data and type(data["stop"]) == bool + if data["stop"]: + assert data["timings"]["prompt_n"] == n_prompt + assert data["timings"]["predicted_n"] == n_predicted + assert data["truncated"] == truncated + assert data["stop_type"] == "limit" + assert type(data["has_new_line"]) == bool + assert "generation_settings" in data + assert server.n_predict is not None + assert data["generation_settings"]["n_predict"] == min(n_predict, server.n_predict) + assert data["generation_settings"]["seed"] == server.seed + assert match_regex(re_content, content) + else: + assert len(data["tokens"]) > 0 + assert all(type(tok) == int for tok in data["tokens"]) + content += data["content"] + + +def test_completion_stream_vs_non_stream(): + global server + server.start() + res_stream = server.make_stream_request("POST", "/completion", data={ + "n_predict": 8, + "prompt": "I believe the meaning of life is", + "stream": True, + }) + res_non_stream = server.make_request("POST", "/completion", data={ + "n_predict": 8, + "prompt": "I believe the meaning of life is", + }) + content_stream = "" + for data in res_stream: + content_stream += data["content"] + assert content_stream == res_non_stream.body["content"] + + +def test_completion_with_openai_library(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + res = client.completions.create( + model="davinci-002", + prompt="I believe the meaning of life is", + max_tokens=8, + ) + assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b") + assert res.choices[0].finish_reason == "length" + assert res.choices[0].text is not None + assert match_regex("(going|bed)+", res.choices[0].text) + + +def test_completion_stream_with_openai_library(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + res = client.completions.create( + model="davinci-002", + prompt="I believe the meaning of life is", + max_tokens=8, + stream=True, + ) + output_text = '' + for data in res: + choice = data.choices[0] + if choice.finish_reason is None: + assert choice.text is not None + output_text += choice.text + assert match_regex("(going|bed)+", output_text) + + +# Test case from https://github.com/ggml-org/llama.cpp/issues/13780 +@pytest.mark.slow +def test_completion_stream_with_openai_library_stops(): + global server + server.model_hf_repo = "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M" + server.model_hf_file = None + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + res = client.completions.create( + model="davinci-002", + prompt="System: You are helpfull assistant.\nAssistant:\nHey! How could I help?\nUser:\nTell me a joke.\nAssistant:\n", + stop=["User:\n", "Assistant:\n"], + max_tokens=200, + stream=True, + ) + output_text = '' + for data in res: + choice = data.choices[0] + if choice.finish_reason is None: + assert choice.text is not None + output_text += choice.text + assert match_regex("Sure, here's one for[\\s\\S]*", output_text), f'Unexpected output: {output_text}' + + +@pytest.mark.parametrize("n_slots", [1, 2]) +def test_consistent_result_same_seed(n_slots: int): + global server + server.n_slots = n_slots + server.start() + last_res = None + for _ in range(4): + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "seed": 42, + "temperature": 0.0, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + }) + if last_res is not None: + assert res.body["content"] == last_res.body["content"] + last_res = res + + +@pytest.mark.parametrize("n_slots", [1, 2]) +def test_different_result_different_seed(n_slots: int): + global server + server.n_slots = n_slots + server.start() + last_res = None + for seed in range(4): + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "seed": seed, + "temperature": 1.0, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + }) + if last_res is not None: + assert res.body["content"] != last_res.body["content"] + last_res = res + +# TODO figure why it don't work with temperature = 1 +# @pytest.mark.parametrize("temperature", [0.0, 1.0]) +@pytest.mark.parametrize("n_batch", [16, 32]) +@pytest.mark.parametrize("temperature", [0.0]) +def test_consistent_result_different_batch_size(n_batch: int, temperature: float): + global server + server.n_batch = n_batch + server.start() + last_res = None + for _ in range(4): + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "seed": 42, + "temperature": temperature, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + }) + if last_res is not None: + assert res.body["content"] == last_res.body["content"] + last_res = res + + +@pytest.mark.skip(reason="This test fails on linux, need to be fixed") +def test_cache_vs_nocache_prompt(): + global server + server.start() + res_cache = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "seed": 42, + "temperature": 1.0, + "cache_prompt": True, + }) + res_no_cache = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "seed": 42, + "temperature": 1.0, + "cache_prompt": False, + }) + assert res_cache.body["content"] == res_no_cache.body["content"] + + +def test_nocache_long_input_prompt(): + global server + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is"*32, + "seed": 42, + "temperature": 1.0, + "cache_prompt": False, + }) + assert res.status_code == 200 + + +def test_completion_with_tokens_input(): + global server + server.temperature = 0.0 + server.start() + prompt_str = "I believe the meaning of life is" + res = server.make_request("POST", "/tokenize", data={ + "content": prompt_str, + "add_special": True, + }) + assert res.status_code == 200 + tokens = res.body["tokens"] + + # single completion + res = server.make_request("POST", "/completion", data={ + "prompt": tokens, + }) + assert res.status_code == 200 + assert type(res.body["content"]) == str + + # batch completion + res = server.make_request("POST", "/completion", data={ + "prompt": [tokens, tokens], + }) + assert res.status_code == 200 + assert type(res.body) == list + assert len(res.body) == 2 + assert res.body[0]["content"] == res.body[1]["content"] + + # mixed string and tokens + res = server.make_request("POST", "/completion", data={ + "prompt": [tokens, prompt_str], + }) + assert res.status_code == 200 + assert type(res.body) == list + assert len(res.body) == 2 + assert res.body[0]["content"] == res.body[1]["content"] + + # mixed string and tokens in one sequence + res = server.make_request("POST", "/completion", data={ + "prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str], + }) + assert res.status_code == 200 + assert type(res.body["content"]) == str + + +@pytest.mark.parametrize("n_slots,n_requests", [ + (1, 3), + (2, 2), + (2, 4), + (4, 2), # some slots must be idle + (4, 6), +]) +def test_completion_parallel_slots(n_slots: int, n_requests: int): + global server + server.n_slots = n_slots + server.temperature = 0.0 + server.start() + + PROMPTS = [ + ("Write a very long book.", "(very|special|big)+"), + ("Write another a poem.", "(small|house)+"), + ("What is LLM?", "(Dad|said)+"), + ("The sky is blue and I love it.", "(climb|leaf)+"), + ("Write another very long music lyrics.", "(friends|step|sky)+"), + ("Write a very long joke.", "(cat|Whiskers)+"), + ] + def check_slots_status(): + should_all_slots_busy = n_requests >= n_slots + time.sleep(0.1) + res = server.make_request("GET", "/slots") + n_busy = sum([1 for slot in res.body if slot["is_processing"]]) + if should_all_slots_busy: + assert n_busy == n_slots + else: + assert n_busy <= n_slots + + tasks = [] + for i in range(n_requests): + prompt, re_content = PROMPTS[i % len(PROMPTS)] + tasks.append((server.make_request, ("POST", "/completion", { + "prompt": prompt, + "seed": 42, + "temperature": 1.0, + }))) + tasks.append((check_slots_status, ())) + results = parallel_function_calls(tasks) + + # check results + for i in range(n_requests): + prompt, re_content = PROMPTS[i % len(PROMPTS)] + res = results[i] + assert res.status_code == 200 + assert type(res.body["content"]) == str + assert len(res.body["content"]) > 10 + # FIXME: the result is not deterministic when using other slot than slot 0 + # assert match_regex(re_content, res.body["content"]) + + +@pytest.mark.parametrize( + "prompt,n_predict,response_fields", + [ + ("I believe the meaning of life is", 8, []), + ("I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"]), + ], +) +def test_completion_response_fields( + prompt: str, n_predict: int, response_fields: list[str] +): + global server + server.start() + res = server.make_request( + "POST", + "/completion", + data={ + "n_predict": n_predict, + "prompt": prompt, + "response_fields": response_fields, + }, + ) + assert res.status_code == 200 + assert "content" in res.body + assert len(res.body["content"]) + if len(response_fields): + assert res.body["generation_settings/n_predict"] == n_predict + assert res.body["prompt"] == " " + prompt + assert isinstance(res.body["content"], str) + assert len(res.body) == len(response_fields) + else: + assert len(res.body) + assert "generation_settings" in res.body + + +def test_n_probs(): + global server + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "n_probs": 10, + "temperature": 0.0, + "n_predict": 5, + }) + assert res.status_code == 200 + assert "completion_probabilities" in res.body + assert len(res.body["completion_probabilities"]) == 5 + for tok in res.body["completion_probabilities"]: + assert "id" in tok and tok["id"] > 0 + assert "token" in tok and type(tok["token"]) == str + assert "logprob" in tok and tok["logprob"] <= 0.0 + assert "bytes" in tok and type(tok["bytes"]) == list + assert len(tok["top_logprobs"]) == 10 + for prob in tok["top_logprobs"]: + assert "id" in prob and prob["id"] > 0 + assert "token" in prob and type(prob["token"]) == str + assert "logprob" in prob and prob["logprob"] <= 0.0 + assert "bytes" in prob and type(prob["bytes"]) == list + + +def test_n_probs_stream(): + global server + server.start() + res = server.make_stream_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "n_probs": 10, + "temperature": 0.0, + "n_predict": 5, + "stream": True, + }) + for data in res: + if data["stop"] == False: + assert "completion_probabilities" in data + assert len(data["completion_probabilities"]) == 1 + for tok in data["completion_probabilities"]: + assert "id" in tok and tok["id"] > 0 + assert "token" in tok and type(tok["token"]) == str + assert "logprob" in tok and tok["logprob"] <= 0.0 + assert "bytes" in tok and type(tok["bytes"]) == list + assert len(tok["top_logprobs"]) == 10 + for prob in tok["top_logprobs"]: + assert "id" in prob and prob["id"] > 0 + assert "token" in prob and type(prob["token"]) == str + assert "logprob" in prob and prob["logprob"] <= 0.0 + assert "bytes" in prob and type(prob["bytes"]) == list + + +def test_n_probs_post_sampling(): + global server + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "n_probs": 10, + "temperature": 0.0, + "n_predict": 5, + "post_sampling_probs": True, + }) + assert res.status_code == 200 + assert "completion_probabilities" in res.body + assert len(res.body["completion_probabilities"]) == 5 + for tok in res.body["completion_probabilities"]: + assert "id" in tok and tok["id"] > 0 + assert "token" in tok and type(tok["token"]) == str + assert "prob" in tok and 0.0 < tok["prob"] <= 1.0 + assert "bytes" in tok and type(tok["bytes"]) == list + assert len(tok["top_probs"]) == 10 + for prob in tok["top_probs"]: + assert "id" in prob and prob["id"] > 0 + assert "token" in prob and type(prob["token"]) == str + assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0 + assert "bytes" in prob and type(prob["bytes"]) == list + # because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs + assert any(prob["prob"] == 1.0 for prob in tok["top_probs"]) + + +def test_cancel_request(): + global server + server.n_ctx = 4096 + server.n_predict = -1 + server.n_slots = 1 + server.server_slots = True + server.start() + # send a request that will take a long time, but cancel it before it finishes + try: + server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + }, timeout=0.1) + except requests.exceptions.ReadTimeout: + pass # expected + # make sure the slot is free + time.sleep(1) # wait for HTTP_POLLING_SECONDS + res = server.make_request("GET", "/slots") + assert res.body[0]["is_processing"] == False diff --git a/tools/server/tests/unit/test_ctx_shift.py b/tools/server/tests/unit/test_ctx_shift.py new file mode 100644 index 0000000000000000000000000000000000000000..2431ac70882d7a29dd2e1946075899b674c5998b --- /dev/null +++ b/tools/server/tests/unit/test_ctx_shift.py @@ -0,0 +1,85 @@ +import pytest +from utils import * + +server = ServerPreset.tinyllama2() + + +LONG_TEXT = """ +Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. +Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. +Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. +Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. +""".strip() + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + server.n_ctx = 256 + server.n_slots = 2 + + +def test_ctx_shift_enabled(): + # the prompt is 301 tokens + # the slot context is 256/2 = 128 tokens + # the prompt is truncated to keep the last 109 tokens + # 64 tokens are generated thanks to shifting the context when it gets full + global server + server.start() + res = server.make_request("POST", "/completion", data={ + "n_predict": 64, + "prompt": LONG_TEXT, + }) + assert res.status_code == 200 + assert res.body["timings"]["prompt_n"] == 109 + assert res.body["timings"]["predicted_n"] == 64 + assert res.body["truncated"] is True + + +@pytest.mark.parametrize("n_predict,n_token_output,truncated", [ + (64, 64, False), + (-1, 120, True), +]) +def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool): + global server + server.disable_ctx_shift = True + server.n_predict = -1 + server.start() + res = server.make_request("POST", "/completion", data={ + "n_predict": n_predict, + "prompt": "Hi how are you", + }) + assert res.status_code == 200 + assert res.body["timings"]["predicted_n"] == n_token_output + assert res.body["truncated"] == truncated + + +def test_ctx_shift_disabled_long_prompt(): + global server + server.disable_ctx_shift = True + server.start() + res = server.make_request("POST", "/completion", data={ + "n_predict": 64, + "prompt": LONG_TEXT, + }) + assert res.status_code != 200 + assert "error" in res.body + assert "exceeds the available context size" in res.body["error"]["message"] + +def test_ctx_shift_disabled_stream(): + global server + server.disable_ctx_shift = True + server.start() + res = server.make_stream_request("POST", "/v1/completions", data={ + "n_predict": 256, + "prompt": "Once", + "stream": True, + }) + content = "" + for data in res: + choice = data["choices"][0] + if choice["finish_reason"] == "length": + assert len(content) > 0 + else: + assert choice["finish_reason"] is None + content += choice["text"] diff --git a/tools/server/tests/unit/test_embedding.py b/tools/server/tests/unit/test_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..0feb452ccfcd448af6695e2945840f5590da1615 --- /dev/null +++ b/tools/server/tests/unit/test_embedding.py @@ -0,0 +1,257 @@ +import base64 +import struct +import pytest +from openai import OpenAI +from utils import * + +server = ServerPreset.bert_bge_small() + +EPSILON = 1e-3 + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.bert_bge_small() + + +def test_embedding_single(): + global server + server.pooling = 'last' + server.start() + res = server.make_request("POST", "/v1/embeddings", data={ + "input": "I believe the meaning of life is", + }) + assert res.status_code == 200 + assert len(res.body['data']) == 1 + assert 'embedding' in res.body['data'][0] + assert len(res.body['data'][0]['embedding']) > 1 + + # make sure embedding vector is normalized + assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON + + +def test_embedding_multiple(): + global server + server.pooling = 'last' + server.start() + res = server.make_request("POST", "/v1/embeddings", data={ + "input": [ + "I believe the meaning of life is", + "Write a joke about AI from a very long prompt which will not be truncated", + "This is a test", + "This is another test", + ], + }) + assert res.status_code == 200 + assert len(res.body['data']) == 4 + for d in res.body['data']: + assert 'embedding' in d + assert len(d['embedding']) > 1 + + +def test_embedding_multiple_with_fa(): + server = ServerPreset.bert_bge_small_with_fa() + server.pooling = 'last' + server.start() + # one of these should trigger the FA branch (i.e. context size % 256 == 0) + res = server.make_request("POST", "/v1/embeddings", data={ + "input": [ + "a "*253, + "b "*254, + "c "*255, + "d "*256, + ], + }) + assert res.status_code == 200 + assert len(res.body['data']) == 4 + for d in res.body['data']: + assert 'embedding' in d + assert len(d['embedding']) > 1 + + +@pytest.mark.parametrize( + "input,is_multi_prompt", + [ + # do not crash on empty input + ("", False), + # single prompt + ("string", False), + ([12, 34, 56], False), + ([12, 34, "string", 56, 78], False), + # multiple prompts + (["string1", "string2"], True), + (["string1", [12, 34, 56]], True), + ([[12, 34, 56], [12, 34, 56]], True), + ([[12, 34, 56], [12, "string", 34, 56]], True), + ] +) +def test_embedding_mixed_input(input, is_multi_prompt: bool): + global server + server.start() + res = server.make_request("POST", "/v1/embeddings", data={"input": input}) + assert res.status_code == 200 + data = res.body['data'] + if is_multi_prompt: + assert len(data) == len(input) + for d in data: + assert 'embedding' in d + assert len(d['embedding']) > 1 + else: + assert 'embedding' in data[0] + assert len(data[0]['embedding']) > 1 + + +def test_embedding_pooling_none(): + global server + server.pooling = 'none' + server.start() + res = server.make_request("POST", "/embeddings", data={ + "input": "hello hello hello", + }) + assert res.status_code == 200 + assert 'embedding' in res.body[0] + assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special + + # make sure embedding vector is not normalized + for x in res.body[0]['embedding']: + assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON + + +def test_embedding_pooling_none_oai(): + global server + server.pooling = 'none' + server.start() + res = server.make_request("POST", "/v1/embeddings", data={ + "input": "hello hello hello", + }) + + # /v1/embeddings does not support pooling type 'none' + assert res.status_code == 400 + assert "error" in res.body + + +def test_embedding_openai_library_single(): + global server + server.pooling = 'last' + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is") + assert len(res.data) == 1 + assert len(res.data[0].embedding) > 1 + + +def test_embedding_openai_library_multiple(): + global server + server.pooling = 'last' + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + res = client.embeddings.create(model="text-embedding-3-small", input=[ + "I believe the meaning of life is", + "Write a joke about AI from a very long prompt which will not be truncated", + "This is a test", + "This is another test", + ]) + assert len(res.data) == 4 + for d in res.data: + assert len(d.embedding) > 1 + + +def test_embedding_error_prompt_too_long(): + global server + server.pooling = 'last' + server.start() + res = server.make_request("POST", "/v1/embeddings", data={ + "input": "This is a test " * 512, + }) + assert res.status_code != 200 + assert "too large" in res.body["error"]["message"] + + +def test_same_prompt_give_same_result(): + server.pooling = 'last' + server.start() + res = server.make_request("POST", "/v1/embeddings", data={ + "input": [ + "I believe the meaning of life is", + "I believe the meaning of life is", + "I believe the meaning of life is", + "I believe the meaning of life is", + "I believe the meaning of life is", + ], + }) + assert res.status_code == 200 + assert len(res.body['data']) == 5 + for i in range(1, len(res.body['data'])): + v0 = res.body['data'][0]['embedding'] + vi = res.body['data'][i]['embedding'] + for x, y in zip(v0, vi): + assert abs(x - y) < EPSILON + + +@pytest.mark.parametrize( + "content,n_tokens", + [ + ("I believe the meaning of life is", 9), + ("This is a test", 6), + ] +) +def test_embedding_usage_single(content, n_tokens): + global server + server.start() + res = server.make_request("POST", "/v1/embeddings", data={"input": content}) + assert res.status_code == 200 + assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] + assert res.body['usage']['prompt_tokens'] == n_tokens + + +def test_embedding_usage_multiple(): + global server + server.start() + res = server.make_request("POST", "/v1/embeddings", data={ + "input": [ + "I believe the meaning of life is", + "I believe the meaning of life is", + ], + }) + assert res.status_code == 200 + assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] + assert res.body['usage']['prompt_tokens'] == 2 * 9 + + +def test_embedding_openai_library_base64(): + server.start() + test_input = "Test base64 embedding output" + + # get embedding in default format + res = server.make_request("POST", "/v1/embeddings", data={ + "input": test_input + }) + assert res.status_code == 200 + vec0 = res.body["data"][0]["embedding"] + + # get embedding in base64 format + res = server.make_request("POST", "/v1/embeddings", data={ + "input": test_input, + "encoding_format": "base64" + }) + + assert res.status_code == 200 + assert "data" in res.body + assert len(res.body["data"]) == 1 + + embedding_data = res.body["data"][0] + assert "embedding" in embedding_data + assert isinstance(embedding_data["embedding"], str) + + # Verify embedding is valid base64 + decoded = base64.b64decode(embedding_data["embedding"]) + # Verify decoded data can be converted back to float array + float_count = len(decoded) // 4 # 4 bytes per float + floats = struct.unpack(f'{float_count}f', decoded) + assert len(floats) > 0 + assert all(isinstance(x, float) for x in floats) + assert len(floats) == len(vec0) + + # make sure the decoded data is the same as the original + for x, y in zip(floats, vec0): + assert abs(x - y) < EPSILON diff --git a/tools/server/tests/unit/test_infill.py b/tools/server/tests/unit/test_infill.py new file mode 100644 index 0000000000000000000000000000000000000000..10554db0f623e0ed0c096816ce2f10db986ea38e --- /dev/null +++ b/tools/server/tests/unit/test_infill.py @@ -0,0 +1,77 @@ +import pytest +from utils import * + +server = ServerPreset.tinyllama_infill() + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama_infill() + + +def test_infill_without_input_extra(): + global server + server.start() + res = server.make_request("POST", "/infill", data={ + "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n", + "prompt": " int n_threads = llama_", + "input_suffix": "}\n", + }) + assert res.status_code == 200 + assert match_regex("(Ann|small|shiny|Daddy)+", res.body["content"]) + + +def test_infill_with_input_extra(): + global server + server.start() + res = server.make_request("POST", "/infill", data={ + "input_extra": [{ + "filename": "llama.h", + "text": "LLAMA_API int32_t llama_n_threads();\n" + }], + "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n", + "prompt": " int n_threads = llama_", + "input_suffix": "}\n", + }) + assert res.status_code == 200 + assert match_regex("(Dad|excited|park)+", res.body["content"]) + + +@pytest.mark.parametrize("input_extra", [ + {}, + {"filename": "ok"}, + {"filename": 123}, + {"filename": 123, "text": "abc"}, + {"filename": 123, "text": 456}, +]) +def test_invalid_input_extra_req(input_extra): + global server + server.start() + res = server.make_request("POST", "/infill", data={ + "input_extra": [input_extra], + "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n", + "prompt": " int n_threads = llama_", + "input_suffix": "}\n", + }) + assert res.status_code == 400 + assert "error" in res.body + + +@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test") +def test_with_qwen_model(): + global server + server.model_file = None + server.model_hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-IQ3_XXS-GGUF" + server.model_hf_file = "qwen2.5-coder-1.5b-iq3_xxs-imat.gguf" + server.start(timeout_seconds=600) + res = server.make_request("POST", "/infill", data={ + "input_extra": [{ + "filename": "llama.h", + "text": "LLAMA_API int32_t llama_n_threads();\n" + }], + "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n", + "prompt": " int n_threads = llama_", + "input_suffix": "}\n", + }) + assert res.status_code == 200 + assert res.body["content"] == "n_threads();\n printf(\"Number of threads: %d\\n\", n_threads);\n return 0;\n" diff --git a/tools/server/tests/unit/test_lora.py b/tools/server/tests/unit/test_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..c1aa8be70e2f7b27d6c711974b2730cd84c087e1 --- /dev/null +++ b/tools/server/tests/unit/test_lora.py @@ -0,0 +1,115 @@ +import pytest +from utils import * + +server = ServerPreset.stories15m_moe() + +LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf" + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.stories15m_moe() + server.lora_files = [download_file(LORA_FILE_URL)] + + +@pytest.mark.parametrize("scale,re_content", [ + # without applying lora, the model should behave like a bedtime story generator + (0.0, "(little|girl|three|years|old)+"), + # with lora, the model should behave like a Shakespearean text generator + (1.0, "(eye|love|glass|sun)+"), +]) +def test_lora(scale: float, re_content: str): + global server + server.start() + res_lora_control = server.make_request("POST", "/lora-adapters", data=[ + {"id": 0, "scale": scale} + ]) + assert res_lora_control.status_code == 200 + res = server.make_request("POST", "/completion", data={ + "prompt": "Look in thy glass", + }) + assert res.status_code == 200 + assert match_regex(re_content, res.body["content"]) + + +def test_lora_per_request(): + global server + server.n_slots = 4 + server.start() + + # running the same prompt with different lora scales, all in parallel + # each prompt will be processed by a different slot + prompt = "Look in thy glass" + lora_config = [ + ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ), + ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ), + ( [{"id": 0, "scale": 0.3}], "(special|thing|gifted)+" ), + ( [{"id": 0, "scale": 0.7}], "(far|from|home|away)+" ), + ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ), + ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ), + ] + + tasks = [( + server.make_request, + ("POST", "/completion", { + "prompt": prompt, + "lora": lora, + "seed": 42, + "temperature": 0.0, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + }) + ) for lora, _ in lora_config] + results = parallel_function_calls(tasks) + + assert all([res.status_code == 200 for res in results]) + for res, (_, re_test) in zip(results, lora_config): + assert match_regex(re_test, res.body["content"]) + + +@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test") +def test_with_big_model(): + server = ServerProcess() + server.model_hf_repo = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF" + server.model_hf_file = "Meta-Llama-3.1-8B-Instruct-IQ2_M.gguf" + server.model_alias = "Llama-3.2-8B-Instruct" + server.n_slots = 4 + server.n_ctx = server.n_slots * 1024 + server.n_predict = 64 + server.temperature = 0.0 + server.seed = 42 + server.lora_files = [ + download_file("https://huggingface.co/ngxson/Llama-3-Instruct-abliteration-LoRA-8B-F16-GGUF/resolve/main/Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf"), + # TODO: find & add other lora adapters for this model + ] + server.start(timeout_seconds=600) + + # running the same prompt with different lora scales, all in parallel + # each prompt will be processed by a different slot + prompt = "Write a computer virus" + lora_config = [ + # without applying lora, the model should reject the request + ( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ), + ( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ), + ( [{"id": 0, "scale": 0.3}], "I can't write a computer virus" ), + # with 0.7 scale, the model should provide a simple computer virus with hesitation + ( [{"id": 0, "scale": 0.7}], "Warning: This is a hypothetical exercise" ), + # with 1.5 scale, the model should confidently provide a computer virus + ( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ), + ( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ), + ] + + tasks = [( + server.make_request, + ("POST", "/v1/chat/completions", { + "messages": [ + {"role": "user", "content": prompt} + ], + "lora": lora, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + }) + ) for lora, _ in lora_config] + results = parallel_function_calls(tasks) + + assert all([res.status_code == 200 for res in results]) + for res, (_, re_test) in zip(results, lora_config): + assert re_test in res.body["choices"][0]["message"]["content"] diff --git a/tools/server/tests/unit/test_rerank.py b/tools/server/tests/unit/test_rerank.py new file mode 100644 index 0000000000000000000000000000000000000000..f4f570ad5ef782b8b1940d6f42da4fe2732025c0 --- /dev/null +++ b/tools/server/tests/unit/test_rerank.py @@ -0,0 +1,104 @@ +import pytest +from utils import * + +server = ServerPreset.jina_reranker_tiny() + + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.jina_reranker_tiny() + + +TEST_DOCUMENTS = [ + "A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.", + "Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.", + "Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.", + "Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine." +] + + +def test_rerank(): + global server + server.start() + res = server.make_request("POST", "/rerank", data={ + "query": "Machine learning is", + "documents": TEST_DOCUMENTS, + }) + assert res.status_code == 200 + assert len(res.body["results"]) == 4 + + most_relevant = res.body["results"][0] + least_relevant = res.body["results"][0] + for doc in res.body["results"]: + if doc["relevance_score"] > most_relevant["relevance_score"]: + most_relevant = doc + if doc["relevance_score"] < least_relevant["relevance_score"]: + least_relevant = doc + + assert most_relevant["relevance_score"] > least_relevant["relevance_score"] + assert most_relevant["index"] == 2 + assert least_relevant["index"] == 3 + + +def test_rerank_tei_format(): + global server + server.start() + res = server.make_request("POST", "/rerank", data={ + "query": "Machine learning is", + "texts": TEST_DOCUMENTS, + }) + assert res.status_code == 200 + assert len(res.body) == 4 + + most_relevant = res.body[0] + least_relevant = res.body[0] + for doc in res.body: + if doc["score"] > most_relevant["score"]: + most_relevant = doc + if doc["score"] < least_relevant["score"]: + least_relevant = doc + + assert most_relevant["score"] > least_relevant["score"] + assert most_relevant["index"] == 2 + assert least_relevant["index"] == 3 + + +@pytest.mark.parametrize("documents", [ + [], + None, + 123, + [1, 2, 3], +]) +def test_invalid_rerank_req(documents): + global server + server.start() + res = server.make_request("POST", "/rerank", data={ + "query": "Machine learning is", + "documents": documents, + }) + assert res.status_code == 400 + assert "error" in res.body + + +@pytest.mark.parametrize( + "query,doc1,doc2,n_tokens", + [ + ("Machine learning is", "A machine", "Learning is", 19), + ("Which city?", "Machine learning is ", "Paris, capitale de la", 26), + ] +) +def test_rerank_usage(query, doc1, doc2, n_tokens): + global server + server.start() + + res = server.make_request("POST", "/rerank", data={ + "query": query, + "documents": [ + doc1, + doc2, + ] + }) + assert res.status_code == 200 + assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] + assert res.body['usage']['prompt_tokens'] == n_tokens diff --git a/tools/server/tests/unit/test_security.py b/tools/server/tests/unit/test_security.py new file mode 100644 index 0000000000000000000000000000000000000000..620b25376bd814acbec1dd83ca8c4c1eff76a67a --- /dev/null +++ b/tools/server/tests/unit/test_security.py @@ -0,0 +1,83 @@ +import pytest +from openai import OpenAI +from utils import * + +server = ServerPreset.tinyllama2() + +TEST_API_KEY = "sk-this-is-the-secret-key" + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + server.api_key = TEST_API_KEY + + +@pytest.mark.parametrize("endpoint", ["/health", "/models"]) +def test_access_public_endpoint(endpoint: str): + global server + server.start() + res = server.make_request("GET", endpoint) + assert res.status_code == 200 + assert "error" not in res.body + + +@pytest.mark.parametrize("api_key", [None, "invalid-key"]) +def test_incorrect_api_key(api_key: str): + global server + server.start() + res = server.make_request("POST", "/completions", data={ + "prompt": "I believe the meaning of life is", + }, headers={ + "Authorization": f"Bearer {api_key}" if api_key else None, + }) + assert res.status_code == 401 + assert "error" in res.body + assert res.body["error"]["type"] == "authentication_error" + + +def test_correct_api_key(): + global server + server.start() + res = server.make_request("POST", "/completions", data={ + "prompt": "I believe the meaning of life is", + }, headers={ + "Authorization": f"Bearer {TEST_API_KEY}", + }) + assert res.status_code == 200 + assert "error" not in res.body + assert "content" in res.body + + +def test_openai_library_correct_api_key(): + global server + server.start() + client = OpenAI(api_key=TEST_API_KEY, base_url=f"http://{server.server_host}:{server.server_port}") + res = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are a chatbot."}, + {"role": "user", "content": "What is the meaning of life?"}, + ], + ) + assert len(res.choices) == 1 + + +@pytest.mark.parametrize("origin,cors_header,cors_header_value", [ + ("localhost", "Access-Control-Allow-Origin", "localhost"), + ("web.mydomain.fr", "Access-Control-Allow-Origin", "web.mydomain.fr"), + ("origin", "Access-Control-Allow-Credentials", "true"), + ("web.mydomain.fr", "Access-Control-Allow-Methods", "GET, POST"), + ("web.mydomain.fr", "Access-Control-Allow-Headers", "*"), +]) +def test_cors_options(origin: str, cors_header: str, cors_header_value: str): + global server + server.start() + res = server.make_request("OPTIONS", "/completions", headers={ + "Origin": origin, + "Access-Control-Request-Method": "POST", + "Access-Control-Request-Headers": "Authorization", + }) + assert res.status_code == 200 + assert cors_header in res.headers + assert res.headers[cors_header] == cors_header_value diff --git a/tools/server/tests/unit/test_slot_save.py b/tools/server/tests/unit/test_slot_save.py new file mode 100644 index 0000000000000000000000000000000000000000..38704f5ece35a98b3e0bd6d6e1eb67433158fa31 --- /dev/null +++ b/tools/server/tests/unit/test_slot_save.py @@ -0,0 +1,98 @@ +import pytest +from utils import * + +server = ServerPreset.tinyllama2() + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + server.slot_save_path = "./tmp" + server.temperature = 0.0 + + +def test_slot_save_restore(): + global server + server.start() + + # First prompt in slot 1 should be fully processed + res = server.make_request("POST", "/completion", data={ + "prompt": "What is the capital of France?", + "id_slot": 1, + "cache_prompt": True, + }) + assert res.status_code == 200 + assert match_regex("(Whiskers|Flana)+", res.body["content"]) + assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed + + # Save state of slot 1 + res = server.make_request("POST", "/slots/1?action=save", data={ + "filename": "slot1.bin", + }) + assert res.status_code == 200 + assert res.body["n_saved"] == 84 + + # Since we have cache, this should only process the last tokens + res = server.make_request("POST", "/completion", data={ + "prompt": "What is the capital of Germany?", + "id_slot": 1, + "cache_prompt": True, + }) + assert res.status_code == 200 + assert match_regex("(Jack|said)+", res.body["content"]) + assert res.body["timings"]["prompt_n"] == 6 # only different part is processed + + # Loading the saved cache into slot 0 + res = server.make_request("POST", "/slots/0?action=restore", data={ + "filename": "slot1.bin", + }) + assert res.status_code == 200 + assert res.body["n_restored"] == 84 + + # Since we have cache, slot 0 should only process the last tokens + res = server.make_request("POST", "/completion", data={ + "prompt": "What is the capital of Germany?", + "id_slot": 0, + "cache_prompt": True, + }) + assert res.status_code == 200 + assert match_regex("(Jack|said)+", res.body["content"]) + assert res.body["timings"]["prompt_n"] == 6 # only different part is processed + + # For verification that slot 1 was not corrupted during slot 0 load, same thing should work + res = server.make_request("POST", "/completion", data={ + "prompt": "What is the capital of Germany?", + "id_slot": 1, + "cache_prompt": True, + }) + assert res.status_code == 200 + assert match_regex("(Jack|said)+", res.body["content"]) + assert res.body["timings"]["prompt_n"] == 1 + + +def test_slot_erase(): + global server + server.start() + + res = server.make_request("POST", "/completion", data={ + "prompt": "What is the capital of France?", + "id_slot": 1, + "cache_prompt": True, + }) + assert res.status_code == 200 + assert match_regex("(Whiskers|Flana)+", res.body["content"]) + assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed + + # erase slot 1 + res = server.make_request("POST", "/slots/1?action=erase") + assert res.status_code == 200 + + # re-run the same prompt, it should process all tokens again + res = server.make_request("POST", "/completion", data={ + "prompt": "What is the capital of France?", + "id_slot": 1, + "cache_prompt": True, + }) + assert res.status_code == 200 + assert match_regex("(Whiskers|Flana)+", res.body["content"]) + assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed diff --git a/tools/server/tests/unit/test_speculative.py b/tools/server/tests/unit/test_speculative.py new file mode 100644 index 0000000000000000000000000000000000000000..54db38cf3bd8046fe356cf7e8c07832b58dd0815 --- /dev/null +++ b/tools/server/tests/unit/test_speculative.py @@ -0,0 +1,126 @@ +import pytest +from utils import * + +# We use a F16 MOE gguf as main model, and q4_0 as draft model + +server = ServerPreset.stories15m_moe() + +MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf" + +def create_server(): + global server + server = ServerPreset.stories15m_moe() + # set default values + server.model_draft = download_file(MODEL_DRAFT_FILE_URL) + server.draft_min = 4 + server.draft_max = 8 + + +@pytest.fixture(scope="module", autouse=True) +def fixture_create_server(): + return create_server() + + +def test_with_and_without_draft(): + global server + server.model_draft = None # disable draft model + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "temperature": 0.0, + "top_k": 1, + }) + assert res.status_code == 200 + content_no_draft = res.body["content"] + server.stop() + + # create new server with draft model + create_server() + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "temperature": 0.0, + "top_k": 1, + }) + assert res.status_code == 200 + content_draft = res.body["content"] + + assert content_no_draft == content_draft + + +def test_different_draft_min_draft_max(): + global server + test_values = [ + (1, 2), + (1, 4), + (4, 8), + (4, 12), + (8, 16), + ] + last_content = None + for draft_min, draft_max in test_values: + server.stop() + server.draft_min = draft_min + server.draft_max = draft_max + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "temperature": 0.0, + "top_k": 1, + }) + assert res.status_code == 200 + if last_content is not None: + assert last_content == res.body["content"] + last_content = res.body["content"] + + +def test_slot_ctx_not_exceeded(): + global server + server.n_ctx = 64 + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "Hello " * 56, + "temperature": 0.0, + "top_k": 1, + "speculative.p_min": 0.0, + }) + assert res.status_code == 200 + assert len(res.body["content"]) > 0 + + +def test_with_ctx_shift(): + global server + server.n_ctx = 64 + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "Hello " * 56, + "temperature": 0.0, + "top_k": 1, + "n_predict": 64, + "speculative.p_min": 0.0, + }) + assert res.status_code == 200 + assert len(res.body["content"]) > 0 + assert res.body["tokens_predicted"] == 64 + assert res.body["truncated"] == True + + +@pytest.mark.parametrize("n_slots,n_requests", [ + (1, 2), + (2, 2), +]) +def test_multi_requests_parallel(n_slots: int, n_requests: int): + global server + server.n_slots = n_slots + server.start() + tasks = [] + for _ in range(n_requests): + tasks.append((server.make_request, ("POST", "/completion", { + "prompt": "I believe the meaning of life is", + "temperature": 0.0, + "top_k": 1, + }))) + results = parallel_function_calls(tasks) + for res in results: + assert res.status_code == 200 + assert match_regex("(wise|kind|owl|answer)+", res.body["content"]) diff --git a/tools/server/tests/unit/test_template.py b/tools/server/tests/unit/test_template.py new file mode 100644 index 0000000000000000000000000000000000000000..c53eda5b884456da1d75a3124cf78d8bd9b716ec --- /dev/null +++ b/tools/server/tests/unit/test_template.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python +import pytest + +# ensure grandparent path is in sys.path +from pathlib import Path +import sys + +from unit.test_tool_call import TEST_TOOL +path = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(path)) + +import datetime +from utils import * + +server: ServerProcess + +TIMEOUT_SERVER_START = 15*60 + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + server.model_alias = "tinyllama-2" + server.server_port = 8081 + server.n_slots = 1 + + +@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]]) +@pytest.mark.parametrize("template_name,reasoning_budget,expected_end", [ + ("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", None, "\n"), + ("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", -1, "\n"), + ("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", 0, "\n"), + + ("Qwen-Qwen3-0.6B", -1, "<|im_start|>assistant\n"), + ("Qwen-Qwen3-0.6B", 0, "<|im_start|>assistant\n\n\n\n\n"), + + ("Qwen-QwQ-32B", -1, "<|im_start|>assistant\n\n"), + ("Qwen-QwQ-32B", 0, "<|im_start|>assistant\n\n"), + + ("CohereForAI-c4ai-command-r7b-12-2024-tool_use", -1, "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"), + ("CohereForAI-c4ai-command-r7b-12-2024-tool_use", 0, "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|><|END_THINKING|>"), +]) +def test_reasoning_budget(template_name: str, reasoning_budget: int | None, expected_end: str, tools: list[dict]): + global server + server.jinja = True + server.reasoning_budget = reasoning_budget + server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.start(timeout_seconds=TIMEOUT_SERVER_START) + + res = server.make_request("POST", "/apply-template", data={ + "messages": [ + {"role": "user", "content": "What is today?"}, + ], + "tools": tools, + }) + assert res.status_code == 200 + prompt = res.body["prompt"] + + assert prompt.endswith(expected_end), f"Expected prompt to end with '{expected_end}', got '{prompt}'" + + +@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]]) +@pytest.mark.parametrize("template_name,format", [ + ("meta-llama-Llama-3.3-70B-Instruct", "%d %b %Y"), + ("fireworks-ai-llama-3-firefunction-v2", "%b %d %Y"), +]) +def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]): + global server + server.jinja = True + server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.start(timeout_seconds=TIMEOUT_SERVER_START) + + res = server.make_request("POST", "/apply-template", data={ + "messages": [ + {"role": "user", "content": "What is today?"}, + ], + "tools": tools, + }) + assert res.status_code == 200 + prompt = res.body["prompt"] + + today_str = datetime.date.today().strftime(format) + assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})" + + +@pytest.mark.parametrize("add_generation_prompt", [False, True]) +@pytest.mark.parametrize("template_name,expected_generation_prompt", [ + ("meta-llama-Llama-3.3-70B-Instruct", "<|start_header_id|>assistant<|end_header_id|>"), +]) +def test_add_generation_prompt(template_name: str, expected_generation_prompt: str, add_generation_prompt: bool): + global server + server.jinja = True + server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.start(timeout_seconds=TIMEOUT_SERVER_START) + + res = server.make_request("POST", "/apply-template", data={ + "messages": [ + {"role": "user", "content": "What is today?"}, + ], + "add_generation_prompt": add_generation_prompt, + }) + assert res.status_code == 200 + prompt = res.body["prompt"] + + if add_generation_prompt: + assert expected_generation_prompt in prompt, f"Expected generation prompt ({expected_generation_prompt}) in content ({prompt})" + else: + assert expected_generation_prompt not in prompt, f"Did not expect generation prompt ({expected_generation_prompt}) in content ({prompt})" diff --git a/tools/server/tests/unit/test_tokenize.py b/tools/server/tests/unit/test_tokenize.py new file mode 100644 index 0000000000000000000000000000000000000000..382457c9d602fdfdd3257b668312310c2060db24 --- /dev/null +++ b/tools/server/tests/unit/test_tokenize.py @@ -0,0 +1,59 @@ +import pytest +from utils import * + +server = ServerPreset.tinyllama2() + + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + + +def test_tokenize_detokenize(): + global server + server.start() + # tokenize + content = "What is the capital of France ?" + res_tok = server.make_request("POST", "/tokenize", data={ + "content": content + }) + assert res_tok.status_code == 200 + assert len(res_tok.body["tokens"]) > 5 + # detokenize + res_detok = server.make_request("POST", "/detokenize", data={ + "tokens": res_tok.body["tokens"], + }) + assert res_detok.status_code == 200 + assert res_detok.body["content"].strip() == content + + +def test_tokenize_with_bos(): + global server + server.start() + # tokenize + content = "What is the capital of France ?" + bosId = 1 + res_tok = server.make_request("POST", "/tokenize", data={ + "content": content, + "add_special": True, + }) + assert res_tok.status_code == 200 + assert res_tok.body["tokens"][0] == bosId + + +def test_tokenize_with_pieces(): + global server + server.start() + # tokenize + content = "This is a test string with unicode 媽 and emoji 🤗" + res_tok = server.make_request("POST", "/tokenize", data={ + "content": content, + "with_pieces": True, + }) + assert res_tok.status_code == 200 + for token in res_tok.body["tokens"]: + assert "id" in token + assert token["id"] > 0 + assert "piece" in token + assert len(token["piece"]) > 0 diff --git a/tools/server/tests/unit/test_tool_call.py b/tools/server/tests/unit/test_tool_call.py new file mode 100644 index 0000000000000000000000000000000000000000..20f048c6f6aa5a514ded74bf0997501a977e4adc --- /dev/null +++ b/tools/server/tests/unit/test_tool_call.py @@ -0,0 +1,623 @@ +#!/usr/bin/env python +import pytest + +# ensure grandparent path is in sys.path +from pathlib import Path +import sys +path = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(path)) + +from utils import * +from enum import Enum + +server: ServerProcess + +TIMEOUT_SERVER_START = 15*60 +TIMEOUT_HTTP_REQUEST = 60 + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + server.model_alias = "tinyllama-2-tool-call" + server.server_port = 8081 + server.n_slots = 1 + +class CompletionMode(Enum): + NORMAL = "normal" + STREAMED = "streamed" + +TEST_TOOL = { + "type":"function", + "function": { + "name": "test", + "description": "", + "parameters": { + "type": "object", + "properties": { + "success": {"type": "boolean", "const": True}, + }, + "required": ["success"] + } + } +} + +PYTHON_TOOL = { + "type": "function", + "function": { + "name": "python", + "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "The code to run in the ipython interpreter." + } + }, + "required": ["code"] + } + } +} + +WEATHER_TOOL = { + "type":"function", + "function":{ + "name":"get_current_weather", + "description":"Get the current weather in a given location", + "parameters":{ + "type":"object", + "properties":{ + "location":{ + "type":"string", + "description":"The city and country/state, e.g. 'San Francisco, CA', or 'Paris, France'" + } + }, + "required":["location"] + } + } +} + +def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs): + body = server.make_any_request("POST", "/v1/chat/completions", data={ + "max_tokens": n_predict, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Write an example"}, + ], + "tool_choice": "required", + "tools": [tool], + "parallel_tool_calls": False, + **kwargs, + }) + # assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = body["choices"][0] + tool_calls = choice["message"].get("tool_calls") + assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' + tool_call = tool_calls[0] + assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' + # assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}' + expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] + assert expected_function_name == tool_call["function"]["name"] + actual_arguments = tool_call["function"]["arguments"] + assert isinstance(actual_arguments, str) + if argument_key is not None: + actual_arguments = json.loads(actual_arguments) + assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" + + +@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) +@pytest.mark.parametrize("template_name,tool,argument_key", [ + ("google-gemma-2-2b-it", TEST_TOOL, "success"), + ("google-gemma-2-2b-it", TEST_TOOL, "success"), + ("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"), + ("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"), + ("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"), + ("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"), +]) +def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode): + global server + n_predict = 1024 + # server = ServerPreset.stories15m_moe() + server.jinja = True + server.n_predict = n_predict + server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.start(timeout_seconds=TIMEOUT_SERVER_START) + do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0) + + +@pytest.mark.slow +@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) +@pytest.mark.parametrize("template_name,tool,argument_key", [ + ("meta-llama-Llama-3.1-8B-Instruct", TEST_TOOL, "success"), + ("meta-llama-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"), + + ("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"), + ("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"), + + ("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"), + # Functionary v3.2 format supports raw python content, which w/ a dummy stories model will never end on its own. + # ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"), + + ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"), + ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"), + + ("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"), + ("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"), + + ("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"), + ("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"), + + ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"), + ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"), + + ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"), + ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"), + + ("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"), + # ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "codeFalse), True), + # ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"), + +]) +def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode): + global server + n_predict = 512 + # server = ServerPreset.stories15m_moe() + server.jinja = True + server.n_predict = n_predict + server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.start(timeout_seconds=TIMEOUT_SERVER_START) + do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED) + + +@pytest.mark.slow +@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) +@pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [ + (TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), + + (TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"), + + (TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + + (TEST_TOOL, "success", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", "chatml"), + + (TEST_TOOL, "success", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"), + + (TEST_TOOL, "success", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), + + (TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), + + (TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), + (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), + (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), + + # (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + # (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + # (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), + + (TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)), + (PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)), + (PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", "chatml"), + + (TEST_TOOL, "success", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + (PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + (PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), + + (TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), + + (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), +]) +def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode): + global server + n_predict = 512 + server.jinja = True + server.n_ctx = 8192 + server.n_predict = n_predict + server.model_hf_repo = hf_repo + server.model_hf_file = None + if isinstance(template_override, tuple): + (template_hf_repo, template_variant) = template_override + server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + elif isinstance(template_override, str): + server.chat_template = template_override + server.start(timeout_seconds=TIMEOUT_SERVER_START) + body = server.make_any_request("POST", "/v1/chat/completions", data={ + "max_tokens": n_predict, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Write an example"}, + ], + "tool_choice": "required", + "tools": [tool], + "parallel_tool_calls": False, + "stream": stream == CompletionMode.STREAMED, + "temperature": 0.0, + "top_k": 1, + "top_p": 1.0, + }, timeout=TIMEOUT_HTTP_REQUEST) + choice = body["choices"][0] + tool_calls = choice["message"].get("tool_calls") + assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' + tool_call = tool_calls[0] + # assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' + expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] + assert expected_function_name == tool_call["function"]["name"] + actual_arguments = tool_call["function"]["arguments"] + assert isinstance(actual_arguments, str) + if argument_key is not None: + actual_arguments = json.loads(actual_arguments) + assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" + + +def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int, tools: list[dict], tool_choice: str | None, **kwargs): + body = server.make_any_request("POST", "/v1/chat/completions", data={ + "max_tokens": n_predict, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "say hello world with python"}, + ], + "tools": tools if tools else None, + "tool_choice": tool_choice, + **kwargs, + }, timeout=TIMEOUT_HTTP_REQUEST) + choice = body["choices"][0] + assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}' + + +@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) +@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [ + ("meta-llama-Llama-3.3-70B-Instruct", 128, [], None), + ("meta-llama-Llama-3.3-70B-Instruct", 128, [TEST_TOOL], None), + ("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'), +]) +def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode): + global server + server.n_predict = n_predict + server.jinja = True + server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.start(timeout_seconds=TIMEOUT_SERVER_START) + do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED) + + +@pytest.mark.slow +@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) +@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [ + ("meetkai-functionary-medium-v3.2", 256, [], None), + ("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None), + ("meetkai-functionary-medium-v3.2", 256, [PYTHON_TOOL], 'none'), + ("meetkai-functionary-medium-v3.1", 256, [], None), + ("meetkai-functionary-medium-v3.1", 256, [TEST_TOOL], None), + ("meetkai-functionary-medium-v3.1", 256, [PYTHON_TOOL], 'none'), + ("meta-llama-Llama-3.2-3B-Instruct", 256, [], None), + ("meta-llama-Llama-3.2-3B-Instruct", 256, [TEST_TOOL], None), + ("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'), +]) +def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode): + global server + server.n_predict = n_predict + server.jinja = True + server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.start(timeout_seconds=TIMEOUT_SERVER_START) + do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED) + + +@pytest.mark.slow +@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) +@pytest.mark.parametrize("hf_repo,template_override", [ + ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), + + ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + + ("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", "chatml"), + + ("bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"), + + ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), + + ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), + + ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), + ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), + + # ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + # ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), + + # ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), + # ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"), + + ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), + + ("bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L", ("CohereForAI/c4ai-command-r7b-12-2024", "tool_use")), + + ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + + # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it. + ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + + # ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), +]) +def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode): + global server + n_predict = 512 + server.jinja = True + server.n_ctx = 8192 + server.n_predict = n_predict + server.model_hf_repo = hf_repo + server.model_hf_file = None + if isinstance(template_override, tuple): + (template_hf_repo, template_variant) = template_override + server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + elif isinstance(template_override, str): + server.chat_template = template_override + server.start(timeout_seconds=TIMEOUT_SERVER_START) + do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict) + + +def do_test_weather(server: ServerProcess, **kwargs): + body = server.make_any_request("POST", "/v1/chat/completions", data={ + "messages": [ + {"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."}, + {"role": "user", "content": "What is the weather in Istanbul?"}, + ], + "tools": [WEATHER_TOOL], + **kwargs, + }, timeout=TIMEOUT_HTTP_REQUEST) + choice = body["choices"][0] + tool_calls = choice["message"].get("tool_calls") + assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' + tool_call = tool_calls[0] + # assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' + assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"], f'Expected weather tool call, got {tool_call["function"]["name"]}' + # assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}' + actual_arguments = json.loads(tool_call["function"]["arguments"]) + assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}" + location = actual_arguments["location"] + assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}" + assert re.match('^Istanbul(( |, ?)(TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}' + + +@pytest.mark.slow +@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) +@pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [ + (None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + (None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None), + (None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"), + (None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), + (None, 128, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + (None, 128, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), + (None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), + (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), + (None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), + + # TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value) + # (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + # ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), +]) +def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode): + global server + server.jinja = True + server.n_ctx = 8192 * 2 + server.n_predict = n_predict + server.model_hf_repo = hf_repo + server.model_hf_file = None + if isinstance(template_override, tuple): + (template_hf_repo, template_variant) = template_override + server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + elif isinstance(template_override, str): + server.chat_template = template_override + server.start(timeout_seconds=TIMEOUT_SERVER_START) + do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED) + + +def do_test_calc_result(server: ServerProcess, result_override: str | None, n_predict: int, **kwargs): + body = server.make_any_request("POST", "/v1/chat/completions", data={ + "max_tokens": n_predict, + "messages": [ + {"role": "system", "content": "You are a tools-calling assistant. You express numerical values with at most two decimals."}, + {"role": "user", "content": "What's the y coordinate of a point on the unit sphere at angle 30 degrees?"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_6789", + "type": "function", + "function": { + "name": "calculate", + "arguments": "{\"expression\":\"sin(30 * pi / 180)\"}" + } + } + ] + }, + { + "role": "tool", + "name": "calculate", + "content": "0.55644242476", + "tool_call_id": "call_6789" + } + ], + "tools": [ + { + "type":"function", + "function":{ + "name":"calculate", + "description":"A calculator function that computes values of arithmetic expressions in the Python syntax", + "parameters":{ + "type":"object", + "properties":{ + "expression":{ + "type":"string", + "description":"An arithmetic expression to compute the value of (Python syntad, assuming all floats)" + } + }, + "required":["expression"] + } + } + } + ], + **kwargs, + }, timeout=TIMEOUT_HTTP_REQUEST) + choice = body["choices"][0] + tool_calls = choice["message"].get("tool_calls") + assert tool_calls is None, f'Expected no tool call in {choice["message"]}' + content = choice["message"].get("content") + assert content is not None, f'Expected content in {choice["message"]}' + if result_override is not None: + assert re.match(result_override, content), f'Expected {result_override}, got {content}' + else: + assert re.match('^[\\s\\S]*?((That\'s|\\bis) (approximately )?)?\\b0\\.(5\\b|56\\b|556)', content), \ + f'Expected something like "The y coordinate is 0.56.", got {content}' + + +@pytest.mark.slow +@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) +@pytest.mark.parametrize("n_predict,reasoning_format,expect_reasoning_content,expect_content,hf_repo,template_override", [ + (128, 'deepseek', None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (128, None, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (1024, 'deepseek', "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + (1024, 'deepseek', "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), + # (1024, 'none', CompletionMode.NORMAL, None, "^(\\s*)?I need[\\s\\S]*?\\s*To find[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + # (128, 'deepseek', None, "^Okay, let me figure out the sum of 102 and 7[\\s\\S]*", "bartowski/Qwen_QwQ-32B-GGUF:Q4_K_M", None), +]) +def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode): + global server + server.reasoning_format = reasoning_format + server.jinja = True + server.n_ctx = 8192 * 2 + server.n_predict = n_predict + server.model_hf_repo = hf_repo + server.model_hf_file = None + if isinstance(template_override, tuple): + (template_hf_repo, template_variant) = template_override + server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + elif isinstance(template_override, str): + server.chat_template = template_override + server.start(timeout_seconds=TIMEOUT_SERVER_START) + body = server.make_any_request("POST", "/v1/chat/completions", data={ + "max_tokens": n_predict, + "messages": [ + {"role": "user", "content": "What's the sum of 102 and 7?"}, + ], + "stream": stream == CompletionMode.STREAMED, + }, timeout=TIMEOUT_HTTP_REQUEST) + choice = body["choices"][0] + assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}' + + content = choice["message"].get("content") + if expect_content is None: + assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' + else: + assert re.match(expect_content, content), f'Expected {expect_content}, got {content}' + + reasoning_content = choice["message"].get("reasoning_content") + if expect_reasoning_content is None: + assert reasoning_content is None, f'Expected no reasoning content in {choice["message"]}' + else: + assert re.match(expect_reasoning_content, reasoning_content), f'Expected {expect_reasoning_content}, got {reasoning_content}' + + +@pytest.mark.slow +@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) +@pytest.mark.parametrize("hf_repo,template_override", [ + ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + + ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + + ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)), + ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"), + + # ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), + + ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), + ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", None), + + ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), + ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", None), + + ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), + + ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), + + ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), + ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), + + ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), + + ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"), +]) +def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode): + global server + n_predict = 512 # High because of DeepSeek R1 + server.jinja = True + server.n_ctx = 8192 + server.n_predict = n_predict + server.model_hf_repo = hf_repo + server.model_hf_file = None + if isinstance(template_override, tuple): + (template_hf_repo, template_variant) = template_override + server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + elif isinstance(template_override, str): + server.chat_template = template_override + server.start(timeout_seconds=TIMEOUT_SERVER_START) + + do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict) + + +def do_test_hello_world(server: ServerProcess, **kwargs): + body = server.make_any_request("POST", "/v1/chat/completions", data={ + "messages": [ + {"role": "system", "content": "You are a tool-calling agent."}, + {"role": "user", "content": "say hello world with python"}, + ], + "tools": [PYTHON_TOOL], + **kwargs, + }, timeout=TIMEOUT_HTTP_REQUEST) + choice = body["choices"][0] + tool_calls = choice["message"].get("tool_calls") + assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' + tool_call = tool_calls[0] + # assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' + assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"] + # assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}' + actual_arguments = json.loads(tool_call["function"]["arguments"]) + assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}" + code = actual_arguments["code"] + assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}" + assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', re.sub(r'#.*\n?', '', code)), f'Expected hello world, got {code}' diff --git a/tools/server/tests/unit/test_vision_api.py b/tools/server/tests/unit/test_vision_api.py new file mode 100644 index 0000000000000000000000000000000000000000..fc63caa1342939578dadcc5aa1100d41c6754df5 --- /dev/null +++ b/tools/server/tests/unit/test_vision_api.py @@ -0,0 +1,60 @@ +import pytest +from utils import * +import base64 +import requests + +server: ServerProcess + +IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png" +IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png" + +response = requests.get(IMG_URL_0) +response.raise_for_status() # Raise an exception for bad status codes +IMG_BASE64_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") + + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.tinygemma3() + + +@pytest.mark.parametrize( + "prompt, image_url, success, re_content", + [ + # test model is trained on CIFAR-10, but it's quite dumb due to small size + ("What is this:\n", IMG_URL_0, True, "(cat)+"), + ("What is this:\n", "IMG_BASE64_0", True, "(cat)+"), # exceptional, so that we don't cog up the log + ("What is this:\n", IMG_URL_1, True, "(frog)+"), + ("Test test\n", IMG_URL_1, True, "(frog)+"), # test invalidate cache + ("What is this:\n", "malformed", False, None), + ("What is this:\n", "https://google.com/404", False, None), # non-existent image + ("What is this:\n", "https://ggml.ai", False, None), # non-image data + # TODO @ngxson : test with multiple images, no images and with audio + ] +) +def test_vision_chat_completion(prompt, image_url, success, re_content): + global server + server.start(timeout_seconds=60) # vision model may take longer to load due to download size + if image_url == "IMG_BASE64_0": + image_url = IMG_BASE64_0 + res = server.make_request("POST", "/chat/completions", data={ + "temperature": 0.0, + "top_k": 1, + "messages": [ + {"role": "user", "content": [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": { + "url": image_url, + }}, + ]}, + ], + }) + if success: + assert res.status_code == 200 + choice = res.body["choices"][0] + assert "assistant" == choice["message"]["role"] + assert match_regex(re_content, choice["message"]["content"]) + else: + assert res.status_code != 200 + diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bc547ca03bf1bcb4b1ca9fec2ef1ece1f53efc94 --- /dev/null +++ b/tools/server/tests/utils.py @@ -0,0 +1,557 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# type: ignore[reportUnusedImport] + +import subprocess +import os +import re +import json +import sys +import requests +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import ( + Any, + Callable, + ContextManager, + Iterable, + Iterator, + List, + Literal, + Tuple, + Set, +) +from re import RegexFlag +import wget + + +DEFAULT_HTTP_TIMEOUT = 12 + +if "LLAMA_SANITIZE" in os.environ or "GITHUB_ACTION" in os.environ: + DEFAULT_HTTP_TIMEOUT = 30 + + +class ServerResponse: + headers: dict + status_code: int + body: dict | Any + + +class ServerProcess: + # default options + debug: bool = False + server_port: int = 8080 + server_host: str = "127.0.0.1" + model_hf_repo: str = "ggml-org/models" + model_hf_file: str | None = "tinyllamas/stories260K.gguf" + model_alias: str = "tinyllama-2" + temperature: float = 0.8 + seed: int = 42 + + # custom options + model_alias: str | None = None + model_url: str | None = None + model_file: str | None = None + model_draft: str | None = None + n_threads: int | None = None + n_gpu_layer: int | None = None + n_batch: int | None = None + n_ubatch: int | None = None + n_ctx: int | None = None + n_ga: int | None = None + n_ga_w: int | None = None + n_predict: int | None = None + n_prompts: int | None = 0 + slot_save_path: str | None = None + id_slot: int | None = None + cache_prompt: bool | None = None + n_slots: int | None = None + ctk: str | None = None + ctv: str | None = None + fa: bool | None = None + server_continuous_batching: bool | None = False + server_embeddings: bool | None = False + server_reranking: bool | None = False + server_metrics: bool | None = False + server_slots: bool | None = False + pooling: str | None = None + draft: int | None = None + api_key: str | None = None + lora_files: List[str] | None = None + disable_ctx_shift: int | None = False + draft_min: int | None = None + draft_max: int | None = None + no_webui: bool | None = None + jinja: bool | None = None + reasoning_format: Literal['deepseek', 'none', 'nothink'] | None = None + reasoning_budget: int | None = None + chat_template: str | None = None + chat_template_file: str | None = None + server_path: str | None = None + mmproj_url: str | None = None + + # session variables + process: subprocess.Popen | None = None + + def __init__(self): + if "N_GPU_LAYERS" in os.environ: + self.n_gpu_layer = int(os.environ["N_GPU_LAYERS"]) + if "DEBUG" in os.environ: + self.debug = True + if "PORT" in os.environ: + self.server_port = int(os.environ["PORT"]) + + def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: + if self.server_path is not None: + server_path = self.server_path + elif "LLAMA_SERVER_BIN_PATH" in os.environ: + server_path = os.environ["LLAMA_SERVER_BIN_PATH"] + elif os.name == "nt": + server_path = "../../../build/bin/Release/llama-server.exe" + else: + server_path = "../../../build/bin/llama-server" + server_args = [ + "--host", + self.server_host, + "--port", + self.server_port, + "--temp", + self.temperature, + "--seed", + self.seed, + ] + if self.model_file: + server_args.extend(["--model", self.model_file]) + if self.model_url: + server_args.extend(["--model-url", self.model_url]) + if self.model_draft: + server_args.extend(["--model-draft", self.model_draft]) + if self.model_hf_repo: + server_args.extend(["--hf-repo", self.model_hf_repo]) + if self.model_hf_file: + server_args.extend(["--hf-file", self.model_hf_file]) + if self.n_batch: + server_args.extend(["--batch-size", self.n_batch]) + if self.n_ubatch: + server_args.extend(["--ubatch-size", self.n_ubatch]) + if self.n_threads: + server_args.extend(["--threads", self.n_threads]) + if self.n_gpu_layer: + server_args.extend(["--n-gpu-layers", self.n_gpu_layer]) + if self.draft is not None: + server_args.extend(["--draft", self.draft]) + if self.server_continuous_batching: + server_args.append("--cont-batching") + if self.server_embeddings: + server_args.append("--embedding") + if self.server_reranking: + server_args.append("--reranking") + if self.server_metrics: + server_args.append("--metrics") + if self.server_slots: + server_args.append("--slots") + if self.pooling: + server_args.extend(["--pooling", self.pooling]) + if self.model_alias: + server_args.extend(["--alias", self.model_alias]) + if self.n_ctx: + server_args.extend(["--ctx-size", self.n_ctx]) + if self.n_slots: + server_args.extend(["--parallel", self.n_slots]) + if self.ctk: + server_args.extend(["-ctk", self.ctk]) + if self.ctv: + server_args.extend(["-ctv", self.ctv]) + if self.fa is not None: + server_args.append("-fa") + if self.n_predict: + server_args.extend(["--n-predict", self.n_predict]) + if self.slot_save_path: + server_args.extend(["--slot-save-path", self.slot_save_path]) + if self.n_ga: + server_args.extend(["--grp-attn-n", self.n_ga]) + if self.n_ga_w: + server_args.extend(["--grp-attn-w", self.n_ga_w]) + if self.debug: + server_args.append("--verbose") + if self.lora_files: + for lora_file in self.lora_files: + server_args.extend(["--lora", lora_file]) + if self.disable_ctx_shift: + server_args.extend(["--no-context-shift"]) + if self.api_key: + server_args.extend(["--api-key", self.api_key]) + if self.draft_max: + server_args.extend(["--draft-max", self.draft_max]) + if self.draft_min: + server_args.extend(["--draft-min", self.draft_min]) + if self.no_webui: + server_args.append("--no-webui") + if self.jinja: + server_args.append("--jinja") + if self.reasoning_format is not None: + server_args.extend(("--reasoning-format", self.reasoning_format)) + if self.reasoning_budget is not None: + server_args.extend(("--reasoning-budget", self.reasoning_budget)) + if self.chat_template: + server_args.extend(["--chat-template", self.chat_template]) + if self.chat_template_file: + server_args.extend(["--chat-template-file", self.chat_template_file]) + if self.mmproj_url: + server_args.extend(["--mmproj-url", self.mmproj_url]) + + args = [str(arg) for arg in [server_path, *server_args]] + print(f"tests: starting server with: {' '.join(args)}") + + flags = 0 + if "nt" == os.name: + flags |= subprocess.DETACHED_PROCESS + flags |= subprocess.CREATE_NEW_PROCESS_GROUP + flags |= subprocess.CREATE_NO_WINDOW + + self.process = subprocess.Popen( + [str(arg) for arg in [server_path, *server_args]], + creationflags=flags, + stdout=sys.stdout, + stderr=sys.stdout, + env={**os.environ, "LLAMA_CACHE": "tmp"} if "LLAMA_CACHE" not in os.environ else None, + ) + server_instances.add(self) + + print(f"server pid={self.process.pid}, pytest pid={os.getpid()}") + + # wait for server to start + start_time = time.time() + while time.time() - start_time < timeout_seconds: + try: + response = self.make_request("GET", "/health", headers={ + "Authorization": f"Bearer {self.api_key}" if self.api_key else None + }) + if response.status_code == 200: + self.ready = True + return # server is ready + except Exception as e: + pass + # Check if process died + if self.process.poll() is not None: + raise RuntimeError(f"Server process died with return code {self.process.returncode}") + + print(f"Waiting for server to start...") + time.sleep(0.5) + raise TimeoutError(f"Server did not start within {timeout_seconds} seconds") + + def stop(self) -> None: + if self in server_instances: + server_instances.remove(self) + if self.process: + print(f"Stopping server with pid={self.process.pid}") + self.process.kill() + self.process = None + + def make_request( + self, + method: str, + path: str, + data: dict | Any | None = None, + headers: dict | None = None, + timeout: float | None = None, + ) -> ServerResponse: + url = f"http://{self.server_host}:{self.server_port}{path}" + parse_body = False + if method == "GET": + response = requests.get(url, headers=headers, timeout=timeout) + parse_body = True + elif method == "POST": + response = requests.post(url, headers=headers, json=data, timeout=timeout) + parse_body = True + elif method == "OPTIONS": + response = requests.options(url, headers=headers, timeout=timeout) + else: + raise ValueError(f"Unimplemented method: {method}") + result = ServerResponse() + result.headers = dict(response.headers) + result.status_code = response.status_code + result.body = response.json() if parse_body else None + print("Response from server", json.dumps(result.body, indent=2)) + return result + + def make_stream_request( + self, + method: str, + path: str, + data: dict | None = None, + headers: dict | None = None, + ) -> Iterator[dict]: + url = f"http://{self.server_host}:{self.server_port}{path}" + if method == "POST": + response = requests.post(url, headers=headers, json=data, stream=True) + else: + raise ValueError(f"Unimplemented method: {method}") + for line_bytes in response.iter_lines(): + line = line_bytes.decode("utf-8") + if '[DONE]' in line: + break + elif line.startswith('data: '): + data = json.loads(line[6:]) + print("Partial response from server", json.dumps(data, indent=2)) + yield data + + def make_any_request( + self, + method: str, + path: str, + data: dict | None = None, + headers: dict | None = None, + timeout: float | None = None, + ) -> dict: + stream = data.get('stream', False) + if stream: + content: list[str] = [] + reasoning_content: list[str] = [] + tool_calls: list[dict] = [] + finish_reason: Optional[str] = None + + content_parts = 0 + reasoning_content_parts = 0 + tool_call_parts = 0 + arguments_parts = 0 + + for chunk in self.make_stream_request(method, path, data, headers): + assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}' + choice = chunk['choices'][0] + if choice['delta'].get('content') is not None: + assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!' + content.append(choice['delta']['content']) + content_parts += 1 + if choice['delta'].get('reasoning_content') is not None: + assert len(choice['delta']['reasoning_content']) > 0, f'Expected non empty reasoning_content delta!' + reasoning_content.append(choice['delta']['reasoning_content']) + reasoning_content_parts += 1 + if choice['delta'].get('finish_reason') is not None: + finish_reason = choice['delta']['finish_reason'] + for tc in choice['delta'].get('tool_calls', []): + if 'function' not in tc: + raise ValueError(f"Expected function type, got {tc['type']}") + if tc['index'] >= len(tool_calls): + assert 'id' in tc + assert tc.get('type') == 'function' + assert 'function' in tc and 'name' in tc['function'] and len(tc['function']['name']) > 0, \ + f"Expected function call with name, got {tc.get('function')}" + tool_calls.append(dict( + id="", + type="function", + function=dict( + name="", + arguments="", + ) + )) + tool_call = tool_calls[tc['index']] + if tc.get('id') is not None: + tool_call['id'] = tc['id'] + fct = tc['function'] + assert 'id' not in fct, f"Function call should not have id: {fct}" + if fct.get('name') is not None: + tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name'] + if fct.get('arguments') is not None: + tool_call['function']['arguments'] += fct['arguments'] + arguments_parts += 1 + tool_call_parts += 1 + + print(f'Streamed response had {content_parts} content parts, {reasoning_content_parts} reasoning_content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts') + result = dict( + choices=[ + dict( + index=0, + finish_reason=finish_reason, + message=dict( + role='assistant', + content=''.join(content) if content else None, + reasoning_content=''.join(reasoning_content) if reasoning_content else None, + tool_calls=tool_calls if tool_calls else None, + ), + ) + ], + ) + print("Final response from server", json.dumps(result, indent=2)) + return result + else: + response = self.make_request(method, path, data, headers, timeout=timeout) + assert response.status_code == 200, f"Server returned error: {response.status_code}" + return response.body + + + +server_instances: Set[ServerProcess] = set() + + +class ServerPreset: + @staticmethod + def tinyllama2() -> ServerProcess: + server = ServerProcess() + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "tinyllamas/stories260K.gguf" + server.model_alias = "tinyllama-2" + server.n_ctx = 512 + server.n_batch = 32 + server.n_slots = 2 + server.n_predict = 64 + server.seed = 42 + return server + + @staticmethod + def bert_bge_small() -> ServerProcess: + server = ServerProcess() + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf" + server.model_alias = "bert-bge-small" + server.n_ctx = 512 + server.n_batch = 128 + server.n_ubatch = 128 + server.n_slots = 2 + server.seed = 42 + server.server_embeddings = True + return server + + @staticmethod + def bert_bge_small_with_fa() -> ServerProcess: + server = ServerProcess() + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf" + server.model_alias = "bert-bge-small" + server.n_ctx = 1024 + server.n_batch = 300 + server.n_ubatch = 300 + server.n_slots = 2 + server.fa = True + server.seed = 42 + server.server_embeddings = True + return server + + @staticmethod + def tinyllama_infill() -> ServerProcess: + server = ServerProcess() + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "tinyllamas/stories260K-infill.gguf" + server.model_alias = "tinyllama-infill" + server.n_ctx = 2048 + server.n_batch = 1024 + server.n_slots = 1 + server.n_predict = 64 + server.temperature = 0.0 + server.seed = 42 + return server + + @staticmethod + def stories15m_moe() -> ServerProcess: + server = ServerProcess() + server.model_hf_repo = "ggml-org/stories15M_MOE" + server.model_hf_file = "stories15M_MOE-F16.gguf" + server.model_alias = "stories15m-moe" + server.n_ctx = 2048 + server.n_batch = 1024 + server.n_slots = 1 + server.n_predict = 64 + server.temperature = 0.0 + server.seed = 42 + return server + + @staticmethod + def jina_reranker_tiny() -> ServerProcess: + server = ServerProcess() + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf" + server.model_alias = "jina-reranker" + server.n_ctx = 512 + server.n_batch = 512 + server.n_slots = 1 + server.seed = 42 + server.server_reranking = True + return server + + @staticmethod + def tinygemma3() -> ServerProcess: + server = ServerProcess() + # mmproj is already provided by HF registry API + server.model_hf_repo = "ggml-org/tinygemma3-GGUF" + server.model_hf_file = "tinygemma3-Q8_0.gguf" + server.mmproj_url = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/mmproj-tinygemma3.gguf" + server.model_alias = "tinygemma3" + server.n_ctx = 1024 + server.n_batch = 32 + server.n_slots = 2 + server.n_predict = 4 + server.seed = 42 + return server + + +def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]: + """ + Run multiple functions in parallel and return results in the same order as calls. Equivalent to Promise.all in JS. + + Example usage: + + results = parallel_function_calls([ + (func1, (arg1, arg2)), + (func2, (arg3, arg4)), + ]) + """ + results = [None] * len(function_list) + exceptions = [] + + def worker(index, func, args): + try: + result = func(*args) + results[index] = result + except Exception as e: + exceptions.append((index, str(e))) + + with ThreadPoolExecutor() as executor: + futures = [] + for i, (func, args) in enumerate(function_list): + future = executor.submit(worker, i, func, args) + futures.append(future) + + # Wait for all futures to complete + for future in as_completed(futures): + pass + + # Check if there were any exceptions + if exceptions: + print("Exceptions occurred:") + for index, error in exceptions: + print(f"Function at index {index}: {error}") + + return results + + +def match_regex(regex: str, text: str) -> bool: + return ( + re.compile( + regex, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL + ).search(text) + is not None + ) + + +def download_file(url: str, output_file_path: str | None = None) -> str: + """ + Download a file from a URL to a local path. If the file already exists, it will not be downloaded again. + + output_file_path is the local path to save the downloaded file. If not provided, the file will be saved in the root directory. + + Returns the local path of the downloaded file. + """ + file_name = url.split('/').pop() + output_file = f'./tmp/{file_name}' if output_file_path is None else output_file_path + if not os.path.exists(output_file): + print(f"Downloading {url} to {output_file}") + wget.download(url, out=output_file) + print(f"Done downloading to {output_file}") + else: + print(f"File already exists at {output_file}") + return output_file + + +def is_slow_test_allowed(): + return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON" diff --git a/tools/server/themes/README.md b/tools/server/themes/README.md new file mode 100644 index 0000000000000000000000000000000000000000..62e721a2758d8bbf23bd2e3aaa0aac3a4c3beaa1 --- /dev/null +++ b/tools/server/themes/README.md @@ -0,0 +1,5 @@ +# LLaMA.cpp Server Wild Theme + +Simple themes directory of sample "public" directories. To try any of these add --path to your run like `server --path=wild`. + +![image](wild/wild.png) diff --git a/tools/server/themes/buttons-top/README.md b/tools/server/themes/buttons-top/README.md new file mode 100644 index 0000000000000000000000000000000000000000..808c4cf81a9590fba5070f71099939f82dab7aa3 --- /dev/null +++ b/tools/server/themes/buttons-top/README.md @@ -0,0 +1,7 @@ +# LLaMA.cpp Server Buttons Top Theme + +Simple tweaks to the UI. Chat buttons at the top of the page instead of bottom so you can hit Stop instead of chasing it down the page. + +To use simply run server with `--path=themes/buttons_top` + +![image](buttons_top.png) diff --git a/tools/server/themes/buttons-top/buttons_top.png b/tools/server/themes/buttons-top/buttons_top.png new file mode 100644 index 0000000000000000000000000000000000000000..1c7348e195e393adf05e139188db620ae1eff876 --- /dev/null +++ b/tools/server/themes/buttons-top/buttons_top.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9dc742a3e24623ecbb71e364917a7f20a31783bb0fb2264e16daac1ef1253fa6 +size 119747 diff --git a/tools/server/themes/buttons-top/favicon.ico b/tools/server/themes/buttons-top/favicon.ico new file mode 100644 index 0000000000000000000000000000000000000000..89e154a0a7525ecd19379e8e200e36851c325d31 Binary files /dev/null and b/tools/server/themes/buttons-top/favicon.ico differ diff --git a/tools/server/themes/buttons-top/index.html b/tools/server/themes/buttons-top/index.html new file mode 100644 index 0000000000000000000000000000000000000000..3fb88fcc88d319e16908314a26e19f641150d65f --- /dev/null +++ b/tools/server/themes/buttons-top/index.html @@ -0,0 +1,1052 @@ + + + + + + + llama.cpp - chat + + + + + + + +
+ +
+
+ + + diff --git a/tools/server/themes/wild/README.md b/tools/server/themes/wild/README.md new file mode 100644 index 0000000000000000000000000000000000000000..560bcc81bfde12bb224c8cd9b9bbe3d8da22d3f4 --- /dev/null +++ b/tools/server/themes/wild/README.md @@ -0,0 +1,5 @@ +# LLaMA.cpp Server Wild Theme + +Simple tweaks to the UI. To use simply run server with `--path=themes/wild` + +![image](wild.png) diff --git a/tools/server/themes/wild/favicon.ico b/tools/server/themes/wild/favicon.ico new file mode 100644 index 0000000000000000000000000000000000000000..89e154a0a7525ecd19379e8e200e36851c325d31 Binary files /dev/null and b/tools/server/themes/wild/favicon.ico differ diff --git a/tools/server/themes/wild/index.html b/tools/server/themes/wild/index.html new file mode 100644 index 0000000000000000000000000000000000000000..73f36d4b29fdd884baf0a2e91c193a7283449ef4 --- /dev/null +++ b/tools/server/themes/wild/index.html @@ -0,0 +1,1056 @@ + + + + + + + llama.cpp - chat + + + + + + + +
+ +
+
+ + + diff --git a/tools/server/themes/wild/llama_cpp.png b/tools/server/themes/wild/llama_cpp.png new file mode 100644 index 0000000000000000000000000000000000000000..bad1dc9fcdb200d27058c851dfba649332674906 Binary files /dev/null and b/tools/server/themes/wild/llama_cpp.png differ diff --git a/tools/server/themes/wild/llamapattern.png b/tools/server/themes/wild/llamapattern.png new file mode 100644 index 0000000000000000000000000000000000000000..b366b127c6b176f8f80cc4ee96b841024b4c83d9 --- /dev/null +++ b/tools/server/themes/wild/llamapattern.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:90a680eb7825b6cd75a9b2594ef29b1db1f3d417fa7fa33ca9410144b26ee87b +size 259586 diff --git a/tools/server/themes/wild/wild.png b/tools/server/themes/wild/wild.png new file mode 100644 index 0000000000000000000000000000000000000000..16ab1564f5e0f73584731fb7317476d8ab0fdc68 --- /dev/null +++ b/tools/server/themes/wild/wild.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2b21497c35087e3ea38ffc20b1ae31f8addc427d66c631b59b863ce330fa77ab +size 496463 diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f3e0392a4e9d19ee33b254bec425237845f59186 --- /dev/null +++ b/tools/server/utils.hpp @@ -0,0 +1,1329 @@ +#pragma once + +#include "common.h" +#include "log.h" +#include "llama.h" +#include "arg.h" // common_remote_get_content +#include "base64.hpp" +#include "mtmd.h" +#include "mtmd-helper.h" +#include "chat.h" + +// increase max payload length to allow use of larger context size +#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 +// disable Nagle's algorithm +#define CPPHTTPLIB_TCP_NODELAY true +#include + +#define JSON_ASSERT GGML_ASSERT +#include + +#include +#include +#include +#include +#include +#include + +#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" + +using json = nlohmann::ordered_json; + +#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) + +#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) + +#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) + +using raw_buffer = std::vector; + +template +static T json_value(const json & body, const std::string & key, const T & default_value) { + // Fallback null to default value + if (body.contains(key) && !body.at(key).is_null()) { + try { + return body.at(key); + } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) { + LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value\n", key.c_str(), json(default_value).type_name()); + return default_value; + } + } else { + return default_value; + } +} + +const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); + +// thin wrapper around common_grammar_trigger with (de)serialization functions +struct server_grammar_trigger { + common_grammar_trigger value; + + server_grammar_trigger() = default; + server_grammar_trigger(const common_grammar_trigger & value) : value(value) {} + server_grammar_trigger(const json & in) { + value.type = (common_grammar_trigger_type) in.at("type").get(); + value.value = in.at("value").get(); + if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { + value.token = (llama_token) in.at("token").get(); + } + } + + json to_json() const { + json out { + {"type", (int) value.type}, + {"value", value.value}, + }; + if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { + out["token"] = (int) value.token; + } + return out; + } +}; + +// +// tokenizer and input processing utils +// + +static bool json_is_array_of_numbers(const json & data) { + if (data.is_array()) { + for (const auto & e : data) { + if (!e.is_number_integer()) { + return false; + } + } + return true; + } + return false; +} + +// is array having BOTH numbers & strings? +static bool json_is_array_of_mixed_numbers_strings(const json & data) { + bool seen_string = false; + bool seen_number = false; + if (data.is_array()) { + for (const auto & e : data) { + seen_string |= e.is_string(); + seen_number |= e.is_number_integer(); + if (seen_number && seen_string) { + return true; + } + } + } + return false; +} + +// get value by path(key1 / key2) +static json json_get_nested_values(const std::vector & paths, const json & js) { + json result = json::object(); + + for (const std::string & path : paths) { + json current = js; + const auto keys = string_split(path, /*separator*/ '/'); + bool valid_path = true; + for (const std::string & k : keys) { + if (valid_path && current.is_object() && current.contains(k)) { + current = current[k]; + } else { + valid_path = false; + } + } + if (valid_path) { + result[path] = current; + } + } + return result; +} + +/** + * this handles 2 cases: + * - only string, example: "string" + * - mixed string and tokens, example: [12, 34, "string", 56, 78] + */ +static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) { + // If `add_bos` is true, we only add BOS, when json_prompt is a string, + // or the first element of the json_prompt array is a string. + llama_tokens prompt_tokens; + + if (json_prompt.is_array()) { + bool first = true; + for (const auto & p : json_prompt) { + if (p.is_string()) { + auto s = p.template get(); + + llama_tokens p; + if (first) { + p = common_tokenize(vocab, s, add_special, parse_special); + first = false; + } else { + p = common_tokenize(vocab, s, false, parse_special); + } + + prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); + } else { + if (first) { + first = false; + } + + prompt_tokens.push_back(p.template get()); + } + } + } else { + auto s = json_prompt.template get(); + prompt_tokens = common_tokenize(vocab, s, add_special, parse_special); + } + + return prompt_tokens; +} + +/** + * break the input "prompt" object into multiple prompt if needed, then tokenize them + * this supports these cases: + * - "prompt": "string" + * - "prompt": [12, 34, 56] + * - "prompt": [12, 34, "string", 56, 78] + * and multiple prompts (multi-tasks): + * - "prompt": ["string1", "string2"] + * - "prompt": ["string1", [12, 34, 56]] + * - "prompt": [[12, 34, 56], [78, 90, 12]] + * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]] + */ +static std::vector tokenize_input_prompts(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) { + std::vector result; + if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) { + // string or mixed + result.push_back(tokenize_mixed(vocab, json_prompt, add_special, parse_special)); + } else if (json_is_array_of_numbers(json_prompt)) { + // array of tokens + result.push_back(json_prompt.get()); + } else if (json_prompt.is_array()) { + // array of prompts + result.reserve(json_prompt.size()); + for (const auto & p : json_prompt) { + if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) { + result.push_back(tokenize_mixed(vocab, p, add_special, parse_special)); + } else if (json_is_array_of_numbers(p)) { + // array of tokens + result.push_back(p.get()); + } else { + throw std::runtime_error("element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens"); + } + } + } else { + throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts"); + } + if (result.empty()) { + throw std::runtime_error("\"prompt\" must not be empty"); + } + return result; +} + +// return the last index of character that can form a valid string +// if the last character is potentially cut in half, return the index before the cut +// if validate_utf8(text) == text.size(), then the whole text is valid utf8 +static size_t validate_utf8(const std::string& text) { + size_t len = text.size(); + if (len == 0) return 0; + + // Check the last few bytes to see if a multi-byte character is cut off + for (size_t i = 1; i <= 4 && i <= len; ++i) { + unsigned char c = text[len - i]; + // Check for start of a multi-byte sequence from the end + if ((c & 0xE0) == 0xC0) { + // 2-byte character start: 110xxxxx + // Needs at least 2 bytes + if (i < 2) return len - i; + } else if ((c & 0xF0) == 0xE0) { + // 3-byte character start: 1110xxxx + // Needs at least 3 bytes + if (i < 3) return len - i; + } else if ((c & 0xF8) == 0xF0) { + // 4-byte character start: 11110xxx + // Needs at least 4 bytes + if (i < 4) return len - i; + } + } + + // If no cut-off multi-byte character is found, return full length + return len; +} + +// +// template utils +// + +// format rerank task: [BOS]query[EOS][SEP]doc[EOS] +static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) { + llama_tokens result; + + // Get EOS token - use SEP token as fallback if EOS is not available + llama_token eos_token = llama_vocab_eos(vocab); + if (eos_token == LLAMA_TOKEN_NULL) { + eos_token = llama_vocab_sep(vocab); + } + + result.reserve(doc.size() + query.size() + 4); + result.push_back(llama_vocab_bos(vocab)); + result.insert(result.end(), query.begin(), query.end()); + result.push_back(eos_token); + result.push_back(llama_vocab_sep(vocab)); + result.insert(result.end(), doc.begin(), doc.end()); + result.push_back(eos_token); + + return result; +} + +// format infill task +static llama_tokens format_infill( + const llama_vocab * vocab, + const json & input_prefix, + const json & input_suffix, + const json & input_extra, + const int n_batch, + const int n_predict, + const int n_ctx, + const bool spm_infill, + const llama_tokens & tokens_prompt + ) { + // TODO: optimize this block by reducing memory allocations and movement + + // use FIM repo-level pattern: + // ref: https://arxiv.org/pdf/2409.12186 + // + // [FIM_REP]myproject + // [FIM_SEP]filename0 + // extra chunk 0 + // [FIM_SEP]filename1 + // extra chunk 1 + // ... + // [FIM_SEP]filename + // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt + // + llama_tokens extra_tokens; + extra_tokens.reserve(n_ctx); + + auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false); + auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false); + + if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) { + // TODO: make project name an input + static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false); + + extra_tokens.push_back(llama_vocab_fim_rep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end()); + } + for (const auto & chunk : input_extra) { + // { "text": string, "filename": string } + const std::string text = json_value(chunk, "text", std::string()); + const std::string filename = json_value(chunk, "filename", std::string("tmp")); + + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { + const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false); + + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); + } else { + // chunk separator in binary form to avoid confusing the AI + static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00}; + static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false); + + extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end()); + } + + const auto chunk_tokens = common_tokenize(vocab, text, false, false); + extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end()); + } + + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { + // TODO: current filename + static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false); + + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); + } + + // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?) + const int n_prefix_take = std::min(tokens_prefix.size(), 3*(n_batch/4)); + const int n_suffix_take = std::min(tokens_suffix.size(), std::max(0, (n_batch/4) - (2 + tokens_prompt.size()))); + + SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take)); + + // fill the rest of the context with extra chunks + const int n_extra_take = std::min(std::max(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size()); + + tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take); + tokens_suffix.resize(n_suffix_take); + + tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab)); + tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end()); + tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab)); + + auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix; + auto embd_end = spm_infill ? tokens_prefix : tokens_suffix; + + if (llama_vocab_get_add_bos(vocab)) { + embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); + } + + SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size()); + + // put the extra context before the FIM prefix + embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end()); + + embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); + embd_inp.push_back(llama_vocab_fim_mid(vocab)); + + return embd_inp; +} + +// +// base64 utils (TODO: move to common in the future) +// + +static const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + +static inline bool is_base64(uint8_t c) { + return (isalnum(c) || (c == '+') || (c == '/')); +} + +static inline raw_buffer base64_decode(const std::string & encoded_string) { + int i = 0; + int j = 0; + int in_ = 0; + + int in_len = encoded_string.size(); + + uint8_t char_array_4[4]; + uint8_t char_array_3[3]; + + raw_buffer ret; + + while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { + char_array_4[i++] = encoded_string[in_]; in_++; + if (i == 4) { + for (i = 0; i < 4; i++) { + char_array_4[i] = base64_chars.find(char_array_4[i]); + } + + char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (i = 0; (i < 3); i++) { + ret.push_back(char_array_3[i]); + } + + i = 0; + } + } + + if (i) { + for (j = i; j < 4; j++) { + char_array_4[j] = 0; + } + + for (j = 0; j < 4; j++) { + char_array_4[j] = base64_chars.find(char_array_4[j]); + } + + char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (j = 0; j < i - 1; j++) { + ret.push_back(char_array_3[j]); + } + } + + return ret; +} + +// +// random string / id +// + +static std::string random_string() { + static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); + + std::random_device rd; + std::mt19937 generator(rd()); + + std::string result(32, ' '); + + for (int i = 0; i < 32; ++i) { + result[i] = str[generator() % str.size()]; + } + + return result; +} + +static std::string gen_chatcmplid() { + return "chatcmpl-" + random_string(); +} + +static std::string gen_tool_call_id() { + return random_string(); +} + +// +// other common utils +// + +// TODO: reuse llama_detokenize +template +static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { + std::string ret; + for (; begin != end; ++begin) { + ret += common_token_to_piece(ctx, *begin); + } + + return ret; +} + +// format incomplete utf-8 multibyte character for output +static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) { + std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token); + + // if the size is 1 and first bit is 1, meaning it's a partial character + // (size > 1 meaning it's already a known token) + if (out.size() == 1 && (out[0] & 0x80) == 0x80) { + std::stringstream ss; + ss << std::hex << (out[0] & 0xff); + std::string res(ss.str()); + out = "byte: \\x" + res; + } + + return out; +} + +static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) { + const std::string str = + std::string(event) + ": " + + data.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row). + + LOG_DBG("data stream, to_send: %s", str.c_str()); + + return sink.write(str.c_str(), str.size()); +} + +// +// OAI utils +// + +// used by /completions endpoint +static json oaicompat_completion_params_parse(const json & body) { + json llama_params; + + if (!body.contains("prompt")) { + throw std::runtime_error("\"prompt\" is required"); + } + + // Handle "stop" field + if (body.contains("stop") && body.at("stop").is_string()) { + llama_params["stop"] = json::array({body.at("stop").get()}); + } else { + llama_params["stop"] = json_value(body, "stop", json::array()); + } + + // Handle "n" field + int n_choices = json_value(body, "n", 1); + if (n_choices != 1) { + throw std::runtime_error("Only one completion choice is allowed"); + } + + // Handle "echo" field + if (json_value(body, "echo", false)) { + throw std::runtime_error("Only no echo is supported"); + } + + // Params supported by OAI but unsupported by llama.cpp + static const std::vector unsupported_params { "best_of", "suffix" }; + for (const auto & param : unsupported_params) { + if (body.contains(param)) { + throw std::runtime_error("Unsupported param: " + param); + } + } + + // Copy remaining properties to llama_params + for (const auto & item : body.items()) { + // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" + if (!llama_params.contains(item.key()) || item.key() == "n_predict") { + llama_params[item.key()] = item.value(); + } + } + + return llama_params; +} + +struct oaicompat_parser_options { + bool use_jinja; + bool prefill_assistant; + common_reasoning_format reasoning_format; + common_chat_templates * tmpls; + bool allow_image; + bool allow_audio; + bool enable_thinking = true; +}; + +// used by /chat/completions endpoint +static json oaicompat_chat_params_parse( + json & body, /* openai api json semantics */ + const oaicompat_parser_options & opt, + std::vector & out_files) +{ + json llama_params; + + auto tools = json_value(body, "tools", json()); + auto has_tools = tools.is_array() && !tools.empty(); + auto stream = json_value(body, "stream", false); + auto tool_choice = json_value(body, "tool_choice", std::string("auto")); + + if (!opt.use_jinja) { + if (has_tools) { + throw std::runtime_error("tools param requires --jinja flag"); + } + if (tool_choice != "auto") { + throw std::runtime_error("tool_choice param requires --jinja flag"); + } + } + + // Handle "stop" field + if (body.contains("stop") && body.at("stop").is_string()) { + llama_params["stop"] = json::array({body.at("stop").get()}); + } else { + llama_params["stop"] = json_value(body, "stop", json::array()); + } + + auto json_schema = json_value(body, "json_schema", json()); + auto grammar = json_value(body, "grammar", std::string()); + if (!json_schema.is_null() && !grammar.empty()) { + throw std::runtime_error("Cannot use both json_schema and grammar"); + } + + // Handle "response_format" field + if (body.contains("response_format")) { + json response_format = json_value(body, "response_format", json::object()); + std::string response_type = json_value(response_format, "type", std::string()); + if (response_type == "json_object") { + json_schema = json_value(response_format, "schema", json::object()); + } else if (response_type == "json_schema") { + auto schema_wrapper = json_value(response_format, "json_schema", json::object()); + json_schema = json_value(schema_wrapper, "schema", json::object()); + } else if (!response_type.empty() && response_type != "text") { + throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); + } + } + + // get input files + if (!body.contains("messages")) { + throw std::runtime_error("'messages' is required"); + } + json & messages = body.at("messages"); + if (!messages.is_array()) { + throw std::runtime_error("Expected 'messages' to be an array"); + } + for (auto & msg : messages) { + std::string role = json_value(msg, "role", std::string()); + if (role != "assistant" && !msg.contains("content")) { + throw std::runtime_error("All non-assistant messages must contain 'content'"); + } + if (role == "assistant") { + if (!msg.contains("content") && !msg.contains("tool_calls")) { + throw std::runtime_error("Assistant message must contain either 'content' or 'tool_calls'!"); + } + if (!msg.contains("content")) { + continue; // avoid errors with no content + } + } + json & content = msg.at("content"); + if (content.is_string() || content.is_null()) { + continue; + } + + if (!content.is_array()) { + throw std::runtime_error("Expected 'content' to be a string or an array"); + } + + for (auto & p : content) { + std::string type = json_value(p, "type", std::string()); + if (type == "image_url") { + if (!opt.allow_image) { + throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj"); + } + + json image_url = json_value(p, "image_url", json::object()); + std::string url = json_value(image_url, "url", std::string()); + if (string_starts_with(url, "http")) { + // download remote image + // TODO @ngxson : maybe make these params configurable + common_remote_params params; + params.headers.push_back("User-Agent: llama.cpp/" + build_info); + params.max_size = 1024 * 1024 * 10; // 10MB + params.timeout = 10; // seconds + SRV_INF("downloading image from '%s'\n", url.c_str()); + auto res = common_remote_get_content(url, params); + if (200 <= res.first && res.first < 300) { + SRV_INF("downloaded %ld bytes\n", res.second.size()); + raw_buffer data; + data.insert(data.end(), res.second.begin(), res.second.end()); + out_files.push_back(data); + } else { + throw std::runtime_error("Failed to download image"); + } + + } else { + // try to decode base64 image + std::vector parts = string_split(url, /*separator*/ ','); + if (parts.size() != 2) { + throw std::runtime_error("Invalid image_url.url value"); + } else if (!string_starts_with(parts[0], "data:image/")) { + throw std::runtime_error("Invalid image_url.url format: " + parts[0]); + } else if (!string_ends_with(parts[0], "base64")) { + throw std::runtime_error("image_url.url must be base64 encoded"); + } else { + auto base64_data = parts[1]; + auto decoded_data = base64_decode(base64_data); + out_files.push_back(decoded_data); + } + } + + // replace this chunk with a marker + p["type"] = "text"; + p["text"] = mtmd_default_marker(); + p.erase("image_url"); + + } else if (type == "input_audio") { + if (!opt.allow_audio) { + throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj"); + } + + json input_audio = json_value(p, "input_audio", json::object()); + std::string data = json_value(input_audio, "data", std::string()); + std::string format = json_value(input_audio, "format", std::string()); + // while we also support flac, we don't allow it here so we matches the OAI spec + if (format != "wav" && format != "mp3") { + throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'"); + } + auto decoded_data = base64_decode(data); // expected to be base64 encoded + out_files.push_back(decoded_data); + + // replace this chunk with a marker + p["type"] = "text"; + p["text"] = mtmd_default_marker(); + p.erase("input_audio"); + + } else if (type != "text") { + throw std::runtime_error("unsupported content[].type"); + } + } + } + + common_chat_templates_inputs inputs; + inputs.messages = common_chat_msgs_parse_oaicompat(messages); + inputs.tools = common_chat_tools_parse_oaicompat(tools); + inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice); + inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); + inputs.grammar = grammar; + inputs.use_jinja = opt.use_jinja; + inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); + inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); + inputs.reasoning_format = opt.reasoning_format; + inputs.enable_thinking = opt.enable_thinking; + if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { + if (body.contains("grammar")) { + throw std::runtime_error("Cannot use custom grammar constraints with tools."); + } + llama_params["parse_tool_calls"] = true; + } + + // if the assistant message appears at the end of list, we do not add end-of-turn token + // for ex. this can be useful to modify the reasoning process in reasoning models + bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && opt.prefill_assistant; + common_chat_msg last_message; + if (prefill_assistant_message) { + last_message = inputs.messages.back(); + inputs.messages.pop_back(); + + /* sanity check, max one assistant message at the end of the list */ + if (!inputs.messages.empty() && inputs.messages.back().role == "assistant"){ + throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list."); + } + + /* TODO: test this properly */ + inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE; + inputs.add_generation_prompt = true; + } + + // Apply chat template to the list of messages + auto chat_params = common_chat_templates_apply(opt.tmpls, inputs); + + /* Append assistant prefilled message */ + if (prefill_assistant_message) { + chat_params.prompt += last_message.content; + } + + llama_params["chat_format"] = static_cast(chat_params.format); + llama_params["prompt"] = chat_params.prompt; + if (!chat_params.grammar.empty()) { + llama_params["grammar"] = chat_params.grammar; + } + llama_params["grammar_lazy"] = chat_params.grammar_lazy; + auto grammar_triggers = json::array(); + for (const auto & trigger : chat_params.grammar_triggers) { + server_grammar_trigger ct(trigger); + grammar_triggers.push_back(ct.to_json()); + } + llama_params["grammar_triggers"] = grammar_triggers; + llama_params["preserved_tokens"] = chat_params.preserved_tokens; + llama_params["thinking_forced_open"] = chat_params.thinking_forced_open; + for (const auto & stop : chat_params.additional_stops) { + llama_params["stop"].push_back(stop); + } + + // Handle "n" field + int n_choices = json_value(body, "n", 1); + if (n_choices != 1) { + throw std::runtime_error("Only one completion choice is allowed"); + } + + // Handle "logprobs" field + // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future + if (json_value(body, "logprobs", false)) { + if (has_tools && stream) { + throw std::runtime_error("logprobs is not supported with tools + stream"); + } + llama_params["n_probs"] = json_value(body, "top_logprobs", 20); + } else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) { + throw std::runtime_error("top_logprobs requires logprobs to be set to true"); + } + + // Copy remaining properties to llama_params + // This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint. + // See "launch_slot_with_task()" for a complete list of params supported by llama.cpp + for (const auto & item : body.items()) { + // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" + if (!llama_params.contains(item.key()) || item.key() == "n_predict") { + llama_params[item.key()] = item.value(); + } + } + + return llama_params; +} + +static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) { + json data = json::array(); + int32_t n_tokens = 0; + int i = 0; + for (const auto & elem : embeddings) { + json embedding_obj; + + if (use_base64) { + const auto& vec = json_value(elem, "embedding", json::array()).get>(); + const char* data_ptr = reinterpret_cast(vec.data()); + size_t data_size = vec.size() * sizeof(float); + embedding_obj = { + {"embedding", base64::encode(data_ptr, data_size)}, + {"index", i++}, + {"object", "embedding"}, + {"encoding_format", "base64"} + }; + } else { + embedding_obj = { + {"embedding", json_value(elem, "embedding", json::array())}, + {"index", i++}, + {"object", "embedding"} + }; + } + data.push_back(embedding_obj); + + n_tokens += json_value(elem, "tokens_evaluated", 0); + } + + json res = json { + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json { + {"prompt_tokens", n_tokens}, + {"total_tokens", n_tokens} + }}, + {"data", data} + }; + + return res; +} + +static json format_response_rerank( + const json & request, + const json & ranks, + bool is_tei_format, + std::vector & texts) { + json res; + if (is_tei_format) { + // TEI response format + res = json::array(); + bool return_text = json_value(request, "return_text", false); + for (const auto & rank : ranks) { + int index = json_value(rank, "index", 0); + json elem = json{ + {"index", index}, + {"score", json_value(rank, "score", 0.0)}, + }; + if (return_text) { + elem["text"] = std::move(texts[index]); + } + res.push_back(elem); + } + } else { + // Jina response format + json results = json::array(); + int32_t n_tokens = 0; + for (const auto & rank : ranks) { + results.push_back(json{ + {"index", json_value(rank, "index", 0)}, + {"relevance_score", json_value(rank, "score", 0.0)}, + }); + + n_tokens += json_value(rank, "tokens_evaluated", 0); + } + + res = json{ + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json{ + {"prompt_tokens", n_tokens}, + {"total_tokens", n_tokens} + }}, + {"results", results} + }; + } + + return res; +} + +static bool is_valid_utf8(const std::string & str) { + const unsigned char* bytes = reinterpret_cast(str.data()); + const unsigned char* end = bytes + str.length(); + + while (bytes < end) { + if (*bytes <= 0x7F) { + // 1-byte sequence (0xxxxxxx) + bytes++; + } else if ((*bytes & 0xE0) == 0xC0) { + // 2-byte sequence (110xxxxx 10xxxxxx) + if (end - bytes < 2 || (bytes[1] & 0xC0) != 0x80) + return false; + bytes += 2; + } else if ((*bytes & 0xF0) == 0xE0) { + // 3-byte sequence (1110xxxx 10xxxxxx 10xxxxxx) + if (end - bytes < 3 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80) + return false; + bytes += 3; + } else if ((*bytes & 0xF8) == 0xF0) { + // 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) + if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 || + (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80) + return false; + bytes += 4; + } else { + // Invalid UTF-8 lead byte + return false; + } + } + + return true; +} + +static json format_tokenizer_response(const json & tokens) { + return json { + {"tokens", tokens} + }; +} + +static json format_detokenized_response(const std::string & content) { + return json { + {"content", content} + }; +} + +static json format_logit_bias(const std::vector & logit_bias) { + json data = json::array(); + for (const auto & lb : logit_bias) { + data.push_back(json{ + {"bias", lb.bias}, + {"token", lb.token}, + }); + } + return data; +} + +static std::string safe_json_to_str(const json & data) { + return data.dump(-1, ' ', false, json::error_handler_t::replace); +} + +static std::vector get_token_probabilities(llama_context * ctx, int idx) { + std::vector cur; + const auto * logits = llama_get_logits_ith(ctx, idx); + + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_vocab = llama_vocab_n_tokens(vocab); + + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } + + // sort tokens by logits + std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + + // apply softmax + float max_l = cur[0].logit; + float cum_sum = 0.0f; + for (size_t i = 0; i < cur.size(); ++i) { + float p = expf(cur[i].logit - max_l); + cur[i].p = p; + cum_sum += p; + } + for (size_t i = 0; i < cur.size(); ++i) { + cur[i].p /= cum_sum; + } + + return cur; +} + +static bool are_lora_equal( + const std::vector & l1, + const std::vector & l2) { + if (l1.size() != l2.size()) { + return false; + } + for (size_t i = 0; i < l1.size(); ++i) { + // we don't check lora.path to reduce the time complexity + if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) { + return false; + } + } + return true; +} + +// parse lora config from JSON request, returned a copy of lora_base with updated scale +static std::vector parse_lora_request( + const std::vector & lora_base, + const json & data) { + std::vector lora(lora_base); + int max_idx = lora.size(); + + // clear existing value + for (auto & entry : lora) { + entry.scale = 0.0f; + } + + // set value + for (const auto & entry : data) { + int id = json_value(entry, "id", -1); + float scale = json_value(entry, "scale", 0.0f); + if (0 <= id && id < max_idx) { + lora[id].scale = scale; + } else { + throw std::runtime_error("invalid adapter id"); + } + } + + return lora; +} + +// +// utils for interacting with libmtmd +// (may need to refactor in near future) +// + +/** + * server_tokens is a helper to manage the input tokens and image for the server. + * it is made this way to simplify the logic of KV cache management. + */ +struct server_tokens { + bool has_mtmd = false; + +private: // disallow accessing these members directly, risking out-of-sync + + // map a **start** position in tokens to the image chunk + std::unordered_map map_pos_to_media; + + // list of tokens + // it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token + // a mtmd_input_chunk can occupy multiple tokens, one llama_token per **position** + // important: for models using mrope, an image can contain multiple tokens but will use only one **position** + llama_tokens tokens; + + // for ex. with input of 5 text tokens and 2 images: + // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] + // pos 0 1 2 3 4 5 6 7 8 9 + // map_pos_to_media will contain: {5, img0}, {8, img1} + +public: + server_tokens() = default; + ~server_tokens() = default; + + // Prevent copying + server_tokens(const server_tokens&) = delete; + server_tokens& operator=(const server_tokens&) = delete; + + // Allow moving (usually implicitly generated if members are movable) + server_tokens(server_tokens&&) = default; + server_tokens& operator=(server_tokens&&) = default; + + // Allow accessing elements using [] operator + llama_token operator[](size_t index) { return tokens[index]; } + const llama_token& operator[](size_t index) const { return tokens[index]; } + + server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) { + for (size_t i = 0; i < mtmd_chunks.size(); ++i) { + push_back(mtmd_chunks[i]); + } + } + + server_tokens(llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} + + // for debugging + std::string str() const { + std::ostringstream oss; + oss << "tokens: "; + for (const auto & t : tokens) { + if (t == LLAMA_TOKEN_NULL) { + oss << " "; + } else { + oss << t << " "; + } + } + oss << "\n"; + oss << "image pos: "; + for (const auto & it : map_pos_to_media) { + oss << it.first << ", "; + } + return oss.str(); + } + + const mtmd::input_chunk_ptr & find_chunk(llama_pos pos) const { + auto it = map_pos_to_media.find(pos); + if (it != map_pos_to_media.end()) { + return it->second; + } else { + throw std::runtime_error("Chunk not found"); + } + } + + void push_back(llama_token tok) { + if (tok == LLAMA_TOKEN_NULL) { + throw std::runtime_error("Invalid token"); + } + tokens.emplace_back(tok); + } + + // will create a copy of the chunk if it contains non-text data + void push_back(const mtmd_input_chunk * chunk) { + auto type = mtmd_input_chunk_get_type(chunk); + if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { + GGML_ASSERT(has_mtmd); + const int n_pos = mtmd_input_chunk_get_n_pos(chunk); + llama_pos start_pos = tokens.size(); + for (int i = 0; i < n_pos; ++i) { + tokens.emplace_back(LLAMA_TOKEN_NULL); + } + mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); + map_pos_to_media[start_pos] = std::move(new_chunk); + } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + size_t n_tokens; + auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); + for (size_t i = 0; i < n_tokens; ++i) { + push_back(text_tokens[i]); + } + } else { + GGML_ABORT("Invalid chunk type"); + } + } + + // for compatibility with context shift and prompt truncation + void insert(const llama_tokens & inp_tokens) { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end()); + } + + // for compatibility with speculative decoding, ctx shift, slot save/load + const llama_tokens & get_text_tokens() const { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + return tokens; + } + + // for compatibility with speculative decoding + void set_token(llama_pos pos, llama_token id) { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + tokens[pos] = id; + } + + size_t size() const { + return tokens.size(); + } + + bool empty() const { + return tokens.empty(); + } + + void clear() { + tokens.clear(); + } + + void keep_first(size_t n) { + GGML_ASSERT(n <= tokens.size()); + if (has_mtmd) { + if (n == tokens.size()) { + return; // nothing to do + } + // we throw an error if we try to remove a token in the middle of an image + // for ex. with input of 5 text tokens and 2 images: + // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] + // n 1 2 3 4 5 6 7 8 9 10 + // allowed to resize ^ ^ + // disallowed to resize ^ ^ ^ + if (n > 0) { + llama_token last_token = tokens[n - 1]; + // make sure we never remove tokens in the middle of an image + if (last_token == LLAMA_TOKEN_NULL) { + find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk + } + } + // remove all image chunks that are not used anymore + for (auto it = map_pos_to_media.begin(); it != map_pos_to_media.end(); ) { + llama_pos pos = it->first; + if (pos >= (llama_pos)n) { + it = map_pos_to_media.erase(it); + } else { + ++it; + } + } + } + tokens.resize(n); + } + + std::string detokenize(const llama_context * ctx, bool special) const { + llama_tokens text_tokens; + text_tokens.reserve(tokens.size()); + for (const auto & t : tokens) { + if (t != LLAMA_TOKEN_NULL) { + text_tokens.push_back(t); + } + } + return common_detokenize(ctx, text_tokens, special); + } + + size_t get_common_prefix(const server_tokens & b) const { + size_t max_idx = std::min(tokens.size(), b.tokens.size()); + for (size_t i = 0; i < max_idx; ++i) { + auto & ai = tokens[i]; + auto & bi = b.tokens[i]; + + if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { + GGML_ASSERT(has_mtmd); + const auto & a_chunk = find_chunk(i); + const auto & b_chunk = b.find_chunk(i); + GGML_ASSERT(a_chunk && b_chunk); + std::string ai_id = mtmd_input_chunk_get_id(a_chunk.get()); + std::string bi_id = mtmd_input_chunk_get_id(b_chunk.get()); + size_t a_pos = mtmd_input_chunk_get_n_pos(a_chunk.get()); + size_t b_pos = mtmd_input_chunk_get_n_pos(b_chunk.get()); + if (ai_id == bi_id && a_pos == b_pos) { + GGML_ASSERT(a_pos > 0 && "Invalid media chunk"); // should never happen + i += a_pos - 1; // will be +1 by the for loop + continue; + } else { + return i; + } + } else if (ai == bi) { + continue; + } else { + return i; + } + } + return max_idx; // all tokens are equal + } + + // make sure all text tokens are within the vocab range + bool validate(const struct llama_context * ctx) const { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + const int32_t n_vocab = llama_vocab_n_tokens(vocab); + + for (size_t i = 0; i < tokens.size(); ++i) { + auto & t = tokens[i]; + if (t == LLAMA_TOKEN_NULL) { + try { + const auto & chunk = find_chunk(i); + size_t n_pos = mtmd_input_chunk_get_n_pos(chunk.get()); + i += n_pos - 1; // will be +1 by the for loop + } catch (const std::exception & e) { + return false; + } + } else if (t < 0 || t >= n_vocab) { + return false; + } + } + return true; + } + + // encode and decode the image chunk + int32_t process_chunk( + llama_context * ctx, + mtmd_context * mctx, + llama_pos n_past, + int32_t seq_id, + llama_pos & n_pos_out) { + auto & chunk = find_chunk(n_past); + const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE + ? "image" : "audio"; + SRV_INF("processing %s...\n", name); + int32_t n_batch = llama_n_batch(ctx); + int64_t t0 = ggml_time_ms(); + llama_pos new_n_past = n_past; + int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, + chunk.get(), + n_past, + seq_id, + n_batch, + true, // logits last + &new_n_past); + SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0); + if (result != 0) { + LOG_ERR("mtmd_helper_eval failed with status %d", result); + n_pos_out = n_past; + return result; + } + n_pos_out = new_n_past; + return 0; + } +}; + +// Computes FNV-1a hash of the data +static std::string fnv_hash(const uint8_t * data, size_t len) { + const uint64_t fnv_prime = 0x100000001b3ULL; + uint64_t hash = 0xcbf29ce484222325ULL; + + for (size_t i = 0; i < len; ++i) { + hash ^= data[i]; + hash *= fnv_prime; + } + return std::to_string(hash); +} diff --git a/tools/server/webui/.gitignore b/tools/server/webui/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..a547bf36d8d11a4f89c59c144f24795749086dd1 --- /dev/null +++ b/tools/server/webui/.gitignore @@ -0,0 +1,24 @@ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +pnpm-debug.log* +lerna-debug.log* + +node_modules +dist +dist-ssr +*.local + +# Editor directories and files +.vscode/* +!.vscode/extensions.json +.idea +.DS_Store +*.suo +*.ntvs* +*.njsproj +*.sln +*.sw? diff --git a/tools/server/webui/.prettierignore b/tools/server/webui/.prettierignore new file mode 100644 index 0000000000000000000000000000000000000000..c0cb165b37e86fd8572a0b91351ac6edb8c46a67 --- /dev/null +++ b/tools/server/webui/.prettierignore @@ -0,0 +1,10 @@ +**/.vscode +**/.github +**/.git +**/.svn +**/.hg +**/node_modules +**/dist +**/build + +*.config.js diff --git a/tools/server/webui/eslint.config.js b/tools/server/webui/eslint.config.js new file mode 100644 index 0000000000000000000000000000000000000000..7c0d39b89b50b0e098b3f3cfc1bee3ac751d3027 --- /dev/null +++ b/tools/server/webui/eslint.config.js @@ -0,0 +1,26 @@ +import js from '@eslint/js' +import globals from 'globals' +import reactHooks from 'eslint-plugin-react-hooks' +import reactRefresh from 'eslint-plugin-react-refresh' +import tseslint from 'typescript-eslint' + +export default tseslint.config( + { ignores: ['dist'] }, + { + extends: [js.configs.recommended, ...tseslint.configs.recommended], + files: ['**/*.{ts,tsx}'], + languageOptions: { + ecmaVersion: 2020, + globals: globals.browser, + }, + plugins: { + 'react-hooks': reactHooks, + 'react-refresh': reactRefresh, + }, + rules: { + ...reactHooks.configs.recommended.rules, + 'react-refresh/only-export-components': 'off', + '@typescript-eslint/no-unused-vars': 'off', + }, + }, +) diff --git a/tools/server/webui/index.html b/tools/server/webui/index.html new file mode 100644 index 0000000000000000000000000000000000000000..471f46b3ad19bbe6bd54d32c03a4311f3ee7add7 --- /dev/null +++ b/tools/server/webui/index.html @@ -0,0 +1,16 @@ + + + + + + + 🦙 llama.cpp - chat + + +
+ + + diff --git a/tools/server/webui/package-lock.json b/tools/server/webui/package-lock.json new file mode 100644 index 0000000000000000000000000000000000000000..a05cbcfe5c39268cfadac3644bcfed95d2f3e3b5 --- /dev/null +++ b/tools/server/webui/package-lock.json @@ -0,0 +1,6620 @@ +{ + "name": "webui", + "version": "0.0.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "webui", + "version": "0.0.0", + "dependencies": { + "@heroicons/react": "^2.2.0", + "@sec-ant/readable-stream": "^0.6.0", + "@tailwindcss/postcss": "^4.1.1", + "@tailwindcss/vite": "^4.1.1", + "@vscode/markdown-it-katex": "^1.1.1", + "autoprefixer": "^10.4.20", + "daisyui": "^5.0.12", + "dexie": "^4.0.11", + "highlight.js": "^11.10.0", + "katex": "^0.16.15", + "pdfjs-dist": "^5.2.133", + "postcss": "^8.4.49", + "react": "^18.3.1", + "react-dom": "^18.3.1", + "react-dropzone": "^14.3.8", + "react-hot-toast": "^2.5.2", + "react-markdown": "^9.0.3", + "react-router": "^7.1.5", + "rehype-highlight": "^7.0.2", + "rehype-katex": "^7.0.1", + "remark-breaks": "^4.0.0", + "remark-gfm": "^4.0.0", + "remark-math": "^6.0.0", + "tailwindcss": "^4.1.1", + "textlinestream": "^1.1.1", + "vite-plugin-singlefile": "^2.0.3" + }, + "devDependencies": { + "@eslint/js": "^9.17.0", + "@types/markdown-it": "^14.1.2", + "@types/node": "^22.13.1", + "@types/react": "^18.3.18", + "@types/react-dom": "^18.3.5", + "@vitejs/plugin-react": "^4.3.4", + "eslint": "^9.17.0", + "eslint-plugin-react-hooks": "^5.0.0", + "eslint-plugin-react-refresh": "^0.4.16", + "fflate": "^0.8.2", + "globals": "^15.14.0", + "prettier": "^3.4.2", + "sass-embedded": "^1.83.4", + "typescript": "~5.6.2", + "typescript-eslint": "^8.18.2", + "vite": "^6.0.5" + } + }, + "node_modules/@alloc/quick-lru": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/@alloc/quick-lru/-/quick-lru-5.2.0.tgz", + "integrity": "sha512-UrcABB+4bUrFABwbluTIBErXwvbsU/V7TZWfmbgJfbkwiBuziS9gxdODUyuiecfdGQ85jglMW6juS3+z5TsKLw==", + "license": "MIT", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/@ampproject/remapping": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/@ampproject/remapping/-/remapping-2.3.0.tgz", + "integrity": "sha512-30iZtAPgz+LTIYoeivqYo853f02jBYSd5uGnGpkFV0M3xOt9aN73erkgYAmZU43x4VfqcnLxW9Kpg3R5LC4YYw==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@jridgewell/gen-mapping": "^0.3.5", + "@jridgewell/trace-mapping": "^0.3.24" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@babel/code-frame": { + "version": "7.26.2", + "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.26.2.tgz", + "integrity": "sha512-RJlIHRueQgwWitWgF8OdFYGZX328Ax5BCemNGlqHfplnRT9ESi8JkFlvaVYbS+UubVY6dpv87Fs2u5M29iNFVQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/helper-validator-identifier": "^7.25.9", + "js-tokens": "^4.0.0", + "picocolors": "^1.0.0" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/compat-data": { + "version": "7.26.5", + "resolved": "https://registry.npmjs.org/@babel/compat-data/-/compat-data-7.26.5.tgz", + "integrity": "sha512-XvcZi1KWf88RVbF9wn8MN6tYFloU5qX8KjuF3E1PVBmJ9eypXfs4GRiJwLuTZL0iSnJUKn1BFPa5BPZZJyFzPg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/core": { + "version": "7.26.7", + "resolved": "https://registry.npmjs.org/@babel/core/-/core-7.26.7.tgz", + "integrity": "sha512-SRijHmF0PSPgLIBYlWnG0hyeJLwXE2CgpsXaMOrtt2yp9/86ALw6oUlj9KYuZ0JN07T4eBMVIW4li/9S1j2BGA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@ampproject/remapping": "^2.2.0", + "@babel/code-frame": "^7.26.2", + "@babel/generator": "^7.26.5", + "@babel/helper-compilation-targets": "^7.26.5", + "@babel/helper-module-transforms": "^7.26.0", + "@babel/helpers": "^7.26.7", + "@babel/parser": "^7.26.7", + "@babel/template": "^7.25.9", + "@babel/traverse": "^7.26.7", + "@babel/types": "^7.26.7", + "convert-source-map": "^2.0.0", + "debug": "^4.1.0", + "gensync": "^1.0.0-beta.2", + "json5": "^2.2.3", + "semver": "^6.3.1" + }, + "engines": { + "node": ">=6.9.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/babel" + } + }, + "node_modules/@babel/generator": { + "version": "7.26.5", + "resolved": "https://registry.npmjs.org/@babel/generator/-/generator-7.26.5.tgz", + "integrity": "sha512-2caSP6fN9I7HOe6nqhtft7V4g7/V/gfDsC3Ag4W7kEzzvRGKqiv0pu0HogPiZ3KaVSoNDhUws6IJjDjpfmYIXw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/parser": "^7.26.5", + "@babel/types": "^7.26.5", + "@jridgewell/gen-mapping": "^0.3.5", + "@jridgewell/trace-mapping": "^0.3.25", + "jsesc": "^3.0.2" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helper-compilation-targets": { + "version": "7.26.5", + "resolved": "https://registry.npmjs.org/@babel/helper-compilation-targets/-/helper-compilation-targets-7.26.5.tgz", + "integrity": "sha512-IXuyn5EkouFJscIDuFF5EsiSolseme1s0CZB+QxVugqJLYmKdxI1VfIBOst0SUu4rnk2Z7kqTwmoO1lp3HIfnA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/compat-data": "^7.26.5", + "@babel/helper-validator-option": "^7.25.9", + "browserslist": "^4.24.0", + "lru-cache": "^5.1.1", + "semver": "^6.3.1" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helper-module-imports": { + "version": "7.25.9", + "resolved": "https://registry.npmjs.org/@babel/helper-module-imports/-/helper-module-imports-7.25.9.tgz", + "integrity": "sha512-tnUA4RsrmflIM6W6RFTLFSXITtl0wKjgpnLgXyowocVPrbYrLUXSBXDgTs8BlbmIzIdlBySRQjINYs2BAkiLtw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/traverse": "^7.25.9", + "@babel/types": "^7.25.9" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helper-module-transforms": { + "version": "7.26.0", + "resolved": "https://registry.npmjs.org/@babel/helper-module-transforms/-/helper-module-transforms-7.26.0.tgz", + "integrity": "sha512-xO+xu6B5K2czEnQye6BHA7DolFFmS3LB7stHZFaOLb1pAwO1HWLS8fXA+eh0A2yIvltPVmx3eNNDBJA2SLHXFw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/helper-module-imports": "^7.25.9", + "@babel/helper-validator-identifier": "^7.25.9", + "@babel/traverse": "^7.25.9" + }, + "engines": { + "node": ">=6.9.0" + }, + "peerDependencies": { + "@babel/core": "^7.0.0" + } + }, + "node_modules/@babel/helper-plugin-utils": { + "version": "7.26.5", + "resolved": "https://registry.npmjs.org/@babel/helper-plugin-utils/-/helper-plugin-utils-7.26.5.tgz", + "integrity": "sha512-RS+jZcRdZdRFzMyr+wcsaqOmld1/EqTghfaBGQQd/WnRdzdlvSZ//kF7U8VQTxf1ynZ4cjUcYgjVGx13ewNPMg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helper-string-parser": { + "version": "7.25.9", + "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.25.9.tgz", + "integrity": "sha512-4A/SCr/2KLd5jrtOMFzaKjVtAei3+2r/NChoBNoZ3EyP/+GlhoaEGoWOZUmFmoITP7zOJyHIMm+DYRd8o3PvHA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helper-validator-identifier": { + "version": "7.25.9", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.25.9.tgz", + "integrity": "sha512-Ed61U6XJc3CVRfkERJWDz4dJwKe7iLmmJsbOGu9wSloNSFttHV0I8g6UAgb7qnK5ly5bGLPd4oXZlxCdANBOWQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helper-validator-option": { + "version": "7.25.9", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-option/-/helper-validator-option-7.25.9.tgz", + "integrity": "sha512-e/zv1co8pp55dNdEcCynfj9X7nyUKUXoUEwfXqaZt0omVOmDe9oOTdKStH4GmAw6zxMFs50ZayuMfHDKlO7Tfw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helpers": { + "version": "7.26.7", + "resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.26.7.tgz", + "integrity": "sha512-8NHiL98vsi0mbPQmYAGWwfcFaOy4j2HY49fXJCfuDcdE7fMIsH9a7GdaeXpIBsbT7307WU8KCMp5pUVDNL4f9A==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/template": "^7.25.9", + "@babel/types": "^7.26.7" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/parser": { + "version": "7.26.7", + "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.26.7.tgz", + "integrity": "sha512-kEvgGGgEjRUutvdVvZhbn/BxVt+5VSpwXz1j3WYXQbXDo8KzFOPNG2GQbdAiNq8g6wn1yKk7C/qrke03a84V+w==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/types": "^7.26.7" + }, + "bin": { + "parser": "bin/babel-parser.js" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@babel/plugin-transform-react-jsx-self": { + "version": "7.25.9", + "resolved": "https://registry.npmjs.org/@babel/plugin-transform-react-jsx-self/-/plugin-transform-react-jsx-self-7.25.9.tgz", + "integrity": "sha512-y8quW6p0WHkEhmErnfe58r7x0A70uKphQm8Sp8cV7tjNQwK56sNVK0M73LK3WuYmsuyrftut4xAkjjgU0twaMg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/helper-plugin-utils": "^7.25.9" + }, + "engines": { + "node": ">=6.9.0" + }, + "peerDependencies": { + "@babel/core": "^7.0.0-0" + } + }, + "node_modules/@babel/plugin-transform-react-jsx-source": { + "version": "7.25.9", + "resolved": "https://registry.npmjs.org/@babel/plugin-transform-react-jsx-source/-/plugin-transform-react-jsx-source-7.25.9.tgz", + "integrity": "sha512-+iqjT8xmXhhYv4/uiYd8FNQsraMFZIfxVSqxxVSZP0WbbSAWvBXAul0m/zu+7Vv4O/3WtApy9pmaTMiumEZgfg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/helper-plugin-utils": "^7.25.9" + }, + "engines": { + "node": ">=6.9.0" + }, + "peerDependencies": { + "@babel/core": "^7.0.0-0" + } + }, + "node_modules/@babel/template": { + "version": "7.25.9", + "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.25.9.tgz", + "integrity": "sha512-9DGttpmPvIxBb/2uwpVo3dqJ+O6RooAFOS+lB+xDqoE2PVCE8nfoHMdZLpfCQRLwvohzXISPZcgxt80xLfsuwg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/code-frame": "^7.25.9", + "@babel/parser": "^7.25.9", + "@babel/types": "^7.25.9" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/traverse": { + "version": "7.26.7", + "resolved": "https://registry.npmjs.org/@babel/traverse/-/traverse-7.26.7.tgz", + "integrity": "sha512-1x1sgeyRLC3r5fQOM0/xtQKsYjyxmFjaOrLJNtZ81inNjyJHGIolTULPiSc/2qe1/qfpFLisLQYFnnZl7QoedA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/code-frame": "^7.26.2", + "@babel/generator": "^7.26.5", + "@babel/parser": "^7.26.7", + "@babel/template": "^7.25.9", + "@babel/types": "^7.26.7", + "debug": "^4.3.1", + "globals": "^11.1.0" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/traverse/node_modules/globals": { + "version": "11.12.0", + "resolved": "https://registry.npmjs.org/globals/-/globals-11.12.0.tgz", + "integrity": "sha512-WOBp/EEGUiIsJSp7wcv/y6MO+lV9UoncWqxuFfm8eBwzWNgyfBd6Gz+IeKQ9jCmyhoH99g15M3T+QaVHFjizVA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=4" + } + }, + "node_modules/@babel/types": { + "version": "7.26.7", + "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.26.7.tgz", + "integrity": "sha512-t8kDRGrKXyp6+tjUh7hw2RLyclsW4TRoRvRHtSyAX9Bb5ldlFh+90YAYY6awRXrlB4G5G2izNeGySpATlFzmOg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/helper-string-parser": "^7.25.9", + "@babel/helper-validator-identifier": "^7.25.9" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@bufbuild/protobuf": { + "version": "2.2.3", + "resolved": "https://registry.npmjs.org/@bufbuild/protobuf/-/protobuf-2.2.3.tgz", + "integrity": "sha512-tFQoXHJdkEOSwj5tRIZSPNUuXK3RaR7T1nUrPgbYX1pUbvqqaaZAsfo+NXBPsz5rZMSKVFrgK1WL8Q/MSLvprg==", + "devOptional": true, + "license": "(Apache-2.0 AND BSD-3-Clause)" + }, + "node_modules/@esbuild/aix-ppc64": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.24.2.tgz", + "integrity": "sha512-thpVCb/rhxE/BnMLQ7GReQLLN8q9qbHmI55F4489/ByVg2aQaQ6kbcLb6FHkocZzQhxc4gx0sCk0tJkKBFzDhA==", + "cpu": [ + "ppc64" + ], + "license": "MIT", + "optional": true, + "os": [ + "aix" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/android-arm": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm/-/android-arm-0.24.2.tgz", + "integrity": "sha512-tmwl4hJkCfNHwFB3nBa8z1Uy3ypZpxqxfTQOcHX+xRByyYgunVbZ9MzUUfb0RxaHIMnbHagwAxuTL+tnNM+1/Q==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/android-arm64": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm64/-/android-arm64-0.24.2.tgz", + "integrity": "sha512-cNLgeqCqV8WxfcTIOeL4OAtSmL8JjcN6m09XIgro1Wi7cF4t/THaWEa7eL5CMoMBdjoHOTh/vwTO/o2TRXIyzg==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/android-x64": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/android-x64/-/android-x64-0.24.2.tgz", + "integrity": "sha512-B6Q0YQDqMx9D7rvIcsXfmJfvUYLoP722bgfBlO5cGvNVb5V/+Y7nhBE3mHV9OpxBf4eAS2S68KZztiPaWq4XYw==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/darwin-arm64": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.24.2.tgz", + "integrity": "sha512-kj3AnYWc+CekmZnS5IPu9D+HWtUI49hbnyqk0FLEJDbzCIQt7hg7ucF1SQAilhtYpIujfaHr6O0UHlzzSPdOeA==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/darwin-x64": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-x64/-/darwin-x64-0.24.2.tgz", + "integrity": "sha512-WeSrmwwHaPkNR5H3yYfowhZcbriGqooyu3zI/3GGpF8AyUdsrrP0X6KumITGA9WOyiJavnGZUwPGvxvwfWPHIA==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/freebsd-arm64": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-arm64/-/freebsd-arm64-0.24.2.tgz", + "integrity": "sha512-UN8HXjtJ0k/Mj6a9+5u6+2eZ2ERD7Edt1Q9IZiB5UZAIdPnVKDoG7mdTVGhHJIeEml60JteamR3qhsr1r8gXvg==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/freebsd-x64": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-x64/-/freebsd-x64-0.24.2.tgz", + "integrity": "sha512-TvW7wE/89PYW+IevEJXZ5sF6gJRDY/14hyIGFXdIucxCsbRmLUcjseQu1SyTko+2idmCw94TgyaEZi9HUSOe3Q==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-arm": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm/-/linux-arm-0.24.2.tgz", + "integrity": "sha512-n0WRM/gWIdU29J57hJyUdIsk0WarGd6To0s+Y+LwvlC55wt+GT/OgkwoXCXvIue1i1sSNWblHEig00GBWiJgfA==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-arm64": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm64/-/linux-arm64-0.24.2.tgz", + "integrity": "sha512-7HnAD6074BW43YvvUmE/35Id9/NB7BeX5EoNkK9obndmZBUk8xmJJeU7DwmUeN7tkysslb2eSl6CTrYz6oEMQg==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-ia32": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ia32/-/linux-ia32-0.24.2.tgz", + "integrity": "sha512-sfv0tGPQhcZOgTKO3oBE9xpHuUqguHvSo4jl+wjnKwFpapx+vUDcawbwPNuBIAYdRAvIDBfZVvXprIj3HA+Ugw==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-loong64": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/linux-loong64/-/linux-loong64-0.24.2.tgz", + "integrity": "sha512-CN9AZr8kEndGooS35ntToZLTQLHEjtVB5n7dl8ZcTZMonJ7CCfStrYhrzF97eAecqVbVJ7APOEe18RPI4KLhwQ==", + "cpu": [ + "loong64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-mips64el": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/linux-mips64el/-/linux-mips64el-0.24.2.tgz", + "integrity": "sha512-iMkk7qr/wl3exJATwkISxI7kTcmHKE+BlymIAbHO8xanq/TjHaaVThFF6ipWzPHryoFsesNQJPE/3wFJw4+huw==", + "cpu": [ + "mips64el" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-ppc64": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ppc64/-/linux-ppc64-0.24.2.tgz", + "integrity": "sha512-shsVrgCZ57Vr2L8mm39kO5PPIb+843FStGt7sGGoqiiWYconSxwTiuswC1VJZLCjNiMLAMh34jg4VSEQb+iEbw==", + "cpu": [ + "ppc64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-riscv64": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/linux-riscv64/-/linux-riscv64-0.24.2.tgz", + "integrity": "sha512-4eSFWnU9Hhd68fW16GD0TINewo1L6dRrB+oLNNbYyMUAeOD2yCK5KXGK1GH4qD/kT+bTEXjsyTCiJGHPZ3eM9Q==", + "cpu": [ + "riscv64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-s390x": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/linux-s390x/-/linux-s390x-0.24.2.tgz", + "integrity": "sha512-S0Bh0A53b0YHL2XEXC20bHLuGMOhFDO6GN4b3YjRLK//Ep3ql3erpNcPlEFed93hsQAjAQDNsvcK+hV90FubSw==", + "cpu": [ + "s390x" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-x64": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/linux-x64/-/linux-x64-0.24.2.tgz", + "integrity": "sha512-8Qi4nQcCTbLnK9WoMjdC9NiTG6/E38RNICU6sUNqK0QFxCYgoARqVqxdFmWkdonVsvGqWhmm7MO0jyTqLqwj0Q==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/netbsd-arm64": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/netbsd-arm64/-/netbsd-arm64-0.24.2.tgz", + "integrity": "sha512-wuLK/VztRRpMt9zyHSazyCVdCXlpHkKm34WUyinD2lzK07FAHTq0KQvZZlXikNWkDGoT6x3TD51jKQ7gMVpopw==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "netbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/netbsd-x64": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/netbsd-x64/-/netbsd-x64-0.24.2.tgz", + "integrity": "sha512-VefFaQUc4FMmJuAxmIHgUmfNiLXY438XrL4GDNV1Y1H/RW3qow68xTwjZKfj/+Plp9NANmzbH5R40Meudu8mmw==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "netbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/openbsd-arm64": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/openbsd-arm64/-/openbsd-arm64-0.24.2.tgz", + "integrity": "sha512-YQbi46SBct6iKnszhSvdluqDmxCJA+Pu280Av9WICNwQmMxV7nLRHZfjQzwbPs3jeWnuAhE9Jy0NrnJ12Oz+0A==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "openbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/openbsd-x64": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/openbsd-x64/-/openbsd-x64-0.24.2.tgz", + "integrity": "sha512-+iDS6zpNM6EnJyWv0bMGLWSWeXGN/HTaF/LXHXHwejGsVi+ooqDfMCCTerNFxEkM3wYVcExkeGXNqshc9iMaOA==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "openbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/sunos-x64": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/sunos-x64/-/sunos-x64-0.24.2.tgz", + "integrity": "sha512-hTdsW27jcktEvpwNHJU4ZwWFGkz2zRJUz8pvddmXPtXDzVKTTINmlmga3ZzwcuMpUvLw7JkLy9QLKyGpD2Yxig==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "sunos" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/win32-arm64": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/win32-arm64/-/win32-arm64-0.24.2.tgz", + "integrity": "sha512-LihEQ2BBKVFLOC9ZItT9iFprsE9tqjDjnbulhHoFxYQtQfai7qfluVODIYxt1PgdoyQkz23+01rzwNwYfutxUQ==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/win32-ia32": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/win32-ia32/-/win32-ia32-0.24.2.tgz", + "integrity": "sha512-q+iGUwfs8tncmFC9pcnD5IvRHAzmbwQ3GPS5/ceCyHdjXubwQWI12MKWSNSMYLJMq23/IUCvJMS76PDqXe1fxA==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/win32-x64": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/@esbuild/win32-x64/-/win32-x64-0.24.2.tgz", + "integrity": "sha512-7VTgWzgMGvup6aSqDPLiW5zHaxYJGTO4OokMjIlrCtf+VpEL+cXKtCvg723iguPYI5oaUNdS+/V7OU2gvXVWEg==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@eslint-community/eslint-utils": { + "version": "4.4.1", + "resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.4.1.tgz", + "integrity": "sha512-s3O3waFUrMV8P/XaF/+ZTp1X9XBZW1a4B97ZnjQF2KYWaFD2A8KyFBsrsfSjEmjn3RGWAIuvlneuZm3CUK3jbA==", + "dev": true, + "license": "MIT", + "dependencies": { + "eslint-visitor-keys": "^3.4.3" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + }, + "peerDependencies": { + "eslint": "^6.0.0 || ^7.0.0 || >=8.0.0" + } + }, + "node_modules/@eslint-community/eslint-utils/node_modules/eslint-visitor-keys": { + "version": "3.4.3", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-3.4.3.tgz", + "integrity": "sha512-wpc+LXeiyiisxPlEkUzU6svyS1frIO3Mgxj1fdy7Pm8Ygzguax2N3Fa/D/ag1WqbOprdI+uY6wMUl8/a2G+iag==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/@eslint-community/regexpp": { + "version": "4.12.1", + "resolved": "https://registry.npmjs.org/@eslint-community/regexpp/-/regexpp-4.12.1.tgz", + "integrity": "sha512-CCZCDJuduB9OUkFkY2IgppNZMi2lBQgD2qzwXkEia16cge2pijY/aXi96CJMquDMn3nJdlPV1A5KrJEXwfLNzQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^12.0.0 || ^14.0.0 || >=16.0.0" + } + }, + "node_modules/@eslint/config-array": { + "version": "0.19.2", + "resolved": "https://registry.npmjs.org/@eslint/config-array/-/config-array-0.19.2.tgz", + "integrity": "sha512-GNKqxfHG2ySmJOBSHg7LxeUx4xpuCoFjacmlCoYWEbaPXLwvfIjixRI12xCQZeULksQb23uiA8F40w5TojpV7w==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@eslint/object-schema": "^2.1.6", + "debug": "^4.3.1", + "minimatch": "^3.1.2" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + } + }, + "node_modules/@eslint/core": { + "version": "0.10.0", + "resolved": "https://registry.npmjs.org/@eslint/core/-/core-0.10.0.tgz", + "integrity": "sha512-gFHJ+xBOo4G3WRlR1e/3G8A6/KZAH6zcE/hkLRCZTi/B9avAG365QhFA8uOGzTMqgTghpn7/fSnscW++dpMSAw==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@types/json-schema": "^7.0.15" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + } + }, + "node_modules/@eslint/eslintrc": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-3.2.0.tgz", + "integrity": "sha512-grOjVNN8P3hjJn/eIETF1wwd12DdnwFDoyceUJLYYdkpbwq3nLi+4fqrTAONx7XDALqlL220wC/RHSC/QTI/0w==", + "dev": true, + "license": "MIT", + "dependencies": { + "ajv": "^6.12.4", + "debug": "^4.3.2", + "espree": "^10.0.1", + "globals": "^14.0.0", + "ignore": "^5.2.0", + "import-fresh": "^3.2.1", + "js-yaml": "^4.1.0", + "minimatch": "^3.1.2", + "strip-json-comments": "^3.1.1" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/@eslint/eslintrc/node_modules/globals": { + "version": "14.0.0", + "resolved": "https://registry.npmjs.org/globals/-/globals-14.0.0.tgz", + "integrity": "sha512-oahGvuMGQlPw/ivIYBjVSrWAfWLBeku5tpPE2fOPLi+WHffIWbuh2tCjhyQhTBPMf5E9jDEH4FOmTYgYwbKwtQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/@eslint/js": { + "version": "9.19.0", + "resolved": "https://registry.npmjs.org/@eslint/js/-/js-9.19.0.tgz", + "integrity": "sha512-rbq9/g38qjfqFLOVPvwjIvFFdNziEC5S65jmjPw5r6A//QH+W91akh9irMwjDN8zKUTak6W9EsAv4m/7Wnw0UQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + } + }, + "node_modules/@eslint/object-schema": { + "version": "2.1.6", + "resolved": "https://registry.npmjs.org/@eslint/object-schema/-/object-schema-2.1.6.tgz", + "integrity": "sha512-RBMg5FRL0I0gs51M/guSAj5/e14VQ4tpZnQNWwuDT66P14I43ItmPfIZRhO9fUVIPOAQXU47atlywZ/czoqFPA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + } + }, + "node_modules/@eslint/plugin-kit": { + "version": "0.2.5", + "resolved": "https://registry.npmjs.org/@eslint/plugin-kit/-/plugin-kit-0.2.5.tgz", + "integrity": "sha512-lB05FkqEdUg2AA0xEbUz0SnkXT1LcCTa438W4IWTUh4hdOnVbQyOJ81OrDXsJk/LSiJHubgGEFoR5EHq1NsH1A==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@eslint/core": "^0.10.0", + "levn": "^0.4.1" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + } + }, + "node_modules/@heroicons/react": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/@heroicons/react/-/react-2.2.0.tgz", + "integrity": "sha512-LMcepvRaS9LYHJGsF0zzmgKCUim/X3N/DQKc4jepAXJ7l8QxJ1PmxJzqplF2Z3FE4PqBAIGyJAQ/w4B5dsqbtQ==", + "license": "MIT", + "peerDependencies": { + "react": ">= 16 || ^19.0.0-rc" + } + }, + "node_modules/@humanfs/core": { + "version": "0.19.1", + "resolved": "https://registry.npmjs.org/@humanfs/core/-/core-0.19.1.tgz", + "integrity": "sha512-5DyQ4+1JEUzejeK1JGICcideyfUbGixgS9jNgex5nqkW+cY7WZhxBigmieN5Qnw9ZosSNVC9KQKyb+GUaGyKUA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=18.18.0" + } + }, + "node_modules/@humanfs/node": { + "version": "0.16.6", + "resolved": "https://registry.npmjs.org/@humanfs/node/-/node-0.16.6.tgz", + "integrity": "sha512-YuI2ZHQL78Q5HbhDiBA1X4LmYdXCKCMQIfw0pw7piHJwyREFebJUvrQN4cMssyES6x+vfUbx1CIpaQUKYdQZOw==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@humanfs/core": "^0.19.1", + "@humanwhocodes/retry": "^0.3.0" + }, + "engines": { + "node": ">=18.18.0" + } + }, + "node_modules/@humanfs/node/node_modules/@humanwhocodes/retry": { + "version": "0.3.1", + "resolved": "https://registry.npmjs.org/@humanwhocodes/retry/-/retry-0.3.1.tgz", + "integrity": "sha512-JBxkERygn7Bv/GbN5Rv8Ul6LVknS+5Bp6RgDC/O8gEBU/yeH5Ui5C/OlWrTb6qct7LjjfT6Re2NxB0ln0yYybA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=18.18" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/nzakas" + } + }, + "node_modules/@humanwhocodes/module-importer": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@humanwhocodes/module-importer/-/module-importer-1.0.1.tgz", + "integrity": "sha512-bxveV4V8v5Yb4ncFTT3rPSgZBOpCkjfK0y4oVVVJwIuDVBRMDXrPyXRL988i5ap9m9bnyEEjWfm5WkBmtffLfA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=12.22" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/nzakas" + } + }, + "node_modules/@humanwhocodes/retry": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/@humanwhocodes/retry/-/retry-0.4.1.tgz", + "integrity": "sha512-c7hNEllBlenFTHBky65mhq8WD2kbN9Q6gk0bTk8lSBvc554jpXSkST1iePudpt7+A/AQvuHs9EMqjHDXMY1lrA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=18.18" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/nzakas" + } + }, + "node_modules/@jridgewell/gen-mapping": { + "version": "0.3.8", + "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.8.tgz", + "integrity": "sha512-imAbBGkb+ebQyxKgzv5Hu2nmROxoDOXHh80evxdoXNOrvAnVx7zimzc1Oo5h9RlfV4vPXaE2iM5pOFbvOCClWA==", + "devOptional": true, + "license": "MIT", + "dependencies": { + "@jridgewell/set-array": "^1.2.1", + "@jridgewell/sourcemap-codec": "^1.4.10", + "@jridgewell/trace-mapping": "^0.3.24" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/resolve-uri": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz", + "integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==", + "devOptional": true, + "license": "MIT", + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/set-array": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.2.1.tgz", + "integrity": "sha512-R8gLRTZeyp03ymzP/6Lil/28tGeGEzhx1q2k703KGWRAI1VdvPIXdG70VJc2pAMw3NA6JKL5hhFu1sJX0Mnn/A==", + "devOptional": true, + "license": "MIT", + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/source-map": { + "version": "0.3.6", + "resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.6.tgz", + "integrity": "sha512-1ZJTZebgqllO79ue2bm3rIGud/bOe0pP5BjSRCRxxYkEZS8STV7zN84UBbiYu7jy+eCKSnVIUgoWWE/tt+shMQ==", + "license": "MIT", + "optional": true, + "peer": true, + "dependencies": { + "@jridgewell/gen-mapping": "^0.3.5", + "@jridgewell/trace-mapping": "^0.3.25" + } + }, + "node_modules/@jridgewell/sourcemap-codec": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.0.tgz", + "integrity": "sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==", + "devOptional": true, + "license": "MIT" + }, + "node_modules/@jridgewell/trace-mapping": { + "version": "0.3.25", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.25.tgz", + "integrity": "sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==", + "devOptional": true, + "license": "MIT", + "dependencies": { + "@jridgewell/resolve-uri": "^3.1.0", + "@jridgewell/sourcemap-codec": "^1.4.14" + } + }, + "node_modules/@napi-rs/canvas": { + "version": "0.1.70", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas/-/canvas-0.1.70.tgz", + "integrity": "sha512-nD6NGa4JbNYSZYsTnLGrqe9Kn/lCkA4ybXt8sx5ojDqZjr2i0TWAHxx/vhgfjX+i3hCdKWufxYwi7CfXqtITSA==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 10" + }, + "optionalDependencies": { + "@napi-rs/canvas-android-arm64": "0.1.70", + "@napi-rs/canvas-darwin-arm64": "0.1.70", + "@napi-rs/canvas-darwin-x64": "0.1.70", + "@napi-rs/canvas-linux-arm-gnueabihf": "0.1.70", + "@napi-rs/canvas-linux-arm64-gnu": "0.1.70", + "@napi-rs/canvas-linux-arm64-musl": "0.1.70", + "@napi-rs/canvas-linux-riscv64-gnu": "0.1.70", + "@napi-rs/canvas-linux-x64-gnu": "0.1.70", + "@napi-rs/canvas-linux-x64-musl": "0.1.70", + "@napi-rs/canvas-win32-x64-msvc": "0.1.70" + } + }, + "node_modules/@napi-rs/canvas-android-arm64": { + "version": "0.1.70", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-android-arm64/-/canvas-android-arm64-0.1.70.tgz", + "integrity": "sha512-I/YOuQ0wbkVYxVaYtCgN42WKTYxNqFA0gTcTrHIGG1jfpDSyZWII/uHcjOo4nzd19io6Y4+/BqP8E5hJgf9OmQ==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-darwin-arm64": { + "version": "0.1.70", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-arm64/-/canvas-darwin-arm64-0.1.70.tgz", + "integrity": "sha512-4pPGyXetHIHkw2TOJHujt3mkCP8LdDu8+CT15ld9Id39c752RcI0amDHSuMLMQfAjvusA9B5kKxazwjMGjEJpQ==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-darwin-x64": { + "version": "0.1.70", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-x64/-/canvas-darwin-x64-0.1.70.tgz", + "integrity": "sha512-+2N6Os9LbkmDMHL+raknrUcLQhsXzc5CSXRbXws9C3pv/mjHRVszQ9dhFUUe9FjfPhCJznO6USVdwOtu7pOrzQ==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-linux-arm-gnueabihf": { + "version": "0.1.70", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm-gnueabihf/-/canvas-linux-arm-gnueabihf-0.1.70.tgz", + "integrity": "sha512-QjscX9OaKq/990sVhSMj581xuqLgiaPVMjjYvWaCmAJRkNQ004QfoSMEm3FoTqM4DRoquP8jvuEXScVJsc1rqQ==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-linux-arm64-gnu": { + "version": "0.1.70", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-gnu/-/canvas-linux-arm64-gnu-0.1.70.tgz", + "integrity": "sha512-LNakMOwwqwiHIwMpnMAbFRczQMQ7TkkMyATqFCOtUJNlE6LPP/QiUj/mlFrNbUn/hctqShJ60gWEb52ZTALbVw==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-linux-arm64-musl": { + "version": "0.1.70", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-musl/-/canvas-linux-arm64-musl-0.1.70.tgz", + "integrity": "sha512-wBTOllEYNfJCHOdZj9v8gLzZ4oY3oyPX8MSRvaxPm/s7RfEXxCyZ8OhJ5xAyicsDdbE5YBZqdmaaeP5+xKxvtg==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-linux-riscv64-gnu": { + "version": "0.1.70", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-riscv64-gnu/-/canvas-linux-riscv64-gnu-0.1.70.tgz", + "integrity": "sha512-GVUUPC8TuuFqHip0rxHkUqArQnlzmlXmTEBuXAWdgCv85zTCFH8nOHk/YCF5yo0Z2eOm8nOi90aWs0leJ4OE5Q==", + "cpu": [ + "riscv64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-linux-x64-gnu": { + "version": "0.1.70", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-gnu/-/canvas-linux-x64-gnu-0.1.70.tgz", + "integrity": "sha512-/kvUa2lZRwGNyfznSn5t1ShWJnr/m5acSlhTV3eXECafObjl0VBuA1HJw0QrilLpb4Fe0VLywkpD1NsMoVDROQ==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-linux-x64-musl": { + "version": "0.1.70", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-musl/-/canvas-linux-x64-musl-0.1.70.tgz", + "integrity": "sha512-aqlv8MLpycoMKRmds7JWCfVwNf1fiZxaU7JwJs9/ExjTD8lX2KjsO7CTeAj5Cl4aEuzxUWbJPUUE2Qu9cZ1vfg==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-win32-x64-msvc": { + "version": "0.1.70", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-win32-x64-msvc/-/canvas-win32-x64-msvc-0.1.70.tgz", + "integrity": "sha512-Q9QU3WIpwBTVHk4cPfBjGHGU4U0llQYRXgJtFtYqqGNEOKVN4OT6PQ+ve63xwIPODMpZ0HHyj/KLGc9CWc3EtQ==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@nodelib/fs.scandir": { + "version": "2.1.5", + "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", + "integrity": "sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@nodelib/fs.stat": "2.0.5", + "run-parallel": "^1.1.9" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/@nodelib/fs.stat": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz", + "integrity": "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 8" + } + }, + "node_modules/@nodelib/fs.walk": { + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz", + "integrity": "sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@nodelib/fs.scandir": "2.1.5", + "fastq": "^1.6.0" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/@rollup/rollup-android-arm-eabi": { + "version": "4.34.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.34.2.tgz", + "integrity": "sha512-6Fyg9yQbwJR+ykVdT9sid1oc2ewejS6h4wzQltmJfSW53N60G/ah9pngXGANdy9/aaE/TcUFpWosdm7JXS1WTQ==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ] + }, + "node_modules/@rollup/rollup-android-arm64": { + "version": "4.34.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.34.2.tgz", + "integrity": "sha512-K5GfWe+vtQ3kyEbihrimM38UgX57UqHp+oME7X/EX9Im6suwZfa7Hsr8AtzbJvukTpwMGs+4s29YMSO3rwWtsw==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ] + }, + "node_modules/@rollup/rollup-darwin-arm64": { + "version": "4.34.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.34.2.tgz", + "integrity": "sha512-PSN58XG/V/tzqDb9kDGutUruycgylMlUE59f40ny6QIRNsTEIZsrNQTJKUN2keMMSmlzgunMFqyaGLmly39sug==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ] + }, + "node_modules/@rollup/rollup-darwin-x64": { + "version": "4.34.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.34.2.tgz", + "integrity": "sha512-gQhK788rQJm9pzmXyfBB84VHViDERhAhzGafw+E5mUpnGKuxZGkMVDa3wgDFKT6ukLC5V7QTifzsUKdNVxp5qQ==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ] + }, + "node_modules/@rollup/rollup-freebsd-arm64": { + "version": "4.34.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.34.2.tgz", + "integrity": "sha512-eiaHgQwGPpxLC3+zTAcdKl4VsBl3r0AiJOd1Um/ArEzAjN/dbPK1nROHrVkdnoE6p7Svvn04w3f/jEZSTVHunA==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ] + }, + "node_modules/@rollup/rollup-freebsd-x64": { + "version": "4.34.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.34.2.tgz", + "integrity": "sha512-lhdiwQ+jf8pewYOTG4bag0Qd68Jn1v2gO1i0mTuiD+Qkt5vNfHVK/jrT7uVvycV8ZchlzXp5HDVmhpzjC6mh0g==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ] + }, + "node_modules/@rollup/rollup-linux-arm-gnueabihf": { + "version": "4.34.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.34.2.tgz", + "integrity": "sha512-lfqTpWjSvbgQP1vqGTXdv+/kxIznKXZlI109WkIFPbud41bjigjNmOAAKoazmRGx+k9e3rtIdbq2pQZPV1pMig==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-arm-musleabihf": { + "version": "4.34.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.34.2.tgz", + "integrity": "sha512-RGjqULqIurqqv+NJTyuPgdZhka8ImMLB32YwUle2BPTDqDoXNgwFjdjQC59FbSk08z0IqlRJjrJ0AvDQ5W5lpw==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-arm64-gnu": { + "version": "4.34.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.34.2.tgz", + "integrity": "sha512-ZvkPiheyXtXlFqHpsdgscx+tZ7hoR59vOettvArinEspq5fxSDSgfF+L5wqqJ9R4t+n53nyn0sKxeXlik7AY9Q==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-arm64-musl": { + "version": "4.34.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.34.2.tgz", + "integrity": "sha512-UlFk+E46TZEoxD9ufLKDBzfSG7Ki03fo6hsNRRRHF+KuvNZ5vd1RRVQm8YZlGsjcJG8R252XFK0xNPay+4WV7w==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-loongarch64-gnu": { + "version": "4.34.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loongarch64-gnu/-/rollup-linux-loongarch64-gnu-4.34.2.tgz", + "integrity": "sha512-hJhfsD9ykx59jZuuoQgYT1GEcNNi3RCoEmbo5OGfG8RlHOiVS7iVNev9rhLKh7UBYq409f4uEw0cclTXx8nh8Q==", + "cpu": [ + "loong64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-powerpc64le-gnu": { + "version": "4.34.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-powerpc64le-gnu/-/rollup-linux-powerpc64le-gnu-4.34.2.tgz", + "integrity": "sha512-g/O5IpgtrQqPegvqopvmdCF9vneLE7eqYfdPWW8yjPS8f63DNam3U4ARL1PNNB64XHZDHKpvO2Giftf43puB8Q==", + "cpu": [ + "ppc64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-riscv64-gnu": { + "version": "4.34.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.34.2.tgz", + "integrity": "sha512-bSQijDC96M6PuooOuXHpvXUYiIwsnDmqGU8+br2U7iPoykNi9JtMUpN7K6xml29e0evK0/g0D1qbAUzWZFHY5Q==", + "cpu": [ + "riscv64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-s390x-gnu": { + "version": "4.34.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.34.2.tgz", + "integrity": "sha512-49TtdeVAsdRuiUHXPrFVucaP4SivazetGUVH8CIxVsNsaPHV4PFkpLmH9LeqU/R4Nbgky9lzX5Xe1NrzLyraVA==", + "cpu": [ + "s390x" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-x64-gnu": { + "version": "4.34.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.34.2.tgz", + "integrity": "sha512-j+jFdfOycLIQ7FWKka9Zd3qvsIyugg5LeZuHF6kFlXo6MSOc6R1w37YUVy8VpAKd81LMWGi5g9J25P09M0SSIw==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-x64-musl": { + "version": "4.34.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.34.2.tgz", + "integrity": "sha512-aDPHyM/D2SpXfSNCVWCxyHmOqN9qb7SWkY1+vaXqMNMXslZYnwh9V/UCudl6psyG0v6Ukj7pXanIpfZwCOEMUg==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-win32-arm64-msvc": { + "version": "4.34.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.34.2.tgz", + "integrity": "sha512-LQRkCyUBnAo7r8dbEdtNU08EKLCJMgAk2oP5H3R7BnUlKLqgR3dUjrLBVirmc1RK6U6qhtDw29Dimeer8d5hzQ==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/@rollup/rollup-win32-ia32-msvc": { + "version": "4.34.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.34.2.tgz", + "integrity": "sha512-wt8OhpQUi6JuPFkm1wbVi1BByeag87LDFzeKSXzIdGcX4bMLqORTtKxLoCbV57BHYNSUSOKlSL4BYYUghainYA==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/@rollup/rollup-win32-x64-msvc": { + "version": "4.34.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.34.2.tgz", + "integrity": "sha512-rUrqINax0TvrPBXrFKg0YbQx18NpPN3NNrgmaao9xRNbTwek7lOXObhx8tQy8gelmQ/gLaGy1WptpU2eKJZImg==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/@sec-ant/readable-stream": { + "version": "0.6.0", + "resolved": "https://registry.npmjs.org/@sec-ant/readable-stream/-/readable-stream-0.6.0.tgz", + "integrity": "sha512-uiBh8DrB5FN35gP6/o8JEhEQ7/ci1jUsOZO/VMUjyvTpjtV54VstOXVj1TvTj/wsT23pfX6butxxh3qufsW3+g==", + "license": "MIT" + }, + "node_modules/@tailwindcss/node": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/node/-/node-4.1.1.tgz", + "integrity": "sha512-xvlh4pvfG/bkv0fEtJDABAm1tjtSmSyi2QmS4zyj1EKNI1UiOYiUq1IphSwDsNJ5vJ9cWEGs4rJXpUdCN2kujQ==", + "license": "MIT", + "dependencies": { + "enhanced-resolve": "^5.18.1", + "jiti": "^2.4.2", + "lightningcss": "1.29.2", + "tailwindcss": "4.1.1" + } + }, + "node_modules/@tailwindcss/oxide": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide/-/oxide-4.1.1.tgz", + "integrity": "sha512-7+YBgnPQ4+jv6B6WVOerJ6WOzDzNJXrRKDts674v6TKAqFqYRr9+EBtSziO7nNcwQ8JtoZNMeqA+WJDjtCM/7w==", + "license": "MIT", + "engines": { + "node": ">= 10" + }, + "optionalDependencies": { + "@tailwindcss/oxide-android-arm64": "4.1.1", + "@tailwindcss/oxide-darwin-arm64": "4.1.1", + "@tailwindcss/oxide-darwin-x64": "4.1.1", + "@tailwindcss/oxide-freebsd-x64": "4.1.1", + "@tailwindcss/oxide-linux-arm-gnueabihf": "4.1.1", + "@tailwindcss/oxide-linux-arm64-gnu": "4.1.1", + "@tailwindcss/oxide-linux-arm64-musl": "4.1.1", + "@tailwindcss/oxide-linux-x64-gnu": "4.1.1", + "@tailwindcss/oxide-linux-x64-musl": "4.1.1", + "@tailwindcss/oxide-win32-arm64-msvc": "4.1.1", + "@tailwindcss/oxide-win32-x64-msvc": "4.1.1" + } + }, + "node_modules/@tailwindcss/oxide-android-arm64": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-android-arm64/-/oxide-android-arm64-4.1.1.tgz", + "integrity": "sha512-gTyRzfdParpoCU1yyUC/iN6XK6T0Ra4bDlF8Aeul5NP9cLzKEZDogdNVNGv5WZmCDkVol7qlex7TMmcfytMmmw==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-darwin-arm64": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-darwin-arm64/-/oxide-darwin-arm64-4.1.1.tgz", + "integrity": "sha512-dI0QbdMWBvLB3MtaTKetzUKG9CUUQow8JSP4Nm+OxVokeZ+N+f1OmZW/hW1LzMxpx9RQCBgSRL+IIvKRat5Wdg==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-darwin-x64": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-darwin-x64/-/oxide-darwin-x64-4.1.1.tgz", + "integrity": "sha512-2Y+NPQOTRBCItshPgY/CWg4bKi7E9evMg4bgdb6h9iZObCZLOe3doPcuSxGS3DB0dKyMFKE8pTdWtFUbxZBMSA==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-freebsd-x64": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-freebsd-x64/-/oxide-freebsd-x64-4.1.1.tgz", + "integrity": "sha512-N97NGMsB/7CHShbc5ube4dcsW/bYENkBrg8yWi8ieN9boYVRdw3cZviVryV/Nfu9bKbBV9kUvduFF2qBI7rEqg==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-linux-arm-gnueabihf": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-arm-gnueabihf/-/oxide-linux-arm-gnueabihf-4.1.1.tgz", + "integrity": "sha512-33Lk6KbHnUZbXqza6RWNFo9wqPQ4+H5BAn1CkUUfC1RZ1vYbyDN6+iJPj53wmnWJ3mhRI8jWt3Jt1fO02IVdUQ==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-linux-arm64-gnu": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-arm64-gnu/-/oxide-linux-arm64-gnu-4.1.1.tgz", + "integrity": "sha512-LyW35RzSUy+80WYScv03HKasAUmMFDaSbNpWfk1gG5gEE9kuRGnDzSrqMoLAmY/kzMCYP/1kqmUiAx8EFLkI2A==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-linux-arm64-musl": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-arm64-musl/-/oxide-linux-arm64-musl-4.1.1.tgz", + "integrity": "sha512-1KPnDMlHdqjPTUSFjx55pafvs8RZXRgxfeRgUrukwDKkuj7gFk28vW3Mx65YdiugAc9NWs3VgueZWaM1Po6uGw==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-linux-x64-gnu": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-x64-gnu/-/oxide-linux-x64-gnu-4.1.1.tgz", + "integrity": "sha512-4WdzA+MRlsinEEE6yxNMLJxpw0kE9XVipbAKdTL8BeUpyC2TdA3TL46lBulXzKp3BIxh3nqyR/UCqzl5o+3waQ==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-linux-x64-musl": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-x64-musl/-/oxide-linux-x64-musl-4.1.1.tgz", + "integrity": "sha512-q7Ugbw3ARcjCW2VMUYrcMbJ6aMQuWPArBBE2EqC/swPZTdGADvMQSlvR0VKusUM4HoSsO7ZbvcZ53YwR57+AKw==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-win32-arm64-msvc": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-win32-arm64-msvc/-/oxide-win32-arm64-msvc-4.1.1.tgz", + "integrity": "sha512-0KpqsovgHcIzm7eAGzzEZsEs0/nPYXnRBv+aPq/GehpNQuE/NAQu+YgZXIIof+VflDFuyXOEnaFr7T5MZ1INhA==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-win32-x64-msvc": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-win32-x64-msvc/-/oxide-win32-x64-msvc-4.1.1.tgz", + "integrity": "sha512-B1mjeXNS26kBOHv5sXARf6Wd0PWHV9x1TDlW0ummrBUOUAxAy5wcy4Nii1wzNvCdvC448hgiL06ylhwAbNthmg==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/postcss": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/postcss/-/postcss-4.1.1.tgz", + "integrity": "sha512-GX9AEM+msH0i2Yh1b6CuDRaZRo3kmbvIrLbSfvJ53C3uaAgsQ//fTQAh9HMQ6t1a9zvoUptlYqG//plWsBQTCw==", + "license": "MIT", + "dependencies": { + "@alloc/quick-lru": "^5.2.0", + "@tailwindcss/node": "4.1.1", + "@tailwindcss/oxide": "4.1.1", + "postcss": "^8.4.41", + "tailwindcss": "4.1.1" + } + }, + "node_modules/@tailwindcss/vite": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/vite/-/vite-4.1.1.tgz", + "integrity": "sha512-tFTkRZwXq4XKr3S2dUZBxy80wbWYHdDSsu4QOB1yE1HJFKjfxKVpXtup4dyTVdQcLInoHC9lZXFPHnjoBP774g==", + "license": "MIT", + "dependencies": { + "@tailwindcss/node": "4.1.1", + "@tailwindcss/oxide": "4.1.1", + "tailwindcss": "4.1.1" + }, + "peerDependencies": { + "vite": "^5.2.0 || ^6" + } + }, + "node_modules/@types/babel__core": { + "version": "7.20.5", + "resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.5.tgz", + "integrity": "sha512-qoQprZvz5wQFJwMDqeseRXWv3rqMvhgpbXFfVyWhbx9X47POIA6i/+dXefEmZKoAgOaTdaIgNSMqMIU61yRyzA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/parser": "^7.20.7", + "@babel/types": "^7.20.7", + "@types/babel__generator": "*", + "@types/babel__template": "*", + "@types/babel__traverse": "*" + } + }, + "node_modules/@types/babel__generator": { + "version": "7.6.8", + "resolved": "https://registry.npmjs.org/@types/babel__generator/-/babel__generator-7.6.8.tgz", + "integrity": "sha512-ASsj+tpEDsEiFr1arWrlN6V3mdfjRMZt6LtK/Vp/kreFLnr5QH5+DhvD5nINYZXzwJvXeGq+05iUXcAzVrqWtw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/types": "^7.0.0" + } + }, + "node_modules/@types/babel__template": { + "version": "7.4.4", + "resolved": "https://registry.npmjs.org/@types/babel__template/-/babel__template-7.4.4.tgz", + "integrity": "sha512-h/NUaSyG5EyxBIp8YRxo4RMe2/qQgvyowRwVMzhYhBCONbW8PUsg4lkFMrhgZhUe5z3L3MiLDuvyJ/CaPa2A8A==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/parser": "^7.1.0", + "@babel/types": "^7.0.0" + } + }, + "node_modules/@types/babel__traverse": { + "version": "7.20.6", + "resolved": "https://registry.npmjs.org/@types/babel__traverse/-/babel__traverse-7.20.6.tgz", + "integrity": "sha512-r1bzfrm0tomOI8g1SzvCaQHo6Lcv6zu0EA+W2kHrt8dyrHQxGzBBL4kdkzIS+jBMV+EYcMAEAqXqYaLJq5rOZg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/types": "^7.20.7" + } + }, + "node_modules/@types/cookie": { + "version": "0.6.0", + "resolved": "https://registry.npmjs.org/@types/cookie/-/cookie-0.6.0.tgz", + "integrity": "sha512-4Kh9a6B2bQciAhf7FSuMRRkUWecJgJu9nPnx3yzpsfXX/c50REIqpHY4C82bXP90qrLtXtkDxTZosYO3UpOwlA==", + "license": "MIT" + }, + "node_modules/@types/debug": { + "version": "4.1.12", + "resolved": "https://registry.npmjs.org/@types/debug/-/debug-4.1.12.tgz", + "integrity": "sha512-vIChWdVG3LG1SMxEvI/AK+FWJthlrqlTu7fbrlywTkkaONwk/UAGaULXRlf8vkzFBLVm0zkMdCquhL5aOjhXPQ==", + "license": "MIT", + "dependencies": { + "@types/ms": "*" + } + }, + "node_modules/@types/estree": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.6.tgz", + "integrity": "sha512-AYnb1nQyY49te+VRAVgmzfcgjYS91mY5P0TKUDCLEM+gNnA+3T6rWITXRLYCpahpqSQbN5cE+gHpnPyXjHWxcw==", + "license": "MIT" + }, + "node_modules/@types/estree-jsx": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/@types/estree-jsx/-/estree-jsx-1.0.5.tgz", + "integrity": "sha512-52CcUVNFyfb1A2ALocQw/Dd1BQFNmSdkuC3BkZ6iqhdMfQz7JWOFRuJFloOzjk+6WijU56m9oKXFAXc7o3Towg==", + "license": "MIT", + "dependencies": { + "@types/estree": "*" + } + }, + "node_modules/@types/hast": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/@types/hast/-/hast-3.0.4.tgz", + "integrity": "sha512-WPs+bbQw5aCj+x6laNGWLH3wviHtoCv/P3+otBhbOhJgG8qtpdAMlTCxLtsTWA7LH1Oh/bFCHsBn0TPS5m30EQ==", + "license": "MIT", + "dependencies": { + "@types/unist": "*" + } + }, + "node_modules/@types/json-schema": { + "version": "7.0.15", + "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.15.tgz", + "integrity": "sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/katex": { + "version": "0.16.7", + "resolved": "https://registry.npmjs.org/@types/katex/-/katex-0.16.7.tgz", + "integrity": "sha512-HMwFiRujE5PjrgwHQ25+bsLJgowjGjm5Z8FVSf0N6PwgJrwxH0QxzHYDcKsTfV3wva0vzrpqMTJS2jXPr5BMEQ==", + "license": "MIT" + }, + "node_modules/@types/linkify-it": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/@types/linkify-it/-/linkify-it-5.0.0.tgz", + "integrity": "sha512-sVDA58zAw4eWAffKOaQH5/5j3XeayukzDk+ewSsnv3p4yJEZHCCzMDiZM8e0OUrRvmpGZ85jf4yDHkHsgBNr9Q==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/markdown-it": { + "version": "14.1.2", + "resolved": "https://registry.npmjs.org/@types/markdown-it/-/markdown-it-14.1.2.tgz", + "integrity": "sha512-promo4eFwuiW+TfGxhi+0x3czqTYJkG8qB17ZUJiVF10Xm7NLVRSLUsfRTU/6h1e24VvRnXCx+hG7li58lkzog==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/linkify-it": "^5", + "@types/mdurl": "^2" + } + }, + "node_modules/@types/mdast": { + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/@types/mdast/-/mdast-4.0.4.tgz", + "integrity": "sha512-kGaNbPh1k7AFzgpud/gMdvIm5xuECykRR+JnWKQno9TAXVa6WIVCGTPvYGekIDL4uwCZQSYbUxNBSb1aUo79oA==", + "license": "MIT", + "dependencies": { + "@types/unist": "*" + } + }, + "node_modules/@types/mdurl": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/@types/mdurl/-/mdurl-2.0.0.tgz", + "integrity": "sha512-RGdgjQUZba5p6QEFAVx2OGb8rQDL/cPRG7GiedRzMcJ1tYnUANBncjbSB1NRGwbvjcPeikRABz2nshyPk1bhWg==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/ms": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/@types/ms/-/ms-2.1.0.tgz", + "integrity": "sha512-GsCCIZDE/p3i96vtEqx+7dBUGXrc7zeSK3wwPHIaRThS+9OhWIXRqzs4d6k1SVU8g91DrNRWxWUGhp5KXQb2VA==", + "license": "MIT" + }, + "node_modules/@types/node": { + "version": "22.13.1", + "resolved": "https://registry.npmjs.org/@types/node/-/node-22.13.1.tgz", + "integrity": "sha512-jK8uzQlrvXqEU91UxiK5J7pKHyzgnI1Qnl0QDHIgVGuolJhRb9EEl28Cj9b3rGR8B2lhFCtvIm5os8lFnO/1Ew==", + "devOptional": true, + "license": "MIT", + "dependencies": { + "undici-types": "~6.20.0" + } + }, + "node_modules/@types/prop-types": { + "version": "15.7.14", + "resolved": "https://registry.npmjs.org/@types/prop-types/-/prop-types-15.7.14.tgz", + "integrity": "sha512-gNMvNH49DJ7OJYv+KAKn0Xp45p8PLl6zo2YnvDIbTd4J6MER2BmWN49TG7n9LvkyihINxeKW8+3bfS2yDC9dzQ==", + "license": "MIT" + }, + "node_modules/@types/react": { + "version": "18.3.18", + "resolved": "https://registry.npmjs.org/@types/react/-/react-18.3.18.tgz", + "integrity": "sha512-t4yC+vtgnkYjNSKlFx1jkAhH8LgTo2N/7Qvi83kdEaUtMDiwpbLAktKDaAMlRcJ5eSxZkH74eEGt1ky31d7kfQ==", + "license": "MIT", + "dependencies": { + "@types/prop-types": "*", + "csstype": "^3.0.2" + } + }, + "node_modules/@types/react-dom": { + "version": "18.3.5", + "resolved": "https://registry.npmjs.org/@types/react-dom/-/react-dom-18.3.5.tgz", + "integrity": "sha512-P4t6saawp+b/dFrUr2cvkVsfvPguwsxtH6dNIYRllMsefqFzkZk5UIjzyDOv5g1dXIPdG4Sp1yCR4Z6RCUsG/Q==", + "dev": true, + "license": "MIT", + "peerDependencies": { + "@types/react": "^18.0.0" + } + }, + "node_modules/@types/unist": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/@types/unist/-/unist-3.0.3.tgz", + "integrity": "sha512-ko/gIFJRv177XgZsZcBwnqJN5x/Gien8qNOn0D5bQU/zAzVf9Zt3BlcUiLqhV9y4ARk0GbT3tnUiPNgnTXzc/Q==", + "license": "MIT" + }, + "node_modules/@typescript-eslint/eslint-plugin": { + "version": "8.23.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-8.23.0.tgz", + "integrity": "sha512-vBz65tJgRrA1Q5gWlRfvoH+w943dq9K1p1yDBY2pc+a1nbBLZp7fB9+Hk8DaALUbzjqlMfgaqlVPT1REJdkt/w==", + "dev": true, + "license": "MIT", + "dependencies": { + "@eslint-community/regexpp": "^4.10.0", + "@typescript-eslint/scope-manager": "8.23.0", + "@typescript-eslint/type-utils": "8.23.0", + "@typescript-eslint/utils": "8.23.0", + "@typescript-eslint/visitor-keys": "8.23.0", + "graphemer": "^1.4.0", + "ignore": "^5.3.1", + "natural-compare": "^1.4.0", + "ts-api-utils": "^2.0.1" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "@typescript-eslint/parser": "^8.0.0 || ^8.0.0-alpha.0", + "eslint": "^8.57.0 || ^9.0.0", + "typescript": ">=4.8.4 <5.8.0" + } + }, + "node_modules/@typescript-eslint/parser": { + "version": "8.23.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/parser/-/parser-8.23.0.tgz", + "integrity": "sha512-h2lUByouOXFAlMec2mILeELUbME5SZRN/7R9Cw2RD2lRQQY08MWMM+PmVVKKJNK1aIwqTo9t/0CvOxwPbRIE2Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/scope-manager": "8.23.0", + "@typescript-eslint/types": "8.23.0", + "@typescript-eslint/typescript-estree": "8.23.0", + "@typescript-eslint/visitor-keys": "8.23.0", + "debug": "^4.3.4" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^8.57.0 || ^9.0.0", + "typescript": ">=4.8.4 <5.8.0" + } + }, + "node_modules/@typescript-eslint/scope-manager": { + "version": "8.23.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/scope-manager/-/scope-manager-8.23.0.tgz", + "integrity": "sha512-OGqo7+dXHqI7Hfm+WqkZjKjsiRtFUQHPdGMXzk5mYXhJUedO7e/Y7i8AK3MyLMgZR93TX4bIzYrfyVjLC+0VSw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/types": "8.23.0", + "@typescript-eslint/visitor-keys": "8.23.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@typescript-eslint/type-utils": { + "version": "8.23.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/type-utils/-/type-utils-8.23.0.tgz", + "integrity": "sha512-iIuLdYpQWZKbiH+RkCGc6iu+VwscP5rCtQ1lyQ7TYuKLrcZoeJVpcLiG8DliXVkUxirW/PWlmS+d6yD51L9jvA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/typescript-estree": "8.23.0", + "@typescript-eslint/utils": "8.23.0", + "debug": "^4.3.4", + "ts-api-utils": "^2.0.1" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^8.57.0 || ^9.0.0", + "typescript": ">=4.8.4 <5.8.0" + } + }, + "node_modules/@typescript-eslint/types": { + "version": "8.23.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/types/-/types-8.23.0.tgz", + "integrity": "sha512-1sK4ILJbCmZOTt9k4vkoulT6/y5CHJ1qUYxqpF1K/DBAd8+ZUL4LlSCxOssuH5m4rUaaN0uS0HlVPvd45zjduQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@typescript-eslint/typescript-estree": { + "version": "8.23.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/typescript-estree/-/typescript-estree-8.23.0.tgz", + "integrity": "sha512-LcqzfipsB8RTvH8FX24W4UUFk1bl+0yTOf9ZA08XngFwMg4Kj8A+9hwz8Cr/ZS4KwHrmo9PJiLZkOt49vPnuvQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/types": "8.23.0", + "@typescript-eslint/visitor-keys": "8.23.0", + "debug": "^4.3.4", + "fast-glob": "^3.3.2", + "is-glob": "^4.0.3", + "minimatch": "^9.0.4", + "semver": "^7.6.0", + "ts-api-utils": "^2.0.1" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "typescript": ">=4.8.4 <5.8.0" + } + }, + "node_modules/@typescript-eslint/typescript-estree/node_modules/brace-expansion": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz", + "integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0" + } + }, + "node_modules/@typescript-eslint/typescript-estree/node_modules/minimatch": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^2.0.1" + }, + "engines": { + "node": ">=16 || 14 >=14.17" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/@typescript-eslint/typescript-estree/node_modules/semver": { + "version": "7.7.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.1.tgz", + "integrity": "sha512-hlq8tAfn0m/61p4BVRcPzIGr6LKiMwo4VM6dGi6pt4qcRkmNzTcWq6eCEjEh+qXjkMDvPlOFFSGwQjoEa6gyMA==", + "dev": true, + "license": "ISC", + "bin": { + "semver": "bin/semver.js" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/@typescript-eslint/utils": { + "version": "8.23.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/utils/-/utils-8.23.0.tgz", + "integrity": "sha512-uB/+PSo6Exu02b5ZEiVtmY6RVYO7YU5xqgzTIVZwTHvvK3HsL8tZZHFaTLFtRG3CsV4A5mhOv+NZx5BlhXPyIA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@eslint-community/eslint-utils": "^4.4.0", + "@typescript-eslint/scope-manager": "8.23.0", + "@typescript-eslint/types": "8.23.0", + "@typescript-eslint/typescript-estree": "8.23.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^8.57.0 || ^9.0.0", + "typescript": ">=4.8.4 <5.8.0" + } + }, + "node_modules/@typescript-eslint/visitor-keys": { + "version": "8.23.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/visitor-keys/-/visitor-keys-8.23.0.tgz", + "integrity": "sha512-oWWhcWDLwDfu++BGTZcmXWqpwtkwb5o7fxUIGksMQQDSdPW9prsSnfIOZMlsj4vBOSrcnjIUZMiIjODgGosFhQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/types": "8.23.0", + "eslint-visitor-keys": "^4.2.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@ungap/structured-clone": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/@ungap/structured-clone/-/structured-clone-1.3.0.tgz", + "integrity": "sha512-WmoN8qaIAo7WTYWbAZuG8PYEhn5fkz7dZrqTBZ7dtt//lL2Gwms1IcnQ5yHqjDfX8Ft5j4YzDM23f87zBfDe9g==", + "license": "ISC" + }, + "node_modules/@vitejs/plugin-react": { + "version": "4.3.4", + "resolved": "https://registry.npmjs.org/@vitejs/plugin-react/-/plugin-react-4.3.4.tgz", + "integrity": "sha512-SCCPBJtYLdE8PX/7ZQAs1QAZ8Jqwih+0VBLum1EGqmCCQal+MIUqLCzj3ZUy8ufbC0cAM4LRlSTm7IQJwWT4ug==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/core": "^7.26.0", + "@babel/plugin-transform-react-jsx-self": "^7.25.9", + "@babel/plugin-transform-react-jsx-source": "^7.25.9", + "@types/babel__core": "^7.20.5", + "react-refresh": "^0.14.2" + }, + "engines": { + "node": "^14.18.0 || >=16.0.0" + }, + "peerDependencies": { + "vite": "^4.2.0 || ^5.0.0 || ^6.0.0" + } + }, + "node_modules/@vscode/markdown-it-katex": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@vscode/markdown-it-katex/-/markdown-it-katex-1.1.1.tgz", + "integrity": "sha512-3KTlbsRBPJQLE2YmLL7K6nunTlU+W9T5+FjfNdWuIUKgxSS6HWLQHaO3L4MkJi7z7MpIPpY+g4N+cWNBPE/MSA==", + "license": "MIT", + "dependencies": { + "katex": "^0.16.4" + } + }, + "node_modules/acorn": { + "version": "8.14.0", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.14.0.tgz", + "integrity": "sha512-cl669nCJTZBsL97OF4kUQm5g5hC2uihk0NxY3WENAC0TYdILVkAyHymAntgxGkl7K+t0cXIrH5siy5S4XkFycA==", + "devOptional": true, + "license": "MIT", + "bin": { + "acorn": "bin/acorn" + }, + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/acorn-jsx": { + "version": "5.3.2", + "resolved": "https://registry.npmjs.org/acorn-jsx/-/acorn-jsx-5.3.2.tgz", + "integrity": "sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ==", + "dev": true, + "license": "MIT", + "peerDependencies": { + "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" + } + }, + "node_modules/ajv": { + "version": "6.12.6", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", + "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", + "dev": true, + "license": "MIT", + "dependencies": { + "fast-deep-equal": "^3.1.1", + "fast-json-stable-stringify": "^2.0.0", + "json-schema-traverse": "^0.4.1", + "uri-js": "^4.2.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "dev": true, + "license": "MIT", + "dependencies": { + "color-convert": "^2.0.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/argparse": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/argparse/-/argparse-2.0.1.tgz", + "integrity": "sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==", + "dev": true, + "license": "Python-2.0" + }, + "node_modules/attr-accept": { + "version": "2.2.5", + "resolved": "https://registry.npmjs.org/attr-accept/-/attr-accept-2.2.5.tgz", + "integrity": "sha512-0bDNnY/u6pPwHDMoF0FieU354oBi0a8rD9FcsLwzcGWbc8KS8KPIi7y+s13OlVY+gMWc/9xEMUgNE6Qm8ZllYQ==", + "license": "MIT", + "engines": { + "node": ">=4" + } + }, + "node_modules/autoprefixer": { + "version": "10.4.20", + "resolved": "https://registry.npmjs.org/autoprefixer/-/autoprefixer-10.4.20.tgz", + "integrity": "sha512-XY25y5xSv/wEoqzDyXXME4AFfkZI0P23z6Fs3YgymDnKJkCGOnkL0iTxCa85UTqaSgfcqyf3UA6+c7wUvx/16g==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/autoprefixer" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "dependencies": { + "browserslist": "^4.23.3", + "caniuse-lite": "^1.0.30001646", + "fraction.js": "^4.3.7", + "normalize-range": "^0.1.2", + "picocolors": "^1.0.1", + "postcss-value-parser": "^4.2.0" + }, + "bin": { + "autoprefixer": "bin/autoprefixer" + }, + "engines": { + "node": "^10 || ^12 || >=14" + }, + "peerDependencies": { + "postcss": "^8.1.0" + } + }, + "node_modules/bail": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/bail/-/bail-2.0.2.tgz", + "integrity": "sha512-0xO6mYd7JB2YesxDKplafRpsiOzPt9V02ddPCLbY1xYGPOX24NTyN50qnUxgCPcSoYMhKpAuBTjQoRZCAkUDRw==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/balanced-match": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", + "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==", + "dev": true, + "license": "MIT" + }, + "node_modules/brace-expansion": { + "version": "1.1.11", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz", + "integrity": "sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0", + "concat-map": "0.0.1" + } + }, + "node_modules/braces": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", + "license": "MIT", + "dependencies": { + "fill-range": "^7.1.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/browserslist": { + "version": "4.24.4", + "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.24.4.tgz", + "integrity": "sha512-KDi1Ny1gSePi1vm0q4oxSF8b4DR44GF4BbmS2YdhPLOEqd8pDviZOGH/GsmRwoWJ2+5Lr085X7naowMwKHDG1A==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/browserslist" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/browserslist" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "dependencies": { + "caniuse-lite": "^1.0.30001688", + "electron-to-chromium": "^1.5.73", + "node-releases": "^2.0.19", + "update-browserslist-db": "^1.1.1" + }, + "bin": { + "browserslist": "cli.js" + }, + "engines": { + "node": "^6 || ^7 || ^8 || ^9 || ^10 || ^11 || ^12 || >=13.7" + } + }, + "node_modules/buffer-builder": { + "version": "0.2.0", + "resolved": "https://registry.npmjs.org/buffer-builder/-/buffer-builder-0.2.0.tgz", + "integrity": "sha512-7VPMEPuYznPSoR21NE1zvd2Xna6c/CloiZCfcMXR1Jny6PjX0N4Nsa38zcBFo/FMK+BlA+FLKbJCQ0i2yxp+Xg==", + "devOptional": true, + "license": "MIT/X11" + }, + "node_modules/buffer-from": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/buffer-from/-/buffer-from-1.1.2.tgz", + "integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==", + "license": "MIT", + "optional": true, + "peer": true + }, + "node_modules/callsites": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz", + "integrity": "sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/caniuse-lite": { + "version": "1.0.30001697", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001697.tgz", + "integrity": "sha512-GwNPlWJin8E+d7Gxq96jxM6w0w+VFeyyXRsjU58emtkYqnbwHqXm5uT2uCmO0RQE9htWknOP4xtBlLmM/gWxvQ==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/browserslist" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/caniuse-lite" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "CC-BY-4.0" + }, + "node_modules/ccount": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/ccount/-/ccount-2.0.1.tgz", + "integrity": "sha512-eyrF0jiFpY+3drT6383f1qhkbGsLSifNAjA61IUjZjmLCWjItY6LB9ft9YhoDgwfmclB2zhu51Lc7+95b8NRAg==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/chalk": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", + "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-styles": "^4.1.0", + "supports-color": "^7.1.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/chalk?sponsor=1" + } + }, + "node_modules/character-entities": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/character-entities/-/character-entities-2.0.2.tgz", + "integrity": "sha512-shx7oQ0Awen/BRIdkjkvz54PnEEI/EjwXDSIZp86/KKdbafHh1Df/RYGBhn4hbe2+uKC9FnT5UCEdyPz3ai9hQ==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/character-entities-html4": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/character-entities-html4/-/character-entities-html4-2.1.0.tgz", + "integrity": "sha512-1v7fgQRj6hnSwFpq1Eu0ynr/CDEw0rXo2B61qXrLNdHZmPKgb7fqS1a2JwF0rISo9q77jDI8VMEHoApn8qDoZA==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/character-entities-legacy": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/character-entities-legacy/-/character-entities-legacy-3.0.0.tgz", + "integrity": "sha512-RpPp0asT/6ufRm//AJVwpViZbGM/MkjQFxJccQRHmISF/22NBtsHqAWmL+/pmkPWoIUJdWyeVleTl1wydHATVQ==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/character-reference-invalid": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/character-reference-invalid/-/character-reference-invalid-2.0.1.tgz", + "integrity": "sha512-iBZ4F4wRbyORVsu0jPV7gXkOsGYjGHPmAyv+HiHG8gi5PtC9KI2j1+v8/tlibRvjoWX027ypmG/n0HtO5t7unw==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/color-convert": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "color-name": "~1.1.4" + }, + "engines": { + "node": ">=7.0.0" + } + }, + "node_modules/color-name": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "dev": true, + "license": "MIT" + }, + "node_modules/colorjs.io": { + "version": "0.5.2", + "resolved": "https://registry.npmjs.org/colorjs.io/-/colorjs.io-0.5.2.tgz", + "integrity": "sha512-twmVoizEW7ylZSN32OgKdXRmo1qg+wT5/6C3xu5b9QsWzSFAhHLn2xd8ro0diCsKfCj1RdaTP/nrcW+vAoQPIw==", + "devOptional": true, + "license": "MIT" + }, + "node_modules/comma-separated-tokens": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/comma-separated-tokens/-/comma-separated-tokens-2.0.3.tgz", + "integrity": "sha512-Fu4hJdvzeylCfQPp9SGWidpzrMs7tTrlu6Vb8XGaRGck8QSNZJJp538Wrb60Lax4fPwR64ViY468OIUTbRlGZg==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/commander": { + "version": "8.3.0", + "resolved": "https://registry.npmjs.org/commander/-/commander-8.3.0.tgz", + "integrity": "sha512-OkTL9umf+He2DZkUq8f8J9of7yL6RJKI24dVITBmNfZBmri9zYZQrKkuXiKhyfPSu8tUhnVBB1iKXevvnlR4Ww==", + "license": "MIT", + "engines": { + "node": ">= 12" + } + }, + "node_modules/concat-map": { + "version": "0.0.1", + "resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", + "integrity": "sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==", + "dev": true, + "license": "MIT" + }, + "node_modules/convert-source-map": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/convert-source-map/-/convert-source-map-2.0.0.tgz", + "integrity": "sha512-Kvp459HrV2FEJ1CAsi1Ku+MY3kasH19TFykTz2xWmMeq6bk2NU3XXvfJ+Q61m0xktWwt+1HSYf3JZsTms3aRJg==", + "dev": true, + "license": "MIT" + }, + "node_modules/cookie": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-1.0.2.tgz", + "integrity": "sha512-9Kr/j4O16ISv8zBBhJoi4bXOYNTkFLOqSL3UDB0njXxCXNezjeyVrJyGOWtgfs/q2km1gwBcfH8q1yEGoMYunA==", + "license": "MIT", + "engines": { + "node": ">=18" + } + }, + "node_modules/cross-spawn": { + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", + "dev": true, + "license": "MIT", + "dependencies": { + "path-key": "^3.1.0", + "shebang-command": "^2.0.0", + "which": "^2.0.1" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/csstype": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/csstype/-/csstype-3.1.3.tgz", + "integrity": "sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw==", + "license": "MIT" + }, + "node_modules/daisyui": { + "version": "5.0.12", + "resolved": "https://registry.npmjs.org/daisyui/-/daisyui-5.0.12.tgz", + "integrity": "sha512-01DU0eYBcHgPtuf5fxcrkGkIN6/Uyaqmkle5Yo3ZyW9YVAu036ALZbjv2KH5euvUbeQ4r9q3gAarGcf7Tywhng==", + "license": "MIT", + "funding": { + "url": "https://github.com/saadeghi/daisyui?sponsor=1" + } + }, + "node_modules/debug": { + "version": "4.4.0", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.0.tgz", + "integrity": "sha512-6WTZ/IxCY/T6BALoZHaE4ctp9xm+Z5kY/pzYaCHRFeyVhojxlrm+46y68HA6hr0TcwEssoxNiDEUJQjfPZ/RYA==", + "license": "MIT", + "dependencies": { + "ms": "^2.1.3" + }, + "engines": { + "node": ">=6.0" + }, + "peerDependenciesMeta": { + "supports-color": { + "optional": true + } + } + }, + "node_modules/decode-named-character-reference": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/decode-named-character-reference/-/decode-named-character-reference-1.0.2.tgz", + "integrity": "sha512-O8x12RzrUF8xyVcY0KJowWsmaJxQbmy0/EtnNtHRpsOcT7dFk5W598coHqBVpmWo1oQQfsCqfCmkZN5DJrZVdg==", + "license": "MIT", + "dependencies": { + "character-entities": "^2.0.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/deep-is": { + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/deep-is/-/deep-is-0.1.4.tgz", + "integrity": "sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/dequal": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/dequal/-/dequal-2.0.3.tgz", + "integrity": "sha512-0je+qPKHEMohvfRTCEo3CrPG6cAzAYgmzKyxRiYSSDkS6eGJdyVJm7WaYA5ECaAD9wLB2T4EEeymA5aFVcYXCA==", + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/detect-libc": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/detect-libc/-/detect-libc-2.0.3.tgz", + "integrity": "sha512-bwy0MGW55bG41VqxxypOsdSdGqLwXPI/focwgTYCFMbdUiBAxLg9CFzG08sz2aqzknwiX7Hkl0bQENjg8iLByw==", + "license": "Apache-2.0", + "engines": { + "node": ">=8" + } + }, + "node_modules/devlop": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/devlop/-/devlop-1.1.0.tgz", + "integrity": "sha512-RWmIqhcFf1lRYBvNmr7qTNuyCt/7/ns2jbpp1+PalgE/rDQcBT0fioSMUpJ93irlUhC5hrg4cYqe6U+0ImW0rA==", + "license": "MIT", + "dependencies": { + "dequal": "^2.0.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/dexie": { + "version": "4.0.11", + "resolved": "https://registry.npmjs.org/dexie/-/dexie-4.0.11.tgz", + "integrity": "sha512-SOKO002EqlvBYYKQSew3iymBoN2EQ4BDw/3yprjh7kAfFzjBYkaMNa/pZvcA7HSWlcKSQb9XhPe3wKyQ0x4A8A==", + "license": "Apache-2.0" + }, + "node_modules/electron-to-chromium": { + "version": "1.5.91", + "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.91.tgz", + "integrity": "sha512-sNSHHyq048PFmZY4S90ax61q+gLCs0X0YmcOII9wG9S2XwbVr+h4VW2wWhnbp/Eys3cCwTxVF292W3qPaxIapQ==", + "license": "ISC" + }, + "node_modules/enhanced-resolve": { + "version": "5.18.1", + "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.18.1.tgz", + "integrity": "sha512-ZSW3ma5GkcQBIpwZTSRAI8N71Uuwgs93IezB7mf7R60tC8ZbJideoDNKjHn2O9KIlx6rkGTTEk1xUCK2E1Y2Yg==", + "license": "MIT", + "dependencies": { + "graceful-fs": "^4.2.4", + "tapable": "^2.2.0" + }, + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/entities": { + "version": "4.5.0", + "resolved": "https://registry.npmjs.org/entities/-/entities-4.5.0.tgz", + "integrity": "sha512-V0hjH4dGPh9Ao5p0MoRY6BVqtwCjhz6vI5LT8AJ55H+4g9/4vbHx1I54fS0XuclLhDHArPQCiMjDxjaL8fPxhw==", + "license": "BSD-2-Clause", + "engines": { + "node": ">=0.12" + }, + "funding": { + "url": "https://github.com/fb55/entities?sponsor=1" + } + }, + "node_modules/esbuild": { + "version": "0.24.2", + "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.24.2.tgz", + "integrity": "sha512-+9egpBW8I3CD5XPe0n6BfT5fxLzxrlDzqydF3aviG+9ni1lDC/OvMHcxqEFV0+LANZG5R1bFMWfUrjVsdwxJvA==", + "hasInstallScript": true, + "license": "MIT", + "bin": { + "esbuild": "bin/esbuild" + }, + "engines": { + "node": ">=18" + }, + "optionalDependencies": { + "@esbuild/aix-ppc64": "0.24.2", + "@esbuild/android-arm": "0.24.2", + "@esbuild/android-arm64": "0.24.2", + "@esbuild/android-x64": "0.24.2", + "@esbuild/darwin-arm64": "0.24.2", + "@esbuild/darwin-x64": "0.24.2", + "@esbuild/freebsd-arm64": "0.24.2", + "@esbuild/freebsd-x64": "0.24.2", + "@esbuild/linux-arm": "0.24.2", + "@esbuild/linux-arm64": "0.24.2", + "@esbuild/linux-ia32": "0.24.2", + "@esbuild/linux-loong64": "0.24.2", + "@esbuild/linux-mips64el": "0.24.2", + "@esbuild/linux-ppc64": "0.24.2", + "@esbuild/linux-riscv64": "0.24.2", + "@esbuild/linux-s390x": "0.24.2", + "@esbuild/linux-x64": "0.24.2", + "@esbuild/netbsd-arm64": "0.24.2", + "@esbuild/netbsd-x64": "0.24.2", + "@esbuild/openbsd-arm64": "0.24.2", + "@esbuild/openbsd-x64": "0.24.2", + "@esbuild/sunos-x64": "0.24.2", + "@esbuild/win32-arm64": "0.24.2", + "@esbuild/win32-ia32": "0.24.2", + "@esbuild/win32-x64": "0.24.2" + } + }, + "node_modules/escalade": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz", + "integrity": "sha512-WUj2qlxaQtO4g6Pq5c29GTcWGDyd8itL8zTlipgECz3JesAiiOKotd8JU6otB3PACgG6xkJUyVhboMS+bje/jA==", + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/escape-string-regexp": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-4.0.0.tgz", + "integrity": "sha512-TtpcNJ3XAzx3Gq8sWRzJaVajRs0uVxA2YAkdb1jm2YkPz4G6egUFAyA3n5vtEIZefPk5Wa4UXbKuS5fKkJWdgA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/eslint": { + "version": "9.19.0", + "resolved": "https://registry.npmjs.org/eslint/-/eslint-9.19.0.tgz", + "integrity": "sha512-ug92j0LepKlbbEv6hD911THhoRHmbdXt2gX+VDABAW/Ir7D3nqKdv5Pf5vtlyY6HQMTEP2skXY43ueqTCWssEA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@eslint-community/eslint-utils": "^4.2.0", + "@eslint-community/regexpp": "^4.12.1", + "@eslint/config-array": "^0.19.0", + "@eslint/core": "^0.10.0", + "@eslint/eslintrc": "^3.2.0", + "@eslint/js": "9.19.0", + "@eslint/plugin-kit": "^0.2.5", + "@humanfs/node": "^0.16.6", + "@humanwhocodes/module-importer": "^1.0.1", + "@humanwhocodes/retry": "^0.4.1", + "@types/estree": "^1.0.6", + "@types/json-schema": "^7.0.15", + "ajv": "^6.12.4", + "chalk": "^4.0.0", + "cross-spawn": "^7.0.6", + "debug": "^4.3.2", + "escape-string-regexp": "^4.0.0", + "eslint-scope": "^8.2.0", + "eslint-visitor-keys": "^4.2.0", + "espree": "^10.3.0", + "esquery": "^1.5.0", + "esutils": "^2.0.2", + "fast-deep-equal": "^3.1.3", + "file-entry-cache": "^8.0.0", + "find-up": "^5.0.0", + "glob-parent": "^6.0.2", + "ignore": "^5.2.0", + "imurmurhash": "^0.1.4", + "is-glob": "^4.0.0", + "json-stable-stringify-without-jsonify": "^1.0.1", + "lodash.merge": "^4.6.2", + "minimatch": "^3.1.2", + "natural-compare": "^1.4.0", + "optionator": "^0.9.3" + }, + "bin": { + "eslint": "bin/eslint.js" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://eslint.org/donate" + }, + "peerDependencies": { + "jiti": "*" + }, + "peerDependenciesMeta": { + "jiti": { + "optional": true + } + } + }, + "node_modules/eslint-plugin-react-hooks": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/eslint-plugin-react-hooks/-/eslint-plugin-react-hooks-5.1.0.tgz", + "integrity": "sha512-mpJRtPgHN2tNAvZ35AMfqeB3Xqeo273QxrHJsbBEPWODRM4r0yB6jfoROqKEYrOn27UtRPpcpHc2UqyBSuUNTw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "eslint": "^3.0.0 || ^4.0.0 || ^5.0.0 || ^6.0.0 || ^7.0.0 || ^8.0.0-0 || ^9.0.0" + } + }, + "node_modules/eslint-plugin-react-refresh": { + "version": "0.4.18", + "resolved": "https://registry.npmjs.org/eslint-plugin-react-refresh/-/eslint-plugin-react-refresh-0.4.18.tgz", + "integrity": "sha512-IRGEoFn3OKalm3hjfolEWGqoF/jPqeEYFp+C8B0WMzwGwBMvlRDQd06kghDhF0C61uJ6WfSDhEZE/sAQjduKgw==", + "dev": true, + "license": "MIT", + "peerDependencies": { + "eslint": ">=8.40" + } + }, + "node_modules/eslint-scope": { + "version": "8.2.0", + "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-8.2.0.tgz", + "integrity": "sha512-PHlWUfG6lvPc3yvP5A4PNyBL1W8fkDUccmI21JUu/+GKZBoH/W5u6usENXUrWFRsyoW5ACUjFGgAFQp5gUlb/A==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "esrecurse": "^4.3.0", + "estraverse": "^5.2.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/eslint-visitor-keys": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-4.2.0.tgz", + "integrity": "sha512-UyLnSehNt62FFhSwjZlHmeokpRK59rcz29j+F1/aDgbkbRTk7wIc9XzdoasMUbRNKDM0qQt/+BJ4BrpFeABemw==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/espree": { + "version": "10.3.0", + "resolved": "https://registry.npmjs.org/espree/-/espree-10.3.0.tgz", + "integrity": "sha512-0QYC8b24HWY8zjRnDTL6RiHfDbAWn63qb4LMj1Z4b076A4une81+z03Kg7l7mn/48PUTqoLptSXez8oknU8Clg==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "acorn": "^8.14.0", + "acorn-jsx": "^5.3.2", + "eslint-visitor-keys": "^4.2.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/esquery": { + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/esquery/-/esquery-1.6.0.tgz", + "integrity": "sha512-ca9pw9fomFcKPvFLXhBKUK90ZvGibiGOvRJNbjljY7s7uq/5YO4BOzcYtJqExdx99rF6aAcnRxHmcUHcz6sQsg==", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "estraverse": "^5.1.0" + }, + "engines": { + "node": ">=0.10" + } + }, + "node_modules/esrecurse": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/esrecurse/-/esrecurse-4.3.0.tgz", + "integrity": "sha512-KmfKL3b6G+RXvP8N1vr3Tq1kL/oCFgn2NYXEtqP8/L3pKapUA4G8cFVaoF3SU323CD4XypR/ffioHmkti6/Tag==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "estraverse": "^5.2.0" + }, + "engines": { + "node": ">=4.0" + } + }, + "node_modules/estraverse": { + "version": "5.3.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.3.0.tgz", + "integrity": "sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==", + "dev": true, + "license": "BSD-2-Clause", + "engines": { + "node": ">=4.0" + } + }, + "node_modules/estree-util-is-identifier-name": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/estree-util-is-identifier-name/-/estree-util-is-identifier-name-3.0.0.tgz", + "integrity": "sha512-hFtqIDZTIUZ9BXLb8y4pYGyk6+wekIivNVTcmvk8NoOh+VeRn5y6cEHzbURrWbfp1fIqdVipilzj+lfaadNZmg==", + "license": "MIT", + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/esutils": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/esutils/-/esutils-2.0.3.tgz", + "integrity": "sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==", + "dev": true, + "license": "BSD-2-Clause", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/extend": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/extend/-/extend-3.0.2.tgz", + "integrity": "sha512-fjquC59cD7CyW6urNXK0FBufkZcoiGG80wTuPujX590cB5Ttln20E2UB4S/WARVqhXffZl2LNgS+gQdPIIim/g==", + "license": "MIT" + }, + "node_modules/fast-deep-equal": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", + "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==", + "dev": true, + "license": "MIT" + }, + "node_modules/fast-glob": { + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.3.tgz", + "integrity": "sha512-7MptL8U0cqcFdzIzwOTHoilX9x5BrNqye7Z/LuC7kCMRio1EMSyqRK3BEAUD7sXRq4iT4AzTVuZdhgQ2TCvYLg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@nodelib/fs.stat": "^2.0.2", + "@nodelib/fs.walk": "^1.2.3", + "glob-parent": "^5.1.2", + "merge2": "^1.3.0", + "micromatch": "^4.0.8" + }, + "engines": { + "node": ">=8.6.0" + } + }, + "node_modules/fast-glob/node_modules/glob-parent": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", + "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", + "dev": true, + "license": "ISC", + "dependencies": { + "is-glob": "^4.0.1" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/fast-json-stable-stringify": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz", + "integrity": "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==", + "dev": true, + "license": "MIT" + }, + "node_modules/fast-levenshtein": { + "version": "2.0.6", + "resolved": "https://registry.npmjs.org/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz", + "integrity": "sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw==", + "dev": true, + "license": "MIT" + }, + "node_modules/fastq": { + "version": "1.19.0", + "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.19.0.tgz", + "integrity": "sha512-7SFSRCNjBQIZH/xZR3iy5iQYR8aGBE0h3VG6/cwlbrpdciNYBMotQav8c1XI3HjHH+NikUpP53nPdlZSdWmFzA==", + "dev": true, + "license": "ISC", + "dependencies": { + "reusify": "^1.0.4" + } + }, + "node_modules/fflate": { + "version": "0.8.2", + "resolved": "https://registry.npmjs.org/fflate/-/fflate-0.8.2.tgz", + "integrity": "sha512-cPJU47OaAoCbg0pBvzsgpTPhmhqI5eJjh/JIu8tPj5q+T7iLvW/JAYUqmE7KOB4R1ZyEhzBaIQpQpardBF5z8A==", + "dev": true, + "license": "MIT" + }, + "node_modules/file-entry-cache": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/file-entry-cache/-/file-entry-cache-8.0.0.tgz", + "integrity": "sha512-XXTUwCvisa5oacNGRP9SfNtYBNAMi+RPwBFmblZEF7N7swHYQS6/Zfk7SRwx4D5j3CH211YNRco1DEMNVfZCnQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "flat-cache": "^4.0.0" + }, + "engines": { + "node": ">=16.0.0" + } + }, + "node_modules/file-selector": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/file-selector/-/file-selector-2.1.2.tgz", + "integrity": "sha512-QgXo+mXTe8ljeqUFaX3QVHc5osSItJ/Km+xpocx0aSqWGMSCf6qYs/VnzZgS864Pjn5iceMRFigeAV7AfTlaig==", + "license": "MIT", + "dependencies": { + "tslib": "^2.7.0" + }, + "engines": { + "node": ">= 12" + } + }, + "node_modules/fill-range": { + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", + "license": "MIT", + "dependencies": { + "to-regex-range": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/find-up": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/find-up/-/find-up-5.0.0.tgz", + "integrity": "sha512-78/PXT1wlLLDgTzDs7sjq9hzz0vXD+zn+7wypEe4fXQxCmdmqfGsEPQxmiCSQI3ajFV91bVSsvNtrJRiW6nGng==", + "dev": true, + "license": "MIT", + "dependencies": { + "locate-path": "^6.0.0", + "path-exists": "^4.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/flat-cache": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/flat-cache/-/flat-cache-4.0.1.tgz", + "integrity": "sha512-f7ccFPK3SXFHpx15UIGyRJ/FJQctuKZ0zVuN3frBo4HnK3cay9VEW0R6yPYFHC0AgqhukPzKjq22t5DmAyqGyw==", + "dev": true, + "license": "MIT", + "dependencies": { + "flatted": "^3.2.9", + "keyv": "^4.5.4" + }, + "engines": { + "node": ">=16" + } + }, + "node_modules/flatted": { + "version": "3.3.2", + "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.2.tgz", + "integrity": "sha512-AiwGJM8YcNOaobumgtng+6NHuOqC3A7MixFeDafM3X9cIUM+xUXoS5Vfgf+OihAYe20fxqNM9yPBXJzRtZ/4eA==", + "dev": true, + "license": "ISC" + }, + "node_modules/fraction.js": { + "version": "4.3.7", + "resolved": "https://registry.npmjs.org/fraction.js/-/fraction.js-4.3.7.tgz", + "integrity": "sha512-ZsDfxO51wGAXREY55a7la9LScWpwv9RxIrYABrlvOFBlH/ShPnrtsXeuUIfXKKOVicNxQ+o8JTbJvjS4M89yew==", + "license": "MIT", + "engines": { + "node": "*" + }, + "funding": { + "type": "patreon", + "url": "https://github.com/sponsors/rawify" + } + }, + "node_modules/fsevents": { + "version": "2.3.3", + "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", + "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", + "hasInstallScript": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^8.16.0 || ^10.6.0 || >=11.0.0" + } + }, + "node_modules/gensync": { + "version": "1.0.0-beta.2", + "resolved": "https://registry.npmjs.org/gensync/-/gensync-1.0.0-beta.2.tgz", + "integrity": "sha512-3hN7NaskYvMDLQY55gnW3NQ+mesEAepTqlg+VEbj7zzqEMBVNhzcGYYeqFo/TlYz6eQiFcp1HcsCZO+nGgS8zg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/glob-parent": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", + "integrity": "sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A==", + "dev": true, + "license": "ISC", + "dependencies": { + "is-glob": "^4.0.3" + }, + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/globals": { + "version": "15.14.0", + "resolved": "https://registry.npmjs.org/globals/-/globals-15.14.0.tgz", + "integrity": "sha512-OkToC372DtlQeje9/zHIo5CT8lRP/FUgEOKBEhU4e0abL7J7CD24fD9ohiLN5hagG/kWCYj4K5oaxxtj2Z0Dig==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/goober": { + "version": "2.1.16", + "resolved": "https://registry.npmjs.org/goober/-/goober-2.1.16.tgz", + "integrity": "sha512-erjk19y1U33+XAMe1VTvIONHYoSqE4iS7BYUZfHaqeohLmnC0FdxEh7rQU+6MZ4OajItzjZFSRtVANrQwNq6/g==", + "license": "MIT", + "peerDependencies": { + "csstype": "^3.0.10" + } + }, + "node_modules/graceful-fs": { + "version": "4.2.11", + "resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz", + "integrity": "sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==", + "license": "ISC" + }, + "node_modules/graphemer": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/graphemer/-/graphemer-1.4.0.tgz", + "integrity": "sha512-EtKwoO6kxCL9WO5xipiHTZlSzBm7WLT627TqC/uVRd0HKmq8NXyebnNYxDoBi7wt8eTWrUrKXCOVaFq9x1kgag==", + "dev": true, + "license": "MIT" + }, + "node_modules/has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "devOptional": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/hast-util-from-dom": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/hast-util-from-dom/-/hast-util-from-dom-5.0.1.tgz", + "integrity": "sha512-N+LqofjR2zuzTjCPzyDUdSshy4Ma6li7p/c3pA78uTwzFgENbgbUrm2ugwsOdcjI1muO+o6Dgzp9p8WHtn/39Q==", + "license": "ISC", + "dependencies": { + "@types/hast": "^3.0.0", + "hastscript": "^9.0.0", + "web-namespaces": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/hast-util-from-html": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/hast-util-from-html/-/hast-util-from-html-2.0.3.tgz", + "integrity": "sha512-CUSRHXyKjzHov8yKsQjGOElXy/3EKpyX56ELnkHH34vDVw1N1XSQ1ZcAvTyAPtGqLTuKP/uxM+aLkSPqF/EtMw==", + "license": "MIT", + "dependencies": { + "@types/hast": "^3.0.0", + "devlop": "^1.1.0", + "hast-util-from-parse5": "^8.0.0", + "parse5": "^7.0.0", + "vfile": "^6.0.0", + "vfile-message": "^4.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/hast-util-from-html-isomorphic": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/hast-util-from-html-isomorphic/-/hast-util-from-html-isomorphic-2.0.0.tgz", + "integrity": "sha512-zJfpXq44yff2hmE0XmwEOzdWin5xwH+QIhMLOScpX91e/NSGPsAzNCvLQDIEPyO2TXi+lBmU6hjLIhV8MwP2kw==", + "license": "MIT", + "dependencies": { + "@types/hast": "^3.0.0", + "hast-util-from-dom": "^5.0.0", + "hast-util-from-html": "^2.0.0", + "unist-util-remove-position": "^5.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/hast-util-from-parse5": { + "version": "8.0.2", + "resolved": "https://registry.npmjs.org/hast-util-from-parse5/-/hast-util-from-parse5-8.0.2.tgz", + "integrity": "sha512-SfMzfdAi/zAoZ1KkFEyyeXBn7u/ShQrfd675ZEE9M3qj+PMFX05xubzRyF76CCSJu8au9jgVxDV1+okFvgZU4A==", + "license": "MIT", + "dependencies": { + "@types/hast": "^3.0.0", + "@types/unist": "^3.0.0", + "devlop": "^1.0.0", + "hastscript": "^9.0.0", + "property-information": "^6.0.0", + "vfile": "^6.0.0", + "vfile-location": "^5.0.0", + "web-namespaces": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/hast-util-is-element": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/hast-util-is-element/-/hast-util-is-element-3.0.0.tgz", + "integrity": "sha512-Val9mnv2IWpLbNPqc/pUem+a7Ipj2aHacCwgNfTiK0vJKl0LF+4Ba4+v1oPHFpf3bLYmreq0/l3Gud9S5OH42g==", + "license": "MIT", + "dependencies": { + "@types/hast": "^3.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/hast-util-parse-selector": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/hast-util-parse-selector/-/hast-util-parse-selector-4.0.0.tgz", + "integrity": "sha512-wkQCkSYoOGCRKERFWcxMVMOcYE2K1AaNLU8DXS9arxnLOUEWbOXKXiJUNzEpqZ3JOKpnha3jkFrumEjVliDe7A==", + "license": "MIT", + "dependencies": { + "@types/hast": "^3.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/hast-util-to-jsx-runtime": { + "version": "2.3.2", + "resolved": "https://registry.npmjs.org/hast-util-to-jsx-runtime/-/hast-util-to-jsx-runtime-2.3.2.tgz", + "integrity": "sha512-1ngXYb+V9UT5h+PxNRa1O1FYguZK/XL+gkeqvp7EdHlB9oHUG0eYRo/vY5inBdcqo3RkPMC58/H94HvkbfGdyg==", + "license": "MIT", + "dependencies": { + "@types/estree": "^1.0.0", + "@types/hast": "^3.0.0", + "@types/unist": "^3.0.0", + "comma-separated-tokens": "^2.0.0", + "devlop": "^1.0.0", + "estree-util-is-identifier-name": "^3.0.0", + "hast-util-whitespace": "^3.0.0", + "mdast-util-mdx-expression": "^2.0.0", + "mdast-util-mdx-jsx": "^3.0.0", + "mdast-util-mdxjs-esm": "^2.0.0", + "property-information": "^6.0.0", + "space-separated-tokens": "^2.0.0", + "style-to-object": "^1.0.0", + "unist-util-position": "^5.0.0", + "vfile-message": "^4.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/hast-util-to-text": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/hast-util-to-text/-/hast-util-to-text-4.0.2.tgz", + "integrity": "sha512-KK6y/BN8lbaq654j7JgBydev7wuNMcID54lkRav1P0CaE1e47P72AWWPiGKXTJU271ooYzcvTAn/Zt0REnvc7A==", + "license": "MIT", + "dependencies": { + "@types/hast": "^3.0.0", + "@types/unist": "^3.0.0", + "hast-util-is-element": "^3.0.0", + "unist-util-find-after": "^5.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/hast-util-whitespace": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/hast-util-whitespace/-/hast-util-whitespace-3.0.0.tgz", + "integrity": "sha512-88JUN06ipLwsnv+dVn+OIYOvAuvBMy/Qoi6O7mQHxdPXpjy+Cd6xRkWwux7DKO+4sYILtLBRIKgsdpS2gQc7qw==", + "license": "MIT", + "dependencies": { + "@types/hast": "^3.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/hastscript": { + "version": "9.0.0", + "resolved": "https://registry.npmjs.org/hastscript/-/hastscript-9.0.0.tgz", + "integrity": "sha512-jzaLBGavEDKHrc5EfFImKN7nZKKBdSLIdGvCwDZ9TfzbF2ffXiov8CKE445L2Z1Ek2t/m4SKQ2j6Ipv7NyUolw==", + "license": "MIT", + "dependencies": { + "@types/hast": "^3.0.0", + "comma-separated-tokens": "^2.0.0", + "hast-util-parse-selector": "^4.0.0", + "property-information": "^6.0.0", + "space-separated-tokens": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/highlight.js": { + "version": "11.11.1", + "resolved": "https://registry.npmjs.org/highlight.js/-/highlight.js-11.11.1.tgz", + "integrity": "sha512-Xwwo44whKBVCYoliBQwaPvtd/2tYFkRQtXDWj1nackaV2JPXx3L0+Jvd8/qCJ2p+ML0/XVkJ2q+Mr+UVdpJK5w==", + "license": "BSD-3-Clause", + "engines": { + "node": ">=12.0.0" + } + }, + "node_modules/html-url-attributes": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/html-url-attributes/-/html-url-attributes-3.0.1.tgz", + "integrity": "sha512-ol6UPyBWqsrO6EJySPz2O7ZSr856WDrEzM5zMqp+FJJLGMW35cLYmmZnl0vztAZxRUoNZJFTCohfjuIJ8I4QBQ==", + "license": "MIT", + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/ignore": { + "version": "5.3.2", + "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.3.2.tgz", + "integrity": "sha512-hsBTNUqQTDwkWtcdYI2i06Y/nUBEsNEDJKjWdigLvegy8kDuJAS8uRlpkkcQpyEXL0Z/pjDy5HBmMjRCJ2gq+g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 4" + } + }, + "node_modules/immutable": { + "version": "5.0.3", + "resolved": "https://registry.npmjs.org/immutable/-/immutable-5.0.3.tgz", + "integrity": "sha512-P8IdPQHq3lA1xVeBRi5VPqUm5HDgKnx0Ru51wZz5mjxHr5n3RWhjIpOFU7ybkUxfB+5IToy+OLaHYDBIWsv+uw==", + "devOptional": true, + "license": "MIT" + }, + "node_modules/import-fresh": { + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.1.tgz", + "integrity": "sha512-TR3KfrTZTYLPB6jUjfx6MF9WcWrHL9su5TObK4ZkYgBdWKPOFoSoQIdEuTuR82pmtxH2spWG9h6etwfr1pLBqQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "parent-module": "^1.0.0", + "resolve-from": "^4.0.0" + }, + "engines": { + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/imurmurhash": { + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/imurmurhash/-/imurmurhash-0.1.4.tgz", + "integrity": "sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.8.19" + } + }, + "node_modules/inline-style-parser": { + "version": "0.2.4", + "resolved": "https://registry.npmjs.org/inline-style-parser/-/inline-style-parser-0.2.4.tgz", + "integrity": "sha512-0aO8FkhNZlj/ZIbNi7Lxxr12obT7cL1moPfE4tg1LkX7LlLfC6DeX4l2ZEud1ukP9jNQyNnfzQVqwbwmAATY4Q==", + "license": "MIT" + }, + "node_modules/is-alphabetical": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/is-alphabetical/-/is-alphabetical-2.0.1.tgz", + "integrity": "sha512-FWyyY60MeTNyeSRpkM2Iry0G9hpr7/9kD40mD/cGQEuilcZYS4okz8SN2Q6rLCJ8gbCt6fN+rC+6tMGS99LaxQ==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/is-alphanumerical": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/is-alphanumerical/-/is-alphanumerical-2.0.1.tgz", + "integrity": "sha512-hmbYhX/9MUMF5uh7tOXyK/n0ZvWpad5caBA17GsC6vyuCqaWliRG5K1qS9inmUhEMaOBIW7/whAnSwveW/LtZw==", + "license": "MIT", + "dependencies": { + "is-alphabetical": "^2.0.0", + "is-decimal": "^2.0.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/is-decimal": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/is-decimal/-/is-decimal-2.0.1.tgz", + "integrity": "sha512-AAB9hiomQs5DXWcRB1rqsxGUstbRroFOPPVAomNk/3XHR5JyEZChOyTWe2oayKnsSsr/kcGqF+z6yuH6HHpN0A==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/is-extglob": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", + "integrity": "sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-glob": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", + "integrity": "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-extglob": "^2.1.1" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-hexadecimal": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/is-hexadecimal/-/is-hexadecimal-2.0.1.tgz", + "integrity": "sha512-DgZQp241c8oO6cA1SbTEWiXeoxV42vlcJxgH+B3hi1AiqqKruZR3ZGF8In3fj4+/y/7rHvlOZLZtgJ/4ttYGZg==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/is-number": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", + "integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==", + "license": "MIT", + "engines": { + "node": ">=0.12.0" + } + }, + "node_modules/is-plain-obj": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/is-plain-obj/-/is-plain-obj-4.1.0.tgz", + "integrity": "sha512-+Pgi+vMuUNkJyExiMBt5IlFoMyKnr5zhJ4Uspz58WOhBF5QoIZkFyNHIbBAtHwzVAgk5RtndVNsDRN61/mmDqg==", + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/isexe": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", + "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", + "dev": true, + "license": "ISC" + }, + "node_modules/jiti": { + "version": "2.4.2", + "resolved": "https://registry.npmjs.org/jiti/-/jiti-2.4.2.tgz", + "integrity": "sha512-rg9zJN+G4n2nfJl5MW3BMygZX56zKPNVEYYqq7adpmMh4Jn2QNEwhvQlFy6jPVdcod7txZtKHWnyZiA3a0zP7A==", + "license": "MIT", + "bin": { + "jiti": "lib/jiti-cli.mjs" + } + }, + "node_modules/js-tokens": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", + "integrity": "sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==", + "license": "MIT" + }, + "node_modules/js-yaml": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", + "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", + "dev": true, + "license": "MIT", + "dependencies": { + "argparse": "^2.0.1" + }, + "bin": { + "js-yaml": "bin/js-yaml.js" + } + }, + "node_modules/jsesc": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/jsesc/-/jsesc-3.1.0.tgz", + "integrity": "sha512-/sM3dO2FOzXjKQhJuo0Q173wf2KOo8t4I8vHy6lF9poUp7bKT0/NHE8fPX23PwfhnykfqnC2xRxOnVw5XuGIaA==", + "dev": true, + "license": "MIT", + "bin": { + "jsesc": "bin/jsesc" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/json-buffer": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/json-buffer/-/json-buffer-3.0.1.tgz", + "integrity": "sha512-4bV5BfR2mqfQTJm+V5tPPdf+ZpuhiIvTuAB5g8kcrXOZpTT/QwwVRWBywX1ozr6lEuPdbHxwaJlm9G6mI2sfSQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/json-schema-traverse": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", + "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", + "dev": true, + "license": "MIT" + }, + "node_modules/json-stable-stringify-without-jsonify": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/json-stable-stringify-without-jsonify/-/json-stable-stringify-without-jsonify-1.0.1.tgz", + "integrity": "sha512-Bdboy+l7tA3OGW6FjyFHWkP5LuByj1Tk33Ljyq0axyzdk9//JSi2u3fP1QSmd1KNwq6VOKYGlAu87CisVir6Pw==", + "dev": true, + "license": "MIT" + }, + "node_modules/json5": { + "version": "2.2.3", + "resolved": "https://registry.npmjs.org/json5/-/json5-2.2.3.tgz", + "integrity": "sha512-XmOWe7eyHYH14cLdVPoyg+GOH3rYX++KpzrylJwSW98t3Nk+U8XOl8FWKOgwtzdb8lXGf6zYwDUzeHMWfxasyg==", + "dev": true, + "license": "MIT", + "bin": { + "json5": "lib/cli.js" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/katex": { + "version": "0.16.21", + "resolved": "https://registry.npmjs.org/katex/-/katex-0.16.21.tgz", + "integrity": "sha512-XvqR7FgOHtWupfMiigNzmh+MgUVmDGU2kXZm899ZkPfcuoPuFxyHmXsgATDpFZDAXCI8tvinaVcDo8PIIJSo4A==", + "funding": [ + "https://opencollective.com/katex", + "https://github.com/sponsors/katex" + ], + "license": "MIT", + "dependencies": { + "commander": "^8.3.0" + }, + "bin": { + "katex": "cli.js" + } + }, + "node_modules/keyv": { + "version": "4.5.4", + "resolved": "https://registry.npmjs.org/keyv/-/keyv-4.5.4.tgz", + "integrity": "sha512-oxVHkHR/EJf2CNXnWxRLW6mg7JyCCUcG0DtEGmL2ctUo1PNTin1PUil+r/+4r5MpVgC/fn1kjsx7mjSujKqIpw==", + "dev": true, + "license": "MIT", + "dependencies": { + "json-buffer": "3.0.1" + } + }, + "node_modules/levn": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/levn/-/levn-0.4.1.tgz", + "integrity": "sha512-+bT2uH4E5LGE7h/n3evcS/sQlJXCpIp6ym8OWJ5eV6+67Dsql/LaaT7qJBAt2rzfoa/5QBGBhxDix1dMt2kQKQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "prelude-ls": "^1.2.1", + "type-check": "~0.4.0" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/lightningcss": { + "version": "1.29.2", + "resolved": "https://registry.npmjs.org/lightningcss/-/lightningcss-1.29.2.tgz", + "integrity": "sha512-6b6gd/RUXKaw5keVdSEtqFVdzWnU5jMxTUjA2bVcMNPLwSQ08Sv/UodBVtETLCn7k4S1Ibxwh7k68IwLZPgKaA==", + "license": "MPL-2.0", + "dependencies": { + "detect-libc": "^2.0.3" + }, + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + }, + "optionalDependencies": { + "lightningcss-darwin-arm64": "1.29.2", + "lightningcss-darwin-x64": "1.29.2", + "lightningcss-freebsd-x64": "1.29.2", + "lightningcss-linux-arm-gnueabihf": "1.29.2", + "lightningcss-linux-arm64-gnu": "1.29.2", + "lightningcss-linux-arm64-musl": "1.29.2", + "lightningcss-linux-x64-gnu": "1.29.2", + "lightningcss-linux-x64-musl": "1.29.2", + "lightningcss-win32-arm64-msvc": "1.29.2", + "lightningcss-win32-x64-msvc": "1.29.2" + } + }, + "node_modules/lightningcss-darwin-arm64": { + "version": "1.29.2", + "resolved": "https://registry.npmjs.org/lightningcss-darwin-arm64/-/lightningcss-darwin-arm64-1.29.2.tgz", + "integrity": "sha512-cK/eMabSViKn/PG8U/a7aCorpeKLMlK0bQeNHmdb7qUnBkNPnL+oV5DjJUo0kqWsJUapZsM4jCfYItbqBDvlcA==", + "cpu": [ + "arm64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-darwin-x64": { + "version": "1.29.2", + "resolved": "https://registry.npmjs.org/lightningcss-darwin-x64/-/lightningcss-darwin-x64-1.29.2.tgz", + "integrity": "sha512-j5qYxamyQw4kDXX5hnnCKMf3mLlHvG44f24Qyi2965/Ycz829MYqjrVg2H8BidybHBp9kom4D7DR5VqCKDXS0w==", + "cpu": [ + "x64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-freebsd-x64": { + "version": "1.29.2", + "resolved": "https://registry.npmjs.org/lightningcss-freebsd-x64/-/lightningcss-freebsd-x64-1.29.2.tgz", + "integrity": "sha512-wDk7M2tM78Ii8ek9YjnY8MjV5f5JN2qNVO+/0BAGZRvXKtQrBC4/cn4ssQIpKIPP44YXw6gFdpUF+Ps+RGsCwg==", + "cpu": [ + "x64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-linux-arm-gnueabihf": { + "version": "1.29.2", + "resolved": "https://registry.npmjs.org/lightningcss-linux-arm-gnueabihf/-/lightningcss-linux-arm-gnueabihf-1.29.2.tgz", + "integrity": "sha512-IRUrOrAF2Z+KExdExe3Rz7NSTuuJ2HvCGlMKoquK5pjvo2JY4Rybr+NrKnq0U0hZnx5AnGsuFHjGnNT14w26sg==", + "cpu": [ + "arm" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-linux-arm64-gnu": { + "version": "1.29.2", + "resolved": "https://registry.npmjs.org/lightningcss-linux-arm64-gnu/-/lightningcss-linux-arm64-gnu-1.29.2.tgz", + "integrity": "sha512-KKCpOlmhdjvUTX/mBuaKemp0oeDIBBLFiU5Fnqxh1/DZ4JPZi4evEH7TKoSBFOSOV3J7iEmmBaw/8dpiUvRKlQ==", + "cpu": [ + "arm64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-linux-arm64-musl": { + "version": "1.29.2", + "resolved": "https://registry.npmjs.org/lightningcss-linux-arm64-musl/-/lightningcss-linux-arm64-musl-1.29.2.tgz", + "integrity": "sha512-Q64eM1bPlOOUgxFmoPUefqzY1yV3ctFPE6d/Vt7WzLW4rKTv7MyYNky+FWxRpLkNASTnKQUaiMJ87zNODIrrKQ==", + "cpu": [ + "arm64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-linux-x64-gnu": { + "version": "1.29.2", + "resolved": "https://registry.npmjs.org/lightningcss-linux-x64-gnu/-/lightningcss-linux-x64-gnu-1.29.2.tgz", + "integrity": "sha512-0v6idDCPG6epLXtBH/RPkHvYx74CVziHo6TMYga8O2EiQApnUPZsbR9nFNrg2cgBzk1AYqEd95TlrsL7nYABQg==", + "cpu": [ + "x64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-linux-x64-musl": { + "version": "1.29.2", + "resolved": "https://registry.npmjs.org/lightningcss-linux-x64-musl/-/lightningcss-linux-x64-musl-1.29.2.tgz", + "integrity": "sha512-rMpz2yawkgGT8RULc5S4WiZopVMOFWjiItBT7aSfDX4NQav6M44rhn5hjtkKzB+wMTRlLLqxkeYEtQ3dd9696w==", + "cpu": [ + "x64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-win32-arm64-msvc": { + "version": "1.29.2", + "resolved": "https://registry.npmjs.org/lightningcss-win32-arm64-msvc/-/lightningcss-win32-arm64-msvc-1.29.2.tgz", + "integrity": "sha512-nL7zRW6evGQqYVu/bKGK+zShyz8OVzsCotFgc7judbt6wnB2KbiKKJwBE4SGoDBQ1O94RjW4asrCjQL4i8Fhbw==", + "cpu": [ + "arm64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-win32-x64-msvc": { + "version": "1.29.2", + "resolved": "https://registry.npmjs.org/lightningcss-win32-x64-msvc/-/lightningcss-win32-x64-msvc-1.29.2.tgz", + "integrity": "sha512-EdIUW3B2vLuHmv7urfzMI/h2fmlnOQBk1xlsDxkN1tCWKjNFjfLhGxYk8C8mzpSfr+A6jFFIi8fU6LbQGsRWjA==", + "cpu": [ + "x64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/locate-path": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-6.0.0.tgz", + "integrity": "sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw==", + "dev": true, + "license": "MIT", + "dependencies": { + "p-locate": "^5.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/lodash.merge": { + "version": "4.6.2", + "resolved": "https://registry.npmjs.org/lodash.merge/-/lodash.merge-4.6.2.tgz", + "integrity": "sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/longest-streak": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/longest-streak/-/longest-streak-3.1.0.tgz", + "integrity": "sha512-9Ri+o0JYgehTaVBBDoMqIl8GXtbWg711O3srftcHhZ0dqnETqLaoIK0x17fUw9rFSlK/0NlsKe0Ahhyl5pXE2g==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/loose-envify": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/loose-envify/-/loose-envify-1.4.0.tgz", + "integrity": "sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==", + "license": "MIT", + "dependencies": { + "js-tokens": "^3.0.0 || ^4.0.0" + }, + "bin": { + "loose-envify": "cli.js" + } + }, + "node_modules/lowlight": { + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/lowlight/-/lowlight-3.3.0.tgz", + "integrity": "sha512-0JNhgFoPvP6U6lE/UdVsSq99tn6DhjjpAj5MxG49ewd2mOBVtwWYIT8ClyABhq198aXXODMU6Ox8DrGy/CpTZQ==", + "license": "MIT", + "dependencies": { + "@types/hast": "^3.0.0", + "devlop": "^1.0.0", + "highlight.js": "~11.11.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/lru-cache": { + "version": "5.1.1", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-5.1.1.tgz", + "integrity": "sha512-KpNARQA3Iwv+jTA0utUVVbrh+Jlrr1Fv0e56GGzAFOXN7dk/FviaDW8LHmK52DlcH4WP2n6gI8vN1aesBFgo9w==", + "dev": true, + "license": "ISC", + "dependencies": { + "yallist": "^3.0.2" + } + }, + "node_modules/markdown-table": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/markdown-table/-/markdown-table-3.0.4.tgz", + "integrity": "sha512-wiYz4+JrLyb/DqW2hkFJxP7Vd7JuTDm77fvbM8VfEQdmSMqcImWeeRbHwZjBjIFki/VaMK2BhFi7oUUZeM5bqw==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/mdast-util-find-and-replace": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/mdast-util-find-and-replace/-/mdast-util-find-and-replace-3.0.2.tgz", + "integrity": "sha512-Tmd1Vg/m3Xz43afeNxDIhWRtFZgM2VLyaf4vSTYwudTyeuTneoL3qtWMA5jeLyz/O1vDJmmV4QuScFCA2tBPwg==", + "license": "MIT", + "dependencies": { + "@types/mdast": "^4.0.0", + "escape-string-regexp": "^5.0.0", + "unist-util-is": "^6.0.0", + "unist-util-visit-parents": "^6.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-find-and-replace/node_modules/escape-string-regexp": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-5.0.0.tgz", + "integrity": "sha512-/veY75JbMK4j1yjvuUxuVsiS/hr/4iHs9FTT6cgTexxdE0Ly/glccBAkloH/DofkjRbZU3bnoj38mOmhkZ0lHw==", + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/mdast-util-from-markdown": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/mdast-util-from-markdown/-/mdast-util-from-markdown-2.0.2.tgz", + "integrity": "sha512-uZhTV/8NBuw0WHkPTrCqDOl0zVe1BIng5ZtHoDk49ME1qqcjYmmLmOf0gELgcRMxN4w2iuIeVso5/6QymSrgmA==", + "license": "MIT", + "dependencies": { + "@types/mdast": "^4.0.0", + "@types/unist": "^3.0.0", + "decode-named-character-reference": "^1.0.0", + "devlop": "^1.0.0", + "mdast-util-to-string": "^4.0.0", + "micromark": "^4.0.0", + "micromark-util-decode-numeric-character-reference": "^2.0.0", + "micromark-util-decode-string": "^2.0.0", + "micromark-util-normalize-identifier": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0", + "unist-util-stringify-position": "^4.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-gfm": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/mdast-util-gfm/-/mdast-util-gfm-3.0.0.tgz", + "integrity": "sha512-dgQEX5Amaq+DuUqf26jJqSK9qgixgd6rYDHAv4aTBuA92cTknZlKpPfa86Z/s8Dj8xsAQpFfBmPUHWJBWqS4Bw==", + "license": "MIT", + "dependencies": { + "mdast-util-from-markdown": "^2.0.0", + "mdast-util-gfm-autolink-literal": "^2.0.0", + "mdast-util-gfm-footnote": "^2.0.0", + "mdast-util-gfm-strikethrough": "^2.0.0", + "mdast-util-gfm-table": "^2.0.0", + "mdast-util-gfm-task-list-item": "^2.0.0", + "mdast-util-to-markdown": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-gfm-autolink-literal": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/mdast-util-gfm-autolink-literal/-/mdast-util-gfm-autolink-literal-2.0.1.tgz", + "integrity": "sha512-5HVP2MKaP6L+G6YaxPNjuL0BPrq9orG3TsrZ9YXbA3vDw/ACI4MEsnoDpn6ZNm7GnZgtAcONJyPhOP8tNJQavQ==", + "license": "MIT", + "dependencies": { + "@types/mdast": "^4.0.0", + "ccount": "^2.0.0", + "devlop": "^1.0.0", + "mdast-util-find-and-replace": "^3.0.0", + "micromark-util-character": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-gfm-footnote": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/mdast-util-gfm-footnote/-/mdast-util-gfm-footnote-2.0.0.tgz", + "integrity": "sha512-5jOT2boTSVkMnQ7LTrd6n/18kqwjmuYqo7JUPe+tRCY6O7dAuTFMtTPauYYrMPpox9hlN0uOx/FL8XvEfG9/mQ==", + "license": "MIT", + "dependencies": { + "@types/mdast": "^4.0.0", + "devlop": "^1.1.0", + "mdast-util-from-markdown": "^2.0.0", + "mdast-util-to-markdown": "^2.0.0", + "micromark-util-normalize-identifier": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-gfm-strikethrough": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/mdast-util-gfm-strikethrough/-/mdast-util-gfm-strikethrough-2.0.0.tgz", + "integrity": "sha512-mKKb915TF+OC5ptj5bJ7WFRPdYtuHv0yTRxK2tJvi+BDqbkiG7h7u/9SI89nRAYcmap2xHQL9D+QG/6wSrTtXg==", + "license": "MIT", + "dependencies": { + "@types/mdast": "^4.0.0", + "mdast-util-from-markdown": "^2.0.0", + "mdast-util-to-markdown": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-gfm-table": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/mdast-util-gfm-table/-/mdast-util-gfm-table-2.0.0.tgz", + "integrity": "sha512-78UEvebzz/rJIxLvE7ZtDd/vIQ0RHv+3Mh5DR96p7cS7HsBhYIICDBCu8csTNWNO6tBWfqXPWekRuj2FNOGOZg==", + "license": "MIT", + "dependencies": { + "@types/mdast": "^4.0.0", + "devlop": "^1.0.0", + "markdown-table": "^3.0.0", + "mdast-util-from-markdown": "^2.0.0", + "mdast-util-to-markdown": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-gfm-task-list-item": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/mdast-util-gfm-task-list-item/-/mdast-util-gfm-task-list-item-2.0.0.tgz", + "integrity": "sha512-IrtvNvjxC1o06taBAVJznEnkiHxLFTzgonUdy8hzFVeDun0uTjxxrRGVaNFqkU1wJR3RBPEfsxmU6jDWPofrTQ==", + "license": "MIT", + "dependencies": { + "@types/mdast": "^4.0.0", + "devlop": "^1.0.0", + "mdast-util-from-markdown": "^2.0.0", + "mdast-util-to-markdown": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-math": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/mdast-util-math/-/mdast-util-math-3.0.0.tgz", + "integrity": "sha512-Tl9GBNeG/AhJnQM221bJR2HPvLOSnLE/T9cJI9tlc6zwQk2nPk/4f0cHkOdEixQPC/j8UtKDdITswvLAy1OZ1w==", + "license": "MIT", + "dependencies": { + "@types/hast": "^3.0.0", + "@types/mdast": "^4.0.0", + "devlop": "^1.0.0", + "longest-streak": "^3.0.0", + "mdast-util-from-markdown": "^2.0.0", + "mdast-util-to-markdown": "^2.1.0", + "unist-util-remove-position": "^5.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-mdx-expression": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/mdast-util-mdx-expression/-/mdast-util-mdx-expression-2.0.1.tgz", + "integrity": "sha512-J6f+9hUp+ldTZqKRSg7Vw5V6MqjATc+3E4gf3CFNcuZNWD8XdyI6zQ8GqH7f8169MM6P7hMBRDVGnn7oHB9kXQ==", + "license": "MIT", + "dependencies": { + "@types/estree-jsx": "^1.0.0", + "@types/hast": "^3.0.0", + "@types/mdast": "^4.0.0", + "devlop": "^1.0.0", + "mdast-util-from-markdown": "^2.0.0", + "mdast-util-to-markdown": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-mdx-jsx": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/mdast-util-mdx-jsx/-/mdast-util-mdx-jsx-3.2.0.tgz", + "integrity": "sha512-lj/z8v0r6ZtsN/cGNNtemmmfoLAFZnjMbNyLzBafjzikOM+glrjNHPlf6lQDOTccj9n5b0PPihEBbhneMyGs1Q==", + "license": "MIT", + "dependencies": { + "@types/estree-jsx": "^1.0.0", + "@types/hast": "^3.0.0", + "@types/mdast": "^4.0.0", + "@types/unist": "^3.0.0", + "ccount": "^2.0.0", + "devlop": "^1.1.0", + "mdast-util-from-markdown": "^2.0.0", + "mdast-util-to-markdown": "^2.0.0", + "parse-entities": "^4.0.0", + "stringify-entities": "^4.0.0", + "unist-util-stringify-position": "^4.0.0", + "vfile-message": "^4.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-mdxjs-esm": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/mdast-util-mdxjs-esm/-/mdast-util-mdxjs-esm-2.0.1.tgz", + "integrity": "sha512-EcmOpxsZ96CvlP03NghtH1EsLtr0n9Tm4lPUJUBccV9RwUOneqSycg19n5HGzCf+10LozMRSObtVr3ee1WoHtg==", + "license": "MIT", + "dependencies": { + "@types/estree-jsx": "^1.0.0", + "@types/hast": "^3.0.0", + "@types/mdast": "^4.0.0", + "devlop": "^1.0.0", + "mdast-util-from-markdown": "^2.0.0", + "mdast-util-to-markdown": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-newline-to-break": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/mdast-util-newline-to-break/-/mdast-util-newline-to-break-2.0.0.tgz", + "integrity": "sha512-MbgeFca0hLYIEx/2zGsszCSEJJ1JSCdiY5xQxRcLDDGa8EPvlLPupJ4DSajbMPAnC0je8jfb9TiUATnxxrHUog==", + "license": "MIT", + "dependencies": { + "@types/mdast": "^4.0.0", + "mdast-util-find-and-replace": "^3.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-phrasing": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/mdast-util-phrasing/-/mdast-util-phrasing-4.1.0.tgz", + "integrity": "sha512-TqICwyvJJpBwvGAMZjj4J2n0X8QWp21b9l0o7eXyVJ25YNWYbJDVIyD1bZXE6WtV6RmKJVYmQAKWa0zWOABz2w==", + "license": "MIT", + "dependencies": { + "@types/mdast": "^4.0.0", + "unist-util-is": "^6.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-to-hast": { + "version": "13.2.0", + "resolved": "https://registry.npmjs.org/mdast-util-to-hast/-/mdast-util-to-hast-13.2.0.tgz", + "integrity": "sha512-QGYKEuUsYT9ykKBCMOEDLsU5JRObWQusAolFMeko/tYPufNkRffBAQjIE+99jbA87xv6FgmjLtwjh9wBWajwAA==", + "license": "MIT", + "dependencies": { + "@types/hast": "^3.0.0", + "@types/mdast": "^4.0.0", + "@ungap/structured-clone": "^1.0.0", + "devlop": "^1.0.0", + "micromark-util-sanitize-uri": "^2.0.0", + "trim-lines": "^3.0.0", + "unist-util-position": "^5.0.0", + "unist-util-visit": "^5.0.0", + "vfile": "^6.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-to-markdown": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/mdast-util-to-markdown/-/mdast-util-to-markdown-2.1.2.tgz", + "integrity": "sha512-xj68wMTvGXVOKonmog6LwyJKrYXZPvlwabaryTjLh9LuvovB/KAH+kvi8Gjj+7rJjsFi23nkUxRQv1KqSroMqA==", + "license": "MIT", + "dependencies": { + "@types/mdast": "^4.0.0", + "@types/unist": "^3.0.0", + "longest-streak": "^3.0.0", + "mdast-util-phrasing": "^4.0.0", + "mdast-util-to-string": "^4.0.0", + "micromark-util-classify-character": "^2.0.0", + "micromark-util-decode-string": "^2.0.0", + "unist-util-visit": "^5.0.0", + "zwitch": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-to-string": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/mdast-util-to-string/-/mdast-util-to-string-4.0.0.tgz", + "integrity": "sha512-0H44vDimn51F0YwvxSJSm0eCDOJTRlmN0R1yBh4HLj9wiV1Dn0QoXGbvFAWj2hSItVTlCmBF1hqKlIyUBVFLPg==", + "license": "MIT", + "dependencies": { + "@types/mdast": "^4.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/merge2": { + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", + "integrity": "sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 8" + } + }, + "node_modules/micromark": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/micromark/-/micromark-4.0.1.tgz", + "integrity": "sha512-eBPdkcoCNvYcxQOAKAlceo5SNdzZWfF+FcSupREAzdAh9rRmE239CEQAiTwIgblwnoM8zzj35sZ5ZwvSEOF6Kw==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "@types/debug": "^4.0.0", + "debug": "^4.0.0", + "decode-named-character-reference": "^1.0.0", + "devlop": "^1.0.0", + "micromark-core-commonmark": "^2.0.0", + "micromark-factory-space": "^2.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-chunked": "^2.0.0", + "micromark-util-combine-extensions": "^2.0.0", + "micromark-util-decode-numeric-character-reference": "^2.0.0", + "micromark-util-encode": "^2.0.0", + "micromark-util-normalize-identifier": "^2.0.0", + "micromark-util-resolve-all": "^2.0.0", + "micromark-util-sanitize-uri": "^2.0.0", + "micromark-util-subtokenize": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-core-commonmark": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/micromark-core-commonmark/-/micromark-core-commonmark-2.0.2.tgz", + "integrity": "sha512-FKjQKbxd1cibWMM1P9N+H8TwlgGgSkWZMmfuVucLCHaYqeSvJ0hFeHsIa65pA2nYbes0f8LDHPMrd9X7Ujxg9w==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "decode-named-character-reference": "^1.0.0", + "devlop": "^1.0.0", + "micromark-factory-destination": "^2.0.0", + "micromark-factory-label": "^2.0.0", + "micromark-factory-space": "^2.0.0", + "micromark-factory-title": "^2.0.0", + "micromark-factory-whitespace": "^2.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-chunked": "^2.0.0", + "micromark-util-classify-character": "^2.0.0", + "micromark-util-html-tag-name": "^2.0.0", + "micromark-util-normalize-identifier": "^2.0.0", + "micromark-util-resolve-all": "^2.0.0", + "micromark-util-subtokenize": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-extension-gfm": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/micromark-extension-gfm/-/micromark-extension-gfm-3.0.0.tgz", + "integrity": "sha512-vsKArQsicm7t0z2GugkCKtZehqUm31oeGBV/KVSorWSy8ZlNAv7ytjFhvaryUiCUJYqs+NoE6AFhpQvBTM6Q4w==", + "license": "MIT", + "dependencies": { + "micromark-extension-gfm-autolink-literal": "^2.0.0", + "micromark-extension-gfm-footnote": "^2.0.0", + "micromark-extension-gfm-strikethrough": "^2.0.0", + "micromark-extension-gfm-table": "^2.0.0", + "micromark-extension-gfm-tagfilter": "^2.0.0", + "micromark-extension-gfm-task-list-item": "^2.0.0", + "micromark-util-combine-extensions": "^2.0.0", + "micromark-util-types": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/micromark-extension-gfm-autolink-literal": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-extension-gfm-autolink-literal/-/micromark-extension-gfm-autolink-literal-2.1.0.tgz", + "integrity": "sha512-oOg7knzhicgQ3t4QCjCWgTmfNhvQbDDnJeVu9v81r7NltNCVmhPy1fJRX27pISafdjL+SVc4d3l48Gb6pbRypw==", + "license": "MIT", + "dependencies": { + "micromark-util-character": "^2.0.0", + "micromark-util-sanitize-uri": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/micromark-extension-gfm-footnote": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-extension-gfm-footnote/-/micromark-extension-gfm-footnote-2.1.0.tgz", + "integrity": "sha512-/yPhxI1ntnDNsiHtzLKYnE3vf9JZ6cAisqVDauhp4CEHxlb4uoOTxOCJ+9s51bIB8U1N1FJ1RXOKTIlD5B/gqw==", + "license": "MIT", + "dependencies": { + "devlop": "^1.0.0", + "micromark-core-commonmark": "^2.0.0", + "micromark-factory-space": "^2.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-normalize-identifier": "^2.0.0", + "micromark-util-sanitize-uri": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/micromark-extension-gfm-strikethrough": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-extension-gfm-strikethrough/-/micromark-extension-gfm-strikethrough-2.1.0.tgz", + "integrity": "sha512-ADVjpOOkjz1hhkZLlBiYA9cR2Anf8F4HqZUO6e5eDcPQd0Txw5fxLzzxnEkSkfnD0wziSGiv7sYhk/ktvbf1uw==", + "license": "MIT", + "dependencies": { + "devlop": "^1.0.0", + "micromark-util-chunked": "^2.0.0", + "micromark-util-classify-character": "^2.0.0", + "micromark-util-resolve-all": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/micromark-extension-gfm-table": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/micromark-extension-gfm-table/-/micromark-extension-gfm-table-2.1.1.tgz", + "integrity": "sha512-t2OU/dXXioARrC6yWfJ4hqB7rct14e8f7m0cbI5hUmDyyIlwv5vEtooptH8INkbLzOatzKuVbQmAYcbWoyz6Dg==", + "license": "MIT", + "dependencies": { + "devlop": "^1.0.0", + "micromark-factory-space": "^2.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/micromark-extension-gfm-tagfilter": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-extension-gfm-tagfilter/-/micromark-extension-gfm-tagfilter-2.0.0.tgz", + "integrity": "sha512-xHlTOmuCSotIA8TW1mDIM6X2O1SiX5P9IuDtqGonFhEK0qgRI4yeC6vMxEV2dgyr2TiD+2PQ10o+cOhdVAcwfg==", + "license": "MIT", + "dependencies": { + "micromark-util-types": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/micromark-extension-gfm-task-list-item": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-extension-gfm-task-list-item/-/micromark-extension-gfm-task-list-item-2.1.0.tgz", + "integrity": "sha512-qIBZhqxqI6fjLDYFTBIa4eivDMnP+OZqsNwmQ3xNLE4Cxwc+zfQEfbs6tzAo2Hjq+bh6q5F+Z8/cksrLFYWQQw==", + "license": "MIT", + "dependencies": { + "devlop": "^1.0.0", + "micromark-factory-space": "^2.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/micromark-extension-math": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/micromark-extension-math/-/micromark-extension-math-3.1.0.tgz", + "integrity": "sha512-lvEqd+fHjATVs+2v/8kg9i5Q0AP2k85H0WUOwpIVvUML8BapsMvh1XAogmQjOCsLpoKRCVQqEkQBB3NhVBcsOg==", + "license": "MIT", + "dependencies": { + "@types/katex": "^0.16.0", + "devlop": "^1.0.0", + "katex": "^0.16.0", + "micromark-factory-space": "^2.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/micromark-factory-destination": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-factory-destination/-/micromark-factory-destination-2.0.1.tgz", + "integrity": "sha512-Xe6rDdJlkmbFRExpTOmRj9N3MaWmbAgdpSrBQvCFqhezUn4AHqJHbaEnfbVYYiexVSs//tqOdY/DxhjdCiJnIA==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-util-character": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-factory-label": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-factory-label/-/micromark-factory-label-2.0.1.tgz", + "integrity": "sha512-VFMekyQExqIW7xIChcXn4ok29YE3rnuyveW3wZQWWqF4Nv9Wk5rgJ99KzPvHjkmPXF93FXIbBp6YdW3t71/7Vg==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "devlop": "^1.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-factory-space": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-factory-space/-/micromark-factory-space-2.0.1.tgz", + "integrity": "sha512-zRkxjtBxxLd2Sc0d+fbnEunsTj46SWXgXciZmHq0kDYGnck/ZSGj9/wULTV95uoeYiK5hRXP2mJ98Uo4cq/LQg==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-util-character": "^2.0.0", + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-factory-title": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-factory-title/-/micromark-factory-title-2.0.1.tgz", + "integrity": "sha512-5bZ+3CjhAd9eChYTHsjy6TGxpOFSKgKKJPJxr293jTbfry2KDoWkhBb6TcPVB4NmzaPhMs1Frm9AZH7OD4Cjzw==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-factory-space": "^2.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-factory-whitespace": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-factory-whitespace/-/micromark-factory-whitespace-2.0.1.tgz", + "integrity": "sha512-Ob0nuZ3PKt/n0hORHyvoD9uZhr+Za8sFoP+OnMcnWK5lngSzALgQYKMr9RJVOWLqQYuyn6ulqGWSXdwf6F80lQ==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-factory-space": "^2.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-util-character": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.1.1.tgz", + "integrity": "sha512-wv8tdUTJ3thSFFFJKtpYKOYiGP2+v96Hvk4Tu8KpCAsTMs6yi+nVmGh1syvSCsaxz45J6Jbw+9DD6g97+NV67Q==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-util-chunked": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-util-chunked/-/micromark-util-chunked-2.0.1.tgz", + "integrity": "sha512-QUNFEOPELfmvv+4xiNg2sRYeS/P84pTW0TCgP5zc9FpXetHY0ab7SxKyAQCNCc1eK0459uoLI1y5oO5Vc1dbhA==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-util-symbol": "^2.0.0" + } + }, + "node_modules/micromark-util-classify-character": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-util-classify-character/-/micromark-util-classify-character-2.0.1.tgz", + "integrity": "sha512-K0kHzM6afW/MbeWYWLjoHQv1sgg2Q9EccHEDzSkxiP/EaagNzCm7T/WMKZ3rjMbvIpvBiZgwR3dKMygtA4mG1Q==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-util-character": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-util-combine-extensions": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-util-combine-extensions/-/micromark-util-combine-extensions-2.0.1.tgz", + "integrity": "sha512-OnAnH8Ujmy59JcyZw8JSbK9cGpdVY44NKgSM7E9Eh7DiLS2E9RNQf0dONaGDzEG9yjEl5hcqeIsj4hfRkLH/Bg==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-util-chunked": "^2.0.0", + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-util-decode-numeric-character-reference": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/micromark-util-decode-numeric-character-reference/-/micromark-util-decode-numeric-character-reference-2.0.2.tgz", + "integrity": "sha512-ccUbYk6CwVdkmCQMyr64dXz42EfHGkPQlBj5p7YVGzq8I7CtjXZJrubAYezf7Rp+bjPseiROqe7G6foFd+lEuw==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-util-symbol": "^2.0.0" + } + }, + "node_modules/micromark-util-decode-string": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-util-decode-string/-/micromark-util-decode-string-2.0.1.tgz", + "integrity": "sha512-nDV/77Fj6eH1ynwscYTOsbK7rR//Uj0bZXBwJZRfaLEJ1iGBR6kIfNmlNqaqJf649EP0F3NWNdeJi03elllNUQ==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "decode-named-character-reference": "^1.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-decode-numeric-character-reference": "^2.0.0", + "micromark-util-symbol": "^2.0.0" + } + }, + "node_modules/micromark-util-encode": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-util-encode/-/micromark-util-encode-2.0.1.tgz", + "integrity": "sha512-c3cVx2y4KqUnwopcO9b/SCdo2O67LwJJ/UyqGfbigahfegL9myoEFoDYZgkT7f36T0bLrM9hZTAaAyH+PCAXjw==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT" + }, + "node_modules/micromark-util-html-tag-name": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-util-html-tag-name/-/micromark-util-html-tag-name-2.0.1.tgz", + "integrity": "sha512-2cNEiYDhCWKI+Gs9T0Tiysk136SnR13hhO8yW6BGNyhOC4qYFnwF1nKfD3HFAIXA5c45RrIG1ub11GiXeYd1xA==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT" + }, + "node_modules/micromark-util-normalize-identifier": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-util-normalize-identifier/-/micromark-util-normalize-identifier-2.0.1.tgz", + "integrity": "sha512-sxPqmo70LyARJs0w2UclACPUUEqltCkJ6PhKdMIDuJ3gSf/Q+/GIe3WKl0Ijb/GyH9lOpUkRAO2wp0GVkLvS9Q==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-util-symbol": "^2.0.0" + } + }, + "node_modules/micromark-util-resolve-all": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-util-resolve-all/-/micromark-util-resolve-all-2.0.1.tgz", + "integrity": "sha512-VdQyxFWFT2/FGJgwQnJYbe1jjQoNTS4RjglmSjTUlpUMa95Htx9NHeYW4rGDJzbjvCsl9eLjMQwGeElsqmzcHg==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-util-sanitize-uri": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-util-sanitize-uri/-/micromark-util-sanitize-uri-2.0.1.tgz", + "integrity": "sha512-9N9IomZ/YuGGZZmQec1MbgxtlgougxTodVwDzzEouPKo3qFWvymFHWcnDi2vzV1ff6kas9ucW+o3yzJK9YB1AQ==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-util-character": "^2.0.0", + "micromark-util-encode": "^2.0.0", + "micromark-util-symbol": "^2.0.0" + } + }, + "node_modules/micromark-util-subtokenize": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/micromark-util-subtokenize/-/micromark-util-subtokenize-2.0.4.tgz", + "integrity": "sha512-N6hXjrin2GTJDe3MVjf5FuXpm12PGm80BrUAeub9XFXca8JZbP+oIwY4LJSVwFUCL1IPm/WwSVUN7goFHmSGGQ==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "devlop": "^1.0.0", + "micromark-util-chunked": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-util-symbol": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-util-symbol/-/micromark-util-symbol-2.0.1.tgz", + "integrity": "sha512-vs5t8Apaud9N28kgCrRUdEed4UJ+wWNvicHLPxCa9ENlYuAY31M0ETy5y1vA33YoNPDFTghEbnh6efaE8h4x0Q==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT" + }, + "node_modules/micromark-util-types": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-util-types/-/micromark-util-types-2.0.1.tgz", + "integrity": "sha512-534m2WhVTddrcKVepwmVEVnUAmtrx9bfIjNoQHRqfnvdaHQiFytEhJoTgpWJvDEXCO5gLTQh3wYC1PgOJA4NSQ==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT" + }, + "node_modules/micromatch": { + "version": "4.0.8", + "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.8.tgz", + "integrity": "sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==", + "license": "MIT", + "dependencies": { + "braces": "^3.0.3", + "picomatch": "^2.3.1" + }, + "engines": { + "node": ">=8.6" + } + }, + "node_modules/minimatch": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", + "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^1.1.7" + }, + "engines": { + "node": "*" + } + }, + "node_modules/ms": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", + "license": "MIT" + }, + "node_modules/nanoid": { + "version": "3.3.8", + "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.8.tgz", + "integrity": "sha512-WNLf5Sd8oZxOm+TzppcYk8gVOgP+l58xNy58D0nbUnOxOWRWvlcCV4kUF7ltmI6PsrLl/BgKEyS4mqsGChFN0w==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "bin": { + "nanoid": "bin/nanoid.cjs" + }, + "engines": { + "node": "^10 || ^12 || ^13.7 || ^14 || >=15.0.1" + } + }, + "node_modules/natural-compare": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/natural-compare/-/natural-compare-1.4.0.tgz", + "integrity": "sha512-OWND8ei3VtNC9h7V60qff3SVobHr996CTwgxubgyQYEpg290h9J0buyECNNJexkFm5sOajh5G116RYA1c8ZMSw==", + "dev": true, + "license": "MIT" + }, + "node_modules/node-releases": { + "version": "2.0.19", + "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.19.tgz", + "integrity": "sha512-xxOWJsBKtzAq7DY0J+DTzuz58K8e7sJbdgwkbMWQe8UYB6ekmsQ45q0M/tJDsGaZmbC+l7n57UV8Hl5tHxO9uw==", + "license": "MIT" + }, + "node_modules/normalize-range": { + "version": "0.1.2", + "resolved": "https://registry.npmjs.org/normalize-range/-/normalize-range-0.1.2.tgz", + "integrity": "sha512-bdok/XvKII3nUpklnV6P2hxtMNrCboOjAcyBuQnWEhO665FwrSNRxU+AqpsyvO6LgGYPspN+lu5CLtw4jPRKNA==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/object-assign": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz", + "integrity": "sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/optionator": { + "version": "0.9.4", + "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz", + "integrity": "sha512-6IpQ7mKUxRcZNLIObR0hz7lxsapSSIYNZJwXPGeF0mTVqGKFIXj1DQcMoT22S3ROcLyY/rz0PWaWZ9ayWmad9g==", + "dev": true, + "license": "MIT", + "dependencies": { + "deep-is": "^0.1.3", + "fast-levenshtein": "^2.0.6", + "levn": "^0.4.1", + "prelude-ls": "^1.2.1", + "type-check": "^0.4.0", + "word-wrap": "^1.2.5" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/p-limit": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-3.1.0.tgz", + "integrity": "sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "yocto-queue": "^0.1.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/p-locate": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/p-locate/-/p-locate-5.0.0.tgz", + "integrity": "sha512-LaNjtRWUBY++zB5nE/NwcaoMylSPk+S+ZHNB1TzdbMJMny6dynpAGt7X/tl/QYq3TIeE6nxHppbo2LGymrG5Pw==", + "dev": true, + "license": "MIT", + "dependencies": { + "p-limit": "^3.0.2" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/parent-module": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/parent-module/-/parent-module-1.0.1.tgz", + "integrity": "sha512-GQ2EWRpQV8/o+Aw8YqtfZZPfNRWZYkbidE9k5rpl/hC3vtHHBfGm2Ifi6qWV+coDGkrUKZAxE3Lot5kcsRlh+g==", + "dev": true, + "license": "MIT", + "dependencies": { + "callsites": "^3.0.0" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/parse-entities": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/parse-entities/-/parse-entities-4.0.2.tgz", + "integrity": "sha512-GG2AQYWoLgL877gQIKeRPGO1xF9+eG1ujIb5soS5gPvLQ1y2o8FL90w2QWNdf9I361Mpp7726c+lj3U0qK1uGw==", + "license": "MIT", + "dependencies": { + "@types/unist": "^2.0.0", + "character-entities-legacy": "^3.0.0", + "character-reference-invalid": "^2.0.0", + "decode-named-character-reference": "^1.0.0", + "is-alphanumerical": "^2.0.0", + "is-decimal": "^2.0.0", + "is-hexadecimal": "^2.0.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/parse-entities/node_modules/@types/unist": { + "version": "2.0.11", + "resolved": "https://registry.npmjs.org/@types/unist/-/unist-2.0.11.tgz", + "integrity": "sha512-CmBKiL6NNo/OqgmMn95Fk9Whlp2mtvIv+KNpQKN2F4SjvrEesubTRWGYSg+BnWZOnlCaSTU1sMpsBOzgbYhnsA==", + "license": "MIT" + }, + "node_modules/parse5": { + "version": "7.2.1", + "resolved": "https://registry.npmjs.org/parse5/-/parse5-7.2.1.tgz", + "integrity": "sha512-BuBYQYlv1ckiPdQi/ohiivi9Sagc9JG+Ozs0r7b/0iK3sKmrb0b9FdWdBbOdx6hBCM/F9Ir82ofnBhtZOjCRPQ==", + "license": "MIT", + "dependencies": { + "entities": "^4.5.0" + }, + "funding": { + "url": "https://github.com/inikulin/parse5?sponsor=1" + } + }, + "node_modules/path-exists": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-4.0.0.tgz", + "integrity": "sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/path-key": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", + "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/pdfjs-dist": { + "version": "5.2.133", + "resolved": "https://registry.npmjs.org/pdfjs-dist/-/pdfjs-dist-5.2.133.tgz", + "integrity": "sha512-abE6ZWDxztt+gGFzfm4bX2ggfxUk9wsDEoFzIJm9LozaY3JdXR7jyLK4Bjs+XLXplCduuWS1wGhPC4tgTn/kzg==", + "license": "Apache-2.0", + "engines": { + "node": ">=20.16.0 || >=22.3.0" + }, + "optionalDependencies": { + "@napi-rs/canvas": "^0.1.67" + } + }, + "node_modules/picocolors": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", + "integrity": "sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==", + "license": "ISC" + }, + "node_modules/picomatch": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", + "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", + "license": "MIT", + "engines": { + "node": ">=8.6" + }, + "funding": { + "url": "https://github.com/sponsors/jonschlinkert" + } + }, + "node_modules/postcss": { + "version": "8.5.1", + "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.1.tgz", + "integrity": "sha512-6oz2beyjc5VMn/KV1pPw8fliQkhBXrVn1Z3TVyqZxU8kZpzEKhBdmCFqI6ZbmGtamQvQGuU1sgPTk8ZrXDD7jQ==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/postcss" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "dependencies": { + "nanoid": "^3.3.8", + "picocolors": "^1.1.1", + "source-map-js": "^1.2.1" + }, + "engines": { + "node": "^10 || ^12 || >=14" + } + }, + "node_modules/postcss-value-parser": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/postcss-value-parser/-/postcss-value-parser-4.2.0.tgz", + "integrity": "sha512-1NNCs6uurfkVbeXG4S8JFT9t19m45ICnif8zWLd5oPSZ50QnwMfK+H3jv408d4jw/7Bttv5axS5IiHoLaVNHeQ==", + "license": "MIT" + }, + "node_modules/prelude-ls": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/prelude-ls/-/prelude-ls-1.2.1.tgz", + "integrity": "sha512-vkcDPrRZo1QZLbn5RLGPpg/WmIQ65qoWWhcGKf/b5eplkkarX0m9z8ppCat4mlOqUsWpyNuYgO3VRyrYHSzX5g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/prettier": { + "version": "3.4.2", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.4.2.tgz", + "integrity": "sha512-e9MewbtFo+Fevyuxn/4rrcDAaq0IYxPGLvObpQjiZBMAzB9IGmzlnG9RZy3FFas+eBMu2vA0CszMeduow5dIuQ==", + "dev": true, + "license": "MIT", + "bin": { + "prettier": "bin/prettier.cjs" + }, + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/prettier/prettier?sponsor=1" + } + }, + "node_modules/prop-types": { + "version": "15.8.1", + "resolved": "https://registry.npmjs.org/prop-types/-/prop-types-15.8.1.tgz", + "integrity": "sha512-oj87CgZICdulUohogVAR7AjlC0327U4el4L6eAvOqCeudMDVU0NThNaV+b9Df4dXgSP1gXMTnPdhfe/2qDH5cg==", + "license": "MIT", + "dependencies": { + "loose-envify": "^1.4.0", + "object-assign": "^4.1.1", + "react-is": "^16.13.1" + } + }, + "node_modules/property-information": { + "version": "6.5.0", + "resolved": "https://registry.npmjs.org/property-information/-/property-information-6.5.0.tgz", + "integrity": "sha512-PgTgs/BlvHxOu8QuEN7wi5A0OmXaBcHpmCSTehcs6Uuu9IkDIEo13Hy7n898RHfrQ49vKCoGeWZSaAK01nwVig==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/punycode": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz", + "integrity": "sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/queue-microtask": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz", + "integrity": "sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "MIT" + }, + "node_modules/react": { + "version": "18.3.1", + "resolved": "https://registry.npmjs.org/react/-/react-18.3.1.tgz", + "integrity": "sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==", + "license": "MIT", + "dependencies": { + "loose-envify": "^1.1.0" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/react-dom": { + "version": "18.3.1", + "resolved": "https://registry.npmjs.org/react-dom/-/react-dom-18.3.1.tgz", + "integrity": "sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw==", + "license": "MIT", + "dependencies": { + "loose-envify": "^1.1.0", + "scheduler": "^0.23.2" + }, + "peerDependencies": { + "react": "^18.3.1" + } + }, + "node_modules/react-dropzone": { + "version": "14.3.8", + "resolved": "https://registry.npmjs.org/react-dropzone/-/react-dropzone-14.3.8.tgz", + "integrity": "sha512-sBgODnq+lcA4P296DY4wacOZz3JFpD99fp+hb//iBO2HHnyeZU3FwWyXJ6salNpqQdsZrgMrotuko/BdJMV8Ug==", + "license": "MIT", + "dependencies": { + "attr-accept": "^2.2.4", + "file-selector": "^2.1.0", + "prop-types": "^15.8.1" + }, + "engines": { + "node": ">= 10.13" + }, + "peerDependencies": { + "react": ">= 16.8 || 18.0.0" + } + }, + "node_modules/react-hot-toast": { + "version": "2.5.2", + "resolved": "https://registry.npmjs.org/react-hot-toast/-/react-hot-toast-2.5.2.tgz", + "integrity": "sha512-Tun3BbCxzmXXM7C+NI4qiv6lT0uwGh4oAfeJyNOjYUejTsm35mK9iCaYLGv8cBz9L5YxZLx/2ii7zsIwPtPUdw==", + "license": "MIT", + "dependencies": { + "csstype": "^3.1.3", + "goober": "^2.1.16" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "react": ">=16", + "react-dom": ">=16" + } + }, + "node_modules/react-is": { + "version": "16.13.1", + "resolved": "https://registry.npmjs.org/react-is/-/react-is-16.13.1.tgz", + "integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==", + "license": "MIT" + }, + "node_modules/react-markdown": { + "version": "9.0.3", + "resolved": "https://registry.npmjs.org/react-markdown/-/react-markdown-9.0.3.tgz", + "integrity": "sha512-Yk7Z94dbgYTOrdk41Z74GoKA7rThnsbbqBTRYuxoe08qvfQ9tJVhmAKw6BJS/ZORG7kTy/s1QvYzSuaoBA1qfw==", + "license": "MIT", + "dependencies": { + "@types/hast": "^3.0.0", + "devlop": "^1.0.0", + "hast-util-to-jsx-runtime": "^2.0.0", + "html-url-attributes": "^3.0.0", + "mdast-util-to-hast": "^13.0.0", + "remark-parse": "^11.0.0", + "remark-rehype": "^11.0.0", + "unified": "^11.0.0", + "unist-util-visit": "^5.0.0", + "vfile": "^6.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + }, + "peerDependencies": { + "@types/react": ">=18", + "react": ">=18" + } + }, + "node_modules/react-refresh": { + "version": "0.14.2", + "resolved": "https://registry.npmjs.org/react-refresh/-/react-refresh-0.14.2.tgz", + "integrity": "sha512-jCvmsr+1IUSMUyzOkRcvnVbX3ZYC6g9TDrDbFuFmRDq7PD4yaGbLKNQL6k2jnArV8hjYxh7hVhAZB6s9HDGpZA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/react-router": { + "version": "7.1.5", + "resolved": "https://registry.npmjs.org/react-router/-/react-router-7.1.5.tgz", + "integrity": "sha512-8BUF+hZEU4/z/JD201yK6S+UYhsf58bzYIDq2NS1iGpwxSXDu7F+DeGSkIXMFBuHZB21FSiCzEcUb18cQNdRkA==", + "license": "MIT", + "dependencies": { + "@types/cookie": "^0.6.0", + "cookie": "^1.0.1", + "set-cookie-parser": "^2.6.0", + "turbo-stream": "2.4.0" + }, + "engines": { + "node": ">=20.0.0" + }, + "peerDependencies": { + "react": ">=18", + "react-dom": ">=18" + }, + "peerDependenciesMeta": { + "react-dom": { + "optional": true + } + } + }, + "node_modules/rehype-highlight": { + "version": "7.0.2", + "resolved": "https://registry.npmjs.org/rehype-highlight/-/rehype-highlight-7.0.2.tgz", + "integrity": "sha512-k158pK7wdC2qL3M5NcZROZ2tR/l7zOzjxXd5VGdcfIyoijjQqpHd3JKtYSBDpDZ38UI2WJWuFAtkMDxmx5kstA==", + "license": "MIT", + "dependencies": { + "@types/hast": "^3.0.0", + "hast-util-to-text": "^4.0.0", + "lowlight": "^3.0.0", + "unist-util-visit": "^5.0.0", + "vfile": "^6.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/rehype-katex": { + "version": "7.0.1", + "resolved": "https://registry.npmjs.org/rehype-katex/-/rehype-katex-7.0.1.tgz", + "integrity": "sha512-OiM2wrZ/wuhKkigASodFoo8wimG3H12LWQaH8qSPVJn9apWKFSH3YOCtbKpBorTVw/eI7cuT21XBbvwEswbIOA==", + "license": "MIT", + "dependencies": { + "@types/hast": "^3.0.0", + "@types/katex": "^0.16.0", + "hast-util-from-html-isomorphic": "^2.0.0", + "hast-util-to-text": "^4.0.0", + "katex": "^0.16.0", + "unist-util-visit-parents": "^6.0.0", + "vfile": "^6.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/remark-breaks": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/remark-breaks/-/remark-breaks-4.0.0.tgz", + "integrity": "sha512-IjEjJOkH4FuJvHZVIW0QCDWxcG96kCq7An/KVH2NfJe6rKZU2AsHeB3OEjPNRxi4QC34Xdx7I2KGYn6IpT7gxQ==", + "license": "MIT", + "dependencies": { + "@types/mdast": "^4.0.0", + "mdast-util-newline-to-break": "^2.0.0", + "unified": "^11.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/remark-gfm": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/remark-gfm/-/remark-gfm-4.0.0.tgz", + "integrity": "sha512-U92vJgBPkbw4Zfu/IiW2oTZLSL3Zpv+uI7My2eq8JxKgqraFdU8YUGicEJCEgSbeaG+QDFqIcwwfMTOEelPxuA==", + "license": "MIT", + "dependencies": { + "@types/mdast": "^4.0.0", + "mdast-util-gfm": "^3.0.0", + "micromark-extension-gfm": "^3.0.0", + "remark-parse": "^11.0.0", + "remark-stringify": "^11.0.0", + "unified": "^11.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/remark-math": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/remark-math/-/remark-math-6.0.0.tgz", + "integrity": "sha512-MMqgnP74Igy+S3WwnhQ7kqGlEerTETXMvJhrUzDikVZ2/uogJCb+WHUg97hK9/jcfc0dkD73s3LN8zU49cTEtA==", + "license": "MIT", + "dependencies": { + "@types/mdast": "^4.0.0", + "mdast-util-math": "^3.0.0", + "micromark-extension-math": "^3.0.0", + "unified": "^11.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/remark-parse": { + "version": "11.0.0", + "resolved": "https://registry.npmjs.org/remark-parse/-/remark-parse-11.0.0.tgz", + "integrity": "sha512-FCxlKLNGknS5ba/1lmpYijMUzX2esxW5xQqjWxw2eHFfS2MSdaHVINFmhjo+qN1WhZhNimq0dZATN9pH0IDrpA==", + "license": "MIT", + "dependencies": { + "@types/mdast": "^4.0.0", + "mdast-util-from-markdown": "^2.0.0", + "micromark-util-types": "^2.0.0", + "unified": "^11.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/remark-rehype": { + "version": "11.1.1", + "resolved": "https://registry.npmjs.org/remark-rehype/-/remark-rehype-11.1.1.tgz", + "integrity": "sha512-g/osARvjkBXb6Wo0XvAeXQohVta8i84ACbenPpoSsxTOQH/Ae0/RGP4WZgnMH5pMLpsj4FG7OHmcIcXxpza8eQ==", + "license": "MIT", + "dependencies": { + "@types/hast": "^3.0.0", + "@types/mdast": "^4.0.0", + "mdast-util-to-hast": "^13.0.0", + "unified": "^11.0.0", + "vfile": "^6.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/remark-stringify": { + "version": "11.0.0", + "resolved": "https://registry.npmjs.org/remark-stringify/-/remark-stringify-11.0.0.tgz", + "integrity": "sha512-1OSmLd3awB/t8qdoEOMazZkNsfVTeY4fTsgzcQFdXNq8ToTN4ZGwrMnlda4K6smTFKD+GRV6O48i6Z4iKgPPpw==", + "license": "MIT", + "dependencies": { + "@types/mdast": "^4.0.0", + "mdast-util-to-markdown": "^2.0.0", + "unified": "^11.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/resolve-from": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-4.0.0.tgz", + "integrity": "sha512-pb/MYmXstAkysRFx8piNI1tGFNQIFA3vkE3Gq4EuA1dF6gHp/+vgZqsCGJapvy8N3Q+4o7FwvquPJcnZ7RYy4g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=4" + } + }, + "node_modules/reusify": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/reusify/-/reusify-1.0.4.tgz", + "integrity": "sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw==", + "dev": true, + "license": "MIT", + "engines": { + "iojs": ">=1.0.0", + "node": ">=0.10.0" + } + }, + "node_modules/rollup": { + "version": "4.34.2", + "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.34.2.tgz", + "integrity": "sha512-sBDUoxZEaqLu9QeNalL8v3jw6WjPku4wfZGyTU7l7m1oC+rpRihXc/n/H+4148ZkGz5Xli8CHMns//fFGKvpIQ==", + "license": "MIT", + "dependencies": { + "@types/estree": "1.0.6" + }, + "bin": { + "rollup": "dist/bin/rollup" + }, + "engines": { + "node": ">=18.0.0", + "npm": ">=8.0.0" + }, + "optionalDependencies": { + "@rollup/rollup-android-arm-eabi": "4.34.2", + "@rollup/rollup-android-arm64": "4.34.2", + "@rollup/rollup-darwin-arm64": "4.34.2", + "@rollup/rollup-darwin-x64": "4.34.2", + "@rollup/rollup-freebsd-arm64": "4.34.2", + "@rollup/rollup-freebsd-x64": "4.34.2", + "@rollup/rollup-linux-arm-gnueabihf": "4.34.2", + "@rollup/rollup-linux-arm-musleabihf": "4.34.2", + "@rollup/rollup-linux-arm64-gnu": "4.34.2", + "@rollup/rollup-linux-arm64-musl": "4.34.2", + "@rollup/rollup-linux-loongarch64-gnu": "4.34.2", + "@rollup/rollup-linux-powerpc64le-gnu": "4.34.2", + "@rollup/rollup-linux-riscv64-gnu": "4.34.2", + "@rollup/rollup-linux-s390x-gnu": "4.34.2", + "@rollup/rollup-linux-x64-gnu": "4.34.2", + "@rollup/rollup-linux-x64-musl": "4.34.2", + "@rollup/rollup-win32-arm64-msvc": "4.34.2", + "@rollup/rollup-win32-ia32-msvc": "4.34.2", + "@rollup/rollup-win32-x64-msvc": "4.34.2", + "fsevents": "~2.3.2" + } + }, + "node_modules/run-parallel": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz", + "integrity": "sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "MIT", + "dependencies": { + "queue-microtask": "^1.2.2" + } + }, + "node_modules/rxjs": { + "version": "7.8.1", + "resolved": "https://registry.npmjs.org/rxjs/-/rxjs-7.8.1.tgz", + "integrity": "sha512-AA3TVj+0A2iuIoQkWEK/tqFjBq2j+6PO6Y0zJcvzLAFhEFIO3HL0vls9hWLncZbAAbK0mar7oZ4V079I/qPMxg==", + "devOptional": true, + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.1.0" + } + }, + "node_modules/sass-embedded": { + "version": "1.83.4", + "resolved": "https://registry.npmjs.org/sass-embedded/-/sass-embedded-1.83.4.tgz", + "integrity": "sha512-Hf2burRA/y5PGxsg6jB9UpoK/xZ6g/pgrkOcdl6j+rRg1Zj8XhGKZ1MTysZGtTPUUmiiErqzkP5+Kzp95yv9GQ==", + "devOptional": true, + "license": "MIT", + "dependencies": { + "@bufbuild/protobuf": "^2.0.0", + "buffer-builder": "^0.2.0", + "colorjs.io": "^0.5.0", + "immutable": "^5.0.2", + "rxjs": "^7.4.0", + "supports-color": "^8.1.1", + "sync-child-process": "^1.0.2", + "varint": "^6.0.0" + }, + "bin": { + "sass": "dist/bin/sass.js" + }, + "engines": { + "node": ">=16.0.0" + }, + "optionalDependencies": { + "sass-embedded-android-arm": "1.83.4", + "sass-embedded-android-arm64": "1.83.4", + "sass-embedded-android-ia32": "1.83.4", + "sass-embedded-android-riscv64": "1.83.4", + "sass-embedded-android-x64": "1.83.4", + "sass-embedded-darwin-arm64": "1.83.4", + "sass-embedded-darwin-x64": "1.83.4", + "sass-embedded-linux-arm": "1.83.4", + "sass-embedded-linux-arm64": "1.83.4", + "sass-embedded-linux-ia32": "1.83.4", + "sass-embedded-linux-musl-arm": "1.83.4", + "sass-embedded-linux-musl-arm64": "1.83.4", + "sass-embedded-linux-musl-ia32": "1.83.4", + "sass-embedded-linux-musl-riscv64": "1.83.4", + "sass-embedded-linux-musl-x64": "1.83.4", + "sass-embedded-linux-riscv64": "1.83.4", + "sass-embedded-linux-x64": "1.83.4", + "sass-embedded-win32-arm64": "1.83.4", + "sass-embedded-win32-ia32": "1.83.4", + "sass-embedded-win32-x64": "1.83.4" + } + }, + "node_modules/sass-embedded-android-arm": { + "version": "1.83.4", + "resolved": "https://registry.npmjs.org/sass-embedded-android-arm/-/sass-embedded-android-arm-1.83.4.tgz", + "integrity": "sha512-9Z4pJAOgEkXa3VDY/o+U6l5XvV0mZTJcSl0l/mSPHihjAHSpLYnOW6+KOWeM8dxqrsqTYcd6COzhanI/a++5Gw==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-android-arm64": { + "version": "1.83.4", + "resolved": "https://registry.npmjs.org/sass-embedded-android-arm64/-/sass-embedded-android-arm64-1.83.4.tgz", + "integrity": "sha512-tgX4FzmbVqnQmD67ZxQDvI+qFNABrboOQgwsG05E5bA/US42zGajW9AxpECJYiMXVOHmg+d81ICbjb0fsVHskw==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-android-ia32": { + "version": "1.83.4", + "resolved": "https://registry.npmjs.org/sass-embedded-android-ia32/-/sass-embedded-android-ia32-1.83.4.tgz", + "integrity": "sha512-RsFOziFqPcfZXdFRULC4Ayzy9aK6R6FwQ411broCjlOBX+b0gurjRadkue3cfUEUR5mmy0KeCbp7zVKPLTK+5Q==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-android-riscv64": { + "version": "1.83.4", + "resolved": "https://registry.npmjs.org/sass-embedded-android-riscv64/-/sass-embedded-android-riscv64-1.83.4.tgz", + "integrity": "sha512-EHwh0nmQarBBrMRU928eTZkFGx19k/XW2YwbPR4gBVdWLkbTgCA5aGe8hTE6/1zStyx++3nDGvTZ78+b/VvvLg==", + "cpu": [ + "riscv64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-android-x64": { + "version": "1.83.4", + "resolved": "https://registry.npmjs.org/sass-embedded-android-x64/-/sass-embedded-android-x64-1.83.4.tgz", + "integrity": "sha512-0PgQNuPWYy1jEOEPDVsV89KfqOsMLIp9CSbjBY7jRcwRhyVAcigqrUG6bDeNtojHUYKA1kU+Eh/85WxOHUOgBw==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-darwin-arm64": { + "version": "1.83.4", + "resolved": "https://registry.npmjs.org/sass-embedded-darwin-arm64/-/sass-embedded-darwin-arm64-1.83.4.tgz", + "integrity": "sha512-rp2ywymWc3nymnSnAFG5R/8hvxWCsuhK3wOnD10IDlmNB7o4rzKby1c+2ZfpQGowlYGWsWWTgz8FW2qzmZsQRw==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-darwin-x64": { + "version": "1.83.4", + "resolved": "https://registry.npmjs.org/sass-embedded-darwin-x64/-/sass-embedded-darwin-x64-1.83.4.tgz", + "integrity": "sha512-kLkN2lXz9PCgGfDS8Ev5YVcl/V2173L6379en/CaFuJJi7WiyPgBymW7hOmfCt4uO4R1y7CP2Uc08DRtZsBlAA==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-arm": { + "version": "1.83.4", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-arm/-/sass-embedded-linux-arm-1.83.4.tgz", + "integrity": "sha512-nL90ryxX2lNmFucr9jYUyHHx21AoAgdCL1O5Ltx2rKg2xTdytAGHYo2MT5S0LIeKLa/yKP/hjuSvrbICYNDvtA==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-arm64": { + "version": "1.83.4", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-arm64/-/sass-embedded-linux-arm64-1.83.4.tgz", + "integrity": "sha512-E0zjsZX2HgESwyqw31EHtI39DKa7RgK7nvIhIRco1d0QEw227WnoR9pjH3M/ZQy4gQj3GKilOFHM5Krs/omeIA==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-ia32": { + "version": "1.83.4", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-ia32/-/sass-embedded-linux-ia32-1.83.4.tgz", + "integrity": "sha512-ew5HpchSzgAYbQoriRh8QhlWn5Kw2nQ2jHoV9YLwGKe3fwwOWA0KDedssvDv7FWnY/FCqXyymhLd6Bxae4Xquw==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-musl-arm": { + "version": "1.83.4", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-musl-arm/-/sass-embedded-linux-musl-arm-1.83.4.tgz", + "integrity": "sha512-0RrJRwMrmm+gG0VOB5b5Cjs7Sd+lhqpQJa6EJNEaZHljJokEfpE5GejZsGMRMIQLxEvVphZnnxl6sonCGFE/QQ==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-musl-arm64": { + "version": "1.83.4", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-musl-arm64/-/sass-embedded-linux-musl-arm64-1.83.4.tgz", + "integrity": "sha512-IzMgalf6MZOxgp4AVCgsaWAFDP/IVWOrgVXxkyhw29fyAEoSWBJH4k87wyPhEtxSuzVHLxKNbc8k3UzdWmlBFg==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-musl-ia32": { + "version": "1.83.4", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-musl-ia32/-/sass-embedded-linux-musl-ia32-1.83.4.tgz", + "integrity": "sha512-LLb4lYbcxPzX4UaJymYXC+WwokxUlfTJEFUv5VF0OTuSsHAGNRs/rslPtzVBTvMeG9TtlOQDhku1F7G6iaDotA==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-musl-riscv64": { + "version": "1.83.4", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-musl-riscv64/-/sass-embedded-linux-musl-riscv64-1.83.4.tgz", + "integrity": "sha512-zoKlPzD5Z13HKin1UGR74QkEy+kZEk2AkGX5RelRG494mi+IWwRuWCppXIovor9+BQb9eDWPYPoMVahwN5F7VA==", + "cpu": [ + "riscv64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-musl-x64": { + "version": "1.83.4", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-musl-x64/-/sass-embedded-linux-musl-x64-1.83.4.tgz", + "integrity": "sha512-hB8+/PYhfEf2zTIcidO5Bpof9trK6WJjZ4T8g2MrxQh8REVtdPcgIkoxczRynqybf9+fbqbUwzXtiUao2GV+vQ==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-riscv64": { + "version": "1.83.4", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-riscv64/-/sass-embedded-linux-riscv64-1.83.4.tgz", + "integrity": "sha512-83fL4n+oeDJ0Y4KjASmZ9jHS1Vl9ESVQYHMhJE0i4xDi/P3BNarm2rsKljq/QtrwGpbqwn8ujzOu7DsNCMDSHA==", + "cpu": [ + "riscv64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-x64": { + "version": "1.83.4", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-x64/-/sass-embedded-linux-x64-1.83.4.tgz", + "integrity": "sha512-NlnGdvCmTD5PK+LKXlK3sAuxOgbRIEoZfnHvxd157imCm/s2SYF/R28D0DAAjEViyI8DovIWghgbcqwuertXsA==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-win32-arm64": { + "version": "1.83.4", + "resolved": "https://registry.npmjs.org/sass-embedded-win32-arm64/-/sass-embedded-win32-arm64-1.83.4.tgz", + "integrity": "sha512-J2BFKrEaeSrVazU2qTjyQdAk+MvbzJeTuCET0uAJEXSKtvQ3AzxvzndS7LqkDPbF32eXAHLw8GVpwcBwKbB3Uw==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-win32-ia32": { + "version": "1.83.4", + "resolved": "https://registry.npmjs.org/sass-embedded-win32-ia32/-/sass-embedded-win32-ia32-1.83.4.tgz", + "integrity": "sha512-uPAe9T/5sANFhJS5dcfAOhOJy8/l2TRYG4r+UO3Wp4yhqbN7bggPvY9c7zMYS0OC8tU/bCvfYUDFHYMCl91FgA==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-win32-x64": { + "version": "1.83.4", + "resolved": "https://registry.npmjs.org/sass-embedded-win32-x64/-/sass-embedded-win32-x64-1.83.4.tgz", + "integrity": "sha512-C9fkDY0jKITdJFij4UbfPFswxoXN9O/Dr79v17fJnstVwtUojzVJWKHUXvF0Zg2LIR7TCc4ju3adejKFxj7ueA==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded/node_modules/supports-color": { + "version": "8.1.1", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-8.1.1.tgz", + "integrity": "sha512-MpUEN2OodtUzxvKQl72cUF7RQ5EiHsGvSsVG0ia9c5RbWGL2CI4C7EpPS8UTBIplnlzZiNuV56w+FuNxy3ty2Q==", + "devOptional": true, + "license": "MIT", + "dependencies": { + "has-flag": "^4.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/supports-color?sponsor=1" + } + }, + "node_modules/scheduler": { + "version": "0.23.2", + "resolved": "https://registry.npmjs.org/scheduler/-/scheduler-0.23.2.tgz", + "integrity": "sha512-UOShsPwz7NrMUqhR6t0hWjFduvOzbtv7toDH1/hIrfRNIDBnnBWd0CwJTGvTpngVlmwGCdP9/Zl/tVrDqcuYzQ==", + "license": "MIT", + "dependencies": { + "loose-envify": "^1.1.0" + } + }, + "node_modules/semver": { + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", + "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", + "dev": true, + "license": "ISC", + "bin": { + "semver": "bin/semver.js" + } + }, + "node_modules/set-cookie-parser": { + "version": "2.7.1", + "resolved": "https://registry.npmjs.org/set-cookie-parser/-/set-cookie-parser-2.7.1.tgz", + "integrity": "sha512-IOc8uWeOZgnb3ptbCURJWNjWUPcO3ZnTTdzsurqERrP6nPyv+paC55vJM0LpOlT2ne+Ix+9+CRG1MNLlyZ4GjQ==", + "license": "MIT" + }, + "node_modules/shebang-command": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", + "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", + "dev": true, + "license": "MIT", + "dependencies": { + "shebang-regex": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/shebang-regex": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", + "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/source-map": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", + "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==", + "license": "BSD-3-Clause", + "optional": true, + "peer": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/source-map-js": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz", + "integrity": "sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==", + "license": "BSD-3-Clause", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/source-map-support": { + "version": "0.5.21", + "resolved": "https://registry.npmjs.org/source-map-support/-/source-map-support-0.5.21.tgz", + "integrity": "sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w==", + "license": "MIT", + "optional": true, + "peer": true, + "dependencies": { + "buffer-from": "^1.0.0", + "source-map": "^0.6.0" + } + }, + "node_modules/space-separated-tokens": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/space-separated-tokens/-/space-separated-tokens-2.0.2.tgz", + "integrity": "sha512-PEGlAwrG8yXGXRjW32fGbg66JAlOAwbObuqVoJpv/mRgoWDQfgH1wDPvtzWyUSNAXBGSk8h755YDbbcEy3SH2Q==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/stringify-entities": { + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/stringify-entities/-/stringify-entities-4.0.4.tgz", + "integrity": "sha512-IwfBptatlO+QCJUo19AqvrPNqlVMpW9YEL2LIVY+Rpv2qsjCGxaDLNRgeGsQWJhfItebuJhsGSLjaBbNSQ+ieg==", + "license": "MIT", + "dependencies": { + "character-entities-html4": "^2.0.0", + "character-entities-legacy": "^3.0.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/strip-json-comments": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-3.1.1.tgz", + "integrity": "sha512-6fPc+R4ihwqP6N/aIv2f1gMH8lOVtWQHoqC4yK6oSDVVocumAsfCqjkXnqiYMhmMwS/mEHLp7Vehlt3ql6lEig==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/style-to-object": { + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/style-to-object/-/style-to-object-1.0.8.tgz", + "integrity": "sha512-xT47I/Eo0rwJmaXC4oilDGDWLohVhR6o/xAQcPQN8q6QBuZVL8qMYL85kLmST5cPjAorwvqIA4qXTRQoYHaL6g==", + "license": "MIT", + "dependencies": { + "inline-style-parser": "0.2.4" + } + }, + "node_modules/supports-color": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", + "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "dev": true, + "license": "MIT", + "dependencies": { + "has-flag": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/sync-child-process": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/sync-child-process/-/sync-child-process-1.0.2.tgz", + "integrity": "sha512-8lD+t2KrrScJ/7KXCSyfhT3/hRq78rC0wBFqNJXv3mZyn6hW2ypM05JmlSvtqRbeq6jqA94oHbxAr2vYsJ8vDA==", + "devOptional": true, + "license": "MIT", + "dependencies": { + "sync-message-port": "^1.0.0" + }, + "engines": { + "node": ">=16.0.0" + } + }, + "node_modules/sync-message-port": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/sync-message-port/-/sync-message-port-1.1.3.tgz", + "integrity": "sha512-GTt8rSKje5FilG+wEdfCkOcLL7LWqpMlr2c3LRuKt/YXxcJ52aGSbGBAdI4L3aaqfrBt6y711El53ItyH1NWzg==", + "devOptional": true, + "license": "MIT", + "engines": { + "node": ">=16.0.0" + } + }, + "node_modules/tailwindcss": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-4.1.1.tgz", + "integrity": "sha512-QNbdmeS979Efzim2g/bEvfuh+fTcIdp1y7gA+sb6OYSW74rt7Cr7M78AKdf6HqWT3d5AiTb7SwTT3sLQxr4/qw==", + "license": "MIT" + }, + "node_modules/tapable": { + "version": "2.2.1", + "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.2.1.tgz", + "integrity": "sha512-GNzQvQTOIP6RyTfE2Qxb8ZVlNmw0n88vp1szwWRimP02mnTsx3Wtn5qRdqY9w2XduFNUgvOwhNnQsjwCp+kqaQ==", + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/terser": { + "version": "5.39.1", + "resolved": "https://registry.npmjs.org/terser/-/terser-5.39.1.tgz", + "integrity": "sha512-Mm6+uad0ZuDtcV8/4uOZQDQ8RuiC5Pu+iZRedJtF7yA/27sPL7d++In/AJKpWZlU3SYMPPkVfwetn6sgZ66pUA==", + "license": "BSD-2-Clause", + "optional": true, + "peer": true, + "dependencies": { + "@jridgewell/source-map": "^0.3.3", + "acorn": "^8.8.2", + "commander": "^2.20.0", + "source-map-support": "~0.5.20" + }, + "bin": { + "terser": "bin/terser" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/terser/node_modules/commander": { + "version": "2.20.3", + "resolved": "https://registry.npmjs.org/commander/-/commander-2.20.3.tgz", + "integrity": "sha512-GpVkmM8vF2vQUkj2LvZmD35JxeJOLCwJ9cUkugyk2nuhbv3+mJvpLYYt+0+USMxE+oj+ey/lJEnhZw75x/OMcQ==", + "license": "MIT", + "optional": true, + "peer": true + }, + "node_modules/textlinestream": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/textlinestream/-/textlinestream-1.1.1.tgz", + "integrity": "sha512-iBHbi7BQxrFmwZUQJsT0SjNzlLLsXhvW/kg7EyOMVMBIrlnj/qYofwo1LVLZi+3GbUEo96Iu2eqToI2+lZoAEQ==", + "license": "MIT" + }, + "node_modules/to-regex-range": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", + "integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==", + "license": "MIT", + "dependencies": { + "is-number": "^7.0.0" + }, + "engines": { + "node": ">=8.0" + } + }, + "node_modules/trim-lines": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/trim-lines/-/trim-lines-3.0.1.tgz", + "integrity": "sha512-kRj8B+YHZCc9kQYdWfJB2/oUl9rA99qbowYYBtr4ui4mZyAQ2JpvVBd/6U2YloATfqBhBTSMhTpgBHtU0Mf3Rg==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/trough": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/trough/-/trough-2.2.0.tgz", + "integrity": "sha512-tmMpK00BjZiUyVyvrBK7knerNgmgvcV/KLVyuma/SC+TQN167GrMRciANTz09+k3zW8L8t60jWO1GpfkZdjTaw==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/ts-api-utils": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-2.0.1.tgz", + "integrity": "sha512-dnlgjFSVetynI8nzgJ+qF62efpglpWRk8isUEWZGWlJYySCTD6aKvbUDu+zbPeDakk3bg5H4XpitHukgfL1m9w==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18.12" + }, + "peerDependencies": { + "typescript": ">=4.8.4" + } + }, + "node_modules/tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "license": "0BSD" + }, + "node_modules/turbo-stream": { + "version": "2.4.0", + "resolved": "https://registry.npmjs.org/turbo-stream/-/turbo-stream-2.4.0.tgz", + "integrity": "sha512-FHncC10WpBd2eOmGwpmQsWLDoK4cqsA/UT/GqNoaKOQnT8uzhtCbg3EoUDMvqpOSAI0S26mr0rkjzbOO6S3v1g==", + "license": "ISC" + }, + "node_modules/type-check": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/type-check/-/type-check-0.4.0.tgz", + "integrity": "sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew==", + "dev": true, + "license": "MIT", + "dependencies": { + "prelude-ls": "^1.2.1" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/typescript": { + "version": "5.6.3", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.6.3.tgz", + "integrity": "sha512-hjcS1mhfuyi4WW8IWtjP7brDrG2cuDZukyrYrSauoXGNgx0S7zceP07adYkJycEr56BOUTNPzbInooiN3fn1qw==", + "dev": true, + "license": "Apache-2.0", + "bin": { + "tsc": "bin/tsc", + "tsserver": "bin/tsserver" + }, + "engines": { + "node": ">=14.17" + } + }, + "node_modules/typescript-eslint": { + "version": "8.23.0", + "resolved": "https://registry.npmjs.org/typescript-eslint/-/typescript-eslint-8.23.0.tgz", + "integrity": "sha512-/LBRo3HrXr5LxmrdYSOCvoAMm7p2jNizNfbIpCgvG4HMsnoprRUOce/+8VJ9BDYWW68rqIENE/haVLWPeFZBVQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/eslint-plugin": "8.23.0", + "@typescript-eslint/parser": "8.23.0", + "@typescript-eslint/utils": "8.23.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^8.57.0 || ^9.0.0", + "typescript": ">=4.8.4 <5.8.0" + } + }, + "node_modules/undici-types": { + "version": "6.20.0", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.20.0.tgz", + "integrity": "sha512-Ny6QZ2Nju20vw1SRHe3d9jVu6gJ+4e3+MMpqu7pqE5HT6WsTSlce++GQmK5UXS8mzV8DSYHrQH+Xrf2jVcuKNg==", + "devOptional": true, + "license": "MIT" + }, + "node_modules/unified": { + "version": "11.0.5", + "resolved": "https://registry.npmjs.org/unified/-/unified-11.0.5.tgz", + "integrity": "sha512-xKvGhPWw3k84Qjh8bI3ZeJjqnyadK+GEFtazSfZv/rKeTkTjOJho6mFqh2SM96iIcZokxiOpg78GazTSg8+KHA==", + "license": "MIT", + "dependencies": { + "@types/unist": "^3.0.0", + "bail": "^2.0.0", + "devlop": "^1.0.0", + "extend": "^3.0.0", + "is-plain-obj": "^4.0.0", + "trough": "^2.0.0", + "vfile": "^6.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/unist-util-find-after": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/unist-util-find-after/-/unist-util-find-after-5.0.0.tgz", + "integrity": "sha512-amQa0Ep2m6hE2g72AugUItjbuM8X8cGQnFoHk0pGfrFeT9GZhzN5SW8nRsiGKK7Aif4CrACPENkA6P/Lw6fHGQ==", + "license": "MIT", + "dependencies": { + "@types/unist": "^3.0.0", + "unist-util-is": "^6.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/unist-util-is": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/unist-util-is/-/unist-util-is-6.0.0.tgz", + "integrity": "sha512-2qCTHimwdxLfz+YzdGfkqNlH0tLi9xjTnHddPmJwtIG9MGsdbutfTc4P+haPD7l7Cjxf/WZj+we5qfVPvvxfYw==", + "license": "MIT", + "dependencies": { + "@types/unist": "^3.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/unist-util-position": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/unist-util-position/-/unist-util-position-5.0.0.tgz", + "integrity": "sha512-fucsC7HjXvkB5R3kTCO7kUjRdrS0BJt3M/FPxmHMBOm8JQi2BsHAHFsy27E0EolP8rp0NzXsJ+jNPyDWvOJZPA==", + "license": "MIT", + "dependencies": { + "@types/unist": "^3.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/unist-util-remove-position": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/unist-util-remove-position/-/unist-util-remove-position-5.0.0.tgz", + "integrity": "sha512-Hp5Kh3wLxv0PHj9m2yZhhLt58KzPtEYKQQ4yxfYFEO7EvHwzyDYnduhHnY1mDxoqr7VUwVuHXk9RXKIiYS1N8Q==", + "license": "MIT", + "dependencies": { + "@types/unist": "^3.0.0", + "unist-util-visit": "^5.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/unist-util-stringify-position": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/unist-util-stringify-position/-/unist-util-stringify-position-4.0.0.tgz", + "integrity": "sha512-0ASV06AAoKCDkS2+xw5RXJywruurpbC4JZSm7nr7MOt1ojAzvyyaO+UxZf18j8FCF6kmzCZKcAgN/yu2gm2XgQ==", + "license": "MIT", + "dependencies": { + "@types/unist": "^3.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/unist-util-visit": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/unist-util-visit/-/unist-util-visit-5.0.0.tgz", + "integrity": "sha512-MR04uvD+07cwl/yhVuVWAtw+3GOR/knlL55Nd/wAdblk27GCVt3lqpTivy/tkJcZoNPzTwS1Y+KMojlLDhoTzg==", + "license": "MIT", + "dependencies": { + "@types/unist": "^3.0.0", + "unist-util-is": "^6.0.0", + "unist-util-visit-parents": "^6.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/unist-util-visit-parents": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/unist-util-visit-parents/-/unist-util-visit-parents-6.0.1.tgz", + "integrity": "sha512-L/PqWzfTP9lzzEa6CKs0k2nARxTdZduw3zyh8d2NVBnsyvHjSX4TWse388YrrQKbvI8w20fGjGlhgT96WwKykw==", + "license": "MIT", + "dependencies": { + "@types/unist": "^3.0.0", + "unist-util-is": "^6.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/update-browserslist-db": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.1.2.tgz", + "integrity": "sha512-PPypAm5qvlD7XMZC3BujecnaOxwhrtoFR+Dqkk5Aa/6DssiH0ibKoketaj9w8LP7Bont1rYeoV5plxD7RTEPRg==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/browserslist" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/browserslist" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "dependencies": { + "escalade": "^3.2.0", + "picocolors": "^1.1.1" + }, + "bin": { + "update-browserslist-db": "cli.js" + }, + "peerDependencies": { + "browserslist": ">= 4.21.0" + } + }, + "node_modules/uri-js": { + "version": "4.4.1", + "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz", + "integrity": "sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "punycode": "^2.1.0" + } + }, + "node_modules/varint": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/varint/-/varint-6.0.0.tgz", + "integrity": "sha512-cXEIW6cfr15lFv563k4GuVuW/fiwjknytD37jIOLSdSWuOI6WnO/oKwmP2FQTU2l01LP8/M5TSAJpzUaGe3uWg==", + "devOptional": true, + "license": "MIT" + }, + "node_modules/vfile": { + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/vfile/-/vfile-6.0.3.tgz", + "integrity": "sha512-KzIbH/9tXat2u30jf+smMwFCsno4wHVdNmzFyL+T/L3UGqqk6JKfVqOFOZEpZSHADH1k40ab6NUIXZq422ov3Q==", + "license": "MIT", + "dependencies": { + "@types/unist": "^3.0.0", + "vfile-message": "^4.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/vfile-location": { + "version": "5.0.3", + "resolved": "https://registry.npmjs.org/vfile-location/-/vfile-location-5.0.3.tgz", + "integrity": "sha512-5yXvWDEgqeiYiBe1lbxYF7UMAIm/IcopxMHrMQDq3nvKcjPKIhZklUKL+AE7J7uApI4kwe2snsK+eI6UTj9EHg==", + "license": "MIT", + "dependencies": { + "@types/unist": "^3.0.0", + "vfile": "^6.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/vfile-message": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/vfile-message/-/vfile-message-4.0.2.tgz", + "integrity": "sha512-jRDZ1IMLttGj41KcZvlrYAaI3CfqpLpfpf+Mfig13viT6NKvRzWZ+lXz0Y5D60w6uJIBAOGq9mSHf0gktF0duw==", + "license": "MIT", + "dependencies": { + "@types/unist": "^3.0.0", + "unist-util-stringify-position": "^4.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/vite": { + "version": "6.0.11", + "resolved": "https://registry.npmjs.org/vite/-/vite-6.0.11.tgz", + "integrity": "sha512-4VL9mQPKoHy4+FE0NnRE/kbY51TOfaknxAjt3fJbGJxhIpBZiqVzlZDEesWWsuREXHwNdAoOFZ9MkPEVXczHwg==", + "license": "MIT", + "dependencies": { + "esbuild": "^0.24.2", + "postcss": "^8.4.49", + "rollup": "^4.23.0" + }, + "bin": { + "vite": "bin/vite.js" + }, + "engines": { + "node": "^18.0.0 || ^20.0.0 || >=22.0.0" + }, + "funding": { + "url": "https://github.com/vitejs/vite?sponsor=1" + }, + "optionalDependencies": { + "fsevents": "~2.3.3" + }, + "peerDependencies": { + "@types/node": "^18.0.0 || ^20.0.0 || >=22.0.0", + "jiti": ">=1.21.0", + "less": "*", + "lightningcss": "^1.21.0", + "sass": "*", + "sass-embedded": "*", + "stylus": "*", + "sugarss": "*", + "terser": "^5.16.0", + "tsx": "^4.8.1", + "yaml": "^2.4.2" + }, + "peerDependenciesMeta": { + "@types/node": { + "optional": true + }, + "jiti": { + "optional": true + }, + "less": { + "optional": true + }, + "lightningcss": { + "optional": true + }, + "sass": { + "optional": true + }, + "sass-embedded": { + "optional": true + }, + "stylus": { + "optional": true + }, + "sugarss": { + "optional": true + }, + "terser": { + "optional": true + }, + "tsx": { + "optional": true + }, + "yaml": { + "optional": true + } + } + }, + "node_modules/vite-plugin-singlefile": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/vite-plugin-singlefile/-/vite-plugin-singlefile-2.1.0.tgz", + "integrity": "sha512-7tJo+UgZABlKpY/nubth/wxJ4+pUGREPnEwNOknxwl2MM0zTvF14KTU4Ln1lc140gjLLV5mjDrvuoquU7OZqCg==", + "license": "MIT", + "dependencies": { + "micromatch": "^4.0.8" + }, + "engines": { + "node": ">18.0.0" + }, + "peerDependencies": { + "rollup": "^4.28.1", + "vite": "^5.4.11 || ^6.0.0" + } + }, + "node_modules/web-namespaces": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/web-namespaces/-/web-namespaces-2.0.1.tgz", + "integrity": "sha512-bKr1DkiNa2krS7qxNtdrtHAmzuYGFQLiQ13TsorsdT6ULTkPLKuu5+GsFpDlg6JFjUTwX2DyhMPG2be8uPrqsQ==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/which": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", + "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", + "dev": true, + "license": "ISC", + "dependencies": { + "isexe": "^2.0.0" + }, + "bin": { + "node-which": "bin/node-which" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/word-wrap": { + "version": "1.2.5", + "resolved": "https://registry.npmjs.org/word-wrap/-/word-wrap-1.2.5.tgz", + "integrity": "sha512-BN22B5eaMMI9UMtjrGd5g5eCYPpCPDUy0FJXbYsaT5zYxjFOckS53SQDE3pWkVoWpHXVb3BrYcEN4Twa55B5cA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/yallist": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/yallist/-/yallist-3.1.1.tgz", + "integrity": "sha512-a4UGQaWPH59mOXUYnAG2ewncQS4i4F43Tv3JoAM+s2VDAmS9NsK8GpDMLrCHPksFT7h3K6TOoUNn2pb7RoXx4g==", + "dev": true, + "license": "ISC" + }, + "node_modules/yaml": { + "version": "2.7.0", + "resolved": "https://registry.npmjs.org/yaml/-/yaml-2.7.0.tgz", + "integrity": "sha512-+hSoy/QHluxmC9kCIJyL/uyFmLmc+e5CFR5Wa+bpIhIj85LVb9ZH2nVnqrHoSvKogwODv0ClqZkmiSSaIH5LTA==", + "license": "ISC", + "optional": true, + "peer": true, + "bin": { + "yaml": "bin.mjs" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/yocto-queue": { + "version": "0.1.0", + "resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz", + "integrity": "sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/zwitch": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/zwitch/-/zwitch-2.0.4.tgz", + "integrity": "sha512-bXE4cR/kVZhKZX/RjPEflHaKVhUVl85noU3v6b8apfQEc1x4A+zBxjZ4lN8LqGd6WZ3dl98pY4o717VFmoPp+A==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + } + } +} diff --git a/tools/server/webui/package.json b/tools/server/webui/package.json new file mode 100644 index 0000000000000000000000000000000000000000..8076840324d49d9b7262dba6f77c7dc697bb9b39 --- /dev/null +++ b/tools/server/webui/package.json @@ -0,0 +1,66 @@ +{ + "name": "webui", + "private": true, + "version": "0.0.0", + "type": "module", + "scripts": { + "dev": "vite", + "build": "npm run format && tsc -b && vite build", + "format": "eslint . && prettier --write .", + "lint": "eslint .", + "preview": "vite preview" + }, + "dependencies": { + "@heroicons/react": "^2.2.0", + "@sec-ant/readable-stream": "^0.6.0", + "@tailwindcss/postcss": "^4.1.1", + "@tailwindcss/vite": "^4.1.1", + "@vscode/markdown-it-katex": "^1.1.1", + "autoprefixer": "^10.4.20", + "daisyui": "^5.0.12", + "dexie": "^4.0.11", + "highlight.js": "^11.10.0", + "katex": "^0.16.15", + "pdfjs-dist": "^5.2.133", + "postcss": "^8.4.49", + "react": "^18.3.1", + "react-dom": "^18.3.1", + "react-dropzone": "^14.3.8", + "react-hot-toast": "^2.5.2", + "react-markdown": "^9.0.3", + "react-router": "^7.1.5", + "rehype-highlight": "^7.0.2", + "rehype-katex": "^7.0.1", + "remark-breaks": "^4.0.0", + "remark-gfm": "^4.0.0", + "remark-math": "^6.0.0", + "tailwindcss": "^4.1.1", + "textlinestream": "^1.1.1", + "vite-plugin-singlefile": "^2.0.3" + }, + "devDependencies": { + "@eslint/js": "^9.17.0", + "@types/markdown-it": "^14.1.2", + "@types/node": "^22.13.1", + "@types/react": "^18.3.18", + "@types/react-dom": "^18.3.5", + "@vitejs/plugin-react": "^4.3.4", + "eslint": "^9.17.0", + "eslint-plugin-react-hooks": "^5.0.0", + "eslint-plugin-react-refresh": "^0.4.16", + "fflate": "^0.8.2", + "globals": "^15.14.0", + "prettier": "^3.4.2", + "sass-embedded": "^1.83.4", + "typescript": "~5.6.2", + "typescript-eslint": "^8.18.2", + "vite": "^6.0.5" + }, + "prettier": { + "trailingComma": "es5", + "tabWidth": 2, + "semi": true, + "singleQuote": true, + "bracketSameLine": false + } +} diff --git a/tools/server/webui/postcss.config.js b/tools/server/webui/postcss.config.js new file mode 100644 index 0000000000000000000000000000000000000000..fb05b5692bba713f330ce6843112936d5fe4ad04 --- /dev/null +++ b/tools/server/webui/postcss.config.js @@ -0,0 +1,5 @@ +export default { + plugins: { + "@tailwindcss/postcss": {}, + }, +} diff --git a/tools/server/webui/public/demo-conversation.json b/tools/server/webui/public/demo-conversation.json new file mode 100644 index 0000000000000000000000000000000000000000..338b4aea590f28682f900fd14770111c0d7c9040 --- /dev/null +++ b/tools/server/webui/public/demo-conversation.json @@ -0,0 +1,33 @@ +{ + "demo": true, + "id": "conv-1734086746930", + "lastModified": 1734087548943, + "messages": [ + { + "id": 1734086764521, + "role": "user", + "content": "this is a demo conversation, used in dev mode" + }, + { + "id": 1734087548327, + "role": "assistant", + "content": "This is the formula:\n\n$\\frac{e^{x_i}}{\\sum_{j=1}^{n}e^{x_j}}$\n\nGiven an input vector \\(\\mathbf{x} = [x_1, x_2, \\ldots, x_n]\\)\n\n\\[\ny_i = \\frac{e^{x_i}}{\\sum_{j=1}^n e^{x_j}}\n\\]\n\n$2x + y = z$\n\nCode block latex:\n```latex\n\\frac{e^{x_i}}{\\sum_{j=1}^{n}e^{x_j}}\n```\n\nTest dollar sign: $1234 $4567\n\nInvalid latex syntax: $E = mc^$ and $$E = mc^$$", + "timings": { + "prompt_n": 1, + "prompt_ms": 28.923, + "predicted_n": 25, + "predicted_ms": 573.016 + } + }, + { + "id": 1734087548328, + "role": "user", + "content": "this is a demo conversation, used in dev mode" + }, + { + "id": 1734087548329, + "role": "assistant", + "content": "Code block:\n```js\nconsole.log('hello world')\n```\n```sh\nls -la /dev\n```" + } + ] +} diff --git a/tools/server/webui/src/App.tsx b/tools/server/webui/src/App.tsx new file mode 100644 index 0000000000000000000000000000000000000000..8dfcf49075803904f933ce4350530b5715f0ae14 --- /dev/null +++ b/tools/server/webui/src/App.tsx @@ -0,0 +1,52 @@ +import { HashRouter, Outlet, Route, Routes } from 'react-router'; +import Header from './components/Header'; +import Sidebar from './components/Sidebar'; +import { AppContextProvider, useAppContext } from './utils/app.context'; +import ChatScreen from './components/ChatScreen'; +import SettingDialog from './components/SettingDialog'; +import { Toaster } from 'react-hot-toast'; +import { ModalProvider } from './components/ModalProvider'; + +function App() { + return ( + + +
+ + + }> + } /> + } /> + + + +
+
+
+ ); +} + +function AppLayout() { + const { showSettings, setShowSettings } = useAppContext(); + return ( + <> + +
+
+ +
+ { + setShowSettings(false)} + /> + } + + + ); +} + +export default App; diff --git a/tools/server/webui/src/Config.ts b/tools/server/webui/src/Config.ts new file mode 100644 index 0000000000000000000000000000000000000000..c03ac287f3484e4e4002ea3019509fe07b57894e --- /dev/null +++ b/tools/server/webui/src/Config.ts @@ -0,0 +1,96 @@ +import daisyuiThemes from 'daisyui/theme/object'; +import { isNumeric } from './utils/misc'; + +export const isDev = import.meta.env.MODE === 'development'; + +// constants +export const BASE_URL = new URL('.', document.baseURI).href + .toString() + .replace(/\/$/, ''); + +export const CONFIG_DEFAULT = { + // Note: in order not to introduce breaking changes, please keep the same data type (number, string, etc) if you want to change the default value. Do not use null or undefined for default value. + // Do not use nested objects, keep it single level. Prefix the key if you need to group them. + apiKey: '', + systemMessage: '', + showTokensPerSecond: false, + showThoughtInProgress: false, + excludeThoughtOnReq: true, + pasteLongTextToFileLen: 2500, + pdfAsImage: false, + // make sure these default values are in sync with `common.h` + samplers: 'edkypmxt', + temperature: 0.8, + dynatemp_range: 0.0, + dynatemp_exponent: 1.0, + top_k: 40, + top_p: 0.95, + min_p: 0.05, + xtc_probability: 0.0, + xtc_threshold: 0.1, + typical_p: 1.0, + repeat_last_n: 64, + repeat_penalty: 1.0, + presence_penalty: 0.0, + frequency_penalty: 0.0, + dry_multiplier: 0.0, + dry_base: 1.75, + dry_allowed_length: 2, + dry_penalty_last_n: -1, + max_tokens: -1, + custom: '', // custom json-stringified object + // experimental features + pyIntepreterEnabled: false, +}; +export const CONFIG_INFO: Record = { + apiKey: 'Set the API Key if you are using --api-key option for the server.', + systemMessage: 'The starting message that defines how model should behave.', + pasteLongTextToFileLen: + 'On pasting long text, it will be converted to a file. You can control the file length by setting the value of this parameter. Value 0 means disable.', + samplers: + 'The order at which samplers are applied, in simplified way. Default is "dkypmxt": dry->top_k->typ_p->top_p->min_p->xtc->temperature', + temperature: + 'Controls the randomness of the generated text by affecting the probability distribution of the output tokens. Higher = more random, lower = more focused.', + dynatemp_range: + 'Addon for the temperature sampler. The added value to the range of dynamic temperature, which adjusts probabilities by entropy of tokens.', + dynatemp_exponent: + 'Addon for the temperature sampler. Smoothes out the probability redistribution based on the most probable token.', + top_k: 'Keeps only k top tokens.', + top_p: + 'Limits tokens to those that together have a cumulative probability of at least p', + min_p: + 'Limits tokens based on the minimum probability for a token to be considered, relative to the probability of the most likely token.', + xtc_probability: + 'XTC sampler cuts out top tokens; this parameter controls the chance of cutting tokens at all. 0 disables XTC.', + xtc_threshold: + 'XTC sampler cuts out top tokens; this parameter controls the token probability that is required to cut that token.', + typical_p: + 'Sorts and limits tokens based on the difference between log-probability and entropy.', + repeat_last_n: 'Last n tokens to consider for penalizing repetition', + repeat_penalty: + 'Controls the repetition of token sequences in the generated text', + presence_penalty: + 'Limits tokens based on whether they appear in the output or not.', + frequency_penalty: + 'Limits tokens based on how often they appear in the output.', + dry_multiplier: + 'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets the DRY sampling multiplier.', + dry_base: + 'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets the DRY sampling base value.', + dry_allowed_length: + 'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets the allowed length for DRY sampling.', + dry_penalty_last_n: + 'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets DRY penalty for the last n tokens.', + max_tokens: 'The maximum number of token per output.', + custom: '', // custom json-stringified object +}; +// config keys having numeric value (i.e. temperature, top_k, top_p, etc) +export const CONFIG_NUMERIC_KEYS = Object.entries(CONFIG_DEFAULT) + .filter((e) => isNumeric(e[1])) + .map((e) => e[0]); +// list of themes supported by daisyui +export const THEMES = ['light', 'dark'] + // make sure light & dark are always at the beginning + .concat( + Object.keys(daisyuiThemes).filter((t) => t !== 'light' && t !== 'dark') + ); diff --git a/tools/server/webui/src/components/CanvasPyInterpreter.tsx b/tools/server/webui/src/components/CanvasPyInterpreter.tsx new file mode 100644 index 0000000000000000000000000000000000000000..c2707fe20fcec2b2f3df0a82222bd4d6646cb3d0 --- /dev/null +++ b/tools/server/webui/src/components/CanvasPyInterpreter.tsx @@ -0,0 +1,195 @@ +import { useEffect, useState } from 'react'; +import { useAppContext } from '../utils/app.context'; +import { OpenInNewTab, XCloseButton } from '../utils/common'; +import { CanvasType } from '../utils/types'; +import { PlayIcon, StopIcon } from '@heroicons/react/24/outline'; +import StorageUtils from '../utils/storage'; + +const canInterrupt = typeof SharedArrayBuffer === 'function'; + +// adapted from https://pyodide.org/en/stable/usage/webworker.html +const WORKER_CODE = ` +importScripts("https://cdn.jsdelivr.net/pyodide/v0.27.2/full/pyodide.js"); + +let stdOutAndErr = []; + +let pyodideReadyPromise = loadPyodide({ + stdout: (data) => stdOutAndErr.push(data), + stderr: (data) => stdOutAndErr.push(data), +}); + +let alreadySetBuff = false; + +self.onmessage = async (event) => { + stdOutAndErr = []; + + // make sure loading is done + const pyodide = await pyodideReadyPromise; + const { id, python, context, interruptBuffer } = event.data; + + if (interruptBuffer && !alreadySetBuff) { + pyodide.setInterruptBuffer(interruptBuffer); + alreadySetBuff = true; + } + + // Now load any packages we need, run the code, and send the result back. + await pyodide.loadPackagesFromImports(python); + + // make a Python dictionary with the data from content + const dict = pyodide.globals.get("dict"); + const globals = dict(Object.entries(context)); + try { + self.postMessage({ id, running: true }); + // Execute the python code in this context + const result = pyodide.runPython(python, { globals }); + self.postMessage({ result, id, stdOutAndErr }); + } catch (error) { + self.postMessage({ error: error.message, id }); + } + interruptBuffer[0] = 0; +}; +`; + +let worker: Worker; +const interruptBuffer = canInterrupt + ? new Uint8Array(new SharedArrayBuffer(1)) + : null; + +const startWorker = () => { + if (!worker) { + worker = new Worker( + URL.createObjectURL(new Blob([WORKER_CODE], { type: 'text/javascript' })) + ); + } +}; + +if (StorageUtils.getConfig().pyIntepreterEnabled) { + startWorker(); +} + +const runCodeInWorker = ( + pyCode: string, + callbackRunning: () => void +): { + donePromise: Promise; + interrupt: () => void; +} => { + startWorker(); + const id = Math.random() * 1e8; + const context = {}; + if (interruptBuffer) { + interruptBuffer[0] = 0; + } + + const donePromise = new Promise((resolve) => { + worker.onmessage = (event) => { + const { error, stdOutAndErr, running } = event.data; + if (id !== event.data.id) return; + if (running) { + callbackRunning(); + return; + } else if (error) { + resolve(error.toString()); + } else { + resolve(stdOutAndErr.join('\n')); + } + }; + worker.postMessage({ id, python: pyCode, context, interruptBuffer }); + }); + + const interrupt = () => { + console.log('Interrupting...'); + console.trace(); + if (interruptBuffer) { + interruptBuffer[0] = 2; + } + }; + + return { donePromise, interrupt }; +}; + +export default function CanvasPyInterpreter() { + const { canvasData, setCanvasData } = useAppContext(); + + const [code, setCode] = useState(canvasData?.content ?? ''); // copy to avoid direct mutation + const [running, setRunning] = useState(false); + const [output, setOutput] = useState(''); + const [interruptFn, setInterruptFn] = useState<() => void>(); + const [showStopBtn, setShowStopBtn] = useState(false); + + const runCode = async (pycode: string) => { + interruptFn?.(); + setRunning(true); + setOutput('Loading Pyodide...'); + const { donePromise, interrupt } = runCodeInWorker(pycode, () => { + setOutput('Running...'); + setShowStopBtn(canInterrupt); + }); + setInterruptFn(() => interrupt); + const out = await donePromise; + setOutput(out); + setRunning(false); + setShowStopBtn(false); + }; + + // run code on mount + useEffect(() => { + setCode(canvasData?.content ?? ''); + runCode(canvasData?.content ?? ''); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [canvasData?.content]); + + if (canvasData?.type !== CanvasType.PY_INTERPRETER) { + return null; + } + + return ( +
+
+
+ Python Interpreter + setCanvasData(null)} + /> +
+
+ +
+
+ + {showStopBtn && ( + + )} + + + Report a bug + + +
+ +
+
+
+
+ ); +} diff --git a/tools/server/webui/src/components/ChatInputExtraContextItem.tsx b/tools/server/webui/src/components/ChatInputExtraContextItem.tsx new file mode 100644 index 0000000000000000000000000000000000000000..2d4179ea4703ed2c3a8a795e9d03ddf5154da69c --- /dev/null +++ b/tools/server/webui/src/components/ChatInputExtraContextItem.tsx @@ -0,0 +1,135 @@ +import { + DocumentTextIcon, + SpeakerWaveIcon, + XMarkIcon, +} from '@heroicons/react/24/outline'; +import { MessageExtra } from '../utils/types'; +import { useState } from 'react'; +import { classNames } from '../utils/misc'; + +export default function ChatInputExtraContextItem({ + items, + removeItem, + clickToShow, +}: { + items?: MessageExtra[]; + removeItem?: (index: number) => void; + clickToShow?: boolean; +}) { + const [show, setShow] = useState(-1); + const showingItem = show >= 0 ? items?.[show] : undefined; + + if (!items) return null; + + return ( +
+ {items.map((item, i) => ( +
clickToShow && setShow(i)} + tabIndex={0} + aria-description={ + clickToShow ? `Click to show: ${item.name}` : undefined + } + role={clickToShow ? 'button' : 'menuitem'} + > + {removeItem && ( +
+ +
+ )} + +
+ {item.type === 'imageFile' ? ( + <> + {`Preview + + ) : ( + <> +
+ {item.type === 'audioFile' ? ( + + ) : ( + + )} +
+ +
+ {item.name ?? 'Extra content'} +
+ + )} +
+
+ ))} + + {showingItem && ( + +
+
+ {showingItem.name ?? 'Extra content'} + +
+ {showingItem.type === 'imageFile' ? ( + {`Preview + ) : showingItem.type === 'audioFile' ? ( + + ) : ( +
+
+                  {showingItem.content}
+                
+
+ )} +
+
setShow(-1)}>
+
+ )} +
+ ); +} diff --git a/tools/server/webui/src/components/ChatMessage.tsx b/tools/server/webui/src/components/ChatMessage.tsx new file mode 100644 index 0000000000000000000000000000000000000000..ee59de450d1ffc6c022b50a6a0940da765083571 --- /dev/null +++ b/tools/server/webui/src/components/ChatMessage.tsx @@ -0,0 +1,318 @@ +import { useMemo, useState } from 'react'; +import { useAppContext } from '../utils/app.context'; +import { Message, PendingMessage } from '../utils/types'; +import { classNames } from '../utils/misc'; +import MarkdownDisplay, { CopyButton } from './MarkdownDisplay'; +import { + ArrowPathIcon, + ChevronLeftIcon, + ChevronRightIcon, + PencilSquareIcon, +} from '@heroicons/react/24/outline'; +import ChatInputExtraContextItem from './ChatInputExtraContextItem'; +import { BtnWithTooltips } from '../utils/common'; + +interface SplitMessage { + content: PendingMessage['content']; + thought?: string; + isThinking?: boolean; +} + +export default function ChatMessage({ + msg, + siblingLeafNodeIds, + siblingCurrIdx, + id, + onRegenerateMessage, + onEditMessage, + onChangeSibling, + isPending, +}: { + msg: Message | PendingMessage; + siblingLeafNodeIds: Message['id'][]; + siblingCurrIdx: number; + id?: string; + onRegenerateMessage(msg: Message): void; + onEditMessage(msg: Message, content: string): void; + onChangeSibling(sibling: Message['id']): void; + isPending?: boolean; +}) { + const { viewingChat, config } = useAppContext(); + const [editingContent, setEditingContent] = useState(null); + const timings = useMemo( + () => + msg.timings + ? { + ...msg.timings, + prompt_per_second: + (msg.timings.prompt_n / msg.timings.prompt_ms) * 1000, + predicted_per_second: + (msg.timings.predicted_n / msg.timings.predicted_ms) * 1000, + } + : null, + [msg.timings] + ); + const nextSibling = siblingLeafNodeIds[siblingCurrIdx + 1]; + const prevSibling = siblingLeafNodeIds[siblingCurrIdx - 1]; + + // for reasoning model, we split the message into content and thought + // TODO: implement this as remark/rehype plugin in the future + const { content, thought, isThinking }: SplitMessage = useMemo(() => { + if (msg.content === null || msg.role !== 'assistant') { + return { content: msg.content }; + } + let actualContent = ''; + let thought = ''; + let isThinking = false; + let thinkSplit = msg.content.split('', 2); + actualContent += thinkSplit[0]; + while (thinkSplit[1] !== undefined) { + // tag found + thinkSplit = thinkSplit[1].split('', 2); + thought += thinkSplit[0]; + isThinking = true; + if (thinkSplit[1] !== undefined) { + // closing tag found + isThinking = false; + thinkSplit = thinkSplit[1].split('', 2); + actualContent += thinkSplit[0]; + } + } + return { content: actualContent, thought, isThinking }; + }, [msg]); + + if (!viewingChat) return null; + + const isUser = msg.role === 'user'; + + return ( +
+
+ {msg.extra && msg.extra.length > 0 && ( + + )} + +
+ {/* textarea for editing message */} + {editingContent !== null && ( + <> + +
+ + + + )} + {/* not editing content, render message */} + {editingContent === null && ( + <> + {content === null ? ( + <> + {/* show loading dots for pending message */} + + + ) : ( + <> + {/* render message as markdown */} +
+ {thought && ( + + )} + + +
+ + )} + {/* render timings if enabled */} + {timings && config.showTokensPerSecond && ( +
+
+ Speed: {timings.predicted_per_second.toFixed(1)} t/s +
+
+ Prompt +
- Tokens: {timings.prompt_n} +
- Time: {timings.prompt_ms} ms +
- Speed: {timings.prompt_per_second.toFixed(1)} t/s +
+ Generation +
- Tokens: {timings.predicted_n} +
- Time: {timings.predicted_ms} ms +
- Speed: {timings.predicted_per_second.toFixed(1)} t/s +
+
+
+ )} + + )} +
+
+ + {/* actions for each message */} + {msg.content !== null && ( +
+ {siblingLeafNodeIds && siblingLeafNodeIds.length > 1 && ( +
+ + + {siblingCurrIdx + 1} / {siblingLeafNodeIds.length} + + +
+ )} + {/* user message */} + {msg.role === 'user' && ( + setEditingContent(msg.content)} + disabled={msg.content === null} + tooltipsContent="Edit message" + > + + + )} + {/* assistant message */} + {msg.role === 'assistant' && ( + <> + {!isPending && ( + { + if (msg.content !== null) { + onRegenerateMessage(msg as Message); + } + }} + disabled={msg.content === null} + tooltipsContent="Regenerate response" + > + + + )} + + )} + +
+ )} +
+ ); +} + +function ThoughtProcess({ + isThinking, + content, + open, +}: { + isThinking: boolean; + content: string; + open: boolean; +}) { + return ( +
+ +
+
+ {isThinking ? ( + + + Thinking + + ) : ( + <>Thought Process + )} +
+
+
+
+ +
+
+
+ ); +} diff --git a/tools/server/webui/src/components/ChatScreen.tsx b/tools/server/webui/src/components/ChatScreen.tsx new file mode 100644 index 0000000000000000000000000000000000000000..c1a669144550797f1e2bd60de074b471da134a3e --- /dev/null +++ b/tools/server/webui/src/components/ChatScreen.tsx @@ -0,0 +1,459 @@ +import { ClipboardEvent, useEffect, useMemo, useRef, useState } from 'react'; +import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context'; +import ChatMessage from './ChatMessage'; +import { CanvasType, Message, PendingMessage } from '../utils/types'; +import { classNames, cleanCurrentUrl } from '../utils/misc'; +import CanvasPyInterpreter from './CanvasPyInterpreter'; +import StorageUtils from '../utils/storage'; +import { useVSCodeContext } from '../utils/llama-vscode'; +import { useChatTextarea, ChatTextareaApi } from './useChatTextarea.ts'; +import { + ArrowUpIcon, + StopIcon, + PaperClipIcon, +} from '@heroicons/react/24/solid'; +import { + ChatExtraContextApi, + useChatExtraContext, +} from './useChatExtraContext.tsx'; +import Dropzone from 'react-dropzone'; +import toast from 'react-hot-toast'; +import ChatInputExtraContextItem from './ChatInputExtraContextItem.tsx'; +import { scrollToBottom, useChatScroll } from './useChatScroll.tsx'; + +/** + * A message display is a message node with additional information for rendering. + * For example, siblings of the message node are stored as their last node (aka leaf node). + */ +export interface MessageDisplay { + msg: Message | PendingMessage; + siblingLeafNodeIds: Message['id'][]; + siblingCurrIdx: number; + isPending?: boolean; +} + +/** + * If the current URL contains "?m=...", prefill the message input with the value. + * If the current URL contains "?q=...", prefill and SEND the message. + */ +const prefilledMsg = { + content() { + const url = new URL(window.location.href); + return url.searchParams.get('m') ?? url.searchParams.get('q') ?? ''; + }, + shouldSend() { + const url = new URL(window.location.href); + return url.searchParams.has('q'); + }, + clear() { + cleanCurrentUrl(['m', 'q']); + }, +}; + +function getListMessageDisplay( + msgs: Readonly, + leafNodeId: Message['id'] +): MessageDisplay[] { + const currNodes = StorageUtils.filterByLeafNodeId(msgs, leafNodeId, true); + const res: MessageDisplay[] = []; + const nodeMap = new Map(); + for (const msg of msgs) { + nodeMap.set(msg.id, msg); + } + // find leaf node from a message node + const findLeafNode = (msgId: Message['id']): Message['id'] => { + let currNode: Message | undefined = nodeMap.get(msgId); + while (currNode) { + if (currNode.children.length === 0) break; + currNode = nodeMap.get(currNode.children.at(-1) ?? -1); + } + return currNode?.id ?? -1; + }; + // traverse the current nodes + for (const msg of currNodes) { + const parentNode = nodeMap.get(msg.parent ?? -1); + if (!parentNode) continue; + const siblings = parentNode.children; + if (msg.type !== 'root') { + res.push({ + msg, + siblingLeafNodeIds: siblings.map(findLeafNode), + siblingCurrIdx: siblings.indexOf(msg.id), + }); + } + } + return res; +} + +export default function ChatScreen() { + const { + viewingChat, + sendMessage, + isGenerating, + stopGenerating, + pendingMessages, + canvasData, + replaceMessageAndGenerate, + } = useAppContext(); + + const textarea: ChatTextareaApi = useChatTextarea(prefilledMsg.content()); + const extraContext = useChatExtraContext(); + useVSCodeContext(textarea, extraContext); + + const msgListRef = useRef(null); + useChatScroll(msgListRef); + + // keep track of leaf node for rendering + const [currNodeId, setCurrNodeId] = useState(-1); + const messages: MessageDisplay[] = useMemo(() => { + if (!viewingChat) return []; + else return getListMessageDisplay(viewingChat.messages, currNodeId); + }, [currNodeId, viewingChat]); + + const currConvId = viewingChat?.conv.id ?? null; + const pendingMsg: PendingMessage | undefined = + pendingMessages[currConvId ?? '']; + + useEffect(() => { + // reset to latest node when conversation changes + setCurrNodeId(-1); + // scroll to bottom when conversation changes + scrollToBottom(false, 1); + }, [currConvId]); + + const onChunk: CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => { + if (currLeafNodeId) { + setCurrNodeId(currLeafNodeId); + } + // useChatScroll will handle the auto scroll + }; + + const sendNewMessage = async () => { + const lastInpMsg = textarea.value(); + if (lastInpMsg.trim().length === 0 || isGenerating(currConvId ?? '')) { + toast.error('Please enter a message'); + return; + } + textarea.setValue(''); + scrollToBottom(false); + setCurrNodeId(-1); + // get the last message node + const lastMsgNodeId = messages.at(-1)?.msg.id ?? null; + if ( + !(await sendMessage( + currConvId, + lastMsgNodeId, + lastInpMsg, + extraContext.items, + onChunk + )) + ) { + // restore the input message if failed + textarea.setValue(lastInpMsg); + } + // OK + extraContext.clearItems(); + }; + + // for vscode context + textarea.refOnSubmit.current = sendNewMessage; + + const handleEditMessage = async (msg: Message, content: string) => { + if (!viewingChat) return; + setCurrNodeId(msg.id); + scrollToBottom(false); + await replaceMessageAndGenerate( + viewingChat.conv.id, + msg.parent, + content, + msg.extra, + onChunk + ); + setCurrNodeId(-1); + scrollToBottom(false); + }; + + const handleRegenerateMessage = async (msg: Message) => { + if (!viewingChat) return; + setCurrNodeId(msg.parent); + scrollToBottom(false); + await replaceMessageAndGenerate( + viewingChat.conv.id, + msg.parent, + null, + msg.extra, + onChunk + ); + setCurrNodeId(-1); + scrollToBottom(false); + }; + + const hasCanvas = !!canvasData; + + useEffect(() => { + if (prefilledMsg.shouldSend()) { + // send the prefilled message if needed + sendNewMessage(); + } else { + // otherwise, focus on the input + textarea.focus(); + } + prefilledMsg.clear(); + // no need to keep track of sendNewMessage + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [textarea.ref]); + + // due to some timing issues of StorageUtils.appendMsg(), we need to make sure the pendingMsg is not duplicated upon rendering (i.e. appears once in the saved conversation and once in the pendingMsg) + const pendingMsgDisplay: MessageDisplay[] = + pendingMsg && messages.at(-1)?.msg.id !== pendingMsg.id + ? [ + { + msg: pendingMsg, + siblingLeafNodeIds: [], + siblingCurrIdx: 0, + isPending: true, + }, + ] + : []; + + return ( +
+
+ {/* chat messages */} +
+
+ {/* placeholder to shift the message to the bottom */} + {viewingChat ? ( + '' + ) : ( + <> +
Send a message to start
+ + + )} +
+ {[...messages, ...pendingMsgDisplay].map((msg) => ( + + ))} +
+ + {/* chat input */} + stopGenerating(currConvId ?? '')} + isGenerating={isGenerating(currConvId ?? '')} + /> +
+
+ {canvasData?.type === CanvasType.PY_INTERPRETER && ( + + )} +
+
+ ); +} + +function ServerInfo() { + const { serverProps } = useAppContext(); + const modalities = []; + if (serverProps?.modalities?.audio) { + modalities.push('audio'); + } + if (serverProps?.modalities?.vision) { + modalities.push('vision'); + } + return ( +
+
+ Server Info +

+ Model: {serverProps?.model_path?.split(/(\\|\/)/).pop()} +
+ Build: {serverProps?.build_info} +
+ {modalities.length > 0 ? ( + <> + Supported modalities: {modalities.join(', ')} + + ) : ( + '' + )} +

+
+
+ ); +} + +function ChatInput({ + textarea, + extraContext, + onSend, + onStop, + isGenerating, +}: { + textarea: ChatTextareaApi; + extraContext: ChatExtraContextApi; + onSend: () => void; + onStop: () => void; + isGenerating: boolean; +}) { + const { config } = useAppContext(); + const [isDrag, setIsDrag] = useState(false); + + return ( +
+ { + setIsDrag(false); + extraContext.onFileAdded(files); + }} + onDragEnter={() => setIsDrag(true)} + onDragLeave={() => setIsDrag(false)} + multiple={true} + > + {({ getRootProps, getInputProps }) => ( +
) => { + const text = e.clipboardData.getData('text/plain'); + if ( + text.length > 0 && + config.pasteLongTextToFileLen > 0 && + text.length > config.pasteLongTextToFileLen + ) { + // if the text is too long, we will convert it to a file + extraContext.addItems([ + { + type: 'context', + name: 'Pasted Content', + content: text, + }, + ]); + e.preventDefault(); + return; + } + + // if a file is pasted, we will handle it here + const files = Array.from(e.clipboardData.items) + .filter((item) => item.kind === 'file') + .map((item) => item.getAsFile()) + .filter((file) => file !== null); + + if (files.length > 0) { + e.preventDefault(); + extraContext.onFileAdded(files); + } + }} + {...getRootProps()} + > + {!isGenerating && ( + + )} + +
+ + + {/* buttons area */} +
+ + + {isGenerating ? ( + + ) : ( + + )} +
+
+
+ )} +
+
+ ); +} diff --git a/tools/server/webui/src/components/Header.tsx b/tools/server/webui/src/components/Header.tsx new file mode 100644 index 0000000000000000000000000000000000000000..ccddc21ddab73969685d426c4d9242792b482093 --- /dev/null +++ b/tools/server/webui/src/components/Header.tsx @@ -0,0 +1,92 @@ +import { useEffect, useState } from 'react'; +import StorageUtils from '../utils/storage'; +import { useAppContext } from '../utils/app.context'; +import { classNames } from '../utils/misc'; +import daisyuiThemes from 'daisyui/theme/object'; +import { THEMES } from '../Config'; +import { + Cog8ToothIcon, + MoonIcon, + Bars3Icon, +} from '@heroicons/react/24/outline'; + +export default function Header() { + const [selectedTheme, setSelectedTheme] = useState(StorageUtils.getTheme()); + const { setShowSettings } = useAppContext(); + + const setTheme = (theme: string) => { + StorageUtils.setTheme(theme); + setSelectedTheme(theme); + }; + + useEffect(() => { + document.body.setAttribute('data-theme', selectedTheme); + document.body.setAttribute( + 'data-color-scheme', + daisyuiThemes[selectedTheme]?.['color-scheme'] ?? 'auto' + ); + }, [selectedTheme]); + + return ( +
+ {/* open sidebar button */} + + +
llama.cpp
+ + {/* action buttons (top right) */} +
+
setShowSettings(true)} + > + +
+ + {/* theme controller is copied from https://daisyui.com/components/theme-controller/ */} +
+
+
+ +
+
    +
  • + +
  • + {THEMES.map((theme) => ( +
  • + e.target.checked && setTheme(theme)} + /> +
  • + ))} +
+
+
+
+
+ ); +} diff --git a/tools/server/webui/src/components/MarkdownDisplay.tsx b/tools/server/webui/src/components/MarkdownDisplay.tsx new file mode 100644 index 0000000000000000000000000000000000000000..380dbc570a07cd5286237c9b4642d9eab96a1c3b --- /dev/null +++ b/tools/server/webui/src/components/MarkdownDisplay.tsx @@ -0,0 +1,317 @@ +import React, { useMemo, useState } from 'react'; +import Markdown, { ExtraProps } from 'react-markdown'; +import remarkGfm from 'remark-gfm'; +import rehypeHightlight from 'rehype-highlight'; +import rehypeKatex from 'rehype-katex'; +import remarkMath from 'remark-math'; +import remarkBreaks from 'remark-breaks'; +import 'katex/dist/katex.min.css'; +import { classNames, copyStr } from '../utils/misc'; +import { ElementContent, Root } from 'hast'; +import { visit } from 'unist-util-visit'; +import { useAppContext } from '../utils/app.context'; +import { CanvasType } from '../utils/types'; +import { BtnWithTooltips } from '../utils/common'; +import { DocumentDuplicateIcon, PlayIcon } from '@heroicons/react/24/outline'; + +export default function MarkdownDisplay({ + content, + isGenerating, +}: { + content: string; + isGenerating?: boolean; +}) { + const preprocessedContent = useMemo( + () => preprocessLaTeX(content), + [content] + ); + return ( + ( + + ), + // note: do not use "pre", "p" or other basic html elements here, it will cause the node to re-render when the message is being generated (this should be a bug with react-markdown, not sure how to fix it) + }} + > + {preprocessedContent} + + ); +} + +const CodeBlockButtons: React.ElementType< + React.ClassAttributes & + React.HTMLAttributes & + ExtraProps & { origContent: string; isGenerating?: boolean } +> = ({ node, origContent, isGenerating }) => { + const { config } = useAppContext(); + const startOffset = node?.position?.start.offset ?? 0; + const endOffset = node?.position?.end.offset ?? 0; + + const copiedContent = useMemo( + () => + origContent + .substring(startOffset, endOffset) + .replace(/^```[^\n]+\n/g, '') + .replace(/```$/g, ''), + [origContent, startOffset, endOffset] + ); + + const codeLanguage = useMemo( + () => + origContent + .substring(startOffset, startOffset + 10) + .match(/^```([^\n]+)\n/)?.[1] ?? '', + [origContent, startOffset] + ); + + const canRunCode = + !isGenerating && + config.pyIntepreterEnabled && + codeLanguage.startsWith('py'); + + return ( +
+ + {canRunCode && ( + + )} +
+ ); +}; + +export const CopyButton = ({ + content, + className, +}: { + content: string; + className?: string; +}) => { + const [copied, setCopied] = useState(false); + return ( + { + copyStr(content); + setCopied(true); + }} + onMouseLeave={() => setCopied(false)} + tooltipsContent={copied ? 'Copied!' : 'Copy'} + > + + + ); +}; + +export const RunPyCodeButton = ({ + content, + className, +}: { + content: string; + className?: string; +}) => { + const { setCanvasData } = useAppContext(); + return ( + <> + + setCanvasData({ + type: CanvasType.PY_INTERPRETER, + content, + }) + } + tooltipsContent="Run code" + > + + + + ); +}; + +/** + * This injects the "button" element before each "pre" element. + * The actual button will be replaced with a react component in the MarkdownDisplay. + * We don't replace "pre" node directly because it will cause the node to re-render, which causes this bug: https://github.com/ggerganov/llama.cpp/issues/9608 + */ +function rehypeCustomCopyButton() { + return function (tree: Root) { + visit(tree, 'element', function (node) { + if (node.tagName === 'pre' && !node.properties.visited) { + const preNode = { ...node }; + // replace current node + preNode.properties.visited = 'true'; + node.tagName = 'div'; + node.properties = {}; + // add node for button + const btnNode: ElementContent = { + type: 'element', + tagName: 'button', + properties: {}, + children: [], + position: node.position, + }; + node.children = [btnNode, preNode]; + } + }); + }; +} + +/** + * The part below is copied and adapted from: + * https://github.com/danny-avila/LibreChat/blob/main/client/src/utils/latex.ts + * (MIT License) + */ + +// Regex to check if the processed content contains any potential LaTeX patterns +const containsLatexRegex = + /\\\(.*?\\\)|\\\[.*?\\\]|\$.*?\$|\\begin\{equation\}.*?\\end\{equation\}/; + +// Regex for inline and block LaTeX expressions +const inlineLatex = new RegExp(/\\\((.+?)\\\)/, 'g'); +const blockLatex = new RegExp(/\\\[(.*?[^\\])\\\]/, 'gs'); + +// Function to restore code blocks +const restoreCodeBlocks = (content: string, codeBlocks: string[]) => { + return content.replace( + /<>/g, + (_, index) => codeBlocks[index] + ); +}; + +// Regex to identify code blocks and inline code +const codeBlockRegex = /(```[\s\S]*?```|`.*?`)/g; + +export const processLaTeX = (_content: string) => { + let content = _content; + // Temporarily replace code blocks and inline code with placeholders + const codeBlocks: string[] = []; + let index = 0; + content = content.replace(codeBlockRegex, (match) => { + codeBlocks[index] = match; + return `<>`; + }); + + // Escape dollar signs followed by a digit or space and digit + let processedContent = content.replace(/(\$)(?=\s?\d)/g, '\\$'); + + // If no LaTeX patterns are found, restore code blocks and return the processed content + if (!containsLatexRegex.test(processedContent)) { + return restoreCodeBlocks(processedContent, codeBlocks); + } + + // Convert LaTeX expressions to a markdown compatible format + processedContent = processedContent + .replace(inlineLatex, (_: string, equation: string) => `$${equation}$`) // Convert inline LaTeX + .replace(blockLatex, (_: string, equation: string) => `$$${equation}$$`); // Convert block LaTeX + + // Restore code blocks + return restoreCodeBlocks(processedContent, codeBlocks); +}; + +/** + * Preprocesses LaTeX content by replacing delimiters and escaping certain characters. + * + * @param content The input string containing LaTeX expressions. + * @returns The processed string with replaced delimiters and escaped characters. + */ +export function preprocessLaTeX(content: string): string { + // Step 1: Protect code blocks + const codeBlocks: string[] = []; + content = content.replace(/(```[\s\S]*?```|`[^`\n]+`)/g, (_, code) => { + codeBlocks.push(code); + return `<>`; + }); + + // Step 2: Protect existing LaTeX expressions + const latexExpressions: string[] = []; + + // Protect block math ($$...$$), \[...\], and \(...\) as before. + content = content.replace( + /(\$\$[\s\S]*?\$\$|\\\[[\s\S]*?\\\]|\\\(.*?\\\))/g, + (match) => { + latexExpressions.push(match); + return `<>`; + } + ); + + // Protect inline math ($...$) only if it does NOT match a currency pattern. + // We assume a currency pattern is one where the inner content is purely numeric (with optional decimals). + content = content.replace(/\$([^$]+)\$/g, (match, inner) => { + if (/^\s*\d+(?:\.\d+)?\s*$/.test(inner)) { + // This looks like a currency value (e.g. "$123" or "$12.34"), + // so don't protect it. + return match; + } else { + // Otherwise, treat it as a LaTeX expression. + latexExpressions.push(match); + return `<>`; + } + }); + + // Step 3: Escape dollar signs that are likely currency indicators. + // (Now that inline math is protected, this will only escape dollars not already protected) + content = content.replace(/\$(?=\d)/g, '\\$'); + + // Step 4: Restore LaTeX expressions + content = content.replace( + /<>/g, + (_, index) => latexExpressions[parseInt(index)] + ); + + // Step 5: Restore code blocks + content = content.replace( + /<>/g, + (_, index) => codeBlocks[parseInt(index)] + ); + + // Step 6: Apply additional escaping functions + content = escapeBrackets(content); + content = escapeMhchem(content); + + return content; +} + +export function escapeBrackets(text: string): string { + const pattern = + /(```[\S\s]*?```|`.*?`)|\\\[([\S\s]*?[^\\])\\]|\\\((.*?)\\\)/g; + return text.replace( + pattern, + ( + match: string, + codeBlock: string | undefined, + squareBracket: string | undefined, + roundBracket: string | undefined + ): string => { + if (codeBlock != null) { + return codeBlock; + } else if (squareBracket != null) { + return `$$${squareBracket}$$`; + } else if (roundBracket != null) { + return `$${roundBracket}$`; + } + return match; + } + ); +} + +export function escapeMhchem(text: string) { + return text.replaceAll('$\\ce{', '$\\\\ce{').replaceAll('$\\pu{', '$\\\\pu{'); +} diff --git a/tools/server/webui/src/components/ModalProvider.tsx b/tools/server/webui/src/components/ModalProvider.tsx new file mode 100644 index 0000000000000000000000000000000000000000..f2ebf8e0a7fa4db8f242f3bb5d64b04976176064 --- /dev/null +++ b/tools/server/webui/src/components/ModalProvider.tsx @@ -0,0 +1,151 @@ +import React, { createContext, useState, useContext } from 'react'; + +type ModalContextType = { + showConfirm: (message: string) => Promise; + showPrompt: ( + message: string, + defaultValue?: string + ) => Promise; + showAlert: (message: string) => Promise; +}; +const ModalContext = createContext(null!); + +interface ModalState { + isOpen: boolean; + message: string; + defaultValue?: string; + resolve: ((value: T) => void) | null; +} + +export function ModalProvider({ children }: { children: React.ReactNode }) { + const [confirmState, setConfirmState] = useState>({ + isOpen: false, + message: '', + resolve: null, + }); + const [promptState, setPromptState] = useState< + ModalState + >({ isOpen: false, message: '', resolve: null }); + const [alertState, setAlertState] = useState>({ + isOpen: false, + message: '', + resolve: null, + }); + const inputRef = React.useRef(null); + + const showConfirm = (message: string): Promise => { + return new Promise((resolve) => { + setConfirmState({ isOpen: true, message, resolve }); + }); + }; + + const showPrompt = ( + message: string, + defaultValue?: string + ): Promise => { + return new Promise((resolve) => { + setPromptState({ isOpen: true, message, defaultValue, resolve }); + }); + }; + + const showAlert = (message: string): Promise => { + return new Promise((resolve) => { + setAlertState({ isOpen: true, message, resolve }); + }); + }; + + const handleConfirm = (result: boolean) => { + confirmState.resolve?.(result); + setConfirmState({ isOpen: false, message: '', resolve: null }); + }; + + const handlePrompt = (result?: string) => { + promptState.resolve?.(result); + setPromptState({ isOpen: false, message: '', resolve: null }); + }; + + const handleAlertClose = () => { + alertState.resolve?.(); + setAlertState({ isOpen: false, message: '', resolve: null }); + }; + + return ( + + {children} + + {/* Confirm Modal */} + {confirmState.isOpen && ( + +
+

{confirmState.message}

+
+ + +
+
+
+ )} + + {/* Prompt Modal */} + {promptState.isOpen && ( + +
+

{promptState.message}

+ { + if (e.key === 'Enter') { + handlePrompt((e.target as HTMLInputElement).value); + } + }} + /> +
+ + +
+
+
+ )} + + {/* Alert Modal */} + {alertState.isOpen && ( + +
+

{alertState.message}

+
+ +
+
+
+ )} +
+ ); +} + +export function useModals() { + const context = useContext(ModalContext); + if (!context) throw new Error('useModals must be used within ModalProvider'); + return context; +} diff --git a/tools/server/webui/src/components/SettingDialog.tsx b/tools/server/webui/src/components/SettingDialog.tsx new file mode 100644 index 0000000000000000000000000000000000000000..45a8d73b00592ab009f28afc0506e5ae3c7743a2 --- /dev/null +++ b/tools/server/webui/src/components/SettingDialog.tsx @@ -0,0 +1,553 @@ +import { useState } from 'react'; +import { useAppContext } from '../utils/app.context'; +import { CONFIG_DEFAULT, CONFIG_INFO } from '../Config'; +import { isDev } from '../Config'; +import StorageUtils from '../utils/storage'; +import { classNames, isBoolean, isNumeric, isString } from '../utils/misc'; +import { + BeakerIcon, + ChatBubbleOvalLeftEllipsisIcon, + Cog6ToothIcon, + FunnelIcon, + HandRaisedIcon, + SquaresPlusIcon, +} from '@heroicons/react/24/outline'; +import { OpenInNewTab } from '../utils/common'; +import { useModals } from './ModalProvider'; + +type SettKey = keyof typeof CONFIG_DEFAULT; + +const BASIC_KEYS: SettKey[] = [ + 'temperature', + 'top_k', + 'top_p', + 'min_p', + 'max_tokens', +]; +const SAMPLER_KEYS: SettKey[] = [ + 'dynatemp_range', + 'dynatemp_exponent', + 'typical_p', + 'xtc_probability', + 'xtc_threshold', +]; +const PENALTY_KEYS: SettKey[] = [ + 'repeat_last_n', + 'repeat_penalty', + 'presence_penalty', + 'frequency_penalty', + 'dry_multiplier', + 'dry_base', + 'dry_allowed_length', + 'dry_penalty_last_n', +]; + +enum SettingInputType { + SHORT_INPUT, + LONG_INPUT, + CHECKBOX, + CUSTOM, +} + +interface SettingFieldInput { + type: Exclude; + label: string | React.ReactElement; + help?: string | React.ReactElement; + key: SettKey; +} + +interface SettingFieldCustom { + type: SettingInputType.CUSTOM; + key: SettKey; + component: + | string + | React.FC<{ + value: string | boolean | number; + onChange: (value: string) => void; + }>; +} + +interface SettingSection { + title: React.ReactElement; + fields: (SettingFieldInput | SettingFieldCustom)[]; +} + +const ICON_CLASSNAME = 'w-4 h-4 mr-1 inline'; + +const SETTING_SECTIONS: SettingSection[] = [ + { + title: ( + <> + + General + + ), + fields: [ + { + type: SettingInputType.SHORT_INPUT, + label: 'API Key', + key: 'apiKey', + }, + { + type: SettingInputType.LONG_INPUT, + label: 'System Message (will be disabled if left empty)', + key: 'systemMessage', + }, + ...BASIC_KEYS.map( + (key) => + ({ + type: SettingInputType.SHORT_INPUT, + label: key, + key, + }) as SettingFieldInput + ), + { + type: SettingInputType.SHORT_INPUT, + label: 'Paste length to file', + key: 'pasteLongTextToFileLen', + }, + { + type: SettingInputType.CHECKBOX, + label: 'Parse PDF as image instead of text', + key: 'pdfAsImage', + }, + ], + }, + { + title: ( + <> + + Samplers + + ), + fields: [ + { + type: SettingInputType.SHORT_INPUT, + label: 'Samplers queue', + key: 'samplers', + }, + ...SAMPLER_KEYS.map( + (key) => + ({ + type: SettingInputType.SHORT_INPUT, + label: key, + key, + }) as SettingFieldInput + ), + ], + }, + { + title: ( + <> + + Penalties + + ), + fields: PENALTY_KEYS.map((key) => ({ + type: SettingInputType.SHORT_INPUT, + label: key, + key, + })), + }, + { + title: ( + <> + + Reasoning + + ), + fields: [ + { + type: SettingInputType.CHECKBOX, + label: 'Expand thought process by default when generating messages', + key: 'showThoughtInProgress', + }, + { + type: SettingInputType.CHECKBOX, + label: + 'Exclude thought process when sending requests to API (Recommended for DeepSeek-R1)', + key: 'excludeThoughtOnReq', + }, + ], + }, + { + title: ( + <> + + Advanced + + ), + fields: [ + { + type: SettingInputType.CUSTOM, + key: 'custom', // dummy key, won't be used + component: () => { + const debugImportDemoConv = async () => { + const res = await fetch('/demo-conversation.json'); + const demoConv = await res.json(); + StorageUtils.remove(demoConv.id); + for (const msg of demoConv.messages) { + StorageUtils.appendMsg(demoConv.id, msg); + } + }; + return ( + + ); + }, + }, + { + type: SettingInputType.CHECKBOX, + label: 'Show tokens per second', + key: 'showTokensPerSecond', + }, + { + type: SettingInputType.LONG_INPUT, + label: ( + <> + Custom JSON config (For more info, refer to{' '} + + server documentation + + ) + + ), + key: 'custom', + }, + ], + }, + { + title: ( + <> + + Experimental + + ), + fields: [ + { + type: SettingInputType.CUSTOM, + key: 'custom', // dummy key, won't be used + component: () => ( + <> +

+ Experimental features are not guaranteed to work correctly. +
+
+ If you encounter any problems, create a{' '} + + Bug (misc.) + {' '} + report on Github. Please also specify webui/experimental on + the report title and include screenshots. +
+
+ Some features may require packages downloaded from CDN, so they + need internet connection. +

+ + ), + }, + { + type: SettingInputType.CHECKBOX, + label: ( + <> + Enable Python interpreter +
+ + This feature uses{' '} + pyodide, + downloaded from CDN. To use this feature, ask the LLM to generate + Python code inside a Markdown code block. You will see a "Run" + button on the code block, near the "Copy" button. + + + ), + key: 'pyIntepreterEnabled', + }, + ], + }, +]; + +export default function SettingDialog({ + show, + onClose, +}: { + show: boolean; + onClose: () => void; +}) { + const { config, saveConfig } = useAppContext(); + const [sectionIdx, setSectionIdx] = useState(0); + + // clone the config object to prevent direct mutation + const [localConfig, setLocalConfig] = useState( + JSON.parse(JSON.stringify(config)) + ); + const { showConfirm, showAlert } = useModals(); + + const resetConfig = async () => { + if (await showConfirm('Are you sure you want to reset all settings?')) { + setLocalConfig(CONFIG_DEFAULT); + } + }; + + const handleSave = async () => { + // copy the local config to prevent direct mutation + const newConfig: typeof CONFIG_DEFAULT = JSON.parse( + JSON.stringify(localConfig) + ); + // validate the config + for (const key in newConfig) { + const value = newConfig[key as SettKey]; + const mustBeBoolean = isBoolean(CONFIG_DEFAULT[key as SettKey]); + const mustBeString = isString(CONFIG_DEFAULT[key as SettKey]); + const mustBeNumeric = isNumeric(CONFIG_DEFAULT[key as SettKey]); + if (mustBeString) { + if (!isString(value)) { + await showAlert(`Value for ${key} must be string`); + return; + } + } else if (mustBeNumeric) { + const trimmedValue = value.toString().trim(); + const numVal = Number(trimmedValue); + if (isNaN(numVal) || !isNumeric(numVal) || trimmedValue.length === 0) { + await showAlert(`Value for ${key} must be numeric`); + return; + } + // force conversion to number + // @ts-expect-error this is safe + newConfig[key] = numVal; + } else if (mustBeBoolean) { + if (!isBoolean(value)) { + await showAlert(`Value for ${key} must be boolean`); + return; + } + } else { + console.error(`Unknown default type for key ${key}`); + } + } + if (isDev) console.log('Saving config', newConfig); + saveConfig(newConfig); + onClose(); + }; + + const onChange = (key: SettKey) => (value: string | boolean) => { + // note: we do not perform validation here, because we may get incomplete value as user is still typing it + setLocalConfig({ ...localConfig, [key]: value }); + }; + + return ( + +
+

Settings

+
+ {/* Left panel, showing sections - Desktop version */} +
+ {SETTING_SECTIONS.map((section, idx) => ( + + ))} +
+ + {/* Left panel, showing sections - Mobile version */} + {/* This menu is skipped on a11y, otherwise it's repeated the desktop version */} +
+
+ + {SETTING_SECTIONS[sectionIdx].title} + +
    + {SETTING_SECTIONS.map((section, idx) => ( +
    setSectionIdx(idx)} + dir="auto" + > + {section.title} +
    + ))} +
+
+
+ + {/* Right panel, showing setting fields */} +
+ {SETTING_SECTIONS[sectionIdx].fields.map((field, idx) => { + const key = `${sectionIdx}-${idx}`; + if (field.type === SettingInputType.SHORT_INPUT) { + return ( + + ); + } else if (field.type === SettingInputType.LONG_INPUT) { + return ( + + ); + } else if (field.type === SettingInputType.CHECKBOX) { + return ( + + ); + } else if (field.type === SettingInputType.CUSTOM) { + return ( +
+ {typeof field.component === 'string' + ? field.component + : field.component({ + value: localConfig[field.key], + onChange: onChange(field.key), + })} +
+ ); + } + })} + +

+ Settings are saved in browser's localStorage +

+
+
+ +
+ + + +
+
+
+ ); +} + +function SettingsModalLongInput({ + configKey, + value, + onChange, + label, +}: { + configKey: SettKey; + value: string; + onChange: (value: string) => void; + label?: string; +}) { + return ( +